├── LICENSE ├── README.md ├── block_diagonalization.ipynb ├── configs ├── 3dshapes │ └── lstsq │ │ ├── lstsq.yml │ │ ├── lstsq_multi.yml │ │ ├── lstsq_rec.yml │ │ ├── neuralM.yml │ │ └── neural_trans.yml ├── mnist │ ├── lstsq │ │ ├── lstsq.yml │ │ ├── lstsq_multi.yml │ │ ├── lstsq_rec.yml │ │ ├── neuralM.yml │ │ ├── neuralM_latentpred.yml │ │ └── neural_trans.yml │ └── simclr │ │ ├── cpc.yml │ │ └── simclr.yml ├── mnist_accl │ └── lstsq │ │ ├── holstsq.yml │ │ ├── lstsq.yml │ │ └── neural_trans.yml ├── mnist_bg │ ├── lstsq │ │ ├── lstsq.yml │ │ ├── lstsq_multi.yml │ │ ├── lstsq_rec.yml │ │ ├── neuralM.yml │ │ ├── neuralM_latentpred.yml │ │ └── neural_trans.yml │ └── simclr │ │ ├── cpc.yml │ │ └── simclr.yml └── smallNORB │ └── lstsq │ ├── lstsq.yml │ ├── lstsq_multi.yml │ ├── lstsq_rec.yml │ ├── neuralM.yml │ └── neural_trans.yml ├── datasets ├── __init__.py ├── seq_mnist.py ├── small_norb.py └── three_dim_shapes.py ├── equivariance_error.ipynb ├── extrp.ipynb ├── figs ├── comp_error_3dshapes.pdf ├── comp_error_mnist.pdf ├── comp_error_mnist_accl.pdf ├── comp_error_mnist_bg.pdf ├── comp_error_mnist_bg_full.pdf ├── comp_error_smallNORB.pdf ├── disentangle_3dshapes.pdf ├── disentangle_mnist.pdf ├── disentangle_mnist_bg.pdf ├── disentangle_mnist_bg_full.pdf ├── disentangle_smallNORB.pdf ├── equiv_error_3dshapes.pdf ├── equiv_error_mnist.pdf ├── equiv_error_mnist_accl.pdf ├── equiv_error_mnist_bg.pdf ├── equiv_error_mnist_bg_full.pdf ├── equiv_error_smallNORB.pdf ├── sbd_3dshapes.pdf ├── sbd_3dshapes_component0.pdf ├── sbd_3dshapes_component1.pdf ├── sbd_3dshapes_component2.pdf ├── sbd_3dshapes_component3.pdf ├── sbd_3dshapes_component5.pdf ├── sbd_3dshapes_components.pdf ├── sbd_mnist.pdf ├── sbd_mnist_bg.pdf ├── sbd_mnist_bg_components.pdf ├── sbd_mnist_bg_full.pdf ├── sbd_mnist_bg_full_component0.pdf ├── sbd_mnist_bg_full_component1.pdf ├── sbd_mnist_bg_full_component2.pdf ├── sbd_mnist_bg_full_components.pdf ├── sbd_mnist_component0.pdf ├── sbd_mnist_component1.pdf ├── sbd_mnist_component2.pdf ├── sbd_smallNORB.pdf ├── sbd_smallNORB_component0.pdf ├── sbd_smallNORB_component1.pdf ├── sbd_smallNORB_component2.pdf ├── sbd_smallNORB_component3.pdf ├── sbd_smallNORB_component4.pdf ├── sbd_smallNORB_component5.pdf └── sbd_smallNORB_components.pdf ├── gen_images.ipynb ├── gen_images ├── 3dshapes_lstsq.png ├── 3dshapes_lstsq_multi.png ├── 3dshapes_lstsq_rec.png ├── 3dshapes_neuralM.png ├── 3dshapes_neural_trans.png ├── mnist_accl_holstsq_equiv.png ├── mnist_bg_full_lstsq.png ├── mnist_bg_full_lstsq_multi.png ├── mnist_bg_full_lstsq_rec.png ├── mnist_bg_full_neuralM.png ├── mnist_bg_full_neural_trans.png ├── mnist_bg_lstsq.png ├── mnist_bg_lstsq_multi.png ├── mnist_bg_lstsq_rec.png ├── mnist_bg_neuralM.png ├── mnist_bg_neural_trans.png ├── mnist_lstsq.png ├── mnist_lstsq_multi.png ├── mnist_lstsq_rec.png ├── mnist_neuralM.png ├── mnist_neural_trans.png ├── smallNORB_lstsq.png ├── smallNORB_lstsq_multi.png ├── smallNORB_lstsq_rec.png ├── smallNORB_neuralM.png └── smallNORB_neural_trans.png ├── jobs ├── 09022022 │ ├── training_allmodels_1.sh │ └── training_allmodels_2.sh ├── 09152022 │ ├── training_allmodels_1.sh │ ├── training_allmodels_2.sh │ ├── training_lstsqs.sh │ └── training_neurals.sh └── 09232022 │ ├── training_neuralMlatentpred_mnist.sh │ └── training_neuralMlatentpred_mnist_bg.sh ├── models ├── __init__.py ├── base_networks.py ├── dynamics_models.py ├── resblock.py ├── seqae.py └── simclr_models.py ├── requirements.txt ├── run.py ├── training_allmodels.sh ├── training_loops.py └── utils ├── __init__.py ├── clr.py ├── emb2d.py ├── laplacian.py ├── misc.py ├── optimize_bd_cob.py ├── weight_standarization.py └── yaml_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Anonymous 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-sequential prediction (MSP) 2 | 3 | 4 |

5 | image 6 |

7 | 8 | 9 | This repository contains the code for the NeurIPS2022 paper: Unsupervised Learning of Equivariant Structure on Sequences. 10 | A simple encoder-decoder model trained with *meta-sequential prediction* captures the hidden disentangled structure underlying the datasets. 11 | 12 | [[arXiv]](https://arxiv.org/abs/2210.05972) [[OpenReview]](https://openreview.net/forum?id=7b7iGkuVqlZ) 13 | 14 | ## Implementation of MSP and simultaneous block diagonalization (SBD) 15 | - `SeqAELSTSQ` in `./models/seqae.py` is the implementation of *meta-sequential prediction*. 16 | - `tracenorm_of_normalized_laplacian` in `./utils/laplacian.py` is used to calculate the block diagonalization loss in our paper. 17 | 18 | # Setup 19 | ## Prerequisite 20 | python3.7, CUDA11.2, cuDNN 21 | 22 | Please install additional python libraries by: 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## Download datasets 28 | Download the compressed dataset files from the following link: 29 | 30 | https://drive.google.com/drive/folders/1_uXjx06U48to9OSyGY1ezqipAbbuT0vq?usp=sharing 31 | 32 | The following is an example script to download and decompress the files: 33 | ``` 34 | # If gdown is not installed: 35 | pip install gdown 36 | 37 | export DATADIR_ROOT=/tmp/path/to/datadir/ 38 | cd /tmp 39 | gdown --folder https://drive.google.com/drive/folders/1_uXjx06U48to9OSyGY1ezqipAbbuT0vq?usp=sharing 40 | mv datasets/* ${DATADIR_ROOT}; rm datasets -r 41 | cd - 42 | tar xzf ${DATADIR_ROOT}/MNIST.tar.gz -C $DATADIR_ROOT 43 | tar xzf ${DATADIR_ROOT}/3dshapes.tar.gz -C $DATADIR_ROOT 44 | tar xzf ${DATADIR_ROOT}/smallNORB.tar.gz -C $DATADIR_ROOT 45 | ``` 46 | 47 | ## Training with MSP 48 | 1. Select dataset on which you want to train the model: 49 | ``` 50 | # Sequential MNIST 51 | export CONFIG=configs/mnist/lstsq/lstsq.yml 52 | # Sequential MNIST-bg with digit 4 53 | export CONFIG=configs/mnist_bg/lstsq/lstsq.yml 54 | # 3DShapes 55 | export CONFIG=configs/3dshapes/lstsq/lstsq.yml 56 | # SmallNORB 57 | export CONFIG=configs/smallNORB/lstsq/lstsq.yml 58 | # Accelerated sequential MNIST 59 | export CONFIG=configs/mnist_accl/lstsq/holstsq.yml 60 | ``` 61 | 62 | 2. Run: 63 | ``` 64 | export LOGDIR=/tmp/path/to/logdir 65 | export DATADIR_ROOT=/tmp/path/to/datadir 66 | python run.py --config_path=$CONFIG --log_dir=$LOGDIR --attr train_data.args.root=$DATADIR_ROOT 67 | ``` 68 | 69 | ## Training all the methods we tested in our NeurIPS paper 70 | ``` 71 | export LOGDIR=/tmp/path/to/logdir 72 | export DATADIR_ROOT=/tmp/path/to/datadir 73 | bash training_allmodels.sh $LOGDIR $DATADIR_ROOT 74 | ``` 75 | 76 | ## Evaluations 77 | - Generated images: `gen_images.ipynb` 78 | - Equivariance errors: `equivariance_error.ipynb` 79 | - Prediction errors: `extrp.ipynb` 80 | - Simultaneous block diagoanlization: `block_diagonalization.ipynb` 81 | -------------------------------------------------------------------------------- /configs/3dshapes/lstsq/lstsq.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/three_dim_shapes.py 14 | name: ThreeDimShapesDataset 15 | args: 16 | root: /tmp/datasets/3dshapes/ 17 | train: True 18 | T: 3 19 | 20 | model: 21 | fn: ./models/seqae.py 22 | name: SeqAELSTSQ 23 | args: 24 | dim_m: 256 25 | dim_a: 16 26 | ch_x: 3 27 | k: 1.0 28 | bottom_width: 8 29 | predictive: True 30 | 31 | training_loop: 32 | fn: ./training_loops.py 33 | name: loop_seqmodel 34 | args: 35 | lr_decay_iter: 40000 36 | reconst_iter: 0 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /configs/3dshapes/lstsq/lstsq_multi.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | optimizer: Adam 12 | 13 | train_data: 14 | fn: ./datasets/three_dim_shapes.py 15 | name: ThreeDimShapesDataset 16 | args: 17 | root: /tmp/datasets/3dshapes/ 18 | train: True 19 | T: 3 20 | 21 | model: 22 | fn: ./models/seqae 23 | name: SeqAEMultiLSTSQ 24 | args: 25 | dim_m: 256 26 | dim_a: 16 27 | ch_x: 3 28 | k: 1.0 29 | bottom_width: 8 30 | alignment: False 31 | change_of_basis: False 32 | global_average_pooling: False # Important. (Why?) 33 | discretized: False 34 | K: 8 35 | 36 | 37 | training_loop: 38 | fn: ./training_loops.py 39 | name: loop_seqmodel 40 | args: 41 | 42 | lr_decay_iter: 40000 43 | reconst_iter: 0 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /configs/3dshapes/lstsq/lstsq_rec.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 3 9 | T_pred: 0 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/three_dim_shapes.py 14 | name: ThreeDimShapesDataset 15 | args: 16 | root: /tmp/datasets/3dshapes/ 17 | train: True 18 | T: 3 19 | 20 | model: 21 | fn: ./models/seqae.py 22 | name: SeqAELSTSQ 23 | args: 24 | dim_m: 256 25 | dim_a: 16 26 | ch_x: 3 27 | k: 1.0 28 | bottom_width: 8 29 | predictive: False 30 | 31 | training_loop: 32 | fn: ./training_loops.py 33 | name: loop_seqmodel 34 | args: 35 | lr_decay_iter: 40000 36 | reconst_iter: 0 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /configs/3dshapes/lstsq/neuralM.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/three_dim_shapes.py 14 | name: ThreeDimShapesDataset 15 | args: 16 | root: /tmp/datasets/3dshapes/ 17 | train: True 18 | T: 3 19 | 20 | model: 21 | fn: ./models/seqae.py 22 | name: SeqAENeuralM 23 | args: 24 | dim_m: 256 25 | dim_a: 16 26 | ch_x: 3 27 | k: 1.0 28 | bottom_width: 8 29 | 30 | 31 | training_loop: 32 | fn: ./training_loops.py 33 | name: loop_seqmodel 34 | args: 35 | lr_decay_iter: 40000 36 | reconst_iter: 50000 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/3dshapes/lstsq/neural_trans.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/three_dim_shapes.py 14 | name: ThreeDimShapesDataset 15 | args: 16 | root: /tmp/datasets/3dshapes/ 17 | train: True 18 | T: 3 19 | 20 | model: 21 | fn: ./models/seqae.py 22 | name: SeqAENeuralTransition 23 | args: 24 | dim_m: 256 25 | dim_a: 16 26 | ch_x: 3 27 | k: 1.0 28 | bottom_width: 8 29 | 30 | 31 | training_loop: 32 | fn: ./training_loops.py 33 | name: loop_seqmodel 34 | args: 35 | lr_decay_iter: 40000 36 | reconst_iter: 50000 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/mnist/lstsq/lstsq.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: False 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAELSTSQ 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 2. 33 | predictive: True 34 | 35 | training_loop: 36 | fn: ./training_loops.py 37 | name: loop_seqmodel 38 | args: 39 | lr_decay_iter: 40000 40 | reconst_iter: 0 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /configs/mnist/lstsq/lstsq_multi.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: False 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAEMultiLSTSQ 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 2. 33 | predictive: True 34 | K: 8 35 | 36 | 37 | 38 | training_loop: 39 | fn: ./training_loops.py 40 | name: loop_seqmodel 41 | args: 42 | lr_decay_iter: 40000 43 | reconst_iter: 0 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /configs/mnist/lstsq/lstsq_rec.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 3 9 | T_pred: 0 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: False 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAELSTSQ 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 2. 33 | predictive: False 34 | 35 | training_loop: 36 | fn: ./training_loops.py 37 | name: loop_seqmodel 38 | args: 39 | lr_decay_iter: 40000 40 | reconst_iter: 0 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /configs/mnist/lstsq/neuralM.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | 24 | model: 25 | fn: ./models/seqae.py 26 | name: SeqAENeuralM 27 | args: 28 | dim_m: 256 29 | dim_a: 16 30 | ch_x: 3 31 | k: 2. 32 | 33 | 34 | 35 | training_loop: 36 | fn: ./training_loops.py 37 | name: loop_seqmodel 38 | args: 39 | lr_decay_iter: 40000 40 | reconst_iter: 50000 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /configs/mnist/lstsq/neuralM_latentpred.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | 24 | model: 25 | fn: ./models/seqae.py 26 | name: SeqAENeuralMLatentPredict 27 | args: 28 | dim_m: 256 29 | dim_a: 16 30 | ch_x: 3 31 | k: 2. 32 | loss_reconst_coeff: 0.0 33 | loss_pred_coeff: 1.0 34 | loss_latent_coeff: 0.0 35 | normalize: True 36 | 37 | 38 | 39 | 40 | training_loop: 41 | fn: ./training_loops.py 42 | name: loop_seqmodel 43 | args: 44 | lr_decay_iter: 40000 45 | reconst_iter: 50000 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /configs/mnist/lstsq/neural_trans.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | 24 | model: 25 | fn: ./models/seqae.py 26 | name: SeqAENeuralTransition 27 | args: 28 | dim_m: 256 29 | dim_a: 16 30 | ch_x: 3 31 | k: 2. 32 | 33 | training_loop: 34 | fn: ./training_loops.py 35 | name: loop_seqmodel 36 | args: 37 | lr_decay_iter: 40000 38 | reconst_iter: 50000 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/mnist/simclr/cpc.yml: -------------------------------------------------------------------------------- 1 | batchsize: 64 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 10000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: False 24 | 25 | model: 26 | fn: ./models/seqae 27 | name: CPC 28 | args: 29 | dim_m: 1 30 | dim_a: 512 31 | k: 2. 32 | loss_type: cossim 33 | normalize: True 34 | temp: 0.01 35 | 36 | training_loop: 37 | fn: ./training_loops.py 38 | name: loop_seqmodel 39 | args: 40 | lr_decay_iter: 40000 41 | reconst_iter: 0 42 | -------------------------------------------------------------------------------- /configs/mnist/simclr/simclr.yml: -------------------------------------------------------------------------------- 1 | batchsize: 64 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 10000 7 | num_workers: 6 8 | lr: 0.0003 9 | 10 | train_data: 11 | fn: ./datasets/seq_mnist.py 12 | name: SequentialMNIST 13 | args: 14 | root: /tmp/datasets/MNIST 15 | train: True 16 | T: 3 17 | max_T: 9 18 | max_angle_velocity_ratio: [-0.5, 0.5] 19 | max_color_velocity_ratio: [-0.5, 0.5] 20 | only_use_digit4: True 21 | backgrnd: False 22 | 23 | model: 24 | fn: ./models/simclr_models.py 25 | name: ResNetwProjHead 26 | args: 27 | dim_mlp: 512 28 | dim_head: 512 29 | k: 2 30 | 31 | 32 | training_loop: 33 | fn: ./training_loops.py 34 | name: loop_simclr 35 | args: 36 | lr_decay_iter: 40000 37 | temp: 0.01 38 | loss_type: cossim -------------------------------------------------------------------------------- /configs/mnist_accl/lstsq/holstsq.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 5 9 | T_pred: 5 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 10 19 | max_T: 10 20 | max_angle_velocity_ratio: [-0.2, 0.2] 21 | max_angle_accl_ratio: [-0.025, 0.025] 22 | max_color_velocity_ratio: [-0.2, 0.2] 23 | max_color_accl_ratio: [-0.025, 0.025] 24 | max_pos: [0., 0.] 25 | max_trans_accl: [0., 0.] 26 | only_use_digit4: True 27 | 28 | model: 29 | fn: ./models/seqae 30 | name: SeqAEHOLSTSQ 31 | args: 32 | dim_m: 256 33 | dim_a: 16 34 | ch_x: 3 35 | k: 2. 36 | n_order: 2 37 | predictive: True 38 | kernel_size: 3 39 | 40 | training_loop: 41 | fn: ./training_loops.py 42 | name: loop_seqmodel 43 | args: 44 | lr_decay_iter: 80000 45 | reconst_iter: 0 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /configs/mnist_accl/lstsq/lstsq.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 5 9 | T_pred: 5 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 10 19 | max_T: 10 20 | max_angle_velocity_ratio: [-0.2, 0.2] 21 | max_angle_accl_ratio: [-0.025, 0.025] 22 | max_color_velocity_ratio: [-0.2, 0.2] 23 | max_color_accl_ratio: [-0.025, 0.025] 24 | max_pos: [0., 0.] 25 | max_trans_accl: [0., 0.] 26 | only_use_digit4: True 27 | 28 | model: 29 | fn: ./models/seqae 30 | name: SeqAELSTSQ 31 | args: 32 | dim_m: 256 33 | dim_a: 16 34 | ch_x: 3 35 | k: 2. 36 | predictive: True 37 | kernel_size: 3 38 | 39 | training_loop: 40 | fn: ./training_loops.py 41 | name: loop_seqmodel 42 | args: 43 | lr_decay_iter: 80000 44 | reconst_iter: 0 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /configs/mnist_accl/lstsq/neural_trans.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 5 9 | T_pred: 5 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 10 19 | max_T: 10 20 | max_angle_velocity_ratio: [-0.2, 0.2] 21 | max_angle_accl_ratio: [-0.025, 0.025] 22 | max_color_velocity_ratio: [-0.2, 0.2] 23 | max_color_accl_ratio: [-0.025, 0.025] 24 | max_pos: [0., 0.] 25 | max_trans_accl: [0., 0.] 26 | only_use_digit4: True 27 | 28 | model: 29 | fn: ./models/seqae.py 30 | name: SeqAENeuralTransition 31 | args: 32 | dim_m: 256 33 | dim_a: 16 34 | ch_x: 3 35 | k: 2. 36 | T_cond: 5 37 | predictive: True 38 | kernel_size: 3 39 | 40 | 41 | training_loop: 42 | fn: ./training_loops.py 43 | name: loop_seqmodel 44 | args: 45 | lr_decay_iter: 80000 46 | reconst_iter: 100000 -------------------------------------------------------------------------------- /configs/mnist_bg/lstsq/lstsq.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: True 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAELSTSQ 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 4. 33 | predictive: True 34 | 35 | training_loop: 36 | fn: ./training_loops.py 37 | name: loop_seqmodel 38 | args: 39 | lr_decay_iter: 80000 40 | reconst_iter: 0 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /configs/mnist_bg/lstsq/lstsq_multi.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: True 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAEMultiLSTSQ 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 4. 33 | predictive: True 34 | K: 8 35 | 36 | 37 | 38 | training_loop: 39 | fn: ./training_loops.py 40 | name: loop_seqmodel 41 | args: 42 | lr_decay_iter: 80000 43 | reconst_iter: 0 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /configs/mnist_bg/lstsq/lstsq_rec.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 3 9 | T_pred: 0 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: True 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAELSTSQ 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 4. 33 | predictive: False 34 | 35 | training_loop: 36 | fn: ./training_loops.py 37 | name: loop_seqmodel 38 | args: 39 | lr_decay_iter: 80000 40 | reconst_iter: 0 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /configs/mnist_bg/lstsq/neuralM.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: True 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAENeuralM 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 4. 33 | 34 | 35 | 36 | training_loop: 37 | fn: ./training_loops.py 38 | name: loop_seqmodel 39 | args: 40 | lr_decay_iter: 80000 41 | reconst_iter: 100000 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /configs/mnist_bg/lstsq/neuralM_latentpred.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: True 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAENeuralMLatentPredict 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 4. 33 | loss_reconst_coeff: 0.0 34 | loss_pred_coeff: 1.0 35 | loss_latent_coeff: 0.0 36 | normalize: True 37 | 38 | 39 | training_loop: 40 | fn: ./training_loops.py 41 | name: loop_seqmodel 42 | args: 43 | lr_decay_iter: 80000 44 | reconst_iter: 100000 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /configs/mnist_bg/lstsq/neural_trans.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: True 24 | 25 | model: 26 | fn: ./models/seqae.py 27 | name: SeqAENeuralTransition 28 | args: 29 | dim_m: 256 30 | dim_a: 16 31 | ch_x: 3 32 | k: 4. 33 | 34 | training_loop: 35 | fn: ./training_loops.py 36 | name: loop_seqmodel 37 | args: 38 | lr_decay_iter: 80000 39 | reconst_iter: 100000 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/mnist_bg/simclr/cpc.yml: -------------------------------------------------------------------------------- 1 | batchsize: 64 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 10000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/seq_mnist.py 14 | name: SequentialMNIST 15 | args: 16 | root: /tmp/datasets/MNIST 17 | train: True 18 | T: 3 19 | max_T: 9 20 | max_angle_velocity_ratio: [-0.5, 0.5] 21 | max_color_velocity_ratio: [-0.5, 0.5] 22 | only_use_digit4: True 23 | backgrnd: False 24 | 25 | model: 26 | fn: ./models/seqae 27 | name: CPC 28 | args: 29 | dim_m: 1 30 | dim_a: 512 31 | k: 4. 32 | loss_type: cossim 33 | normalize: True 34 | temp: 0.01 35 | 36 | training_loop: 37 | fn: ./training_loops.py 38 | name: loop_seqmodel 39 | args: 40 | lr_decay_iter: 40000 41 | reconst_iter: 0 42 | -------------------------------------------------------------------------------- /configs/mnist_bg/simclr/simclr.yml: -------------------------------------------------------------------------------- 1 | batchsize: 64 2 | seed: 1 3 | max_iteration: 100000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 10000 7 | num_workers: 6 8 | lr: 0.0003 9 | 10 | train_data: 11 | fn: ./datasets/seq_mnist.py 12 | name: SequentialMNIST 13 | args: 14 | root: /tmp/datasets/MNIST 15 | train: True 16 | T: 3 17 | max_T: 9 18 | max_angle_velocity_ratio: [-0.5, 0.5] 19 | max_color_velocity_ratio: [-0.5, 0.5] 20 | only_use_digit4: True 21 | backgrnd: False 22 | 23 | model: 24 | fn: ./models/simclr_models.py 25 | name: ResNetwProjHead 26 | args: 27 | dim_mlp: 512 28 | dim_head: 512 29 | k: 4 30 | 31 | 32 | training_loop: 33 | fn: ./training_loops.py 34 | name: loop_simclr 35 | args: 36 | lr_decay_iter: 80000 37 | temp: 0.01 38 | loss_type: cossim -------------------------------------------------------------------------------- /configs/smallNORB/lstsq/lstsq.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/small_norb.py 14 | name: SmallNORBDataset 15 | args: 16 | root: /tmp/datasets 17 | train: True 18 | T: 3 19 | 20 | model: 21 | fn: ./models/seqae.py 22 | name: SeqAELSTSQ 23 | args: 24 | dim_m: 256 25 | dim_a: 16 26 | ch_x: 1 27 | k: 1.0 28 | bottom_width: 6 29 | n_blocks: 4 30 | predictive: True 31 | 32 | 33 | training_loop: 34 | fn: ./training_loops.py 35 | name: loop_seqmodel 36 | args: 37 | lr_decay_iter: 40000 38 | reconst_iter: 0 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/smallNORB/lstsq/lstsq_multi.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 2 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | 13 | train_data: 14 | fn: ./datasets/small_norb.py 15 | name: SmallNORBDataset 16 | args: 17 | root: /tmp/datasets 18 | train: True 19 | T: 3 20 | 21 | model: 22 | fn: ./models/seqae 23 | name: SeqAEMultiLSTSQ 24 | args: 25 | dim_m: 256 26 | dim_a: 16 27 | ch_x: 1 28 | k: 1.0 29 | bottom_width: 6 30 | n_blocks: 4 31 | K: 8 32 | 33 | training_loop: 34 | fn: ./training_loops.py 35 | name: loop_seqmodel 36 | args: 37 | lr_decay_iter: 40000 38 | reconst_iter: 0 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/smallNORB/lstsq/lstsq_rec.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 3 9 | T_pred: 0 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/small_norb.py 14 | name: SmallNORBDataset 15 | args: 16 | root: /tmp/datasets 17 | train: True 18 | T: 3 19 | 20 | model: 21 | fn: ./models/seqae.py 22 | name: SeqAELSTSQ 23 | args: 24 | dim_m: 256 25 | dim_a: 16 26 | ch_x: 1 27 | k: 1.0 28 | bottom_width: 6 29 | n_blocks: 4 30 | predictive: False 31 | 32 | 33 | training_loop: 34 | fn: ./training_loops.py 35 | name: loop_seqmodel 36 | args: 37 | lr_decay_iter: 40000 38 | reconst_iter: 0 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/smallNORB/lstsq/neuralM.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | train_data: 13 | fn: ./datasets/small_norb.py 14 | name: SmallNORBDataset 15 | args: 16 | root: /tmp/datasets 17 | train: True 18 | T: 3 19 | 20 | model: 21 | fn: ./models/seqae.py 22 | name: SeqAENeuralM 23 | args: 24 | dim_m: 256 25 | dim_a: 16 26 | ch_x: 1 27 | k: 1.0 28 | n_blocks: 4 29 | bottom_width: 6 30 | 31 | training_loop: 32 | fn: ./training_loops.py 33 | name: loop_seqmodel 34 | args: 35 | lr_decay_iter: 40000 36 | reconst_iter: 50000 37 | 38 | 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /configs/smallNORB/lstsq/neural_trans.yml: -------------------------------------------------------------------------------- 1 | batchsize: 32 2 | seed: 1 3 | max_iteration: 50000 4 | report_freq: 1000 5 | model_snapshot_freq: 10000 6 | manager_snapshot_freq: 50000 7 | num_workers: 6 8 | T_cond: 2 9 | T_pred: 1 10 | lr: 0.0003 11 | 12 | 13 | train_data: 14 | fn: ./datasets/small_norb.py 15 | name: SmallNORBDataset 16 | args: 17 | root: /tmp/datasets 18 | train: True 19 | T: 3 20 | 21 | model: 22 | fn: ./models/seqae.py 23 | name: SeqAENeuralTransition 24 | args: 25 | dim_m: 256 26 | dim_a: 16 27 | ch_x: 1 28 | k: 1.0 29 | n_blocks: 4 30 | bottom_width: 6 31 | 32 | training_loop: 33 | fn: ./training_loops.py 34 | name: loop_seqmodel 35 | args: 36 | lr_decay_iter: 40000 37 | reconst_iter: 50000 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/seq_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torchvision 6 | import math 7 | import colorsys 8 | from skimage.transform import resize 9 | from copy import deepcopy 10 | from utils.misc import get_RTmat 11 | from utils.misc import freq_to_wave 12 | 13 | 14 | class SequentialMNIST(): 15 | # Rotate around z axis only. 16 | 17 | default_active_actions = [0, 1, 2] 18 | 19 | def __init__( 20 | self, 21 | root, 22 | train=True, 23 | transforms=torchvision.transforms.ToTensor(), 24 | T=3, 25 | max_angle_velocity_ratio=[-0.5, 0.5], 26 | max_angle_accl_ratio=[-0.0, 0.0], 27 | max_color_velocity_ratio=[-0.5, 0.5], 28 | max_color_accl_ratio=[-0.0, 0.0], 29 | max_pos=[-10, 10], 30 | max_trans_accl=[-0.0, 0.0], 31 | label=False, 32 | label_velo=False, 33 | label_accl=False, 34 | active_actions=None, 35 | max_T=9, 36 | only_use_digit4=False, 37 | backgrnd=False, 38 | shared_transition=False, 39 | color_off=False, 40 | rng=None 41 | ): 42 | self.T = T 43 | self.max_T = max_T 44 | self.rng = rng if rng is not None else np.random 45 | self.transforms = transforms 46 | self.data = torchvision.datasets.MNIST(root, train, download=True) 47 | self.angle_velocity_range = (-max_angle_velocity_ratio, max_angle_velocity_ratio) if isinstance( 48 | max_angle_velocity_ratio, (int, float)) else max_angle_velocity_ratio 49 | self.color_velocity_range = (-max_color_velocity_ratio, max_color_velocity_ratio) if isinstance( 50 | max_color_velocity_ratio, (int, float)) else max_color_velocity_ratio 51 | self.angle_accl_range = (-max_angle_accl_ratio, max_angle_accl_ratio) if isinstance( 52 | max_angle_accl_ratio, (int, float)) else max_angle_accl_ratio 53 | self.color_accl_range = (-max_color_accl_ratio, max_color_accl_ratio) if isinstance( 54 | max_color_accl_ratio, (int, float)) else max_color_accl_ratio 55 | self.color_off = color_off 56 | 57 | self.max_pos = max_pos 58 | self.max_trans_accl = max_trans_accl 59 | self.label = label 60 | self.label_velo = label_velo 61 | self.label_accl = label_accl 62 | self.active_actions = self.default_active_actions if active_actions is None else active_actions 63 | if backgrnd: 64 | print(""" 65 | ============= 66 | background ON 67 | ============= 68 | """) 69 | fname = "MNIST/train_dat.pt" if train else "MNIST/test_dat.pt" 70 | self.backgrnd_data = torch.load(os.path.join(root, fname)) 71 | 72 | if only_use_digit4: 73 | datas = [] 74 | for pair in self.data: 75 | if pair[1] == 4: 76 | datas.append(pair) 77 | self.data = datas 78 | self.shared_transition = shared_transition 79 | if self.shared_transition: 80 | self.init_shared_transition_parameters() 81 | 82 | def init_shared_transition_parameters(self): 83 | self.angles_v = self.rng.uniform(math.pi * self.angle_velocity_range[0], 84 | math.pi * self.angle_velocity_range[1], size=1) 85 | self.angles_a = self.rng.uniform(math.pi * self.angle_accl_range[0], 86 | math.pi * self.angle_accl_range[1], size=1) 87 | self.color_v = 0.5 * self.rng.uniform(self.color_velocity_range[0], 88 | self.color_velocity_range[1], size=1) 89 | self.color_a = 0.5 * \ 90 | self.rng.uniform( 91 | self.color_accl_range[0], self.color_accl_range[1], size=1) 92 | pos0 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2]) 93 | pos1 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2]) 94 | self.pos_v = (pos1-pos0)/(self.max_T - 1) 95 | self.pos_a = self.rng.uniform( 96 | self.max_trans_accl[0], self.max_trans_accl[1], size=[2]) 97 | 98 | def __len__(self): 99 | return len(self.data) 100 | 101 | def __getitem__(self, i): 102 | image = np.array(self.data[i][0], np.float32).reshape(28, 28) 103 | image = resize(image, [24, 24]) 104 | image = cv2.copyMakeBorder( 105 | image, 4, 4, 4, 4, cv2.BORDER_CONSTANT, value=(0, 0, 0)) 106 | angles_0 = self.rng.uniform(0, 2 * math.pi, size=1) 107 | color_0 = self.rng.uniform(0, 1, size=1) 108 | pos0 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2]) 109 | pos1 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2]) 110 | if self.shared_transition: 111 | (angles_v, angles_a) = (self.angles_v, self.angles_a) 112 | (color_v, color_a) = (self.color_v, self.color_a) 113 | (pos_v, pos_a) = (self.pos_v, self.pos_a) 114 | else: 115 | angles_v = self.rng.uniform(math.pi * self.angle_velocity_range[0], 116 | math.pi * self.angle_velocity_range[1], size=1) 117 | angles_a = self.rng.uniform(math.pi * self.angle_accl_range[0], 118 | math.pi * self.angle_accl_range[1], size=1) 119 | color_v = 0.5 * self.rng.uniform(self.color_velocity_range[0], 120 | self.color_velocity_range[1], size=1) 121 | color_a = 0.5 * \ 122 | self.rng.uniform( 123 | self.color_accl_range[0], self.color_accl_range[1], size=1) 124 | pos_v = (pos1-pos0)/(self.max_T - 1) 125 | pos_a = self.rng.uniform( 126 | self.max_trans_accl[0], self.max_trans_accl[1], size=[2]) 127 | images = [] 128 | for t in range(self.T): 129 | angles_t = (0.5 * angles_a * t**2 + angles_v * t + 130 | angles_0) if 0 in self.active_actions else angles_0 131 | color_t = ((0.5 * color_a * t**2 + t * color_v + color_0) % 132 | 1) if 1 in self.active_actions else color_0 133 | pos_t = (0.5 * pos_a * t**2 + pos_v * t + 134 | pos0) if 2 in self.active_actions else pos0 135 | mat = get_RTmat(0, 0, float(angles_t), 32, 32, pos_t[0], pos_t[1]) 136 | _image = cv2.warpPerspective(image.copy(), mat, (32, 32)) 137 | 138 | rgb = np.asarray(colorsys.hsv_to_rgb( 139 | color_t, 1, 1), dtype=np.float32) 140 | _image = np.concatenate( 141 | [_image[:, :, None]] * 3, axis=-1) * rgb[None, None] 142 | _image = _image / 255. 143 | 144 | if hasattr(self, 'backgrnd_data'): 145 | _imagemask = (np.sum(_image, axis=2, keepdims=True) < 3e-1) 146 | _image = torch.tensor( 147 | _image) + self.backgrnd_data[i].permute([1, 2, 0]) * (_imagemask) 148 | _image = np.array(torch.clip(_image, max=1.)) 149 | 150 | images.append(self.transforms(_image.astype(np.float32))) 151 | 152 | if self.label or self.label_velo: 153 | ret = [images] 154 | if self.label: 155 | ret += [self.data[i][1]] 156 | if self.label_velo: 157 | ret += [ 158 | freq_to_wave(angles_v.astype(np.float32)), 159 | freq_to_wave((2 * math.pi * color_v).astype(np.float32)), 160 | pos_v.astype(np.float32) 161 | ] 162 | if self.label_accl: 163 | ret += [ 164 | freq_to_wave(angles_a.astype(np.float32)), 165 | freq_to_wave((2 * math.pi * color_a).astype(np.float32)), 166 | pos_a.astype(np.float32) 167 | ] 168 | return ret 169 | else: 170 | return images 171 | -------------------------------------------------------------------------------- /datasets/small_norb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from collections import OrderedDict 5 | import os 6 | import numpy as np 7 | 8 | 9 | _FACTORS_IN_ORDER = ['category', 'instance', 10 | 'lighting', 'elevation', 'azimuth'] 11 | _ELEV_V = [30, 35, 40, 45, 50, 55, 60, 65, 70] 12 | _AZIM_V = np.arange(0, 350, 20) 13 | assert len(_AZIM_V) == 18 14 | _NUM_VALUES_PER_FACTOR = OrderedDict( 15 | {'category': 5, 'instance': 5, 'lighting': 6, 'elevation': 9, 'azimuth': 18}) 16 | 17 | 18 | def get_index(factors): 19 | """ Converts factors to indices in range(num_data) 20 | Args: 21 | factors: np array shape [6,batch_size]. 22 | factors[i]=factors[i,:] takes integer values in 23 | range(_NUM_VALUES_PER_FACTOR[_FACTORS_IN_ORDER[i]]). 24 | 25 | Returns: 26 | indices: np array shape [batch_size]. 27 | """ 28 | indices = 0 29 | base = 1 30 | for factor, name in reversed(list(enumerate(_FACTORS_IN_ORDER))): 31 | indices += factors[factor] * base 32 | base *= _NUM_VALUES_PER_FACTOR[name] 33 | return indices 34 | 35 | 36 | class SmallNORBDataset(object): 37 | 38 | default_active_actions = [3,4] 39 | 40 | def __init__(self, root, 41 | train=True, 42 | T=3, 43 | label=False, 44 | label_velo=False, 45 | force_moving=False, 46 | active_actions=None, 47 | transforms=torchvision.transforms.ToTensor(), 48 | shared_transition=False, 49 | rng=None): 50 | assert T <= 6 51 | self.data = torch.load(os.path.join( 52 | root, 'smallNORB/train.pt' if train else 'smallNORB/test.pt')) 53 | self.label = label 54 | self.label_velo = label_velo 55 | print(self.data.shape) 56 | self.T = T 57 | self.transforms = transforms 58 | self.active_actions = self.default_active_actions if active_actions is None else active_actions 59 | self.force_moving = force_moving 60 | self.rng = rng if rng is not None else np.random 61 | self.shared_transition = shared_transition 62 | if self.shared_transition: 63 | self.init_shared_transition_parameters() 64 | 65 | def init_shared_transition_parameters(self): 66 | self.vs = {} 67 | for kv in _NUM_VALUES_PER_FACTOR.items(): 68 | key, value = kv[0], kv[1] 69 | self.vs[key] = self.gen_v(value) 70 | 71 | def __len__(self): 72 | return 5000 73 | 74 | def gen_pos(self, max_n, v): 75 | _x = np.abs(v) * (self.T-1) 76 | if v < 0: 77 | return self.rng.randint(_x, max_n) 78 | else: 79 | return self.rng.randint(0, max_n-_x) 80 | 81 | def gen_v(self, max_n): 82 | v = self.rng.randint(1 if self.force_moving else 0, max_n//self.T + 1) 83 | if self.rng.uniform() > 0.5: 84 | v = -v 85 | return v 86 | 87 | def gen_factors(self): 88 | # initial state 89 | p_and_v_list = [] 90 | sampled_indices = [] 91 | for action_index, kv in enumerate(_NUM_VALUES_PER_FACTOR.items()): 92 | key, value = kv[0], kv[1] 93 | if key == 'category' or key == 'instance' or key == 'lighting': 94 | p_and_v_list.append([0, 0]) 95 | index = self.rng.randint(0, _NUM_VALUES_PER_FACTOR[key]) 96 | sampled_indices.append([index]*self.T) 97 | else: 98 | if not(action_index in self.active_actions): 99 | v = 0 100 | else: 101 | if self.shared_transition: 102 | v = self.vs[key] 103 | else: 104 | v = self.gen_v(value) 105 | p = self.gen_pos(value, v) 106 | p_and_v_list.append((p, v)) 107 | indices = [p + t * v for t in range(self.T)] 108 | sampled_indices.append(indices) 109 | #print(p_and_v_list) 110 | return np.array(p_and_v_list, dtype=np.uint8), np.array(sampled_indices, dtype=np.uint8).T 111 | 112 | def __getitem__(self, i): 113 | p_and_v_list, sample_indices = self.gen_factors() 114 | imgs = [] 115 | for t in range(self.T): 116 | ind = sample_indices[t] 117 | img = self.data[ind[0], ind[1], ind[2], ind[3], ind[4]] 118 | img = img/255. 119 | img = self.transforms(img[:, :, None]) 120 | imgs.append(img) 121 | if self.T == 1: 122 | imgs = imgs[0] 123 | 124 | if self.label or self.label_velo: 125 | ret = [imgs] 126 | if self.label: 127 | ret += [sample_indices[0][0]] 128 | if self.label_velo: 129 | ret += [p_and_v_list[3][1][None], p_and_v_list[4][1][None]] 130 | 131 | return ret 132 | else: 133 | return imgs 134 | -------------------------------------------------------------------------------- /datasets/three_dim_shapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from collections import OrderedDict 5 | import os 6 | 7 | 8 | _FACTORS_IN_ORDER = ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 9 | 'orientation'] 10 | _NUM_VALUES_PER_FACTOR = OrderedDict({'floor_hue': 10, 'wall_hue': 10, 'object_hue': 10, 11 | 'scale': 8, 'shape': 4, 'orientation': 15}) 12 | 13 | 14 | def get_index(factors): 15 | """ Converts factors to indices in range(num_data) 16 | Args: 17 | factors: np array shape [6,batch_size]. 18 | factors[i]=factors[i,:] takes integer values in 19 | range(_NUM_VALUES_PER_FACTOR[_FACTORS_IN_ORDER[i]]). 20 | 21 | Returns: 22 | indices: np array shape [batch_size]. 23 | """ 24 | indices = 0 25 | base = 1 26 | for factor, name in reversed(list(enumerate(_FACTORS_IN_ORDER))): 27 | indices += factors[factor] * base 28 | base *= _NUM_VALUES_PER_FACTOR[name] 29 | return indices 30 | 31 | 32 | class ThreeDimShapesDataset(object): 33 | 34 | default_active_actions = [0,1,2,3,5] 35 | 36 | def __init__(self, root, train=True, T=3, label_velo=False, transforms=torchvision.transforms.ToTensor(), 37 | active_actions=None, force_moving=False, shared_transition=False, rng=None): 38 | assert T <= 8 39 | self.images = torch.load(os.path.join(root, '3dshapes/images.pt')).astype(np.float32) 40 | self.label_velo = label_velo 41 | self.train = train 42 | self.T = T 43 | self.transforms = transforms 44 | self.active_actions = self.default_active_actions if active_actions is None else active_actions 45 | self.force_moving = force_moving 46 | self.rng = rng if rng is not None else np.random 47 | self.shared_transition = shared_transition 48 | if self.shared_transition: 49 | self.init_shared_transition_parameters() 50 | 51 | def init_shared_transition_parameters(self): 52 | vs = {} 53 | for kv in _NUM_VALUES_PER_FACTOR.items(): 54 | key, value = kv[0], kv[1] 55 | vs[key] = self.gen_v(value) 56 | self.vs = vs 57 | 58 | def __len__(self): 59 | return 5000 60 | 61 | def gen_pos(self, max_n, v): 62 | _x = np.abs(v) * (self.T-1) 63 | if v < 0: 64 | return self.rng.randint(_x, max_n) 65 | else: 66 | return self.rng.randint(0, max_n-_x) 67 | 68 | def gen_v(self, max_n): 69 | v = self.rng.randint(1 if self.force_moving else 0, max_n//self.T + 1) 70 | if self.rng.uniform() > 0.5: 71 | v = -v 72 | return v 73 | 74 | 75 | def gen_factors(self): 76 | # initial state 77 | p_and_v_list = [] 78 | sampled_indices = [] 79 | for action_index, kv in enumerate(_NUM_VALUES_PER_FACTOR.items()): 80 | key, value = kv[0], kv[1] 81 | if key == 'shape': 82 | p_and_v_list.append([0, 0]) 83 | if self.train: 84 | shape = self.rng.choice([0]) 85 | else: 86 | shape = self.rng.choice([1,2,3]) 87 | sampled_indices.append([shape]*self.T) 88 | else: 89 | if not(action_index in self.active_actions): 90 | v = 0 91 | else: 92 | if self.shared_transition: 93 | v = self.vs[key] 94 | else: 95 | v = self.gen_v(value) 96 | p = self.gen_pos(value, v) 97 | p_and_v_list.append((p, v)) 98 | indices = [p + t * v for t in range(self.T)] 99 | sampled_indices.append(indices) 100 | return np.array(p_and_v_list, dtype=np.uint8), np.array(sampled_indices, dtype=np.uint8).T 101 | 102 | 103 | def __getitem__(self, i): 104 | p_and_v_list, sample_indices = self.gen_factors() 105 | imgs = [] 106 | for t in range(self.T): 107 | img = self.images[get_index(sample_indices[t])] / 255. 108 | img = self.transforms(img) 109 | imgs.append(img) 110 | if self.label_velo: 111 | return imgs, p_and_v_list[0][1][None], p_and_v_list[1][1][None], p_and_v_list[2][1][None], p_and_v_list[3][1][None], p_and_v_list[5][1][None] 112 | else: 113 | return imgs 114 | 115 | -------------------------------------------------------------------------------- /figs/comp_error_3dshapes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_3dshapes.pdf -------------------------------------------------------------------------------- /figs/comp_error_mnist.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist.pdf -------------------------------------------------------------------------------- /figs/comp_error_mnist_accl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist_accl.pdf -------------------------------------------------------------------------------- /figs/comp_error_mnist_bg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist_bg.pdf -------------------------------------------------------------------------------- /figs/comp_error_mnist_bg_full.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist_bg_full.pdf -------------------------------------------------------------------------------- /figs/comp_error_smallNORB.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_smallNORB.pdf -------------------------------------------------------------------------------- /figs/disentangle_3dshapes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_3dshapes.pdf -------------------------------------------------------------------------------- /figs/disentangle_mnist.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_mnist.pdf -------------------------------------------------------------------------------- /figs/disentangle_mnist_bg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_mnist_bg.pdf -------------------------------------------------------------------------------- /figs/disentangle_mnist_bg_full.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_mnist_bg_full.pdf -------------------------------------------------------------------------------- /figs/disentangle_smallNORB.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_smallNORB.pdf -------------------------------------------------------------------------------- /figs/equiv_error_3dshapes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_3dshapes.pdf -------------------------------------------------------------------------------- /figs/equiv_error_mnist.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist.pdf -------------------------------------------------------------------------------- /figs/equiv_error_mnist_accl.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist_accl.pdf -------------------------------------------------------------------------------- /figs/equiv_error_mnist_bg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist_bg.pdf -------------------------------------------------------------------------------- /figs/equiv_error_mnist_bg_full.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist_bg_full.pdf -------------------------------------------------------------------------------- /figs/equiv_error_smallNORB.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_smallNORB.pdf -------------------------------------------------------------------------------- /figs/sbd_3dshapes.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes.pdf -------------------------------------------------------------------------------- /figs/sbd_3dshapes_component0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component0.pdf -------------------------------------------------------------------------------- /figs/sbd_3dshapes_component1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component1.pdf -------------------------------------------------------------------------------- /figs/sbd_3dshapes_component2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component2.pdf -------------------------------------------------------------------------------- /figs/sbd_3dshapes_component3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component3.pdf -------------------------------------------------------------------------------- /figs/sbd_3dshapes_component5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component5.pdf -------------------------------------------------------------------------------- /figs/sbd_3dshapes_components.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_components.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_bg.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_bg_components.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_components.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_bg_full.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_bg_full_component0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_component0.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_bg_full_component1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_component1.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_bg_full_component2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_component2.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_bg_full_components.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_components.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_component0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_component0.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_component1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_component1.pdf -------------------------------------------------------------------------------- /figs/sbd_mnist_component2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_component2.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB_component0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component0.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB_component1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component1.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB_component2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component2.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB_component3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component3.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB_component4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component4.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB_component5.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component5.pdf -------------------------------------------------------------------------------- /figs/sbd_smallNORB_components.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_components.pdf -------------------------------------------------------------------------------- /gen_images/3dshapes_lstsq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_lstsq.png -------------------------------------------------------------------------------- /gen_images/3dshapes_lstsq_multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_lstsq_multi.png -------------------------------------------------------------------------------- /gen_images/3dshapes_lstsq_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_lstsq_rec.png -------------------------------------------------------------------------------- /gen_images/3dshapes_neuralM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_neuralM.png -------------------------------------------------------------------------------- /gen_images/3dshapes_neural_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_neural_trans.png -------------------------------------------------------------------------------- /gen_images/mnist_accl_holstsq_equiv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_accl_holstsq_equiv.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_full_lstsq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_lstsq.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_full_lstsq_multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_lstsq_multi.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_full_lstsq_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_lstsq_rec.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_full_neuralM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_neuralM.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_full_neural_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_neural_trans.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_lstsq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_lstsq.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_lstsq_multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_lstsq_multi.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_lstsq_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_lstsq_rec.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_neuralM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_neuralM.png -------------------------------------------------------------------------------- /gen_images/mnist_bg_neural_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_neural_trans.png -------------------------------------------------------------------------------- /gen_images/mnist_lstsq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_lstsq.png -------------------------------------------------------------------------------- /gen_images/mnist_lstsq_multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_lstsq_multi.png -------------------------------------------------------------------------------- /gen_images/mnist_lstsq_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_lstsq_rec.png -------------------------------------------------------------------------------- /gen_images/mnist_neuralM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_neuralM.png -------------------------------------------------------------------------------- /gen_images/mnist_neural_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_neural_trans.png -------------------------------------------------------------------------------- /gen_images/smallNORB_lstsq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_lstsq.png -------------------------------------------------------------------------------- /gen_images/smallNORB_lstsq_multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_lstsq_multi.png -------------------------------------------------------------------------------- /gen_images/smallNORB_lstsq_rec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_lstsq_rec.png -------------------------------------------------------------------------------- /gen_images/smallNORB_neuralM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_neuralM.png -------------------------------------------------------------------------------- /gen_images/smallNORB_neural_trans.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_neural_trans.png -------------------------------------------------------------------------------- /jobs/09022022/training_allmodels_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT=$1 3 | DATASET_ROOT=$2 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do 7 | for model_name in lstsq lstsq_multi lstsq_rec neuralM neural_trans; do 8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 10 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT} 11 | done 12 | done 13 | done 14 | 15 | for seed in 1 2 3; do 16 | for dataset_name in mnist_accl; do 17 | for model_name in lstsq holstsq neural_trans; do 18 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 19 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 20 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 21 | done 22 | done 23 | done 24 | 25 | for seed in 1 2 3; do 26 | for dataset_name in mnist mnist_bg; do 27 | for model_name in simclr cpc; do 28 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 29 | --config_path=./configs/${dataset_name}/simclr/${model_name}.yml \ 30 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 31 | done 32 | done 33 | done 34 | -------------------------------------------------------------------------------- /jobs/09022022/training_allmodels_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT=$1 3 | DATASET_ROOT=$2 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in mnist_bg; do 7 | for model_name in lstsq lstsq_multi lstsq_rec; do 8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \ 9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 11 | done 12 | for model_name in neuralM neural_trans; do 13 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \ 14 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 15 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 training_loop.args.reconst_iter=200000 16 | done 17 | done 18 | done 19 | 20 | 21 | for seed in 1 2 3; do 22 | for dataset_name in mnist_accl; do 23 | for model_name in lstsq holstsq neural_trans; do 24 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 25 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 26 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 27 | done 28 | done 29 | done 30 | 31 | 32 | for seed in 1 2 3; do 33 | for dataset_name in mnist mnist_bg; do 34 | for model_name in simclr cpc; do 35 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 36 | --config_path=./configs/${dataset_name}/simclr/${model_name}.yml \ 37 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /jobs/09152022/training_allmodels_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022' 3 | DATASET_ROOT='/home/TakeruMiyato/datasets' 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do 7 | for model_name in lstsq_multi lstsq_rec neuralM neural_trans; do 8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 11 | done 12 | done 13 | done 14 | 15 | for seed in 1 2 3; do 16 | for dataset_name in mnist_accl; do 17 | for model_name in lstsq holstsq neural_trans; do 18 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 19 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 20 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 21 | done 22 | done 23 | done 24 | 25 | for seed in 1 2 3; do 26 | for dataset_name in mnist mnist_bg; do 27 | for model_name in simclr cpc; do 28 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 29 | --config_path=./configs/${dataset_name}/simclr/${model_name}.yml \ 30 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 31 | done 32 | done 33 | done 34 | -------------------------------------------------------------------------------- /jobs/09152022/training_allmodels_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022' 3 | DATASET_ROOT='/home/TakeruMiyato/datasets' 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in mnist_bg; do 7 | for model_name in lstsq lstsq_multi lstsq_rec; do 8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \ 9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 11 | done 12 | for model_name in neuralM neural_trans; do 13 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \ 14 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 15 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 training_loop.args.reconst_iter=200000 16 | done 17 | done 18 | done 19 | 20 | 21 | -------------------------------------------------------------------------------- /jobs/09152022/training_lstsqs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022' 3 | DATASET_ROOT='/home/TakeruMiyato/datasets' 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do 7 | for model_name in lstsq; do 8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 11 | done 12 | done 13 | done -------------------------------------------------------------------------------- /jobs/09152022/training_neurals.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022' 3 | DATASET_ROOT='/home/TakeruMiyato/datasets' 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in 3dshapes smallNORB; do 7 | for model_name in neuralM neural_trans; do 8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} 11 | done 12 | done 13 | done -------------------------------------------------------------------------------- /jobs/09232022/training_neuralMlatentpred_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT='/mnt/research_logs/logs/09232022' 3 | DATASET_ROOT='/home/TakeruMiyato/datasets' 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in mnist; do 7 | for model_name in neuralM_latentpred; do 8 | for loss_latent_coeff in 0.0001 0.001 0.01; do 9 | for loss_reconst_coeff in 0.1 1.0 10; do 10 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-lrc${loss_reconst_coeff}-llc${loss_latent_coeff}-seed${seed}/ \ 11 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 12 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} \ 13 | model.args.loss_latent_coeff=${loss_latent_coeff} model.args.loss_reconst_coeff=${loss_reconst_coeff} 14 | done 15 | done 16 | done 17 | done 18 | done -------------------------------------------------------------------------------- /jobs/09232022/training_neuralMlatentpred_mnist_bg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT='/mnt/research_logs/logs/09232022' 3 | DATASET_ROOT='/home/TakeruMiyato/datasets' 4 | 5 | for seed in 1 2 3; do 6 | for dataset_name in mnist_bg; do 7 | for model_name in neuralM_latentpred; do 8 | for loss_latent_coeff in 0.0001 0.001 0.01; do 9 | for loss_reconst_coeff in 0.1 1.0 10; do 10 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-lrc${loss_reconst_coeff}-llc${loss_latent_coeff}-seed${seed}/ \ 11 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 12 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} \ 13 | model.args.loss_latent_coeff=${loss_latent_coeff} model.args.loss_reconst_coeff=${loss_reconst_coeff} 14 | done 15 | done 16 | done 17 | done 18 | done -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/models/__init__.py -------------------------------------------------------------------------------- /models/base_networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from models.resblock import Block, Conv1d1x1Block 5 | from einops.layers.torch import Rearrange 6 | from einops import repeat 7 | 8 | 9 | class Conv1d1x1Encoder(nn.Sequential): 10 | def __init__(self, 11 | dim_out=16, 12 | dim_hidden=128, 13 | act=nn.ReLU()): 14 | super().__init__( 15 | nn.LazyConv1d(dim_hidden, 1, 1, 0), 16 | Conv1d1x1Block(dim_hidden, dim_hidden, act=act), 17 | Conv1d1x1Block(dim_hidden, dim_hidden, act=act), 18 | Rearrange('n c s -> n s c'), 19 | nn.LayerNorm((dim_hidden)), 20 | Rearrange('n s c-> n c s'), 21 | act, 22 | nn.LazyConv1d(dim_out, 1, 1, 0) 23 | ) 24 | 25 | 26 | class ResNetEncoder(nn.Module): 27 | def __init__(self, 28 | dim_latent=1024, 29 | k=1, 30 | act=nn.ReLU(), 31 | kernel_size=3, 32 | n_blocks=3): 33 | super().__init__() 34 | self.phi = nn.Sequential( 35 | nn.LazyConv2d(int(32 * k), 3, 1, 1), 36 | *[Block(int(32 * k) * (2 ** i), int(32 * k) * (2 ** (i+1)), int(32 * k) * (2 ** (i+1)), 37 | resample='down', activation=act, kernel_size=kernel_size) for i in range(n_blocks)], 38 | nn.GroupNorm(min(32, int(32 * k) * (2 ** n_blocks)), 39 | int(32 * k) * (2 ** n_blocks)), 40 | act) 41 | self.linear = nn.LazyLinear( 42 | dim_latent) if dim_latent > 0 else lambda x: x 43 | 44 | def __call__(self, x): 45 | h = x 46 | h = self.phi(h) 47 | h = h.reshape(h.shape[0], -1) 48 | h = self.linear(h) 49 | return h 50 | 51 | 52 | class ResNetDecoder(nn.Module): 53 | def __init__(self, ch_x, k=1, act=nn.ReLU(), kernel_size=3, bottom_width=4, n_blocks=3): 54 | super().__init__() 55 | self.bottom_width = bottom_width 56 | self.linear = nn.LazyLinear(int(32 * k) * (2 ** n_blocks)) 57 | self.net = nn.Sequential( 58 | *[Block(int(32 * k) * (2 ** (i+1)), int(32 * k) * (2 ** i), int(32 * k) * (2 ** i), 59 | resample='up', activation=act, kernel_size=kernel_size, posemb=True) for i in range(n_blocks-1, -1, -1)], 60 | nn.GroupNorm(min(32, int(32 * k)), int(32 * k)), 61 | act, 62 | nn.Conv2d(int(32 * k), ch_x, 3, 1, 1) 63 | ) 64 | 65 | def __call__(self, x): 66 | x = self.linear(x) 67 | x = repeat(x, 'n c -> n c h w', 68 | h=self.bottom_width, w=self.bottom_width) 69 | x = self.net(x) 70 | x = torch.sigmoid(x) 71 | return x 72 | -------------------------------------------------------------------------------- /models/dynamics_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from utils.laplacian import make_identity_like, tracenorm_of_normalized_laplacian, make_identity, make_diagonal 5 | import einops 6 | import pytorch_pfn_extras as ppe 7 | 8 | 9 | def _rep_M(M, T): 10 | return einops.repeat(M, "n a1 a2 -> n t a1 a2", t=T) 11 | 12 | 13 | def _loss(A, B): 14 | return torch.sum((A-B)**2) 15 | 16 | 17 | def _solve(A, B): 18 | ATA = A.transpose(-2, -1) @ A 19 | ATB = A.transpose(-2, -1) @ B 20 | return torch.linalg.solve(ATA, ATB) 21 | 22 | 23 | def loss_bd(M_star, alignment): 24 | # Block Diagonalization Loss 25 | S = torch.abs(M_star) 26 | STS = torch.matmul(S.transpose(-2, -1), S) 27 | if alignment: 28 | laploss_sts = tracenorm_of_normalized_laplacian( 29 | torch.mean(STS, 0)) 30 | else: 31 | laploss_sts = torch.mean( 32 | tracenorm_of_normalized_laplacian(STS), 0) 33 | return laploss_sts 34 | 35 | 36 | def loss_orth(M_star): 37 | # Orthogonalization of M 38 | I = make_identity_like(M_star) 39 | return torch.mean(torch.sum((I-M_star @ M_star.transpose(-2, -1))**2, axis=(-2, -1))) 40 | 41 | 42 | class LinearTensorDynamicsLSTSQ(nn.Module): 43 | 44 | class DynFn(nn.Module): 45 | def __init__(self, M): 46 | super().__init__() 47 | self.M = M 48 | 49 | def __call__(self, H): 50 | return H @ _rep_M(self.M, T=H.shape[1]) 51 | 52 | def inverse(self, H): 53 | M = _rep_M(self.M, T=H.shape[1]) 54 | return torch.linalg.solve(M, H.transpose(-2, -1)).transpose(-2, -1) 55 | 56 | def __init__(self, alignment=True): 57 | super().__init__() 58 | self.alignment = alignment 59 | 60 | def __call__(self, H, return_loss=False, fix_indices=None): 61 | # Regress M. 62 | # Note: backpropagation is disabled when fix_indices is not None. 63 | 64 | # H0.shape = H1.shape [n, t, s, a] 65 | H0, H1 = H[:, :-1], H[:, 1:] 66 | # num_ts x ([len_ts -1] * dim_s) x dim_a 67 | # The difference between the the time shifted components 68 | loss_internal_0 = _loss(H0, H1) 69 | ppe.reporting.report({ 70 | 'loss_internal_0': loss_internal_0.item() 71 | }) 72 | _H0 = H0.reshape(H0.shape[0], -1, H0.shape[-1]) 73 | _H1 = H1.reshape(H1.shape[0], -1, H1.shape[-1]) 74 | if fix_indices is not None: 75 | # Note: backpropagation is disabled. 76 | dim_a = _H0.shape[-1] 77 | active_indices = np.array(list(set(np.arange(dim_a)) - set(fix_indices))) 78 | _M_star = _solve(_H0[:, :, active_indices], 79 | _H1[:, :, active_indices]) 80 | M_star = make_identity(_H1.shape[0], _H1.shape[-1], _H1.device) 81 | M_star[:, active_indices[:, np.newaxis], active_indices] = _M_star 82 | else: 83 | M_star = _solve(_H0, _H1) 84 | dyn_fn = self.DynFn(M_star) 85 | loss_internal_T = _loss(dyn_fn(H0), H1) 86 | ppe.reporting.report({ 87 | 'loss_internal_T': loss_internal_T.item() 88 | }) 89 | 90 | # M_star is returned in the form of module, not the matrix 91 | if return_loss: 92 | losses = (loss_bd(dyn_fn.M, self.alignment), 93 | loss_orth(dyn_fn.M), loss_internal_T) 94 | return dyn_fn, losses 95 | else: 96 | return dyn_fn 97 | 98 | 99 | class HigherOrderLinearTensorDynamicsLSTSQ(LinearTensorDynamicsLSTSQ): 100 | 101 | class DynFn(nn.Module): 102 | def __init__(self, M): 103 | super().__init__() 104 | self.M = M 105 | 106 | def __call__(self, Hs): 107 | nHs = [None]*len(Hs) 108 | for l in range(len(Hs)-1, -1, -1): 109 | if l == len(Hs)-1: 110 | nHs[l] = Hs[l] @ _rep_M(self.M, Hs[l].shape[1]) 111 | else: 112 | nHs[l] = Hs[l] @ nHs[l+1] 113 | return nHs 114 | 115 | def __init__(self, alignment=True, n_order=2): 116 | super().__init__(alignment) 117 | self.n_order = n_order 118 | 119 | def __call__(self, H, return_loss=False, fix_indices=None): 120 | assert H.shape[1] > self.n_order 121 | # H0.shape = H1.shape [n, t, s, a] 122 | H0, Hn = H[:, :-self.n_order], H[:, self.n_order:] 123 | loss_internal_0 = _loss(H0, Hn) 124 | ppe.reporting.report({ 125 | 'loss_internal_0': loss_internal_0.item() 126 | }) 127 | Ms = [] 128 | _H = H 129 | if fix_indices is not None: 130 | raise NotImplementedError 131 | else: 132 | for n in range(self.n_order): 133 | # H0.shape = H1.shape [n, t, s, a] 134 | _H0, _H1 = _H[:, :-1], _H[:, 1:] 135 | if n == self.n_order - 1: 136 | _H0 = _H0.reshape(_H0.shape[0], -1, _H0.shape[-1]) 137 | _H1 = _H1.reshape(_H1.shape[0], -1, _H1.shape[-1]) 138 | _H = _solve(_H0, _H1) # [N, a, a] 139 | else: 140 | _H = _solve(_H0, _H1)[:, 1:] # [N, T-n, a, a] 141 | Ms.append(_H) 142 | dyn_fn = self.DynFn(Ms[-1]) 143 | loss_internal_T = _loss(dyn_fn([H0] + Ms[:-1])[0], Hn) 144 | ppe.reporting.report({ 145 | 'loss_internal_T': loss_internal_T.item() 146 | }) 147 | # M_star is returned in the form of module, not the matrix 148 | if return_loss: 149 | losses = (loss_bd(dyn_fn.M, self.alignment), 150 | loss_orth(dyn_fn.M), loss_internal_T) 151 | return dyn_fn, Ms[:-1], losses 152 | else: 153 | return dyn_fn, Ms[:-1] 154 | 155 | # The fixed block model 156 | 157 | 158 | class MultiLinearTensorDynamicsLSTSQ(LinearTensorDynamicsLSTSQ): 159 | 160 | def __init__(self, dim_a, alignment=True, K=4): 161 | super().__init__(alignment=alignment) 162 | self.dim_a = dim_a 163 | self.alignment = alignment 164 | assert dim_a % K == 0 165 | self.K = K 166 | 167 | def __call__(self, H, return_loss=False, fix_indices=None): 168 | H0, H1 = H[:, :-1], H[:, 1:] 169 | # num_ts x ([len_ts -1] * dim_s) x dim_a 170 | 171 | # The difference between the the time shifted components 172 | loss_internal_0 = _loss(H0, H1) 173 | 174 | _H0 = H0.reshape(H.shape[0], -1, H.shape[3]) 175 | _H1 = H1.reshape(H.shape[0], -1, H.shape[3]) 176 | 177 | ppe.reporting.report({ 178 | 'loss_internal_0': loss_internal_0.item() 179 | }) 180 | M_stars = [] 181 | for k in range(self.K): 182 | if fix_indices is not None and k in fix_indices: 183 | M_stars.append(make_identity( 184 | H.shape[0], self.dim_a//self.K, H.device)) 185 | else: 186 | st = k*(self.dim_a//self.K) 187 | ed = (k+1)*(self.dim_a//self.K) 188 | M_stars.append(_solve(_H0[:, :, st:ed], _H1[:, :, st:ed])) 189 | 190 | # Contstruct block diagonals 191 | for k in range(self.K): 192 | if k == 0: 193 | M_star = M_stars[0] 194 | else: 195 | M1 = M_star 196 | M2 = M_stars[k] 197 | _M1 = torch.cat( 198 | [M1, torch.zeros(H.shape[0], M2.shape[1], M1.shape[2]).to(H.device)], axis=1) 199 | _M2 = torch.cat( 200 | [torch.zeros(H.shape[0], M1.shape[1], M2.shape[2]).to(H.device), M2], axis=1) 201 | M_star = torch.cat([_M1, _M2], axis=2) 202 | dyn_fn = self.DynFn(M_star) 203 | loss_internal_T = _loss(dyn_fn(H0), H1) 204 | ppe.reporting.report({ 205 | 'loss_internal_T': loss_internal_T.item() 206 | }) 207 | 208 | # M_star is returned in the form of module, not the matrix 209 | if return_loss: 210 | losses = (loss_bd(dyn_fn.M, self.alignment), 211 | loss_orth(dyn_fn.M), loss_internal_T) 212 | return dyn_fn, losses 213 | else: 214 | return dyn_fn 215 | -------------------------------------------------------------------------------- /models/resblock.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from utils.weight_standarization import WeightStandarization, WeightStandarization1d 8 | import torch.nn.utils.parametrize as P 9 | from utils.emb2d import Emb2D 10 | 11 | 12 | def upsample_conv(x, conv): 13 | # Upsample -> Conv 14 | x = nn.Upsample(scale_factor=2, mode='nearest')(x) 15 | x = conv(x) 16 | return x 17 | 18 | 19 | def conv_downsample(x, conv): 20 | # Conv -> Downsample 21 | x = conv(x) 22 | h = F.avg_pool2d(x, 2) 23 | return h 24 | 25 | 26 | class Block(nn.Module): 27 | def __init__(self, 28 | in_channels, 29 | out_channels, 30 | hidden_channels=None, 31 | kernel_size=3, 32 | padding=None, 33 | activation=F.relu, 34 | resample=None, 35 | group_norm=True, 36 | skip_connection=True, 37 | posemb=False): 38 | super(Block, self).__init__() 39 | if padding is None: 40 | padding = (kernel_size-1) // 2 41 | self.pe = Emb2D() if posemb else lambda x: x 42 | 43 | in_ch_conv = in_channels + self.pe.dim if posemb else in_channels 44 | self.skip_connection = skip_connection 45 | self.activation = activation 46 | self.resample = resample 47 | initializer = torch.nn.init.xavier_uniform_ 48 | if self.resample is None or self.resample == 'up': 49 | hidden_channels = out_channels if hidden_channels is None else hidden_channels 50 | else: 51 | hidden_channels = in_channels if hidden_channels is None else hidden_channels 52 | self.c1 = nn.Conv2d(in_ch_conv, hidden_channels, 53 | kernel_size=kernel_size, padding=padding) 54 | self.c2 = nn.Conv2d(hidden_channels, out_channels, 55 | kernel_size=kernel_size, padding=padding) 56 | initializer(self.c1.weight, math.sqrt(2)) 57 | initializer(self.c2.weight, math.sqrt(2)) 58 | P.register_parametrization( 59 | self.c1, 'weight', WeightStandarization()) 60 | P.register_parametrization( 61 | self.c2, 'weight', WeightStandarization()) 62 | 63 | if group_norm: 64 | self.b1 = nn.GroupNorm(min(32, in_channels), in_channels) 65 | self.b2 = nn.GroupNorm(min(32, hidden_channels), hidden_channels) 66 | else: 67 | self.b1 = self.b2 = lambda x: x 68 | if self.skip_connection: 69 | self.c_sc = nn.Conv2d(in_ch_conv, out_channels, 70 | kernel_size=1, padding=0) 71 | initializer(self.c_sc.weight) 72 | 73 | def residual(self, x): 74 | x = self.b1(x) 75 | x = self.activation(x) 76 | if self.resample == 'up': 77 | x = nn.Upsample(scale_factor=2, mode='nearest')(x) 78 | x = self.pe(x) 79 | x = self.c1(x) 80 | x = self.b2(x) 81 | x = self.activation(x) 82 | x = self.c2(x) 83 | if self.resample == 'down': 84 | x = F.avg_pool2d(x, 2) 85 | return x 86 | 87 | def shortcut(self, x): 88 | # Upsample -> Conv 89 | if self.resample == 'up': 90 | x = nn.Upsample(scale_factor=2, mode='nearest')(x) 91 | x = self.pe(x) 92 | x = self.c_sc(x) 93 | 94 | elif self.resample == 'down': 95 | x = self.pe(x) 96 | x = self.c_sc(x) 97 | x = F.avg_pool2d(x, 2) 98 | else: 99 | x = self.pe(x) 100 | x = self.c_sc(x) 101 | return x 102 | 103 | def __call__(self, x): 104 | if self.skip_connection: 105 | return self.residual(x) + self.shortcut(x) 106 | else: 107 | return self.residual(x) 108 | 109 | 110 | class Conv1d1x1Block(nn.Module): 111 | def __init__(self, 112 | in_channels, 113 | out_channels, 114 | hidden_channels=None, 115 | act=F.relu): 116 | super().__init__() 117 | 118 | self.act = act 119 | initializer = torch.nn.init.xavier_uniform_ 120 | hidden_channels = out_channels if hidden_channels is None else hidden_channels 121 | self.c1 = nn.Conv1d(in_channels, hidden_channels, 1, 1, 0) 122 | self.c2 = nn.Conv1d(hidden_channels, out_channels, 1, 1, 0) 123 | initializer(self.c1.weight, math.sqrt(2)) 124 | initializer(self.c2.weight, math.sqrt(2)) 125 | P.register_parametrization( 126 | self.c1, 'weight', WeightStandarization1d()) 127 | P.register_parametrization( 128 | self.c2, 'weight', WeightStandarization1d()) 129 | self.norm1 = nn.LayerNorm((in_channels)) 130 | self.norm2 = nn.LayerNorm((hidden_channels)) 131 | self.c_sc = nn.Conv1d(in_channels, out_channels, 1, 1, 0) 132 | initializer(self.c_sc.weight) 133 | 134 | def residual(self, x): 135 | x = self.norm1(x.transpose(-2, -1)).transpose(-2, -1) 136 | x = self.act(x) 137 | x = self.c1(x) 138 | x = self.norm2(x.transpose(-2, -1)).transpose(-2, -1) 139 | x = self.act(x) 140 | x = self.c2(x) 141 | return x 142 | 143 | def shortcut(self, x): 144 | x = self.c_sc(x) 145 | return x 146 | 147 | def __call__(self, x): 148 | return self.residual(x) + self.shortcut(x) -------------------------------------------------------------------------------- /models/seqae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from models import dynamics_models 5 | import torch.nn.utils.parametrize as P 6 | from models.dynamics_models import LinearTensorDynamicsLSTSQ, MultiLinearTensorDynamicsLSTSQ, HigherOrderLinearTensorDynamicsLSTSQ 7 | from models.base_networks import ResNetEncoder, ResNetDecoder, Conv1d1x1Encoder 8 | from einops import rearrange, repeat 9 | from utils.clr import simclr 10 | 11 | 12 | class SeqAELSTSQ(nn.Module): 13 | def __init__( 14 | self, 15 | dim_a, 16 | dim_m, 17 | alignment=False, 18 | ch_x=3, 19 | k=1.0, 20 | kernel_size=3, 21 | change_of_basis=False, 22 | predictive=True, 23 | bottom_width=4, 24 | n_blocks=3, 25 | *args, 26 | **kwargs): 27 | super().__init__() 28 | self.dim_a = dim_a 29 | self.dim_m = dim_m 30 | self.predictive = predictive 31 | self.enc = ResNetEncoder( 32 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks) 33 | self.dec = ResNetDecoder( 34 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks) 35 | self.dynamics_model = LinearTensorDynamicsLSTSQ(alignment=alignment) 36 | if change_of_basis: 37 | self.change_of_basis = nn.Parameter( 38 | torch.empty(dim_a, dim_a)) 39 | nn.init.eye_(self.change_of_basis) 40 | 41 | def _encode_base(self, xs, enc): 42 | shape = xs.shape 43 | x = torch.reshape(xs, (shape[0] * shape[1], *shape[2:])) 44 | H = enc(x) 45 | H = torch.reshape( 46 | H, (shape[0], shape[1], *H.shape[1:])) 47 | return H 48 | 49 | def encode(self, xs): 50 | H = self._encode_base(xs, self.enc) 51 | H = torch.reshape( 52 | H, (H.shape[0], H.shape[1], self.dim_m, self.dim_a)) 53 | if hasattr(self, "change_of_basis"): 54 | H = H @ repeat(self.change_of_basis, 55 | 'a1 a2 -> n t a1 a2', n=H.shape[0], t=H.shape[1]) 56 | return H 57 | 58 | def phi(self, xs): 59 | return self._encode_base(xs, self.enc.phi) 60 | 61 | def get_M(self, xs): 62 | dyn_fn = self.dynamics_fn(xs) 63 | return dyn_fn.M 64 | 65 | def decode(self, H): 66 | if hasattr(self, "change_of_basis"): 67 | H = H @ repeat(torch.linalg.inv(self.change_of_basis), 68 | 'a1 a2 -> n t a1 a2', n=H.shape[0], t=H.shape[1]) 69 | n, t = H.shape[:2] 70 | if hasattr(self, "pidec"): 71 | H = rearrange(H, 'n t d_s d_a -> (n t) d_a d_s') 72 | H = self.pidec(H) 73 | else: 74 | H = rearrange(H, 'n t d_s d_a -> (n t) (d_s d_a)') 75 | x_next_preds = self.dec(H) 76 | x_next_preds = torch.reshape( 77 | x_next_preds, (n, t, *x_next_preds.shape[1:])) 78 | return x_next_preds 79 | 80 | def dynamics_fn(self, xs, return_loss=False, fix_indices=None): 81 | H = self.encode(xs) 82 | return self.dynamics_model(H, return_loss=return_loss, fix_indices=fix_indices) 83 | 84 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False): 85 | xs_cond = xs[:, :T_cond] 86 | xs_pred = self(xs_cond, return_reg_loss=return_reg_loss, 87 | n_rolls=xs.shape[1] - T_cond, predictive=self.predictive, reconst=reconst) 88 | if return_reg_loss: 89 | xs_pred, reg_losses = xs_pred 90 | if reconst: 91 | xs_target = xs 92 | else: 93 | xs_target = xs[:, T_cond:] if self.predictive else xs[:, 1:] 94 | loss = torch.mean( 95 | torch.sum((xs_target - xs_pred) ** 2, axis=[2, 3, 4])) 96 | return (loss, reg_losses) if return_reg_loss else loss 97 | 98 | def __call__(self, xs_cond, return_reg_loss=False, n_rolls=1, fix_indices=None, predictive=True, reconst=False): 99 | # Encoded Latent. Num_ts x len_ts x dim_m x dim_a 100 | H = self.encode(xs_cond) 101 | 102 | # ==Esitmate dynamics== 103 | ret = self.dynamics_model( 104 | H, return_loss=return_reg_loss, fix_indices=fix_indices) 105 | if return_reg_loss: 106 | # fn is a map by M_star. Loss is the training external loss 107 | fn, losses = ret 108 | else: 109 | fn = ret 110 | 111 | if predictive: 112 | H_last = H[:, -1:] 113 | H_preds = [H] if reconst else [] 114 | array = np.arange(n_rolls) 115 | else: 116 | H_last = H[:, :1] 117 | H_preds = [H[:, :1]] if reconst else [] 118 | array = np.arange(xs_cond.shape[1] + n_rolls - 1) 119 | 120 | for _ in array: 121 | H_last = fn(H_last) 122 | H_preds.append(H_last) 123 | H_preds = torch.cat(H_preds, axis=1) 124 | # Prediction in the observation space 125 | x_preds = self.decode(H_preds) 126 | if return_reg_loss: 127 | return x_preds, losses 128 | else: 129 | return x_preds 130 | 131 | 132 | def loss_equiv(self, xs, T_cond=2, reduce=False): 133 | bsize = len(xs) 134 | xs_cond = xs[:, :T_cond] 135 | xs_target = xs[:, T_cond:] 136 | H = self.encode(xs_cond[:, -1:]) 137 | dyn_fn = self.dynamics_fn(xs_cond) 138 | 139 | H_last = H 140 | H_preds = [] 141 | n_rolls = xs.shape[1] - T_cond 142 | for _ in np.arange(n_rolls): 143 | H_last = dyn_fn(H_last) 144 | H_preds.append(H_last) 145 | H_pred = torch.cat(H_preds, axis=1) 146 | # swapping M 147 | dyn_fn.M = dyn_fn.M[torch.arange(-1, bsize-1)] 148 | 149 | H_last = H 150 | H_preds_perm = [] 151 | for _ in np.arange(n_rolls): 152 | H_last = dyn_fn(H_last) 153 | H_preds_perm.append(H_last) 154 | H_pred_perm = torch.cat(H_preds_perm, axis=1) 155 | 156 | xs_pred = self.decode(H_pred) 157 | xs_pred_perm = self.decode(H_pred_perm) 158 | reduce_dim = (1,2,3,4,5) if reduce else (2,3,4) 159 | loss = torch.sum((xs_target-xs_pred)**2, dim=reduce_dim).detach().cpu().numpy() 160 | loss_perm = torch.sum((xs_target-xs_pred_perm)**2, dim=reduce_dim).detach().cpu().numpy() 161 | return loss, loss_perm 162 | 163 | 164 | class SeqAEHOLSTSQ(SeqAELSTSQ): 165 | # Higher order version of SeqAELSTSQ 166 | def __init__( 167 | self, 168 | dim_a, 169 | dim_m, 170 | alignment=False, 171 | ch_x=3, 172 | k=1.0, 173 | kernel_size=3, 174 | change_of_basis=False, 175 | predictive=True, 176 | bottom_width=4, 177 | n_blocks=3, 178 | n_order=2, 179 | *args, 180 | **kwargs): 181 | super(SeqAELSTSQ, self).__init__() 182 | self.dim_a = dim_a 183 | self.dim_m = dim_m 184 | self.predictive = predictive 185 | self.enc = ResNetEncoder( 186 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks) 187 | self.dec = ResNetDecoder( 188 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks) 189 | self.dynamics_model = HigherOrderLinearTensorDynamicsLSTSQ( 190 | alignment=alignment, n_order=n_order) 191 | if change_of_basis: 192 | self.change_of_basis = nn.Parameter( 193 | torch.empty(dim_a, dim_a)) 194 | nn.init.eye_(self.change_of_basis) 195 | 196 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False): 197 | if reconst: 198 | raise NotImplementedError 199 | xs_cond = xs[:, :T_cond] 200 | xs_pred = self(xs_cond, predictive=self.predictive, return_reg_loss=return_reg_loss, 201 | n_rolls=xs.shape[1] - T_cond) 202 | if return_reg_loss: 203 | xs_pred, reg_losses = xs_pred 204 | xs_target = xs[:, T_cond:] if self.predictive else xs[:, 1:] 205 | loss = torch.mean( 206 | torch.sum((xs_target - xs_pred) ** 2, axis=[2, 3, 4])) 207 | return (loss, reg_losses) if return_reg_loss else loss 208 | 209 | 210 | def __call__(self, xs, n_rolls=1, fix_indices=None, predictive=True, return_reg_loss=False): 211 | # Encoded Latent. Num_ts x len_ts x dim_m x dim_a 212 | H = self.encode(xs) 213 | 214 | # ==Esitmate dynamics== 215 | ret = self.dynamics_model( 216 | H, return_loss=return_reg_loss, fix_indices=fix_indices) 217 | if return_reg_loss: 218 | # fn is a map by M_star. Loss is the training external loss 219 | fn, Ms, losses = ret 220 | else: 221 | fn, Ms = ret 222 | 223 | if predictive: 224 | Hs_last = [H[:, -1:]] + [M[:, -1:] for M in Ms] 225 | array = np.arange(n_rolls) 226 | else: 227 | Hs_last = [H[:, :1]] + [M[:, :1] for M in Ms] 228 | array = np.arange(xs.shape[1] + n_rolls - 1) 229 | 230 | # Create prediction for the unseen future 231 | H_preds = [] 232 | for _ in array: 233 | Hs_last = fn(Hs_last) 234 | H_preds.append(Hs_last[0]) 235 | H_preds = torch.cat(H_preds, axis=1) 236 | x_preds = self.decode(H_preds) 237 | if return_reg_loss: 238 | return x_preds, losses 239 | else: 240 | return x_preds 241 | 242 | def loss_equiv(self, xs, T_cond=5, reduce=False, return_generated_images=False): 243 | bsize = len(xs) 244 | xs_cond = xs[:, :T_cond] 245 | xs_target = xs[:, T_cond:] 246 | H = self.encode(xs_cond[:, -1:]) 247 | dyn_fn, Ms = self.dynamics_fn(xs_cond) 248 | 249 | H_last = [H] + [M[:, -1:] for M in Ms] 250 | H_preds = [] 251 | n_rolls = xs.shape[1] - T_cond 252 | for _ in np.arange(n_rolls): 253 | H_last = dyn_fn(H_last) 254 | H_preds.append(H_last[0]) 255 | H_pred = torch.cat(H_preds, axis=1) 256 | # swapping M 257 | dyn_fn.M = dyn_fn.M[torch.arange(-1, bsize-1)] 258 | Ms = [M[torch.arange(-1, bsize-1)] for M in Ms] 259 | 260 | H_last = [H] + [M[:, -1:] for M in Ms] 261 | H_preds_perm = [] 262 | for _ in np.arange(n_rolls): 263 | H_last = dyn_fn(H_last) 264 | H_preds_perm.append(H_last[0]) 265 | H_pred_perm = torch.cat(H_preds_perm, axis=1) 266 | 267 | xs_pred = self.decode(H_pred) 268 | xs_pred_perm = self.decode(H_pred_perm) 269 | loss = torch.sum((xs_target-xs_pred)**2, dim=(2,3,4)).detach().cpu().numpy() 270 | loss_perm = torch.sum((xs_target-xs_pred_perm)**2, dim=(2,3,4)).detach().cpu().numpy() 271 | if reduce: 272 | loss = torch.mean(loss) 273 | loss_perm = torch.mean(loss_perm) 274 | if return_generated_images: 275 | return (loss, loss_perm), (xs_pred, xs_pred_perm) 276 | else: 277 | return loss, loss_perm 278 | 279 | 280 | 281 | class SeqAEMultiLSTSQ(SeqAELSTSQ): 282 | def __init__( 283 | self, 284 | dim_a, 285 | dim_m, 286 | alignment=False, 287 | ch_x=3, 288 | k=1.0, 289 | kernel_size=3, 290 | change_of_basis=False, 291 | predictive=True, 292 | bottom_width=4, 293 | n_blocks=3, 294 | K=8, 295 | *args, 296 | **kwargs): 297 | super(SeqAELSTSQ, self).__init__() 298 | self.dim_a = dim_a 299 | self.dim_m = dim_m 300 | self.predictive = predictive 301 | self.K = K 302 | self.enc = ResNetEncoder( 303 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks) 304 | self.dec = ResNetDecoder( 305 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks) 306 | self.dynamics_model = MultiLinearTensorDynamicsLSTSQ( 307 | dim_a, alignment=alignment, K=K) 308 | if change_of_basis: 309 | self.change_of_basis = nn.Parameter( 310 | torch.empty(dim_a, dim_a)) 311 | nn.init.eye_(self.change_of_basis) 312 | 313 | def get_blocks_of_M(self, xs): 314 | M = self.get_M(xs) 315 | blocks = [] 316 | for k in range(self.K): 317 | dim_block = self.dim_a // self.K 318 | blocks.append(M[:, k*dim_block:(k+1)*dim_block] 319 | [:, :, k*dim_block:(k+1)*dim_block]) 320 | blocks_of_M = torch.stack(blocks, 1) 321 | return blocks_of_M 322 | 323 | 324 | class SeqAENeuralM(SeqAELSTSQ): 325 | def __init__( 326 | self, 327 | dim_a, 328 | dim_m, 329 | ch_x=3, 330 | k=1.0, 331 | alignment=False, 332 | kernel_size=3, 333 | predictive=True, 334 | bottom_width=4, 335 | n_blocks=3, 336 | *args, 337 | **kwargs): 338 | super(SeqAELSTSQ, self).__init__() 339 | self.dim_a = dim_a 340 | self.dim_m = dim_m 341 | self.predictive = predictive 342 | self.alignment = alignment 343 | self.initial_scale_M = 0.01 344 | self.enc = ResNetEncoder( 345 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks) 346 | self.M_net = ResNetEncoder( 347 | dim_a*dim_a, k=k, kernel_size=kernel_size, n_blocks=n_blocks) 348 | self.dec = ResNetDecoder( 349 | ch_x, k=k, kernel_size=kernel_size, n_blocks=n_blocks, bottom_width=bottom_width) 350 | 351 | def dynamics_fn(self, xs): 352 | M = self.get_M(xs) 353 | dyn_fn = dynamics_models.LinearTensorDynamicsLSTSQ.DynFn(M) 354 | return dyn_fn 355 | 356 | def get_M(self, xs): 357 | xs = rearrange(xs, 'n t c h w -> n (t c) h w') 358 | M = self.M_net(xs) 359 | M = rearrange(M, 'n (a_1 a_2) -> n a_1 a_2', a_1=self.dim_a) 360 | M = self.initial_scale_M * M 361 | return M 362 | 363 | def __call__(self, xs, n_rolls=1, return_reg_loss=False, predictive=True, reconst=False): 364 | # ==Esitmate dynamics== 365 | fn = self.dynamics_fn(xs) 366 | 367 | if reconst: 368 | H = self.encode(xs) 369 | if predictive: 370 | H_last = H[:, -1:] 371 | else: 372 | H_last = H[:, :1] 373 | else: 374 | H_last = self.encode(xs[:, -1:] if predictive else xs[:, :1]) 375 | 376 | if predictive: 377 | H_preds = [H] if reconst else [] 378 | array = np.arange(n_rolls) 379 | else: 380 | H_preds = [H[:, :1]] if reconst else [] 381 | array = np.arange(xs.shape[1] + n_rolls - 1) 382 | 383 | # Create prediction for the unseen future 384 | for _ in array: 385 | H_last = fn(H_last) 386 | H_preds.append(H_last) 387 | H_preds = torch.cat(H_preds, axis=1) 388 | x_preds = self.decode(H_preds) 389 | if return_reg_loss: 390 | losses = (dynamics_models.loss_bd(fn.M, self.alignment), 391 | dynamics_models.loss_orth(fn.M), 0) 392 | return x_preds, losses 393 | else: 394 | return x_preds 395 | 396 | class SeqAENeuralMLatentPredict(SeqAENeuralM): 397 | def __init__(self, 398 | dim_a, 399 | dim_m, 400 | ch_x=3, 401 | k=1.0, 402 | alignment=False, 403 | kernel_size=3, 404 | predictive=True, 405 | bottom_width=4, 406 | n_blocks=3, 407 | loss_latent_coeff=0, 408 | loss_pred_coeff=1.0, 409 | loss_reconst_coeff=0, 410 | normalize=True, 411 | *args, 412 | **kwargs): 413 | assert predictive 414 | super().__init__( 415 | dim_a=dim_a, 416 | dim_m=dim_m, 417 | ch_x=ch_x, 418 | k=k, 419 | alignment=alignment, 420 | kernel_size=kernel_size, 421 | predictive=predictive, 422 | bottom_width=bottom_width, 423 | n_blocks=n_blocks, 424 | ) 425 | self.loss_reconst_coeff = loss_reconst_coeff 426 | self.loss_pred_coeff = loss_pred_coeff 427 | self.loss_latent_coeff = loss_latent_coeff 428 | self.normalize = normalize 429 | 430 | def normalize_isotypic_copy(self, H): 431 | isotype_norm = torch.sqrt(torch.sum(H**2, axis=2, keepdims=True)) 432 | H = H / isotype_norm 433 | return H 434 | 435 | #encoding function with isotypic column normalization 436 | def encode(self, xs): 437 | H = super().encode(xs) 438 | if self.normalize: 439 | H = self.normalize_isotypic_copy(H) 440 | return H 441 | 442 | def latent_error(self, H_preds, H_target): 443 | latent_e = torch.mean(torch.sum((H_preds - H_target)**2, axis=(2,3))) 444 | return latent_e 445 | 446 | def obs_error(self, xs_1, xs_2): 447 | obs_e = torch.mean(torch.sum((xs_1 - xs_2)**2, axis=(2,3,4))) 448 | return obs_e 449 | 450 | def __call__(self, xs, n_rolls=1, T_cond=2, return_losses=False, return_reg_losses=False): 451 | xs_cond, xs_target =xs[:, :T_cond], xs[:, T_cond:] 452 | fn = self.dynamics_fn(xs_cond) 453 | H_cond, H_target = self.encode(xs_cond), self.encode(xs_target) 454 | H_last = H_cond[:, -1:] 455 | H_preds=[H_cond] 456 | array = np.arange(n_rolls) 457 | 458 | for _ in array: 459 | H_last = fn(H_last) 460 | H_preds.append(H_last) 461 | H_preds = torch.cat(H_preds, axis=1) 462 | xs_preds = self.decode(H_preds) 463 | ret = [xs_preds] 464 | if return_losses: 465 | losses = {} 466 | losses['loss_reconst'] = self.obs_error(xs_preds[:, :T_cond], xs_cond) if self.loss_reconst_coeff > 0 else torch.tensor([0]).to(xs.device) 467 | losses['loss_pred'] = self.obs_error(xs_preds[:, T_cond:], xs_target) if self.loss_pred_coeff > 0 else torch.tensor([0]).to(xs.device) 468 | losses['loss_latent'] = self.latent_error(H_preds[:, T_cond:], H_target) if self.loss_latent_coeff > 0 else torch.tensor([0]).to(xs.device) 469 | ret += [losses] 470 | if return_reg_losses: 471 | ret += [(dynamics_models.loss_bd(fn.M, self.alignment), 472 | dynamics_models.loss_orth(fn.M), 0)] 473 | return ret 474 | 475 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False): 476 | ret = self(xs, return_losses=True, return_reg_losses=return_reg_loss, T_cond= T_cond, 477 | n_rolls=xs.shape[1] - T_cond) 478 | if return_reg_loss: 479 | _, losses, reg_losses = ret 480 | else: 481 | _, losses = ret 482 | 483 | total_loss = self.loss_reconst_coeff * losses['loss_reconst'] \ 484 | + self.loss_pred_coeff * losses['loss_pred'] \ 485 | + self.loss_latent_coeff * losses['loss_latent'] 486 | return (total_loss, reg_losses) if return_reg_loss else total_loss 487 | 488 | 489 | 490 | class SeqAENeuralTransition(SeqAELSTSQ): 491 | def __init__( 492 | self, 493 | dim_a, 494 | dim_m, 495 | ch_x=3, 496 | k=1.0, 497 | kernel_size=3, 498 | T_cond=2, 499 | bottom_width=4, 500 | n_blocks=3, 501 | *args, 502 | **kwargs): 503 | super(SeqAELSTSQ, self).__init__() 504 | self.dim_a = dim_a 505 | self.dim_m = dim_m 506 | self.T_cond = T_cond 507 | self.enc = ResNetEncoder( 508 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks) 509 | self.ar = Conv1d1x1Encoder(dim_out=dim_a) 510 | self.dec = ResNetDecoder( 511 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks) 512 | 513 | def loss(self, xs, return_reg_loss=False, T_cond=2, reconst=False): 514 | assert T_cond == self.T_cond 515 | xs_cond = xs[:, :T_cond] 516 | xs_pred = self(xs_cond, n_rolls=xs.shape[1] - T_cond, reconst=reconst) 517 | xs_target = xs if reconst else xs[:, T_cond:] 518 | loss = torch.mean( 519 | torch.sum((xs_target - xs_pred) ** 2, axis=[2, 3, 4])) 520 | if return_reg_loss: 521 | return loss, [torch.Tensor(np.array(0, dtype=np.float32)).to(xs.device)] * 3 522 | else: 523 | return loss 524 | 525 | def get_M(self, xs): 526 | T = xs.shape[1] 527 | xs = rearrange(xs, 'n t c h w -> (n t) c h w') 528 | H = self.enc(xs) 529 | H = rearrange(H, '(n t) c -> n (t c)', t=T) 530 | return H 531 | 532 | def __call__(self, xs, n_rolls=1, reconst=False): 533 | # ==Esitmate dynamics== 534 | H = self.encode(xs) 535 | 536 | array = np.arange(n_rolls) 537 | H_preds = [H] if reconst else [] 538 | # Create prediction for the unseen future 539 | for _ in array: 540 | H_pred = self.ar(rearrange(H, 'n t s a -> n (t a) s')) 541 | H_pred = rearrange( 542 | H_pred, 'n (t a) s-> n t s a', t=1, a=self.dim_a) 543 | H_preds.append(H_pred) 544 | H = torch.cat([H[:, 1:], H_pred], dim=1) 545 | H_preds = torch.cat(H_preds, axis=1) 546 | # Prediction in the observation space 547 | return self.decode(H_preds) 548 | 549 | 550 | class CPC(SeqAELSTSQ): 551 | def __init__( 552 | self, 553 | dim_a, 554 | dim_m, 555 | k=1.0, 556 | kernel_size=3, 557 | temp=0.01, 558 | normalize=True, 559 | loss_type='cossim', 560 | n_blocks=3, 561 | *args, 562 | **kwargs): 563 | super(SeqAELSTSQ, self).__init__() 564 | self.dim_a = dim_a 565 | self.dim_m = dim_m 566 | self.normalize = normalize 567 | self.temp = temp 568 | self.loss_type = loss_type 569 | self.enc = ResNetEncoder( 570 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks) 571 | self.ar = Conv1d1x1Encoder(dim_out=dim_a*dim_m) 572 | 573 | def __call__(self, xs): 574 | H = self.encode(xs) # [n, t, s, a] 575 | 576 | # Create prediction for the unseen future 577 | H = rearrange(H, 'n t s a -> n (t a) s') 578 | 579 | # Obtain c in CPC 580 | H_pred = self.ar(H) # [n a s] 581 | H_pred = rearrange(H_pred, 'n a s -> n s a') 582 | 583 | return H_pred 584 | 585 | def get_M(self, xs): 586 | T = xs.shape[1] 587 | xs = rearrange(xs, 'n t c h w -> (n t) c h w') 588 | H = self.enc(xs) 589 | H = rearrange(H, '(n t) c -> n (t c)', t=T) 590 | return H 591 | 592 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False): 593 | T_pred = xs.shape[1] - T_cond 594 | assert T_pred == 1 595 | # Encoded Latent. Num_ts x len_ts x dim_m x dim_a 596 | H = self.encode(xs) # [n, t, s, a] 597 | 598 | # Create prediction for the unseen future 599 | H_cond = H[:, :T_cond] 600 | H_cond = rearrange(H_cond, 'n t s a -> n (t a) s') 601 | 602 | # Obtain c in CPC 603 | H_pred = self.ar(H_cond) # [n a s] 604 | H_pred = rearrange(H_pred, 'n a s -> n s a') 605 | 606 | H_true = H[:, -1] # n s a 607 | H_true = rearrange(H_true, 'n s a -> n (s a)') 608 | H_pred = rearrange(H_pred, 'n s a -> n (s a)') 609 | loss = simclr([H_pred, H_true], self.temp, 610 | normalize=self.normalize, loss_type=self.loss_type) 611 | if return_reg_loss: 612 | reg_losses = [torch.Tensor(np.array(0, dtype=np.float32))]*3 613 | return loss, reg_losses 614 | else: 615 | return loss 616 | -------------------------------------------------------------------------------- /models/simclr_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.base_networks import ResNetEncoder 4 | from einops import rearrange 5 | 6 | 7 | class ResNetwProjHead(nn.Module): 8 | 9 | def __init__(self, dim_mlp=512, dim_head=128, k=1, act=nn.ReLU(), n_blocks=3): 10 | super().__init__() 11 | self.enc = ResNetEncoder( 12 | dim_latent=0, k=k, n_blocks=n_blocks) 13 | self.projhead = nn.Sequential( 14 | nn.LazyLinear(dim_mlp), 15 | act, 16 | nn.LazyLinear(dim_head)) 17 | 18 | def _encode_base(self, xs, enc): 19 | shape = xs.shape 20 | x = torch.reshape(xs, (shape[0] * shape[1], *shape[2:])) 21 | H = enc(x) 22 | H = torch.reshape(H, (shape[0], shape[1], *H.shape[1:])) 23 | return H 24 | 25 | def __call__(self, xs): 26 | return self._encode_base(xs, lambda x: self.projhead(self.enc(x))) 27 | 28 | def phi(self, xs): 29 | return self._encode_base(xs, self.enc.phi) 30 | 31 | def get_M(self, xs): 32 | T = xs.shape[1] 33 | xs = rearrange(xs, 'n t c h w -> (n t) c h w') 34 | H = self.enc(xs) 35 | H = rearrange(H, '(n t) c -> n (t c)', t=T) 36 | return H 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.0 2 | pytorch-pfn-extras 3 | scikit-image 4 | opencv-python 5 | pytorch_pfn_extras 6 | moviepy 7 | ffmpeg-python 8 | einops 9 | tqdm -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import yaml 4 | import copy 5 | import functools 6 | import random 7 | import argparse 8 | import numpy as np 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torch.utils.data import DataLoader 13 | import pytorch_pfn_extras as ppe 14 | from pytorch_pfn_extras.training import extensions 15 | from utils import yaml_utils as yu 16 | 17 | 18 | def train(config): 19 | 20 | torch.cuda.empty_cache() 21 | torch.manual_seed(config['seed']) 22 | random.seed(config['seed']) 23 | np.random.seed(config['seed']) 24 | 25 | if torch.cuda.is_available(): 26 | device = torch.device('cuda') 27 | cudnn.deterministic = True 28 | cudnn.benchmark = True 29 | else: 30 | device = torch.device('cpu') 31 | gpu_index = -1 32 | 33 | # Dataaset 34 | data = yu.load_component(config['train_data']) 35 | train_loader = DataLoader( 36 | data, batch_size=config['batchsize'], shuffle=True, num_workers=config['num_workers']) 37 | 38 | # Def. of Model and optimizer 39 | model = yu.load_component(config['model']) 40 | model.to(device) 41 | optimizer = torch.optim.Adam(model.parameters(), config['lr']) 42 | 43 | manager = ppe.training.ExtensionsManager( 44 | model, optimizer, None, 45 | iters_per_epoch=len(train_loader), 46 | out_dir=config['log_dir'], 47 | stop_trigger=(config['max_iteration'], 'iteration') 48 | ) 49 | 50 | manager.extend( 51 | extensions.PrintReport( 52 | ['epoch', 'iteration', 'train/loss', 'train/loss_bd', 'train/loss_orth', 'loss_internal_0', 'loss_internal_T', 'elapsed_time']), 53 | trigger=(config['report_freq'], 'iteration')) 54 | manager.extend(extensions.LogReport( 55 | trigger=(config['report_freq'], 'iteration'))) 56 | manager.extend( 57 | extensions.snapshot( 58 | target=model, filename='snapshot_model_iter_{.iteration}'), 59 | trigger=(config['model_snapshot_freq'], 'iteration')) 60 | manager.extend( 61 | extensions.snapshot( 62 | target=manager, filename='snapshot_manager_iter_{.iteration}', n_retains=1), 63 | trigger=(config['manager_snapshot_freq'], 'iteration')) 64 | # Run training loop 65 | print("Start training...") 66 | yu.load_component_fxn(config['training_loop'])( 67 | manager, model, optimizer, train_loader, config, device) 68 | 69 | 70 | if __name__ == '__main__': 71 | # Loading the configuration arguments from specified config path 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--log_dir', type=str) 74 | parser.add_argument('--config_path', type=str) 75 | parser.add_argument('-a', '--attrs', nargs='*', default=()) 76 | parser.add_argument('-w', '--warning', action='store_true') 77 | args = parser.parse_args() 78 | 79 | with open(args.config_path, 'r') as f: 80 | config = yaml.safe_load(f) 81 | config['config_path'] = args.config_path 82 | config['log_dir'] = args.log_dir 83 | 84 | # Modify the yaml file using attr 85 | for attr in args.attrs: 86 | module, new_value = attr.split('=') 87 | keys = module.split('.') 88 | target = functools.reduce(dict.__getitem__, keys[:-1], config) 89 | if keys[-1] in target.keys(): 90 | target[keys[-1]] = yaml.safe_load(new_value) 91 | else: 92 | raise ValueError('The following key is not defined in the config file:{}', keys) 93 | 94 | for k, v in sorted(config.items()): 95 | print("\t{} {}".format(k, v)) 96 | 97 | # create the result directory and save yaml 98 | if not os.path.exists(config['log_dir']): 99 | os.makedirs(config['log_dir']) 100 | 101 | _config = copy.deepcopy(config) 102 | configpath = os.path.join(config['log_dir'], "config.yml") 103 | open(configpath, 'w').write( 104 | yaml.dump(_config, default_flow_style=False) 105 | ) 106 | 107 | # Training 108 | train(config) 109 | -------------------------------------------------------------------------------- /training_allmodels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | LOGDIR_ROOT=$1 3 | DATADIR_ROOT=$2 4 | 5 | # Training on MNIST, MNIST-bg w/ digit 4, 3DShapes and smallNORB 6 | for seed in 1 2 3; do 7 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do 8 | for model_name in lstsq lstsq_multi lstsq_rec neuralM neural_trans; do 9 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 10 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 11 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT} 12 | done 13 | done 14 | done 15 | 16 | # Training on MNIST-bg w/ all digits 17 | for seed in 1 2 3; do 18 | for dataset_name in mnist_bg; do 19 | for model_name in lstsq lstsq_multi lstsq_rec; do 20 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \ 21 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 22 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 23 | done 24 | for model_name in neuralM neural_trans; do 25 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \ 26 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 27 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 training_loop.args.reconst_iter=200000 28 | done 29 | done 30 | done 31 | 32 | # Training on Accelerated Sequential MNIST 33 | for seed in 1 2 3; do 34 | for dataset_name in mnist_accl; do 35 | for model_name in lstsq holstsq neural_trans; do 36 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \ 37 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \ 38 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT} 39 | done 40 | done 41 | done 42 | 43 | -------------------------------------------------------------------------------- /training_loops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import pytorch_pfn_extras as ppe 5 | from utils.clr import simclr 6 | from utils.misc import freq_to_wave 7 | from tqdm import tqdm 8 | 9 | 10 | def loop_seqmodel(manager, model, optimizer, train_loader, config, device): 11 | while not manager.stop_trigger: 12 | for images in train_loader: 13 | with manager.run_iteration(): 14 | reconst = True if manager.iteration < config['training_loop']['args']['reconst_iter'] else False 15 | if manager.iteration >= config['training_loop']['args']['lr_decay_iter']: 16 | optimizer.param_groups[0]['lr'] = config['lr']/3. 17 | else: 18 | optimizer.param_groups[0]['lr'] = config['lr'] 19 | model.train() 20 | images = torch.stack(images).transpose(1, 0).to(device) 21 | loss, (loss_bd, loss_orth, _) = model.loss(images, T_cond=config['T_cond'], return_reg_loss=True, reconst=reconst) 22 | optimizer.zero_grad() 23 | loss.backward() 24 | optimizer.step() 25 | ppe.reporting.report({ 26 | 'train/loss': loss.item(), 27 | 'train/loss_bd': loss_bd.item(), 28 | 'train/loss_orth': loss_orth.item(), 29 | }) 30 | 31 | if manager.stop_trigger: 32 | break 33 | 34 | 35 | def loop_simclr(manager, model, optimizer, train_loader, config, device): 36 | while not manager.stop_trigger: 37 | for images in train_loader: 38 | with manager.run_iteration(): 39 | if manager.iteration >= config['training_loop']['args']['lr_decay_iter']: 40 | optimizer.param_groups[0]['lr'] = config['lr']/3. 41 | else: 42 | optimizer.param_groups[0]['lr'] = config['lr'] 43 | model.train() 44 | images = torch.stack(images, dim=1).to(device) # n t c h w 45 | zs = model(images) 46 | zs = [zs[:, i] for i in range(zs.shape[1])] 47 | loss = simclr( 48 | zs, 49 | loss_type=config['training_loop']['args']['loss_type'], 50 | temperature=config['training_loop']['args']['temp'] 51 | ) 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | ppe.reporting.report({ 56 | 'train/loss': loss.item(), 57 | }) 58 | if manager.stop_trigger: 59 | break -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/utils/__init__.py -------------------------------------------------------------------------------- /utils/clr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pytorch_pfn_extras as ppe 5 | 6 | 7 | def simclr(zs, temperature=1.0, normalize=True, loss_type='cossim'): 8 | if normalize: 9 | zs = [F.normalize(z, p=2, dim=1) for z in zs] 10 | m = len(zs) 11 | n = zs[0].shape[0] 12 | device = zs[0].device 13 | mask = torch.eye(n * m, device=device) 14 | label0 = torch.fmod(n + torch.arange(0, m * n, device=device), n * m) 15 | z = torch.cat(zs, 0) 16 | if loss_type == 'euclid': 17 | sim = - torch.cdist(z, z) 18 | elif loss_type == 'sq': 19 | sim = - torch.cdist(z, z) ** 2 20 | elif loss_type == 'cossim': 21 | sim = torch.matmul(z, z.transpose(0, 1)) 22 | else: 23 | raise NotImplementedError 24 | logit_zz = sim / temperature 25 | logit_zz += mask * -1e8 26 | loss = nn.CrossEntropyLoss()(logit_zz, label0) 27 | return loss 28 | -------------------------------------------------------------------------------- /utils/emb2d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Emb2D(nn.modules.lazy.LazyModuleMixin, nn.Module): 8 | def __init__(self, dim=64): 9 | super().__init__() 10 | self.dim = dim 11 | self.emb = torch.nn.parameter.UninitializedParameter() 12 | 13 | def __call__(self, x): 14 | if torch.nn.parameter.is_lazy(self.emb): 15 | _, h, w = x.shape[1:] 16 | self.emb.materialize((self.dim, h, w)) 17 | self.emb.data = positionalencoding2d(self.dim, h, w) 18 | emb = torch.tile(self.emb[None].to(x.device), [x.shape[0], 1, 1, 1]) 19 | x = torch.cat([x, emb], axis=1) 20 | return x 21 | 22 | # Copied from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py 23 | 24 | 25 | def positionalencoding2d(d_model, height, width): 26 | """ 27 | :param d_model: dimension of the model 28 | :param height: height of the positions 29 | :param width: width of the positions 30 | :return: d_model*height*width position matrix 31 | """ 32 | if d_model % 4 != 0: 33 | raise ValueError("Cannot use sin/cos positional encoding with " 34 | "odd dimension (got dim={:d})".format(d_model)) 35 | pe = torch.zeros(d_model, height, width) 36 | # Each dimension use half of d_model 37 | d_model = int(d_model / 2) 38 | div_term = torch.exp(torch.arange(0., d_model, 2) * 39 | -(math.log(10000.0) / d_model)) 40 | pos_w = torch.arange(0., width).unsqueeze(1) 41 | pos_h = torch.arange(0., height).unsqueeze(1) 42 | pe[0:d_model:2, :, :] = torch.sin( 43 | pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 44 | pe[1:d_model:2, :, :] = torch.cos( 45 | pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 46 | pe[d_model::2, :, :] = torch.sin( 47 | pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 48 | pe[d_model + 1::2, :, 49 | :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 50 | 51 | return pe 52 | -------------------------------------------------------------------------------- /utils/laplacian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from einops import repeat 4 | 5 | 6 | def make_identity(N, D, device): 7 | if N is None: 8 | return torch.Tensor(np.array(np.eye(D))).to(device) 9 | else: 10 | return torch.Tensor(np.array([np.eye(D)] * N)).to(device) 11 | 12 | def make_identity_like(A): 13 | assert A.shape[-2] == A.shape[-1] # Ensure A is a batch of squared matrices 14 | device = A.device 15 | shape = A.shape[:-2] 16 | eye = torch.eye(A.shape[-1], device=device)[(None,)*len(shape)] 17 | return eye.repeat(*shape, 1, 1) 18 | 19 | 20 | def make_diagonal(vecs): 21 | vecs = vecs[..., None].repeat(*([1,]*len(vecs.shape)), vecs.shape[-1]) 22 | return vecs * make_identity_like(vecs) 23 | 24 | # Calculate Normalized Laplacian 25 | def tracenorm_of_normalized_laplacian(A): 26 | D_vec = torch.sum(A, axis=-1) 27 | D = make_diagonal(D_vec) 28 | L = D - A 29 | inv_A_diag = make_diagonal( 30 | 1 / torch.sqrt(1e-10 + D_vec)) 31 | L = torch.matmul(inv_A_diag, torch.matmul(L, inv_A_diag)) 32 | sigmas = torch.linalg.svdvals(L) 33 | return torch.sum(sigmas, axis=-1) 34 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from einops import repeat 5 | import numpy as np 6 | 7 | 8 | def freq_to_wave(freq, is_radian=True): 9 | _freq_rad = 2 * math.pi * freq if not is_radian else freq 10 | return torch.hstack([torch.cos(_freq_rad), torch.sin(_freq_rad)]) 11 | 12 | 13 | def unsqueeze_at_the_end(x, n): 14 | return x[(...,) + (None,)*n] 15 | 16 | 17 | def get_RTmat(theta, phi, gamma, w, h, dx, dy): 18 | d = np.sqrt(h ** 2 + w ** 2) 19 | f = d / (2 * np.sin(gamma) if np.sin(gamma) != 0 else 1) 20 | # Projection 2D -> 3D matrix 21 | A1 = np.array([[1, 0, -w / 2], 22 | [0, 1, -h / 2], 23 | [0, 0, 1], 24 | [0, 0, 1]]) 25 | 26 | # Rotation matrices around the X, Y, and Z axis 27 | RX = np.array([[1, 0, 0, 0], 28 | [0, np.cos(theta), -np.sin(theta), 0], 29 | [0, np.sin(theta), np.cos(theta), 0], 30 | [0, 0, 0, 1]]) 31 | 32 | RY = np.array([[np.cos(phi), 0, -np.sin(phi), 0], 33 | [0, 1, 0, 0], 34 | [np.sin(phi), 0, np.cos(phi), 0], 35 | [0, 0, 0, 1]]) 36 | 37 | RZ = np.array([[np.cos(gamma), -np.sin(gamma), 0, 0], 38 | [np.sin(gamma), np.cos(gamma), 0, 0], 39 | [0, 0, 1, 0], 40 | [0, 0, 0, 1]]) 41 | 42 | # Composed rotation matrix with (RX, RY, RZ) 43 | R = np.dot(np.dot(RX, RY), RZ) 44 | 45 | # Translation matrix 46 | T = np.array([[1, 0, 0, dx], 47 | [0, 1, 0, dy], 48 | [0, 0, 1, f], 49 | [0, 0, 0, 1]]) 50 | # Projection 3D -> 2D matrix 51 | A2 = np.array([[f, 0, w / 2, 0], 52 | [0, f, h / 2, 0], 53 | [0, 0, 1, 0]]) 54 | return np.dot(A2, np.dot(T, np.dot(R, A1))) 55 | -------------------------------------------------------------------------------- /utils/optimize_bd_cob.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import repeat 4 | from utils.laplacian import tracenorm_of_normalized_laplacian, make_identity_like 5 | 6 | 7 | def optimize_bd_cob(mats, batchsize=32, n_epochs=50, epochs_monitor=10): 8 | # Optimize change of basis matrix U by minimizing block diagonalization loss 9 | 10 | class ChangeOfBasis(torch.nn.Module): 11 | def __init__(self, d): 12 | super().__init__() 13 | self.U = nn.Parameter(torch.empty(d, d)) 14 | torch.nn.init.orthogonal_(self.U) 15 | 16 | def __call__(self, mat): 17 | _U = repeat(self.U, "a1 a2 -> n a1 a2", n=mat.shape[0]) 18 | n_mat = torch.linalg.solve(_U, mat) @ _U 19 | return n_mat 20 | 21 | change_of_basis = ChangeOfBasis(mats.shape[-1]).to(mats.device) 22 | dataloader = torch.utils.data.DataLoader( 23 | mats, batch_size=batchsize, shuffle=True, num_workers=0) 24 | optimizer = torch.optim.Adam(change_of_basis.parameters(), lr=0.1) 25 | for ep in range(n_epochs): 26 | total_loss, total_N = 0, 0 27 | for mat in dataloader: 28 | n_mat = change_of_basis(mat) 29 | n_mat = torch.abs(n_mat) 30 | n_mat = torch.matmul(n_mat.transpose(-2, -1), n_mat) 31 | loss = torch.mean( 32 | tracenorm_of_normalized_laplacian(n_mat)) 33 | optimizer.zero_grad() 34 | loss.backward() 35 | optimizer.step() 36 | total_loss += loss.item() * mat.shape[0] 37 | total_N += mat.shape[0] 38 | if ((ep+1) % epochs_monitor) == 0: 39 | print('ep:{} loss:{}'.format(ep, total_loss/total_N)) 40 | return change_of_basis 41 | -------------------------------------------------------------------------------- /utils/weight_standarization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class WeightStandarization(nn.Module): 5 | def forward(self, weight): 6 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 7 | keepdim=True).mean(dim=3, keepdim=True) 8 | weight = weight - weight_mean 9 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 10 | weight = weight / std.expand_as(weight) 11 | return weight 12 | 13 | 14 | class WeightStandarization1d(nn.Module): 15 | def forward(self, weight): 16 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 17 | keepdim=True) 18 | weight = weight - weight_mean 19 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1) + 1e-5 20 | weight = weight / std.expand_as(weight) 21 | return weight 22 | 23 | 24 | class WeightStandarization0d(nn.Module): 25 | def forward(self, weight): 26 | weight_mean = weight.mean(dim=1, keepdim=True) 27 | weight = weight - weight_mean 28 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1) + 1e-5 29 | weight = weight / std.expand_as(weight) 30 | return weight -------------------------------------------------------------------------------- /utils/yaml_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import functools 4 | import argparse 5 | import yaml 6 | import pdb 7 | sys.path.append('../') 8 | sys.path.append('./') 9 | 10 | # Originally created by @msaito 11 | def load_module(fn, name): 12 | mod_name = os.path.splitext(os.path.basename(fn))[0] 13 | mod_path = os.path.dirname(fn) 14 | sys.path.insert(0, mod_path) 15 | return getattr(__import__(mod_name), name) 16 | 17 | 18 | def load_component(config): 19 | class_fn = load_module(config['fn'], config['name']) 20 | return class_fn(**config['args']) if 'args' in config.keys() else class_fn() 21 | 22 | 23 | def load_component_fxn(config): 24 | fxn = load_module(config['fn'], config['name']) 25 | return fxn 26 | 27 | 28 | def make_function(module, name): 29 | fxn = getattr(module, name) 30 | return fxn 31 | 32 | 33 | def make_instance(module, config=[], args=None): 34 | Class = getattr(module, config['name']) 35 | kwargs = config['args'] 36 | if args is not None: 37 | kwargs.update(args) 38 | return Class(**kwargs) 39 | 40 | 41 | ''' 42 | conbines multiple configs 43 | ''' 44 | 45 | 46 | def make_config(conf_dicts, attr_lists=None): 47 | def merge_dictionary(base, diff): 48 | for key, value in diff.items(): 49 | if (key in base and isinstance(base[key], dict) 50 | and isinstance(diff[key], dict)): 51 | merge_dictionary(base[key], diff[key]) 52 | else: 53 | base[key] = diff[key] 54 | 55 | config = {} 56 | for diff in conf_dicts: 57 | merge_dictionary(config, diff) 58 | if attr_lists is not None: 59 | for attr in attr_lists: 60 | module, new_value = attr.split('=') 61 | keys = module.split('.') 62 | target = functools.reduce(dict.__getitem__, keys[:-1], config) 63 | target[keys[-1]] = yaml.load(new_value) 64 | return config 65 | 66 | 67 | ''' 68 | argument parser that uses make_config 69 | ''' 70 | 71 | 72 | def parse_args(): 73 | 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument( 76 | 'infiles', nargs='+', type=argparse.FileType('r'), default=()) 77 | parser.add_argument('-a', '--attrs', nargs='*', default=()) 78 | parser.add_argument('-c', '--comment', default='') 79 | parser.add_argument('-w', '--warning', action='store_true') 80 | parser.add_argument('-o', '--output-config', default='') 81 | args = parser.parse_args() 82 | 83 | conf_dicts = [yaml.load(fp) for fp in args.infiles] 84 | config = make_config(conf_dicts, args.attrs) 85 | return config, args 86 | --------------------------------------------------------------------------------