├── .gitignore
├── LICENSE
├── README.md
├── data
├── gt_bach.wav
└── gt_counting.wav
├── dataio.py
├── diff_operators.py
├── environment.yml
├── experiment_scripts
├── __init__.py
├── fwi
│ └── data_cylinder_5.mat
├── test_audio.py
├── test_conv_neural_process.py
├── test_neural_process.py
├── test_sdf.py
├── train_audio.py
├── train_helmholtz.py
├── train_img.py
├── train_img_inpainting.py
├── train_img_neural_process.py
├── train_inverse_helmholtz.py
├── train_poisson_grad_img.py
├── train_poisson_gradcomp_img.py
├── train_poisson_lapl_img.py
├── train_sdf.py
├── train_video.py
└── train_wave_equation.py
├── explore_siren.ipynb
├── loss_functions.py
├── make_figures.py
├── meta_modules.py
├── modules.py
├── sdf_meshing.py
├── torchmeta
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── assets
│ │ ├── cifar100
│ │ │ ├── cifar-fs
│ │ │ │ ├── test.json
│ │ │ │ ├── train.json
│ │ │ │ └── val.json
│ │ │ └── fc100
│ │ │ │ ├── test.json
│ │ │ │ ├── train.json
│ │ │ │ └── val.json
│ │ ├── cub
│ │ │ ├── test.json
│ │ │ ├── train.json
│ │ │ └── val.json
│ │ ├── doublemnist
│ │ │ ├── test.json
│ │ │ ├── train.json
│ │ │ └── val.json
│ │ ├── omniglot
│ │ │ ├── test.json
│ │ │ ├── train.json
│ │ │ └── val.json
│ │ ├── tcga
│ │ │ ├── cancers.json
│ │ │ ├── task_variables.json
│ │ │ ├── test.json
│ │ │ ├── train.json
│ │ │ └── val.json
│ │ └── triplemnist
│ │ │ ├── test.json
│ │ │ ├── train.json
│ │ │ └── val.json
│ ├── cifar100
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── cifar_fs.py
│ │ └── fc100.py
│ ├── cub.py
│ ├── doublemnist.py
│ ├── helpers.py
│ ├── miniimagenet.py
│ ├── omniglot.py
│ ├── tcga.py
│ ├── tieredimagenet.py
│ ├── triplemnist.py
│ └── utils.py
├── modules
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── container.py
│ ├── conv.py
│ ├── linear.py
│ ├── module.py
│ ├── normalization.py
│ └── utils.py
├── tests
│ ├── __init__.py
│ ├── test_dataloaders.py
│ ├── test_prototype.py
│ ├── test_splitters.py
│ └── test_toy.py
├── toy
│ ├── __init__.py
│ ├── harmonic.py
│ ├── helpers.py
│ ├── sinusoid.py
│ └── sinusoid_line.py
├── transforms
│ ├── __init__.py
│ ├── augmentations.py
│ ├── categorical.py
│ ├── splitters.py
│ ├── target_transforms.py
│ └── utils.py
├── utils
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── dataloader.py
│ │ ├── dataset.py
│ │ ├── sampler.py
│ │ └── task.py
│ ├── metrics.py
│ └── prototype.py
└── version.py
├── training.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Vincent Sitzmann
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Implicit Neural Representations with Periodic Activation Functions
2 | ### [Project Page](https://vsitzmann.github.io/siren) | [Paper](https://arxiv.org/abs/2006.09661) | [Data](https://drive.google.com/drive/folders/1_iq__37-hw7FJOEUK1tX7mdp8SKB368K?usp=sharing)
3 | [](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb)
4 |
5 | [Vincent Sitzmann](https://vsitzmann.github.io/)\*,
6 | [Julien N. P. Martel](http://www.jmartel.net)\*,
7 | [Alexander W. Bergman](http://alexanderbergman7.github.io),
8 | [David B. Lindell](http://www.davidlindell.com/),
9 | [Gordon Wetzstein](https://stanford.edu/~gordonwz/)
10 | Stanford University, \*denotes equal contribution
11 |
12 | This is the official implementation of the paper "Implicit Neural Representations with Periodic Activation Functions".
13 |
14 | [](https://www.youtube.com/watch?v=Q2fLWGBeaiI)
15 |
16 |
17 | ## Google Colab
18 | If you want to experiment with Siren, we have written a [Colab](https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb).
19 | It's quite comprehensive and comes with a no-frills, drop-in implementation of SIREN. It doesn't require
20 | installing anything, and goes through the following experiments / SIREN properties:
21 | * Fitting an image
22 | * Fitting an audio signal
23 | * Solving Poisson's equation
24 | * Initialization scheme & distribution of activations
25 | * Distribution of activations is shift-invariant
26 | * Periodicity & behavior outside of the training range.
27 |
28 | ## Tensorflow Playground
29 | You can also play arond with a tiny SIREN interactively, directly in the browser, via the Tensorflow Playground [here](https://dcato98.github.io/playground/#activation=sine). Thanks to [David Cato](https://github.com/dcato98) for implementing this!
30 |
31 | ## Get started
32 | If you want to reproduce all the results (including the baselines) shown in the paper, the videos, point clouds, and
33 | audio files can be found [here](https://drive.google.com/drive/folders/1_iq__37-hw7FJOEUK1tX7mdp8SKB368K?usp=sharing).
34 |
35 | You can then set up a conda environment with all dependencies like so:
36 | ```
37 | conda env create -f environment.yml
38 | conda activate siren
39 | ```
40 |
41 | ## High-Level structure
42 | The code is organized as follows:
43 | * dataio.py loads training and testing data.
44 | * training.py contains a generic training routine.
45 | * modules.py contains layers and full neural network modules.
46 | * meta_modules.py contains hypernetwork code.
47 | * utils.py contains utility functions, most promintently related to the writing of Tensorboard summaries.
48 | * diff_operators.py contains implementations of differential operators.
49 | * loss_functions.py contains loss functions for the different experiments.
50 | * make_figures.py contains helper functions to create the convergence videos shown in the video.
51 | * ./experiment_scripts/ contains scripts to reproduce experiments in the paper.
52 |
53 | ## Reproducing experiments
54 | The directory `experiment_scripts` contains one script per experiment in the paper.
55 |
56 | To monitor progress, the training code writes tensorboard summaries into a "summaries"" subdirectory in the logging_root.
57 |
58 | ### Image experiments
59 | The image experiment can be reproduced with
60 | ```
61 | python experiment_scripts/train_img.py --model_type=sine
62 | ```
63 | The figures in the paper were made by extracting images from the tensorboard summaries. Example code how to do this can
64 | be found in the make_figures.py script.
65 |
66 | ### Audio experiments
67 | This github repository comes with both the "counting" and "bach" audio clips under ./data.
68 |
69 | They can be trained with
70 | ```
71 | python experiment_scipts/train_audio.py --model_type=sine --wav_path=
72 | ```
73 |
74 | ### Video experiments
75 | The "bikes" video sequence comes with scikit-video and need not be downloaded. The cat video can be downloaded with the
76 | link above.
77 |
78 | To fit a model to a video, run
79 | ```
80 | python experiment_scipts/train_video.py --model_type=sine --experiment_name bikes_video
81 | ```
82 |
83 | ### Poisson experiments
84 | For the poisson experiments, there are three separate scripts: One for reconstructing an image from its gradients
85 | (train_poisson_grad_img.py), from its laplacian (train_poisson_lapl_image.py), and to combine two images
86 | (train_poisson_gradcomp_img.py).
87 |
88 | Some of the experiments were run using the BSD500 datast, which you can download [here](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/).
89 |
90 | ### SDF Experiments
91 | To fit a Signed Distance Function (SDF) with SIREN, you first need a pointcloud in .xyz format that includes surface normals.
92 | If you only have a mesh / ply file, this can be accomplished with the open-source tool Meshlab.
93 |
94 | To reproduce our results, we provide both models of the Thai Statue from the 3D Stanford model repository and the living room used in our paper
95 | for download here.
96 |
97 | To start training a SIREN, run:
98 | ```
99 | python experiments_scripts/train_single_sdf.py --model_type=sine --point_cloud_path= --batch_size=250000 --experiment_name=experiment_1
100 | ```
101 | This will regularly save checkpoints in the directory specified by the rootpath in the script, in a subdirectory "experiment_1".
102 | The batch_size is typically adjusted to fit in the entire memory of your GPU.
103 | Our experiments show that with a 256, 3 hidden layer SIREN one can set the batch size between 230-250'000 for a NVidia GPU with 12GB memory.
104 |
105 | To inspect a SDF fitted to a 3D point cloud, we now need to create a mesh from the zero-level set of the SDF.
106 | This is performed with another script that uses a marching cubes algorithm (adapted from the DeepSDF github repo)
107 | and creates the mesh saved in a .ply file format. It can be called with:
108 | ```
109 | python experiments_scripts/test_single_sdf.py --checkpoint_path= --experiment_name=experiment_1_rec
110 | ```
111 | This will save the .ply file as "reconstruction.ply" in "experiment_1_rec" (be patient, the marching cube meshing step takes some time ;) )
112 | In the event the machine you use for the reconstruction does not have enough RAM, running test_sdf script will likely freeze. If this is the case,
113 | please use the option --resolution=512 in the command line above (set to 1600 by default) that will reconstruct the mesh at a lower spatial resolution.
114 |
115 | The .ply file can be visualized using a software such as [Meshlab](https://www.meshlab.net/#download) (a cross-platform visualizer and editor for 3D models).
116 |
117 | ### Helmholtz and wave equation experiments
118 | The helmholtz and wave equation experiments can be reproduced with the train_wave_equation.py and train_helmholtz.py scripts.
119 |
120 | ## Torchmeta
121 | We're using the excellent [torchmeta](https://github.com/tristandeleu/pytorch-meta) to implement hypernetworks. We
122 | realized that there is a technical report, which we forgot to cite - it'll make it into the camera-ready version!
123 |
124 | ## Citation
125 | If you find our work useful in your research, please cite:
126 | ```
127 | @inproceedings{sitzmann2019siren,
128 | author = {Sitzmann, Vincent
129 | and Martel, Julien N.P.
130 | and Bergman, Alexander W.
131 | and Lindell, David B.
132 | and Wetzstein, Gordon},
133 | title = {Implicit Neural Representations
134 | with Periodic Activation Functions},
135 | booktitle = {arXiv},
136 | year={2020}
137 | }
138 | ```
139 |
140 | ## Contact
141 | If you have any questions, please feel free to email the authors.
142 |
--------------------------------------------------------------------------------
/data/gt_bach.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/data/gt_bach.wav
--------------------------------------------------------------------------------
/data/gt_counting.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/data/gt_counting.wav
--------------------------------------------------------------------------------
/diff_operators.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import grad
3 |
4 |
5 | def hessian(y, x):
6 | ''' hessian of y wrt x
7 | y: shape (meta_batch_size, num_observations, channels)
8 | x: shape (meta_batch_size, num_observations, 2)
9 | '''
10 | meta_batch_size, num_observations = y.shape[:2]
11 | grad_y = torch.ones_like(y[..., 0]).to(y.device)
12 | h = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device)
13 | for i in range(y.shape[-1]):
14 | # calculate dydx over batches for each feature value of y
15 | dydx = grad(y[..., i], x, grad_y, create_graph=True)[0]
16 |
17 | # calculate hessian on y for each x value
18 | for j in range(x.shape[-1]):
19 | h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][..., :]
20 |
21 | status = 0
22 | if torch.any(torch.isnan(h)):
23 | status = -1
24 | return h, status
25 |
26 |
27 | def laplace(y, x):
28 | grad = gradient(y, x)
29 | return divergence(grad, x)
30 |
31 |
32 | def divergence(y, x):
33 | div = 0.
34 | for i in range(y.shape[-1]):
35 | div += grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
36 | return div
37 |
38 |
39 | def gradient(y, x, grad_outputs=None):
40 | if grad_outputs is None:
41 | grad_outputs = torch.ones_like(y)
42 | grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
43 | return grad
44 |
45 |
46 | def jacobian(y, x):
47 | ''' jacobian of y wrt x '''
48 | meta_batch_size, num_observations = y.shape[:2]
49 | jac = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1]).to(y.device) # (meta_batch_size*num_points, 2, 2)
50 | for i in range(y.shape[-1]):
51 | # calculate dydx over batches for each feature value of y
52 | y_flat = y[...,i].view(-1, 1)
53 | jac[:, :, i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0]
54 |
55 | status = 0
56 | if torch.any(torch.isnan(jac)):
57 | status = -1
58 |
59 | return jac, status
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: siren
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - absl-py=0.9.0=py36_0
9 | - attrs=19.3.0=py_0
10 | - backcall=0.1.0=py36_0
11 | - blas=1.0=mkl
12 | - bleach=3.1.4=py_0
13 | - blinker=1.4=py36_0
14 | - bzip2=1.0.8=h516909a_2
15 | - c-ares=1.15.0=h7b6447c_1001
16 | - ca-certificates=2020.6.20=hecda079_0
17 | - cachetools=3.1.1=py_0
18 | - cairo=1.14.12=h8948797_3
19 | - certifi=2020.6.20=py36h9f0ad1d_0
20 | - cffi=1.14.0=py36he30daa8_1
21 | - chardet=3.0.4=py36_1003
22 | - click=7.1.2=py_0
23 | - cloudpickle=1.4.1=py_0
24 | - configargparse=1.1=py_0
25 | - cryptography=2.9.2=py36h1ba5d50_0
26 | - cudatoolkit=10.1.243=h6bb024c_0
27 | - cycler=0.10.0=py36_0
28 | - cytoolz=0.10.1=py36h7b6447c_0
29 | - dask-core=2.17.2=py_0
30 | - dbus=1.13.14=hb2f20db_0
31 | - decorator=4.4.2=py_0
32 | - defusedxml=0.6.0=py_0
33 | - entrypoints=0.3=py36_0
34 | - expat=2.2.6=he6710b0_0
35 | - ffmpeg=4.0.2=ha0c5888_2
36 | - fftw=3.3.8=nompi_h7f3a6c3_1110
37 | - fontconfig=2.13.0=h9420a91_0
38 | - freetype=2.9.1=h8a8886c_1
39 | - gettext=0.19.8.1=h5e8e0c9_1
40 | - ghostscript=9.22=hf484d3e_1001
41 | - giflib=5.1.9=h516909a_0
42 | - glib=2.63.1=h3eb4bd4_1
43 | - gmp=6.1.2=h6c8ec71_1
44 | - gnutls=3.5.19=h2a4e5f8_1
45 | - google-auth=1.14.1=py_0
46 | - google-auth-oauthlib=0.4.1=py_2
47 | - graphite2=1.3.13=he1b5a44_1001
48 | - graphviz=2.38.0=hcf1ce16_1009
49 | - grpcio=1.27.2=py36hf8bcb03_0
50 | - gst-plugins-base=1.14.0=hbbd80ab_1
51 | - gstreamer=1.14.0=hb31296c_0
52 | - h5py=2.8.0=py36h989c5e5_3
53 | - harfbuzz=1.8.8=hffaf4a1_0
54 | - hdf5=1.10.2=hba1933b_1
55 | - icu=58.2=he6710b0_3
56 | - idna=2.9=py_1
57 | - imageio=2.8.0=py_0
58 | - imagemagick=7.0.8_11=pl526hc610aec_0
59 | - importlib-metadata=1.6.0=py36_0
60 | - importlib_metadata=1.6.0=0
61 | - intel-openmp=2020.1=217
62 | - ipykernel=5.1.4=py36h39e3cac_0
63 | - ipython=7.13.0=py36h5ca1d4c_0
64 | - ipython_genutils=0.2.0=py36_0
65 | - ipywidgets=7.5.1=py_0
66 | - jbig=2.1=h516909a_2002
67 | - jedi=0.17.0=py36_0
68 | - jinja2=2.11.2=py_0
69 | - joblib=0.15.1=py_0
70 | - jpeg=9d=h516909a_0
71 | - jsonschema=3.2.0=py36_0
72 | - jupyter=1.0.0=py36_7
73 | - jupyter_client=6.1.3=py_0
74 | - jupyter_console=6.1.0=py_0
75 | - jupyter_core=4.6.3=py36_0
76 | - kiwisolver=1.2.0=py36hfd86e86_0
77 | - ld_impl_linux-64=2.33.1=h53a641e_7
78 | - libedit=3.1.20181209=hc058e9b_0
79 | - libffi=3.3=he6710b0_1
80 | - libgcc-ng=9.1.0=hdf63c60_0
81 | - libgfortran-ng=7.3.0=hdf63c60_0
82 | - libiconv=1.15=h516909a_1006
83 | - libpng=1.6.37=hbc83047_0
84 | - libprotobuf=3.11.4=hd408876_0
85 | - libsodium=1.0.16=h1bed415_0
86 | - libstdcxx-ng=9.1.0=hdf63c60_0
87 | - libtiff=4.1.0=h2733197_1
88 | - libtool=2.4.6=h14c3975_1002
89 | - libuuid=1.0.3=h1bed415_2
90 | - libwebp=0.5.2=7
91 | - libxcb=1.13=h1bed415_1
92 | - libxml2=2.9.9=hea5a465_1
93 | - lz4-c=1.9.2=he6710b0_0
94 | - markdown=3.1.1=py36_0
95 | - markupsafe=1.1.1=py36h7b6447c_0
96 | - mistune=0.8.4=py36h7b6447c_0
97 | - mkl=2020.1=217
98 | - mkl-service=2.3.0=py36he904b0f_0
99 | - mkl_fft=1.0.15=py36ha843d7b_0
100 | - mkl_random=1.1.1=py36h0573a6f_0
101 | - nbconvert=5.6.1=py36_0
102 | - nbformat=5.0.6=py_0
103 | - ncurses=6.2=he6710b0_1
104 | - nettle=3.3=0
105 | - networkx=2.4=py_0
106 | - ninja=1.9.0=py36hfd86e86_0
107 | - notebook=6.0.3=py36_0
108 | - numpy=1.18.1=py36h4f9e942_0
109 | - numpy-base=1.18.1=py36hde5b4d6_1
110 | - oauthlib=3.1.0=py_0
111 | - olefile=0.46=py36_0
112 | - openh264=1.8.0=hdbcaa40_1000
113 | - openjpeg=2.3.1=h981e76c_3
114 | - openssl=1.1.1g=h516909a_0
115 | - pandoc=2.2.3.2=0
116 | - pandocfilters=1.4.2=py36_1
117 | - pango=1.40.14=he752989_2
118 | - parso=0.7.0=py_0
119 | - pcre=8.43=he6710b0_0
120 | - perl=5.26.2=h516909a_1006
121 | - pexpect=4.8.0=py36_0
122 | - pickleshare=0.7.5=py36_0
123 | - pillow=6.2.1=py36h34e0f95_0
124 | - pip=20.0.2=py36_3
125 | - pixman=0.38.0=h516909a_1003
126 | - pkg-config=0.29.2=h516909a_1006
127 | - prometheus_client=0.7.1=py_0
128 | - prompt-toolkit=3.0.5=py_0
129 | - prompt_toolkit=3.0.5=0
130 | - protobuf=3.11.4=py36he6710b0_0
131 | - ptyprocess=0.6.0=py36_0
132 | - pyasn1=0.4.8=py_0
133 | - pyasn1-modules=0.2.7=py_0
134 | - pycparser=2.20=py_0
135 | - pygments=2.6.1=py_0
136 | - pyjwt=1.7.1=py36_0
137 | - pyopenssl=19.1.0=py36_0
138 | - pyparsing=2.4.7=py_0
139 | - pyqt=5.9.2=py36h05f1152_2
140 | - pyrsistent=0.16.0=py36h7b6447c_0
141 | - pysocks=1.7.1=py36_0
142 | - python=3.6.10=h7579374_2
143 | - python-dateutil=2.8.1=py_0
144 | - python_abi=3.6=1_cp36m
145 | - pytorch=1.5.0=py3.6_cuda10.1.243_cudnn7.6.3_0
146 | - pywavelets=1.1.1=py36h7b6447c_0
147 | - pyyaml=5.3.1=py36h7b6447c_0
148 | - pyzmq=18.1.1=py36he6710b0_0
149 | - qt=5.9.7=h5867ecd_1
150 | - qtconsole=4.7.4=py_0
151 | - qtpy=1.9.0=py_0
152 | - readline=8.0=h7b6447c_0
153 | - requests=2.23.0=py36_0
154 | - requests-oauthlib=1.3.0=py_0
155 | - rsa=4.0=py_0
156 | - scikit-image=0.16.2=py36h0573a6f_0
157 | - scikit-learn=0.22.1=py36hd81dba3_0
158 | - scikit-video=1.1.11=pyh24bf2e0_0
159 | - scipy=1.4.1=py36h0b6359f_0
160 | - send2trash=1.5.0=py36_0
161 | - setuptools=46.4.0=py36_0
162 | - sip=4.19.8=py36hf484d3e_0
163 | - six=1.14.0=py36_0
164 | - sqlite=3.31.1=h62c20be_1
165 | - tensorboard-plugin-wit=1.6.0=py_0
166 | - terminado=0.8.3=py36_0
167 | - testpath=0.4.4=py_0
168 | - tk=8.6.8=hbc83047_0
169 | - toolz=0.10.0=py_0
170 | - torchvision=0.6.0=py36_cu101
171 | - tornado=6.0.4=py36h7b6447c_1
172 | - tqdm=4.46.0=py_0
173 | - traitlets=4.3.3=py36_0
174 | - urllib3=1.25.8=py36_0
175 | - wcwidth=0.1.9=py_0
176 | - webencodings=0.5.1=py36_1
177 | - werkzeug=1.0.1=py_0
178 | - wheel=0.34.2=py36_0
179 | - widgetsnbextension=3.5.1=py36_0
180 | - x264=1!152.20180717=h14c3975_1001
181 | - xorg-kbproto=1.0.7=h14c3975_1002
182 | - xorg-libice=1.0.10=h516909a_0
183 | - xorg-libsm=1.2.2=h470a237_5
184 | - xorg-libx11=1.6.9=h516909a_0
185 | - xorg-libxext=1.3.4=h516909a_0
186 | - xorg-libxpm=3.5.13=h516909a_0
187 | - xorg-libxrender=0.9.10=h516909a_1002
188 | - xorg-libxt=1.1.5=h516909a_1003
189 | - xorg-renderproto=0.11.1=h14c3975_1002
190 | - xorg-xextproto=7.3.0=h14c3975_1002
191 | - xorg-xproto=7.0.31=h14c3975_1007
192 | - xz=5.2.5=h7b6447c_0
193 | - yaml=0.1.7=had09818_2
194 | - zeromq=4.3.1=he6710b0_3
195 | - zipp=3.1.0=py_0
196 | - zlib=1.2.11=h7b6447c_3
197 | - zstd=1.4.4=h0b5b093_3
198 | - pip:
199 | - astor==0.8.1
200 | - cmapy==0.6.6
201 | - gast==0.2.2
202 | - google-pasta==0.2.0
203 | - imageio-ffmpeg==0.4.2
204 | - imbalanced-learn==0.6.2
205 | - keras-applications==1.0.8
206 | - keras-preprocessing==1.1.2
207 | - matplotlib==2.2.5
208 | - moviepy==1.0.3
209 | - opencv-python==3.4.9.33
210 | - opt-einsum==3.2.1
211 | - proglog==0.1.9
212 | - pytz==2020.1
213 | - tensorboard==1.15.0
214 | - tensorflow==1.15.0
215 | - tensorflow-estimator==1.15.1
216 | - termcolor==1.1.0
217 | - wrapt==1.12.1
218 | - youtube-dl==2020.6.6
219 |
220 |
--------------------------------------------------------------------------------
/experiment_scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/experiment_scripts/__init__.py
--------------------------------------------------------------------------------
/experiment_scripts/fwi/data_cylinder_5.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/experiment_scripts/fwi/data_cylinder_5.mat
--------------------------------------------------------------------------------
/experiment_scripts/test_audio.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
5 |
6 | import dataio, utils, modules
7 |
8 | from torch.utils.data import DataLoader
9 | import configargparse
10 | import torch
11 | import scipy.io.wavfile as wavfile
12 |
13 | p = configargparse.ArgumentParser()
14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
15 |
16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
17 | p.add_argument('--experiment_name', type=str, default='audio',
18 | help='Name of subdirectory in logging_root where wav file will be saved.')
19 | p.add_argument('--gt_wav_path', type=str, default='../data/gt_bach.wav', help='ground truth wav path')
20 |
21 | p.add_argument('--model_type', type=str, default='sine',
22 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,'
23 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),'
24 | 'and in the future: "mixed" (first layer sine, other layers tanh)')
25 | p.add_argument('--checkpoint_path', required=True, help='Checkpoint to trained model.')
26 |
27 | opt = p.parse_args()
28 |
29 | audio_dataset = dataio.AudioFile(filename=opt.gt_wav_path)
30 | coord_dataset = dataio.ImplicitAudioWrapper(audio_dataset)
31 |
32 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)
33 |
34 | # Define the model and load in checkpoint path
35 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh':
36 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', in_features=1)
37 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf':
38 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, fn_samples=len(audio_dataset.data), in_features=1)
39 | else:
40 | raise NotImplementedError
41 | model.load_state_dict(torch.load(opt.checkpoint_path))
42 | model.cuda()
43 |
44 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
45 | utils.cond_mkdir(root_path)
46 |
47 | # Get ground truth and input data
48 | model_input, gt = next(iter(dataloader))
49 | model_input = {key: value.cuda() for key, value in model_input.items()}
50 | gt = {key: value.cuda() for key, value in gt.items()}
51 |
52 | # Evaluate the trained model
53 | with torch.no_grad():
54 | model_output = model(model_input)
55 |
56 | waveform = torch.squeeze(model_output['model_out']).detach().cpu().numpy()
57 | rate = torch.squeeze(gt['rate']).detach().cpu().numpy()
58 | wavfile.write(os.path.join(opt.logging_root, opt.experiment_name, 'pred_waveform.wav'), rate, waveform)
--------------------------------------------------------------------------------
/experiment_scripts/test_conv_neural_process.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
5 |
6 | import dataio, meta_modules, utils, training, loss_functions
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 |
10 | import torch
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 |
14 | import imageio
15 | from functools import partial
16 | import random
17 | from tqdm.autonotebook import tqdm
18 | import time
19 | import utils
20 | from torch.utils.tensorboard import SummaryWriter
21 |
22 | p = configargparse.ArgumentParser()
23 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
24 |
25 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
26 | p.add_argument('--experiment_name', type=str, required=True,
27 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
28 |
29 | # General training options
30 | p.add_argument('--checkpoint_path', default=None, type=str, required=True,
31 | help='path to directory where checkpoints & tensorboard events will be saved.')
32 | p.add_argument('--dataset', type=str, default='celeba_32x32',
33 | help='Time interval in seconds until tensorboard summary is saved.')
34 | p.add_argument('--model_type', type=str, default='sine',
35 | help='Nonlinearity in the neural implicit representation')
36 | p.add_argument('--test_sparsity', type=float, default=200,
37 | help='Amount of subsampled pixels input into the set encoder')
38 | p.add_argument('--partial_conv', action='store_true', default=False, help='Use a partial convolution encoder')
39 | opt = p.parse_args()
40 |
41 | if opt.experiment_name is None:
42 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_TEST'
43 | else:
44 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_' + opt.experiment_name
45 |
46 | assert opt.dataset == 'celeba_32x32'
47 | img_dataset_test = dataio.CelebA(split='test', downsampled=True)
48 | coord_dataset_test = dataio.Implicit2DWrapper(img_dataset_test, sidelength=(32, 32))
49 | generalization_dataset_test = dataio.ImageGeneralizationWrapper(coord_dataset_test, test_sparsity=200,
50 | generalization_mode='conv_cnp_test')
51 | image_resolution = (32, 32)
52 |
53 | img_dataset_train = dataio.CelebA(split='train', downsampled=True)
54 | coord_dataset_train = dataio.Implicit2DWrapper(img_dataset_train, sidelength=(32, 32))
55 | generalization_dataset_train = dataio.ImageGeneralizationWrapper(coord_dataset_train, test_sparsity=200,
56 | generalization_mode='conv_cnp_test')
57 |
58 | # Define the model.
59 | model = meta_modules.ConvolutionalNeuralProcessImplicit2DHypernet(in_features=img_dataset_test.img_channels,
60 | out_features=img_dataset_test.img_channels,
61 | image_resolution=image_resolution,
62 | partial_conv=opt.partial_conv)
63 | model.cuda()
64 | model.eval()
65 |
66 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
67 | utils.cond_mkdir(root_path)
68 |
69 |
70 | # Load checkpoint
71 | model.load_state_dict(torch.load(opt.checkpoint_path))
72 |
73 | # First experiment: Upsample training image
74 | model_input = {'coords':dataio.get_mgrid(image_resolution)[None,:].cuda(),
75 | 'img_sparse':generalization_dataset_train[0][0]['img_sparse'].unsqueeze(0).cuda()}
76 | model_output = model(model_input)
77 |
78 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1,2,0).detach().cpu().numpy()
79 | out_img += 1
80 | out_img /= 2.
81 | out_img = np.clip(out_img, 0., 1.)
82 |
83 | imageio.imwrite(os.path.join(root_path, 'upsampled_train.png'), out_img)
84 |
85 | # Second experiment: sample larger range
86 | model_input = {'coords':dataio.get_mgrid(image_resolution)[None,:].cuda()*5,
87 | 'img_sparse':generalization_dataset_train[0][0]['img_sparse'].unsqueeze(0).cuda()}
88 | model_output = model(model_input)
89 |
90 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1,2,0).detach().cpu().numpy()
91 | out_img += 1
92 | out_img /= 2.
93 | out_img = np.clip(out_img, 0., 1.)
94 |
95 | imageio.imwrite(os.path.join(root_path, 'outside_range.png'), out_img)
96 |
97 | # Third experiment: interpolate between latent codes
98 | idx1, idx2 = 57, 181
99 | model_input_1 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(),
100 | 'img_sparse': generalization_dataset_train[idx1][0]['img_sparse'].unsqueeze(0).cuda()}
101 | model_input_2 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(),
102 | 'img_sparse': generalization_dataset_train[idx2][0]['img_sparse'].unsqueeze(0).cuda()}
103 |
104 | embedding_1 = model.get_hypo_net_weights(model_input_1)[1]
105 | embedding_2 = model.get_hypo_net_weights(model_input_2)[1]
106 | for i in np.linspace(0,1,8):
107 | embedding = i*embedding_1 + (1.-i)*embedding_2
108 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 'embedding': embedding}
109 | model_output = model(model_input)
110 |
111 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1,2,0).detach().cpu().numpy()
112 | out_img += 1
113 | out_img /= 2.
114 | out_img = np.clip(out_img, 0., 1.)
115 |
116 | if i == 0.:
117 | out_img_cat = out_img
118 | else:
119 | out_img_cat = np.concatenate((out_img_cat, out_img), axis=1)
120 |
121 | imageio.imwrite(os.path.join(root_path, 'interpolated_image.png'), out_img_cat)
122 |
123 | # Fourth experiment: Fit test images
124 | def to_uint8(img):
125 | img = img * 255
126 | img = img.astype(np.uint8)
127 | return img
128 |
129 | def getTestMSE(dataloader, subdir):
130 | MSEs = []
131 | total_steps = 0
132 | utils.cond_mkdir(os.path.join(root_path, subdir))
133 | utils.cond_mkdir(os.path.join(root_path, 'ground_truth'))
134 |
135 | with tqdm(total=len(dataloader)) as pbar:
136 | for step, (model_input, gt) in enumerate(dataloader):
137 | model_input['idx'] = torch.Tensor([model_input['idx']]).long()
138 | model_input = {key: value.cuda() for key, value in model_input.items()}
139 | gt = {key: value.cuda() for key, value in gt.items()}
140 |
141 | with torch.no_grad():
142 | model_output = model(model_input)
143 |
144 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy()
145 | out_img += 1
146 | out_img /= 2.
147 | out_img = np.clip(out_img, 0., 1.)
148 | gt_img = dataio.lin2img(gt['img'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy()
149 | gt_img += 1
150 | gt_img /= 2.
151 | gt_img = np.clip(gt_img, 0., 1.)
152 |
153 | sparse_img = model_input['img_sparse'].squeeze().detach().cpu().permute(1,2,0).numpy()
154 | mask = np.sum((sparse_img == 0), axis=2) == 3
155 | sparse_img += 1
156 | sparse_img /= 2.
157 | sparse_img = np.clip(sparse_img, 0., 1.)
158 | sparse_img[mask, ...] = 1.
159 |
160 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps)+'_sparse.png'), to_uint8(sparse_img))
161 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps)+'.png'), to_uint8(out_img))
162 | imageio.imwrite(os.path.join(root_path, 'ground_truth', str(total_steps)+'.png'), to_uint8(gt_img))
163 |
164 | MSE = np.mean((out_img - gt_img) ** 2)
165 | MSEs.append(MSE)
166 |
167 | pbar.update(1)
168 | total_steps += 1
169 |
170 | return MSEs
171 |
172 | sparsities = [10, 100, 1000, 'full', 'half']
173 | for sparsity in sparsities:
174 | generalization_dataset_test.update_test_sparsity(sparsity)
175 | dataloader = DataLoader(generalization_dataset_test, shuffle=False, batch_size=1, pin_memory=True, num_workers=0)
176 | MSE = getTestMSE(dataloader, 'test_'+str(sparsity)+'_pixels')
177 | np.save(os.path.join(root_path, 'MSE_'+str(sparsity)+'_context.npy'), MSE)
178 | print(np.mean(MSE))
179 |
--------------------------------------------------------------------------------
/experiment_scripts/test_neural_process.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import os
3 | import sys
4 |
5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6 |
7 | import dataio, meta_modules
8 | import numpy as np
9 |
10 | import torch
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 |
14 | import imageio
15 | from tqdm.autonotebook import tqdm
16 | import utils
17 | import skimage
18 |
19 | p = configargparse.ArgumentParser()
20 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
21 |
22 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
23 | p.add_argument('--experiment_name', type=str, required=True,
24 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
25 |
26 | # General training options
27 | p.add_argument('--checkpoint_path', default=None, type=str, required=True,
28 | help='path to directory where checkpoints & tensorboard events will be saved.')
29 | p.add_argument('--dataset', type=str, default='celeba_32x32',
30 | help='Time interval in seconds until tensorboard summary is saved.')
31 | p.add_argument('--model_type', type=str, default='sine',
32 | help='Nonlinearity in the neural implicit representation')
33 | p.add_argument('--test_sparsity', type=float, default=200,
34 | help='Amount of subsampled pixels input into the set encoder')
35 | opt = p.parse_args()
36 |
37 | if opt.experiment_name is None:
38 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_TEST'
39 | else:
40 | opt.experiment_name = opt.checkpoint_path.split('/')[-3] + '_' + opt.experiment_name
41 |
42 | assert opt.dataset == 'celeba_32x32'
43 | img_dataset_test = dataio.CelebA(split='test', downsampled=True)
44 | coord_dataset_test = dataio.Implicit2DWrapper(img_dataset_test, sidelength=(32, 32))
45 | generalization_dataset_test = dataio.ImageGeneralizationWrapper(coord_dataset_test, test_sparsity=10,
46 | generalization_mode='cnp_test')
47 | image_resolution = (32, 32)
48 |
49 | img_dataset_train = dataio.CelebA(split='train', downsampled=True)
50 | coord_dataset_train = dataio.Implicit2DWrapper(img_dataset_train, sidelength=(32, 32))
51 | generalization_dataset_train = dataio.ImageGeneralizationWrapper(coord_dataset_train, test_sparsity=10,
52 | generalization_mode='cnp_test')
53 |
54 | # Define the model.
55 | model = meta_modules.NeuralProcessImplicit2DHypernet(in_features=img_dataset_test.img_channels + 2,
56 | out_features=img_dataset_test.img_channels,
57 | image_resolution=image_resolution)
58 | model.cuda()
59 | model.eval()
60 |
61 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
62 | utils.cond_mkdir(root_path)
63 |
64 | # Load checkpoint
65 | model.load_state_dict(torch.load(opt.checkpoint_path))
66 |
67 | # First experiment: Upsample training image
68 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(),
69 | 'img_sub': generalization_dataset_train[0][0]['img_sub'].unsqueeze(0).cuda(),
70 | 'coords_sub': generalization_dataset_train[0][0]['coords_sub'].unsqueeze(0).cuda()}
71 | model_output = model(model_input)
72 |
73 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy()
74 | out_img += 1
75 | out_img /= 2.
76 | out_img = np.clip(out_img, 0., 1.)
77 |
78 | imageio.imwrite(os.path.join(root_path, 'upsampled_train.png'), out_img)
79 |
80 | # Second experiment: sample larger range
81 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda() * 5,
82 | 'img_sub': generalization_dataset_train[0][0]['img_sub'].unsqueeze(0).cuda(),
83 | 'coords_sub': generalization_dataset_train[0][0]['coords_sub'].unsqueeze(0).cuda()}
84 | model_output = model(model_input)
85 |
86 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy()
87 | out_img += 1
88 | out_img /= 2.
89 | out_img = np.clip(out_img, 0., 1.)
90 |
91 | imageio.imwrite(os.path.join(root_path, 'outside_range.png'), out_img)
92 |
93 | # Third experiment: interpolate between latent codes
94 | idx1, idx2 = 57, 181
95 | model_input_1 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(),
96 | 'img_sub': generalization_dataset_train[idx1][0]['img_sub'].unsqueeze(0).cuda(),
97 | 'coords_sub': generalization_dataset_train[idx1][0]['coords_sub'].unsqueeze(0).cuda()}
98 | model_input_2 = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(),
99 | 'img_sub': generalization_dataset_train[idx2][0]['img_sub'].unsqueeze(0).cuda(),
100 | 'coords_sub': generalization_dataset_train[idx2][0]['coords_sub'].unsqueeze(0).cuda()}
101 |
102 | embedding_1 = model.get_hypo_net_weights(model_input_1)[1]
103 | embedding_2 = model.get_hypo_net_weights(model_input_2)[1]
104 | for i in np.linspace(0, 1, 8):
105 | embedding = i * embedding_1 + (1. - i) * embedding_2
106 | model_input = {'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 'embedding': embedding}
107 | model_output = model(model_input)
108 |
109 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2,
110 | 0).detach().cpu().numpy()
111 | out_img += 1
112 | out_img /= 2.
113 | out_img = np.clip(out_img, 0., 1.)
114 |
115 | if i == 0.:
116 | out_img_cat = out_img
117 | else:
118 | out_img_cat = np.concatenate((out_img_cat, out_img), axis=1)
119 |
120 | imageio.imwrite(os.path.join(root_path, 'interpolated_image.png'), out_img_cat)
121 |
122 |
123 | # Fourth experiment: Fit test images
124 | def to_uint8(img):
125 | img = img * 255
126 | img = img.astype(np.uint8)
127 | return img
128 |
129 |
130 | def getTestMSE(dataloader, subdir):
131 | MSEs = []
132 | PSNRs = []
133 | total_steps = 0
134 | utils.cond_mkdir(os.path.join(root_path, subdir))
135 | utils.cond_mkdir(os.path.join(root_path, 'ground_truth'))
136 |
137 | with tqdm(total=len(dataloader)) as pbar:
138 | for step, (model_input, gt) in enumerate(dataloader):
139 | model_input['idx'] = torch.Tensor([model_input['idx']]).long()
140 | model_input = {key: value.cuda() for key, value in model_input.items()}
141 | gt = {key: value.cuda() for key, value in gt.items()}
142 |
143 | with torch.no_grad():
144 | model_output = model(model_input)
145 |
146 | out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute(1, 2,
147 | 0).detach().cpu().numpy()
148 | out_img += 1
149 | out_img /= 2.
150 | out_img = np.clip(out_img, 0., 1.)
151 | gt_img = dataio.lin2img(gt['img'], image_resolution).squeeze().permute(1, 2, 0).detach().cpu().numpy()
152 | gt_img += 1
153 | gt_img /= 2.
154 | gt_img = np.clip(gt_img, 0., 1.)
155 |
156 | sparse_img = np.ones((image_resolution[0], image_resolution[1], 3))
157 | coords_sub = model_input['coords_sub'].squeeze().detach().cpu().numpy()
158 | rgb_sub = model_input['img_sub'].squeeze().detach().cpu().numpy()
159 | for index in range(0, coords_sub.shape[0]):
160 | r = int(round((coords_sub[index][0] + 1) / 2 * 31))
161 | c = int(round((coords_sub[index][1] + 1) / 2 * 31))
162 | sparse_img[r, c, :] = np.clip((rgb_sub[index, :] + 1) / 2, 0., 1.)
163 |
164 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps) + '_sparse.png'), to_uint8(sparse_img))
165 | imageio.imwrite(os.path.join(root_path, subdir, str(total_steps) + '.png'), to_uint8(out_img))
166 | imageio.imwrite(os.path.join(root_path, 'ground_truth', str(total_steps) + '.png'), to_uint8(gt_img))
167 |
168 | MSE = np.mean((out_img - gt_img) ** 2)
169 | MSEs.append(MSE)
170 |
171 | PSNR = skimage.measure.compare_psnr(out_img, gt_img, data_range=1)
172 | PSNRs.append(PSNR)
173 |
174 | pbar.update(1)
175 | total_steps += 1
176 |
177 | return MSEs, PSNRs
178 |
179 |
180 | sparsities = [10, 100, 1000, 'full', 'half']
181 | for sparsity in sparsities:
182 | generalization_dataset_test.update_test_sparsity(sparsity)
183 | dataloader = DataLoader(generalization_dataset_test, shuffle=False, batch_size=1, pin_memory=True, num_workers=0)
184 | MSE, PSNR = getTestMSE(dataloader, 'test_' + str(sparsity) + '_pixels')
185 | np.save(os.path.join(root_path, 'MSE_' + str(sparsity) + '_context.npy'), MSE)
186 | np.save(os.path.join(root_path, 'PSNR_' + str(sparsity) + '_context.npy'), PSNR)
187 | print(np.mean(MSE))
188 |
--------------------------------------------------------------------------------
/experiment_scripts/test_sdf.py:
--------------------------------------------------------------------------------
1 | '''Test script for experiments in paper Sec. 4.2, Supplement Sec. 3, reconstruction from laplacian.
2 | '''
3 |
4 | # Enable import from parent package
5 | import os
6 | import sys
7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8 |
9 | import torch
10 | import modules, utils
11 | import sdf_meshing
12 | import configargparse
13 |
14 | p = configargparse.ArgumentParser()
15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
16 |
17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
18 | p.add_argument('--experiment_name', type=str, required=True,
19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
20 |
21 | # General training options
22 | p.add_argument('--batch_size', type=int, default=16384)
23 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
24 |
25 | p.add_argument('--model_type', type=str, default='sine',
26 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)')
27 | p.add_argument('--mode', type=str, default='mlp',
28 | help='Options are "mlp" or "nerf"')
29 | p.add_argument('--resolution', type=int, default=1600)
30 |
31 | opt = p.parse_args()
32 |
33 |
34 | class SDFDecoder(torch.nn.Module):
35 | def __init__(self):
36 | super().__init__()
37 | # Define the model.
38 | if opt.mode == 'mlp':
39 | self.model = modules.SingleBVPNet(type=opt.model_type, final_layer_factor=1, in_features=3)
40 | elif opt.mode == 'nerf':
41 | self.model = modules.SingleBVPNet(type='relu', mode='nerf', final_layer_factor=1, in_features=3)
42 | self.model.load_state_dict(torch.load(opt.checkpoint_path))
43 | self.model.cuda()
44 |
45 | def forward(self, coords):
46 | model_in = {'coords': coords}
47 | return self.model(model_in)['model_out']
48 |
49 |
50 | sdf_decoder = SDFDecoder()
51 |
52 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
53 | utils.cond_mkdir(root_path)
54 |
55 | sdf_meshing.create_mesh(sdf_decoder, os.path.join(root_path, 'test'), N=opt.resolution)
56 |
--------------------------------------------------------------------------------
/experiment_scripts/train_audio.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
5 |
6 | import dataio, meta_modules, utils, training, loss_functions, modules
7 |
8 | from torch.utils.data import DataLoader
9 | import configargparse
10 | from functools import partial
11 |
12 | p = configargparse.ArgumentParser()
13 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
14 |
15 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
16 | p.add_argument('--experiment_name', type=str, default='audio',
17 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
18 |
19 | p.add_argument('--wav_path', type=str, default='../data/gt_bach.wav', help='root for logging')
20 |
21 | # General training options
22 | p.add_argument('--batch_size', type=int, default=1)
23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5')
24 | p.add_argument('--num_epochs', type=int, default=9001,
25 | help='Number of epochs to train for.')
26 |
27 | p.add_argument('--epochs_til_ckpt', type=int, default=1000,
28 | help='Time interval in seconds until checkpoint is saved.')
29 | p.add_argument('--steps_til_summary', type=int, default=1000,
30 | help='Time interval in seconds until tensorboard summary is saved.')
31 |
32 | p.add_argument('--model_type', type=str, default='sine',
33 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,'
34 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),'
35 | 'and in the future: "mixed" (first layer sine, other layers tanh)')
36 |
37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
38 | opt = p.parse_args()
39 |
40 | audio_dataset = dataio.AudioFile(filename=opt.wav_path)
41 | coord_dataset = dataio.ImplicitAudioWrapper(audio_dataset)
42 |
43 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
44 |
45 | # Define the model.
46 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh':
47 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', in_features=1)
48 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf':
49 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, fn_samples=len(audio_dataset.data), in_features=1)
50 | else:
51 | raise NotImplementedError
52 | model.cuda()
53 |
54 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
55 | utils.cond_mkdir(root_path)
56 |
57 | # Define the loss
58 | loss_fn = loss_functions.function_mse
59 | summary_fn = partial(utils.write_audio_summary, root_path)
60 |
61 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
62 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
63 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn)
--------------------------------------------------------------------------------
/experiment_scripts/train_helmholtz.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
5 |
6 | import dataio, meta_modules, utils, training, loss_functions, modules
7 |
8 | from torch.utils.data import DataLoader
9 | import torch
10 | import configargparse
11 | import numpy as np
12 |
13 | p = configargparse.ArgumentParser()
14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
15 |
16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
17 | p.add_argument('--experiment_name', type=str, required=True,
18 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
19 |
20 | # General training options
21 | p.add_argument('--batch_size', type=int, default=32)
22 | p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5')
23 | p.add_argument('--num_epochs', type=int, default=50000,
24 | help='Number of epochs to train for.')
25 |
26 | p.add_argument('--epochs_til_ckpt', type=int, default=1000,
27 | help='Time interval in seconds until checkpoint is saved.')
28 | p.add_argument('--steps_til_summary', type=int, default=100,
29 | help='Time interval in seconds until tensorboard summary is saved.')
30 | p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid', 'relu'],
31 | help='Type of model to evaluate, default is sine.')
32 | p.add_argument('--mode', type=str, default='mlp', required=False, choices=['mlp', 'rbf', 'pinn'],
33 | help='Whether to use uniform velocity parameter')
34 | p.add_argument('--velocity', type=str, default='uniform', required=False, choices=['uniform', 'square', 'circle'],
35 | help='Whether to use uniform velocity parameter')
36 | p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.')
37 | p.add_argument('--use_lbfgs', default=False, type=bool, help='use L-BFGS.')
38 |
39 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
40 | opt = p.parse_args()
41 |
42 | # if we have a velocity perturbation, offset the source
43 | if opt.velocity!='uniform':
44 | source_coords = [-0.35, 0.]
45 | else:
46 | source_coords = [0., 0.]
47 |
48 | dataset = dataio.SingleHelmholtzSource(sidelength=230, velocity=opt.velocity, source_coords=source_coords)
49 |
50 | dataloader = DataLoader(dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
51 |
52 | # Define the model.
53 | if opt.mode == 'pinn':
54 | model = modules.PINNet(out_features=2, type='tanh', mode=opt.mode)
55 | opt.use_lbfgs = True
56 | else:
57 | model = modules.SingleBVPNet(out_features=2, type=opt.model, mode=opt.mode, final_layer_factor=1.)
58 |
59 | model.cuda()
60 |
61 | # Define the loss
62 | loss_fn = loss_functions.helmholtz_pml
63 | summary_fn = utils.write_helmholtz_summary
64 |
65 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
66 |
67 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
68 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
69 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=opt.clip_grad,
70 | use_lbfgs=opt.use_lbfgs)
71 |
--------------------------------------------------------------------------------
/experiment_scripts/train_img.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
5 |
6 | import dataio, meta_modules, utils, training, loss_functions, modules
7 |
8 | from torch.utils.data import DataLoader
9 | import configargparse
10 | from functools import partial
11 |
12 | p = configargparse.ArgumentParser()
13 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
14 |
15 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
16 | p.add_argument('--experiment_name', type=str, required=True,
17 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
18 |
19 | # General training options
20 | p.add_argument('--batch_size', type=int, default=1)
21 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=1e-4')
22 | p.add_argument('--num_epochs', type=int, default=10000,
23 | help='Number of epochs to train for.')
24 |
25 | p.add_argument('--epochs_til_ckpt', type=int, default=25,
26 | help='Time interval in seconds until checkpoint is saved.')
27 | p.add_argument('--steps_til_summary', type=int, default=1000,
28 | help='Time interval in seconds until tensorboard summary is saved.')
29 |
30 | p.add_argument('--model_type', type=str, default='sine',
31 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,'
32 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),'
33 | 'and in the future: "mixed" (first layer sine, other layers tanh)')
34 |
35 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
36 | opt = p.parse_args()
37 |
38 | img_dataset = dataio.Camera()
39 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all')
40 | image_resolution = (512, 512)
41 |
42 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
43 |
44 | # Define the model.
45 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh' or opt.model_type == 'selu' or opt.model_type == 'elu'\
46 | or opt.model_type == 'softplus':
47 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', sidelength=image_resolution)
48 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf':
49 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, sidelength=image_resolution)
50 | else:
51 | raise NotImplementedError
52 | model.cuda()
53 |
54 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
55 |
56 | # Define the loss
57 | loss_fn = partial(loss_functions.image_mse, None)
58 | summary_fn = partial(utils.write_image_summary, image_resolution)
59 |
60 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
61 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
62 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn)
63 |
--------------------------------------------------------------------------------
/experiment_scripts/train_img_inpainting.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
5 |
6 | import dataio, meta_modules, utils, training, loss_functions, modules
7 |
8 | from torch.utils.data import DataLoader
9 | import configargparse
10 | from functools import partial
11 | import torch
12 | from PIL import Image
13 | from torchvision.transforms import ToTensor
14 | import numpy as np
15 |
16 | p = configargparse.ArgumentParser()
17 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
18 |
19 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
20 | p.add_argument('--experiment_name', type=str, required=True,
21 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
22 |
23 | # General training options
24 | p.add_argument('--batch_size', type=int, default=1)
25 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=1e-4')
26 | p.add_argument('--num_epochs', type=int, default=10000,
27 | help='Number of epochs to train for.')
28 | p.add_argument('--k1', type=float, default=1, help='weight on prior')
29 | p.add_argument('--sparsity', type=float, default=0.1, help='percentage of pixels filled')
30 | p.add_argument('--prior', type=str, default=None, help='prior')
31 | p.add_argument('--downsample', action='store_true', default=False, help='use image downsampling kernel')
32 |
33 | p.add_argument('--epochs_til_ckpt', type=int, default=25,
34 | help='Time interval in seconds until checkpoint is saved.')
35 | p.add_argument('--steps_til_summary', type=int, default=1000,
36 | help='Time interval in seconds until tensorboard summary is saved.')
37 |
38 | p.add_argument('--dataset', type=str, default='camera',
39 | help='Time interval in seconds until tensorboard summary is saved.')
40 | p.add_argument('--model_type', type=str, default='sine',
41 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,'
42 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu),'
43 | 'and in the future: "mixed" (first layer sine, other layers tanh)')
44 |
45 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
46 |
47 | p.add_argument('--mask_path', type=str, default=None, help='Path to mask image')
48 | p.add_argument('--custom_image', type=str, default=None, help='Path to single training image')
49 | opt = p.parse_args()
50 |
51 |
52 | if opt.dataset == 'camera':
53 | img_dataset = dataio.Camera()
54 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=512, compute_diff='all')
55 | image_resolution = (512, 512)
56 | if opt.dataset == 'camera_downsampled':
57 | img_dataset = dataio.Camera(downsample_factor=2)
58 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='all')
59 | image_resolution = (256, 256)
60 | if opt.dataset == 'custom':
61 | img_dataset = dataio.ImageFile(opt.custom_image)
62 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=(img_dataset[0].size[1], img_dataset[0].size[0]),
63 | compute_diff='all')
64 | image_resolution = (img_dataset[0].size[1], img_dataset[0].size[0])
65 |
66 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
67 |
68 | # Define the model.
69 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh':
70 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', out_features=img_dataset.img_channels, sidelength=image_resolution,
71 | downsample=opt.downsample)
72 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf':
73 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, out_features=img_dataset.img_channels, sidelength=image_resolution,
74 | downsample=opt.downsample)
75 | else:
76 | raise NotImplementedError
77 | model.cuda()
78 |
79 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
80 |
81 | if opt.mask_path:
82 | mask = Image.open(opt.mask_path)
83 | mask = ToTensor()(mask)
84 | mask = mask.float().cuda()
85 | percentage = torch.sum(mask).cpu().numpy() / np.prod(mask.shape)
86 | print("mask sparsity %f" % (percentage))
87 | else:
88 | mask = torch.rand(image_resolution) < opt.sparsity
89 | mask = mask.float().cuda()
90 |
91 | # Define the loss
92 | if opt.prior is None:
93 | loss_fn = partial(loss_functions.image_mse, mask.view(-1,1))
94 | elif opt.prior == 'TV':
95 | loss_fn = partial(loss_functions.image_mse_TV_prior, mask.view(-1,1), opt.k1, model)
96 | elif opt.prior == 'FH':
97 | loss_fn = partial(loss_functions.image_mse_FH_prior, mask.view(-1,1), opt.k1, model)
98 | summary_fn = partial(utils.write_image_summary, image_resolution)
99 |
100 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
101 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
102 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn)
103 |
--------------------------------------------------------------------------------
/experiment_scripts/train_img_neural_process.py:
--------------------------------------------------------------------------------
1 | # Enable import from parent package
2 | import sys
3 | import os
4 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
5 |
6 | import dataio, meta_modules, utils, training, loss_functions
7 |
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import configargparse
11 | from functools import partial
12 |
13 | p = configargparse.ArgumentParser()
14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
15 |
16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
17 | p.add_argument('--experiment_name', type=str, required=True,
18 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
19 |
20 | # General training options
21 | p.add_argument('--batch_size', type=int, default=100)
22 | p.add_argument('--lr', type=float, default=5e-5, help='learning rate. default=5e-5')
23 | p.add_argument('--num_epochs', type=int, default=401,
24 | help='Number of epochs to train for.')
25 | p.add_argument('--kl_weight', type=float, default=1e-1,
26 | help='Weight for l2 loss term on code vectors z (lambda_latent in paper).')
27 | p.add_argument('--fw_weight', type=float, default=1e2,
28 | help='Weight for the l2 loss term on the weights of the sine network')
29 | p.add_argument('--train_sparsity_range', type=int, nargs='+', default=[10, 200],
30 | help='Two integers: lowest number of sparse pixels sampled followed by highest number of sparse'
31 | 'pixels sampled when training the conditional neural process')
32 |
33 | p.add_argument('--epochs_til_ckpt', type=int, default=10,
34 | help='Time interval in seconds until checkpoint is saved.')
35 | p.add_argument('--steps_til_summary', type=int, default=1000,
36 | help='Time interval in seconds until tensorboard summary is saved.')
37 |
38 | p.add_argument('--dataset', type=str, default='celeba_32x32',
39 | help='Time interval in seconds until tensorboard summary is saved.')
40 | p.add_argument('--model_type', type=str, default='sine',
41 | help='Nonlinearity for the hypo-network module')
42 |
43 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
44 |
45 | p.add_argument('--conv_encoder', action='store_true', default=False, help='Use convolutional encoder process')
46 | opt = p.parse_args()
47 |
48 |
49 | assert opt.dataset == 'celeba_32x32'
50 | if opt.conv_encoder: gmode = 'conv_cnp'
51 | else: gmode = 'cnp'
52 |
53 | img_dataset = dataio.CelebA(split='train', downsampled=True)
54 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=(32, 32))
55 | generalization_dataset = dataio.ImageGeneralizationWrapper(coord_dataset,
56 | train_sparsity_range=opt.train_sparsity_range,
57 | generalization_mode=gmode)
58 | image_resolution = (32, 32)
59 |
60 | dataloader = DataLoader(generalization_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
61 |
62 | if opt.conv_encoder:
63 | model = meta_modules.ConvolutionalNeuralProcessImplicit2DHypernet(in_features=img_dataset.img_channels,
64 | out_features=img_dataset.img_channels,
65 | image_resolution=image_resolution)
66 | else:
67 | model = meta_modules.NeuralProcessImplicit2DHypernet(in_features=img_dataset.img_channels + 2,
68 | out_features=img_dataset.img_channels,
69 | image_resolution=image_resolution)
70 | model.cuda()
71 |
72 | # Define the loss
73 | loss_fn = partial(loss_functions.image_hypernetwork_loss, None, opt.kl_weight, opt.fw_weight)
74 | summary_fn = partial(utils.write_image_summary_small, image_resolution, None)
75 |
76 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
77 |
78 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
79 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
80 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=True)
81 |
--------------------------------------------------------------------------------
/experiment_scripts/train_inverse_helmholtz.py:
--------------------------------------------------------------------------------
1 | '''Reproduces Paper Sec. 4.3, Supplement Sec. 5, reconstruction from gradient.
2 | '''
3 |
4 | # Enable import from parent package
5 | import sys
6 | import os
7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
8 |
9 | import dataio, utils, training, loss_functions, modules
10 | from torch.utils.data import DataLoader
11 | import torch
12 | import configargparse
13 | from scipy.io import loadmat
14 |
15 | p = configargparse.ArgumentParser()
16 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
17 |
18 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
19 | p.add_argument('--experiment_name', type=str, required=True,
20 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
21 |
22 | # General training options
23 | p.add_argument('--batch_size', type=int, default=32)
24 | p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5')
25 | p.add_argument('--num_epochs', type=int, default=80000,
26 | help='Number of epochs to train for.')
27 |
28 | p.add_argument('--epochs_til_ckpt', type=int, default=1000,
29 | help='Time interval in seconds until checkpoint is saved.')
30 | p.add_argument('--steps_til_summary', type=int, default=100,
31 | help='Time interval in seconds until tensorboard summary is saved.')
32 | p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid'],
33 | help='Type of model to evaluate, default is sine.')
34 | p.add_argument('--data', type=str, default='./fwi/data_cylinder_5.mat', required=False,
35 | help='Data file with the source/rec coordinates and data.')
36 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
37 | p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.')
38 | p.add_argument('--pretrain', default=False, action='store_true', help='don''t solve for velocity.')
39 | p.add_argument('--load_model', type=str, default=None, required=False,
40 | help='Load pretrained model from checkpoint.')
41 |
42 | opt = p.parse_args()
43 |
44 | # we need to load source and receiver data generated by the principled solver for FWI
45 | data = loadmat(opt.data)
46 | source_coords = data['source']
47 | rec_coords = data['receivers']
48 | rec_val = data['rec_val']
49 |
50 | dataset = dataio.InverseHelmholtz(source_coords, rec_coords, rec_val, sidelength=115, pretrain=opt.pretrain)
51 | dataloader = DataLoader(dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
52 |
53 | # Define the model.
54 | N_src = source_coords.shape[0]
55 | model = modules.SingleBVPNet(in_features=2, out_features=2 * N_src + 1, type=opt.model, final_layer_factor=1.)
56 |
57 | if opt.load_model is not None:
58 | model.load_state_dict(torch.load(opt.load_model))
59 |
60 | model.cuda()
61 |
62 | # Define the loss
63 | loss_fn = loss_functions.helmholtz_pml
64 | summary_fn = utils.write_helmholtz_summary
65 |
66 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
67 |
68 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
69 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
70 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=opt.clip_grad)
71 |
--------------------------------------------------------------------------------
/experiment_scripts/train_poisson_grad_img.py:
--------------------------------------------------------------------------------
1 | '''Reproduces Paper Sec. 4.1, Supplement Sec. 3, reconstruction from gradient.
2 | '''
3 |
4 | # Enable import from parent package
5 | import sys
6 | import os
7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
8 |
9 | import dataio, meta_modules, utils, training, loss_functions, modules
10 |
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 |
14 | p = configargparse.ArgumentParser()
15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
16 |
17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
18 | p.add_argument('--experiment_name', type=str, required=True,
19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
20 |
21 | # General training options
22 | p.add_argument('--batch_size', type=int, default=16384)
23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5')
24 | p.add_argument('--num_epochs', type=int, default=10000,
25 | help='Number of epochs to train for.')
26 |
27 | p.add_argument('--epochs_til_ckpt', type=int, default=25,
28 | help='Time interval in seconds until checkpoint is saved.')
29 | p.add_argument('--steps_til_summary', type=int, default=100,
30 | help='Time interval in seconds until tensorboard summary is saved.')
31 |
32 | p.add_argument('--dataset', type=str, choices=['camera','bsd500'], default='camera',
33 | help='Dataset: choices=[camera,bsd500].')
34 | p.add_argument('--model_type', type=str, default='sine',
35 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)')
36 |
37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
38 | opt = p.parse_args()
39 |
40 | if opt.dataset == 'camera':
41 | img_dataset = dataio.Camera()
42 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='gradients')
43 | elif opt.dataset == 'bsd500':
44 | # you can select the image your like in idx to sample
45 | img_dataset = dataio.BSD500ImageDataset(in_folder='../data/BSD500/train',
46 | idx_to_sample=[19])
47 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='gradients')
48 |
49 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
50 |
51 | # Define the model.
52 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh' or opt.model_type == 'softplus':
53 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', sidelength=(256, 256))
54 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf':
55 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, sidelength=(256, 256))
56 | else:
57 | raise NotImplementedError
58 | model.cuda()
59 |
60 | # Define the loss & summary functions
61 | loss_fn = loss_functions.gradients_mse
62 | summary_fn = utils.write_gradients_summary
63 |
64 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
65 |
66 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
67 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
68 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False)
69 |
--------------------------------------------------------------------------------
/experiment_scripts/train_poisson_gradcomp_img.py:
--------------------------------------------------------------------------------
1 | '''Reproduces Paper Sec. 4.1, Supplement Sec. 3, poisson image editing.
2 | '''
3 |
4 | # Enable import from parent package
5 | import sys
6 | import os
7 |
8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9 |
10 | import dataio, meta_modules, utils, training, loss_functions, modules
11 |
12 | from torch.utils.data import DataLoader
13 | import configargparse
14 |
15 | p = configargparse.ArgumentParser()
16 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
17 |
18 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
19 | p.add_argument('--experiment_name', type=str, required=True,
20 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
21 |
22 | # General training options
23 | p.add_argument('--batch_size', type=int, default=16384)
24 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5')
25 | p.add_argument('--num_epochs', type=int, default=10000,
26 | help='Number of epochs to train for.')
27 |
28 | p.add_argument('--epochs_til_ckpt', type=int, default=25,
29 | help='Time interval in seconds until checkpoint is saved.')
30 | p.add_argument('--steps_til_summary', type=int, default=100,
31 | help='Time interval in seconds until tensorboard summary is saved.')
32 |
33 | p.add_argument('--model_type', type=str, default='sine',
34 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)')
35 |
36 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
37 | opt = p.parse_args()
38 |
39 | # Dataset
40 | img_filepath1 = '/media/data4e/jnmartel/neural_bvps/gizeh.jpg'
41 | img_filepath2 = '/media/data4e/jnmartel/neural_bvps/bear.jpg'
42 | is_color = False
43 | coord_dataset = dataio.CompositeGradients(img_filepath1, img_filepath2, is_color=is_color,
44 | sidelength=512)
45 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
46 |
47 | # Define the model.
48 | if not is_color:
49 | model = modules.SingleBVPNet(type=opt.model_type, final_layer_factor=1)
50 | loss_fn = loss_functions.gradients_mse
51 | else:
52 | model = modules.SingleBVPNet(out_features=3, type=opt.model_type, final_layer_factor=1)
53 | loss_fn = loss_functions.gradients_color_mse
54 | model.cuda()
55 |
56 | summary_fn = utils.write_gradcomp_summary
57 |
58 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
59 |
60 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
61 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
62 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False)
63 |
--------------------------------------------------------------------------------
/experiment_scripts/train_poisson_lapl_img.py:
--------------------------------------------------------------------------------
1 | '''Reproduces Paper Sec. 4.1, Supplement Sec. 3, reconstruction from laplacian.
2 | '''
3 |
4 | # Enable import from parent package
5 | import sys
6 | import os
7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
8 |
9 | import dataio, meta_modules, utils, training, loss_functions, modules
10 |
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 |
14 | p = configargparse.ArgumentParser()
15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
16 |
17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
18 | p.add_argument('--experiment_name', type=str, required=True,
19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
20 |
21 | # General training options
22 | p.add_argument('--batch_size', type=int, default=16384)
23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5')
24 | p.add_argument('--num_epochs', type=int, default=10000,
25 | help='Number of epochs to train for.')
26 |
27 | p.add_argument('--epochs_til_ckpt', type=int, default=25,
28 | help='Time interval in seconds until checkpoint is saved.')
29 | p.add_argument('--steps_til_summary', type=int, default=100,
30 | help='Time interval in seconds until tensorboard summary is saved.')
31 |
32 | p.add_argument('--dataset', type=str, choices=['camera','bsd500'], default='camera',
33 | help='Dataset: choices=[camera,bsd500].')
34 | p.add_argument('--model_type', type=str, default='sine',
35 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)')
36 |
37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
38 | opt = p.parse_args()
39 |
40 |
41 | if opt.dataset == 'camera':
42 | img_dataset = dataio.Camera()
43 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='laplacian')
44 | elif opt.dataset == 'bsd500':
45 | # you can select the image your like in idx to sample
46 | img_dataset = dataio.BSD500ImageDataset(in_folder='/media/data3/awb/BSD500/train',
47 | idx_to_sample=[19])
48 | coord_dataset = dataio.Implicit2DWrapper(img_dataset, sidelength=256, compute_diff='laplacian')
49 |
50 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
51 |
52 | # Define the model
53 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh':
54 | model = modules.SingleBVPNet(type=opt.model_type, mode='mlp', sidelength=(256, 256))
55 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf':
56 | model = modules.SingleBVPNet(type='relu', mode=opt.model_type, sidelength=(256, 256))
57 | else:
58 | raise NotImplementedError
59 | model.cuda()
60 |
61 | # Define the loss
62 | loss_fn = loss_functions.laplace_mse
63 | summary_fn = utils.write_laplace_summary
64 |
65 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
66 |
67 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
68 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
69 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False)
70 |
--------------------------------------------------------------------------------
/experiment_scripts/train_sdf.py:
--------------------------------------------------------------------------------
1 | '''Reproduces Sec. 4.2 in main paper and Sec. 4 in Supplement.
2 | '''
3 |
4 | # Enable import from parent package
5 | import sys
6 | import os
7 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
8 |
9 | import dataio, meta_modules, utils, training, loss_functions, modules
10 |
11 | from torch.utils.data import DataLoader
12 | import configargparse
13 |
14 | p = configargparse.ArgumentParser()
15 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
16 |
17 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
18 | p.add_argument('--experiment_name', type=str, required=True,
19 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
20 |
21 | # General training options
22 | p.add_argument('--batch_size', type=int, default=1400)
23 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=5e-5')
24 | p.add_argument('--num_epochs', type=int, default=10000,
25 | help='Number of epochs to train for.')
26 |
27 | p.add_argument('--epochs_til_ckpt', type=int, default=1,
28 | help='Time interval in seconds until checkpoint is saved.')
29 | p.add_argument('--steps_til_summary', type=int, default=100,
30 | help='Time interval in seconds until tensorboard summary is saved.')
31 |
32 | p.add_argument('--model_type', type=str, default='sine',
33 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)')
34 | p.add_argument('--point_cloud_path', type=str, default='/home/sitzmann/data/point_cloud.xyz',
35 | help='Options are "sine" (all sine activations) and "mixed" (first layer sine, other layers tanh)')
36 |
37 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
38 | opt = p.parse_args()
39 |
40 |
41 | sdf_dataset = dataio.PointCloud(opt.point_cloud_path, on_surface_points=opt.batch_size)
42 | dataloader = DataLoader(sdf_dataset, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)
43 |
44 | # Define the model.
45 | if opt.model_type == 'nerf':
46 | model = modules.SingleBVPNet(type='relu', mode='nerf', in_features=3)
47 | else:
48 | model = modules.SingleBVPNet(type=opt.model_type, in_features=3)
49 | model.cuda()
50 |
51 | # Define the loss
52 | loss_fn = loss_functions.sdf
53 | summary_fn = utils.write_sdf_summary
54 |
55 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
56 |
57 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
58 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
59 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, double_precision=False,
60 | clip_grad=True)
61 |
--------------------------------------------------------------------------------
/experiment_scripts/train_video.py:
--------------------------------------------------------------------------------
1 | '''Reproduces Supplement Sec. 7'''
2 |
3 | # Enable import from parent package
4 | import sys
5 | import os
6 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
7 |
8 | import dataio, meta_modules, utils, training, loss_functions, modules
9 |
10 | from torch.utils.data import DataLoader
11 | import configargparse
12 | from functools import partial
13 | import skvideo.datasets
14 |
15 | p = configargparse.ArgumentParser()
16 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
17 |
18 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
19 | p.add_argument('--experiment_name', type=str, required=True,
20 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
21 |
22 | # General training options
23 | p.add_argument('--batch_size', type=int, default=1)
24 | p.add_argument('--lr', type=float, default=1e-4, help='learning rate. default=1e-4')
25 | p.add_argument('--num_epochs', type=int, default=100000,
26 | help='Number of epochs to train for.')
27 |
28 | p.add_argument('--epochs_til_ckpt', type=int, default=1000,
29 | help='Time interval in seconds until checkpoint is saved.')
30 | p.add_argument('--steps_til_summary', type=int, default=100,
31 | help='Time interval in seconds until tensorboard summary is saved.')
32 | p.add_argument('--dataset', type=str, default='bikes',
33 | help='Video dataset; one of (cat, bikes)', choices=['cat', 'bikes'])
34 | p.add_argument('--model_type', type=str, default='sine',
35 | help='Options currently are "sine" (all sine activations), "relu" (all relu activations,'
36 | '"nerf" (relu activations and positional encoding as in NeRF), "rbf" (input rbf layer, rest relu)')
37 | p.add_argument('--sample_frac', type=float, default=38e-4,
38 | help='What fraction of video pixels to sample in each batch (default is all)')
39 |
40 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
41 | opt = p.parse_args()
42 |
43 | if opt.dataset == 'cat':
44 | video_path = './data/video_512.npy'
45 | elif opt.dataset == 'bikes':
46 | video_path = skvideo.datasets.bikes()
47 |
48 | vid_dataset = dataio.Video(video_path)
49 | coord_dataset = dataio.Implicit3DWrapper(vid_dataset, sidelength=vid_dataset.shape, sample_fraction=opt.sample_frac)
50 | dataloader = DataLoader(coord_dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
51 |
52 | # Define the model.
53 | if opt.model_type == 'sine' or opt.model_type == 'relu' or opt.model_type == 'tanh':
54 | model = modules.SingleBVPNet(type=opt.model_type, in_features=3, out_features=vid_dataset.channels,
55 | mode='mlp', hidden_features=1024, num_hidden_layers=3)
56 | elif opt.model_type == 'rbf' or opt.model_type == 'nerf':
57 | model = modules.SingleBVPNet(type='relu', in_features=3, out_features=vid_dataset.channels, mode=opt.model_type)
58 | else:
59 | raise NotImplementedError
60 | model.cuda()
61 |
62 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
63 |
64 | # Define the loss
65 | loss_fn = partial(loss_functions.image_mse, None)
66 | summary_fn = partial(utils.write_video_summary, vid_dataset)
67 |
68 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
69 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
70 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn)
71 |
--------------------------------------------------------------------------------
/experiment_scripts/train_wave_equation.py:
--------------------------------------------------------------------------------
1 | '''Reproduces Paper Sec. 4.3 and Supplement Sec. 5'''
2 |
3 | # Enable import from parent package
4 | import sys
5 | import os
6 | sys.path.append( os.path.dirname( os.path.dirname( os.path.abspath(__file__) ) ) )
7 |
8 | import dataio, meta_modules, utils, training, loss_functions, modules
9 |
10 | from torch.utils.data import DataLoader
11 | import configargparse
12 |
13 | p = configargparse.ArgumentParser()
14 | p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
15 |
16 | p.add_argument('--logging_root', type=str, default='./logs', help='root for logging')
17 | p.add_argument('--experiment_name', type=str, required=True,
18 | help='Name of subdirectory in logging_root where summaries and checkpoints will be saved.')
19 |
20 | # General training options
21 | p.add_argument('--batch_size', type=int, default=32)
22 | p.add_argument('--lr', type=float, default=2e-5, help='learning rate. default=2e-5')
23 | p.add_argument('--num_epochs', type=int, default=100000,
24 | help='Number of epochs to train for.')
25 |
26 | p.add_argument('--epochs_til_ckpt', type=int, default=1000,
27 | help='Time interval in seconds until checkpoint is saved.')
28 | p.add_argument('--steps_til_summary', type=int, default=100,
29 | help='Time interval in seconds until tensorboard summary is saved.')
30 | p.add_argument('--model', type=str, default='sine', required=False, choices=['sine', 'tanh', 'sigmoid', 'relu'],
31 | help='Type of model to evaluate, default is sine.')
32 | p.add_argument('--mode', type=str, default='mlp', required=False, choices=['mlp', 'rbf', 'pinn'],
33 | help='Whether to use uniform velocity parameter')
34 | p.add_argument('--velocity', type=str, default='uniform', required=False, choices=['uniform', 'square', 'circle'],
35 | help='Whether to use uniform velocity parameter')
36 | p.add_argument('--pretrain', action='store_true', default=False, required=False, help='Pretrain dirichlet and neumann conditions')
37 | p.add_argument('--clip_grad', default=0.0, type=float, help='Clip gradient.')
38 | p.add_argument('--use_lbfgs', default=False, type=bool, help='use L-BFGS.')
39 |
40 | p.add_argument('--checkpoint_path', default=None, help='Checkpoint to trained model.')
41 | opt = p.parse_args()
42 |
43 | # if we have a velocity perturbation, offset the source
44 | source_coords = [0., 0., 0.]
45 |
46 | dataset = dataio.WaveSource(sidelength=340, velocity=opt.velocity,
47 | source_coords=source_coords, pretrain=opt.pretrain)
48 |
49 | dataloader = DataLoader(dataset, shuffle=True, batch_size=opt.batch_size, pin_memory=True, num_workers=0)
50 |
51 | model = modules.SingleBVPNet(in_features=3, out_features=1, type=opt.model, mode=opt.mode,
52 | final_layer_factor=1., hidden_features=512, num_hidden_layers=3)
53 | model.cuda()
54 |
55 | # Define the loss
56 | loss_fn = loss_functions.wave_pml
57 | summary_fn = utils.write_wave_summary
58 |
59 | root_path = os.path.join(opt.logging_root, opt.experiment_name)
60 |
61 | training.train(model=model, train_dataloader=dataloader, epochs=opt.num_epochs, lr=opt.lr,
62 | steps_til_summary=opt.steps_til_summary, epochs_til_checkpoint=opt.epochs_til_ckpt,
63 | model_dir=root_path, loss_fn=loss_fn, summary_fn=summary_fn, clip_grad=opt.clip_grad,
64 | use_lbfgs=opt.use_lbfgs)
65 |
--------------------------------------------------------------------------------
/meta_modules.py:
--------------------------------------------------------------------------------
1 | '''Modules for hypernetwork experiments, Paper Sec. 4.4
2 | '''
3 |
4 | import torch
5 | from torch import nn
6 | from collections import OrderedDict
7 | import modules
8 |
9 |
10 | class HyperNetwork(nn.Module):
11 | def __init__(self, hyper_in_features, hyper_hidden_layers, hyper_hidden_features, hypo_module):
12 | '''
13 |
14 | Args:
15 | hyper_in_features: In features of hypernetwork
16 | hyper_hidden_layers: Number of hidden layers in hypernetwork
17 | hyper_hidden_features: Number of hidden units in hypernetwork
18 | hypo_module: MetaModule. The module whose parameters are predicted.
19 | '''
20 | super().__init__()
21 |
22 | hypo_parameters = hypo_module.meta_named_parameters()
23 |
24 | self.names = []
25 | self.nets = nn.ModuleList()
26 | self.param_shapes = []
27 | for name, param in hypo_parameters:
28 | self.names.append(name)
29 | self.param_shapes.append(param.size())
30 |
31 | hn = modules.FCBlock(in_features=hyper_in_features, out_features=int(torch.prod(torch.tensor(param.size()))),
32 | num_hidden_layers=hyper_hidden_layers, hidden_features=hyper_hidden_features,
33 | outermost_linear=True, nonlinearity='relu')
34 | self.nets.append(hn)
35 |
36 | if 'weight' in name:
37 | self.nets[-1].net[-1].apply(lambda m: hyper_weight_init(m, param.size()[-1]))
38 | elif 'bias' in name:
39 | self.nets[-1].net[-1].apply(lambda m: hyper_bias_init(m))
40 |
41 | def forward(self, z):
42 | '''
43 | Args:
44 | z: Embedding. Input to hypernetwork. Could be output of "Autodecoder" (see above)
45 |
46 | Returns:
47 | params: OrderedDict. Can be directly passed as the "params" parameter of a MetaModule.
48 | '''
49 | params = OrderedDict()
50 | for name, net, param_shape in zip(self.names, self.nets, self.param_shapes):
51 | batch_param_shape = (-1,) + param_shape
52 | params[name] = net(z).reshape(batch_param_shape)
53 | return params
54 |
55 |
56 | class NeuralProcessImplicit2DHypernet(nn.Module):
57 | '''A canonical 2D representation hypernetwork mapping 2D coords to out_features.'''
58 | def __init__(self, in_features, out_features, image_resolution=None, encoder_nl='sine'):
59 | super().__init__()
60 |
61 | latent_dim = 256
62 | self.hypo_net = modules.SingleBVPNet(out_features=out_features, type='sine', sidelength=image_resolution,
63 | in_features=2)
64 | self.hyper_net = HyperNetwork(hyper_in_features=latent_dim, hyper_hidden_layers=1, hyper_hidden_features=256,
65 | hypo_module=self.hypo_net)
66 | self.set_encoder = modules.SetEncoder(in_features=in_features, out_features=latent_dim, num_hidden_layers=2,
67 | hidden_features=latent_dim, nonlinearity=encoder_nl)
68 | print(self)
69 |
70 | def freeze_hypernet(self):
71 | for param in self.hyper_net.parameters():
72 | param.requires_grad = False
73 |
74 | def get_hypo_net_weights(self, model_input):
75 | pixels, coords = model_input['img_sub'], model_input['coords_sub']
76 | ctxt_mask = model_input.get('ctxt_mask', None)
77 | embedding = self.set_encoder(coords, pixels, ctxt_mask=ctxt_mask)
78 | hypo_params = self.hyper_net(embedding)
79 | return hypo_params, embedding
80 |
81 | def forward(self, model_input):
82 | if model_input.get('embedding', None) is None:
83 | pixels, coords = model_input['img_sub'], model_input['coords_sub']
84 | ctxt_mask = model_input.get('ctxt_mask', None)
85 | embedding = self.set_encoder(coords, pixels, ctxt_mask=ctxt_mask)
86 | else:
87 | embedding = model_input['embedding']
88 | hypo_params = self.hyper_net(embedding)
89 |
90 | model_output = self.hypo_net(model_input, params=hypo_params)
91 | return {'model_in':model_output['model_in'], 'model_out':model_output['model_out'], 'latent_vec':embedding,
92 | 'hypo_params':hypo_params}
93 |
94 |
95 | class ConvolutionalNeuralProcessImplicit2DHypernet(nn.Module):
96 | def __init__(self, in_features, out_features, image_resolution=None, partial_conv=False):
97 | super().__init__()
98 | latent_dim = 256
99 |
100 | if partial_conv:
101 | self.encoder = modules.PartialConvImgEncoder(channel=in_features, image_resolution=image_resolution)
102 | else:
103 | self.encoder = modules.ConvImgEncoder(channel=in_features, image_resolution=image_resolution)
104 | self.hypo_net = modules.SingleBVPNet(out_features=out_features, type='sine', sidelength=image_resolution,
105 | in_features=2)
106 | self.hyper_net = HyperNetwork(hyper_in_features=latent_dim, hyper_hidden_layers=1, hyper_hidden_features=256,
107 | hypo_module=self.hypo_net)
108 | print(self)
109 |
110 | def forward(self, model_input):
111 | if model_input.get('embedding', None) is None:
112 | embedding = self.encoder(model_input['img_sparse'])
113 | else:
114 | embedding = model_input['embedding']
115 | hypo_params = self.hyper_net(embedding)
116 |
117 | model_output = self.hypo_net(model_input, params=hypo_params)
118 |
119 | return {'model_in': model_output['model_in'], 'model_out': model_output['model_out'], 'latent_vec': embedding,
120 | 'hypo_params': hypo_params}
121 |
122 | def get_hypo_net_weights(self, model_input):
123 | embedding = self.encoder(model_input['img_sparse'])
124 | hypo_params = self.hyper_net(embedding)
125 | return hypo_params, embedding
126 |
127 | def freeze_hypernet(self):
128 | for param in self.hyper_net.parameters():
129 | param.requires_grad = False
130 | for param in self.encoder.parameters():
131 | param.requires_grad = False
132 |
133 |
134 | ############################
135 | # Initialization schemes
136 | def hyper_weight_init(m, in_features_main_net):
137 | if hasattr(m, 'weight'):
138 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
139 | m.weight.data = m.weight.data / 1.e2
140 |
141 | if hasattr(m, 'bias'):
142 | with torch.no_grad():
143 | m.bias.uniform_(-1/in_features_main_net, 1/in_features_main_net)
144 |
145 |
146 | def hyper_bias_init(m):
147 | if hasattr(m, 'weight'):
148 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
149 | m.weight.data = m.weight.data / 1.e2
150 |
151 | if hasattr(m, 'bias'):
152 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
153 | with torch.no_grad():
154 | m.bias.uniform_(-1/fan_in, 1/fan_in)
155 |
--------------------------------------------------------------------------------
/sdf_meshing.py:
--------------------------------------------------------------------------------
1 | '''From the DeepSDF repository https://github.com/facebookresearch/DeepSDF
2 | '''
3 | #!/usr/bin/env python3
4 |
5 | import logging
6 | import numpy as np
7 | import plyfile
8 | import skimage.measure
9 | import time
10 | import torch
11 |
12 |
13 | def create_mesh(
14 | decoder, filename, N=256, max_batch=64 ** 3, offset=None, scale=None
15 | ):
16 | start = time.time()
17 | ply_filename = filename
18 |
19 | decoder.eval()
20 |
21 | # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
22 | voxel_origin = [-1, -1, -1]
23 | voxel_size = 2.0 / (N - 1)
24 |
25 | overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
26 | samples = torch.zeros(N ** 3, 4)
27 |
28 | # transform first 3 columns
29 | # to be the x, y, z index
30 | samples[:, 2] = overall_index % N
31 | samples[:, 1] = (overall_index.long() / N) % N
32 | samples[:, 0] = ((overall_index.long() / N) / N) % N
33 |
34 | # transform first 3 columns
35 | # to be the x, y, z coordinate
36 | samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
37 | samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
38 | samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
39 |
40 | num_samples = N ** 3
41 |
42 | samples.requires_grad = False
43 |
44 | head = 0
45 |
46 | while head < num_samples:
47 | print(head)
48 | sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()
49 |
50 | samples[head : min(head + max_batch, num_samples), 3] = (
51 | decoder(sample_subset)
52 | .squeeze()#.squeeze(1)
53 | .detach()
54 | .cpu()
55 | )
56 | head += max_batch
57 |
58 | sdf_values = samples[:, 3]
59 | sdf_values = sdf_values.reshape(N, N, N)
60 |
61 | end = time.time()
62 | print("sampling takes: %f" % (end - start))
63 |
64 | convert_sdf_samples_to_ply(
65 | sdf_values.data.cpu(),
66 | voxel_origin,
67 | voxel_size,
68 | ply_filename + ".ply",
69 | offset,
70 | scale,
71 | )
72 |
73 |
74 | def convert_sdf_samples_to_ply(
75 | pytorch_3d_sdf_tensor,
76 | voxel_grid_origin,
77 | voxel_size,
78 | ply_filename_out,
79 | offset=None,
80 | scale=None,
81 | ):
82 | """
83 | Convert sdf samples to .ply
84 |
85 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
86 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
87 | :voxel_size: float, the size of the voxels
88 | :ply_filename_out: string, path of the filename to save to
89 |
90 | This function adapted from: https://github.com/RobotLocomotion/spartan
91 | """
92 |
93 | start_time = time.time()
94 |
95 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
96 |
97 | verts, faces, normals, values = np.zeros((0, 3)), np.zeros((0, 3)), np.zeros((0, 3)), np.zeros(0)
98 | try:
99 | verts, faces, normals, values = skimage.measure.marching_cubes_lewiner(
100 | numpy_3d_sdf_tensor, level=0.0, spacing=[voxel_size] * 3
101 | )
102 | except:
103 | pass
104 |
105 | # transform from voxel coordinates to camera coordinates
106 | # note x and y are flipped in the output of marching_cubes
107 | mesh_points = np.zeros_like(verts)
108 | mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
109 | mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
110 | mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
111 |
112 | # apply additional offset and scale
113 | if scale is not None:
114 | mesh_points = mesh_points / scale
115 | if offset is not None:
116 | mesh_points = mesh_points - offset
117 |
118 | # try writing to the ply file
119 |
120 | num_verts = verts.shape[0]
121 | num_faces = faces.shape[0]
122 |
123 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
124 |
125 | for i in range(0, num_verts):
126 | verts_tuple[i] = tuple(mesh_points[i, :])
127 |
128 | faces_building = []
129 | for i in range(0, num_faces):
130 | faces_building.append(((faces[i, :].tolist(),)))
131 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
132 |
133 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
134 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
135 |
136 | ply_data = plyfile.PlyData([el_verts, el_faces])
137 | logging.debug("saving mesh to %s" % (ply_filename_out))
138 | ply_data.write(ply_filename_out)
139 |
140 | logging.debug(
141 | "converting to ply format and writing to file took {} s".format(
142 | time.time() - start_time
143 | )
144 | )
145 |
--------------------------------------------------------------------------------
/torchmeta/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta import datasets
2 | from torchmeta import modules
3 | from torchmeta import toy
4 | from torchmeta import transforms
5 | from torchmeta import utils
6 |
7 | from torchmeta.version import VERSION as __version__
8 |
--------------------------------------------------------------------------------
/torchmeta/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta.datasets.triplemnist import TripleMNIST
2 | from torchmeta.datasets.doublemnist import DoubleMNIST
3 | from torchmeta.datasets.cub import CUB
4 | from torchmeta.datasets.cifar100 import CIFARFS, FC100
5 | from torchmeta.datasets.miniimagenet import MiniImagenet
6 | from torchmeta.datasets.omniglot import Omniglot
7 | from torchmeta.datasets.tieredimagenet import TieredImagenet
8 | from torchmeta.datasets.tcga import TCGA
9 |
10 | from torchmeta.datasets import helpers
11 |
12 | __all__ = [
13 | 'TCGA',
14 | 'Omniglot',
15 | 'MiniImagenet',
16 | 'TieredImagenet',
17 | 'CIFARFS',
18 | 'FC100',
19 | 'CUB',
20 | 'DoubleMNIST',
21 | 'TripleMNIST',
22 | 'helpers'
23 | ]
24 |
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cifar100/cifar-fs/test.json:
--------------------------------------------------------------------------------
1 | [["aquatic_mammals","whale"],["flowers","poppy"],["flowers","rose"],["fruit_and_vegetables","sweet_pepper"],["household_electrical_devices","telephone"],["household_furniture","bed"],["household_furniture","table"],["household_furniture","wardrobe"],["large_carnivores","leopard"],["large_natural_outdoor_scenes","plain"],["large_omnivores_and_herbivores","chimpanzee"],["medium_mammals","fox"],["non-insect_invertebrates","snail"],["non-insect_invertebrates","worm"],["people","baby"],["people","man"],["people","woman"],["vehicles_1","bicycle"],["vehicles_1","pickup_truck"],["vehicles_2","rocket"]]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cifar100/cifar-fs/train.json:
--------------------------------------------------------------------------------
1 | [["aquatic_mammals","dolphin"],["aquatic_mammals","seal"],["fish","aquarium_fish"],["fish","ray"],["fish","trout"],["flowers","orchid"],["flowers","sunflower"],["flowers","tulip"],["food_containers","bottle"],["food_containers","bowl"],["food_containers","can"],["food_containers","cup"],["food_containers","plate"],["fruit_and_vegetables","apple"],["fruit_and_vegetables","mushroom"],["fruit_and_vegetables","orange"],["fruit_and_vegetables","pear"],["household_electrical_devices","clock"],["household_electrical_devices","keyboard"],["household_furniture","chair"],["household_furniture","couch"],["insects","bee"],["insects","caterpillar"],["insects","cockroach"],["large_carnivores","bear"],["large_carnivores","lion"],["large_carnivores","tiger"],["large_carnivores","wolf"],["large_man-made_outdoor_things","bridge"],["large_man-made_outdoor_things","castle"],["large_man-made_outdoor_things","house"],["large_man-made_outdoor_things","road"],["large_man-made_outdoor_things","skyscraper"],["large_natural_outdoor_scenes","cloud"],["large_natural_outdoor_scenes","forest"],["large_natural_outdoor_scenes","mountain"],["large_omnivores_and_herbivores","elephant"],["large_omnivores_and_herbivores","kangaroo"],["medium_mammals","porcupine"],["medium_mammals","possum"],["medium_mammals","raccoon"],["medium_mammals","skunk"],["non-insect_invertebrates","lobster"],["non-insect_invertebrates","spider"],["people","boy"],["people","girl"],["reptiles","dinosaur"],["reptiles","lizard"],["reptiles","snake"],["reptiles","turtle"],["small_mammals","hamster"],["small_mammals","mouse"],["small_mammals","rabbit"],["small_mammals","shrew"],["small_mammals","squirrel"],["trees","oak_tree"],["trees","palm_tree"],["trees","pine_tree"],["trees","willow_tree"],["vehicles_1","bus"],["vehicles_1","train"],["vehicles_2","lawn_mower"],["vehicles_2","streetcar"],["vehicles_2","tank"]]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cifar100/cifar-fs/val.json:
--------------------------------------------------------------------------------
1 | [["aquatic_mammals","beaver"],["aquatic_mammals","otter"],["fish","flatfish"],["fish","shark"],["household_electrical_devices","lamp"],["household_electrical_devices","television"],["insects","beetle"],["insects","butterfly"],["large_natural_outdoor_scenes","sea"],["large_omnivores_and_herbivores","camel"],["large_omnivores_and_herbivores","cattle"],["non-insect_invertebrates","crab"],["reptiles","crocodile"],["trees","maple_tree"],["vehicles_1","motorcycle"],["vehicles_2","tractor"]]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cifar100/fc100/test.json:
--------------------------------------------------------------------------------
1 | ["aquatic_mammals","insects","medium_mammals","people"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cifar100/fc100/train.json:
--------------------------------------------------------------------------------
1 | ["fish","flowers","food_containers","fruit_and_vegetables","household_electrical_devices","household_furniture","large_man-made_outdoor_things","large_natural_outdoor_scenes","reptiles","trees","vehicles_1","vehicles_2"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cifar100/fc100/val.json:
--------------------------------------------------------------------------------
1 | ["large_carnivores","large_omnivores_and_herbivores","non-insect_invertebrates","small_mammals"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cub/test.json:
--------------------------------------------------------------------------------
1 | ["004.Groove_billed_Ani","008.Rhinoceros_Auklet","012.Yellow_headed_Blackbird","016.Painted_Bunting","020.Yellow_breasted_Chat","024.Red_faced_Cormorant","028.Brown_Creeper","032.Mangrove_Cuckoo","036.Northern_Flicker","040.Olive_sided_Flycatcher","044.Frigatebird","048.European_Goldfinch","052.Pied_billed_Grebe","056.Pine_Grosbeak","060.Glaucous_winged_Gull","064.Ring_billed_Gull","068.Ruby_throated_Hummingbird","072.Pomarine_Jaeger","076.Dark_eyed_Junco","080.Green_Kingfisher","084.Red_legged_Kittiwake","088.Western_Meadowlark","092.Nighthawk","096.Hooded_Oriole","100.Brown_Pelican","104.American_Pipit","108.White_necked_Raven","112.Great_Grey_Shrike","116.Chipping_Sparrow","120.Fox_Sparrow","124.Le_Conte_Sparrow","128.Seaside_Sparrow","132.White_crowned_Sparrow","136.Barn_Swallow","140.Summer_Tanager","144.Common_Tern","148.Green_tailed_Towhee","152.Blue_headed_Vireo","156.White_eyed_Vireo","160.Black_throated_Blue_Warbler","164.Cerulean_Warbler","168.Kentucky_Warbler","172.Nashville_Warbler","176.Prairie_Warbler","180.Wilson_Warbler","184.Louisiana_Waterthrush","188.Pileated_Woodpecker","192.Downy_Woodpecker","196.House_Wren","200.Common_Yellowthroat"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cub/train.json:
--------------------------------------------------------------------------------
1 | ["001.Black_footed_Albatross","003.Sooty_Albatross","005.Crested_Auklet","007.Parakeet_Auklet","009.Brewer_Blackbird","011.Rusty_Blackbird","013.Bobolink","015.Lazuli_Bunting","017.Cardinal","019.Gray_Catbird","021.Eastern_Towhee","023.Brandt_Cormorant","025.Pelagic_Cormorant","027.Shiny_Cowbird","029.American_Crow","031.Black_billed_Cuckoo","033.Yellow_billed_Cuckoo","035.Purple_Finch","037.Acadian_Flycatcher","039.Least_Flycatcher","041.Scissor_tailed_Flycatcher","043.Yellow_bellied_Flycatcher","045.Northern_Fulmar","047.American_Goldfinch","049.Boat_tailed_Grackle","051.Horned_Grebe","053.Western_Grebe","055.Evening_Grosbeak","057.Rose_breasted_Grosbeak","059.California_Gull","061.Heermann_Gull","063.Ivory_Gull","065.Slaty_backed_Gull","067.Anna_Hummingbird","069.Rufous_Hummingbird","071.Long_tailed_Jaeger","073.Blue_Jay","075.Green_Jay","077.Tropical_Kingbird","079.Belted_Kingfisher","081.Pied_Kingfisher","083.White_breasted_Kingfisher","085.Horned_Lark","087.Mallard","089.Hooded_Merganser","091.Mockingbird","093.Clark_Nutcracker","095.Baltimore_Oriole","097.Orchard_Oriole","099.Ovenbird","101.White_Pelican","103.Sayornis","105.Whip_poor_Will","107.Common_Raven","109.American_Redstart","111.Loggerhead_Shrike","113.Baird_Sparrow","115.Brewer_Sparrow","117.Clay_colored_Sparrow","119.Field_Sparrow","121.Grasshopper_Sparrow","123.Henslow_Sparrow","125.Lincoln_Sparrow","127.Savannah_Sparrow","129.Song_Sparrow","131.Vesper_Sparrow","133.White_throated_Sparrow","135.Bank_Swallow","137.Cliff_Swallow","139.Scarlet_Tanager","141.Artic_Tern","143.Caspian_Tern","145.Elegant_Tern","147.Least_Tern","149.Brown_Thrasher","151.Black_capped_Vireo","153.Philadelphia_Vireo","155.Warbling_Vireo","157.Yellow_throated_Vireo","159.Black_and_white_Warbler","161.Blue_winged_Warbler","163.Cape_May_Warbler","165.Chestnut_sided_Warbler","167.Hooded_Warbler","169.Magnolia_Warbler","171.Myrtle_Warbler","173.Orange_crowned_Warbler","175.Pine_Warbler","177.Prothonotary_Warbler","179.Tennessee_Warbler","181.Worm_eating_Warbler","183.Northern_Waterthrush","185.Bohemian_Waxwing","187.American_Three_toed_Woodpecker","189.Red_bellied_Woodpecker","191.Red_headed_Woodpecker","193.Bewick_Wren","195.Carolina_Wren","197.Marsh_Wren","199.Winter_Wren"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/cub/val.json:
--------------------------------------------------------------------------------
1 | ["002.Laysan_Albatross","006.Least_Auklet","010.Red_winged_Blackbird","014.Indigo_Bunting","018.Spotted_Catbird","022.Chuck_will_Widow","026.Bronzed_Cowbird","030.Fish_Crow","034.Gray_crowned_Rosy_Finch","038.Great_Crested_Flycatcher","042.Vermilion_Flycatcher","046.Gadwall","050.Eared_Grebe","054.Blue_Grosbeak","058.Pigeon_Guillemot","062.Herring_Gull","066.Western_Gull","070.Green_Violetear","074.Florida_Jay","078.Gray_Kingbird","082.Ringed_Kingfisher","086.Pacific_Loon","090.Red_breasted_Merganser","094.White_breasted_Nuthatch","098.Scott_Oriole","102.Western_Wood_Pewee","106.Horned_Puffin","110.Geococcyx","114.Black_throated_Sparrow","118.House_Sparrow","122.Harris_Sparrow","126.Nelson_Sharp_tailed_Sparrow","130.Tree_Sparrow","134.Cape_Glossy_Starling","138.Tree_Swallow","142.Black_Tern","146.Forsters_Tern","150.Sage_Thrasher","154.Red_eyed_Vireo","158.Bay_breasted_Warbler","162.Canada_Warbler","166.Golden_winged_Warbler","170.Mourning_Warbler","174.Palm_Warbler","178.Swainson_Warbler","182.Yellow_Warbler","186.Cedar_Waxwing","190.Red_cockaded_Woodpecker","194.Cactus_Wren","198.Rock_Wren"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/doublemnist/test.json:
--------------------------------------------------------------------------------
1 | ["02", "17", "25", "32", "36", "46", "47", "49", "55", "57", "66", "67", "68", "73", "78", "80", "83", "86", "92", "96"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/doublemnist/train.json:
--------------------------------------------------------------------------------
1 | ["00", "01", "04", "05", "06", "08", "09", "11", "12", "13", "14", "15", "16", "18", "19", "20", "21", "23", "24", "26", "28", "29", "30", "31", "33", "35", "37", "38", "41", "42", "43", "44", "45", "50", "51", "53", "54", "56", "59", "60", "62", "63", "65", "69", "70", "72", "74", "75", "76", "77", "79", "81", "82", "84", "85", "87", "88", "89", "90", "91", "94", "95", "97", "98"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/doublemnist/val.json:
--------------------------------------------------------------------------------
1 | ["03", "07", "10", "22", "27", "34", "39", "40", "48", "52", "58", "61", "64", "71", "93", "99"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/omniglot/test.json:
--------------------------------------------------------------------------------
1 | {"background":{},"evaluation":{"Gurmukhi":["character42","character43","character44","character45"],"Kannada":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41"],"Keble":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26"],"Malayalam":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45","character46","character47"],"Manipuri":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40"],"Mongolian":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30"],"Old_Church_Slavonic_(Cyrillic)":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45"],"Oriya":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45","character46"],"Sylheti":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28"],"Syriac_(Serto)":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23"],"Tengwar":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25"],"Tibetan":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42"],"ULOG":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26"]}}
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/omniglot/val.json:
--------------------------------------------------------------------------------
1 | {"background":{"Armenian":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41"],"Bengali":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41","character42","character43","character44","character45","character46"],"Early_Aramaic":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22"],"Hebrew":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22"],"Mkhedruli_(Georgian)":["character01","character02","character03","character04","character05","character06","character07","character08","character09","character10","character11","character12","character13","character14","character15","character16","character17","character18","character19","character20","character21","character22","character23","character24","character25","character26","character27","character28","character29","character30","character31","character32","character33","character34","character35","character36","character37","character38","character39","character40","character41"]},"evaluation":{}}
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/tcga/cancers.json:
--------------------------------------------------------------------------------
1 | ["ACC","BLCA","BRCA","CESC","CHOL","COAD","DLBC","ESCA","FPPP","GBM","HNSC","KICH","KIRP","LAML","LGG","LIHC","LUAD","LUSC","MESO","OV","PAAD","PCPG","PRAD","READ","SARC","SKCM","STAD","TGCT","THCA","THYM","UCEC","UCS","UVM"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/tcga/task_variables.json:
--------------------------------------------------------------------------------
1 | ["gender","tumor_tissue_site","icd_10","icd_o_3_site","_EVENT","histological_type","_RFS","anatomic_neoplasm_subdivision","oct_embedded","_PANCAN_CNA_PANCAN_K8","_PANCAN_DNAMethyl_PANCAN","_PANCAN_Cluster_Cluster_PANCAN","_PANCAN_miRNA_PANCAN","_PANCAN_mutation_PANCAN","_PANCAN_UNC_RNAseq_PANCAN_K16","_PANCAN_RPPA_PANCAN_K8","height","clinical_stage","menopause_status","clinical_M","year_of_tobacco_smoking_onset","clinical_T","lymphatic_invasion","venous_invasion","synchronous_colon_cancer_present","additional_treatment_completion_success_outcome","breast_carcinoma_estrogen_receptor_status","breast_carcinoma_progesterone_receptor_status","clinical_N","white_cell_count_result","alcohol_history_documented","PAM50Call_RNAseq","mental_status_changes","lymphovascular_invasion_present","family_history_of_cancer","_PANCAN_DNAMethyl_BRCA","Node_nature2012","Tumor_nature2012","Metastasis_nature2012","ER_Status_nature2012","PR_Status_nature2012","HER2_Final_Status_nature2012","asthma_history","family_history_of_primary_brain_tumor","_PANCAN_mirna_BRCA","colon_polyps_present","animal_insect_allergy_history","Expression_Subtype","KRAS","Pathology_Updated","gleason_score","extrathyroid_carcinoma_present_extension_status","GeneExp_Subtype","AWG_cancer_type_Oct62011","hypermutation","hypertension","pregnancies","diabetes","biochemical_recurrence","_PANCAN_mirna_OV","melanoma_origin_skin_anatomic_site","barretts_esophagus","diagnosis_subtype","_PANCAN_mirna_KIRC","birth_control_pill_history_usage_category","albumin_result_specified_value","amount_of_alcohol_consumption_per_day","albumin_result_lower_limit","albumin_result_upper_limit","family_history_of_stomach_cancer","_PANCAN_DNAMethyl_GBM","_PANCAN_DNAMethyl_UCEC","_PANCAN_mirna_UCEC","melanoma_clark_level_value","_PANCAN_mirna_HNSC","fibrosis_ishak_score","adjacent_hepatic_tissue_inflammation_extent_type","_PANCAN_DNAMethyl_HNSC","antireflux_treatment","ldh1_mutation_found","_PANCAN_DNAMethyl_COADREAD","_PANCAN_mirna_LUAD","leukemia_french_american_british_morphology_code","acute_myeloid_leukemia_calgb_cytogenetics_risk_category","_PANCAN_DNAMethyl_LAML","_PANCAN_mirna_COAD","_PANCAN_mirna_LAML","metastatic_diagnosis","_PANCAN_DNAMethyl_LUAD","_PANCAN_DNAMethyl_LUSC","_PANCAN_mirna_LUSC","cancer_diagnosis_cancer_type_icd9_text_name","_PANCAN_DNAMethyl_BLCA","_PANCAN_mirna_BLCA","leiomyosarcoma_histologic_subtype","family_history_other_cancer","necrosis","weiss_venous_invasion","weiss_score","_PANCAN_mirna_READ","atypical_mitotic_figures","relative_cancer_type","metastatic_breast_carcinoma_estrogen_receptor_status","metastatic_breast_carcinoma_progesterone_receptor_status","well_differentiated_liposarcoma_primary_dx","food_allergy_types","metastatic_breast_carcinom_lb_prc_hr2_n_mmnhstchmstry_rcptr_stts"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/triplemnist/test.json:
--------------------------------------------------------------------------------
1 | ["002", "003", "008", "014", "016", "017", "039", "046", "047", "051", "056", "058", "060", "065", "067", "068", "073", "076", "077", "086", "087", "088", "096", "098", "099", "106", "109", "111", "113", "123", "129", "135", "139", "140", "141", "146", "154", "158", "176", "180", "186", "193", "198", "206", "208", "213", "214", "222", "224", "244", "253", "255", "257", "259", "262", "265", "271", "290", "296", "301", "304", "305", "315", "319", "321", "322", "323", "339", "340", "342", "354", "357", "358", "359", "360", "364", "365", "371", "377", "380", "382", "385", "390", "394", "401", "409", "411", "412", "418", "419", "420", "424", "428", "433", "434", "438", "440", "451", "460", "465", "467", "471", "472", "473", "480", "481", "484", "486", "489", "490", "493", "496", "497", "504", "510", "513", "515", "530", "537", "539", "544", "552", "555", "559", "569", "575", "576", "582", "588", "595", "596", "608", "615", "628", "630", "631", "635", "638", "645", "647", "661", "665", "680", "681", "695", "717", "720", "737", "742", "745", "755", "759", "762", "763", "765", "766", "768", "780", "781", "782", "790", "806", "812", "815", "817", "819", "823", "824", "826", "836", "837", "842", "843", "846", "860", "862", "865", "866", "867", "873", "878", "884", "885", "895", "899", "907", "919", "929", "942", "945", "951", "954", "956", "957", "964", "976", "988", "990", "996", "998"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/triplemnist/train.json:
--------------------------------------------------------------------------------
1 | ["000", "004", "005", "006", "007", "009", "011", "013", "015", "018", "020", "021", "022", "023", "024", "026", "027", "028", "030", "031", "032", "033", "034", "035", "036", "037", "038", "040", "041", "042", "043", "045", "048", "049", "050", "052", "054", "055", "057", "059", "061", "062", "063", "064", "066", "070", "071", "072", "074", "075", "078", "079", "080", "082", "084", "085", "089", "090", "091", "093", "094", "095", "097", "100", "101", "102", "103", "104", "105", "107", "110", "114", "115", "116", "117", "118", "119", "120", "121", "122", "124", "125", "127", "128", "131", "132", "134", "138", "142", "145", "147", "148", "150", "151", "153", "155", "156", "157", "159", "160", "161", "162", "163", "164", "165", "166", "167", "169", "170", "171", "172", "174", "175", "177", "178", "179", "181", "182", "183", "184", "185", "188", "189", "190", "191", "192", "195", "196", "199", "200", "201", "202", "203", "204", "205", "209", "210", "211", "212", "216", "217", "219", "221", "223", "226", "227", "228", "229", "230", "231", "232", "233", "234", "235", "236", "237", "239", "240", "241", "242", "243", "245", "246", "248", "249", "250", "252", "254", "256", "258", "260", "261", "263", "264", "266", "267", "268", "269", "270", "272", "273", "274", "275", "276", "277", "278", "279", "280", "282", "283", "284", "285", "287", "288", "289", "291", "292", "294", "295", "297", "298", "299", "300", "302", "303", "306", "307", "308", "309", "310", "311", "313", "314", "316", "317", "318", "320", "326", "327", "328", "329", "331", "332", "333", "335", "336", "338", "343", "344", "345", "346", "347", "348", "349", "351", "352", "353", "356", "362", "363", "366", "367", "368", "369", "370", "372", "373", "374", "375", "376", "378", "379", "381", "384", "386", "387", "388", "389", "391", "392", "395", "396", "397", "398", "399", "402", "403", "404", "406", "408", "414", "415", "416", "417", "421", "425", "426", "427", "429", "430", "431", "435", "437", "443", "444", "445", "446", "447", "448", "449", "450", "452", "453", "455", "456", "457", "458", "461", "462", "463", "468", "469", "470", "474", "476", "478", "479", "482", "483", "485", "487", "491", "492", "498", "499", "500", "505", "508", "509", "511", "512", "514", "517", "518", "520", "521", "522", "523", "524", "525", "526", "527", "528", "529", "531", "532", "533", "534", "536", "538", "541", "542", "543", "545", "546", "547", "548", "549", "551", "553", "554", "556", "558", "560", "561", "563", "564", "565", "566", "567", "568", "571", "577", "578", "580", "581", "584", "585", "586", "587", "590", "591", "594", "597", "598", "599", "600", "603", "605", "606", "607", "610", "613", "614", "616", "617", "618", "619", "620", "621", "623", "624", "625", "626", "632", "633", "637", "640", "641", "642", "643", "644", "646", "648", "650", "652", "653", "654", "655", "656", "657", "659", "660", "662", "663", "664", "666", "667", "668", "669", "670", "671", "672", "674", "675", "676", "677", "678", "682", "683", "684", "685", "687", "688", "689", "690", "691", "692", "693", "694", "696", "698", "699", "700", "701", "702", "704", "705", "706", "708", "709", "710", "711", "712", "713", "714", "716", "719", "721", "722", "723", "724", "725", "727", "728", "729", "730", "732", "733", "734", "735", "736", "738", "739", "743", "744", "747", "748", "749", "751", "752", "753", "754", "756", "757", "758", "760", "761", "764", "767", "769", "770", "771", "773", "774", "776", "777", "778", "783", "784", "785", "786", "787", "788", "791", "792", "793", "795", "797", "798", "799", "800", "801", "802", "803", "804", "805", "807", "808", "809", "810", "811", "813", "814", "818", "820", "821", "822", "825", "827", "828", "830", "831", "832", "833", "834", "835", "838", "839", "840", "844", "845", "848", "849", "850", "853", "854", "856", "857", "858", "859", "861", "863", "864", "868", "869", "870", "871", "872", "874", "875", "876", "877", "881", "882", "883", "887", "891", "892", "893", "894", "896", "897", "898", "901", "902", "903", "905", "908", "909", "913", "914", "916", "917", "918", "920", "921", "922", "923", "924", "925", "927", "928", "930", "931", "932", "933", "934", "935", "936", "937", "938", "939", "941", "943", "946", "947", "948", "949", "950", "952", "953", "958", "959", "960", "965", "966", "967", "968", "969", "970", "971", "972", "973", "974", "977", "978", "979", "982", "983", "984", "985", "986", "987", "989", "991", "993", "994", "995", "997"]
--------------------------------------------------------------------------------
/torchmeta/datasets/assets/triplemnist/val.json:
--------------------------------------------------------------------------------
1 | ["001", "010", "012", "019", "025", "029", "044", "053", "069", "081", "083", "092", "108", "112", "126", "130", "133", "136", "137", "143", "144", "149", "152", "168", "173", "187", "194", "197", "207", "215", "218", "220", "225", "238", "247", "251", "281", "286", "293", "312", "324", "325", "330", "334", "337", "341", "350", "355", "361", "383", "393", "400", "405", "407", "410", "413", "422", "423", "432", "436", "439", "441", "442", "454", "459", "464", "466", "475", "477", "488", "494", "495", "501", "502", "503", "506", "507", "516", "519", "535", "540", "550", "557", "562", "570", "572", "573", "574", "579", "583", "589", "592", "593", "601", "602", "604", "609", "611", "612", "622", "627", "629", "634", "636", "639", "649", "651", "658", "673", "679", "686", "697", "703", "707", "715", "718", "726", "731", "740", "741", "746", "750", "772", "775", "779", "789", "794", "796", "816", "829", "841", "847", "851", "852", "855", "879", "880", "886", "888", "889", "890", "900", "904", "906", "910", "911", "912", "915", "926", "940", "944", "955", "961", "962", "963", "975", "980", "981", "992", "999"]
--------------------------------------------------------------------------------
/torchmeta/datasets/cifar100/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta.datasets.cifar100.cifar_fs import CIFARFS
2 | from torchmeta.datasets.cifar100.fc100 import FC100
3 |
4 | __all__ = ['CIFARFS', 'FC100']
5 |
--------------------------------------------------------------------------------
/torchmeta/datasets/cifar100/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import json
4 | import h5py
5 | from PIL import Image
6 |
7 | from torchvision.datasets.utils import check_integrity, download_url
8 | from torchmeta.utils.data import Dataset, ClassDataset
9 |
10 |
11 | class CIFAR100ClassDataset(ClassDataset):
12 | folder = 'cifar100'
13 | subfolder = None
14 | download_url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
15 | gz_folder = 'cifar-100-python'
16 | gz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
17 | files_md5 = {
18 | 'train': '16019d7e3df5f24257cddd939b257f8d',
19 | 'test': 'f0ef6b0ae62326f3e7ffdfab6717acfc',
20 | 'meta': '7973b15100ade9c7d40fb424638fde48'
21 | }
22 |
23 | filename = 'data.hdf5'
24 | filename_labels = '{0}_labels.json'
25 | filename_fine_names = 'fine_names.json'
26 |
27 | def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
28 | meta_split=None, transform=None, class_augmentations=None,
29 | download=False):
30 | super(CIFAR100ClassDataset, self).__init__(meta_train=meta_train,
31 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
32 | class_augmentations=class_augmentations)
33 |
34 | if self.subfolder is None:
35 | raise ValueError()
36 |
37 | self.root = os.path.join(os.path.expanduser(root), self.folder)
38 | self.transform = transform
39 |
40 | self.split_filename_labels = os.path.join(self.root, self.subfolder,
41 | self.filename_labels.format(self.meta_split))
42 | self._data = None
43 | self._labels = None
44 |
45 | if download:
46 | self.download()
47 |
48 | if not self._check_integrity():
49 | raise RuntimeError('CIFAR100 integrity check failed')
50 | self._num_classes = len(self.labels)
51 |
52 | def __getitem__(self, index):
53 | coarse_label_name, fine_label_name = self.labels[index % self.num_classes]
54 | data = self.data['{0}/{1}'.format(coarse_label_name, fine_label_name)]
55 | transform = self.get_transform(index, self.transform)
56 | target_transform = self.get_target_transform(index)
57 |
58 | return CIFAR100Dataset(index, data, coarse_label_name, fine_label_name,
59 | transform=transform, target_transform=target_transform)
60 |
61 | @property
62 | def num_classes(self):
63 | return self._num_classes
64 |
65 | @property
66 | def data(self):
67 | if self._data is None:
68 | self._data = h5py.File(os.path.join(self.root, self.filename), 'r')
69 | return self._data
70 |
71 | @property
72 | def labels(self):
73 | if self._labels is None:
74 | with open(self.split_filename_labels, 'r') as f:
75 | self._labels = json.load(f)
76 | return self._labels
77 |
78 | def _check_integrity(self):
79 | return (self._check_integrity_data()
80 | and os.path.isfile(self.split_filename_labels)
81 | and os.path.isfile(os.path.join(self.root, self.filename_fine_names)))
82 |
83 | def _check_integrity_data(self):
84 | return os.path.isfile(os.path.join(self.root, self.filename))
85 |
86 | def close(self):
87 | if self._data is not None:
88 | self._data.close()
89 | self._data = None
90 |
91 | def download(self):
92 | import tarfile
93 | import pickle
94 | import shutil
95 |
96 | if self._check_integrity_data():
97 | return
98 |
99 | gz_filename = '{0}.tar.gz'.format(self.gz_folder)
100 | download_url(self.download_url, self.root, filename=gz_filename,
101 | md5=self.gz_md5)
102 | with tarfile.open(os.path.join(self.root, gz_filename), 'r:gz') as tar:
103 | tar.extractall(path=self.root)
104 |
105 | train_filename = os.path.join(self.root, self.gz_folder, 'train')
106 | check_integrity(train_filename, self.files_md5['train'])
107 | with open(train_filename, 'rb') as f:
108 | data = pickle.load(f, encoding='bytes')
109 | images = data[b'data']
110 | fine_labels = data[b'fine_labels']
111 | coarse_labels = data[b'coarse_labels']
112 |
113 | test_filename = os.path.join(self.root, self.gz_folder, 'test')
114 | check_integrity(test_filename, self.files_md5['test'])
115 | with open(test_filename, 'rb') as f:
116 | data = pickle.load(f, encoding='bytes')
117 | images = np.concatenate((images, data[b'data']), axis=0)
118 | fine_labels = np.concatenate((fine_labels, data[b'fine_labels']), axis=0)
119 | coarse_labels = np.concatenate((coarse_labels, data[b'coarse_labels']), axis=0)
120 |
121 | images = images.reshape((-1, 3, 32, 32))
122 | images = images.transpose((0, 2, 3, 1))
123 |
124 | meta_filename = os.path.join(self.root, self.gz_folder, 'meta')
125 | check_integrity(meta_filename, self.files_md5['meta'])
126 | with open(meta_filename, 'rb') as f:
127 | data = pickle.load(f, encoding='latin1')
128 | fine_label_names = data['fine_label_names']
129 | coarse_label_names = data['coarse_label_names']
130 |
131 | filename = os.path.join(self.root, self.filename)
132 | fine_names = dict()
133 | with h5py.File(filename, 'w') as f:
134 | for i, coarse_name in enumerate(coarse_label_names):
135 | group = f.create_group(coarse_name)
136 | fine_indices = np.unique(fine_labels[coarse_labels == i])
137 | for j in fine_indices:
138 | dataset = group.create_dataset(fine_label_names[j],
139 | data=images[fine_labels == j])
140 | fine_names[coarse_name] = [fine_label_names[j] for j in fine_indices]
141 |
142 | filename_fine_names = os.path.join(self.root, self.filename_fine_names)
143 | with open(filename_fine_names, 'w') as f:
144 | json.dump(fine_names, f)
145 |
146 | gz_folder = os.path.join(self.root, self.gz_folder)
147 | if os.path.isdir(gz_folder):
148 | shutil.rmtree(gz_folder)
149 | if os.path.isfile('{0}.tar.gz'.format(gz_folder)):
150 | os.remove('{0}.tar.gz'.format(gz_folder))
151 |
152 |
153 | class CIFAR100Dataset(Dataset):
154 | def __init__(self, index, data, coarse_label_name, fine_label_name,
155 | transform=None, target_transform=None):
156 | super(CIFAR100Dataset, self).__init__(index, transform=transform,
157 | target_transform=target_transform)
158 | self.data = data
159 | self.coarse_label_name = coarse_label_name
160 | self.fine_label_name = fine_label_name
161 |
162 | def __len__(self):
163 | return self.data.shape[0]
164 |
165 | def __getitem__(self, index):
166 | image = Image.fromarray(self.data[index])
167 | target = (self.coarse_label_name, self.fine_label_name)
168 |
169 | if self.transform is not None:
170 | image = self.transform(image)
171 |
172 | if self.target_transform is not None:
173 | target = self.target_transform(target)
174 |
175 | return (image, target)
176 |
--------------------------------------------------------------------------------
/torchmeta/datasets/cifar100/cifar_fs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | from torchmeta.datasets.cifar100.base import CIFAR100ClassDataset
5 | from torchmeta.datasets.utils import get_asset
6 | from torchmeta.utils.data import ClassDataset, CombinationMetaDataset
7 |
8 |
9 | class CIFARFS(CombinationMetaDataset):
10 | """
11 | The CIFAR-FS dataset, introduced in [1]. This dataset contains
12 | images of 100 different classes from the CIFAR100 dataset [2].
13 |
14 | Parameters
15 | ----------
16 | root : string
17 | Root directory where the dataset folder `cifar100` exists.
18 |
19 | num_classes_per_task : int
20 | Number of classes per tasks. This corresponds to `N` in `N-way`
21 | classification.
22 |
23 | meta_train : bool (default: `False`)
24 | Use the meta-train split of the dataset. If set to `True`, then the
25 | arguments `meta_val` and `meta_test` must be set to `False`. Exactly one
26 | of these three arguments must be set to `True`.
27 |
28 | meta_val : bool (default: `False`)
29 | Use the meta-validation split of the dataset. If set to `True`, then the
30 | arguments `meta_train` and `meta_test` must be set to `False`. Exactly one
31 | of these three arguments must be set to `True`.
32 |
33 | meta_test : bool (default: `False`)
34 | Use the meta-test split of the dataset. If set to `True`, then the
35 | arguments `meta_train` and `meta_val` must be set to `False`. Exactly one
36 | of these three arguments must be set to `True`.
37 |
38 | meta_split : string in {'train', 'val', 'test'}, optional
39 | Name of the split to use. This overrides the arguments `meta_train`,
40 | `meta_val` and `meta_test` if all three are set to `False`.
41 |
42 | transform : callable, optional
43 | A function/transform that takes a `PIL` image, and returns a transformed
44 | version. See also `torchvision.transforms`.
45 |
46 | target_transform : callable, optional
47 | A function/transform that takes a target, and returns a transformed
48 | version. See also `torchvision.transforms`.
49 |
50 | dataset_transform : callable, optional
51 | A function/transform that takes a dataset (ie. a task), and returns a
52 | transformed version of it. E.g. `transforms.ClassSplitter()`.
53 |
54 | class_augmentations : list of callable, optional
55 | A list of functions that augment the dataset with new classes. These classes
56 | are transformations of existing classes. E.g. `transforms.HorizontalFlip()`.
57 |
58 | download : bool (default: `False`)
59 | If `True`, downloads the pickle files and processes the dataset in the root
60 | directory (under the `cifar100` folder). If the dataset is already
61 | available, this does not download/process the dataset again.
62 |
63 | Notes
64 | -----
65 | The meta train/validation/test splits are over 64/16/20 classes from the
66 | CIFAR100 dataset.
67 |
68 | References
69 | ----------
70 | .. [1] Bertinetto L., Henriques J. F., Torr P. H.S., Vedaldi A. (2019).
71 | Meta-learning with differentiable closed-form solvers. In International
72 | Conference on Learning Representations (https://arxiv.org/abs/1805.08136)
73 |
74 | .. [2] Krizhevsky A. (2009). Learning Multiple Layers of Features from Tiny
75 | Images. (https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
76 | """
77 | def __init__(self, root, num_classes_per_task=None, meta_train=False,
78 | meta_val=False, meta_test=False, meta_split=None,
79 | transform=None, target_transform=None, dataset_transform=None,
80 | class_augmentations=None, download=False):
81 | dataset = CIFARFSClassDataset(root, meta_train=meta_train,
82 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
83 | transform=transform, class_augmentations=class_augmentations,
84 | download=download)
85 | super(CIFARFS, self).__init__(dataset, num_classes_per_task,
86 | target_transform=target_transform, dataset_transform=dataset_transform)
87 |
88 |
89 | class CIFARFSClassDataset(CIFAR100ClassDataset):
90 | subfolder = 'cifar-fs'
91 |
92 | def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
93 | meta_split=None, transform=None, class_augmentations=None,
94 | download=False):
95 | super(CIFARFSClassDataset, self).__init__(root, meta_train=meta_train,
96 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
97 | transform=transform, class_augmentations=class_augmentations,
98 | download=download)
99 |
100 | def download(self):
101 | if self._check_integrity():
102 | return
103 | super(CIFARFSClassDataset, self).download()
104 |
105 | subfolder = os.path.join(self.root, self.subfolder)
106 | if not os.path.exists(subfolder):
107 | os.makedirs(subfolder)
108 |
109 | for split in ['train', 'val', 'test']:
110 | split_filename_labels = os.path.join(subfolder,
111 | self.filename_labels.format(split))
112 | if os.path.isfile(split_filename_labels):
113 | continue
114 |
115 | data = get_asset(self.folder, self.subfolder,
116 | '{0}.json'.format(split), dtype='json')
117 | with open(split_filename_labels, 'w') as f:
118 | json.dump(data, f)
119 |
--------------------------------------------------------------------------------
/torchmeta/datasets/cifar100/fc100.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | from torchmeta.datasets.cifar100.base import CIFAR100ClassDataset
5 | from torchmeta.datasets.utils import get_asset
6 | from torchmeta.utils.data import ClassDataset, CombinationMetaDataset
7 |
8 |
9 | class FC100(CombinationMetaDataset):
10 | """
11 | The Fewshot-CIFAR100 dataset, introduced in [1]. This dataset contains
12 | images of 100 different classes from the CIFAR100 dataset [2].
13 |
14 | Parameters
15 | ----------
16 | root : string
17 | Root directory where the dataset folder `cifar100` exists.
18 |
19 | num_classes_per_task : int
20 | Number of classes per tasks. This corresponds to `N` in `N-way`
21 | classification.
22 |
23 | meta_train : bool (default: `False`)
24 | Use the meta-train split of the dataset. If set to `True`, then the
25 | arguments `meta_val` and `meta_test` must be set to `False`. Exactly one
26 | of these three arguments must be set to `True`.
27 |
28 | meta_val : bool (default: `False`)
29 | Use the meta-validation split of the dataset. If set to `True`, then the
30 | arguments `meta_train` and `meta_test` must be set to `False`. Exactly one
31 | of these three arguments must be set to `True`.
32 |
33 | meta_test : bool (default: `False`)
34 | Use the meta-test split of the dataset. If set to `True`, then the
35 | arguments `meta_train` and `meta_val` must be set to `False`. Exactly one
36 | of these three arguments must be set to `True`.
37 |
38 | meta_split : string in {'train', 'val', 'test'}, optional
39 | Name of the split to use. This overrides the arguments `meta_train`,
40 | `meta_val` and `meta_test` if all three are set to `False`.
41 |
42 | transform : callable, optional
43 | A function/transform that takes a `PIL` image, and returns a transformed
44 | version. See also `torchvision.transforms`.
45 |
46 | target_transform : callable, optional
47 | A function/transform that takes a target, and returns a transformed
48 | version. See also `torchvision.transforms`.
49 |
50 | dataset_transform : callable, optional
51 | A function/transform that takes a dataset (ie. a task), and returns a
52 | transformed version of it. E.g. `transforms.ClassSplitter()`.
53 |
54 | class_augmentations : list of callable, optional
55 | A list of functions that augment the dataset with new classes. These classes
56 | are transformations of existing classes. E.g. `transforms.HorizontalFlip()`.
57 |
58 | download : bool (default: `False`)
59 | If `True`, downloads the pickle files and processes the dataset in the root
60 | directory (under the `cifar100` folder). If the dataset is already
61 | available, this does not download/process the dataset again.
62 |
63 | Notes
64 | -----
65 | The meta train/validation/test splits are over 12/4/4 superclasses from the
66 | CIFAR100 dataset. The meta train/validation/test splits contain 60/20/20
67 | classes.
68 |
69 | References
70 | ----------
71 | .. [1] Oreshkin B. N., Rodriguez P., Lacoste A. (2018). TADAM: Task dependent
72 | adaptive metric for improved few-shot learning. In Advances in Neural
73 | Information Processing Systems (https://arxiv.org/abs/1805.10123)
74 |
75 | .. [2] Krizhevsky A. (2009). Learning Multiple Layers of Features from Tiny
76 | Images. (https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
77 | """
78 | def __init__(self, root, num_classes_per_task=None, meta_train=False,
79 | meta_val=False, meta_test=False, meta_split=None,
80 | transform=None, target_transform=None, dataset_transform=None,
81 | class_augmentations=None, download=False):
82 | dataset = FC100ClassDataset(root, meta_train=meta_train,
83 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
84 | transform=transform, class_augmentations=class_augmentations,
85 | download=download)
86 | super(FC100, self).__init__(dataset, num_classes_per_task,
87 | target_transform=target_transform, dataset_transform=dataset_transform)
88 |
89 |
90 | class FC100ClassDataset(CIFAR100ClassDataset):
91 | subfolder = 'fc100'
92 |
93 | def __init__(self, root, meta_train=False, meta_val=False, meta_test=False,
94 | meta_split=None, transform=None, class_augmentations=None,
95 | download=False):
96 | super(FC100ClassDataset, self).__init__(root, meta_train=meta_train,
97 | meta_val=meta_val, meta_test=meta_test, meta_split=meta_split,
98 | transform=transform, class_augmentations=class_augmentations,
99 | download=download)
100 |
101 | def download(self):
102 | if self._check_integrity():
103 | return
104 | super(FC100ClassDataset, self).download()
105 |
106 | subfolder = os.path.join(self.root, self.subfolder)
107 | if not os.path.exists(subfolder):
108 | os.makedirs(subfolder)
109 |
110 | filename_fine_names = os.path.join(self.root, self.filename_fine_names)
111 | with open(filename_fine_names, 'r') as f:
112 | fine_names = json.load(f)
113 |
114 | for split in ['train', 'val', 'test']:
115 | split_filename_labels = os.path.join(subfolder,
116 | self.filename_labels.format(split))
117 | if os.path.isfile(split_filename_labels):
118 | continue
119 |
120 | data = get_asset(self.folder, self.subfolder,
121 | '{0}.json'.format(split), dtype='json')
122 | with open(split_filename_labels, 'w') as f:
123 | labels = [[coarse_name, fine_name] for coarse_name in data
124 | for fine_name in fine_names[coarse_name]]
125 | json.dump(labels, f)
126 |
--------------------------------------------------------------------------------
/torchmeta/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | def get_asset_path(*args):
5 | basedir = os.path.dirname(__file__)
6 | return os.path.join(basedir, 'assets', *args)
7 |
8 |
9 | def get_asset(*args, dtype=None):
10 | filename = get_asset_path(*args)
11 | if not os.path.isfile(filename):
12 | raise IOError('{} not found'.format(filename))
13 |
14 | if dtype is None:
15 | _, dtype = os.path.splitext(filename)
16 | dtype = dtype[1:]
17 |
18 | if dtype == 'json':
19 | with open(filename, 'r') as f:
20 | data = json.load(f)
21 | else:
22 | raise NotImplementedError()
23 | return data
24 |
--------------------------------------------------------------------------------
/torchmeta/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta.modules.batchnorm import MetaBatchNorm1d, MetaBatchNorm2d, MetaBatchNorm3d
2 | from torchmeta.modules.container import MetaSequential
3 | from torchmeta.modules.conv import MetaConv1d, MetaConv2d, MetaConv3d
4 | from torchmeta.modules.linear import MetaLinear, MetaBilinear
5 | from torchmeta.modules.module import MetaModule
6 | from torchmeta.modules.normalization import MetaLayerNorm
7 |
8 | __all__ = [
9 | 'MetaBatchNorm1d', 'MetaBatchNorm2d', 'MetaBatchNorm3d',
10 | 'MetaSequential',
11 | 'MetaConv1d', 'MetaConv2d', 'MetaConv3d',
12 | 'MetaLinear', 'MetaBilinear',
13 | 'MetaModule',
14 | 'MetaLayerNorm'
15 | ]
--------------------------------------------------------------------------------
/torchmeta/modules/batchnorm.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from collections import OrderedDict
5 | from torch.nn.modules.batchnorm import _BatchNorm
6 | from torchmeta.modules.module import MetaModule
7 |
8 | class _MetaBatchNorm(_BatchNorm, MetaModule):
9 | def forward(self, input, params=None):
10 | self._check_input_dim(input)
11 | if params is None:
12 | params = OrderedDict(self.named_parameters())
13 |
14 | # exponential_average_factor is self.momentum set to
15 | # (when it is available) only so that if gets updated
16 | # in ONNX graph when this node is exported to ONNX.
17 | if self.momentum is None:
18 | exponential_average_factor = 0.0
19 | else:
20 | exponential_average_factor = self.momentum
21 |
22 | if self.training and self.track_running_stats:
23 | if self.num_batches_tracked is not None:
24 | self.num_batches_tracked += 1
25 | if self.momentum is None: # use cumulative moving average
26 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
27 | else: # use exponential moving average
28 | exponential_average_factor = self.momentum
29 |
30 | weight = params.get('weight', None)
31 | bias = params.get('bias', None)
32 |
33 | return F.batch_norm(
34 | input, self.running_mean, self.running_var, weight, bias,
35 | self.training or not self.track_running_stats,
36 | exponential_average_factor, self.eps)
37 |
38 | class MetaBatchNorm1d(_MetaBatchNorm):
39 | __doc__ = nn.BatchNorm1d.__doc__
40 |
41 | def _check_input_dim(self, input):
42 | if input.dim() != 2 and input.dim() != 3:
43 | raise ValueError('expected 2D or 3D input (got {}D input)'
44 | .format(input.dim()))
45 |
46 | class MetaBatchNorm2d(_MetaBatchNorm):
47 | __doc__ = nn.BatchNorm2d.__doc__
48 |
49 | def _check_input_dim(self, input):
50 | if input.dim() != 4:
51 | raise ValueError('expected 4D input (got {}D input)'
52 | .format(input.dim()))
53 |
54 | class MetaBatchNorm3d(_MetaBatchNorm):
55 | __doc__ = nn.BatchNorm3d.__doc__
56 |
57 | def _check_input_dim(self, input):
58 | if input.dim() != 5:
59 | raise ValueError('expected 5D input (got {}D input)'
60 | .format(input.dim()))
61 |
--------------------------------------------------------------------------------
/torchmeta/modules/container.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from torchmeta.modules.module import MetaModule
4 | from torchmeta.modules.utils import get_subdict
5 |
6 | class MetaSequential(nn.Sequential, MetaModule):
7 | __doc__ = nn.Sequential.__doc__
8 |
9 | def forward(self, input, params=None):
10 | for name, module in self._modules.items():
11 | if isinstance(module, MetaModule):
12 | input = module(input, params=get_subdict(params, name))
13 | elif isinstance(module, nn.Module):
14 | input = module(input)
15 | else:
16 | raise TypeError('The module must be either a torch module '
17 | '(inheriting from `nn.Module`), or a `MetaModule`. '
18 | 'Got type: `{0}`'.format(type(module)))
19 | return input
20 |
--------------------------------------------------------------------------------
/torchmeta/modules/conv.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from collections import OrderedDict
5 | from torch.nn.modules.utils import _single, _pair, _triple
6 | from torchmeta.modules.module import MetaModule
7 |
8 | class MetaConv1d(nn.Conv1d, MetaModule):
9 | __doc__ = nn.Conv1d.__doc__
10 |
11 | def forward(self, input, params=None):
12 | if params is None:
13 | params = OrderedDict(self.named_parameters())
14 | bias = params.get('bias', None)
15 |
16 | if self.padding_mode == 'circular':
17 | expanded_padding = ((self.padding[0] + 1) // 2, self.padding[0] // 2)
18 | return F.conv1d(F.pad(input, expanded_padding, mode='circular'),
19 | params['weight'], bias, self.stride,
20 | _single(0), self.dilation, self.groups)
21 |
22 | return F.conv1d(input, params['weight'], bias, self.stride,
23 | self.padding, self.dilation, self.groups)
24 |
25 | class MetaConv2d(nn.Conv2d, MetaModule):
26 | __doc__ = nn.Conv2d.__doc__
27 |
28 | def forward(self, input, params=None):
29 | if params is None:
30 | params = OrderedDict(self.named_parameters())
31 | bias = params.get('bias', None)
32 |
33 | if self.padding_mode == 'circular':
34 | expanded_padding = ((self.padding[1] + 1) // 2, self.padding[1] // 2,
35 | (self.padding[0] + 1) // 2, self.padding[0] // 2)
36 | return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
37 | params['weight'], bias, self.stride,
38 | _pair(0), self.dilation, self.groups)
39 |
40 | return F.conv2d(input, params['weight'], bias, self.stride,
41 | self.padding, self.dilation, self.groups)
42 |
43 | class MetaConv3d(nn.Conv3d, MetaModule):
44 | __doc__ = nn.Conv3d.__doc__
45 |
46 | def forward(self, input, params=None):
47 | if params is None:
48 | params = OrderedDict(self.named_parameters())
49 | bias = params.get('bias', None)
50 |
51 | if self.padding_mode == 'circular':
52 | expanded_padding = ((self.padding[2] + 1) // 2, self.padding[2] // 2,
53 | (self.padding[1] + 1) // 2, self.padding[1] // 2,
54 | (self.padding[0] + 1) // 2, self.padding[0] // 2)
55 | return F.conv3d(F.pad(input, expanded_padding, mode='circular'),
56 | params['weight'], bias, self.stride,
57 | _triple(0), self.dilation, self.groups)
58 |
59 | return F.conv3d(input, params['weight'], bias, self.stride,
60 | self.padding, self.dilation, self.groups)
61 |
--------------------------------------------------------------------------------
/torchmeta/modules/linear.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from collections import OrderedDict
5 | from torchmeta.modules.module import MetaModule
6 |
7 | class MetaLinear(nn.Linear, MetaModule):
8 | __doc__ = nn.Linear.__doc__
9 |
10 | def forward(self, input, params=None):
11 | if params is None:
12 | params = OrderedDict(self.named_parameters())
13 | bias = params.get('bias', None)
14 | return F.linear(input, params['weight'], bias)
15 |
16 | class MetaBilinear(nn.Bilinear, MetaModule):
17 | __doc__ = nn.Bilinear.__doc__
18 |
19 | def forward(self, input1, input2, params=None):
20 | if params is None:
21 | params = OrderedDict(self.named_parameters())
22 | bias = params.get('bias', None)
23 | return F.bilinear(input1, input2, params['weight'], bias)
24 |
--------------------------------------------------------------------------------
/torchmeta/modules/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from collections import OrderedDict
5 |
6 | class MetaModule(nn.Module):
7 | """
8 | Base class for PyTorch meta-learning modules. These modules accept an
9 | additional argument `params` in their `forward` method.
10 |
11 | Notes
12 | -----
13 | Objects inherited from `MetaModule` are fully compatible with PyTorch
14 | modules from `torch.nn.Module`. The argument `params` is a dictionary of
15 | tensors, with full support of the computation graph (for differentiation).
16 | """
17 | def meta_named_parameters(self, prefix='', recurse=True):
18 | gen = self._named_members(
19 | lambda module: module._parameters.items()
20 | if isinstance(module, MetaModule) else [],
21 | prefix=prefix, recurse=recurse)
22 | for elem in gen:
23 | yield elem
24 |
25 | def meta_parameters(self, recurse=True):
26 | for name, param in self.meta_named_parameters(recurse=recurse):
27 | yield param
28 |
--------------------------------------------------------------------------------
/torchmeta/modules/normalization.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from collections import OrderedDict
5 | from torchmeta.modules.module import MetaModule
6 |
7 | class MetaLayerNorm(nn.LayerNorm, MetaModule):
8 | __doc__ = nn.LayerNorm.__doc__
9 |
10 | def forward(self, input, params=None):
11 | if params is None:
12 | params = OrderedDict(self.named_parameters())
13 | weight = params.get('weight', None)
14 | bias = params.get('bias', None)
15 | return F.layer_norm(
16 | input, self.normalized_shape, weight, bias, self.eps)
17 |
--------------------------------------------------------------------------------
/torchmeta/modules/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from collections import OrderedDict
3 |
4 | def get_subdict(dictionary, key=None):
5 | if dictionary is None:
6 | return None
7 | if (key is None) or (key == ''):
8 | return dictionary
9 | key_re = re.compile(r'^{0}\.(.+)'.format(re.escape(key)))
10 | return OrderedDict((key_re.sub(r'\1', k), value) for (k, value)
11 | in dictionary.items() if key_re.match(k) is not None)
12 |
--------------------------------------------------------------------------------
/torchmeta/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vsitzmann/siren/4df34baee3f0f9c8f351630992c1fe1f69114b5f/torchmeta/tests/__init__.py
--------------------------------------------------------------------------------
/torchmeta/tests/test_dataloaders.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 | from torchmeta.toy import Sinusoid
7 | from torchmeta.transforms import ClassSplitter
8 | from torchmeta.utils.data import Task, MetaDataLoader, BatchMetaDataLoader
9 |
10 |
11 | def test_meta_dataloader():
12 | dataset = Sinusoid(10, num_tasks=1000, noise_std=None)
13 | meta_dataloader = MetaDataLoader(dataset, batch_size=4)
14 | assert isinstance(meta_dataloader, DataLoader)
15 | assert len(meta_dataloader) == 250 # 1000 / 4
16 |
17 | batch = next(iter(meta_dataloader))
18 | assert isinstance(batch, list)
19 | assert len(batch) == 4
20 |
21 | task = batch[0]
22 | assert isinstance(task, Task)
23 | assert len(task) == 10
24 |
25 |
26 | def test_meta_dataloader_task_loader():
27 | dataset = Sinusoid(10, num_tasks=1000, noise_std=None)
28 | meta_dataloader = MetaDataLoader(dataset, batch_size=4)
29 | batch = next(iter(meta_dataloader))
30 |
31 | dataloader = DataLoader(batch[0], batch_size=5)
32 | inputs, targets = next(iter(dataloader))
33 |
34 | assert len(dataloader) == 2 # 10 / 5
35 | # PyTorch dataloaders convert numpy array to tensors
36 | assert isinstance(inputs, torch.Tensor)
37 | assert isinstance(targets, torch.Tensor)
38 | assert inputs.shape == (5, 1)
39 | assert targets.shape == (5, 1)
40 |
41 |
42 | def test_batch_meta_dataloader():
43 | dataset = Sinusoid(10, num_tasks=1000, noise_std=None)
44 | meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)
45 | assert isinstance(meta_dataloader, DataLoader)
46 | assert len(meta_dataloader) == 250 # 1000 / 4
47 |
48 | inputs, targets = next(iter(meta_dataloader))
49 | assert isinstance(inputs, torch.Tensor)
50 | assert isinstance(targets, torch.Tensor)
51 | assert inputs.shape == (4, 10, 1)
52 | assert targets.shape == (4, 10, 1)
53 |
54 |
55 | def test_batch_meta_dataloader_splitter():
56 | dataset = Sinusoid(20, num_tasks=1000, noise_std=None)
57 | dataset = ClassSplitter(dataset, num_train_per_class=5,
58 | num_test_per_class=15)
59 | meta_dataloader = BatchMetaDataLoader(dataset, batch_size=4)
60 |
61 | batch = next(iter(meta_dataloader))
62 | assert isinstance(batch, dict)
63 | assert 'train' in batch
64 | assert 'test' in batch
65 |
66 | train_inputs, train_targets = batch['train']
67 | test_inputs, test_targets = batch['test']
68 | assert isinstance(train_inputs, torch.Tensor)
69 | assert isinstance(train_targets, torch.Tensor)
70 | assert train_inputs.shape == (4, 5, 1)
71 | assert train_targets.shape == (4, 5, 1)
72 | assert isinstance(test_inputs, torch.Tensor)
73 | assert isinstance(test_targets, torch.Tensor)
74 | assert test_inputs.shape == (4, 15, 1)
75 | assert test_targets.shape == (4, 15, 1)
76 |
--------------------------------------------------------------------------------
/torchmeta/tests/test_prototype.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from torchmeta.utils.prototype import get_num_samples, get_prototypes, prototypical_loss
7 |
8 |
9 | @pytest.mark.parametrize('dtype', [None, torch.float32])
10 | def test_get_num_samples(dtype):
11 | # Numpy
12 | num_classes = 3
13 | targets_np = np.random.randint(0, num_classes, size=(2, 5))
14 |
15 | # PyTorch
16 | targets_th = torch.as_tensor(targets_np)
17 | num_samples_th = get_num_samples(targets_th, num_classes, dtype=dtype)
18 |
19 | num_samples_np = np.zeros((2, num_classes), dtype=np.int_)
20 | for i in range(2):
21 | for j in range(5):
22 | num_samples_np[i, targets_np[i, j]] += 1
23 |
24 | assert num_samples_th.shape == (2, num_classes)
25 | if dtype is not None:
26 | assert num_samples_th.dtype == dtype
27 | np.testing.assert_equal(num_samples_th.numpy(), num_samples_np)
28 |
29 |
30 | def test_get_prototypes():
31 | # Numpy
32 | num_classes = 3
33 | embeddings_np = np.random.rand(2, 5, 7).astype(np.float32)
34 | targets_np = np.random.randint(0, num_classes, size=(2, 5))
35 |
36 | # PyTorch
37 | embeddings_th = torch.as_tensor(embeddings_np)
38 | targets_th = torch.as_tensor(targets_np)
39 | prototypes_th = get_prototypes(embeddings_th, targets_th, num_classes)
40 |
41 | assert prototypes_th.shape == (2, num_classes, 7)
42 | assert prototypes_th.dtype == embeddings_th.dtype
43 |
44 | prototypes_np = np.zeros((2, num_classes, 7), dtype=np.float32)
45 | num_samples_np = np.zeros((2, num_classes), dtype=np.int_)
46 | for i in range(2):
47 | for j in range(5):
48 | num_samples_np[i, targets_np[i, j]] += 1
49 | for k in range(7):
50 | prototypes_np[i, targets_np[i, j], k] += embeddings_np[i, j, k]
51 |
52 | for i in range(2):
53 | for j in range(num_classes):
54 | for k in range(7):
55 | prototypes_np[i, j, k] /= max(num_samples_np[i, j], 1)
56 |
57 | np.testing.assert_allclose(prototypes_th.detach().numpy(), prototypes_np)
58 |
--------------------------------------------------------------------------------
/torchmeta/tests/test_splitters.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 | from torchmeta.transforms.splitters import ClassSplitter
7 | from torchmeta.toy import Sinusoid
8 | from torchmeta.utils.data import Task
9 |
10 | def test_seed_class_splitter():
11 | dataset_transform = ClassSplitter(shuffle=True,
12 | num_train_per_class=5, num_test_per_class=5)
13 | dataset = Sinusoid(10, num_tasks=1000, noise_std=0.1,
14 | dataset_transform=dataset_transform)
15 | dataset.seed(1)
16 |
17 | expected_train_inputs = np.array([-2.03870077, 0.09898378, 3.75388738, 1.08565437, -1.56211897])
18 | expected_train_targets = np.array([-0.1031986 , -1.61885041, 0.91773121, -0.00309463, -1.37650356])
19 |
20 | expected_test_inputs = np.array([ 4.62078213, -2.48340416, 0.32922559, 0.76977846, -3.15504396])
21 | expected_test_targets = np.array([-0.9346262 , 0.73113509, -1.52508997, -0.4698061 , 1.86656819])
22 |
23 | task = dataset[0]
24 | train_dataset, test_dataset = task['train'], task['test']
25 |
26 | assert len(train_dataset) == 5
27 | assert len(test_dataset) == 5
28 |
29 | for i, (train_input, train_target) in enumerate(train_dataset):
30 | assert np.isclose(train_input, expected_train_inputs[i])
31 | assert np.isclose(train_target, expected_train_targets[i])
32 |
33 | for i, (test_input, test_target) in enumerate(test_dataset):
34 | assert np.isclose(test_input, expected_test_inputs[i])
35 | assert np.isclose(test_target, expected_test_targets[i])
36 |
37 | def test_class_splitter_for_fold_overlaps():
38 | class DemoTask(Task):
39 | def __init__(self):
40 | super(DemoTask, self).__init__(index=0, num_classes=None)
41 | self._inputs = np.arange(10)
42 |
43 | def __len__(self):
44 | return len(self._inputs)
45 |
46 | def __getitem__(self, index):
47 | return self._inputs[index]
48 |
49 | splitter = ClassSplitter(shuffle=True, num_train_per_class=5, num_test_per_class=5)
50 | task = DemoTask()
51 |
52 | all_train_samples = list()
53 | all_test_samples = list()
54 |
55 | # split task ten times into train and test
56 | for i in range(10):
57 | tasks_split = splitter(task)
58 | train_task = tasks_split["train"]
59 | test_task = tasks_split["test"]
60 |
61 | train_samples = set([train_task[i] for i in range(len(train_task))])
62 | test_samples = set([test_task[i] for i in range(len(train_task))])
63 |
64 | # no overlap between train and test splits at single split
65 | assert len(train_samples.intersection(test_samples)) == 0
66 |
67 | all_train_samples.append(train_samples)
68 | all_train_samples.append(train_samples)
69 |
70 | # gather unique samples from multiple splits
71 | samples_in_all_train_splits = set().union(*all_train_samples)
72 | samples_in_all_test_splits = set().union(*all_test_samples)
73 |
74 | # no overlap between train and test splits at multiple splits
75 | assert len(samples_in_all_test_splits.intersection(samples_in_all_train_splits)) == 0
--------------------------------------------------------------------------------
/torchmeta/tests/test_toy.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import numpy as np
4 | from collections import OrderedDict
5 |
6 | from torchmeta.utils.data import Task, MetaDataset
7 | from torchmeta.toy import Sinusoid, Harmonic, SinusoidAndLine
8 | from torchmeta.toy import helpers
9 |
10 |
11 | @pytest.mark.parametrize('dataset_class',
12 | [Sinusoid, Harmonic, SinusoidAndLine])
13 | def test_toy_meta_dataset(dataset_class):
14 | dataset = dataset_class(10, num_tasks=1000, noise_std=None)
15 |
16 | assert isinstance(dataset, MetaDataset)
17 | assert len(dataset) == 1000
18 |
19 |
20 | @pytest.mark.parametrize('dataset_class',
21 | [Sinusoid, Harmonic, SinusoidAndLine])
22 | def test_toy_task(dataset_class):
23 | dataset = dataset_class(10, num_tasks=1000, noise_std=None)
24 | task = dataset[0]
25 |
26 | assert isinstance(task, Task)
27 | assert len(task) == 10
28 |
29 |
30 | @pytest.mark.parametrize('dataset_class',
31 | [Sinusoid, Harmonic, SinusoidAndLine])
32 | def test_toy_sample(dataset_class):
33 | dataset = dataset_class(10, num_tasks=1000, noise_std=None)
34 | task = dataset[0]
35 | input, target = task[0]
36 |
37 | assert isinstance(input, np.ndarray)
38 | assert isinstance(target, np.ndarray)
39 | assert input.shape == (1,)
40 | assert target.shape == (1,)
41 |
42 |
43 | @pytest.mark.parametrize('name,dataset_class',
44 | [('sinusoid', Sinusoid), ('harmonic', Harmonic)])
45 | def test_toy_helpers(name, dataset_class):
46 | dataset_fn = getattr(helpers, name)
47 | dataset = dataset_fn(shots=5, test_shots=15)
48 | assert isinstance(dataset, dataset_class)
49 |
50 | task = dataset[0]
51 | assert isinstance(task, OrderedDict)
52 | assert 'train' in task
53 | assert 'test' in task
54 |
55 | train, test = task['train'], task['test']
56 | assert isinstance(train, Task)
57 | assert isinstance(test, Task)
58 | assert len(train) == 5
59 | assert len(test) == 15
60 |
--------------------------------------------------------------------------------
/torchmeta/toy/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta.toy.harmonic import Harmonic
2 | from torchmeta.toy.sinusoid import Sinusoid
3 | from torchmeta.toy.sinusoid_line import SinusoidAndLine
4 |
5 | from torchmeta.toy import helpers
6 |
7 | __all__ = ['Harmonic', 'Sinusoid', 'SinusoidAndLine', 'helpers']
8 |
--------------------------------------------------------------------------------
/torchmeta/toy/harmonic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from torchmeta.utils.data import Task, MetaDataset
4 |
5 |
6 | class Harmonic(MetaDataset):
7 | """
8 | Simple regression task, based on the sum of two sine waves, as introduced
9 | in [1].
10 |
11 | Parameters
12 | ----------
13 | num_samples_per_task : int
14 | Number of examples per task.
15 |
16 | num_tasks : int (default: 5,000)
17 | Overall number of tasks to sample.
18 |
19 | noise_std : float, optional
20 | Amount of noise to include in the targets for each task. If `None`, then
21 | nos noise is included, and the target is the sum of 2 sine functions of
22 | the input.
23 |
24 | transform : callable, optional
25 | A function/transform that takes a numpy array of size (1,) and returns a
26 | transformed version of the input.
27 |
28 | target_transform : callable, optional
29 | A function/transform that takes a numpy array of size (1,) and returns a
30 | transformed version of the target.
31 |
32 | dataset_transform : callable, optional
33 | A function/transform that takes a dataset (ie. a task), and returns a
34 | transformed version of it. E.g. `torchmeta.transforms.ClassSplitter()`.
35 |
36 | Notes
37 | -----
38 | The tasks are created randomly as the sum of two sinusoid functions, with
39 | a frequency ratio of 2. The amplitudes vary within [5.0, 7.0], the phases
40 | within [0, 2 * pi], and the inputs are sampled according to N(mu_x, 1), with
41 | mu_x varying in [-4.0, 4.0]. Due to the way PyTorch handles datasets, the
42 | number of tasks to be sampled needs to be fixed ahead of time (with
43 | `num_tasks`). This will typically be equal to `meta_batch_size * num_batches`.
44 |
45 | References
46 | ----------
47 | .. [1] Lacoste A., Oreshkin B., Chung W., Boquet T., Rostamzadeh N.,
48 | Krueger D. (2018). Uncertainty in Multitask Transfer Learning. In
49 | Advances in Neural Information Processing Systems (https://arxiv.org/abs/1806.07528)
50 | """
51 | def __init__(self, num_samples_per_task, num_tasks=5000,
52 | noise_std=None, transform=None, target_transform=None,
53 | dataset_transform=None):
54 | super(Harmonic, self).__init__(meta_split='train',
55 | target_transform=target_transform, dataset_transform=dataset_transform)
56 | self.num_samples_per_task = num_samples_per_task
57 | self.num_tasks = num_tasks
58 | self.noise_std = noise_std
59 | self.transform = transform
60 |
61 | self._domain_range = np.array([-4.0, 4.0])
62 | self._frequency_range = np.array([5.0, 7.0])
63 | self._phase_range = np.array([0, 2 * np.pi])
64 |
65 | self._domains = None
66 | self._frequencies = None
67 | self._phases = None
68 | self._amplitudes = None
69 |
70 | @property
71 | def domains(self):
72 | if self._domains is None:
73 | self._domains = self.np_random.uniform(self._domain_range[0],
74 | self._domain_range[1], size=self.num_tasks)
75 | return self._domains
76 |
77 | @property
78 | def frequencies(self):
79 | if self._frequencies is None:
80 | self._frequencies = self.np_random.uniform(self._frequency_range[0],
81 | self._frequency_range[1], size=self.num_tasks)
82 | return self._frequencies
83 |
84 | @property
85 | def phases(self):
86 | if self._phases is None:
87 | self._phases = self.np_random.uniform(self._phase_range[0],
88 | self._phase_range[1], size=(self.num_tasks, 2))
89 | return self._phases
90 |
91 | @property
92 | def amplitudes(self):
93 | if self._amplitudes is None:
94 | self._amplitudes = self.np_random.randn(self.num_tasks, 2)
95 | return self._amplitudes
96 |
97 | def __len__(self):
98 | return self.num_tasks
99 |
100 | def __getitem__(self, index):
101 | domain = self.domains[index]
102 | frequency = self.frequencies[index]
103 | phases = self.phases[index]
104 | amplitudes = self.amplitudes[index]
105 |
106 | task = HarmonicTask(index, domain, frequency, phases, amplitudes,
107 | self.noise_std, self.num_samples_per_task, self.transform,
108 | self.target_transform, np_random=self.np_random)
109 |
110 | if self.dataset_transform is not None:
111 | task = self.dataset_transform(task)
112 |
113 | return task
114 |
115 |
116 | class HarmonicTask(Task):
117 | def __init__(self, index, domain, frequency, phases, amplitudes,
118 | noise_std, num_samples, transform=None,
119 | target_transform=None, np_random=None):
120 | super(HarmonicTask, self).__init__(index, None) # Regression task
121 | self.domain = domain
122 | self.frequency = frequency
123 | self.phases = phases
124 | self.amplitudes = amplitudes
125 | self.noise_std = noise_std
126 | self.num_samples = num_samples
127 |
128 | self.transform = transform
129 | self.target_transform = target_transform
130 |
131 | if np_random is None:
132 | np_random = np.random.RandomState(None)
133 |
134 | a_1, a_2 = self.amplitudes
135 | b_1, b_2 = self.phases
136 |
137 | self._inputs = self.domain + np_random.randn(num_samples, 1)
138 | self._targets = (a_1 * np.sin(frequency * self._inputs + b_1)
139 | + a_2 * np.sin(2 * frequency * self._inputs + b_2))
140 | if (noise_std is not None) and (noise_std > 0.):
141 | self._targets += noise_std * np_random.randn(num_samples, 1)
142 |
143 | def __len__(self):
144 | return self.num_samples
145 |
146 | def __getitem__(self, index):
147 | input, target = self._inputs[index], self._targets[index]
148 |
149 | if self.transform is not None:
150 | input = self.transform(input)
151 |
152 | if self.target_transform is not None:
153 | target = self.target_transform(target)
154 |
155 | return (input, target)
156 |
--------------------------------------------------------------------------------
/torchmeta/toy/helpers.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from torchmeta.toy import Sinusoid, Harmonic
4 | from torchmeta.transforms import ClassSplitter
5 |
6 | def sinusoid(shots, shuffle=True, test_shots=None, seed=None, **kwargs):
7 | """Helper function to create a meta-dataset for the Sinusoid toy dataset.
8 |
9 | Parameters
10 | ----------
11 | shots : int
12 | Number of (training) examples in each task. This corresponds to `k` in
13 | `k-shot` classification.
14 |
15 | shuffle : bool (default: `True`)
16 | Shuffle the examples when creating the tasks.
17 |
18 | test_shots : int, optional
19 | Number of test examples in each task. If `None`, then the number of test
20 | examples is equal to the number of training examples in each task.
21 |
22 | seed : int, optional
23 | Random seed to be used in the meta-dataset.
24 |
25 | kwargs
26 | Additional arguments passed to the `Sinusoid` class.
27 |
28 | See also
29 | --------
30 | `torchmeta.toy.Sinusoid` : Meta-dataset for the Sinusoid toy dataset.
31 | """
32 | if 'num_samples_per_task' in kwargs:
33 | warnings.warn('Both arguments `shots` and `num_samples_per_task` were '
34 | 'set in the helper function for the number of samples in each task. '
35 | 'Ignoring the argument `shots`.', stacklevel=2)
36 | if test_shots is not None:
37 | shots = kwargs['num_samples_per_task'] - test_shots
38 | if shots <= 0:
39 | raise ValueError('The argument `test_shots` ({0}) is greater '
40 | 'than the number of samples per task ({1}). Either use the '
41 | 'argument `shots` instead of `num_samples_per_task`, or '
42 | 'increase the value of `num_samples_per_task`.'.format(
43 | test_shots, kwargs['num_samples_per_task']))
44 | else:
45 | shots = kwargs['num_samples_per_task'] // 2
46 | if test_shots is None:
47 | test_shots = shots
48 |
49 | dataset = Sinusoid(num_samples_per_task=shots + test_shots, **kwargs)
50 | dataset = ClassSplitter(dataset, shuffle=shuffle,
51 | num_train_per_class=shots, num_test_per_class=test_shots)
52 | dataset.seed(seed)
53 |
54 | return dataset
55 |
56 | def harmonic(shots, shuffle=True, test_shots=None, seed=None, **kwargs):
57 | """Helper function to create a meta-dataset for the Harmonic toy dataset.
58 |
59 | Parameters
60 | ----------
61 | shots : int
62 | Number of (training) examples in each task. This corresponds to `k` in
63 | `k-shot` classification.
64 |
65 | shuffle : bool (default: `True`)
66 | Shuffle the examples when creating the tasks.
67 |
68 | test_shots : int, optional
69 | Number of test examples in each task. If `None`, then the number of test
70 | examples is equal to the number of training examples in each task.
71 |
72 | seed : int, optional
73 | Random seed to be used in the meta-dataset.
74 |
75 | kwargs
76 | Additional arguments passed to the `Harmonic` class.
77 |
78 | See also
79 | --------
80 | `torchmeta.toy.Harmonic` : Meta-dataset for the Harmonic toy dataset.
81 | """
82 | if 'num_samples_per_task' in kwargs:
83 | warnings.warn('Both arguments `shots` and `num_samples_per_task` were '
84 | 'set in the helper function for the number of samples in each task. '
85 | 'Ignoring the argument `shots`.', stacklevel=2)
86 | if test_shots is not None:
87 | shots = kwargs['num_samples_per_task'] - test_shots
88 | if shots <= 0:
89 | raise ValueError('The argument `test_shots` ({0}) is greater '
90 | 'than the number of samples per task ({1}). Either use the '
91 | 'argument `shots` instead of `num_samples_per_task`, or '
92 | 'increase the value of `num_samples_per_task`.'.format(
93 | test_shots, kwargs['num_samples_per_task']))
94 | else:
95 | shots = kwargs['num_samples_per_task'] // 2
96 | if test_shots is None:
97 | test_shots = shots
98 |
99 | dataset = Harmonic(num_samples_per_task=shots + test_shots, **kwargs)
100 | dataset = ClassSplitter(dataset, shuffle=shuffle,
101 | num_train_per_class=shots, num_test_per_class=test_shots)
102 | dataset.seed(seed)
103 |
104 | return dataset
105 |
--------------------------------------------------------------------------------
/torchmeta/toy/sinusoid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from torchmeta.utils.data import Task, MetaDataset
4 |
5 |
6 | class Sinusoid(MetaDataset):
7 | """
8 | Simple regression task, based on sinusoids, as introduced in [1].
9 |
10 | Parameters
11 | ----------
12 | num_samples_per_task : int
13 | Number of examples per task.
14 |
15 | num_tasks : int (default: 1,000,000)
16 | Overall number of tasks to sample.
17 |
18 | noise_std : float, optional
19 | Amount of noise to include in the targets for each task. If `None`, then
20 | nos noise is included, and the target is a sine function of the input.
21 |
22 | transform : callable, optional
23 | A function/transform that takes a numpy array of size (1,) and returns a
24 | transformed version of the input.
25 |
26 | target_transform : callable, optional
27 | A function/transform that takes a numpy array of size (1,) and returns a
28 | transformed version of the target.
29 |
30 | dataset_transform : callable, optional
31 | A function/transform that takes a dataset (ie. a task), and returns a
32 | transformed version of it. E.g. `torchmeta.transforms.ClassSplitter()`.
33 |
34 | Notes
35 | -----
36 | The tasks are created randomly as random sinusoid function. The amplitude
37 | varies within [0.1, 5.0], the phase within [0, pi], and the inputs are
38 | sampled uniformly in [-5.0, 5.0]. Due to the way PyTorch handles datasets,
39 | the number of tasks to be sampled needs to be fixed ahead of time (with
40 | `num_tasks`). This will typically be equal to `meta_batch_size * num_batches`.
41 |
42 | References
43 | ----------
44 | .. [1] Finn C., Abbeel P., and Levine, S. (2017). Model-Agnostic Meta-Learning
45 | for Fast Adaptation of Deep Networks. International Conference on
46 | Machine Learning (ICML) (https://arxiv.org/abs/1703.03400)
47 | """
48 | def __init__(self, num_samples_per_task, num_tasks=1000000,
49 | noise_std=None, transform=None, target_transform=None,
50 | dataset_transform=None):
51 | super(Sinusoid, self).__init__(meta_split='train',
52 | target_transform=target_transform, dataset_transform=dataset_transform)
53 | self.num_samples_per_task = num_samples_per_task
54 | self.num_tasks = num_tasks
55 | self.noise_std = noise_std
56 | self.transform = transform
57 |
58 | self._input_range = np.array([-5.0, 5.0])
59 | self._amplitude_range = np.array([0.1, 5.0])
60 | self._phase_range = np.array([0, np.pi])
61 |
62 | self._amplitudes = None
63 | self._phases = None
64 |
65 | @property
66 | def amplitudes(self):
67 | if self._amplitudes is None:
68 | self._amplitudes = self.np_random.uniform(self._amplitude_range[0],
69 | self._amplitude_range[1], size=self.num_tasks)
70 | return self._amplitudes
71 |
72 | @property
73 | def phases(self):
74 | if self._phases is None:
75 | self._phases = self.np_random.uniform(self._phase_range[0],
76 | self._phase_range[1], size=self.num_tasks)
77 | return self._phases
78 |
79 | def __len__(self):
80 | return self.num_tasks
81 |
82 | def __getitem__(self, index):
83 | amplitude, phase = self.amplitudes[index], self.phases[index]
84 | task = SinusoidTask(index, amplitude, phase, self._input_range,
85 | self.noise_std, self.num_samples_per_task, self.transform,
86 | self.target_transform, np_random=self.np_random)
87 |
88 | if self.dataset_transform is not None:
89 | task = self.dataset_transform(task)
90 |
91 | return task
92 |
93 |
94 | class SinusoidTask(Task):
95 | def __init__(self, index, amplitude, phase, input_range, noise_std,
96 | num_samples, transform=None, target_transform=None,
97 | np_random=None):
98 | super(SinusoidTask, self).__init__(index, None) # Regression task
99 | self.amplitude = amplitude
100 | self.phase = phase
101 | self.input_range = input_range
102 | self.num_samples = num_samples
103 | self.noise_std = noise_std
104 |
105 | self.transform = transform
106 | self.target_transform = target_transform
107 |
108 | if np_random is None:
109 | np_random = np.random.RandomState(None)
110 |
111 | self._inputs = np_random.uniform(input_range[0], input_range[1],
112 | size=(num_samples, 1))
113 | self._targets = amplitude * np.sin(self._inputs - phase)
114 | if (noise_std is not None) and (noise_std > 0.):
115 | self._targets += noise_std * np_random.randn(num_samples, 1)
116 |
117 | def __len__(self):
118 | return self.num_samples
119 |
120 | def __getitem__(self, index):
121 | input, target = self._inputs[index], self._targets[index]
122 |
123 | if self.transform is not None:
124 | input = self.transform(input)
125 |
126 | if self.target_transform is not None:
127 | target = self.target_transform(target)
128 |
129 | return (input, target)
130 |
--------------------------------------------------------------------------------
/torchmeta/toy/sinusoid_line.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from torchmeta.utils.data import Task, MetaDataset
4 | from torchmeta.toy.sinusoid import SinusoidTask
5 |
6 |
7 | class SinusoidAndLine(MetaDataset):
8 | """
9 | Simple multimodal regression task, based on sinusoids and lines, as
10 | introduced in [1].
11 |
12 | Parameters
13 | ----------
14 | num_samples_per_task : int
15 | Number of examples per task.
16 |
17 | num_tasks : int (default: 1,000,000)
18 | Overall number of tasks to sample.
19 |
20 | noise_std : float, optional
21 | Amount of noise to include in the targets for each task. If `None`, then
22 | nos noise is included, and the target is either a sine function, or a
23 | linear function of the input.
24 |
25 | transform : callable, optional
26 | A function/transform that takes a numpy array of size (1,) and returns a
27 | transformed version of the input.
28 |
29 | target_transform : callable, optional
30 | A function/transform that takes a numpy array of size (1,) and returns a
31 | transformed version of the target.
32 |
33 | dataset_transform : callable, optional
34 | A function/transform that takes a dataset (ie. a task), and returns a
35 | transformed version of it. E.g. `torchmeta.transforms.ClassSplitter()`.
36 |
37 | Notes
38 | -----
39 | The tasks are created randomly as either random sinusoid functions, or
40 | random linear functions. The amplitude of the sinusoids varies within
41 | [0.1, 5.0] and the phase within [0, pi]. The slope and intercept of the lines
42 | vary in [-3.0, 3.0]. The inputs are sampled uniformly in [-5.0, 5.0]. Due to
43 | the way PyTorch handles datasets, the number of tasks to be sampled needs to
44 | be fixed ahead of time (with `num_tasks`). This will typically be equal to
45 | `meta_batch_size * num_batches`.
46 |
47 | References
48 | ----------
49 | .. [1] Finn C., Xu K., Levine S. (2018). Probabilistic Model-Agnostic
50 | Meta-Learning. In Advances in Neural Information Processing Systems
51 | (https://arxiv.org/abs/1806.02817)
52 | """
53 | def __init__(self, num_samples_per_task, num_tasks=1000000,
54 | noise_std=None, transform=None, target_transform=None,
55 | dataset_transform=None):
56 | super(SinusoidAndLine, self).__init__(meta_split='train',
57 | target_transform=target_transform, dataset_transform=dataset_transform)
58 | self.num_samples_per_task = num_samples_per_task
59 | self.num_tasks = num_tasks
60 | self.noise_std = noise_std
61 | self.transform = transform
62 |
63 | self._input_range = np.array([-5.0, 5.0])
64 | self._amplitude_range = np.array([0.1, 5.0])
65 | self._phase_range = np.array([0, np.pi])
66 | self._slope_range = np.array([-3.0, 3.0])
67 | self._intercept_range = np.array([-3.0, 3.0])
68 |
69 |
70 | self._is_sinusoid = None
71 | self._amplitudes = None
72 | self._phases = None
73 | self._slopes = None
74 | self._intercepts = None
75 |
76 | @property
77 | def amplitudes(self):
78 | if self._amplitudes is None:
79 | self._amplitudes = self.np_random.uniform(self._amplitude_range[0],
80 | self._amplitude_range[1], size=self.num_tasks)
81 | return self._amplitudes
82 |
83 | @property
84 | def phases(self):
85 | if self._phases is None:
86 | self._phases = self.np_random.uniform(self._phase_range[0],
87 | self._phase_range[1], size=self.num_tasks)
88 | return self._phases
89 |
90 | @property
91 | def slopes(self):
92 | if self._slopes is None:
93 | self._slopes = self.np_random.uniform(self._slope_range[0],
94 | self._slope_range[1], size=self.num_tasks)
95 | return self._slopes
96 |
97 | @property
98 | def intercepts(self):
99 | if self._intercepts is None:
100 | self._intercepts = self.np_random.uniform(self._intercept_range[0],
101 | self._intercept_range[1], size=self.num_tasks)
102 | return self._intercepts
103 |
104 | @property
105 | def is_sinusoid(self):
106 | if self._is_sinusoid is None:
107 | self._is_sinusoid = np.zeros((self.num_tasks,), dtype=np.bool_)
108 | self._is_sinusoid[self.num_tasks // 2:] = True
109 | self.np_random.shuffle(self._is_sinusoid)
110 | return self._is_sinusoid
111 |
112 | def __len__(self):
113 | return self.num_tasks
114 |
115 | def __getitem__(self, index):
116 | if self.is_sinusoid[index]:
117 | amplitude, phase = self.amplitudes[index], self.phases[index]
118 | task = SinusoidTask(index, amplitude, phase, self._input_range,
119 | self.noise_std, self.num_samples_per_task, self.transform,
120 | self.target_transform, np_random=self.np_random)
121 | else:
122 | slope, intercept = self.slopes[index], self.intercepts[index]
123 | task = LinearTask(index, slope, intercept, self._input_range,
124 | self.noise_std, self.num_samples_per_task, self.transform,
125 | self.target_transform, np_random=self.np_random)
126 |
127 | if self.dataset_transform is not None:
128 | task = self.dataset_transform(task)
129 |
130 | return task
131 |
132 |
133 | class LinearTask(Task):
134 | def __init__(self, index, slope, intercept, input_range, noise_std,
135 | num_samples, transform=None, target_transform=None,
136 | np_random=None):
137 | super(LinearTask, self).__init__(index, None) # Regression task
138 | self.slope = slope
139 | self.intercept = intercept
140 | self.input_range = input_range
141 | self.num_samples = num_samples
142 | self.noise_std = noise_std
143 |
144 | self.transform = transform
145 | self.target_transform = target_transform
146 |
147 | if np_random is None:
148 | np_random = np.random.RandomState(None)
149 |
150 | self._inputs = np_random.uniform(input_range[0], input_range[1],
151 | size=(num_samples, 1))
152 | self._targets = intercept + slope * self._inputs
153 | if (noise_std is not None) and (noise_std > 0.):
154 | self._targets += noise_std * np_random.randn(num_samples, 1)
155 |
156 | def __len__(self):
157 | return self.num_samples
158 |
159 | def __getitem__(self, index):
160 | input, target = self._inputs[index], self._targets[index]
161 |
162 | if self.transform is not None:
163 | input = self.transform(input)
164 |
165 | if self.target_transform is not None:
166 | target = self.target_transform(target)
167 |
168 | return (input, target)
169 |
--------------------------------------------------------------------------------
/torchmeta/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta.transforms.categorical import Categorical, FixedCategory
2 | from torchmeta.transforms.augmentations import Rotation, HorizontalFlip, VerticalFlip
3 | from torchmeta.transforms.splitters import Splitter, ClassSplitter, WeightedClassSplitter
4 | from torchmeta.transforms.target_transforms import TargetTransform, DefaultTargetTransform
5 |
--------------------------------------------------------------------------------
/torchmeta/transforms/augmentations.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms.functional as F
2 |
3 | class Rotation(object):
4 | def __init__(self, angle, resample=False, expand=False, center=None):
5 | super(Rotation, self).__init__()
6 | if isinstance(angle, (list, tuple)):
7 | self._angles = angle
8 | self.angle = None
9 | else:
10 | self._angles = [angle]
11 | self.angle = angle
12 | if angle % 360 == 0:
13 | import warnings
14 | warnings.warn('Applying a rotation of {0} degrees (`{1}`) as a '
15 | 'class augmentation on a dataset is equivalent to the original '
16 | 'dataset.'.format(angle, self), UserWarning, stacklevel=2)
17 |
18 | self.resample = resample
19 | self.expand = expand
20 | self.center = center
21 |
22 | def __iter__(self):
23 | return iter(Rotation(angle, resample=self.resample, expand=self.expand,
24 | center=self.center) for angle in self._angles)
25 |
26 | def __call__(self, image):
27 | if self.angle is None:
28 | raise ValueError('The value of the angle is unspecified.')
29 | # QKFIX: Explicitly compute the pixel fill value due to an
30 | # incompatibility between Torchvision 0.5 and Pillow 7.0.0
31 | # https://github.com/pytorch/vision/issues/1759#issuecomment-583826810
32 | # Will be fixed in Torchvision 0.6
33 | fill = tuple([0] * len(image.getbands()))
34 | return F.rotate(image, self.angle % 360, self.resample,
35 | self.expand, self.center, fill=fill)
36 |
37 | def __hash__(self):
38 | return hash(repr(self))
39 |
40 | def __eq__(self, other):
41 | if (self.angle is None) or (other.angle is None):
42 | return self._angles == other._angles
43 | return (self.angle % 360) == (other.angle % 360)
44 |
45 | def __repr__(self):
46 | if self.angle is None:
47 | return 'Rotation({0})'.format(', '.join(map(str, self._angles)))
48 | else:
49 | return 'Rotation({0})'.format(self.angle % 360)
50 |
51 | def __str__(self):
52 | if self.angle is None:
53 | return 'Rotation({0})'.format(', '.join(map(str, self._angles)))
54 | else:
55 | return 'Rotation({0})'.format(self.angle)
56 |
57 | class HorizontalFlip(object):
58 | def __iter__(self):
59 | return iter([HorizontalFlip()])
60 |
61 | def __call__(self, image):
62 | return F.hflip(image)
63 |
64 | def __repr__(self):
65 | return 'HorizontalFlip()'
66 |
67 | class VerticalFlip(object):
68 | def __iter__(self):
69 | return iter([VerticalFlip()])
70 |
71 | def __call__(self, image):
72 | return F.vflip(image)
73 |
74 | def __repr__(self):
75 | return 'VerticalFlip()'
76 |
--------------------------------------------------------------------------------
/torchmeta/transforms/categorical.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmeta.transforms.utils import apply_wrapper
3 | from collections import defaultdict
4 |
5 | from torchmeta.transforms.target_transforms import TargetTransform
6 |
7 |
8 | class Categorical(TargetTransform):
9 | """Target transform to return labels in `[0, num_classes)`.
10 |
11 | Parameters
12 | ----------
13 | num_classes : int, optional
14 | Number of classes. If `None`, then the number of classes is inferred
15 | from the number of individual labels encountered.
16 |
17 | Examples
18 | --------
19 | >>> dataset = Omniglot('data', num_classes_per_task=5, meta_train=True)
20 | >>> task = dataset.sample_task()
21 | >>> task[0]
22 | (,
23 | ('images_evaluation/Glagolitic/character12', None))
24 |
25 | >>> dataset = Omniglot('data', num_classes_per_task=5, meta_train=True,
26 | ... target_transform=Categorical(5))
27 | >>> task = dataset.sample_task()
28 | >>> task[0]
29 | (, 2)
30 | """
31 | def __init__(self, num_classes=None):
32 | super(Categorical, self).__init__()
33 | self.num_classes = num_classes
34 | self._classes = None
35 | self._labels = None
36 |
37 | def reset(self):
38 | self._classes = None
39 | self._labels = None
40 |
41 | @property
42 | def classes(self):
43 | if self._classes is None:
44 | self._classes = defaultdict(None)
45 | if self.num_classes is None:
46 | default_factory = lambda: len(self._classes)
47 | else:
48 | default_factory = lambda: self.labels[len(self._classes)]
49 | self._classes.default_factory = default_factory
50 | if (self.num_classes is not None) and (len(self._classes) > self.num_classes):
51 | raise ValueError('The number of individual labels ({0}) is greater '
52 | 'than the number of classes defined by `num_classes` '
53 | '({1}).'.format(len(self._classes), self.num_classes))
54 | return self._classes
55 |
56 | @property
57 | def labels(self):
58 | if (self._labels is None) and (self.num_classes is not None):
59 | # TODO: Replace torch.randperm with seed-friendly counterpart
60 | self._labels = torch.randperm(self.num_classes).tolist()
61 | return self._labels
62 |
63 | def __call__(self, target):
64 | return self.classes[target]
65 |
66 | def __repr__(self):
67 | return '{0}({1})'.format(self.__class__.__name__, self.num_classes or '')
68 |
69 |
70 | class FixedCategory(object):
71 | def __init__(self, transform=None):
72 | self.transform = transform
73 |
74 | def __call__(self, index):
75 | return (index, self.transform)
76 |
77 | def __repr__(self):
78 | return ('{0}({1})'.format(self.__class__.__name__, self.transform))
79 |
--------------------------------------------------------------------------------
/torchmeta/transforms/target_transforms.py:
--------------------------------------------------------------------------------
1 | class TargetTransform(object):
2 | def __call__(self, target):
3 | raise NotImplementedError()
4 |
5 | def __repr__(self):
6 | return str(self.__class__.__name__)
7 |
8 |
9 | class DefaultTargetTransform(TargetTransform):
10 | def __init__(self, class_augmentations):
11 | super(DefaultTargetTransform, self).__init__()
12 | self.class_augmentations = class_augmentations
13 |
14 | self._augmentations = dict((augmentation, i + 1)
15 | for (i, augmentation) in enumerate(class_augmentations))
16 | self._augmentations[None] = 0
17 |
18 | def __call__(self, target):
19 | assert isinstance(target, tuple) and len(target) == 2
20 | label, augmentation = target
21 | return (label, self._augmentations[augmentation])
22 |
--------------------------------------------------------------------------------
/torchmeta/transforms/utils.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Compose
2 | from torchmeta.utils.data.task import Task
3 |
4 | def apply_wrapper(wrapper, task_or_dataset=None):
5 | if task_or_dataset is None:
6 | return wrapper
7 |
8 | from torchmeta.utils.data import MetaDataset
9 | if isinstance(task_or_dataset, Task):
10 | return wrapper(task_or_dataset)
11 | elif isinstance(task_or_dataset, MetaDataset):
12 | if task_or_dataset.dataset_transform is None:
13 | dataset_transform = wrapper
14 | else:
15 | dataset_transform = Compose([
16 | task_or_dataset.dataset_transform, wrapper])
17 | task_or_dataset.dataset_transform = dataset_transform
18 | return task_or_dataset
19 | else:
20 | raise NotImplementedError()
21 |
22 | def wrap_transform(transform, fn, transform_type=None):
23 | if (transform_type is None) or isinstance(transform, transform_type):
24 | return fn(transform)
25 | elif isinstance(transform, Compose):
26 | return Compose([wrap_transform(subtransform, fn, transform_type)
27 | for subtransform in transform.transforms])
28 | else:
29 | return transform
30 |
--------------------------------------------------------------------------------
/torchmeta/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta.utils import data
2 | from torchmeta.utils.metrics import hardness_metric
3 | from torchmeta.utils.prototype import get_num_samples, get_prototypes, prototypical_loss
--------------------------------------------------------------------------------
/torchmeta/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from torchmeta.utils.data.dataloader import MetaDataLoader, BatchMetaDataLoader
2 | from torchmeta.utils.data.dataset import ClassDataset, MetaDataset, CombinationMetaDataset
3 | from torchmeta.utils.data.sampler import CombinationSequentialSampler, CombinationRandomSampler
4 | from torchmeta.utils.data.task import Dataset, Task, ConcatTask, SubsetTask
5 |
6 | __all__ = [
7 | 'MetaDataLoader',
8 | 'BatchMetaDataLoader',
9 | 'ClassDataset',
10 | 'MetaDataset',
11 | 'CombinationMetaDataset',
12 | 'CombinationSequentialSampler',
13 | 'CombinationRandomSampler',
14 | 'Dataset',
15 | 'Task',
16 | 'ConcatTask',
17 | 'SubsetTask'
18 | ]
19 |
--------------------------------------------------------------------------------
/torchmeta/utils/data/dataloader.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from torch.utils.data import DataLoader
4 | from torch.utils.data.dataloader import default_collate
5 | from torch.utils.data.dataset import Dataset as TorchDataset
6 |
7 | from torchmeta.utils.data.dataset import CombinationMetaDataset
8 | from torchmeta.utils.data.sampler import (CombinationSequentialSampler,
9 | CombinationRandomSampler)
10 |
11 | def batch_meta_collate(collate_fn):
12 | def collate_task(task):
13 | if isinstance(task, TorchDataset):
14 | return collate_fn([task[idx] for idx in range(len(task))])
15 | elif isinstance(task, OrderedDict):
16 | return OrderedDict([(key, collate_task(subtask))
17 | for (key, subtask) in task.items()])
18 | else:
19 | raise NotImplementedError()
20 |
21 | def _collate_fn(batch):
22 | return collate_fn([collate_task(task) for task in batch])
23 |
24 | return _collate_fn
25 |
26 | def no_collate(batch):
27 | return batch
28 |
29 | class MetaDataLoader(DataLoader):
30 | def __init__(self, dataset, batch_size=1, shuffle=True, sampler=None,
31 | batch_sampler=None, num_workers=0, collate_fn=None,
32 | pin_memory=False, drop_last=False, timeout=0,
33 | worker_init_fn=None):
34 | if collate_fn is None:
35 | collate_fn = no_collate
36 |
37 | if isinstance(dataset, CombinationMetaDataset) and (sampler is None):
38 | if shuffle:
39 | sampler = CombinationRandomSampler(dataset)
40 | else:
41 | sampler = CombinationSequentialSampler(dataset)
42 | shuffle = False
43 |
44 | super(MetaDataLoader, self).__init__(dataset, batch_size=batch_size,
45 | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
46 | num_workers=num_workers, collate_fn=collate_fn,
47 | pin_memory=pin_memory, drop_last=drop_last, timeout=timeout,
48 | worker_init_fn=worker_init_fn)
49 |
50 |
51 | class BatchMetaDataLoader(MetaDataLoader):
52 | def __init__(self, dataset, batch_size=1, shuffle=True, sampler=None, num_workers=0,
53 | pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None):
54 | collate_fn = batch_meta_collate(default_collate)
55 |
56 | super(BatchMetaDataLoader, self).__init__(dataset,
57 | batch_size=batch_size, shuffle=shuffle, sampler=sampler,
58 | batch_sampler=None, num_workers=num_workers,
59 | collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last,
60 | timeout=timeout, worker_init_fn=worker_init_fn)
61 |
--------------------------------------------------------------------------------
/torchmeta/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 | from itertools import combinations
3 | from torch.utils.data.sampler import SequentialSampler, RandomSampler
4 |
5 | from torchmeta.utils.data.dataset import CombinationMetaDataset
6 |
7 | __all__ = ['CombinationSequentialSampler', 'CombinationRandomSampler']
8 |
9 |
10 | class CombinationSequentialSampler(SequentialSampler):
11 | def __init__(self, data_source):
12 | if not isinstance(data_source, CombinationMetaDataset):
13 | raise ValueError()
14 | super(CombinationSequentialSampler, self).__init__(data_source)
15 |
16 | def __iter__(self):
17 | num_classes = len(self.data_source.dataset)
18 | num_classes_per_task = self.data_source.num_classes_per_task
19 | return combinations(range(num_classes), num_classes_per_task)
20 |
21 |
22 | class CombinationRandomSampler(RandomSampler):
23 | def __init__(self, data_source):
24 | if not isinstance(data_source, CombinationMetaDataset):
25 | raise ValueError()
26 | self.data_source = data_source
27 |
28 | def __iter__(self):
29 | num_classes = len(self.data_source.dataset)
30 | num_classes_per_task = self.data_source.num_classes_per_task
31 | for _ in combinations(range(num_classes), num_classes_per_task):
32 | yield tuple(random.sample(range(num_classes), num_classes_per_task))
33 |
--------------------------------------------------------------------------------
/torchmeta/utils/data/task.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import ConcatDataset, Subset
2 | from torch.utils.data import Dataset as Dataset_
3 | from torchvision.transforms import Compose
4 |
5 | __all__ = ['Dataset', 'Task', 'ConcatTask', 'SubsetTask']
6 |
7 |
8 | class Dataset(Dataset_):
9 | def __init__(self, index, transform=None, target_transform=None):
10 | self.index = index
11 | self.transform = transform
12 | self.target_transform = target_transform
13 |
14 | def target_transform_append(self, transform):
15 | if transform is None:
16 | return
17 | if self.target_transform is None:
18 | self.target_transform = transform
19 | else:
20 | self.target_transform = Compose([self.target_transform, transform])
21 |
22 | def __hash__(self):
23 | return hash(self.index)
24 |
25 |
26 | class Task(Dataset):
27 | """Base class for a classification task.
28 |
29 | Parameters
30 | ----------
31 | num_classes : int
32 | Number of classes for the classification task.
33 | """
34 | def __init__(self, index, num_classes,
35 | transform=None, target_transform=None):
36 | super(Task, self).__init__(index, transform=transform,
37 | target_transform=target_transform)
38 | self.num_classes = num_classes
39 |
40 |
41 | class ConcatTask(Task, ConcatDataset):
42 | def __init__(self, datasets, num_classes, target_transform=None):
43 | index = tuple(task.index for task in datasets)
44 | Task.__init__(self, index, num_classes)
45 | ConcatDataset.__init__(self, datasets)
46 | for task in self.datasets:
47 | task.target_transform_append(target_transform)
48 |
49 | def __getitem__(self, index):
50 | return ConcatDataset.__getitem__(self, index)
51 |
52 |
53 | class SubsetTask(Task, Subset):
54 | def __init__(self, dataset, indices, num_classes=None,
55 | target_transform=None):
56 | if num_classes is None:
57 | num_classes = dataset.num_classes
58 | Task.__init__(self, dataset.index, num_classes)
59 | Subset.__init__(self, dataset, indices)
60 | self.dataset.target_transform_append(target_transform)
61 |
62 | def __getitem__(self, index):
63 | return Subset.__getitem__(self, index)
64 |
65 | def __hash__(self):
66 | return hash((self.index, tuple(self.indices)))
67 |
--------------------------------------------------------------------------------
/torchmeta/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torchmeta.utils.prototype import get_prototypes
5 |
6 | __all__ = ['hardness_metric']
7 |
8 |
9 | def _pad_images(inputs, size=(224, 224), **kwargs):
10 | height, width = inputs.shape[-2:]
11 | pad_height, pad_width = (size[0] - height) // 2, (size[1] - width) // 2
12 | padding = (pad_width, size[1] - width - pad_width,
13 | pad_height, size[0] - height - pad_height)
14 | return F.pad(inputs, padding, **kwargs)
15 |
16 |
17 | def hardness_metric(batch, num_classes):
18 | """Hardness metric of an episode, as defined in [1].
19 |
20 | Parameters
21 | ----------
22 | batch : dict
23 | The batch of tasks over which the metric is computed. The batch of tasks
24 | is a dictionary containing the keys `train` (or `support`) and `test`
25 | (or `query`). This is typically the output of `BatchMetaDataLoader`.
26 |
27 | num_classes : int
28 | The number of classes in the classification task. This corresponds to
29 | the number of ways in an `N`-way classification problem.
30 |
31 | Returns
32 | -------
33 | metric : `torch.FloatTensor` instance
34 | Values of the hardness metric for each task in the batch.
35 |
36 | References
37 | ----------
38 | .. [1] Dhillon, G. S., Chaudhari, P., Ravichandran, A. and Soatto S. (2019).
39 | A Baseline for Few-Shot Image Classification. (https://arxiv.org/abs/1909.02729)
40 | """
41 | if ('train' not in batch) and ('support' not in batch):
42 | raise ValueError('The tasks do not contain any training/support set. '
43 | 'Make sure the tasks contain either the "train" or the '
44 | '"support" key.')
45 | if ('test' not in batch) and ('query' not in batch):
46 | raise ValueError('The tasks do not contain any test/query set. Make '
47 | 'sure the tasks contain either the "test" of the '
48 | '"query" key.')
49 |
50 | train = 'train' if ('train' in batch) else 'support'
51 | test = 'test' if ('test' in batch) else 'query'
52 |
53 | with torch.no_grad():
54 | # Load a pre-trained backbone Resnet-152 model from PyTorch Hub
55 | backbone = torch.hub.load('pytorch/vision:v0.5.0',
56 | 'resnet152',
57 | pretrained=True,
58 | verbose=False)
59 | backbone.eval()
60 |
61 | train_inputs, train_targets = batch[train]
62 | test_inputs, test_targets = batch[test]
63 | batch_size, num_images, num_channels = train_inputs.shape[:3]
64 | num_test_images = test_inputs.size(1)
65 |
66 | backbone.to(device=train_inputs.device)
67 |
68 | if num_channels != 3:
69 | raise ValueError('The images must be RGB images.')
70 |
71 | # Pad the images so that they are compatible with the pre-trained model
72 | padded_train_inputs = _pad_images(train_inputs,
73 | size=(224, 224), mode='constant', value=0.)
74 | padded_test_inputs = _pad_images(test_inputs,
75 | size=(224, 224), mode='constant', value=0.)
76 |
77 | # Compute the features from the logits returned by the pre-trained
78 | # model on the train/support examples. These features are z(x, theta)_+,
79 | # averaged for each class
80 | train_logits = backbone(padded_train_inputs.view(-1, 3, 224, 224))
81 | train_logits = F.relu(train_logits.view(batch_size, num_images, -1))
82 | train_features = get_prototypes(train_logits, train_targets, num_classes)
83 |
84 | # Get the weights by normalizing the features
85 | weights = F.normalize(train_features, p=2, dim=2)
86 |
87 | # Compute and normalize the logits of the test/query examples
88 | test_logits = backbone(padded_test_inputs.view(-1, 3, 224, 224))
89 | test_logits = test_logits.view(batch_size, num_test_images, -1)
90 | test_logits = F.normalize(test_logits, p=2, dim=2)
91 |
92 | # Compute the log probabilities of the test/query examples
93 | test_logits = torch.bmm(weights, test_logits.transpose(1, 2))
94 | test_log_probas = -F.cross_entropy(test_logits, test_targets,
95 | reduction='none')
96 |
97 | # Compute the log-odds ratios for each image of the test/query set
98 | log_odds_ratios = torch.log1p(-test_log_probas.exp()) - test_log_probas
99 |
100 | return torch.mean(log_odds_ratios, dim=1)
101 |
--------------------------------------------------------------------------------
/torchmeta/utils/prototype.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | __all__ = ['get_num_samples', 'get_prototypes', 'prototypical_loss']
5 |
6 |
7 | def get_num_samples(targets, num_classes, dtype=None):
8 | batch_size = targets.size(0)
9 | with torch.no_grad():
10 | ones = torch.ones_like(targets, dtype=dtype)
11 | num_samples = ones.new_zeros((batch_size, num_classes))
12 | num_samples.scatter_add_(1, targets, ones)
13 | return num_samples
14 |
15 |
16 | def get_prototypes(embeddings, targets, num_classes):
17 | """Compute the prototypes (the mean vector of the embedded training/support
18 | points belonging to its class) for each classes in the task.
19 |
20 | Parameters
21 | ----------
22 | embeddings : `torch.FloatTensor` instance
23 | A tensor containing the embeddings of the support points. This tensor
24 | has shape `(batch_size, num_examples, embedding_size)`.
25 |
26 | targets : `torch.LongTensor` instance
27 | A tensor containing the targets of the support points. This tensor has
28 | shape `(batch_size, num_examples)`.
29 |
30 | num_classes : int
31 | Number of classes in the task.
32 |
33 | Returns
34 | -------
35 | prototypes : `torch.FloatTensor` instance
36 | A tensor containing the prototypes for each class. This tensor has shape
37 | `(batch_size, num_classes, embedding_size)`.
38 | """
39 | batch_size, embedding_size = embeddings.size(0), embeddings.size(-1)
40 |
41 | num_samples = get_num_samples(targets, num_classes, dtype=embeddings.dtype)
42 | num_samples.unsqueeze_(-1)
43 | num_samples = torch.max(num_samples, torch.ones_like(num_samples))
44 |
45 | prototypes = embeddings.new_zeros((batch_size, num_classes, embedding_size))
46 | indices = targets.unsqueeze(-1).expand_as(embeddings)
47 | prototypes.scatter_add_(1, indices, embeddings).div_(num_samples)
48 |
49 | return prototypes
50 |
51 |
52 | def prototypical_loss(prototypes, embeddings, targets, **kwargs):
53 | """Compute the loss (i.e. negative log-likelihood) for the prototypical
54 | network, on the test/query points.
55 |
56 | Parameters
57 | ----------
58 | prototypes : `torch.FloatTensor` instance
59 | A tensor containing the prototypes for each class. This tensor has shape
60 | `(batch_size, num_classes, embedding_size)`.
61 |
62 | embeddings : `torch.FloatTensor` instance
63 | A tensor containing the embeddings of the query points. This tensor has
64 | shape `(batch_size, num_examples, embedding_size)`.
65 |
66 | targets : `torch.LongTensor` instance
67 | A tensor containing the targets of the query points. This tensor has
68 | shape `(batch_size, num_examples)`.
69 |
70 | Returns
71 | -------
72 | loss : `torch.FloatTensor` instance
73 | The negative log-likelihood on the query points.
74 | """
75 | squared_distances = torch.sum((prototypes.unsqueeze(2)
76 | - embeddings.unsqueeze(1)) ** 2, dim=-1)
77 | return F.cross_entropy(-squared_distances, targets, **kwargs)
78 |
--------------------------------------------------------------------------------
/torchmeta/version.py:
--------------------------------------------------------------------------------
1 | VERSION = '1.4.0'
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | '''Implements a generic training loop.
2 | '''
3 |
4 | import torch
5 | import utils
6 | from torch.utils.tensorboard import SummaryWriter
7 | from tqdm.autonotebook import tqdm
8 | import time
9 | import numpy as np
10 | import os
11 | import shutil
12 |
13 |
14 | def train(model, train_dataloader, epochs, lr, steps_til_summary, epochs_til_checkpoint, model_dir, loss_fn,
15 | summary_fn, val_dataloader=None, double_precision=False, clip_grad=False, use_lbfgs=False, loss_schedules=None):
16 |
17 | optim = torch.optim.Adam(lr=lr, params=model.parameters())
18 |
19 | # copy settings from Raissi et al. (2019) and here
20 | # https://github.com/maziarraissi/PINNs
21 | if use_lbfgs:
22 | optim = torch.optim.LBFGS(lr=lr, params=model.parameters(), max_iter=50000, max_eval=50000,
23 | history_size=50, line_search_fn='strong_wolfe')
24 |
25 | if os.path.exists(model_dir):
26 | val = input("The model directory %s exists. Overwrite? (y/n)"%model_dir)
27 | if val == 'y':
28 | shutil.rmtree(model_dir)
29 |
30 | os.makedirs(model_dir)
31 |
32 | summaries_dir = os.path.join(model_dir, 'summaries')
33 | utils.cond_mkdir(summaries_dir)
34 |
35 | checkpoints_dir = os.path.join(model_dir, 'checkpoints')
36 | utils.cond_mkdir(checkpoints_dir)
37 |
38 | writer = SummaryWriter(summaries_dir)
39 |
40 | total_steps = 0
41 | with tqdm(total=len(train_dataloader) * epochs) as pbar:
42 | train_losses = []
43 | for epoch in range(epochs):
44 | if not epoch % epochs_til_checkpoint and epoch:
45 | torch.save(model.state_dict(),
46 | os.path.join(checkpoints_dir, 'model_epoch_%04d.pth' % epoch))
47 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_epoch_%04d.txt' % epoch),
48 | np.array(train_losses))
49 |
50 | for step, (model_input, gt) in enumerate(train_dataloader):
51 | start_time = time.time()
52 |
53 | model_input = {key: value.cuda() for key, value in model_input.items()}
54 | gt = {key: value.cuda() for key, value in gt.items()}
55 |
56 | if double_precision:
57 | model_input = {key: value.double() for key, value in model_input.items()}
58 | gt = {key: value.double() for key, value in gt.items()}
59 |
60 | if use_lbfgs:
61 | def closure():
62 | optim.zero_grad()
63 | model_output = model(model_input)
64 | losses = loss_fn(model_output, gt)
65 | train_loss = 0.
66 | for loss_name, loss in losses.items():
67 | train_loss += loss.mean()
68 | train_loss.backward()
69 | return train_loss
70 | optim.step(closure)
71 |
72 | model_output = model(model_input)
73 | losses = loss_fn(model_output, gt)
74 |
75 | train_loss = 0.
76 | for loss_name, loss in losses.items():
77 | single_loss = loss.mean()
78 |
79 | if loss_schedules is not None and loss_name in loss_schedules:
80 | writer.add_scalar(loss_name + "_weight", loss_schedules[loss_name](total_steps), total_steps)
81 | single_loss *= loss_schedules[loss_name](total_steps)
82 |
83 | writer.add_scalar(loss_name, single_loss, total_steps)
84 | train_loss += single_loss
85 |
86 | train_losses.append(train_loss.item())
87 | writer.add_scalar("total_train_loss", train_loss, total_steps)
88 |
89 | if not total_steps % steps_til_summary:
90 | torch.save(model.state_dict(),
91 | os.path.join(checkpoints_dir, 'model_current.pth'))
92 | summary_fn(model, model_input, gt, model_output, writer, total_steps)
93 |
94 | if not use_lbfgs:
95 | optim.zero_grad()
96 | train_loss.backward()
97 |
98 | if clip_grad:
99 | if isinstance(clip_grad, bool):
100 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
101 | else:
102 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
103 |
104 | optim.step()
105 |
106 | pbar.update(1)
107 |
108 | if not total_steps % steps_til_summary:
109 | tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (epoch, train_loss, time.time() - start_time))
110 |
111 | if val_dataloader is not None:
112 | print("Running validation set...")
113 | model.eval()
114 | with torch.no_grad():
115 | val_losses = []
116 | for (model_input, gt) in val_dataloader:
117 | model_output = model(model_input)
118 | val_loss = loss_fn(model_output, gt)
119 | val_losses.append(val_loss)
120 |
121 | writer.add_scalar("val_loss", np.mean(val_losses), total_steps)
122 | model.train()
123 |
124 | total_steps += 1
125 |
126 | torch.save(model.state_dict(),
127 | os.path.join(checkpoints_dir, 'model_final.pth'))
128 | np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'),
129 | np.array(train_losses))
130 |
131 |
132 | class LinearDecaySchedule():
133 | def __init__(self, start_val, final_val, num_steps):
134 | self.start_val = start_val
135 | self.final_val = final_val
136 | self.num_steps = num_steps
137 |
138 | def __call__(self, iter):
139 | return self.start_val + (self.final_val - self.start_val) * min(iter / self.num_steps, 1.)
140 |
--------------------------------------------------------------------------------