├── LICENSE
├── README.md
├── block_diagonalization.ipynb
├── configs
├── 3dshapes
│ └── lstsq
│ │ ├── lstsq.yml
│ │ ├── lstsq_multi.yml
│ │ ├── lstsq_rec.yml
│ │ ├── neuralM.yml
│ │ └── neural_trans.yml
├── mnist
│ ├── lstsq
│ │ ├── lstsq.yml
│ │ ├── lstsq_multi.yml
│ │ ├── lstsq_rec.yml
│ │ ├── neuralM.yml
│ │ ├── neuralM_latentpred.yml
│ │ └── neural_trans.yml
│ └── simclr
│ │ ├── cpc.yml
│ │ └── simclr.yml
├── mnist_accl
│ └── lstsq
│ │ ├── holstsq.yml
│ │ ├── lstsq.yml
│ │ └── neural_trans.yml
├── mnist_bg
│ ├── lstsq
│ │ ├── lstsq.yml
│ │ ├── lstsq_multi.yml
│ │ ├── lstsq_rec.yml
│ │ ├── neuralM.yml
│ │ ├── neuralM_latentpred.yml
│ │ └── neural_trans.yml
│ └── simclr
│ │ ├── cpc.yml
│ │ └── simclr.yml
└── smallNORB
│ └── lstsq
│ ├── lstsq.yml
│ ├── lstsq_multi.yml
│ ├── lstsq_rec.yml
│ ├── neuralM.yml
│ └── neural_trans.yml
├── datasets
├── __init__.py
├── seq_mnist.py
├── small_norb.py
└── three_dim_shapes.py
├── equivariance_error.ipynb
├── extrp.ipynb
├── figs
├── comp_error_3dshapes.pdf
├── comp_error_mnist.pdf
├── comp_error_mnist_accl.pdf
├── comp_error_mnist_bg.pdf
├── comp_error_mnist_bg_full.pdf
├── comp_error_smallNORB.pdf
├── disentangle_3dshapes.pdf
├── disentangle_mnist.pdf
├── disentangle_mnist_bg.pdf
├── disentangle_mnist_bg_full.pdf
├── disentangle_smallNORB.pdf
├── equiv_error_3dshapes.pdf
├── equiv_error_mnist.pdf
├── equiv_error_mnist_accl.pdf
├── equiv_error_mnist_bg.pdf
├── equiv_error_mnist_bg_full.pdf
├── equiv_error_smallNORB.pdf
├── sbd_3dshapes.pdf
├── sbd_3dshapes_component0.pdf
├── sbd_3dshapes_component1.pdf
├── sbd_3dshapes_component2.pdf
├── sbd_3dshapes_component3.pdf
├── sbd_3dshapes_component5.pdf
├── sbd_3dshapes_components.pdf
├── sbd_mnist.pdf
├── sbd_mnist_bg.pdf
├── sbd_mnist_bg_components.pdf
├── sbd_mnist_bg_full.pdf
├── sbd_mnist_bg_full_component0.pdf
├── sbd_mnist_bg_full_component1.pdf
├── sbd_mnist_bg_full_component2.pdf
├── sbd_mnist_bg_full_components.pdf
├── sbd_mnist_component0.pdf
├── sbd_mnist_component1.pdf
├── sbd_mnist_component2.pdf
├── sbd_smallNORB.pdf
├── sbd_smallNORB_component0.pdf
├── sbd_smallNORB_component1.pdf
├── sbd_smallNORB_component2.pdf
├── sbd_smallNORB_component3.pdf
├── sbd_smallNORB_component4.pdf
├── sbd_smallNORB_component5.pdf
└── sbd_smallNORB_components.pdf
├── gen_images.ipynb
├── gen_images
├── 3dshapes_lstsq.png
├── 3dshapes_lstsq_multi.png
├── 3dshapes_lstsq_rec.png
├── 3dshapes_neuralM.png
├── 3dshapes_neural_trans.png
├── mnist_accl_holstsq_equiv.png
├── mnist_bg_full_lstsq.png
├── mnist_bg_full_lstsq_multi.png
├── mnist_bg_full_lstsq_rec.png
├── mnist_bg_full_neuralM.png
├── mnist_bg_full_neural_trans.png
├── mnist_bg_lstsq.png
├── mnist_bg_lstsq_multi.png
├── mnist_bg_lstsq_rec.png
├── mnist_bg_neuralM.png
├── mnist_bg_neural_trans.png
├── mnist_lstsq.png
├── mnist_lstsq_multi.png
├── mnist_lstsq_rec.png
├── mnist_neuralM.png
├── mnist_neural_trans.png
├── smallNORB_lstsq.png
├── smallNORB_lstsq_multi.png
├── smallNORB_lstsq_rec.png
├── smallNORB_neuralM.png
└── smallNORB_neural_trans.png
├── jobs
├── 09022022
│ ├── training_allmodels_1.sh
│ └── training_allmodels_2.sh
├── 09152022
│ ├── training_allmodels_1.sh
│ ├── training_allmodels_2.sh
│ ├── training_lstsqs.sh
│ └── training_neurals.sh
└── 09232022
│ ├── training_neuralMlatentpred_mnist.sh
│ └── training_neuralMlatentpred_mnist_bg.sh
├── models
├── __init__.py
├── base_networks.py
├── dynamics_models.py
├── resblock.py
├── seqae.py
└── simclr_models.py
├── requirements.txt
├── run.py
├── training_allmodels.sh
├── training_loops.py
└── utils
├── __init__.py
├── clr.py
├── emb2d.py
├── laplacian.py
├── misc.py
├── optimize_bd_cob.py
├── weight_standarization.py
└── yaml_utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Anonymous
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Meta-sequential prediction (MSP)
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | This repository contains the code for the NeurIPS2022 paper: Unsupervised Learning of Equivariant Structure on Sequences.
10 | A simple encoder-decoder model trained with *meta-sequential prediction* captures the hidden disentangled structure underlying the datasets.
11 |
12 | [[arXiv]](https://arxiv.org/abs/2210.05972) [[OpenReview]](https://openreview.net/forum?id=7b7iGkuVqlZ)
13 |
14 | ## Implementation of MSP and simultaneous block diagonalization (SBD)
15 | - `SeqAELSTSQ` in `./models/seqae.py` is the implementation of *meta-sequential prediction*.
16 | - `tracenorm_of_normalized_laplacian` in `./utils/laplacian.py` is used to calculate the block diagonalization loss in our paper.
17 |
18 | # Setup
19 | ## Prerequisite
20 | python3.7, CUDA11.2, cuDNN
21 |
22 | Please install additional python libraries by:
23 | ```
24 | pip install -r requirements.txt
25 | ```
26 |
27 | ## Download datasets
28 | Download the compressed dataset files from the following link:
29 |
30 | https://drive.google.com/drive/folders/1_uXjx06U48to9OSyGY1ezqipAbbuT0vq?usp=sharing
31 |
32 | The following is an example script to download and decompress the files:
33 | ```
34 | # If gdown is not installed:
35 | pip install gdown
36 |
37 | export DATADIR_ROOT=/tmp/path/to/datadir/
38 | cd /tmp
39 | gdown --folder https://drive.google.com/drive/folders/1_uXjx06U48to9OSyGY1ezqipAbbuT0vq?usp=sharing
40 | mv datasets/* ${DATADIR_ROOT}; rm datasets -r
41 | cd -
42 | tar xzf ${DATADIR_ROOT}/MNIST.tar.gz -C $DATADIR_ROOT
43 | tar xzf ${DATADIR_ROOT}/3dshapes.tar.gz -C $DATADIR_ROOT
44 | tar xzf ${DATADIR_ROOT}/smallNORB.tar.gz -C $DATADIR_ROOT
45 | ```
46 |
47 | ## Training with MSP
48 | 1. Select dataset on which you want to train the model:
49 | ```
50 | # Sequential MNIST
51 | export CONFIG=configs/mnist/lstsq/lstsq.yml
52 | # Sequential MNIST-bg with digit 4
53 | export CONFIG=configs/mnist_bg/lstsq/lstsq.yml
54 | # 3DShapes
55 | export CONFIG=configs/3dshapes/lstsq/lstsq.yml
56 | # SmallNORB
57 | export CONFIG=configs/smallNORB/lstsq/lstsq.yml
58 | # Accelerated sequential MNIST
59 | export CONFIG=configs/mnist_accl/lstsq/holstsq.yml
60 | ```
61 |
62 | 2. Run:
63 | ```
64 | export LOGDIR=/tmp/path/to/logdir
65 | export DATADIR_ROOT=/tmp/path/to/datadir
66 | python run.py --config_path=$CONFIG --log_dir=$LOGDIR --attr train_data.args.root=$DATADIR_ROOT
67 | ```
68 |
69 | ## Training all the methods we tested in our NeurIPS paper
70 | ```
71 | export LOGDIR=/tmp/path/to/logdir
72 | export DATADIR_ROOT=/tmp/path/to/datadir
73 | bash training_allmodels.sh $LOGDIR $DATADIR_ROOT
74 | ```
75 |
76 | ## Evaluations
77 | - Generated images: `gen_images.ipynb`
78 | - Equivariance errors: `equivariance_error.ipynb`
79 | - Prediction errors: `extrp.ipynb`
80 | - Simultaneous block diagoanlization: `block_diagonalization.ipynb`
81 |
--------------------------------------------------------------------------------
/configs/3dshapes/lstsq/lstsq.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/three_dim_shapes.py
14 | name: ThreeDimShapesDataset
15 | args:
16 | root: /tmp/datasets/3dshapes/
17 | train: True
18 | T: 3
19 |
20 | model:
21 | fn: ./models/seqae.py
22 | name: SeqAELSTSQ
23 | args:
24 | dim_m: 256
25 | dim_a: 16
26 | ch_x: 3
27 | k: 1.0
28 | bottom_width: 8
29 | predictive: True
30 |
31 | training_loop:
32 | fn: ./training_loops.py
33 | name: loop_seqmodel
34 | args:
35 | lr_decay_iter: 40000
36 | reconst_iter: 0
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/configs/3dshapes/lstsq/lstsq_multi.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 | optimizer: Adam
12 |
13 | train_data:
14 | fn: ./datasets/three_dim_shapes.py
15 | name: ThreeDimShapesDataset
16 | args:
17 | root: /tmp/datasets/3dshapes/
18 | train: True
19 | T: 3
20 |
21 | model:
22 | fn: ./models/seqae
23 | name: SeqAEMultiLSTSQ
24 | args:
25 | dim_m: 256
26 | dim_a: 16
27 | ch_x: 3
28 | k: 1.0
29 | bottom_width: 8
30 | alignment: False
31 | change_of_basis: False
32 | global_average_pooling: False # Important. (Why?)
33 | discretized: False
34 | K: 8
35 |
36 |
37 | training_loop:
38 | fn: ./training_loops.py
39 | name: loop_seqmodel
40 | args:
41 |
42 | lr_decay_iter: 40000
43 | reconst_iter: 0
44 |
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/configs/3dshapes/lstsq/lstsq_rec.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 3
9 | T_pred: 0
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/three_dim_shapes.py
14 | name: ThreeDimShapesDataset
15 | args:
16 | root: /tmp/datasets/3dshapes/
17 | train: True
18 | T: 3
19 |
20 | model:
21 | fn: ./models/seqae.py
22 | name: SeqAELSTSQ
23 | args:
24 | dim_m: 256
25 | dim_a: 16
26 | ch_x: 3
27 | k: 1.0
28 | bottom_width: 8
29 | predictive: False
30 |
31 | training_loop:
32 | fn: ./training_loops.py
33 | name: loop_seqmodel
34 | args:
35 | lr_decay_iter: 40000
36 | reconst_iter: 0
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/configs/3dshapes/lstsq/neuralM.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/three_dim_shapes.py
14 | name: ThreeDimShapesDataset
15 | args:
16 | root: /tmp/datasets/3dshapes/
17 | train: True
18 | T: 3
19 |
20 | model:
21 | fn: ./models/seqae.py
22 | name: SeqAENeuralM
23 | args:
24 | dim_m: 256
25 | dim_a: 16
26 | ch_x: 3
27 | k: 1.0
28 | bottom_width: 8
29 |
30 |
31 | training_loop:
32 | fn: ./training_loops.py
33 | name: loop_seqmodel
34 | args:
35 | lr_decay_iter: 40000
36 | reconst_iter: 50000
37 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/3dshapes/lstsq/neural_trans.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/three_dim_shapes.py
14 | name: ThreeDimShapesDataset
15 | args:
16 | root: /tmp/datasets/3dshapes/
17 | train: True
18 | T: 3
19 |
20 | model:
21 | fn: ./models/seqae.py
22 | name: SeqAENeuralTransition
23 | args:
24 | dim_m: 256
25 | dim_a: 16
26 | ch_x: 3
27 | k: 1.0
28 | bottom_width: 8
29 |
30 |
31 | training_loop:
32 | fn: ./training_loops.py
33 | name: loop_seqmodel
34 | args:
35 | lr_decay_iter: 40000
36 | reconst_iter: 50000
37 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/mnist/lstsq/lstsq.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: False
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAELSTSQ
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 2.
33 | predictive: True
34 |
35 | training_loop:
36 | fn: ./training_loops.py
37 | name: loop_seqmodel
38 | args:
39 | lr_decay_iter: 40000
40 | reconst_iter: 0
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/configs/mnist/lstsq/lstsq_multi.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: False
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAEMultiLSTSQ
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 2.
33 | predictive: True
34 | K: 8
35 |
36 |
37 |
38 | training_loop:
39 | fn: ./training_loops.py
40 | name: loop_seqmodel
41 | args:
42 | lr_decay_iter: 40000
43 | reconst_iter: 0
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/configs/mnist/lstsq/lstsq_rec.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 3
9 | T_pred: 0
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: False
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAELSTSQ
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 2.
33 | predictive: False
34 |
35 | training_loop:
36 | fn: ./training_loops.py
37 | name: loop_seqmodel
38 | args:
39 | lr_decay_iter: 40000
40 | reconst_iter: 0
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/configs/mnist/lstsq/neuralM.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 |
24 | model:
25 | fn: ./models/seqae.py
26 | name: SeqAENeuralM
27 | args:
28 | dim_m: 256
29 | dim_a: 16
30 | ch_x: 3
31 | k: 2.
32 |
33 |
34 |
35 | training_loop:
36 | fn: ./training_loops.py
37 | name: loop_seqmodel
38 | args:
39 | lr_decay_iter: 40000
40 | reconst_iter: 50000
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/configs/mnist/lstsq/neuralM_latentpred.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 |
24 | model:
25 | fn: ./models/seqae.py
26 | name: SeqAENeuralMLatentPredict
27 | args:
28 | dim_m: 256
29 | dim_a: 16
30 | ch_x: 3
31 | k: 2.
32 | loss_reconst_coeff: 0.0
33 | loss_pred_coeff: 1.0
34 | loss_latent_coeff: 0.0
35 | normalize: True
36 |
37 |
38 |
39 |
40 | training_loop:
41 | fn: ./training_loops.py
42 | name: loop_seqmodel
43 | args:
44 | lr_decay_iter: 40000
45 | reconst_iter: 50000
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/configs/mnist/lstsq/neural_trans.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 |
24 | model:
25 | fn: ./models/seqae.py
26 | name: SeqAENeuralTransition
27 | args:
28 | dim_m: 256
29 | dim_a: 16
30 | ch_x: 3
31 | k: 2.
32 |
33 | training_loop:
34 | fn: ./training_loops.py
35 | name: loop_seqmodel
36 | args:
37 | lr_decay_iter: 40000
38 | reconst_iter: 50000
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/mnist/simclr/cpc.yml:
--------------------------------------------------------------------------------
1 | batchsize: 64
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 10000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: False
24 |
25 | model:
26 | fn: ./models/seqae
27 | name: CPC
28 | args:
29 | dim_m: 1
30 | dim_a: 512
31 | k: 2.
32 | loss_type: cossim
33 | normalize: True
34 | temp: 0.01
35 |
36 | training_loop:
37 | fn: ./training_loops.py
38 | name: loop_seqmodel
39 | args:
40 | lr_decay_iter: 40000
41 | reconst_iter: 0
42 |
--------------------------------------------------------------------------------
/configs/mnist/simclr/simclr.yml:
--------------------------------------------------------------------------------
1 | batchsize: 64
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 10000
7 | num_workers: 6
8 | lr: 0.0003
9 |
10 | train_data:
11 | fn: ./datasets/seq_mnist.py
12 | name: SequentialMNIST
13 | args:
14 | root: /tmp/datasets/MNIST
15 | train: True
16 | T: 3
17 | max_T: 9
18 | max_angle_velocity_ratio: [-0.5, 0.5]
19 | max_color_velocity_ratio: [-0.5, 0.5]
20 | only_use_digit4: True
21 | backgrnd: False
22 |
23 | model:
24 | fn: ./models/simclr_models.py
25 | name: ResNetwProjHead
26 | args:
27 | dim_mlp: 512
28 | dim_head: 512
29 | k: 2
30 |
31 |
32 | training_loop:
33 | fn: ./training_loops.py
34 | name: loop_simclr
35 | args:
36 | lr_decay_iter: 40000
37 | temp: 0.01
38 | loss_type: cossim
--------------------------------------------------------------------------------
/configs/mnist_accl/lstsq/holstsq.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 5
9 | T_pred: 5
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 10
19 | max_T: 10
20 | max_angle_velocity_ratio: [-0.2, 0.2]
21 | max_angle_accl_ratio: [-0.025, 0.025]
22 | max_color_velocity_ratio: [-0.2, 0.2]
23 | max_color_accl_ratio: [-0.025, 0.025]
24 | max_pos: [0., 0.]
25 | max_trans_accl: [0., 0.]
26 | only_use_digit4: True
27 |
28 | model:
29 | fn: ./models/seqae
30 | name: SeqAEHOLSTSQ
31 | args:
32 | dim_m: 256
33 | dim_a: 16
34 | ch_x: 3
35 | k: 2.
36 | n_order: 2
37 | predictive: True
38 | kernel_size: 3
39 |
40 | training_loop:
41 | fn: ./training_loops.py
42 | name: loop_seqmodel
43 | args:
44 | lr_decay_iter: 80000
45 | reconst_iter: 0
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/configs/mnist_accl/lstsq/lstsq.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 5
9 | T_pred: 5
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 10
19 | max_T: 10
20 | max_angle_velocity_ratio: [-0.2, 0.2]
21 | max_angle_accl_ratio: [-0.025, 0.025]
22 | max_color_velocity_ratio: [-0.2, 0.2]
23 | max_color_accl_ratio: [-0.025, 0.025]
24 | max_pos: [0., 0.]
25 | max_trans_accl: [0., 0.]
26 | only_use_digit4: True
27 |
28 | model:
29 | fn: ./models/seqae
30 | name: SeqAELSTSQ
31 | args:
32 | dim_m: 256
33 | dim_a: 16
34 | ch_x: 3
35 | k: 2.
36 | predictive: True
37 | kernel_size: 3
38 |
39 | training_loop:
40 | fn: ./training_loops.py
41 | name: loop_seqmodel
42 | args:
43 | lr_decay_iter: 80000
44 | reconst_iter: 0
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/configs/mnist_accl/lstsq/neural_trans.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 5
9 | T_pred: 5
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 10
19 | max_T: 10
20 | max_angle_velocity_ratio: [-0.2, 0.2]
21 | max_angle_accl_ratio: [-0.025, 0.025]
22 | max_color_velocity_ratio: [-0.2, 0.2]
23 | max_color_accl_ratio: [-0.025, 0.025]
24 | max_pos: [0., 0.]
25 | max_trans_accl: [0., 0.]
26 | only_use_digit4: True
27 |
28 | model:
29 | fn: ./models/seqae.py
30 | name: SeqAENeuralTransition
31 | args:
32 | dim_m: 256
33 | dim_a: 16
34 | ch_x: 3
35 | k: 2.
36 | T_cond: 5
37 | predictive: True
38 | kernel_size: 3
39 |
40 |
41 | training_loop:
42 | fn: ./training_loops.py
43 | name: loop_seqmodel
44 | args:
45 | lr_decay_iter: 80000
46 | reconst_iter: 100000
--------------------------------------------------------------------------------
/configs/mnist_bg/lstsq/lstsq.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: True
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAELSTSQ
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 4.
33 | predictive: True
34 |
35 | training_loop:
36 | fn: ./training_loops.py
37 | name: loop_seqmodel
38 | args:
39 | lr_decay_iter: 80000
40 | reconst_iter: 0
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/configs/mnist_bg/lstsq/lstsq_multi.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: True
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAEMultiLSTSQ
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 4.
33 | predictive: True
34 | K: 8
35 |
36 |
37 |
38 | training_loop:
39 | fn: ./training_loops.py
40 | name: loop_seqmodel
41 | args:
42 | lr_decay_iter: 80000
43 | reconst_iter: 0
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/configs/mnist_bg/lstsq/lstsq_rec.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 3
9 | T_pred: 0
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: True
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAELSTSQ
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 4.
33 | predictive: False
34 |
35 | training_loop:
36 | fn: ./training_loops.py
37 | name: loop_seqmodel
38 | args:
39 | lr_decay_iter: 80000
40 | reconst_iter: 0
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/configs/mnist_bg/lstsq/neuralM.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: True
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAENeuralM
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 4.
33 |
34 |
35 |
36 | training_loop:
37 | fn: ./training_loops.py
38 | name: loop_seqmodel
39 | args:
40 | lr_decay_iter: 80000
41 | reconst_iter: 100000
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/configs/mnist_bg/lstsq/neuralM_latentpred.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: True
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAENeuralMLatentPredict
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 4.
33 | loss_reconst_coeff: 0.0
34 | loss_pred_coeff: 1.0
35 | loss_latent_coeff: 0.0
36 | normalize: True
37 |
38 |
39 | training_loop:
40 | fn: ./training_loops.py
41 | name: loop_seqmodel
42 | args:
43 | lr_decay_iter: 80000
44 | reconst_iter: 100000
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/configs/mnist_bg/lstsq/neural_trans.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: True
24 |
25 | model:
26 | fn: ./models/seqae.py
27 | name: SeqAENeuralTransition
28 | args:
29 | dim_m: 256
30 | dim_a: 16
31 | ch_x: 3
32 | k: 4.
33 |
34 | training_loop:
35 | fn: ./training_loops.py
36 | name: loop_seqmodel
37 | args:
38 | lr_decay_iter: 80000
39 | reconst_iter: 100000
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/mnist_bg/simclr/cpc.yml:
--------------------------------------------------------------------------------
1 | batchsize: 64
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 10000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/seq_mnist.py
14 | name: SequentialMNIST
15 | args:
16 | root: /tmp/datasets/MNIST
17 | train: True
18 | T: 3
19 | max_T: 9
20 | max_angle_velocity_ratio: [-0.5, 0.5]
21 | max_color_velocity_ratio: [-0.5, 0.5]
22 | only_use_digit4: True
23 | backgrnd: False
24 |
25 | model:
26 | fn: ./models/seqae
27 | name: CPC
28 | args:
29 | dim_m: 1
30 | dim_a: 512
31 | k: 4.
32 | loss_type: cossim
33 | normalize: True
34 | temp: 0.01
35 |
36 | training_loop:
37 | fn: ./training_loops.py
38 | name: loop_seqmodel
39 | args:
40 | lr_decay_iter: 40000
41 | reconst_iter: 0
42 |
--------------------------------------------------------------------------------
/configs/mnist_bg/simclr/simclr.yml:
--------------------------------------------------------------------------------
1 | batchsize: 64
2 | seed: 1
3 | max_iteration: 100000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 10000
7 | num_workers: 6
8 | lr: 0.0003
9 |
10 | train_data:
11 | fn: ./datasets/seq_mnist.py
12 | name: SequentialMNIST
13 | args:
14 | root: /tmp/datasets/MNIST
15 | train: True
16 | T: 3
17 | max_T: 9
18 | max_angle_velocity_ratio: [-0.5, 0.5]
19 | max_color_velocity_ratio: [-0.5, 0.5]
20 | only_use_digit4: True
21 | backgrnd: False
22 |
23 | model:
24 | fn: ./models/simclr_models.py
25 | name: ResNetwProjHead
26 | args:
27 | dim_mlp: 512
28 | dim_head: 512
29 | k: 4
30 |
31 |
32 | training_loop:
33 | fn: ./training_loops.py
34 | name: loop_simclr
35 | args:
36 | lr_decay_iter: 80000
37 | temp: 0.01
38 | loss_type: cossim
--------------------------------------------------------------------------------
/configs/smallNORB/lstsq/lstsq.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/small_norb.py
14 | name: SmallNORBDataset
15 | args:
16 | root: /tmp/datasets
17 | train: True
18 | T: 3
19 |
20 | model:
21 | fn: ./models/seqae.py
22 | name: SeqAELSTSQ
23 | args:
24 | dim_m: 256
25 | dim_a: 16
26 | ch_x: 1
27 | k: 1.0
28 | bottom_width: 6
29 | n_blocks: 4
30 | predictive: True
31 |
32 |
33 | training_loop:
34 | fn: ./training_loops.py
35 | name: loop_seqmodel
36 | args:
37 | lr_decay_iter: 40000
38 | reconst_iter: 0
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/smallNORB/lstsq/lstsq_multi.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 2
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 |
13 | train_data:
14 | fn: ./datasets/small_norb.py
15 | name: SmallNORBDataset
16 | args:
17 | root: /tmp/datasets
18 | train: True
19 | T: 3
20 |
21 | model:
22 | fn: ./models/seqae
23 | name: SeqAEMultiLSTSQ
24 | args:
25 | dim_m: 256
26 | dim_a: 16
27 | ch_x: 1
28 | k: 1.0
29 | bottom_width: 6
30 | n_blocks: 4
31 | K: 8
32 |
33 | training_loop:
34 | fn: ./training_loops.py
35 | name: loop_seqmodel
36 | args:
37 | lr_decay_iter: 40000
38 | reconst_iter: 0
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/smallNORB/lstsq/lstsq_rec.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 3
9 | T_pred: 0
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/small_norb.py
14 | name: SmallNORBDataset
15 | args:
16 | root: /tmp/datasets
17 | train: True
18 | T: 3
19 |
20 | model:
21 | fn: ./models/seqae.py
22 | name: SeqAELSTSQ
23 | args:
24 | dim_m: 256
25 | dim_a: 16
26 | ch_x: 1
27 | k: 1.0
28 | bottom_width: 6
29 | n_blocks: 4
30 | predictive: False
31 |
32 |
33 | training_loop:
34 | fn: ./training_loops.py
35 | name: loop_seqmodel
36 | args:
37 | lr_decay_iter: 40000
38 | reconst_iter: 0
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/smallNORB/lstsq/neuralM.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 | train_data:
13 | fn: ./datasets/small_norb.py
14 | name: SmallNORBDataset
15 | args:
16 | root: /tmp/datasets
17 | train: True
18 | T: 3
19 |
20 | model:
21 | fn: ./models/seqae.py
22 | name: SeqAENeuralM
23 | args:
24 | dim_m: 256
25 | dim_a: 16
26 | ch_x: 1
27 | k: 1.0
28 | n_blocks: 4
29 | bottom_width: 6
30 |
31 | training_loop:
32 | fn: ./training_loops.py
33 | name: loop_seqmodel
34 | args:
35 | lr_decay_iter: 40000
36 | reconst_iter: 50000
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/configs/smallNORB/lstsq/neural_trans.yml:
--------------------------------------------------------------------------------
1 | batchsize: 32
2 | seed: 1
3 | max_iteration: 50000
4 | report_freq: 1000
5 | model_snapshot_freq: 10000
6 | manager_snapshot_freq: 50000
7 | num_workers: 6
8 | T_cond: 2
9 | T_pred: 1
10 | lr: 0.0003
11 |
12 |
13 | train_data:
14 | fn: ./datasets/small_norb.py
15 | name: SmallNORBDataset
16 | args:
17 | root: /tmp/datasets
18 | train: True
19 | T: 3
20 |
21 | model:
22 | fn: ./models/seqae.py
23 | name: SeqAENeuralTransition
24 | args:
25 | dim_m: 256
26 | dim_a: 16
27 | ch_x: 1
28 | k: 1.0
29 | n_blocks: 4
30 | bottom_width: 6
31 |
32 | training_loop:
33 | fn: ./training_loops.py
34 | name: loop_seqmodel
35 | args:
36 | lr_decay_iter: 40000
37 | reconst_iter: 50000
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/seq_mnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import torch
5 | import torchvision
6 | import math
7 | import colorsys
8 | from skimage.transform import resize
9 | from copy import deepcopy
10 | from utils.misc import get_RTmat
11 | from utils.misc import freq_to_wave
12 |
13 |
14 | class SequentialMNIST():
15 | # Rotate around z axis only.
16 |
17 | default_active_actions = [0, 1, 2]
18 |
19 | def __init__(
20 | self,
21 | root,
22 | train=True,
23 | transforms=torchvision.transforms.ToTensor(),
24 | T=3,
25 | max_angle_velocity_ratio=[-0.5, 0.5],
26 | max_angle_accl_ratio=[-0.0, 0.0],
27 | max_color_velocity_ratio=[-0.5, 0.5],
28 | max_color_accl_ratio=[-0.0, 0.0],
29 | max_pos=[-10, 10],
30 | max_trans_accl=[-0.0, 0.0],
31 | label=False,
32 | label_velo=False,
33 | label_accl=False,
34 | active_actions=None,
35 | max_T=9,
36 | only_use_digit4=False,
37 | backgrnd=False,
38 | shared_transition=False,
39 | color_off=False,
40 | rng=None
41 | ):
42 | self.T = T
43 | self.max_T = max_T
44 | self.rng = rng if rng is not None else np.random
45 | self.transforms = transforms
46 | self.data = torchvision.datasets.MNIST(root, train, download=True)
47 | self.angle_velocity_range = (-max_angle_velocity_ratio, max_angle_velocity_ratio) if isinstance(
48 | max_angle_velocity_ratio, (int, float)) else max_angle_velocity_ratio
49 | self.color_velocity_range = (-max_color_velocity_ratio, max_color_velocity_ratio) if isinstance(
50 | max_color_velocity_ratio, (int, float)) else max_color_velocity_ratio
51 | self.angle_accl_range = (-max_angle_accl_ratio, max_angle_accl_ratio) if isinstance(
52 | max_angle_accl_ratio, (int, float)) else max_angle_accl_ratio
53 | self.color_accl_range = (-max_color_accl_ratio, max_color_accl_ratio) if isinstance(
54 | max_color_accl_ratio, (int, float)) else max_color_accl_ratio
55 | self.color_off = color_off
56 |
57 | self.max_pos = max_pos
58 | self.max_trans_accl = max_trans_accl
59 | self.label = label
60 | self.label_velo = label_velo
61 | self.label_accl = label_accl
62 | self.active_actions = self.default_active_actions if active_actions is None else active_actions
63 | if backgrnd:
64 | print("""
65 | =============
66 | background ON
67 | =============
68 | """)
69 | fname = "MNIST/train_dat.pt" if train else "MNIST/test_dat.pt"
70 | self.backgrnd_data = torch.load(os.path.join(root, fname))
71 |
72 | if only_use_digit4:
73 | datas = []
74 | for pair in self.data:
75 | if pair[1] == 4:
76 | datas.append(pair)
77 | self.data = datas
78 | self.shared_transition = shared_transition
79 | if self.shared_transition:
80 | self.init_shared_transition_parameters()
81 |
82 | def init_shared_transition_parameters(self):
83 | self.angles_v = self.rng.uniform(math.pi * self.angle_velocity_range[0],
84 | math.pi * self.angle_velocity_range[1], size=1)
85 | self.angles_a = self.rng.uniform(math.pi * self.angle_accl_range[0],
86 | math.pi * self.angle_accl_range[1], size=1)
87 | self.color_v = 0.5 * self.rng.uniform(self.color_velocity_range[0],
88 | self.color_velocity_range[1], size=1)
89 | self.color_a = 0.5 * \
90 | self.rng.uniform(
91 | self.color_accl_range[0], self.color_accl_range[1], size=1)
92 | pos0 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2])
93 | pos1 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2])
94 | self.pos_v = (pos1-pos0)/(self.max_T - 1)
95 | self.pos_a = self.rng.uniform(
96 | self.max_trans_accl[0], self.max_trans_accl[1], size=[2])
97 |
98 | def __len__(self):
99 | return len(self.data)
100 |
101 | def __getitem__(self, i):
102 | image = np.array(self.data[i][0], np.float32).reshape(28, 28)
103 | image = resize(image, [24, 24])
104 | image = cv2.copyMakeBorder(
105 | image, 4, 4, 4, 4, cv2.BORDER_CONSTANT, value=(0, 0, 0))
106 | angles_0 = self.rng.uniform(0, 2 * math.pi, size=1)
107 | color_0 = self.rng.uniform(0, 1, size=1)
108 | pos0 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2])
109 | pos1 = self.rng.uniform(self.max_pos[0], self.max_pos[1], size=[2])
110 | if self.shared_transition:
111 | (angles_v, angles_a) = (self.angles_v, self.angles_a)
112 | (color_v, color_a) = (self.color_v, self.color_a)
113 | (pos_v, pos_a) = (self.pos_v, self.pos_a)
114 | else:
115 | angles_v = self.rng.uniform(math.pi * self.angle_velocity_range[0],
116 | math.pi * self.angle_velocity_range[1], size=1)
117 | angles_a = self.rng.uniform(math.pi * self.angle_accl_range[0],
118 | math.pi * self.angle_accl_range[1], size=1)
119 | color_v = 0.5 * self.rng.uniform(self.color_velocity_range[0],
120 | self.color_velocity_range[1], size=1)
121 | color_a = 0.5 * \
122 | self.rng.uniform(
123 | self.color_accl_range[0], self.color_accl_range[1], size=1)
124 | pos_v = (pos1-pos0)/(self.max_T - 1)
125 | pos_a = self.rng.uniform(
126 | self.max_trans_accl[0], self.max_trans_accl[1], size=[2])
127 | images = []
128 | for t in range(self.T):
129 | angles_t = (0.5 * angles_a * t**2 + angles_v * t +
130 | angles_0) if 0 in self.active_actions else angles_0
131 | color_t = ((0.5 * color_a * t**2 + t * color_v + color_0) %
132 | 1) if 1 in self.active_actions else color_0
133 | pos_t = (0.5 * pos_a * t**2 + pos_v * t +
134 | pos0) if 2 in self.active_actions else pos0
135 | mat = get_RTmat(0, 0, float(angles_t), 32, 32, pos_t[0], pos_t[1])
136 | _image = cv2.warpPerspective(image.copy(), mat, (32, 32))
137 |
138 | rgb = np.asarray(colorsys.hsv_to_rgb(
139 | color_t, 1, 1), dtype=np.float32)
140 | _image = np.concatenate(
141 | [_image[:, :, None]] * 3, axis=-1) * rgb[None, None]
142 | _image = _image / 255.
143 |
144 | if hasattr(self, 'backgrnd_data'):
145 | _imagemask = (np.sum(_image, axis=2, keepdims=True) < 3e-1)
146 | _image = torch.tensor(
147 | _image) + self.backgrnd_data[i].permute([1, 2, 0]) * (_imagemask)
148 | _image = np.array(torch.clip(_image, max=1.))
149 |
150 | images.append(self.transforms(_image.astype(np.float32)))
151 |
152 | if self.label or self.label_velo:
153 | ret = [images]
154 | if self.label:
155 | ret += [self.data[i][1]]
156 | if self.label_velo:
157 | ret += [
158 | freq_to_wave(angles_v.astype(np.float32)),
159 | freq_to_wave((2 * math.pi * color_v).astype(np.float32)),
160 | pos_v.astype(np.float32)
161 | ]
162 | if self.label_accl:
163 | ret += [
164 | freq_to_wave(angles_a.astype(np.float32)),
165 | freq_to_wave((2 * math.pi * color_a).astype(np.float32)),
166 | pos_a.astype(np.float32)
167 | ]
168 | return ret
169 | else:
170 | return images
171 |
--------------------------------------------------------------------------------
/datasets/small_norb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision
4 | from collections import OrderedDict
5 | import os
6 | import numpy as np
7 |
8 |
9 | _FACTORS_IN_ORDER = ['category', 'instance',
10 | 'lighting', 'elevation', 'azimuth']
11 | _ELEV_V = [30, 35, 40, 45, 50, 55, 60, 65, 70]
12 | _AZIM_V = np.arange(0, 350, 20)
13 | assert len(_AZIM_V) == 18
14 | _NUM_VALUES_PER_FACTOR = OrderedDict(
15 | {'category': 5, 'instance': 5, 'lighting': 6, 'elevation': 9, 'azimuth': 18})
16 |
17 |
18 | def get_index(factors):
19 | """ Converts factors to indices in range(num_data)
20 | Args:
21 | factors: np array shape [6,batch_size].
22 | factors[i]=factors[i,:] takes integer values in
23 | range(_NUM_VALUES_PER_FACTOR[_FACTORS_IN_ORDER[i]]).
24 |
25 | Returns:
26 | indices: np array shape [batch_size].
27 | """
28 | indices = 0
29 | base = 1
30 | for factor, name in reversed(list(enumerate(_FACTORS_IN_ORDER))):
31 | indices += factors[factor] * base
32 | base *= _NUM_VALUES_PER_FACTOR[name]
33 | return indices
34 |
35 |
36 | class SmallNORBDataset(object):
37 |
38 | default_active_actions = [3,4]
39 |
40 | def __init__(self, root,
41 | train=True,
42 | T=3,
43 | label=False,
44 | label_velo=False,
45 | force_moving=False,
46 | active_actions=None,
47 | transforms=torchvision.transforms.ToTensor(),
48 | shared_transition=False,
49 | rng=None):
50 | assert T <= 6
51 | self.data = torch.load(os.path.join(
52 | root, 'smallNORB/train.pt' if train else 'smallNORB/test.pt'))
53 | self.label = label
54 | self.label_velo = label_velo
55 | print(self.data.shape)
56 | self.T = T
57 | self.transforms = transforms
58 | self.active_actions = self.default_active_actions if active_actions is None else active_actions
59 | self.force_moving = force_moving
60 | self.rng = rng if rng is not None else np.random
61 | self.shared_transition = shared_transition
62 | if self.shared_transition:
63 | self.init_shared_transition_parameters()
64 |
65 | def init_shared_transition_parameters(self):
66 | self.vs = {}
67 | for kv in _NUM_VALUES_PER_FACTOR.items():
68 | key, value = kv[0], kv[1]
69 | self.vs[key] = self.gen_v(value)
70 |
71 | def __len__(self):
72 | return 5000
73 |
74 | def gen_pos(self, max_n, v):
75 | _x = np.abs(v) * (self.T-1)
76 | if v < 0:
77 | return self.rng.randint(_x, max_n)
78 | else:
79 | return self.rng.randint(0, max_n-_x)
80 |
81 | def gen_v(self, max_n):
82 | v = self.rng.randint(1 if self.force_moving else 0, max_n//self.T + 1)
83 | if self.rng.uniform() > 0.5:
84 | v = -v
85 | return v
86 |
87 | def gen_factors(self):
88 | # initial state
89 | p_and_v_list = []
90 | sampled_indices = []
91 | for action_index, kv in enumerate(_NUM_VALUES_PER_FACTOR.items()):
92 | key, value = kv[0], kv[1]
93 | if key == 'category' or key == 'instance' or key == 'lighting':
94 | p_and_v_list.append([0, 0])
95 | index = self.rng.randint(0, _NUM_VALUES_PER_FACTOR[key])
96 | sampled_indices.append([index]*self.T)
97 | else:
98 | if not(action_index in self.active_actions):
99 | v = 0
100 | else:
101 | if self.shared_transition:
102 | v = self.vs[key]
103 | else:
104 | v = self.gen_v(value)
105 | p = self.gen_pos(value, v)
106 | p_and_v_list.append((p, v))
107 | indices = [p + t * v for t in range(self.T)]
108 | sampled_indices.append(indices)
109 | #print(p_and_v_list)
110 | return np.array(p_and_v_list, dtype=np.uint8), np.array(sampled_indices, dtype=np.uint8).T
111 |
112 | def __getitem__(self, i):
113 | p_and_v_list, sample_indices = self.gen_factors()
114 | imgs = []
115 | for t in range(self.T):
116 | ind = sample_indices[t]
117 | img = self.data[ind[0], ind[1], ind[2], ind[3], ind[4]]
118 | img = img/255.
119 | img = self.transforms(img[:, :, None])
120 | imgs.append(img)
121 | if self.T == 1:
122 | imgs = imgs[0]
123 |
124 | if self.label or self.label_velo:
125 | ret = [imgs]
126 | if self.label:
127 | ret += [sample_indices[0][0]]
128 | if self.label_velo:
129 | ret += [p_and_v_list[3][1][None], p_and_v_list[4][1][None]]
130 |
131 | return ret
132 | else:
133 | return imgs
134 |
--------------------------------------------------------------------------------
/datasets/three_dim_shapes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchvision
4 | from collections import OrderedDict
5 | import os
6 |
7 |
8 | _FACTORS_IN_ORDER = ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape',
9 | 'orientation']
10 | _NUM_VALUES_PER_FACTOR = OrderedDict({'floor_hue': 10, 'wall_hue': 10, 'object_hue': 10,
11 | 'scale': 8, 'shape': 4, 'orientation': 15})
12 |
13 |
14 | def get_index(factors):
15 | """ Converts factors to indices in range(num_data)
16 | Args:
17 | factors: np array shape [6,batch_size].
18 | factors[i]=factors[i,:] takes integer values in
19 | range(_NUM_VALUES_PER_FACTOR[_FACTORS_IN_ORDER[i]]).
20 |
21 | Returns:
22 | indices: np array shape [batch_size].
23 | """
24 | indices = 0
25 | base = 1
26 | for factor, name in reversed(list(enumerate(_FACTORS_IN_ORDER))):
27 | indices += factors[factor] * base
28 | base *= _NUM_VALUES_PER_FACTOR[name]
29 | return indices
30 |
31 |
32 | class ThreeDimShapesDataset(object):
33 |
34 | default_active_actions = [0,1,2,3,5]
35 |
36 | def __init__(self, root, train=True, T=3, label_velo=False, transforms=torchvision.transforms.ToTensor(),
37 | active_actions=None, force_moving=False, shared_transition=False, rng=None):
38 | assert T <= 8
39 | self.images = torch.load(os.path.join(root, '3dshapes/images.pt')).astype(np.float32)
40 | self.label_velo = label_velo
41 | self.train = train
42 | self.T = T
43 | self.transforms = transforms
44 | self.active_actions = self.default_active_actions if active_actions is None else active_actions
45 | self.force_moving = force_moving
46 | self.rng = rng if rng is not None else np.random
47 | self.shared_transition = shared_transition
48 | if self.shared_transition:
49 | self.init_shared_transition_parameters()
50 |
51 | def init_shared_transition_parameters(self):
52 | vs = {}
53 | for kv in _NUM_VALUES_PER_FACTOR.items():
54 | key, value = kv[0], kv[1]
55 | vs[key] = self.gen_v(value)
56 | self.vs = vs
57 |
58 | def __len__(self):
59 | return 5000
60 |
61 | def gen_pos(self, max_n, v):
62 | _x = np.abs(v) * (self.T-1)
63 | if v < 0:
64 | return self.rng.randint(_x, max_n)
65 | else:
66 | return self.rng.randint(0, max_n-_x)
67 |
68 | def gen_v(self, max_n):
69 | v = self.rng.randint(1 if self.force_moving else 0, max_n//self.T + 1)
70 | if self.rng.uniform() > 0.5:
71 | v = -v
72 | return v
73 |
74 |
75 | def gen_factors(self):
76 | # initial state
77 | p_and_v_list = []
78 | sampled_indices = []
79 | for action_index, kv in enumerate(_NUM_VALUES_PER_FACTOR.items()):
80 | key, value = kv[0], kv[1]
81 | if key == 'shape':
82 | p_and_v_list.append([0, 0])
83 | if self.train:
84 | shape = self.rng.choice([0])
85 | else:
86 | shape = self.rng.choice([1,2,3])
87 | sampled_indices.append([shape]*self.T)
88 | else:
89 | if not(action_index in self.active_actions):
90 | v = 0
91 | else:
92 | if self.shared_transition:
93 | v = self.vs[key]
94 | else:
95 | v = self.gen_v(value)
96 | p = self.gen_pos(value, v)
97 | p_and_v_list.append((p, v))
98 | indices = [p + t * v for t in range(self.T)]
99 | sampled_indices.append(indices)
100 | return np.array(p_and_v_list, dtype=np.uint8), np.array(sampled_indices, dtype=np.uint8).T
101 |
102 |
103 | def __getitem__(self, i):
104 | p_and_v_list, sample_indices = self.gen_factors()
105 | imgs = []
106 | for t in range(self.T):
107 | img = self.images[get_index(sample_indices[t])] / 255.
108 | img = self.transforms(img)
109 | imgs.append(img)
110 | if self.label_velo:
111 | return imgs, p_and_v_list[0][1][None], p_and_v_list[1][1][None], p_and_v_list[2][1][None], p_and_v_list[3][1][None], p_and_v_list[5][1][None]
112 | else:
113 | return imgs
114 |
115 |
--------------------------------------------------------------------------------
/figs/comp_error_3dshapes.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_3dshapes.pdf
--------------------------------------------------------------------------------
/figs/comp_error_mnist.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist.pdf
--------------------------------------------------------------------------------
/figs/comp_error_mnist_accl.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist_accl.pdf
--------------------------------------------------------------------------------
/figs/comp_error_mnist_bg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist_bg.pdf
--------------------------------------------------------------------------------
/figs/comp_error_mnist_bg_full.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_mnist_bg_full.pdf
--------------------------------------------------------------------------------
/figs/comp_error_smallNORB.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/comp_error_smallNORB.pdf
--------------------------------------------------------------------------------
/figs/disentangle_3dshapes.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_3dshapes.pdf
--------------------------------------------------------------------------------
/figs/disentangle_mnist.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_mnist.pdf
--------------------------------------------------------------------------------
/figs/disentangle_mnist_bg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_mnist_bg.pdf
--------------------------------------------------------------------------------
/figs/disentangle_mnist_bg_full.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_mnist_bg_full.pdf
--------------------------------------------------------------------------------
/figs/disentangle_smallNORB.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/disentangle_smallNORB.pdf
--------------------------------------------------------------------------------
/figs/equiv_error_3dshapes.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_3dshapes.pdf
--------------------------------------------------------------------------------
/figs/equiv_error_mnist.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist.pdf
--------------------------------------------------------------------------------
/figs/equiv_error_mnist_accl.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist_accl.pdf
--------------------------------------------------------------------------------
/figs/equiv_error_mnist_bg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist_bg.pdf
--------------------------------------------------------------------------------
/figs/equiv_error_mnist_bg_full.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_mnist_bg_full.pdf
--------------------------------------------------------------------------------
/figs/equiv_error_smallNORB.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/equiv_error_smallNORB.pdf
--------------------------------------------------------------------------------
/figs/sbd_3dshapes.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes.pdf
--------------------------------------------------------------------------------
/figs/sbd_3dshapes_component0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component0.pdf
--------------------------------------------------------------------------------
/figs/sbd_3dshapes_component1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component1.pdf
--------------------------------------------------------------------------------
/figs/sbd_3dshapes_component2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component2.pdf
--------------------------------------------------------------------------------
/figs/sbd_3dshapes_component3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component3.pdf
--------------------------------------------------------------------------------
/figs/sbd_3dshapes_component5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_component5.pdf
--------------------------------------------------------------------------------
/figs/sbd_3dshapes_components.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_3dshapes_components.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_bg.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_bg_components.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_components.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_bg_full.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_bg_full_component0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_component0.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_bg_full_component1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_component1.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_bg_full_component2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_component2.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_bg_full_components.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_bg_full_components.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_component0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_component0.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_component1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_component1.pdf
--------------------------------------------------------------------------------
/figs/sbd_mnist_component2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_mnist_component2.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB_component0.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component0.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB_component1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component1.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB_component2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component2.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB_component3.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component3.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB_component4.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component4.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB_component5.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_component5.pdf
--------------------------------------------------------------------------------
/figs/sbd_smallNORB_components.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/figs/sbd_smallNORB_components.pdf
--------------------------------------------------------------------------------
/gen_images/3dshapes_lstsq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_lstsq.png
--------------------------------------------------------------------------------
/gen_images/3dshapes_lstsq_multi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_lstsq_multi.png
--------------------------------------------------------------------------------
/gen_images/3dshapes_lstsq_rec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_lstsq_rec.png
--------------------------------------------------------------------------------
/gen_images/3dshapes_neuralM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_neuralM.png
--------------------------------------------------------------------------------
/gen_images/3dshapes_neural_trans.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/3dshapes_neural_trans.png
--------------------------------------------------------------------------------
/gen_images/mnist_accl_holstsq_equiv.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_accl_holstsq_equiv.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_full_lstsq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_lstsq.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_full_lstsq_multi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_lstsq_multi.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_full_lstsq_rec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_lstsq_rec.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_full_neuralM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_neuralM.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_full_neural_trans.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_full_neural_trans.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_lstsq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_lstsq.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_lstsq_multi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_lstsq_multi.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_lstsq_rec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_lstsq_rec.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_neuralM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_neuralM.png
--------------------------------------------------------------------------------
/gen_images/mnist_bg_neural_trans.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_bg_neural_trans.png
--------------------------------------------------------------------------------
/gen_images/mnist_lstsq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_lstsq.png
--------------------------------------------------------------------------------
/gen_images/mnist_lstsq_multi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_lstsq_multi.png
--------------------------------------------------------------------------------
/gen_images/mnist_lstsq_rec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_lstsq_rec.png
--------------------------------------------------------------------------------
/gen_images/mnist_neuralM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_neuralM.png
--------------------------------------------------------------------------------
/gen_images/mnist_neural_trans.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/mnist_neural_trans.png
--------------------------------------------------------------------------------
/gen_images/smallNORB_lstsq.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_lstsq.png
--------------------------------------------------------------------------------
/gen_images/smallNORB_lstsq_multi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_lstsq_multi.png
--------------------------------------------------------------------------------
/gen_images/smallNORB_lstsq_rec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_lstsq_rec.png
--------------------------------------------------------------------------------
/gen_images/smallNORB_neuralM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_neuralM.png
--------------------------------------------------------------------------------
/gen_images/smallNORB_neural_trans.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/gen_images/smallNORB_neural_trans.png
--------------------------------------------------------------------------------
/jobs/09022022/training_allmodels_1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT=$1
3 | DATASET_ROOT=$2
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do
7 | for model_name in lstsq lstsq_multi lstsq_rec neuralM neural_trans; do
8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
10 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT}
11 | done
12 | done
13 | done
14 |
15 | for seed in 1 2 3; do
16 | for dataset_name in mnist_accl; do
17 | for model_name in lstsq holstsq neural_trans; do
18 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
19 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
20 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
21 | done
22 | done
23 | done
24 |
25 | for seed in 1 2 3; do
26 | for dataset_name in mnist mnist_bg; do
27 | for model_name in simclr cpc; do
28 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
29 | --config_path=./configs/${dataset_name}/simclr/${model_name}.yml \
30 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
31 | done
32 | done
33 | done
34 |
--------------------------------------------------------------------------------
/jobs/09022022/training_allmodels_2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT=$1
3 | DATASET_ROOT=$2
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in mnist_bg; do
7 | for model_name in lstsq lstsq_multi lstsq_rec; do
8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \
9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000
11 | done
12 | for model_name in neuralM neural_trans; do
13 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \
14 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
15 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 training_loop.args.reconst_iter=200000
16 | done
17 | done
18 | done
19 |
20 |
21 | for seed in 1 2 3; do
22 | for dataset_name in mnist_accl; do
23 | for model_name in lstsq holstsq neural_trans; do
24 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
25 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
26 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
27 | done
28 | done
29 | done
30 |
31 |
32 | for seed in 1 2 3; do
33 | for dataset_name in mnist mnist_bg; do
34 | for model_name in simclr cpc; do
35 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
36 | --config_path=./configs/${dataset_name}/simclr/${model_name}.yml \
37 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
38 | done
39 | done
40 | done
41 |
--------------------------------------------------------------------------------
/jobs/09152022/training_allmodels_1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022'
3 | DATASET_ROOT='/home/TakeruMiyato/datasets'
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do
7 | for model_name in lstsq_multi lstsq_rec neuralM neural_trans; do
8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
11 | done
12 | done
13 | done
14 |
15 | for seed in 1 2 3; do
16 | for dataset_name in mnist_accl; do
17 | for model_name in lstsq holstsq neural_trans; do
18 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
19 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
20 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
21 | done
22 | done
23 | done
24 |
25 | for seed in 1 2 3; do
26 | for dataset_name in mnist mnist_bg; do
27 | for model_name in simclr cpc; do
28 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
29 | --config_path=./configs/${dataset_name}/simclr/${model_name}.yml \
30 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
31 | done
32 | done
33 | done
34 |
--------------------------------------------------------------------------------
/jobs/09152022/training_allmodels_2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022'
3 | DATASET_ROOT='/home/TakeruMiyato/datasets'
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in mnist_bg; do
7 | for model_name in lstsq lstsq_multi lstsq_rec; do
8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \
9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000
11 | done
12 | for model_name in neuralM neural_trans; do
13 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \
14 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
15 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 training_loop.args.reconst_iter=200000
16 | done
17 | done
18 | done
19 |
20 |
21 |
--------------------------------------------------------------------------------
/jobs/09152022/training_lstsqs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022'
3 | DATASET_ROOT='/home/TakeruMiyato/datasets'
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do
7 | for model_name in lstsq; do
8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
11 | done
12 | done
13 | done
--------------------------------------------------------------------------------
/jobs/09152022/training_neurals.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT='/mnt/research_logs/logs/09152022'
3 | DATASET_ROOT='/home/TakeruMiyato/datasets'
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in 3dshapes smallNORB; do
7 | for model_name in neuralM neural_trans; do
8 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
9 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
10 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT}
11 | done
12 | done
13 | done
--------------------------------------------------------------------------------
/jobs/09232022/training_neuralMlatentpred_mnist.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT='/mnt/research_logs/logs/09232022'
3 | DATASET_ROOT='/home/TakeruMiyato/datasets'
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in mnist; do
7 | for model_name in neuralM_latentpred; do
8 | for loss_latent_coeff in 0.0001 0.001 0.01; do
9 | for loss_reconst_coeff in 0.1 1.0 10; do
10 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-lrc${loss_reconst_coeff}-llc${loss_latent_coeff}-seed${seed}/ \
11 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
12 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} \
13 | model.args.loss_latent_coeff=${loss_latent_coeff} model.args.loss_reconst_coeff=${loss_reconst_coeff}
14 | done
15 | done
16 | done
17 | done
18 | done
--------------------------------------------------------------------------------
/jobs/09232022/training_neuralMlatentpred_mnist_bg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT='/mnt/research_logs/logs/09232022'
3 | DATASET_ROOT='/home/TakeruMiyato/datasets'
4 |
5 | for seed in 1 2 3; do
6 | for dataset_name in mnist_bg; do
7 | for model_name in neuralM_latentpred; do
8 | for loss_latent_coeff in 0.0001 0.001 0.01; do
9 | for loss_reconst_coeff in 0.1 1.0 10; do
10 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-lrc${loss_reconst_coeff}-llc${loss_latent_coeff}-seed${seed}/ \
11 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
12 | --attr seed=${seed} train_data.args.root=${DATASET_ROOT} \
13 | model.args.loss_latent_coeff=${loss_latent_coeff} model.args.loss_reconst_coeff=${loss_reconst_coeff}
14 | done
15 | done
16 | done
17 | done
18 | done
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/models/__init__.py
--------------------------------------------------------------------------------
/models/base_networks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | from models.resblock import Block, Conv1d1x1Block
5 | from einops.layers.torch import Rearrange
6 | from einops import repeat
7 |
8 |
9 | class Conv1d1x1Encoder(nn.Sequential):
10 | def __init__(self,
11 | dim_out=16,
12 | dim_hidden=128,
13 | act=nn.ReLU()):
14 | super().__init__(
15 | nn.LazyConv1d(dim_hidden, 1, 1, 0),
16 | Conv1d1x1Block(dim_hidden, dim_hidden, act=act),
17 | Conv1d1x1Block(dim_hidden, dim_hidden, act=act),
18 | Rearrange('n c s -> n s c'),
19 | nn.LayerNorm((dim_hidden)),
20 | Rearrange('n s c-> n c s'),
21 | act,
22 | nn.LazyConv1d(dim_out, 1, 1, 0)
23 | )
24 |
25 |
26 | class ResNetEncoder(nn.Module):
27 | def __init__(self,
28 | dim_latent=1024,
29 | k=1,
30 | act=nn.ReLU(),
31 | kernel_size=3,
32 | n_blocks=3):
33 | super().__init__()
34 | self.phi = nn.Sequential(
35 | nn.LazyConv2d(int(32 * k), 3, 1, 1),
36 | *[Block(int(32 * k) * (2 ** i), int(32 * k) * (2 ** (i+1)), int(32 * k) * (2 ** (i+1)),
37 | resample='down', activation=act, kernel_size=kernel_size) for i in range(n_blocks)],
38 | nn.GroupNorm(min(32, int(32 * k) * (2 ** n_blocks)),
39 | int(32 * k) * (2 ** n_blocks)),
40 | act)
41 | self.linear = nn.LazyLinear(
42 | dim_latent) if dim_latent > 0 else lambda x: x
43 |
44 | def __call__(self, x):
45 | h = x
46 | h = self.phi(h)
47 | h = h.reshape(h.shape[0], -1)
48 | h = self.linear(h)
49 | return h
50 |
51 |
52 | class ResNetDecoder(nn.Module):
53 | def __init__(self, ch_x, k=1, act=nn.ReLU(), kernel_size=3, bottom_width=4, n_blocks=3):
54 | super().__init__()
55 | self.bottom_width = bottom_width
56 | self.linear = nn.LazyLinear(int(32 * k) * (2 ** n_blocks))
57 | self.net = nn.Sequential(
58 | *[Block(int(32 * k) * (2 ** (i+1)), int(32 * k) * (2 ** i), int(32 * k) * (2 ** i),
59 | resample='up', activation=act, kernel_size=kernel_size, posemb=True) for i in range(n_blocks-1, -1, -1)],
60 | nn.GroupNorm(min(32, int(32 * k)), int(32 * k)),
61 | act,
62 | nn.Conv2d(int(32 * k), ch_x, 3, 1, 1)
63 | )
64 |
65 | def __call__(self, x):
66 | x = self.linear(x)
67 | x = repeat(x, 'n c -> n c h w',
68 | h=self.bottom_width, w=self.bottom_width)
69 | x = self.net(x)
70 | x = torch.sigmoid(x)
71 | return x
72 |
--------------------------------------------------------------------------------
/models/dynamics_models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from utils.laplacian import make_identity_like, tracenorm_of_normalized_laplacian, make_identity, make_diagonal
5 | import einops
6 | import pytorch_pfn_extras as ppe
7 |
8 |
9 | def _rep_M(M, T):
10 | return einops.repeat(M, "n a1 a2 -> n t a1 a2", t=T)
11 |
12 |
13 | def _loss(A, B):
14 | return torch.sum((A-B)**2)
15 |
16 |
17 | def _solve(A, B):
18 | ATA = A.transpose(-2, -1) @ A
19 | ATB = A.transpose(-2, -1) @ B
20 | return torch.linalg.solve(ATA, ATB)
21 |
22 |
23 | def loss_bd(M_star, alignment):
24 | # Block Diagonalization Loss
25 | S = torch.abs(M_star)
26 | STS = torch.matmul(S.transpose(-2, -1), S)
27 | if alignment:
28 | laploss_sts = tracenorm_of_normalized_laplacian(
29 | torch.mean(STS, 0))
30 | else:
31 | laploss_sts = torch.mean(
32 | tracenorm_of_normalized_laplacian(STS), 0)
33 | return laploss_sts
34 |
35 |
36 | def loss_orth(M_star):
37 | # Orthogonalization of M
38 | I = make_identity_like(M_star)
39 | return torch.mean(torch.sum((I-M_star @ M_star.transpose(-2, -1))**2, axis=(-2, -1)))
40 |
41 |
42 | class LinearTensorDynamicsLSTSQ(nn.Module):
43 |
44 | class DynFn(nn.Module):
45 | def __init__(self, M):
46 | super().__init__()
47 | self.M = M
48 |
49 | def __call__(self, H):
50 | return H @ _rep_M(self.M, T=H.shape[1])
51 |
52 | def inverse(self, H):
53 | M = _rep_M(self.M, T=H.shape[1])
54 | return torch.linalg.solve(M, H.transpose(-2, -1)).transpose(-2, -1)
55 |
56 | def __init__(self, alignment=True):
57 | super().__init__()
58 | self.alignment = alignment
59 |
60 | def __call__(self, H, return_loss=False, fix_indices=None):
61 | # Regress M.
62 | # Note: backpropagation is disabled when fix_indices is not None.
63 |
64 | # H0.shape = H1.shape [n, t, s, a]
65 | H0, H1 = H[:, :-1], H[:, 1:]
66 | # num_ts x ([len_ts -1] * dim_s) x dim_a
67 | # The difference between the the time shifted components
68 | loss_internal_0 = _loss(H0, H1)
69 | ppe.reporting.report({
70 | 'loss_internal_0': loss_internal_0.item()
71 | })
72 | _H0 = H0.reshape(H0.shape[0], -1, H0.shape[-1])
73 | _H1 = H1.reshape(H1.shape[0], -1, H1.shape[-1])
74 | if fix_indices is not None:
75 | # Note: backpropagation is disabled.
76 | dim_a = _H0.shape[-1]
77 | active_indices = np.array(list(set(np.arange(dim_a)) - set(fix_indices)))
78 | _M_star = _solve(_H0[:, :, active_indices],
79 | _H1[:, :, active_indices])
80 | M_star = make_identity(_H1.shape[0], _H1.shape[-1], _H1.device)
81 | M_star[:, active_indices[:, np.newaxis], active_indices] = _M_star
82 | else:
83 | M_star = _solve(_H0, _H1)
84 | dyn_fn = self.DynFn(M_star)
85 | loss_internal_T = _loss(dyn_fn(H0), H1)
86 | ppe.reporting.report({
87 | 'loss_internal_T': loss_internal_T.item()
88 | })
89 |
90 | # M_star is returned in the form of module, not the matrix
91 | if return_loss:
92 | losses = (loss_bd(dyn_fn.M, self.alignment),
93 | loss_orth(dyn_fn.M), loss_internal_T)
94 | return dyn_fn, losses
95 | else:
96 | return dyn_fn
97 |
98 |
99 | class HigherOrderLinearTensorDynamicsLSTSQ(LinearTensorDynamicsLSTSQ):
100 |
101 | class DynFn(nn.Module):
102 | def __init__(self, M):
103 | super().__init__()
104 | self.M = M
105 |
106 | def __call__(self, Hs):
107 | nHs = [None]*len(Hs)
108 | for l in range(len(Hs)-1, -1, -1):
109 | if l == len(Hs)-1:
110 | nHs[l] = Hs[l] @ _rep_M(self.M, Hs[l].shape[1])
111 | else:
112 | nHs[l] = Hs[l] @ nHs[l+1]
113 | return nHs
114 |
115 | def __init__(self, alignment=True, n_order=2):
116 | super().__init__(alignment)
117 | self.n_order = n_order
118 |
119 | def __call__(self, H, return_loss=False, fix_indices=None):
120 | assert H.shape[1] > self.n_order
121 | # H0.shape = H1.shape [n, t, s, a]
122 | H0, Hn = H[:, :-self.n_order], H[:, self.n_order:]
123 | loss_internal_0 = _loss(H0, Hn)
124 | ppe.reporting.report({
125 | 'loss_internal_0': loss_internal_0.item()
126 | })
127 | Ms = []
128 | _H = H
129 | if fix_indices is not None:
130 | raise NotImplementedError
131 | else:
132 | for n in range(self.n_order):
133 | # H0.shape = H1.shape [n, t, s, a]
134 | _H0, _H1 = _H[:, :-1], _H[:, 1:]
135 | if n == self.n_order - 1:
136 | _H0 = _H0.reshape(_H0.shape[0], -1, _H0.shape[-1])
137 | _H1 = _H1.reshape(_H1.shape[0], -1, _H1.shape[-1])
138 | _H = _solve(_H0, _H1) # [N, a, a]
139 | else:
140 | _H = _solve(_H0, _H1)[:, 1:] # [N, T-n, a, a]
141 | Ms.append(_H)
142 | dyn_fn = self.DynFn(Ms[-1])
143 | loss_internal_T = _loss(dyn_fn([H0] + Ms[:-1])[0], Hn)
144 | ppe.reporting.report({
145 | 'loss_internal_T': loss_internal_T.item()
146 | })
147 | # M_star is returned in the form of module, not the matrix
148 | if return_loss:
149 | losses = (loss_bd(dyn_fn.M, self.alignment),
150 | loss_orth(dyn_fn.M), loss_internal_T)
151 | return dyn_fn, Ms[:-1], losses
152 | else:
153 | return dyn_fn, Ms[:-1]
154 |
155 | # The fixed block model
156 |
157 |
158 | class MultiLinearTensorDynamicsLSTSQ(LinearTensorDynamicsLSTSQ):
159 |
160 | def __init__(self, dim_a, alignment=True, K=4):
161 | super().__init__(alignment=alignment)
162 | self.dim_a = dim_a
163 | self.alignment = alignment
164 | assert dim_a % K == 0
165 | self.K = K
166 |
167 | def __call__(self, H, return_loss=False, fix_indices=None):
168 | H0, H1 = H[:, :-1], H[:, 1:]
169 | # num_ts x ([len_ts -1] * dim_s) x dim_a
170 |
171 | # The difference between the the time shifted components
172 | loss_internal_0 = _loss(H0, H1)
173 |
174 | _H0 = H0.reshape(H.shape[0], -1, H.shape[3])
175 | _H1 = H1.reshape(H.shape[0], -1, H.shape[3])
176 |
177 | ppe.reporting.report({
178 | 'loss_internal_0': loss_internal_0.item()
179 | })
180 | M_stars = []
181 | for k in range(self.K):
182 | if fix_indices is not None and k in fix_indices:
183 | M_stars.append(make_identity(
184 | H.shape[0], self.dim_a//self.K, H.device))
185 | else:
186 | st = k*(self.dim_a//self.K)
187 | ed = (k+1)*(self.dim_a//self.K)
188 | M_stars.append(_solve(_H0[:, :, st:ed], _H1[:, :, st:ed]))
189 |
190 | # Contstruct block diagonals
191 | for k in range(self.K):
192 | if k == 0:
193 | M_star = M_stars[0]
194 | else:
195 | M1 = M_star
196 | M2 = M_stars[k]
197 | _M1 = torch.cat(
198 | [M1, torch.zeros(H.shape[0], M2.shape[1], M1.shape[2]).to(H.device)], axis=1)
199 | _M2 = torch.cat(
200 | [torch.zeros(H.shape[0], M1.shape[1], M2.shape[2]).to(H.device), M2], axis=1)
201 | M_star = torch.cat([_M1, _M2], axis=2)
202 | dyn_fn = self.DynFn(M_star)
203 | loss_internal_T = _loss(dyn_fn(H0), H1)
204 | ppe.reporting.report({
205 | 'loss_internal_T': loss_internal_T.item()
206 | })
207 |
208 | # M_star is returned in the form of module, not the matrix
209 | if return_loss:
210 | losses = (loss_bd(dyn_fn.M, self.alignment),
211 | loss_orth(dyn_fn.M), loss_internal_T)
212 | return dyn_fn, losses
213 | else:
214 | return dyn_fn
215 |
--------------------------------------------------------------------------------
/models/resblock.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | from utils.weight_standarization import WeightStandarization, WeightStandarization1d
8 | import torch.nn.utils.parametrize as P
9 | from utils.emb2d import Emb2D
10 |
11 |
12 | def upsample_conv(x, conv):
13 | # Upsample -> Conv
14 | x = nn.Upsample(scale_factor=2, mode='nearest')(x)
15 | x = conv(x)
16 | return x
17 |
18 |
19 | def conv_downsample(x, conv):
20 | # Conv -> Downsample
21 | x = conv(x)
22 | h = F.avg_pool2d(x, 2)
23 | return h
24 |
25 |
26 | class Block(nn.Module):
27 | def __init__(self,
28 | in_channels,
29 | out_channels,
30 | hidden_channels=None,
31 | kernel_size=3,
32 | padding=None,
33 | activation=F.relu,
34 | resample=None,
35 | group_norm=True,
36 | skip_connection=True,
37 | posemb=False):
38 | super(Block, self).__init__()
39 | if padding is None:
40 | padding = (kernel_size-1) // 2
41 | self.pe = Emb2D() if posemb else lambda x: x
42 |
43 | in_ch_conv = in_channels + self.pe.dim if posemb else in_channels
44 | self.skip_connection = skip_connection
45 | self.activation = activation
46 | self.resample = resample
47 | initializer = torch.nn.init.xavier_uniform_
48 | if self.resample is None or self.resample == 'up':
49 | hidden_channels = out_channels if hidden_channels is None else hidden_channels
50 | else:
51 | hidden_channels = in_channels if hidden_channels is None else hidden_channels
52 | self.c1 = nn.Conv2d(in_ch_conv, hidden_channels,
53 | kernel_size=kernel_size, padding=padding)
54 | self.c2 = nn.Conv2d(hidden_channels, out_channels,
55 | kernel_size=kernel_size, padding=padding)
56 | initializer(self.c1.weight, math.sqrt(2))
57 | initializer(self.c2.weight, math.sqrt(2))
58 | P.register_parametrization(
59 | self.c1, 'weight', WeightStandarization())
60 | P.register_parametrization(
61 | self.c2, 'weight', WeightStandarization())
62 |
63 | if group_norm:
64 | self.b1 = nn.GroupNorm(min(32, in_channels), in_channels)
65 | self.b2 = nn.GroupNorm(min(32, hidden_channels), hidden_channels)
66 | else:
67 | self.b1 = self.b2 = lambda x: x
68 | if self.skip_connection:
69 | self.c_sc = nn.Conv2d(in_ch_conv, out_channels,
70 | kernel_size=1, padding=0)
71 | initializer(self.c_sc.weight)
72 |
73 | def residual(self, x):
74 | x = self.b1(x)
75 | x = self.activation(x)
76 | if self.resample == 'up':
77 | x = nn.Upsample(scale_factor=2, mode='nearest')(x)
78 | x = self.pe(x)
79 | x = self.c1(x)
80 | x = self.b2(x)
81 | x = self.activation(x)
82 | x = self.c2(x)
83 | if self.resample == 'down':
84 | x = F.avg_pool2d(x, 2)
85 | return x
86 |
87 | def shortcut(self, x):
88 | # Upsample -> Conv
89 | if self.resample == 'up':
90 | x = nn.Upsample(scale_factor=2, mode='nearest')(x)
91 | x = self.pe(x)
92 | x = self.c_sc(x)
93 |
94 | elif self.resample == 'down':
95 | x = self.pe(x)
96 | x = self.c_sc(x)
97 | x = F.avg_pool2d(x, 2)
98 | else:
99 | x = self.pe(x)
100 | x = self.c_sc(x)
101 | return x
102 |
103 | def __call__(self, x):
104 | if self.skip_connection:
105 | return self.residual(x) + self.shortcut(x)
106 | else:
107 | return self.residual(x)
108 |
109 |
110 | class Conv1d1x1Block(nn.Module):
111 | def __init__(self,
112 | in_channels,
113 | out_channels,
114 | hidden_channels=None,
115 | act=F.relu):
116 | super().__init__()
117 |
118 | self.act = act
119 | initializer = torch.nn.init.xavier_uniform_
120 | hidden_channels = out_channels if hidden_channels is None else hidden_channels
121 | self.c1 = nn.Conv1d(in_channels, hidden_channels, 1, 1, 0)
122 | self.c2 = nn.Conv1d(hidden_channels, out_channels, 1, 1, 0)
123 | initializer(self.c1.weight, math.sqrt(2))
124 | initializer(self.c2.weight, math.sqrt(2))
125 | P.register_parametrization(
126 | self.c1, 'weight', WeightStandarization1d())
127 | P.register_parametrization(
128 | self.c2, 'weight', WeightStandarization1d())
129 | self.norm1 = nn.LayerNorm((in_channels))
130 | self.norm2 = nn.LayerNorm((hidden_channels))
131 | self.c_sc = nn.Conv1d(in_channels, out_channels, 1, 1, 0)
132 | initializer(self.c_sc.weight)
133 |
134 | def residual(self, x):
135 | x = self.norm1(x.transpose(-2, -1)).transpose(-2, -1)
136 | x = self.act(x)
137 | x = self.c1(x)
138 | x = self.norm2(x.transpose(-2, -1)).transpose(-2, -1)
139 | x = self.act(x)
140 | x = self.c2(x)
141 | return x
142 |
143 | def shortcut(self, x):
144 | x = self.c_sc(x)
145 | return x
146 |
147 | def __call__(self, x):
148 | return self.residual(x) + self.shortcut(x)
--------------------------------------------------------------------------------
/models/seqae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from models import dynamics_models
5 | import torch.nn.utils.parametrize as P
6 | from models.dynamics_models import LinearTensorDynamicsLSTSQ, MultiLinearTensorDynamicsLSTSQ, HigherOrderLinearTensorDynamicsLSTSQ
7 | from models.base_networks import ResNetEncoder, ResNetDecoder, Conv1d1x1Encoder
8 | from einops import rearrange, repeat
9 | from utils.clr import simclr
10 |
11 |
12 | class SeqAELSTSQ(nn.Module):
13 | def __init__(
14 | self,
15 | dim_a,
16 | dim_m,
17 | alignment=False,
18 | ch_x=3,
19 | k=1.0,
20 | kernel_size=3,
21 | change_of_basis=False,
22 | predictive=True,
23 | bottom_width=4,
24 | n_blocks=3,
25 | *args,
26 | **kwargs):
27 | super().__init__()
28 | self.dim_a = dim_a
29 | self.dim_m = dim_m
30 | self.predictive = predictive
31 | self.enc = ResNetEncoder(
32 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks)
33 | self.dec = ResNetDecoder(
34 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks)
35 | self.dynamics_model = LinearTensorDynamicsLSTSQ(alignment=alignment)
36 | if change_of_basis:
37 | self.change_of_basis = nn.Parameter(
38 | torch.empty(dim_a, dim_a))
39 | nn.init.eye_(self.change_of_basis)
40 |
41 | def _encode_base(self, xs, enc):
42 | shape = xs.shape
43 | x = torch.reshape(xs, (shape[0] * shape[1], *shape[2:]))
44 | H = enc(x)
45 | H = torch.reshape(
46 | H, (shape[0], shape[1], *H.shape[1:]))
47 | return H
48 |
49 | def encode(self, xs):
50 | H = self._encode_base(xs, self.enc)
51 | H = torch.reshape(
52 | H, (H.shape[0], H.shape[1], self.dim_m, self.dim_a))
53 | if hasattr(self, "change_of_basis"):
54 | H = H @ repeat(self.change_of_basis,
55 | 'a1 a2 -> n t a1 a2', n=H.shape[0], t=H.shape[1])
56 | return H
57 |
58 | def phi(self, xs):
59 | return self._encode_base(xs, self.enc.phi)
60 |
61 | def get_M(self, xs):
62 | dyn_fn = self.dynamics_fn(xs)
63 | return dyn_fn.M
64 |
65 | def decode(self, H):
66 | if hasattr(self, "change_of_basis"):
67 | H = H @ repeat(torch.linalg.inv(self.change_of_basis),
68 | 'a1 a2 -> n t a1 a2', n=H.shape[0], t=H.shape[1])
69 | n, t = H.shape[:2]
70 | if hasattr(self, "pidec"):
71 | H = rearrange(H, 'n t d_s d_a -> (n t) d_a d_s')
72 | H = self.pidec(H)
73 | else:
74 | H = rearrange(H, 'n t d_s d_a -> (n t) (d_s d_a)')
75 | x_next_preds = self.dec(H)
76 | x_next_preds = torch.reshape(
77 | x_next_preds, (n, t, *x_next_preds.shape[1:]))
78 | return x_next_preds
79 |
80 | def dynamics_fn(self, xs, return_loss=False, fix_indices=None):
81 | H = self.encode(xs)
82 | return self.dynamics_model(H, return_loss=return_loss, fix_indices=fix_indices)
83 |
84 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False):
85 | xs_cond = xs[:, :T_cond]
86 | xs_pred = self(xs_cond, return_reg_loss=return_reg_loss,
87 | n_rolls=xs.shape[1] - T_cond, predictive=self.predictive, reconst=reconst)
88 | if return_reg_loss:
89 | xs_pred, reg_losses = xs_pred
90 | if reconst:
91 | xs_target = xs
92 | else:
93 | xs_target = xs[:, T_cond:] if self.predictive else xs[:, 1:]
94 | loss = torch.mean(
95 | torch.sum((xs_target - xs_pred) ** 2, axis=[2, 3, 4]))
96 | return (loss, reg_losses) if return_reg_loss else loss
97 |
98 | def __call__(self, xs_cond, return_reg_loss=False, n_rolls=1, fix_indices=None, predictive=True, reconst=False):
99 | # Encoded Latent. Num_ts x len_ts x dim_m x dim_a
100 | H = self.encode(xs_cond)
101 |
102 | # ==Esitmate dynamics==
103 | ret = self.dynamics_model(
104 | H, return_loss=return_reg_loss, fix_indices=fix_indices)
105 | if return_reg_loss:
106 | # fn is a map by M_star. Loss is the training external loss
107 | fn, losses = ret
108 | else:
109 | fn = ret
110 |
111 | if predictive:
112 | H_last = H[:, -1:]
113 | H_preds = [H] if reconst else []
114 | array = np.arange(n_rolls)
115 | else:
116 | H_last = H[:, :1]
117 | H_preds = [H[:, :1]] if reconst else []
118 | array = np.arange(xs_cond.shape[1] + n_rolls - 1)
119 |
120 | for _ in array:
121 | H_last = fn(H_last)
122 | H_preds.append(H_last)
123 | H_preds = torch.cat(H_preds, axis=1)
124 | # Prediction in the observation space
125 | x_preds = self.decode(H_preds)
126 | if return_reg_loss:
127 | return x_preds, losses
128 | else:
129 | return x_preds
130 |
131 |
132 | def loss_equiv(self, xs, T_cond=2, reduce=False):
133 | bsize = len(xs)
134 | xs_cond = xs[:, :T_cond]
135 | xs_target = xs[:, T_cond:]
136 | H = self.encode(xs_cond[:, -1:])
137 | dyn_fn = self.dynamics_fn(xs_cond)
138 |
139 | H_last = H
140 | H_preds = []
141 | n_rolls = xs.shape[1] - T_cond
142 | for _ in np.arange(n_rolls):
143 | H_last = dyn_fn(H_last)
144 | H_preds.append(H_last)
145 | H_pred = torch.cat(H_preds, axis=1)
146 | # swapping M
147 | dyn_fn.M = dyn_fn.M[torch.arange(-1, bsize-1)]
148 |
149 | H_last = H
150 | H_preds_perm = []
151 | for _ in np.arange(n_rolls):
152 | H_last = dyn_fn(H_last)
153 | H_preds_perm.append(H_last)
154 | H_pred_perm = torch.cat(H_preds_perm, axis=1)
155 |
156 | xs_pred = self.decode(H_pred)
157 | xs_pred_perm = self.decode(H_pred_perm)
158 | reduce_dim = (1,2,3,4,5) if reduce else (2,3,4)
159 | loss = torch.sum((xs_target-xs_pred)**2, dim=reduce_dim).detach().cpu().numpy()
160 | loss_perm = torch.sum((xs_target-xs_pred_perm)**2, dim=reduce_dim).detach().cpu().numpy()
161 | return loss, loss_perm
162 |
163 |
164 | class SeqAEHOLSTSQ(SeqAELSTSQ):
165 | # Higher order version of SeqAELSTSQ
166 | def __init__(
167 | self,
168 | dim_a,
169 | dim_m,
170 | alignment=False,
171 | ch_x=3,
172 | k=1.0,
173 | kernel_size=3,
174 | change_of_basis=False,
175 | predictive=True,
176 | bottom_width=4,
177 | n_blocks=3,
178 | n_order=2,
179 | *args,
180 | **kwargs):
181 | super(SeqAELSTSQ, self).__init__()
182 | self.dim_a = dim_a
183 | self.dim_m = dim_m
184 | self.predictive = predictive
185 | self.enc = ResNetEncoder(
186 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks)
187 | self.dec = ResNetDecoder(
188 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks)
189 | self.dynamics_model = HigherOrderLinearTensorDynamicsLSTSQ(
190 | alignment=alignment, n_order=n_order)
191 | if change_of_basis:
192 | self.change_of_basis = nn.Parameter(
193 | torch.empty(dim_a, dim_a))
194 | nn.init.eye_(self.change_of_basis)
195 |
196 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False):
197 | if reconst:
198 | raise NotImplementedError
199 | xs_cond = xs[:, :T_cond]
200 | xs_pred = self(xs_cond, predictive=self.predictive, return_reg_loss=return_reg_loss,
201 | n_rolls=xs.shape[1] - T_cond)
202 | if return_reg_loss:
203 | xs_pred, reg_losses = xs_pred
204 | xs_target = xs[:, T_cond:] if self.predictive else xs[:, 1:]
205 | loss = torch.mean(
206 | torch.sum((xs_target - xs_pred) ** 2, axis=[2, 3, 4]))
207 | return (loss, reg_losses) if return_reg_loss else loss
208 |
209 |
210 | def __call__(self, xs, n_rolls=1, fix_indices=None, predictive=True, return_reg_loss=False):
211 | # Encoded Latent. Num_ts x len_ts x dim_m x dim_a
212 | H = self.encode(xs)
213 |
214 | # ==Esitmate dynamics==
215 | ret = self.dynamics_model(
216 | H, return_loss=return_reg_loss, fix_indices=fix_indices)
217 | if return_reg_loss:
218 | # fn is a map by M_star. Loss is the training external loss
219 | fn, Ms, losses = ret
220 | else:
221 | fn, Ms = ret
222 |
223 | if predictive:
224 | Hs_last = [H[:, -1:]] + [M[:, -1:] for M in Ms]
225 | array = np.arange(n_rolls)
226 | else:
227 | Hs_last = [H[:, :1]] + [M[:, :1] for M in Ms]
228 | array = np.arange(xs.shape[1] + n_rolls - 1)
229 |
230 | # Create prediction for the unseen future
231 | H_preds = []
232 | for _ in array:
233 | Hs_last = fn(Hs_last)
234 | H_preds.append(Hs_last[0])
235 | H_preds = torch.cat(H_preds, axis=1)
236 | x_preds = self.decode(H_preds)
237 | if return_reg_loss:
238 | return x_preds, losses
239 | else:
240 | return x_preds
241 |
242 | def loss_equiv(self, xs, T_cond=5, reduce=False, return_generated_images=False):
243 | bsize = len(xs)
244 | xs_cond = xs[:, :T_cond]
245 | xs_target = xs[:, T_cond:]
246 | H = self.encode(xs_cond[:, -1:])
247 | dyn_fn, Ms = self.dynamics_fn(xs_cond)
248 |
249 | H_last = [H] + [M[:, -1:] for M in Ms]
250 | H_preds = []
251 | n_rolls = xs.shape[1] - T_cond
252 | for _ in np.arange(n_rolls):
253 | H_last = dyn_fn(H_last)
254 | H_preds.append(H_last[0])
255 | H_pred = torch.cat(H_preds, axis=1)
256 | # swapping M
257 | dyn_fn.M = dyn_fn.M[torch.arange(-1, bsize-1)]
258 | Ms = [M[torch.arange(-1, bsize-1)] for M in Ms]
259 |
260 | H_last = [H] + [M[:, -1:] for M in Ms]
261 | H_preds_perm = []
262 | for _ in np.arange(n_rolls):
263 | H_last = dyn_fn(H_last)
264 | H_preds_perm.append(H_last[0])
265 | H_pred_perm = torch.cat(H_preds_perm, axis=1)
266 |
267 | xs_pred = self.decode(H_pred)
268 | xs_pred_perm = self.decode(H_pred_perm)
269 | loss = torch.sum((xs_target-xs_pred)**2, dim=(2,3,4)).detach().cpu().numpy()
270 | loss_perm = torch.sum((xs_target-xs_pred_perm)**2, dim=(2,3,4)).detach().cpu().numpy()
271 | if reduce:
272 | loss = torch.mean(loss)
273 | loss_perm = torch.mean(loss_perm)
274 | if return_generated_images:
275 | return (loss, loss_perm), (xs_pred, xs_pred_perm)
276 | else:
277 | return loss, loss_perm
278 |
279 |
280 |
281 | class SeqAEMultiLSTSQ(SeqAELSTSQ):
282 | def __init__(
283 | self,
284 | dim_a,
285 | dim_m,
286 | alignment=False,
287 | ch_x=3,
288 | k=1.0,
289 | kernel_size=3,
290 | change_of_basis=False,
291 | predictive=True,
292 | bottom_width=4,
293 | n_blocks=3,
294 | K=8,
295 | *args,
296 | **kwargs):
297 | super(SeqAELSTSQ, self).__init__()
298 | self.dim_a = dim_a
299 | self.dim_m = dim_m
300 | self.predictive = predictive
301 | self.K = K
302 | self.enc = ResNetEncoder(
303 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks)
304 | self.dec = ResNetDecoder(
305 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks)
306 | self.dynamics_model = MultiLinearTensorDynamicsLSTSQ(
307 | dim_a, alignment=alignment, K=K)
308 | if change_of_basis:
309 | self.change_of_basis = nn.Parameter(
310 | torch.empty(dim_a, dim_a))
311 | nn.init.eye_(self.change_of_basis)
312 |
313 | def get_blocks_of_M(self, xs):
314 | M = self.get_M(xs)
315 | blocks = []
316 | for k in range(self.K):
317 | dim_block = self.dim_a // self.K
318 | blocks.append(M[:, k*dim_block:(k+1)*dim_block]
319 | [:, :, k*dim_block:(k+1)*dim_block])
320 | blocks_of_M = torch.stack(blocks, 1)
321 | return blocks_of_M
322 |
323 |
324 | class SeqAENeuralM(SeqAELSTSQ):
325 | def __init__(
326 | self,
327 | dim_a,
328 | dim_m,
329 | ch_x=3,
330 | k=1.0,
331 | alignment=False,
332 | kernel_size=3,
333 | predictive=True,
334 | bottom_width=4,
335 | n_blocks=3,
336 | *args,
337 | **kwargs):
338 | super(SeqAELSTSQ, self).__init__()
339 | self.dim_a = dim_a
340 | self.dim_m = dim_m
341 | self.predictive = predictive
342 | self.alignment = alignment
343 | self.initial_scale_M = 0.01
344 | self.enc = ResNetEncoder(
345 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks)
346 | self.M_net = ResNetEncoder(
347 | dim_a*dim_a, k=k, kernel_size=kernel_size, n_blocks=n_blocks)
348 | self.dec = ResNetDecoder(
349 | ch_x, k=k, kernel_size=kernel_size, n_blocks=n_blocks, bottom_width=bottom_width)
350 |
351 | def dynamics_fn(self, xs):
352 | M = self.get_M(xs)
353 | dyn_fn = dynamics_models.LinearTensorDynamicsLSTSQ.DynFn(M)
354 | return dyn_fn
355 |
356 | def get_M(self, xs):
357 | xs = rearrange(xs, 'n t c h w -> n (t c) h w')
358 | M = self.M_net(xs)
359 | M = rearrange(M, 'n (a_1 a_2) -> n a_1 a_2', a_1=self.dim_a)
360 | M = self.initial_scale_M * M
361 | return M
362 |
363 | def __call__(self, xs, n_rolls=1, return_reg_loss=False, predictive=True, reconst=False):
364 | # ==Esitmate dynamics==
365 | fn = self.dynamics_fn(xs)
366 |
367 | if reconst:
368 | H = self.encode(xs)
369 | if predictive:
370 | H_last = H[:, -1:]
371 | else:
372 | H_last = H[:, :1]
373 | else:
374 | H_last = self.encode(xs[:, -1:] if predictive else xs[:, :1])
375 |
376 | if predictive:
377 | H_preds = [H] if reconst else []
378 | array = np.arange(n_rolls)
379 | else:
380 | H_preds = [H[:, :1]] if reconst else []
381 | array = np.arange(xs.shape[1] + n_rolls - 1)
382 |
383 | # Create prediction for the unseen future
384 | for _ in array:
385 | H_last = fn(H_last)
386 | H_preds.append(H_last)
387 | H_preds = torch.cat(H_preds, axis=1)
388 | x_preds = self.decode(H_preds)
389 | if return_reg_loss:
390 | losses = (dynamics_models.loss_bd(fn.M, self.alignment),
391 | dynamics_models.loss_orth(fn.M), 0)
392 | return x_preds, losses
393 | else:
394 | return x_preds
395 |
396 | class SeqAENeuralMLatentPredict(SeqAENeuralM):
397 | def __init__(self,
398 | dim_a,
399 | dim_m,
400 | ch_x=3,
401 | k=1.0,
402 | alignment=False,
403 | kernel_size=3,
404 | predictive=True,
405 | bottom_width=4,
406 | n_blocks=3,
407 | loss_latent_coeff=0,
408 | loss_pred_coeff=1.0,
409 | loss_reconst_coeff=0,
410 | normalize=True,
411 | *args,
412 | **kwargs):
413 | assert predictive
414 | super().__init__(
415 | dim_a=dim_a,
416 | dim_m=dim_m,
417 | ch_x=ch_x,
418 | k=k,
419 | alignment=alignment,
420 | kernel_size=kernel_size,
421 | predictive=predictive,
422 | bottom_width=bottom_width,
423 | n_blocks=n_blocks,
424 | )
425 | self.loss_reconst_coeff = loss_reconst_coeff
426 | self.loss_pred_coeff = loss_pred_coeff
427 | self.loss_latent_coeff = loss_latent_coeff
428 | self.normalize = normalize
429 |
430 | def normalize_isotypic_copy(self, H):
431 | isotype_norm = torch.sqrt(torch.sum(H**2, axis=2, keepdims=True))
432 | H = H / isotype_norm
433 | return H
434 |
435 | #encoding function with isotypic column normalization
436 | def encode(self, xs):
437 | H = super().encode(xs)
438 | if self.normalize:
439 | H = self.normalize_isotypic_copy(H)
440 | return H
441 |
442 | def latent_error(self, H_preds, H_target):
443 | latent_e = torch.mean(torch.sum((H_preds - H_target)**2, axis=(2,3)))
444 | return latent_e
445 |
446 | def obs_error(self, xs_1, xs_2):
447 | obs_e = torch.mean(torch.sum((xs_1 - xs_2)**2, axis=(2,3,4)))
448 | return obs_e
449 |
450 | def __call__(self, xs, n_rolls=1, T_cond=2, return_losses=False, return_reg_losses=False):
451 | xs_cond, xs_target =xs[:, :T_cond], xs[:, T_cond:]
452 | fn = self.dynamics_fn(xs_cond)
453 | H_cond, H_target = self.encode(xs_cond), self.encode(xs_target)
454 | H_last = H_cond[:, -1:]
455 | H_preds=[H_cond]
456 | array = np.arange(n_rolls)
457 |
458 | for _ in array:
459 | H_last = fn(H_last)
460 | H_preds.append(H_last)
461 | H_preds = torch.cat(H_preds, axis=1)
462 | xs_preds = self.decode(H_preds)
463 | ret = [xs_preds]
464 | if return_losses:
465 | losses = {}
466 | losses['loss_reconst'] = self.obs_error(xs_preds[:, :T_cond], xs_cond) if self.loss_reconst_coeff > 0 else torch.tensor([0]).to(xs.device)
467 | losses['loss_pred'] = self.obs_error(xs_preds[:, T_cond:], xs_target) if self.loss_pred_coeff > 0 else torch.tensor([0]).to(xs.device)
468 | losses['loss_latent'] = self.latent_error(H_preds[:, T_cond:], H_target) if self.loss_latent_coeff > 0 else torch.tensor([0]).to(xs.device)
469 | ret += [losses]
470 | if return_reg_losses:
471 | ret += [(dynamics_models.loss_bd(fn.M, self.alignment),
472 | dynamics_models.loss_orth(fn.M), 0)]
473 | return ret
474 |
475 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False):
476 | ret = self(xs, return_losses=True, return_reg_losses=return_reg_loss, T_cond= T_cond,
477 | n_rolls=xs.shape[1] - T_cond)
478 | if return_reg_loss:
479 | _, losses, reg_losses = ret
480 | else:
481 | _, losses = ret
482 |
483 | total_loss = self.loss_reconst_coeff * losses['loss_reconst'] \
484 | + self.loss_pred_coeff * losses['loss_pred'] \
485 | + self.loss_latent_coeff * losses['loss_latent']
486 | return (total_loss, reg_losses) if return_reg_loss else total_loss
487 |
488 |
489 |
490 | class SeqAENeuralTransition(SeqAELSTSQ):
491 | def __init__(
492 | self,
493 | dim_a,
494 | dim_m,
495 | ch_x=3,
496 | k=1.0,
497 | kernel_size=3,
498 | T_cond=2,
499 | bottom_width=4,
500 | n_blocks=3,
501 | *args,
502 | **kwargs):
503 | super(SeqAELSTSQ, self).__init__()
504 | self.dim_a = dim_a
505 | self.dim_m = dim_m
506 | self.T_cond = T_cond
507 | self.enc = ResNetEncoder(
508 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks)
509 | self.ar = Conv1d1x1Encoder(dim_out=dim_a)
510 | self.dec = ResNetDecoder(
511 | ch_x, k=k, kernel_size=kernel_size, bottom_width=bottom_width, n_blocks=n_blocks)
512 |
513 | def loss(self, xs, return_reg_loss=False, T_cond=2, reconst=False):
514 | assert T_cond == self.T_cond
515 | xs_cond = xs[:, :T_cond]
516 | xs_pred = self(xs_cond, n_rolls=xs.shape[1] - T_cond, reconst=reconst)
517 | xs_target = xs if reconst else xs[:, T_cond:]
518 | loss = torch.mean(
519 | torch.sum((xs_target - xs_pred) ** 2, axis=[2, 3, 4]))
520 | if return_reg_loss:
521 | return loss, [torch.Tensor(np.array(0, dtype=np.float32)).to(xs.device)] * 3
522 | else:
523 | return loss
524 |
525 | def get_M(self, xs):
526 | T = xs.shape[1]
527 | xs = rearrange(xs, 'n t c h w -> (n t) c h w')
528 | H = self.enc(xs)
529 | H = rearrange(H, '(n t) c -> n (t c)', t=T)
530 | return H
531 |
532 | def __call__(self, xs, n_rolls=1, reconst=False):
533 | # ==Esitmate dynamics==
534 | H = self.encode(xs)
535 |
536 | array = np.arange(n_rolls)
537 | H_preds = [H] if reconst else []
538 | # Create prediction for the unseen future
539 | for _ in array:
540 | H_pred = self.ar(rearrange(H, 'n t s a -> n (t a) s'))
541 | H_pred = rearrange(
542 | H_pred, 'n (t a) s-> n t s a', t=1, a=self.dim_a)
543 | H_preds.append(H_pred)
544 | H = torch.cat([H[:, 1:], H_pred], dim=1)
545 | H_preds = torch.cat(H_preds, axis=1)
546 | # Prediction in the observation space
547 | return self.decode(H_preds)
548 |
549 |
550 | class CPC(SeqAELSTSQ):
551 | def __init__(
552 | self,
553 | dim_a,
554 | dim_m,
555 | k=1.0,
556 | kernel_size=3,
557 | temp=0.01,
558 | normalize=True,
559 | loss_type='cossim',
560 | n_blocks=3,
561 | *args,
562 | **kwargs):
563 | super(SeqAELSTSQ, self).__init__()
564 | self.dim_a = dim_a
565 | self.dim_m = dim_m
566 | self.normalize = normalize
567 | self.temp = temp
568 | self.loss_type = loss_type
569 | self.enc = ResNetEncoder(
570 | dim_a*dim_m, k=k, kernel_size=kernel_size, n_blocks=n_blocks)
571 | self.ar = Conv1d1x1Encoder(dim_out=dim_a*dim_m)
572 |
573 | def __call__(self, xs):
574 | H = self.encode(xs) # [n, t, s, a]
575 |
576 | # Create prediction for the unseen future
577 | H = rearrange(H, 'n t s a -> n (t a) s')
578 |
579 | # Obtain c in CPC
580 | H_pred = self.ar(H) # [n a s]
581 | H_pred = rearrange(H_pred, 'n a s -> n s a')
582 |
583 | return H_pred
584 |
585 | def get_M(self, xs):
586 | T = xs.shape[1]
587 | xs = rearrange(xs, 'n t c h w -> (n t) c h w')
588 | H = self.enc(xs)
589 | H = rearrange(H, '(n t) c -> n (t c)', t=T)
590 | return H
591 |
592 | def loss(self, xs, return_reg_loss=True, T_cond=2, reconst=False):
593 | T_pred = xs.shape[1] - T_cond
594 | assert T_pred == 1
595 | # Encoded Latent. Num_ts x len_ts x dim_m x dim_a
596 | H = self.encode(xs) # [n, t, s, a]
597 |
598 | # Create prediction for the unseen future
599 | H_cond = H[:, :T_cond]
600 | H_cond = rearrange(H_cond, 'n t s a -> n (t a) s')
601 |
602 | # Obtain c in CPC
603 | H_pred = self.ar(H_cond) # [n a s]
604 | H_pred = rearrange(H_pred, 'n a s -> n s a')
605 |
606 | H_true = H[:, -1] # n s a
607 | H_true = rearrange(H_true, 'n s a -> n (s a)')
608 | H_pred = rearrange(H_pred, 'n s a -> n (s a)')
609 | loss = simclr([H_pred, H_true], self.temp,
610 | normalize=self.normalize, loss_type=self.loss_type)
611 | if return_reg_loss:
612 | reg_losses = [torch.Tensor(np.array(0, dtype=np.float32))]*3
613 | return loss, reg_losses
614 | else:
615 | return loss
616 |
--------------------------------------------------------------------------------
/models/simclr_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.base_networks import ResNetEncoder
4 | from einops import rearrange
5 |
6 |
7 | class ResNetwProjHead(nn.Module):
8 |
9 | def __init__(self, dim_mlp=512, dim_head=128, k=1, act=nn.ReLU(), n_blocks=3):
10 | super().__init__()
11 | self.enc = ResNetEncoder(
12 | dim_latent=0, k=k, n_blocks=n_blocks)
13 | self.projhead = nn.Sequential(
14 | nn.LazyLinear(dim_mlp),
15 | act,
16 | nn.LazyLinear(dim_head))
17 |
18 | def _encode_base(self, xs, enc):
19 | shape = xs.shape
20 | x = torch.reshape(xs, (shape[0] * shape[1], *shape[2:]))
21 | H = enc(x)
22 | H = torch.reshape(H, (shape[0], shape[1], *H.shape[1:]))
23 | return H
24 |
25 | def __call__(self, xs):
26 | return self._encode_base(xs, lambda x: self.projhead(self.enc(x)))
27 |
28 | def phi(self, xs):
29 | return self._encode_base(xs, self.enc.phi)
30 |
31 | def get_M(self, xs):
32 | T = xs.shape[1]
33 | xs = rearrange(xs, 'n t c h w -> (n t) c h w')
34 | H = self.enc(xs)
35 | H = rearrange(H, '(n t) c -> n (t c)', t=T)
36 | return H
37 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.10.0
2 | pytorch-pfn-extras
3 | scikit-image
4 | opencv-python
5 | pytorch_pfn_extras
6 | moviepy
7 | ffmpeg-python
8 | einops
9 | tqdm
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import yaml
4 | import copy
5 | import functools
6 | import random
7 | import argparse
8 | import numpy as np
9 |
10 | import torch
11 | import torch.backends.cudnn as cudnn
12 | from torch.utils.data import DataLoader
13 | import pytorch_pfn_extras as ppe
14 | from pytorch_pfn_extras.training import extensions
15 | from utils import yaml_utils as yu
16 |
17 |
18 | def train(config):
19 |
20 | torch.cuda.empty_cache()
21 | torch.manual_seed(config['seed'])
22 | random.seed(config['seed'])
23 | np.random.seed(config['seed'])
24 |
25 | if torch.cuda.is_available():
26 | device = torch.device('cuda')
27 | cudnn.deterministic = True
28 | cudnn.benchmark = True
29 | else:
30 | device = torch.device('cpu')
31 | gpu_index = -1
32 |
33 | # Dataaset
34 | data = yu.load_component(config['train_data'])
35 | train_loader = DataLoader(
36 | data, batch_size=config['batchsize'], shuffle=True, num_workers=config['num_workers'])
37 |
38 | # Def. of Model and optimizer
39 | model = yu.load_component(config['model'])
40 | model.to(device)
41 | optimizer = torch.optim.Adam(model.parameters(), config['lr'])
42 |
43 | manager = ppe.training.ExtensionsManager(
44 | model, optimizer, None,
45 | iters_per_epoch=len(train_loader),
46 | out_dir=config['log_dir'],
47 | stop_trigger=(config['max_iteration'], 'iteration')
48 | )
49 |
50 | manager.extend(
51 | extensions.PrintReport(
52 | ['epoch', 'iteration', 'train/loss', 'train/loss_bd', 'train/loss_orth', 'loss_internal_0', 'loss_internal_T', 'elapsed_time']),
53 | trigger=(config['report_freq'], 'iteration'))
54 | manager.extend(extensions.LogReport(
55 | trigger=(config['report_freq'], 'iteration')))
56 | manager.extend(
57 | extensions.snapshot(
58 | target=model, filename='snapshot_model_iter_{.iteration}'),
59 | trigger=(config['model_snapshot_freq'], 'iteration'))
60 | manager.extend(
61 | extensions.snapshot(
62 | target=manager, filename='snapshot_manager_iter_{.iteration}', n_retains=1),
63 | trigger=(config['manager_snapshot_freq'], 'iteration'))
64 | # Run training loop
65 | print("Start training...")
66 | yu.load_component_fxn(config['training_loop'])(
67 | manager, model, optimizer, train_loader, config, device)
68 |
69 |
70 | if __name__ == '__main__':
71 | # Loading the configuration arguments from specified config path
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--log_dir', type=str)
74 | parser.add_argument('--config_path', type=str)
75 | parser.add_argument('-a', '--attrs', nargs='*', default=())
76 | parser.add_argument('-w', '--warning', action='store_true')
77 | args = parser.parse_args()
78 |
79 | with open(args.config_path, 'r') as f:
80 | config = yaml.safe_load(f)
81 | config['config_path'] = args.config_path
82 | config['log_dir'] = args.log_dir
83 |
84 | # Modify the yaml file using attr
85 | for attr in args.attrs:
86 | module, new_value = attr.split('=')
87 | keys = module.split('.')
88 | target = functools.reduce(dict.__getitem__, keys[:-1], config)
89 | if keys[-1] in target.keys():
90 | target[keys[-1]] = yaml.safe_load(new_value)
91 | else:
92 | raise ValueError('The following key is not defined in the config file:{}', keys)
93 |
94 | for k, v in sorted(config.items()):
95 | print("\t{} {}".format(k, v))
96 |
97 | # create the result directory and save yaml
98 | if not os.path.exists(config['log_dir']):
99 | os.makedirs(config['log_dir'])
100 |
101 | _config = copy.deepcopy(config)
102 | configpath = os.path.join(config['log_dir'], "config.yml")
103 | open(configpath, 'w').write(
104 | yaml.dump(_config, default_flow_style=False)
105 | )
106 |
107 | # Training
108 | train(config)
109 |
--------------------------------------------------------------------------------
/training_allmodels.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | LOGDIR_ROOT=$1
3 | DATADIR_ROOT=$2
4 |
5 | # Training on MNIST, MNIST-bg w/ digit 4, 3DShapes and smallNORB
6 | for seed in 1 2 3; do
7 | for dataset_name in mnist mnist_bg 3dshapes smallNORB; do
8 | for model_name in lstsq lstsq_multi lstsq_rec neuralM neural_trans; do
9 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
10 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
11 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT}
12 | done
13 | done
14 | done
15 |
16 | # Training on MNIST-bg w/ all digits
17 | for seed in 1 2 3; do
18 | for dataset_name in mnist_bg; do
19 | for model_name in lstsq lstsq_multi lstsq_rec; do
20 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \
21 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
22 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000
23 | done
24 | for model_name in neuralM neural_trans; do
25 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}_full-${model_name}-seed${seed}/ \
26 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
27 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT} train_data.args.only_use_digit4=False max_iteration=200000 training_loop.args.lr_decay_iter=160000 training_loop.args.reconst_iter=200000
28 | done
29 | done
30 | done
31 |
32 | # Training on Accelerated Sequential MNIST
33 | for seed in 1 2 3; do
34 | for dataset_name in mnist_accl; do
35 | for model_name in lstsq holstsq neural_trans; do
36 | python run.py --log_dir=${LOGDIR_ROOT}/${dataset_name}-${model_name}-seed${seed}/ \
37 | --config_path=./configs/${dataset_name}/lstsq/${model_name}.yml \
38 | --attr seed=${seed} train_data.args.root=${DATADIR_ROOT}
39 | done
40 | done
41 | done
42 |
43 |
--------------------------------------------------------------------------------
/training_loops.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import pytorch_pfn_extras as ppe
5 | from utils.clr import simclr
6 | from utils.misc import freq_to_wave
7 | from tqdm import tqdm
8 |
9 |
10 | def loop_seqmodel(manager, model, optimizer, train_loader, config, device):
11 | while not manager.stop_trigger:
12 | for images in train_loader:
13 | with manager.run_iteration():
14 | reconst = True if manager.iteration < config['training_loop']['args']['reconst_iter'] else False
15 | if manager.iteration >= config['training_loop']['args']['lr_decay_iter']:
16 | optimizer.param_groups[0]['lr'] = config['lr']/3.
17 | else:
18 | optimizer.param_groups[0]['lr'] = config['lr']
19 | model.train()
20 | images = torch.stack(images).transpose(1, 0).to(device)
21 | loss, (loss_bd, loss_orth, _) = model.loss(images, T_cond=config['T_cond'], return_reg_loss=True, reconst=reconst)
22 | optimizer.zero_grad()
23 | loss.backward()
24 | optimizer.step()
25 | ppe.reporting.report({
26 | 'train/loss': loss.item(),
27 | 'train/loss_bd': loss_bd.item(),
28 | 'train/loss_orth': loss_orth.item(),
29 | })
30 |
31 | if manager.stop_trigger:
32 | break
33 |
34 |
35 | def loop_simclr(manager, model, optimizer, train_loader, config, device):
36 | while not manager.stop_trigger:
37 | for images in train_loader:
38 | with manager.run_iteration():
39 | if manager.iteration >= config['training_loop']['args']['lr_decay_iter']:
40 | optimizer.param_groups[0]['lr'] = config['lr']/3.
41 | else:
42 | optimizer.param_groups[0]['lr'] = config['lr']
43 | model.train()
44 | images = torch.stack(images, dim=1).to(device) # n t c h w
45 | zs = model(images)
46 | zs = [zs[:, i] for i in range(zs.shape[1])]
47 | loss = simclr(
48 | zs,
49 | loss_type=config['training_loop']['args']['loss_type'],
50 | temperature=config['training_loop']['args']['temp']
51 | )
52 | optimizer.zero_grad()
53 | loss.backward()
54 | optimizer.step()
55 | ppe.reporting.report({
56 | 'train/loss': loss.item(),
57 | })
58 | if manager.stop_trigger:
59 | break
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takerum/meta_sequential_prediction/8cdd2174cbf176cd35fb235efa7042daeff92a0f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/clr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import pytorch_pfn_extras as ppe
5 |
6 |
7 | def simclr(zs, temperature=1.0, normalize=True, loss_type='cossim'):
8 | if normalize:
9 | zs = [F.normalize(z, p=2, dim=1) for z in zs]
10 | m = len(zs)
11 | n = zs[0].shape[0]
12 | device = zs[0].device
13 | mask = torch.eye(n * m, device=device)
14 | label0 = torch.fmod(n + torch.arange(0, m * n, device=device), n * m)
15 | z = torch.cat(zs, 0)
16 | if loss_type == 'euclid':
17 | sim = - torch.cdist(z, z)
18 | elif loss_type == 'sq':
19 | sim = - torch.cdist(z, z) ** 2
20 | elif loss_type == 'cossim':
21 | sim = torch.matmul(z, z.transpose(0, 1))
22 | else:
23 | raise NotImplementedError
24 | logit_zz = sim / temperature
25 | logit_zz += mask * -1e8
26 | loss = nn.CrossEntropyLoss()(logit_zz, label0)
27 | return loss
28 |
--------------------------------------------------------------------------------
/utils/emb2d.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class Emb2D(nn.modules.lazy.LazyModuleMixin, nn.Module):
8 | def __init__(self, dim=64):
9 | super().__init__()
10 | self.dim = dim
11 | self.emb = torch.nn.parameter.UninitializedParameter()
12 |
13 | def __call__(self, x):
14 | if torch.nn.parameter.is_lazy(self.emb):
15 | _, h, w = x.shape[1:]
16 | self.emb.materialize((self.dim, h, w))
17 | self.emb.data = positionalencoding2d(self.dim, h, w)
18 | emb = torch.tile(self.emb[None].to(x.device), [x.shape[0], 1, 1, 1])
19 | x = torch.cat([x, emb], axis=1)
20 | return x
21 |
22 | # Copied from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
23 |
24 |
25 | def positionalencoding2d(d_model, height, width):
26 | """
27 | :param d_model: dimension of the model
28 | :param height: height of the positions
29 | :param width: width of the positions
30 | :return: d_model*height*width position matrix
31 | """
32 | if d_model % 4 != 0:
33 | raise ValueError("Cannot use sin/cos positional encoding with "
34 | "odd dimension (got dim={:d})".format(d_model))
35 | pe = torch.zeros(d_model, height, width)
36 | # Each dimension use half of d_model
37 | d_model = int(d_model / 2)
38 | div_term = torch.exp(torch.arange(0., d_model, 2) *
39 | -(math.log(10000.0) / d_model))
40 | pos_w = torch.arange(0., width).unsqueeze(1)
41 | pos_h = torch.arange(0., height).unsqueeze(1)
42 | pe[0:d_model:2, :, :] = torch.sin(
43 | pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
44 | pe[1:d_model:2, :, :] = torch.cos(
45 | pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
46 | pe[d_model::2, :, :] = torch.sin(
47 | pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
48 | pe[d_model + 1::2, :,
49 | :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
50 |
51 | return pe
52 |
--------------------------------------------------------------------------------
/utils/laplacian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from einops import repeat
4 |
5 |
6 | def make_identity(N, D, device):
7 | if N is None:
8 | return torch.Tensor(np.array(np.eye(D))).to(device)
9 | else:
10 | return torch.Tensor(np.array([np.eye(D)] * N)).to(device)
11 |
12 | def make_identity_like(A):
13 | assert A.shape[-2] == A.shape[-1] # Ensure A is a batch of squared matrices
14 | device = A.device
15 | shape = A.shape[:-2]
16 | eye = torch.eye(A.shape[-1], device=device)[(None,)*len(shape)]
17 | return eye.repeat(*shape, 1, 1)
18 |
19 |
20 | def make_diagonal(vecs):
21 | vecs = vecs[..., None].repeat(*([1,]*len(vecs.shape)), vecs.shape[-1])
22 | return vecs * make_identity_like(vecs)
23 |
24 | # Calculate Normalized Laplacian
25 | def tracenorm_of_normalized_laplacian(A):
26 | D_vec = torch.sum(A, axis=-1)
27 | D = make_diagonal(D_vec)
28 | L = D - A
29 | inv_A_diag = make_diagonal(
30 | 1 / torch.sqrt(1e-10 + D_vec))
31 | L = torch.matmul(inv_A_diag, torch.matmul(L, inv_A_diag))
32 | sigmas = torch.linalg.svdvals(L)
33 | return torch.sum(sigmas, axis=-1)
34 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from einops import repeat
5 | import numpy as np
6 |
7 |
8 | def freq_to_wave(freq, is_radian=True):
9 | _freq_rad = 2 * math.pi * freq if not is_radian else freq
10 | return torch.hstack([torch.cos(_freq_rad), torch.sin(_freq_rad)])
11 |
12 |
13 | def unsqueeze_at_the_end(x, n):
14 | return x[(...,) + (None,)*n]
15 |
16 |
17 | def get_RTmat(theta, phi, gamma, w, h, dx, dy):
18 | d = np.sqrt(h ** 2 + w ** 2)
19 | f = d / (2 * np.sin(gamma) if np.sin(gamma) != 0 else 1)
20 | # Projection 2D -> 3D matrix
21 | A1 = np.array([[1, 0, -w / 2],
22 | [0, 1, -h / 2],
23 | [0, 0, 1],
24 | [0, 0, 1]])
25 |
26 | # Rotation matrices around the X, Y, and Z axis
27 | RX = np.array([[1, 0, 0, 0],
28 | [0, np.cos(theta), -np.sin(theta), 0],
29 | [0, np.sin(theta), np.cos(theta), 0],
30 | [0, 0, 0, 1]])
31 |
32 | RY = np.array([[np.cos(phi), 0, -np.sin(phi), 0],
33 | [0, 1, 0, 0],
34 | [np.sin(phi), 0, np.cos(phi), 0],
35 | [0, 0, 0, 1]])
36 |
37 | RZ = np.array([[np.cos(gamma), -np.sin(gamma), 0, 0],
38 | [np.sin(gamma), np.cos(gamma), 0, 0],
39 | [0, 0, 1, 0],
40 | [0, 0, 0, 1]])
41 |
42 | # Composed rotation matrix with (RX, RY, RZ)
43 | R = np.dot(np.dot(RX, RY), RZ)
44 |
45 | # Translation matrix
46 | T = np.array([[1, 0, 0, dx],
47 | [0, 1, 0, dy],
48 | [0, 0, 1, f],
49 | [0, 0, 0, 1]])
50 | # Projection 3D -> 2D matrix
51 | A2 = np.array([[f, 0, w / 2, 0],
52 | [0, f, h / 2, 0],
53 | [0, 0, 1, 0]])
54 | return np.dot(A2, np.dot(T, np.dot(R, A1)))
55 |
--------------------------------------------------------------------------------
/utils/optimize_bd_cob.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import repeat
4 | from utils.laplacian import tracenorm_of_normalized_laplacian, make_identity_like
5 |
6 |
7 | def optimize_bd_cob(mats, batchsize=32, n_epochs=50, epochs_monitor=10):
8 | # Optimize change of basis matrix U by minimizing block diagonalization loss
9 |
10 | class ChangeOfBasis(torch.nn.Module):
11 | def __init__(self, d):
12 | super().__init__()
13 | self.U = nn.Parameter(torch.empty(d, d))
14 | torch.nn.init.orthogonal_(self.U)
15 |
16 | def __call__(self, mat):
17 | _U = repeat(self.U, "a1 a2 -> n a1 a2", n=mat.shape[0])
18 | n_mat = torch.linalg.solve(_U, mat) @ _U
19 | return n_mat
20 |
21 | change_of_basis = ChangeOfBasis(mats.shape[-1]).to(mats.device)
22 | dataloader = torch.utils.data.DataLoader(
23 | mats, batch_size=batchsize, shuffle=True, num_workers=0)
24 | optimizer = torch.optim.Adam(change_of_basis.parameters(), lr=0.1)
25 | for ep in range(n_epochs):
26 | total_loss, total_N = 0, 0
27 | for mat in dataloader:
28 | n_mat = change_of_basis(mat)
29 | n_mat = torch.abs(n_mat)
30 | n_mat = torch.matmul(n_mat.transpose(-2, -1), n_mat)
31 | loss = torch.mean(
32 | tracenorm_of_normalized_laplacian(n_mat))
33 | optimizer.zero_grad()
34 | loss.backward()
35 | optimizer.step()
36 | total_loss += loss.item() * mat.shape[0]
37 | total_N += mat.shape[0]
38 | if ((ep+1) % epochs_monitor) == 0:
39 | print('ep:{} loss:{}'.format(ep, total_loss/total_N))
40 | return change_of_basis
41 |
--------------------------------------------------------------------------------
/utils/weight_standarization.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class WeightStandarization(nn.Module):
5 | def forward(self, weight):
6 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
7 | keepdim=True).mean(dim=3, keepdim=True)
8 | weight = weight - weight_mean
9 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
10 | weight = weight / std.expand_as(weight)
11 | return weight
12 |
13 |
14 | class WeightStandarization1d(nn.Module):
15 | def forward(self, weight):
16 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
17 | keepdim=True)
18 | weight = weight - weight_mean
19 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1) + 1e-5
20 | weight = weight / std.expand_as(weight)
21 | return weight
22 |
23 |
24 | class WeightStandarization0d(nn.Module):
25 | def forward(self, weight):
26 | weight_mean = weight.mean(dim=1, keepdim=True)
27 | weight = weight - weight_mean
28 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1) + 1e-5
29 | weight = weight / std.expand_as(weight)
30 | return weight
--------------------------------------------------------------------------------
/utils/yaml_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import functools
4 | import argparse
5 | import yaml
6 | import pdb
7 | sys.path.append('../')
8 | sys.path.append('./')
9 |
10 | # Originally created by @msaito
11 | def load_module(fn, name):
12 | mod_name = os.path.splitext(os.path.basename(fn))[0]
13 | mod_path = os.path.dirname(fn)
14 | sys.path.insert(0, mod_path)
15 | return getattr(__import__(mod_name), name)
16 |
17 |
18 | def load_component(config):
19 | class_fn = load_module(config['fn'], config['name'])
20 | return class_fn(**config['args']) if 'args' in config.keys() else class_fn()
21 |
22 |
23 | def load_component_fxn(config):
24 | fxn = load_module(config['fn'], config['name'])
25 | return fxn
26 |
27 |
28 | def make_function(module, name):
29 | fxn = getattr(module, name)
30 | return fxn
31 |
32 |
33 | def make_instance(module, config=[], args=None):
34 | Class = getattr(module, config['name'])
35 | kwargs = config['args']
36 | if args is not None:
37 | kwargs.update(args)
38 | return Class(**kwargs)
39 |
40 |
41 | '''
42 | conbines multiple configs
43 | '''
44 |
45 |
46 | def make_config(conf_dicts, attr_lists=None):
47 | def merge_dictionary(base, diff):
48 | for key, value in diff.items():
49 | if (key in base and isinstance(base[key], dict)
50 | and isinstance(diff[key], dict)):
51 | merge_dictionary(base[key], diff[key])
52 | else:
53 | base[key] = diff[key]
54 |
55 | config = {}
56 | for diff in conf_dicts:
57 | merge_dictionary(config, diff)
58 | if attr_lists is not None:
59 | for attr in attr_lists:
60 | module, new_value = attr.split('=')
61 | keys = module.split('.')
62 | target = functools.reduce(dict.__getitem__, keys[:-1], config)
63 | target[keys[-1]] = yaml.load(new_value)
64 | return config
65 |
66 |
67 | '''
68 | argument parser that uses make_config
69 | '''
70 |
71 |
72 | def parse_args():
73 |
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument(
76 | 'infiles', nargs='+', type=argparse.FileType('r'), default=())
77 | parser.add_argument('-a', '--attrs', nargs='*', default=())
78 | parser.add_argument('-c', '--comment', default='')
79 | parser.add_argument('-w', '--warning', action='store_true')
80 | parser.add_argument('-o', '--output-config', default='')
81 | args = parser.parse_args()
82 |
83 | conf_dicts = [yaml.load(fp) for fp in args.infiles]
84 | config = make_config(conf_dicts, args.attrs)
85 | return config, args
86 |
--------------------------------------------------------------------------------