├── datasets
└── .gitkeep
├── assets
├── gif1_gt.gif
├── gif2_gt.gif
├── teaser.png
├── gif1_pred.gif
├── gif1_seg.gif
├── gif2_pred.gif
├── gif2_seg.gif
├── AllPredictors.png
└── docs
│ ├── INSTALL.md
│ └── TRAIN.md
├── experiments
├── Obj3D
│ ├── Predictor_LSTM
│ │ ├── model_architecture.txt
│ │ └── experiment_params.json
│ ├── Predictor_Transformer
│ │ ├── model_architecture.txt
│ │ └── experiment_params.json
│ ├── Predictor_OCVPPar
│ │ ├── experiment_params.json
│ │ └── model_architecture.txt
│ ├── Predictor_OCVPSeq
│ │ ├── experiment_params.json
│ │ └── model_architecture.txt
│ ├── experiment_params.json
│ └── model_architecture.txt
└── MOViA
│ ├── Predictor_LSTM
│ ├── model_architecture.txt
│ └── experiment_params.json
│ ├── Predictor_OCVPPar
│ ├── experiment_params.json
│ └── model_architecture.txt
│ ├── Predictor_OCVPSeq
│ ├── experiment_params.json
│ └── model_architecture.txt
│ ├── Predictor_Transformer
│ └── experiment_params.json
│ ├── experiment_params.json
│ └── model_architecture.txt
├── src
├── data
│ ├── __init__.py
│ ├── data_utils.py
│ ├── obj3d.py
│ ├── load_data.py
│ └── MoviConvert.py
├── configs
│ ├── __init__.py
│ └── README.md
├── models
│ ├── __init__.py
│ ├── initializers.py
│ ├── model_utils.py
│ └── model_blocks.py
├── 01_create_experiment.py
├── extract_movi_dataset.py
├── 01_create_predictor_experiment.py
├── 03_evaluate_savi_noMasks.py
├── 03_evaluate_savi.py
├── lib
│ ├── config.py
│ ├── logger.py
│ ├── loss.py
│ ├── utils.py
│ └── schedulers.py
├── 06_generate_figs_savi.py
├── base
│ ├── baseEvaluator.py
│ ├── baseFigGenerator.py
│ └── basePredictorEvaluator.py
├── CONFIG.py
├── 05_evaluate_predictor.py
├── 02_train_savi.py
└── 04_train_predictor.py
├── download_pretrained.sh
├── download_obj3d.sh
├── README.md
└── environment.yml
/datasets/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/gif1_gt.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif1_gt.gif
--------------------------------------------------------------------------------
/assets/gif2_gt.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif2_gt.gif
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/assets/gif1_pred.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif1_pred.gif
--------------------------------------------------------------------------------
/assets/gif1_seg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif1_seg.gif
--------------------------------------------------------------------------------
/assets/gif2_pred.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif2_pred.gif
--------------------------------------------------------------------------------
/assets/gif2_seg.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif2_seg.gif
--------------------------------------------------------------------------------
/assets/AllPredictors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/AllPredictors.png
--------------------------------------------------------------------------------
/experiments/Obj3D/Predictor_LSTM/model_architecture.txt:
--------------------------------------------------------------------------------
1 | Total Params: 264192
2 |
3 | Params: 264192
4 | ModuleList(
5 | (0): LSTMCell(128, 128)
6 | (1): LSTMCell(128, 128)
7 | )
8 |
--------------------------------------------------------------------------------
/experiments/MOViA/Predictor_LSTM/model_architecture.txt:
--------------------------------------------------------------------------------
1 | Total Params: 264192
2 |
3 | Params: 264192
4 | LSTMPredictor(
5 | (lstm): ModuleList(
6 | (0): LSTMCell(128, 128)
7 | (1): LSTMCell(128, 128)
8 | )
9 | )
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Accessing datasets from script
3 | """
4 |
5 | from .Movi import MOVI
6 | from .obj3d import OBJ3D
7 |
8 | from .load_data import load_data, build_data_loader, unwrap_batch_data, unwrap_batch_data_masks
9 |
--------------------------------------------------------------------------------
/src/configs/__init__.py:
--------------------------------------------------------------------------------
1 | """ Configs """
2 |
3 | import os
4 | from CONFIG import CONFIG
5 |
6 |
7 | def get_available_configs():
8 | """ Getting a list with the name of the available config files """
9 | config_path = CONFIG["paths"]["configs_path"]
10 | files = sorted(os.listdir(config_path))
11 | available_configs = [f[:-5] for f in files if f[-5:] == ".json"]
12 | return available_configs
13 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Accesing models
3 | """
4 |
5 | from .model_blocks import SoftPositionEmbed
6 | from .encoders_decoders import get_encoder, get_decoder, SimpleConvEncoder, DownsamplingConvEncoder
7 | from .initializers import get_initalizer
8 |
9 | from .attention import SlotAttention, MultiHeadSelfAttention, TransformerBlock
10 | from .SAVi import SAVi
11 | from .model_utils import freeze_params
12 |
--------------------------------------------------------------------------------
/download_pretrained.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir experiments
4 |
5 | wget https://www.dropbox.com/s/mbbae2yk7cexwig/MOViA.zip?dl=1
6 | unzip MOViA.zip?dl=1
7 | rm MOViA.zip?dl=1
8 | rsync -va --delete-after MOViA experiments
9 | rm -r MOViA
10 |
11 | wget https://www.dropbox.com/s/luuzfmo3v2ka2kb/Obj3D.zip?dl=1
12 | unzip Obj3D.zip?dl=1
13 | rm Obj3D.zip?dl=1
14 | rsync -va --delete-after Obj3D experiments
15 | rm -r Obj3D
16 |
--------------------------------------------------------------------------------
/src/configs/README.md:
--------------------------------------------------------------------------------
1 | # Recommended Configs
2 |
3 | The files in this directory correspond to default ```experiment_parameters.json``` for different datasets.
4 |
5 | These experiment values have been tested and lead to reasonable results.
6 |
7 |
8 | #### TODO
9 |
10 | - Add command line argument to ```01_create_experiment.py``` to initialize the experiment's ```experiment_parameters.json``` with one of the here included configurations.
11 |
--------------------------------------------------------------------------------
/download_obj3d.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1XSLW3qBtcxxvV-5oiRruVTlDlQ_Yatzm' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1XSLW3qBtcxxvV-5oiRruVTlDlQ_Yatzm" -O datasets/OBJ3D.zip && rm -rf /tmp/cookies.txt
3 | cd datasets && unzip OBJ3D.zip && rm OBJ3D.zip
4 |
--------------------------------------------------------------------------------
/assets/docs/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | 1. Clone and enter the repository:
4 | ```
5 | git clone git@github.com:AIS-Bonn/OCVP-object-centric-video-prediction.git
6 | cd OCVP-object-centric-video-prediction
7 | ```
8 |
9 |
10 | 2. Install all required packages by installing the ```conda``` environment file included in the repository:
11 | ```
12 | conda env create -f environment.yml
13 | conda activate OCVP
14 | ```
15 |
16 |
17 | 3. Download the Obj3D and MOVi-A datasets, and place them under the `datasets` directory. The folder structure should be like:
18 | ```
19 | OCVP
20 | ├── datasets/
21 | | ├── Obj3D/
22 | | └── MOViA/
23 | ```
24 |
25 | * **Obj3D:** Donwload and extract this dataset by running the following bash script:
26 | ```
27 | chmod +x download_obj3d.sh
28 | ./download_obj3d.sh
29 | ```
30 |
31 | - **MOViA:** Download the MOVi-A dataset to your local disk from the [Google Cloud Storage](https://console.cloud.google.com/storage/browser/kubric-public/tfds), and preprocess the *TFRecord* files to extract the video frames and other required metadata by running the following commands:
32 | ```
33 | gsutil -m cp -r gs://kubric-public/tfds/movi_a/128x128/ .
34 | mkdir movi_a
35 | mv 128x128/ movi_a/128x128/
36 | python src/extract_movi_dataset.py
37 | ```
38 |
39 |
40 |
41 | 4. Download and extract the pretrained models, including checkpoints for the SAVi decomposition and prediction modules:
42 | ```
43 | chmod +x download_pretrained.sh
44 | ./download_pretrained.sh
45 | ```
46 |
--------------------------------------------------------------------------------
/src/01_create_experiment.py:
--------------------------------------------------------------------------------
1 | """
2 | Creating experiment directory and initalizing it with defauls
3 | """
4 |
5 | import os
6 | from lib.arguments import create_experiment_arguments
7 | from lib.config import Config
8 | from lib.logger import Logger, print_
9 | from lib.utils import create_directory, delete_directory, timestamp, clear_cmd
10 |
11 | from CONFIG import CONFIG
12 |
13 |
14 | def initialize_experiment():
15 | """
16 | Creating experiment directory and initalizing it with defauls
17 | """
18 | # reading command line args
19 | args = create_experiment_arguments()
20 | exp_dir, config, exp_name = args.exp_directory, args.config, args.name
21 | exp_name = f"experiment_{timestamp()}" if exp_name is None or len(exp_name) < 1 else exp_name
22 | exp_path = os.path.join(CONFIG["paths"]["experiments_path"], exp_dir, exp_name)
23 |
24 | # creating directories
25 | create_directory(exp_path)
26 | _ = Logger(exp_path) # initialize logger once exp_dir is created
27 | create_directory(dir_path=exp_path, dir_name="plots")
28 | create_directory(dir_path=exp_path, dir_name="tboard_logs")
29 |
30 | try:
31 | cfg = Config(exp_path=exp_path)
32 | cfg.create_exp_config_file(config=config)
33 | except FileNotFoundError as e:
34 | print_("An error has occurred...\n Removing experiment directory")
35 | delete_directory(dir_path=exp_path)
36 | print(e)
37 |
38 | return
39 |
40 |
41 | if __name__ == "__main__":
42 | clear_cmd()
43 | initialize_experiment()
44 |
45 | #
46 |
--------------------------------------------------------------------------------
/src/extract_movi_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Converting the MOVi dataset from the shitty tensorboard format into images, so that we can
3 | then load them like civilised people without a million Tensflow errorsand warnings
4 | """
5 |
6 | import os
7 | import torch
8 | import torchvision
9 | from tqdm import tqdm
10 | from data.MoviConvert import _MoviA
11 | import lib.utils as utils
12 | from CONFIG import CONFIG
13 |
14 |
15 | PATH = os.path.join(CONFIG["paths"]["data_path"], "movi_a")
16 |
17 |
18 | def process_dataset(split="train"):
19 | """ """
20 | data_path = os.path.join(PATH, split)
21 | utils.create_directory(data_path)
22 |
23 | db = _MoviA(split=split, num_frames=100, img_size=(128, 128))
24 | db.get_masks = True
25 | db.get_bbox = True
26 | print(f" --> {len(db) = }")
27 |
28 | # iterating and saving data
29 | for i in tqdm(range(len(db))):
30 | imgs, all_preds = db[i]
31 | bbox = all_preds["bbox_coords"]
32 | com = all_preds["com_coords"]
33 | masks = all_preds["masks"]
34 | flow = all_preds["flow"]
35 | for j in range(imgs.shape[0]):
36 | torchvision.utils.save_image(flow[j], fp=os.path.join(data_path, f"flow_{i:05d}_{j:02d}.png"))
37 | torchvision.utils.save_image(imgs[j], fp=os.path.join(data_path, f"rgb_{i:05d}_{j:02d}.png"))
38 | torch.save({"com": com, "bbox": bbox}, os.path.join(data_path, f"coords_{i:05d}.pt"))
39 | torch.save({"masks": masks}, os.path.join(data_path, f"mask_{i:05d}.pt"))
40 |
41 | return
42 |
43 |
44 | if __name__ == '__main__':
45 | os.system("clear")
46 | print("Processing Validation set")
47 | process_dataset(split="validation")
48 | print("Processing Training set")
49 | process_dataset(split="train")
50 | print("Finished")
51 |
--------------------------------------------------------------------------------
/src/data/data_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some data processing and other utils
3 | """
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def get_slots_stats(seq, masks):
10 | """
11 | Obtaining stats about the number of slots in a video sequence
12 |
13 | Args:
14 | -----
15 | seq: torch tensor
16 | Sequence of images. Shape is (N_frames, N_channels, H, W)
17 | masks: torch Tensor
18 | Instance segmentation masks. Shape is (N_frames, 1, H, W)
19 | """
20 | total_num_slots = len(torch.unique(masks))
21 | slot_dist = [len(torch.unique(m)) for m in masks]
22 |
23 | stats = {
24 | "total_num_slots": total_num_slots,
25 | "slot_dist": slot_dist,
26 | "max_num_slots": np.max(slot_dist),
27 | "min_num_slots": np.min(slot_dist)
28 | }
29 | return stats
30 |
31 |
32 | def masks_to_boxes(masks):
33 | """
34 | Converting a binary segmentation mask into a bounding box
35 |
36 | Args:
37 | -----
38 | masks: torch Tensor
39 | Segmentation masks. Shape is (n_imgs, n_objs, H, W)
40 |
41 | Returns:
42 | --------
43 | bboxes: torch Tensor
44 | Bounding boxes corresponding the input segmentation masks in format [x1, y1, x2, y2].
45 | Shape is (n_imgs, 4)
46 | """
47 | assert masks.unique().tolist() == [0, 1]
48 |
49 | bboxes = torch.zeros(masks.shape[0], 4)
50 | for i, mask in enumerate(masks):
51 | if mask.max() == 0:
52 | bboxes[i] = torch.ones(4) * -1
53 | continue
54 | vertical_indices = torch.where(torch.any(mask, axis=1))[0]
55 | horizontal_indices = torch.where(torch.any(mask, axis=0))[0]
56 | if horizontal_indices.shape[0]:
57 | x1, x2 = horizontal_indices[[0, -1]]
58 | y1, y2 = vertical_indices[[0, -1]]
59 | else:
60 | bboxes[i] = torch.ones(4) * -1
61 | continue
62 | bboxes[i] = torch.tensor([x1, y1, x2, y2])
63 | return bboxes.to(masks.device)
64 |
65 | #
66 |
--------------------------------------------------------------------------------
/experiments/Obj3D/Predictor_Transformer/model_architecture.txt:
--------------------------------------------------------------------------------
1 | Total Params: 562944
2 |
3 | Params: 16512
4 | Linear(in_features=128, out_features=128, bias=True)
5 |
6 | Params: 16512
7 | Linear(in_features=128, out_features=128, bias=True)
8 |
9 | Params: 529920
10 | Sequential(
11 | (0): TransformerEncoderLayer(
12 | (self_attn): MultiheadAttention(
13 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
14 | )
15 | (linear1): Linear(in_features=128, out_features=256, bias=True)
16 | (dropout): Dropout(p=0.1, inplace=False)
17 | (linear2): Linear(in_features=256, out_features=128, bias=True)
18 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
19 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
20 | (dropout1): Dropout(p=0.1, inplace=False)
21 | (dropout2): Dropout(p=0.1, inplace=False)
22 | )
23 | (1): TransformerEncoderLayer(
24 | (self_attn): MultiheadAttention(
25 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
26 | )
27 | (linear1): Linear(in_features=128, out_features=256, bias=True)
28 | (dropout): Dropout(p=0.1, inplace=False)
29 | (linear2): Linear(in_features=256, out_features=128, bias=True)
30 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
31 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
32 | (dropout1): Dropout(p=0.1, inplace=False)
33 | (dropout2): Dropout(p=0.1, inplace=False)
34 | )
35 | (2): TransformerEncoderLayer(
36 | (self_attn): MultiheadAttention(
37 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
38 | )
39 | (linear1): Linear(in_features=128, out_features=256, bias=True)
40 | (dropout): Dropout(p=0.1, inplace=False)
41 | (linear2): Linear(in_features=256, out_features=128, bias=True)
42 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
43 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
44 | (dropout1): Dropout(p=0.1, inplace=False)
45 | (dropout2): Dropout(p=0.1, inplace=False)
46 | )
47 | (3): TransformerEncoderLayer(
48 | (self_attn): MultiheadAttention(
49 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
50 | )
51 | (linear1): Linear(in_features=128, out_features=256, bias=True)
52 | (dropout): Dropout(p=0.1, inplace=False)
53 | (linear2): Linear(in_features=256, out_features=128, bias=True)
54 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
55 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
56 | (dropout1): Dropout(p=0.1, inplace=False)
57 | (dropout2): Dropout(p=0.1, inplace=False)
58 | )
59 | )
--------------------------------------------------------------------------------
/experiments/MOViA/Predictor_LSTM/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "MoviA",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 11,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 128,
19 | "mlp_hidden": 256,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 3,
23 | "num_iterations": 1,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "BBox"
31 | },
32 | "predictor": {
33 | "predictor_name": "LSTM",
34 | "LSTM": {
35 | "num_cells": 2,
36 | "hidden_dim": 128,
37 | "residual": true
38 | }
39 | }
40 | },
41 | "loss": [
42 | {
43 | "type": "mse",
44 | "weight": 1
45 | }
46 | ],
47 | "predictor_loss": [
48 | {
49 | "type": "pred_img_mse",
50 | "weight": 1
51 | },
52 | {
53 | "type": "pred_slot_mse",
54 | "weight": 1
55 | }
56 | ],
57 | "training_slots": {
58 | "num_epochs": 1500,
59 | "save_frequency": 10,
60 | "log_frequency": 25,
61 | "image_log_frequency": 100,
62 | "batch_size": 64,
63 | "lr": 0.0002,
64 | "optimizer": "adam",
65 | "momentum": 0,
66 | "weight_decay": 0,
67 | "nesterov": false,
68 | "scheduler": "cosine_annealing",
69 | "lr_factor": 0.5,
70 | "patience": 10,
71 | "scheduler_steps": 400000,
72 | "lr_warmup": true,
73 | "warmup_steps": 2500,
74 | "warmup_epochs": 200,
75 | "gradient_clipping": true,
76 | "clipping_max_value": 0.05
77 | },
78 | "training_prediction": {
79 | "num_context": 6,
80 | "num_preds": 8,
81 | "teacher_force": false,
82 | "skip_first_slot": false,
83 | "num_epochs": 1500,
84 | "train_iters_per_epoch": 10000000000,
85 | "save_frequency": 10,
86 | "save_frequency_iters": 10000000,
87 | "log_frequency": 25,
88 | "image_log_frequency": 100,
89 | "batch_size": 64,
90 | "sample_length": 14,
91 | "gradient_clipping": false,
92 | "clipping_max_value": 3.0
93 | },
94 | "_general": {
95 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2",
96 | "created_time": "2023-01-28_11-34-50",
97 | "last_loaded": "2023-02-09_09-52-06"
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/experiments/Obj3D/Predictor_LSTM/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "OBJ3D",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 6,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 64,
19 | "mlp_hidden": 128,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 2,
23 | "num_iterations": 2,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "LearnedRandom"
31 | },
32 | "predictor": {
33 | "predictor_name": "LSTM",
34 | "LSTM": {
35 | "num_cells": 2,
36 | "hidden_dim": 128,
37 | "residual": true
38 | }
39 | }
40 | },
41 | "loss": [
42 | {
43 | "type": "mse",
44 | "weight": 1
45 | }
46 | ],
47 | "predictor_loss": [
48 | {
49 | "type": "pred_img_mse",
50 | "weight": 1
51 | },
52 | {
53 | "type": "pred_slot_mse",
54 | "weight": 1
55 | }
56 | ],
57 | "training_slots": {
58 | "num_epochs": 2000,
59 | "save_frequency": 10,
60 | "log_frequency": 100,
61 | "image_log_frequency": 100,
62 | "batch_size": 64,
63 | "lr": 0.0001,
64 | "optimizer": "adam",
65 | "momentum": 0,
66 | "weight_decay": 0,
67 | "nesterov": false,
68 | "scheduler": "cosine_annealing",
69 | "lr_factor": 0.05,
70 | "patience": 10,
71 | "scheduler_steps": 100000,
72 | "lr_warmup": true,
73 | "warmup_steps": 2500,
74 | "warmup_epochs": 1000,
75 | "gradient_clipping": true,
76 | "clipping_max_value": 0.05
77 | },
78 | "training_prediction": {
79 | "num_context": 5,
80 | "num_preds": 5,
81 | "teacher_force": false,
82 | "skip_first_slot": false,
83 | "num_epochs": 1500,
84 | "train_iters_per_epoch": 10000000000,
85 | "save_frequency": 10,
86 | "save_frequency_iters": 1000000,
87 | "log_frequency": 100,
88 | "image_log_frequency": 100,
89 | "batch_size": 64,
90 | "sample_length": 10,
91 | "gradient_clipping": true,
92 | "clipping_max_value": 0.05
93 | },
94 | "_general": {
95 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI",
96 | "created_time": "2022-12-06_09-08-22",
97 | "last_loaded": "2022-12-15_14-51-35"
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/experiments/MOViA/Predictor_OCVPPar/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "MoviA",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 11,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 128,
19 | "mlp_hidden": 256,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 3,
23 | "num_iterations": 1,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "BBox"
31 | },
32 | "predictor": {
33 | "predictor_name": "OCVP-Par",
34 | "OCVP-Par": {
35 | "token_dim": 256,
36 | "hidden_dim": 512,
37 | "num_layers": 4,
38 | "n_heads": 8,
39 | "residual": true,
40 | "input_buffer_size": 30
41 | }
42 | }
43 | },
44 | "loss": [
45 | {
46 | "type": "mse",
47 | "weight": 1
48 | }
49 | ],
50 | "predictor_loss": [
51 | {
52 | "type": "pred_img_mse",
53 | "weight": 1
54 | },
55 | {
56 | "type": "pred_slot_mse",
57 | "weight": 1
58 | }
59 | ],
60 | "training_slots": {
61 | "num_epochs": 1500,
62 | "save_frequency": 10,
63 | "log_frequency": 25,
64 | "image_log_frequency": 100,
65 | "batch_size": 64,
66 | "lr": 0.0002,
67 | "optimizer": "adam",
68 | "momentum": 0,
69 | "weight_decay": 0,
70 | "nesterov": false,
71 | "scheduler": "cosine_annealing",
72 | "lr_factor": 0.5,
73 | "patience": 10,
74 | "scheduler_steps": 400000,
75 | "lr_warmup": true,
76 | "warmup_steps": 2500,
77 | "warmup_epochs": 200,
78 | "gradient_clipping": true,
79 | "clipping_max_value": 0.05
80 | },
81 | "training_prediction": {
82 | "num_context": 6,
83 | "num_preds": 8,
84 | "teacher_force": false,
85 | "skip_first_slot": false,
86 | "num_epochs": 1500,
87 | "train_iters_per_epoch": 10000000000,
88 | "save_frequency": 10,
89 | "save_frequency_iters": 10000000,
90 | "log_frequency": 25,
91 | "image_log_frequency": 100,
92 | "batch_size": 64,
93 | "sample_length": 14,
94 | "gradient_clipping": false,
95 | "clipping_max_value": 3.0
96 | },
97 | "_general": {
98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2",
99 | "created_time": "2023-01-28_11-34-50",
100 | "last_loaded": "2023-02-09_13-44-17"
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/experiments/MOViA/Predictor_OCVPSeq/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "MoviA",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 11,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 128,
19 | "mlp_hidden": 256,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 3,
23 | "num_iterations": 1,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "BBox"
31 | },
32 | "predictor": {
33 | "predictor_name": "OCVP-Seq",
34 | "OCVP-Seq": {
35 | "token_dim": 256,
36 | "hidden_dim": 512,
37 | "num_layers": 4,
38 | "n_heads": 8,
39 | "residual": true,
40 | "input_buffer_size": 30
41 | }
42 | }
43 | },
44 | "loss": [
45 | {
46 | "type": "mse",
47 | "weight": 1
48 | }
49 | ],
50 | "predictor_loss": [
51 | {
52 | "type": "pred_img_mse",
53 | "weight": 1
54 | },
55 | {
56 | "type": "pred_slot_mse",
57 | "weight": 1
58 | }
59 | ],
60 | "training_slots": {
61 | "num_epochs": 1500,
62 | "save_frequency": 10,
63 | "log_frequency": 25,
64 | "image_log_frequency": 100,
65 | "batch_size": 64,
66 | "lr": 0.0002,
67 | "optimizer": "adam",
68 | "momentum": 0,
69 | "weight_decay": 0,
70 | "nesterov": false,
71 | "scheduler": "cosine_annealing",
72 | "lr_factor": 0.5,
73 | "patience": 10,
74 | "scheduler_steps": 400000,
75 | "lr_warmup": true,
76 | "warmup_steps": 2500,
77 | "warmup_epochs": 200,
78 | "gradient_clipping": true,
79 | "clipping_max_value": 0.05
80 | },
81 | "training_prediction": {
82 | "num_context": 6,
83 | "num_preds": 8,
84 | "teacher_force": false,
85 | "skip_first_slot": false,
86 | "num_epochs": 1500,
87 | "train_iters_per_epoch": 10000000000,
88 | "save_frequency": 10,
89 | "save_frequency_iters": 10000000,
90 | "log_frequency": 25,
91 | "image_log_frequency": 100,
92 | "batch_size": 64,
93 | "sample_length": 14,
94 | "gradient_clipping": false,
95 | "clipping_max_value": 3.0
96 | },
97 | "_general": {
98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2",
99 | "created_time": "2023-01-28_11-34-50",
100 | "last_loaded": "2023-02-08_16-32-24"
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/experiments/MOViA/Predictor_Transformer/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "MoviA",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 11,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 128,
19 | "mlp_hidden": 256,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 3,
23 | "num_iterations": 1,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "BBox"
31 | },
32 | "predictor": {
33 | "predictor_name": "Transformer",
34 | "Transformer": {
35 | "token_dim": 256,
36 | "hidden_dim": 512,
37 | "num_layers": 4,
38 | "n_heads": 8,
39 | "residual": true,
40 | "input_buffer_size": 30
41 | }
42 | }
43 | },
44 | "loss": [
45 | {
46 | "type": "mse",
47 | "weight": 1
48 | }
49 | ],
50 | "predictor_loss": [
51 | {
52 | "type": "pred_img_mse",
53 | "weight": 1
54 | },
55 | {
56 | "type": "pred_slot_mse",
57 | "weight": 1
58 | }
59 | ],
60 | "training_slots": {
61 | "num_epochs": 1500,
62 | "save_frequency": 10,
63 | "log_frequency": 25,
64 | "image_log_frequency": 100,
65 | "batch_size": 64,
66 | "lr": 0.0002,
67 | "optimizer": "adam",
68 | "momentum": 0,
69 | "weight_decay": 0,
70 | "nesterov": false,
71 | "scheduler": "cosine_annealing",
72 | "lr_factor": 0.5,
73 | "patience": 10,
74 | "scheduler_steps": 400000,
75 | "lr_warmup": true,
76 | "warmup_steps": 2500,
77 | "warmup_epochs": 200,
78 | "gradient_clipping": true,
79 | "clipping_max_value": 0.05
80 | },
81 | "training_prediction": {
82 | "num_context": 6,
83 | "num_preds": 8,
84 | "teacher_force": false,
85 | "skip_first_slot": false,
86 | "num_epochs": 1500,
87 | "train_iters_per_epoch": 10000000000,
88 | "save_frequency": 10,
89 | "save_frequency_iters": 10000000,
90 | "log_frequency": 25,
91 | "image_log_frequency": 100,
92 | "batch_size": 64,
93 | "sample_length": 14,
94 | "gradient_clipping": false,
95 | "clipping_max_value": 3.0
96 | },
97 | "_general": {
98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2",
99 | "created_time": "2023-01-28_11-34-50",
100 | "last_loaded": "2023-02-08_09-15-58"
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/experiments/Obj3D/Predictor_OCVPPar/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "OBJ3D",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 6,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 64,
19 | "mlp_hidden": 128,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 2,
23 | "num_iterations": 2,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "LearnedRandom"
31 | },
32 | "predictor": {
33 | "predictor_name": "OCVP-Par",
34 | "OCVP-Par": {
35 | "token_dim": 128,
36 | "hidden_dim": 256,
37 | "num_layers": 4,
38 | "n_heads": 4,
39 | "residual": true,
40 | "input_buffer_size": 5
41 | }
42 | }
43 | },
44 | "loss": [
45 | {
46 | "type": "mse",
47 | "weight": 1
48 | }
49 | ],
50 | "predictor_loss": [
51 | {
52 | "type": "pred_img_mse",
53 | "weight": 1
54 | },
55 | {
56 | "type": "pred_slot_mse",
57 | "weight": 1
58 | }
59 | ],
60 | "training_slots": {
61 | "num_epochs": 2000,
62 | "save_frequency": 10,
63 | "log_frequency": 100,
64 | "image_log_frequency": 100,
65 | "batch_size": 64,
66 | "lr": 0.0001,
67 | "optimizer": "adam",
68 | "momentum": 0,
69 | "weight_decay": 0,
70 | "nesterov": false,
71 | "scheduler": "cosine_annealing",
72 | "lr_factor": 0.05,
73 | "patience": 10,
74 | "scheduler_steps": 100000,
75 | "lr_warmup": true,
76 | "warmup_steps": 2500,
77 | "warmup_epochs": 1000,
78 | "gradient_clipping": true,
79 | "clipping_max_value": 0.05
80 | },
81 | "training_prediction": {
82 | "num_context": 5,
83 | "num_preds": 5,
84 | "teacher_force": false,
85 | "skip_first_slot": false,
86 | "num_epochs": 1500,
87 | "train_iters_per_epoch": 10000000000,
88 | "save_frequency": 25,
89 | "save_frequency_iters": 1000000,
90 | "log_frequency": 100,
91 | "image_log_frequency": 100,
92 | "batch_size": 16,
93 | "sample_length": 10,
94 | "gradient_clipping": true,
95 | "clipping_max_value": 0.05
96 | },
97 | "_general": {
98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI",
99 | "created_time": "2022-12-06_09-08-22",
100 | "last_loaded": "2023-01-13_13-41-20"
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/experiments/Obj3D/Predictor_OCVPSeq/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "OBJ3D",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 6,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 64,
19 | "mlp_hidden": 128,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 2,
23 | "num_iterations": 2,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "LearnedRandom"
31 | },
32 | "predictor": {
33 | "predictor_name": "OCVP-Seq",
34 | "OCVP-Seq": {
35 | "token_dim": 128,
36 | "hidden_dim": 256,
37 | "num_layers": 4,
38 | "n_heads": 4,
39 | "residual": true,
40 | "input_buffer_size": 5
41 | }
42 | }
43 | },
44 | "loss": [
45 | {
46 | "type": "mse",
47 | "weight": 1
48 | }
49 | ],
50 | "predictor_loss": [
51 | {
52 | "type": "pred_img_mse",
53 | "weight": 1
54 | },
55 | {
56 | "type": "pred_slot_mse",
57 | "weight": 1
58 | }
59 | ],
60 | "training_slots": {
61 | "num_epochs": 2000,
62 | "save_frequency": 10,
63 | "log_frequency": 100,
64 | "image_log_frequency": 100,
65 | "batch_size": 64,
66 | "lr": 0.0001,
67 | "optimizer": "adam",
68 | "momentum": 0,
69 | "weight_decay": 0,
70 | "nesterov": false,
71 | "scheduler": "cosine_annealing",
72 | "lr_factor": 0.05,
73 | "patience": 10,
74 | "scheduler_steps": 100000,
75 | "lr_warmup": true,
76 | "warmup_steps": 2500,
77 | "warmup_epochs": 1000,
78 | "gradient_clipping": true,
79 | "clipping_max_value": 0.05
80 | },
81 | "training_prediction": {
82 | "num_context": 5,
83 | "num_preds": 5,
84 | "teacher_force": false,
85 | "skip_first_slot": false,
86 | "num_epochs": 1500,
87 | "train_iters_per_epoch": 10000000000,
88 | "save_frequency": 25,
89 | "save_frequency_iters": 1000000,
90 | "log_frequency": 100,
91 | "image_log_frequency": 100,
92 | "batch_size": 16,
93 | "sample_length": 10,
94 | "gradient_clipping": true,
95 | "clipping_max_value": 0.05
96 | },
97 | "_general": {
98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI",
99 | "created_time": "2022-12-06_09-08-22",
100 | "last_loaded": "2023-01-13_13-36-22"
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/experiments/Obj3D/Predictor_Transformer/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "OBJ3D",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 6,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 64,
19 | "mlp_hidden": 128,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 2,
23 | "num_iterations": 2,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "LearnedRandom"
31 | },
32 | "predictor": {
33 | "predictor_name": "Transformer",
34 | "Transformer": {
35 | "token_dim": 128,
36 | "hidden_dim": 256,
37 | "num_layers": 4,
38 | "n_heads": 4,
39 | "residual": true,
40 | "input_buffer_size": 5
41 | }
42 | }
43 | },
44 | "loss": [
45 | {
46 | "type": "mse",
47 | "weight": 1
48 | }
49 | ],
50 | "predictor_loss": [
51 | {
52 | "type": "pred_img_mse",
53 | "weight": 1
54 | },
55 | {
56 | "type": "pred_slot_mse",
57 | "weight": 1
58 | }
59 | ],
60 | "training_slots": {
61 | "num_epochs": 2000,
62 | "save_frequency": 10,
63 | "log_frequency": 100,
64 | "image_log_frequency": 100,
65 | "batch_size": 64,
66 | "lr": 0.0001,
67 | "optimizer": "adam",
68 | "momentum": 0,
69 | "weight_decay": 0,
70 | "nesterov": false,
71 | "scheduler": "cosine_annealing",
72 | "lr_factor": 0.05,
73 | "patience": 10,
74 | "scheduler_steps": 100000,
75 | "lr_warmup": true,
76 | "warmup_steps": 2500,
77 | "warmup_epochs": 1000,
78 | "gradient_clipping": true,
79 | "clipping_max_value": 0.05
80 | },
81 | "training_prediction": {
82 | "num_context": 5,
83 | "num_preds": 5,
84 | "teacher_force": false,
85 | "skip_first_slot": false,
86 | "num_epochs": 1500,
87 | "train_iters_per_epoch": 10000000000,
88 | "save_frequency": 25,
89 | "save_frequency_iters": 1000000,
90 | "log_frequency": 100,
91 | "image_log_frequency": 100,
92 | "batch_size": 64,
93 | "sample_length": 10,
94 | "gradient_clipping": true,
95 | "clipping_max_value": 0.05
96 | },
97 | "_general": {
98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI",
99 | "created_time": "2022-12-06_09-08-22",
100 | "last_loaded": "2022-12-29_08-41-26"
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/src/01_create_predictor_experiment.py:
--------------------------------------------------------------------------------
1 | """
2 | Creating a predictor experiment inside an existing experiment directory and initializing
3 | its experiment parameters.
4 | """
5 |
6 | import os
7 | from lib.arguments import create_predictor_experiment_arguments
8 | from lib.config import Config
9 | from lib.logger import Logger, print_
10 | from lib.utils import create_directory, delete_directory, clear_cmd
11 | from CONFIG import CONFIG
12 |
13 |
14 | def initialize_predictor_experiment():
15 | """
16 | Creating predictor experiment directory and initializing it with defauls
17 | """
18 | # reading command line args
19 | args = create_predictor_experiment_arguments()
20 | exp_dir, exp_name, predictor_name = args.exp_directory, args.name, args.predictor_name
21 |
22 | # making sure everything adds up
23 | parent_path = os.path.join(CONFIG["paths"]["experiments_path"], exp_dir)
24 | exp_path = os.path.join(CONFIG["paths"]["experiments_path"], exp_dir, exp_name)
25 | if not os.path.exists(parent_path):
26 | raise FileNotFoundError(f"{parent_path = } does not exist")
27 | if not os.path.exists(os.path.join(parent_path, "experiment_params.json")):
28 | raise FileNotFoundError(f"{parent_path = } does not have experiment_params...")
29 | if len(os.listdir(os.path.join(parent_path, "models"))) <= 0:
30 | raise FileNotFoundError("Parent models-dir does not contain any models!...")
31 | if os.path.exists(exp_path):
32 | raise ValueError(f"{exp_path = } already exists. Choose a different name!")
33 |
34 | # creating directories
35 | create_directory(exp_path)
36 | _ = Logger(exp_path) # initialize logger once exp_dir is created
37 | create_directory(dir_path=exp_path, dir_name="plots")
38 | create_directory(dir_path=exp_path, dir_name="tboard_logs")
39 |
40 | # adding experiment parameters from the parent directory, but only with specified predictor params
41 | try:
42 | cfg = Config(exp_path=parent_path)
43 | exp_params = cfg.load_exp_config_file()
44 | new_predictor_params = {}
45 | predictor_params = exp_params["model"]["predictor"]
46 | if predictor_name not in predictor_params.keys():
47 | raise ValueError(f"{predictor_name} not in keys {predictor_params.keys() = }")
48 | new_predictor_params["predictor_name"] = predictor_name
49 | new_predictor_params[predictor_name] = predictor_params[predictor_name]
50 | exp_params["model"]["predictor"] = new_predictor_params
51 | cfg.save_exp_config_file(exp_path=exp_path, exp_params=exp_params)
52 | except FileNotFoundError as e:
53 | print_("An error has occurred...\n Removing experiment directory")
54 | delete_directory(dir_path=exp_path)
55 | print(e)
56 | print(f"Predictor experiment {exp_name} created successfully! :)")
57 | return
58 |
59 |
60 | if __name__ == "__main__":
61 | clear_cmd()
62 | initialize_predictor_experiment()
63 |
64 | #
65 |
--------------------------------------------------------------------------------
/src/03_evaluate_savi_noMasks.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluating a SAVI model checkpoint on a dataset without ground truth masks.
3 | Since we do not have masks to compare with, we simply evaluate the visual qualitiy
4 | of the reconstructed images using video-prediction metrics: PSNR, SSIM and LPIPS.
5 | """
6 |
7 | from data.load_data import unwrap_batch_data
8 | from lib.arguments import get_sa_eval_arguments
9 | from lib.logger import Logger, print_, log_function, for_all_methods
10 | from lib.metrics import MetricTracker
11 | import lib.utils as utils
12 |
13 | from base.baseEvaluator import BaseEvaluator
14 |
15 |
16 | @for_all_methods(log_function)
17 | class Evaluator(BaseEvaluator):
18 | """
19 | Evaluating a SAVI model checkpoint on a dataset without ground truth masks.
20 | Since we do not have masks to compare with, we simply evaluate the visual qualitiy
21 | of the reconstructed images using video-prediction metrics: PSNR, SSIM and LPIPS.
22 | """
23 |
24 | def set_metric_tracker(self):
25 | """
26 | Initializing the metric tracker
27 | """
28 | self.metric_tracker = MetricTracker(
29 | exp_path=exp_path,
30 | metrics=["psnr", "ssim", "lpips"]
31 | )
32 | return
33 |
34 | def forward_eval(self, batch_data, **kwargs):
35 | """
36 | Making a forwad pass through the model and computing the evaluation metrics
37 |
38 | Args:
39 | -----
40 | batch_data: dict
41 | Dictionary containing the information for the current batch, including images, poses,
42 | actions, or metadata, among others.
43 |
44 | Returns:
45 | --------
46 | pred_data: dict
47 | Predictions from the model for the current batch of data
48 | """
49 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data)
50 | videos, targets = videos.to(self.device), targets.to(self.device)
51 | out_model = self.model(
52 | videos,
53 | num_imgs=videos.shape[1],
54 | **initializer_kwargs
55 | )
56 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model
57 | self.metric_tracker.accumulate(
58 | preds=reconstruction_history.clamp(0, 1),
59 | targets=targets.clamp(0, 1)
60 | )
61 | return
62 |
63 |
64 | if __name__ == "__main__":
65 | utils.clear_cmd()
66 | exp_path, args = get_sa_eval_arguments()
67 | logger = Logger(exp_path=exp_path)
68 | logger.log_info("Starting SAVi visual quality evaluation procedure", message_type="new_exp")
69 |
70 | print_("Initializing Evaluator...")
71 | print_("Args:")
72 | print_("-----")
73 | for k, v in vars(args).items():
74 | print_(f" --> {k} = {v}")
75 | evaluator = Evaluator(
76 | exp_path=exp_path,
77 | checkpoint=args.checkpoint
78 | )
79 | print_("Loading dataset...")
80 | evaluator.load_data()
81 | print_("Setting up model and loading pretrained parameters")
82 | evaluator.setup_model()
83 | print_("Starting evaluation")
84 | evaluator.evaluate()
85 |
86 | #
87 |
--------------------------------------------------------------------------------
/src/03_evaluate_savi.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluating a SAVI model checkpoint using object-centric metrics
3 | This evalaution can only be performed on datasets with annotated segmentation masks
4 | """
5 |
6 | import torch
7 |
8 | from data import unwrap_batch_data_masks
9 | from lib.arguments import get_sa_eval_arguments
10 | from lib.logger import Logger, print_, log_function, for_all_methods
11 | from lib.metrics import MetricTracker
12 | import lib.utils as utils
13 |
14 | from base.baseEvaluator import BaseEvaluator
15 |
16 |
17 | @for_all_methods(log_function)
18 | class Evaluator(BaseEvaluator):
19 | """
20 | Class for evaluating a SAVI model using object-centric metrics.
21 | This evalaution can only be performed on datasets with annotated segmentation masks
22 | """
23 |
24 | def set_metric_tracker(self):
25 | """
26 | Initializing the metric tracker
27 | """
28 | self.metric_tracker = MetricTracker(
29 | exp_path,
30 | metrics=["segmentation_ari", "IoU"]
31 | )
32 | return
33 |
34 | def load_data(self):
35 | """
36 | Loading data
37 | """
38 | super().load_data()
39 | self.test_set.get_masks = True
40 | return
41 |
42 | def forward_eval(self, batch_data, **kwargs):
43 | """
44 | Making a forwad pass through the model and computing the evaluation metrics
45 |
46 | Args:
47 | -----
48 | batch_data: dict
49 | Dictionary containing the information for the current batch, including images, poses,
50 | actions, or metadata, among others.
51 |
52 | Returns:
53 | --------
54 | pred_data: dict
55 | Predictions from the model for the current batch of data
56 | """
57 | videos, masks, initializer_kwargs = unwrap_batch_data_masks(self.exp_params, batch_data)
58 | videos, masks = videos.to(self.device), masks.to(self.device)
59 | out_model = self.model(
60 | videos,
61 | num_imgs=videos.shape[1],
62 | **initializer_kwargs
63 | )
64 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model
65 |
66 | # evaluation
67 | predicted_combined_masks = torch.argmax(masks_history, dim=2).squeeze(2)
68 | self.metric_tracker.accumulate(
69 | preds=predicted_combined_masks,
70 | targets=masks
71 | )
72 | return
73 |
74 |
75 | if __name__ == "__main__":
76 | utils.clear_cmd()
77 | exp_path, args = get_sa_eval_arguments()
78 | logger = Logger(exp_path=exp_path)
79 | logger.log_info("Starting SAVi object-cetric evaluation procedure", message_type="new_exp")
80 |
81 | print_("Initializing Evaluator...")
82 | print_("Args:")
83 | print_("-----")
84 | for k, v in vars(args).items():
85 | print_(f" --> {k} = {v}")
86 | evaluator = Evaluator(
87 | exp_path=exp_path,
88 | checkpoint=args.checkpoint
89 | )
90 | print_("Loading dataset...")
91 | evaluator.load_data()
92 | print_("Setting up model and loading pretrained parameters")
93 | evaluator.setup_model()
94 | print_("Starting evaluation")
95 | evaluator.evaluate()
96 |
97 |
98 | #
99 |
--------------------------------------------------------------------------------
/experiments/MOViA/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "MoviA",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 11,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 128,
19 | "mlp_hidden": 256,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 3,
23 | "num_iterations": 1,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "BBox"
31 | },
32 | "predictor": {
33 | "predictor_name": "Transformer",
34 | "LSTM": {
35 | "num_cells": 2,
36 | "hidden_dim": 64,
37 | "residual": 30
38 | },
39 | "Transformer": {
40 | "token_dim": 128,
41 | "hidden_dim": 256,
42 | "num_layers": 2,
43 | "n_heads": 4,
44 | "residual": true,
45 | "input_buffer_size": 30
46 | },
47 | "OCVP-Seq": {
48 | "token_dim": 128,
49 | "hidden_dim": 256,
50 | "num_layers": 2,
51 | "n_heads": 4,
52 | "residual": true,
53 | "input_buffer_size": 30
54 | },
55 | "OCVP-Par": {
56 | "token_dim": 128,
57 | "hidden_dim": 256,
58 | "num_layers": 2,
59 | "n_heads": 4,
60 | "residual": true,
61 | "input_buffer_size": 30
62 | }
63 | }
64 | },
65 | "loss": [
66 | {
67 | "type": "mse",
68 | "weight": 1
69 | }
70 | ],
71 | "predictor_loss": [
72 | {
73 | "type": "pred_img_mse",
74 | "weight": 1
75 | },
76 | {
77 | "type": "pred_slot_mse",
78 | "weight": 1
79 | }
80 | ],
81 | "training_slots": {
82 | "num_epochs": 2500,
83 | "save_frequency": 10,
84 | "log_frequency": 25,
85 | "image_log_frequency": 100,
86 | "batch_size": 64,
87 | "lr": 0.0001,
88 | "optimizer": "adam",
89 | "momentum": 0,
90 | "weight_decay": 0,
91 | "nesterov": false,
92 | "scheduler": "cosine_annealing",
93 | "lr_factor": 0.5,
94 | "patience": 10,
95 | "scheduler_steps": 400000,
96 | "lr_warmup": true,
97 | "warmup_steps": 2500,
98 | "warmup_epochs": 200,
99 | "gradient_clipping": true,
100 | "clipping_max_value": 0.05
101 | },
102 | "training_prediction": {
103 | "num_context": 5,
104 | "num_preds": 5,
105 | "teacher_force": false,
106 | "skip_first_slot": false,
107 | "num_epochs": 1500,
108 | "train_iters_per_epoch": 10000000000,
109 | "save_frequency": 10,
110 | "save_frequency_iters": 1000,
111 | "log_frequency": 25,
112 | "image_log_frequency": 100,
113 | "batch_size": 64,
114 | "sample_length": 10,
115 | "gradient_clipping": true,
116 | "clipping_max_value": 0.05
117 | },
118 | "_general": {
119 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2",
120 | "created_time": "2023-01-28_11-34-50",
121 | "last_loaded": "2023-01-28_11-34-50"
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/experiments/Obj3D/experiment_params.json:
--------------------------------------------------------------------------------
1 | {
2 | "dataset": {
3 | "dataset_name": "OBJ3D",
4 | "shuffle_train": true,
5 | "shuffle_eval": false,
6 | "use_segmentation": true,
7 | "target": "rgb",
8 | "random_start": true
9 | },
10 | "model": {
11 | "model_name": "SAVi",
12 | "SAVi": {
13 | "num_slots": 6,
14 | "slot_dim": 128,
15 | "in_channels": 3,
16 | "encoder_type": "ConvEncoder",
17 | "num_channels": [32, 32, 32, 32],
18 | "mlp_encoder_dim": 64,
19 | "mlp_hidden": 128,
20 | "num_channels_decoder": [64, 64, 64, 64],
21 | "kernel_size": 5,
22 | "num_iterations_first": 2,
23 | "num_iterations": 2,
24 | "resolution": [64, 64],
25 | "downsample_encoder": false,
26 | "downsample_decoder": true,
27 | "decoder_resolution": [8, 8],
28 | "upsample": 2,
29 | "use_predictor": true,
30 | "initializer": "LearnedRandom"
31 | },
32 | "predictor": {
33 | "predictor_name": "Transformer",
34 | "LSTM": {
35 | "num_cells": 2,
36 | "hidden_dim": 64,
37 | "residual": true
38 | },
39 | "Transformer": {
40 | "token_dim": 128,
41 | "hidden_dim": 256,
42 | "num_layers": 2,
43 | "n_heads": 4,
44 | "residual": true,
45 | "input_buffer_size": null
46 | },
47 | "OCVP-Seq": {
48 | "token_dim": 128,
49 | "hidden_dim": 256,
50 | "num_layers": 2,
51 | "n_heads": 4,
52 | "residual": true,
53 | "input_buffer_size": null
54 | },
55 | "OCVP-Par": {
56 | "token_dim": 128,
57 | "hidden_dim": 256,
58 | "num_layers": 2,
59 | "n_heads": 4,
60 | "residual": true,
61 | "input_buffer_size": null
62 | }
63 | }
64 | },
65 | "loss": [
66 | {
67 | "type": "mse",
68 | "weight": 1
69 | }
70 | ],
71 | "predictor_loss": [
72 | {
73 | "type": "pred_img_mse",
74 | "weight": 1
75 | },
76 | {
77 | "type": "pred_slot_mse",
78 | "weight": 1
79 | }
80 | ],
81 | "training_slots": {
82 | "num_epochs": 2000,
83 | "save_frequency": 10,
84 | "log_frequency": 100,
85 | "image_log_frequency": 100,
86 | "batch_size": 64,
87 | "lr": 0.0001,
88 | "optimizer": "adam",
89 | "momentum": 0,
90 | "weight_decay": 0,
91 | "nesterov": false,
92 | "scheduler": "cosine_annealing",
93 | "lr_factor": 0.05,
94 | "patience": 10,
95 | "scheduler_steps": 100000,
96 | "lr_warmup": true,
97 | "warmup_steps": 2500,
98 | "warmup_epochs": 1000,
99 | "gradient_clipping": true,
100 | "clipping_max_value": 0.05
101 | },
102 | "training_prediction": {
103 | "num_context": 5,
104 | "num_preds": 5,
105 | "teacher_force": false,
106 | "skip_first_slot": false,
107 | "num_epochs": 1500,
108 | "train_iters_per_epoch": 10000000000,
109 | "save_frequency": 10,
110 | "save_frequency_iters": 1000000,
111 | "log_frequency": 100,
112 | "image_log_frequency": 100,
113 | "batch_size": 64,
114 | "sample_length": 10,
115 | "gradient_clipping": true,
116 | "clipping_max_value": 0.05
117 | },
118 | "_general": {
119 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI",
120 | "created_time": "2022-12-06_09-08-22",
121 | "last_loaded": "2022-12-06_09-08-22"
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/experiments/Obj3D/model_architecture.txt:
--------------------------------------------------------------------------------
1 | Total Params: 915652
2 |
3 | Params: 79328
4 | SimpleConvEncoder(
5 | (encoder): Sequential(
6 | (0): ConvBlock(
7 | (block): Sequential(
8 | (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
9 | (1): ReLU()
10 | )
11 | )
12 | (1): ConvBlock(
13 | (block): Sequential(
14 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
15 | (1): ReLU()
16 | )
17 | )
18 | (2): ConvBlock(
19 | (block): Sequential(
20 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
21 | (1): ReLU()
22 | )
23 | )
24 | (3): ConvBlock(
25 | (block): Sequential(
26 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
27 | (1): ReLU()
28 | )
29 | )
30 | )
31 | )
32 |
33 | Params: 160
34 | SoftPositionEmbed(
35 | (projection): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
36 | )
37 |
38 | Params: 6336
39 | Sequential(
40 | (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
41 | (1): Linear(in_features=32, out_features=64, bias=True)
42 | (2): ReLU()
43 | (3): Linear(in_features=64, out_features=64, bias=True)
44 | )
45 |
46 | Params: 148480
47 | TransformerBlock(
48 | (attn): MultiHeadSelfAttention(
49 | (q): Linear(in_features=128, out_features=128, bias=False)
50 | (k): Linear(in_features=128, out_features=128, bias=False)
51 | (v): Linear(in_features=128, out_features=128, bias=False)
52 | (drop): Dropout(p=0.0, inplace=False)
53 | (out_projection): Sequential(
54 | (0): Linear(in_features=128, out_features=128, bias=False)
55 | )
56 | )
57 | (mlp): Sequential(
58 | (0): Linear(in_features=128, out_features=256, bias=True)
59 | (1): ReLU()
60 | (2): Linear(in_features=256, out_features=128, bias=True)
61 | )
62 | (layernorm_query): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
63 | (layernorm_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
64 | (dense_o): Linear(in_features=128, out_features=128, bias=True)
65 | )
66 |
67 | Params: 640
68 | SoftPositionEmbed(
69 | (projection): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1))
70 | )
71 |
72 | Params: 514564
73 | Decoder(
74 | (decoder): Sequential(
75 | (0): ConvBlock(
76 | (block): Sequential(
77 | (0): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
78 | (1): ReLU()
79 | )
80 | )
81 | (1): Upsample(scale_factor=2.0, mode=nearest)
82 | (2): ConvBlock(
83 | (block): Sequential(
84 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
85 | (1): ReLU()
86 | )
87 | )
88 | (3): Upsample(scale_factor=2.0, mode=nearest)
89 | (4): ConvBlock(
90 | (block): Sequential(
91 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
92 | (1): ReLU()
93 | )
94 | )
95 | (5): Upsample(scale_factor=2.0, mode=nearest)
96 | (6): ConvBlock(
97 | (block): Sequential(
98 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
99 | (1): ReLU()
100 | )
101 | )
102 | (7): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
103 | )
104 | )
105 |
106 | Params: 166144
107 | SlotAttention(
108 | (norm_input): LayerNorm((64,), eps=0.001, elementwise_affine=True)
109 | (norm_slot): LayerNorm((128,), eps=0.001, elementwise_affine=True)
110 | (norm_mlp): LayerNorm((128,), eps=0.001, elementwise_affine=True)
111 | (to_q): Linear(in_features=128, out_features=128, bias=True)
112 | (to_k): Linear(in_features=64, out_features=128, bias=True)
113 | (to_v): Linear(in_features=64, out_features=128, bias=True)
114 | (gru): GRUCell(128, 128)
115 | (mlp): Sequential(
116 | (0): Linear(in_features=128, out_features=128, bias=True)
117 | (1): ReLU()
118 | (2): Linear(in_features=128, out_features=128, bias=True)
119 | )
120 | )
121 |
--------------------------------------------------------------------------------
/src/data/obj3d.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset class to load Obj3D dataset
3 | - Source: https://github.com/zhixuan-lin/G-SWM/blob/master/src/dataset/obj3d.py
4 | """
5 |
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 | import glob
9 | import os
10 | import os.path as osp
11 | import torch
12 | from PIL import Image, ImageFile
13 |
14 | from CONFIG import CONFIG
15 | PATH = CONFIG["paths"]["data_path"]
16 |
17 | ImageFile.LOAD_TRUNCATED_IMAGES = True
18 |
19 |
20 | class OBJ3D(Dataset):
21 | """
22 | DataClass for the Obj3D Dataset.
23 |
24 | During training, we sample a random subset of frames in the episode. At inference time,
25 | we always start from the first frame, e.g., when the ball moves towards the objects, and
26 | before any collision happens.
27 |
28 | - Source: https://github.com/zhixuan-lin/G-SWM/blob/master/src/dataset/obj3d.py
29 |
30 | Args:
31 | -----
32 | mode: string
33 | Dataset split to load. Can be one of ['train', 'val', 'test']
34 | ep_len: int
35 | Number of frames in an episode. Default is 30
36 | sample_length: int
37 | Number of frames in the sequences to load
38 | random_start: bool
39 | If True, first frame of the sequence is sampled at random between the possible starting frames.
40 | Otherwise, starting frame is always the first frame in the sequence.
41 | """
42 |
43 | def __init__(self, mode, ep_len=30, sample_length=20, random_start=True):
44 | """
45 | Dataset Initializer
46 | """
47 | assert mode in ["train", "val", "valid", "eval", "test"], f"Unknown dataset split {mode}..."
48 | mode = "val" if mode in ["val", "valid"] else mode
49 | mode = "test" if mode in ["test", "eval"] else mode
50 | assert mode in ['train', 'val', 'test'], f"Unknown dataset split {mode}..."
51 |
52 | self.root = os.path.join(PATH, "OBJ3D", mode)
53 | self.mode = mode
54 | self.sample_length = sample_length
55 | self.random_start = random_start
56 |
57 | # Get all numbers
58 | self.folders = []
59 | for file in os.listdir(self.root):
60 | try:
61 | self.folders.append(int(file))
62 | except ValueError:
63 | continue
64 | self.folders.sort()
65 |
66 | # episode-related paramters
67 | self.epsisodes = []
68 | self.EP_LEN = ep_len
69 | if mode == "train" and self.random_start:
70 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1
71 | else:
72 | self.seq_per_episode = 1
73 |
74 | # loading images from data directories and assembling then into episodes
75 | for f in self.folders:
76 | dir_name = os.path.join(self.root, str(f))
77 | paths = list(glob.glob(osp.join(dir_name, 'test_*.png')))
78 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0].partition('_')[-1])
79 | paths.sort(key=get_num)
80 | self.epsisodes.append(paths)
81 | return
82 |
83 | def __getitem__(self, index):
84 | """
85 | Fetching a sequence from the dataset
86 | """
87 | imgs = []
88 |
89 | # Implement continuous indexing
90 | ep = index // self.seq_per_episode
91 | offset = index % self.seq_per_episode
92 | end = offset + self.sample_length
93 |
94 | e = self.epsisodes[ep]
95 | for image_index in range(offset, end):
96 | img = Image.open(osp.join(e[image_index]))
97 | img = img.resize((64, 64))
98 | img = transforms.ToTensor()(img)[:3]
99 | imgs.append(img)
100 | img = torch.stack(imgs, dim=0).float()
101 |
102 | targets = img
103 | all_reps = {"videos": img}
104 | return img, targets, all_reps
105 |
106 | def __len__(self):
107 | """
108 | Number of episodes in the dataset
109 | """
110 | length = len(self.epsisodes)
111 | return length
112 |
113 |
114 | #
115 |
--------------------------------------------------------------------------------
/experiments/MOViA/Predictor_OCVPPar/model_architecture.txt:
--------------------------------------------------------------------------------
1 | Total Params: 4279680
2 |
3 | Params: 4279680
4 | OCVTransformerV2Predictor(
5 | (mlp_in): Linear(in_features=128, out_features=256, bias=True)
6 | (mlp_out): Linear(in_features=256, out_features=128, bias=True)
7 | (transformer_encoders): Sequential(
8 | (0): ObjectCentricTransformerLayerV2(
9 | (self_attn): MultiheadAttention(
10 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
11 | )
12 | (linear1): Linear(in_features=256, out_features=512, bias=True)
13 | (dropout): Dropout(p=0.1, inplace=False)
14 | (linear2): Linear(in_features=512, out_features=256, bias=True)
15 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
16 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
17 | (dropout1): Dropout(p=0.1, inplace=False)
18 | (dropout2): Dropout(p=0.1, inplace=False)
19 | (self_attn_obj): MultiheadAttention(
20 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
21 | )
22 | (self_attn_time): MultiheadAttention(
23 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
24 | )
25 | )
26 | (1): ObjectCentricTransformerLayerV2(
27 | (self_attn): MultiheadAttention(
28 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
29 | )
30 | (linear1): Linear(in_features=256, out_features=512, bias=True)
31 | (dropout): Dropout(p=0.1, inplace=False)
32 | (linear2): Linear(in_features=512, out_features=256, bias=True)
33 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
34 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
35 | (dropout1): Dropout(p=0.1, inplace=False)
36 | (dropout2): Dropout(p=0.1, inplace=False)
37 | (self_attn_obj): MultiheadAttention(
38 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
39 | )
40 | (self_attn_time): MultiheadAttention(
41 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
42 | )
43 | )
44 | (2): ObjectCentricTransformerLayerV2(
45 | (self_attn): MultiheadAttention(
46 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
47 | )
48 | (linear1): Linear(in_features=256, out_features=512, bias=True)
49 | (dropout): Dropout(p=0.1, inplace=False)
50 | (linear2): Linear(in_features=512, out_features=256, bias=True)
51 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
52 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
53 | (dropout1): Dropout(p=0.1, inplace=False)
54 | (dropout2): Dropout(p=0.1, inplace=False)
55 | (self_attn_obj): MultiheadAttention(
56 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
57 | )
58 | (self_attn_time): MultiheadAttention(
59 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
60 | )
61 | )
62 | (3): ObjectCentricTransformerLayerV2(
63 | (self_attn): MultiheadAttention(
64 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
65 | )
66 | (linear1): Linear(in_features=256, out_features=512, bias=True)
67 | (dropout): Dropout(p=0.1, inplace=False)
68 | (linear2): Linear(in_features=512, out_features=256, bias=True)
69 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
70 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
71 | (dropout1): Dropout(p=0.1, inplace=False)
72 | (dropout2): Dropout(p=0.1, inplace=False)
73 | (self_attn_obj): MultiheadAttention(
74 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
75 | )
76 | (self_attn_time): MultiheadAttention(
77 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
78 | )
79 | )
80 | )
81 | (pe): PositionalEncoding(
82 | (dropout): Dropout(p=0.1, inplace=False)
83 | )
84 | )
--------------------------------------------------------------------------------
/experiments/Obj3D/Predictor_OCVPPar/model_architecture.txt:
--------------------------------------------------------------------------------
1 | Total Params: 1091328
2 |
3 | Params: 1091328
4 | OCVTransformerV2Predictor(
5 | (mlp_in): Linear(in_features=128, out_features=128, bias=True)
6 | (mlp_out): Linear(in_features=128, out_features=128, bias=True)
7 | (transformer_encoders): Sequential(
8 | (0): ObjectCentricTransformerLayerV2(
9 | (self_attn): MultiheadAttention(
10 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
11 | )
12 | (linear1): Linear(in_features=128, out_features=256, bias=True)
13 | (dropout): Dropout(p=0.1, inplace=False)
14 | (linear2): Linear(in_features=256, out_features=128, bias=True)
15 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
16 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
17 | (dropout1): Dropout(p=0.1, inplace=False)
18 | (dropout2): Dropout(p=0.1, inplace=False)
19 | (self_attn_obj): MultiheadAttention(
20 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
21 | )
22 | (self_attn_time): MultiheadAttention(
23 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
24 | )
25 | )
26 | (1): ObjectCentricTransformerLayerV2(
27 | (self_attn): MultiheadAttention(
28 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
29 | )
30 | (linear1): Linear(in_features=128, out_features=256, bias=True)
31 | (dropout): Dropout(p=0.1, inplace=False)
32 | (linear2): Linear(in_features=256, out_features=128, bias=True)
33 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
34 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
35 | (dropout1): Dropout(p=0.1, inplace=False)
36 | (dropout2): Dropout(p=0.1, inplace=False)
37 | (self_attn_obj): MultiheadAttention(
38 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
39 | )
40 | (self_attn_time): MultiheadAttention(
41 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
42 | )
43 | )
44 | (2): ObjectCentricTransformerLayerV2(
45 | (self_attn): MultiheadAttention(
46 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
47 | )
48 | (linear1): Linear(in_features=128, out_features=256, bias=True)
49 | (dropout): Dropout(p=0.1, inplace=False)
50 | (linear2): Linear(in_features=256, out_features=128, bias=True)
51 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
52 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
53 | (dropout1): Dropout(p=0.1, inplace=False)
54 | (dropout2): Dropout(p=0.1, inplace=False)
55 | (self_attn_obj): MultiheadAttention(
56 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
57 | )
58 | (self_attn_time): MultiheadAttention(
59 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
60 | )
61 | )
62 | (3): ObjectCentricTransformerLayerV2(
63 | (self_attn): MultiheadAttention(
64 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
65 | )
66 | (linear1): Linear(in_features=128, out_features=256, bias=True)
67 | (dropout): Dropout(p=0.1, inplace=False)
68 | (linear2): Linear(in_features=256, out_features=128, bias=True)
69 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
70 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
71 | (dropout1): Dropout(p=0.1, inplace=False)
72 | (dropout2): Dropout(p=0.1, inplace=False)
73 | (self_attn_obj): MultiheadAttention(
74 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
75 | )
76 | (self_attn_time): MultiheadAttention(
77 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
78 | )
79 | )
80 | )
81 | (pe): PositionalEncoding(
82 | (dropout): Dropout(p=0.1, inplace=False)
83 | )
84 | )
--------------------------------------------------------------------------------
/experiments/MOViA/model_architecture.txt:
--------------------------------------------------------------------------------
1 | Total Params: 1013445
2 |
3 | Params: 34177
4 | CoordInit(
5 | (coord_encoder): Sequential(
6 | (0): Linear(in_features=4, out_features=256, bias=True)
7 | (1): ReLU()
8 | (2): Linear(in_features=256, out_features=128, bias=True)
9 | )
10 | )
11 |
12 | Params: 79328
13 | SimpleConvEncoder(
14 | (encoder): Sequential(
15 | (0): ConvBlock(
16 | (block): Sequential(
17 | (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
18 | (1): ReLU()
19 | )
20 | )
21 | (1): ConvBlock(
22 | (block): Sequential(
23 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
24 | (1): ReLU()
25 | )
26 | )
27 | (2): ConvBlock(
28 | (block): Sequential(
29 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
30 | (1): ReLU()
31 | )
32 | )
33 | (3): ConvBlock(
34 | (block): Sequential(
35 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
36 | (1): ReLU()
37 | )
38 | )
39 | )
40 | )
41 |
42 | Params: 160
43 | SoftPositionEmbed(
44 | (projection): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
45 | )
46 |
47 | Params: 20800
48 | Sequential(
49 | (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
50 | (1): Linear(in_features=32, out_features=128, bias=True)
51 | (2): ReLU()
52 | (3): Linear(in_features=128, out_features=128, bias=True)
53 | )
54 |
55 | Params: 148480
56 | TransformerBlock(
57 | (attn): MultiHeadSelfAttention(
58 | (q): Linear(in_features=128, out_features=128, bias=False)
59 | (k): Linear(in_features=128, out_features=128, bias=False)
60 | (v): Linear(in_features=128, out_features=128, bias=False)
61 | (drop): Dropout(p=0.0, inplace=False)
62 | (out_projection): Sequential(
63 | (0): Linear(in_features=128, out_features=128, bias=False)
64 | )
65 | )
66 | (mlp): Sequential(
67 | (0): Linear(in_features=128, out_features=256, bias=True)
68 | (1): ReLU()
69 | (2): Linear(in_features=256, out_features=128, bias=True)
70 | )
71 | (layernorm_query): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
72 | (layernorm_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
73 | (dense_o): Linear(in_features=128, out_features=128, bias=True)
74 | )
75 |
76 | Params: 640
77 | SoftPositionEmbed(
78 | (projection): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1))
79 | )
80 |
81 | Params: 514564
82 | Decoder(
83 | (decoder): Sequential(
84 | (0): ConvBlock(
85 | (block): Sequential(
86 | (0): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
87 | (1): ReLU()
88 | )
89 | )
90 | (1): Upsample(scale_factor=2)
91 | (2): ConvBlock(
92 | (block): Sequential(
93 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
94 | (1): ReLU()
95 | )
96 | )
97 | (3): Upsample(scale_factor=2)
98 | (4): ConvBlock(
99 | (block): Sequential(
100 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
101 | (1): ReLU()
102 | )
103 | )
104 | (5): Upsample(scale_factor=2)
105 | (6): ConvBlock(
106 | (block): Sequential(
107 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
108 | (1): ReLU()
109 | )
110 | )
111 | (7): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
112 | )
113 | )
114 |
115 | Params: 215296
116 | SlotAttention(
117 | (norm_input): LayerNorm((128,), eps=0.001, elementwise_affine=True)
118 | (norm_slot): LayerNorm((128,), eps=0.001, elementwise_affine=True)
119 | (norm_mlp): LayerNorm((128,), eps=0.001, elementwise_affine=True)
120 | (to_q): Linear(in_features=128, out_features=128, bias=True)
121 | (to_k): Linear(in_features=128, out_features=128, bias=True)
122 | (to_v): Linear(in_features=128, out_features=128, bias=True)
123 | (gru): GRUCell(128, 128)
124 | (mlp): Sequential(
125 | (0): Linear(in_features=128, out_features=256, bias=True)
126 | (1): ReLU()
127 | (2): Linear(in_features=256, out_features=128, bias=True)
128 | )
129 | )
130 |
--------------------------------------------------------------------------------
/src/lib/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Methods to manage parameters and configurations
3 | TODO:
4 | - Add support to change any values from the command line
5 | """
6 |
7 | import os
8 | import json
9 |
10 | from lib.logger import print_
11 | from lib.utils import timestamp
12 | from CONFIG import DEFAULTS, CONFIG
13 |
14 |
15 | class Config(dict):
16 | """
17 | """
18 | _default_values = DEFAULTS
19 | _help = "Potentially you can add here comments for what your configs are"
20 | _config_groups = ["dataset", "model", "training", "loss"]
21 |
22 | def __init__(self, exp_path):
23 | """
24 | Populating the dictionary with the default values
25 | """
26 | for key in self._default_values.keys():
27 | self[key] = self._default_values[key]
28 | self["_general"] = {}
29 | self["_general"]["exp_path"] = exp_path
30 | return
31 |
32 | def create_exp_config_file(self, exp_path=None, config=None):
33 | """
34 | Creating a JSON file with exp configs in the experiment path
35 | """
36 | exp_path = exp_path if exp_path is not None else self["_general"]["exp_path"]
37 | if not os.path.exists(exp_path):
38 | raise FileNotFoundError(f"ERROR!: exp_path {exp_path} does not exist...")
39 |
40 | if config is not None:
41 | config_file = os.path.join(CONFIG["paths"]["configs_path"], config)
42 | if not os.path.exists(config_file):
43 | raise FileNotFoundError(f"Given config file {config_file} does not exist...")
44 |
45 | with open(config_file) as file:
46 | self = json.load(file)
47 | self["_general"] = {}
48 | self["_general"]["exp_path"] = exp_path
49 | print_(f"Creating experiment parameters file from config {config}...")
50 |
51 | self["_general"]["created_time"] = timestamp()
52 | self["_general"]["last_loaded"] = timestamp()
53 | exp_config = os.path.join(exp_path, "experiment_params.json")
54 | with open(exp_config, "w") as file:
55 | json.dump(self, file)
56 | return
57 |
58 | def load_exp_config_file(self, exp_path=None):
59 | """
60 | Loading the JSON file with exp configs
61 | """
62 | if exp_path is not None:
63 | self["_general"]["exp_path"] = exp_path
64 | exp_config = os.path.join(self["_general"]["exp_path"], "experiment_params.json")
65 | if not os.path.exists(exp_config):
66 | raise FileNotFoundError(f"ERROR! exp. configs file {exp_config} does not exist...")
67 |
68 | with open(exp_config) as file:
69 | self = json.load(file)
70 | self["_general"]["last_loaded"] = timestamp()
71 | return self
72 |
73 | def update_config(self, exp_params):
74 | """
75 | Updating an experiments parameters file with newly added configurations from CONFIG.
76 | """
77 | # TODO: Add recursion to make it always work
78 | for group in Config._config_groups:
79 | if not isinstance(Config._default_values[group], dict):
80 | continue
81 | for k in Config._default_values[group].keys():
82 | if(k not in exp_params[group]):
83 | if(isinstance(Config._default_values[group][k], (dict))):
84 | exp_params[group][k] = {}
85 | else:
86 | exp_params[group][k] = Config._default_values[group][k]
87 |
88 | if(isinstance(Config._default_values[group][k], dict)):
89 | for q in Config._default_values[group][k].keys():
90 | if(q not in exp_params[group][k]):
91 | exp_params[group][k][q] = Config._default_values[group][k][q]
92 | return exp_params
93 |
94 | def save_exp_config_file(self, exp_path=None, exp_params=None):
95 | """
96 | Dumping experiment parameters into path
97 | """
98 | exp_path = self["_general"]["exp_path"] if exp_path is None else exp_path
99 | exp_params = self if exp_params is None else exp_params
100 |
101 | exp_config = os.path.join(exp_path, "experiment_params.json")
102 | with open(exp_config, "w") as file:
103 | json.dump(exp_params, file)
104 | return
105 |
106 | #
107 |
--------------------------------------------------------------------------------
/src/06_generate_figs_savi.py:
--------------------------------------------------------------------------------
1 | """
2 | Generating figures using a pretrained SAVI model
3 | """
4 |
5 | import os
6 |
7 | import matplotlib.pyplot as plt
8 | import torch
9 |
10 | from base.baseFigGenerator import BaseFigGenerator
11 |
12 | from data import unwrap_batch_data
13 | from lib.arguments import get_generate_figs_savi
14 | from lib.logger import print_
15 | import lib.utils as utils
16 | from lib.visualizations import visualize_recons, visualize_decomp
17 |
18 |
19 | class FigGenerator(BaseFigGenerator):
20 | """
21 | Class for generating figures using a pretrained SAVI model
22 | """
23 |
24 | def __init__(self, exp_path, savi_model, num_seqs=10):
25 | """
26 | Initializing the figure generation module
27 | """
28 | super().__init__(
29 | exp_path=exp_path,
30 | savi_model=savi_model,
31 | num_seqs=num_seqs
32 | )
33 |
34 | model_name = savi_model.split('.')[0]
35 | self.plots_path = os.path.join(
36 | self.exp_path,
37 | "plots",
38 | f"figGeneration_SaVIModel_{model_name}"
39 | )
40 | self.models_path = os.path.join(self.exp_path, "models")
41 | utils.create_directory(self.plots_path)
42 | return
43 |
44 | @torch.no_grad()
45 | def compute_visualization(self, batch_data, img_idx):
46 | """
47 | Computing visualization
48 | """
49 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data)
50 | videos, targets = videos.to(self.device), targets.to(self.device)
51 | out_model = self.model(videos, num_imgs=videos.shape[1], **initializer_kwargs)
52 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model
53 |
54 | cur_dir = f"sequence_{img_idx:02d}"
55 | utils.create_directory(os.path.join(self.plots_path, cur_dir))
56 |
57 | N = min(10, videos.shape[1])
58 | savepath = os.path.join(self.plots_path, cur_dir, f"Recons_{img_idx+1}.png")
59 | visualize_recons(
60 | imgs=videos[0, :N].clamp(0, 1),
61 | recons=reconstruction_history[0, :N].clamp(0, 1),
62 | n_cols=10,
63 | savepath=savepath
64 | )
65 |
66 | savepath = os.path.join(self.plots_path, cur_dir, f"ReconsTargets_{img_idx+1}.png")
67 | visualize_recons(
68 | imgs=targets[0, :N].clamp(0, 1),
69 | recons=reconstruction_history[0, :N].clamp(0, 1),
70 | n_cols=10,
71 | savepath=savepath
72 | )
73 |
74 | savepath = os.path.join(self.plots_path, cur_dir, f"Objects_{img_idx+1}.png")
75 | fig, _, _ = visualize_decomp(
76 | individual_recons_history[0, :N],
77 | savepath=savepath,
78 | vmin=0,
79 | vmax=1,
80 | )
81 | plt.close(fig)
82 |
83 | savepath = os.path.join(self.plots_path, cur_dir, f"masks_{img_idx+1}.png")
84 | fig, _, _ = visualize_decomp(
85 | masks_history[0][:N],
86 | savepath=savepath,
87 | cmap="gray_r",
88 | vmin=0,
89 | vmax=1,
90 | )
91 | plt.close(fig)
92 | savepath = os.path.join(self.plots_path, cur_dir, f"maskedObj_{img_idx+1}.png")
93 | recon_combined = masks_history[0][:N] * individual_recons_history[0][:N]
94 | recon_combined = torch.clamp(recon_combined, min=0, max=1)
95 | fig, _, _ = visualize_decomp(
96 | recon_combined,
97 | savepath=savepath,
98 | vmin=0,
99 | vmax=1,
100 | )
101 | plt.close(fig)
102 | return
103 |
104 |
105 | if __name__ == "__main__":
106 | utils.clear_cmd()
107 | exp_path, args = get_generate_figs_savi()
108 | print_("Generating figures for SAVI...")
109 | figGenerator = FigGenerator(
110 | exp_path=exp_path,
111 | savi_model=args.savi_model,
112 | num_seqs=args.num_seqs
113 | )
114 | print_("Loading dataset...")
115 | figGenerator.load_data()
116 | print_("Setting up model and loading pretrained parameters")
117 | figGenerator.load_model()
118 | print_("Generating and saving figures")
119 | figGenerator.generate_figs()
120 |
121 |
122 | #
123 |
--------------------------------------------------------------------------------
/src/lib/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Utils for creating a Logger that writes info, warnings and so on into a logs file
3 | in the experiment directory. Each experiment has its own independent logs
4 | """
5 |
6 | import os
7 | import traceback
8 | from datetime import datetime
9 |
10 | LOGGER = None
11 |
12 |
13 | def log_function(func):
14 | """
15 | Decorator for logging a method in case of raising an exception
16 | """
17 | def try_call_log(*args, **kwargs):
18 | """
19 | Calling the function but calling the logger in case an exception is raised
20 | """
21 | try:
22 | if(LOGGER is not None):
23 | message = f"Calling: {func.__name__}..."
24 | LOGGER.log_info(message=message, message_type="info")
25 | return func(*args, **kwargs)
26 | except Exception as e:
27 | if(LOGGER is None):
28 | raise e
29 | message = traceback.format_exc()
30 | print_(message, message_type="error")
31 | exit()
32 | return try_call_log
33 |
34 |
35 | def for_all_methods(decorator):
36 | """
37 | Decorator that applies a decorator to all methods inside a class
38 | """
39 | def decorate(cls):
40 | for attr in cls.__dict__: # there's propably a better way to do this
41 | if callable(getattr(cls, attr)):
42 | setattr(cls, attr, decorator(getattr(cls, attr)))
43 | return cls
44 | return decorate
45 |
46 |
47 | def print_(message, message_type="info"):
48 | """
49 | Overloads the print method so that the message is written both in logs file and console
50 | """
51 |
52 | print(message)
53 | if(LOGGER is not None):
54 | LOGGER.log_info(message, message_type)
55 | return
56 |
57 |
58 | def log_info(message, message_type="info"):
59 | if(LOGGER is not None):
60 | LOGGER.log_info(message, message_type)
61 | return
62 |
63 |
64 | class Logger():
65 | """
66 | Class that instanciates a Logger object to write logs into a file
67 |
68 | Args:
69 | -----
70 | exp_path: string
71 | path to the root directory of an experiment where the logs are saved
72 | file_name: string
73 | name of the file where logs are stored
74 | """
75 |
76 | def __init__(self, exp_path, file_name="logs.txt"):
77 | """
78 | Initializer of the logger object
79 | """
80 |
81 | logs_path = os.path.join(exp_path, file_name)
82 | self.logs_path = logs_path
83 |
84 | if not os.path.exists(logs_path):
85 | if(not os.path.exists(exp_path)):
86 | os.makedirs(exp_path)
87 | with open(logs_path, 'w') as f:
88 | f.write("")
89 |
90 | global LOGGER
91 | LOGGER = self
92 | return
93 |
94 | def log_info(self, message, message_type="info", **kwargs):
95 | """
96 | Logging a message into the file
97 | """
98 |
99 | if(message_type not in ["new_exp", "info", "warning", "error", "params"]):
100 | message_type = "info"
101 | cur_time = self._get_datetime()
102 | format_message = self._format_message(message=message, cur_time=cur_time,
103 | message_type=message_type)
104 | with open(self.logs_path, 'a') as f:
105 | f.write(format_message)
106 |
107 | if(message_type == "error"):
108 | exit()
109 |
110 | return
111 |
112 | def log_params(self, params):
113 | """
114 | Logging parameters so that it is visually appealing
115 | Args:
116 | -----
117 | params: dictionary
118 | dictionary containing parameters and values
119 | """
120 |
121 | for param, value in params.items():
122 | message = f" {param}:{value}"
123 | self.log_info(message, message_type="params")
124 |
125 | return
126 |
127 | def _format_message(self, message, cur_time, message_type="info"):
128 | """
129 | Formatting the message to have a standarizied template
130 | """
131 | pre_string = ""
132 | if(message_type == "new_exp"):
133 | pre_string = "\n\n\n"
134 | form_message = f"{pre_string}{cur_time} {message_type.upper()}: {message}\n"
135 | return form_message
136 |
137 | def _get_datetime(self):
138 | """
139 | Obtaining current data and time in format YYYY-MM-DD-HH-MM-SS
140 | """
141 | time = datetime.today().strftime('%Y-%m-%d-%H:%M:%S')
142 | return time
143 |
144 | #
145 |
--------------------------------------------------------------------------------
/src/data/load_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Methods for loading specific datasets, fitting data loaders and other
3 | """
4 |
5 | # from torchvision import datasets
6 | from torch.utils.data import DataLoader
7 | from data import OBJ3D, MOVI
8 | from CONFIG import CONFIG, DATASETS
9 |
10 |
11 | def load_data(exp_params, split="train"):
12 | """
13 | Loading a dataset given the parameters
14 |
15 | Args:
16 | -----
17 | dataset_name: string
18 | name of the dataset to load
19 | split: string
20 | Split from the dataset to obtain (e.g., 'train' or 'test')
21 |
22 | Returns:
23 | --------
24 | dataset: torch dataset
25 | Dataset loaded given specifications from exp_params
26 | """
27 | dataset_name = exp_params["dataset"]["dataset_name"]
28 |
29 | if dataset_name == "OBJ3D":
30 | dataset = OBJ3D(
31 | mode=split,
32 | sample_length=exp_params["training_prediction"]["sample_length"]
33 | )
34 | elif dataset_name == "MoviA":
35 | dataset = MOVI(
36 | datapath="/home/nfs/inf6/data/datasets/MOVi/movi_a",
37 | target=exp_params["dataset"].get("target", "rgb"),
38 | split=split,
39 | num_frames=exp_params["training_prediction"]["sample_length"],
40 | img_size=(64, 64),
41 | random_start=exp_params["dataset"].get("random_start", False),
42 | slot_initializer=exp_params["model"]["SAVi"].get("initializer", "LearnedRandom")
43 | )
44 | elif dataset_name == "MoviC":
45 | dataset = MOVI(
46 | datapath="/home/nfs/inf6/data/datasets/MOVi/movi_c",
47 | target=exp_params["dataset"].get("target", "rgb"),
48 | split=split,
49 | num_frames=exp_params["training_prediction"]["sample_length"],
50 | img_size=(64, 64),
51 | random_start=exp_params["dataset"].get("random_start", False),
52 | slot_initializer=exp_params["model"]["SAVi"].get("initializer", "LearnedRandom")
53 | )
54 | else:
55 | raise NotImplementedError(
56 | f"""ERROR! Dataset'{dataset_name}' is not available.
57 | Please use one of the following: {DATASETS}..."""
58 | )
59 |
60 | return dataset
61 |
62 |
63 | def build_data_loader(dataset, batch_size=8, shuffle=False):
64 | """
65 | Fitting a data loader for the given dataset
66 |
67 | Args:
68 | -----
69 | dataset: torch dataset
70 | Dataset (or dataset split) to fit to the DataLoader
71 | batch_size: integer
72 | number of elements per mini-batch
73 | shuffle: boolean
74 | If True, mini-batches are sampled randomly from the database
75 | """
76 |
77 | data_loader = DataLoader(
78 | dataset=dataset,
79 | batch_size=batch_size,
80 | shuffle=shuffle,
81 | num_workers=CONFIG["num_workers"]
82 | )
83 |
84 | return data_loader
85 |
86 |
87 | def unwrap_batch_data(exp_params, batch_data):
88 | """
89 | Unwrapping the batch data depending on the dataset that we are training on
90 | """
91 | initializer_kwargs = {}
92 | if exp_params["dataset"]["dataset_name"] in ["OBJ3D"]:
93 | videos, targets, _ = batch_data
94 | elif exp_params["dataset"]["dataset_name"] in ["MoviA", "MoviC"]:
95 | videos, targets, all_reps = batch_data
96 | initializer_kwargs["instance_masks"] = all_reps["masks"]
97 | initializer_kwargs["com_coords"] = all_reps["com_coords"]
98 | initializer_kwargs["bbox_coords"] = all_reps["bbox_coords"]
99 | else:
100 | dataset_name = exp_params["dataset"]["dataset_name"]
101 | raise NotImplementedError(f"Dataset {dataset_name} is not supported...")
102 | return videos, targets, initializer_kwargs
103 |
104 |
105 | def unwrap_batch_data_masks(exp_params, batch_data):
106 | """
107 | Unwrapping the batch data for a mask-based evaluation depending on the dataset that
108 | we are currently evaluating on
109 | """
110 | dbs = ["MoviA", "MoviC"]
111 | dataset_name = exp_params["dataset"]["dataset_name"]
112 | initializer_kwargs = {}
113 | if dataset_name in ["MoviA", "MoviC"]:
114 | videos, _, all_reps = batch_data
115 | masks = all_reps["masks"]
116 | initializer_kwargs["instance_masks"] = all_reps["masks"]
117 | initializer_kwargs["com_coords"] = all_reps["com_coords"]
118 | initializer_kwargs["bbox_coords"] = all_reps["bbox_coords"]
119 | else:
120 | raise ValueError(f"Only {dbs} support object-based mask evaluation. Given {dataset_name = }")
121 | return videos, masks, initializer_kwargs
122 |
123 |
124 | #
125 |
--------------------------------------------------------------------------------
/src/base/baseEvaluator.py:
--------------------------------------------------------------------------------
1 | """
2 | Base evaluator from which all backbone evaluator modules inherit.
3 |
4 | Basically it removes the scaffolding that is repeat across all evaluation modules
5 | """
6 |
7 | import os
8 | from tqdm import tqdm
9 | import torch
10 |
11 | from lib.config import Config
12 | from lib.logger import log_function, for_all_methods
13 | from lib.metrics import MetricTracker
14 | import lib.setup_model as setup_model
15 | import lib.utils as utils
16 | import data
17 |
18 |
19 | @for_all_methods(log_function)
20 | class BaseEvaluator:
21 | """
22 | Base Class for evaluating a model
23 |
24 | Args:
25 | -----
26 | exp_path: string
27 | Path to the experiment directory from which to read the experiment parameters,
28 | and where to store logs, plots and checkpoints
29 | checkpoint: string/None
30 | Name of a model checkpoint to evaluate.
31 | It must be stored in the models/ directory of the experiment directory.
32 | """
33 |
34 | def __init__(self, exp_path, checkpoint):
35 | """
36 | Initializing the trainer object
37 | """
38 | self.exp_path = exp_path
39 | self.cfg = Config(exp_path)
40 | self.exp_params = self.cfg.load_exp_config_file()
41 | self.checkpoint = checkpoint
42 | model_name = checkpoint.split(".")[0]
43 | self.results_name = f"{model_name}"
44 |
45 | self.plots_path = os.path.join(self.exp_path, "plots")
46 | utils.create_directory(self.plots_path)
47 | self.models_path = os.path.join(self.exp_path, "models")
48 | utils.create_directory(self.models_path)
49 | return
50 |
51 | def set_metric_tracker(self):
52 | """
53 | Initializing the metric tracker with evaluation metrics to track
54 | """
55 | self.metric_tracker = MetricTracker(
56 | self.exp_path,
57 | metrics=["segmentation_ari"]
58 | )
59 |
60 | def load_data(self):
61 | """
62 | Loading test-set and fitting data-loader for iterating in a batch-like fashion
63 | """
64 | batch_size = 1 # self.exp_params["training"]["batch_size"]
65 | shuffle_eval = self.exp_params["dataset"]["shuffle_eval"]
66 | self.test_set = data.load_data(
67 | exp_params=self.exp_params,
68 | split="test"
69 | )
70 | self.test_loader = data.build_data_loader(
71 | dataset=self.test_set,
72 | batch_size=batch_size,
73 | shuffle=shuffle_eval
74 | )
75 | return
76 |
77 | def setup_model(self):
78 | """
79 | Initializing model and loading pretrained parameters given checkpoint
80 | """
81 | torch.backends.cudnn.fastest = True
82 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83 |
84 | # loading model
85 | self.model = setup_model.setup_model(model_params=self.exp_params["model"])
86 | self.model = self.model.eval().to(self.device)
87 |
88 | # loading pretrained paramters
89 | checkpoint_path = os.path.join(self.models_path, self.checkpoint)
90 | self.model = setup_model.load_checkpoint(
91 | checkpoint_path=checkpoint_path,
92 | model=self.model,
93 | only_model=True
94 | )
95 | self.set_metric_tracker()
96 | return
97 |
98 | @torch.no_grad()
99 | def evaluate(self, save_results=True):
100 | """
101 | Evaluating model
102 | """
103 | self.model = self.model.eval()
104 | progress_bar = tqdm(enumerate(self.test_loader), total=len(self.test_loader))
105 |
106 | # iterating test set and accumulating the results
107 | for i, batch_data in progress_bar:
108 | self.forward_eval(batch_data=batch_data)
109 | progress_bar.set_description(f"Iter {i}/{len(self.test_loader)}")
110 |
111 | # computing average results and saving to results file
112 | self.metric_tracker.aggregate()
113 | self.results = self.metric_tracker.summary()
114 | if save_results:
115 | self.metric_tracker.save_results(exp_path=self.exp_path, fname=self.results_name)
116 | return
117 |
118 | def forward_eval(self, batch_data, **kwargs):
119 | """
120 | Making a forwad pass through the model and computing the evaluation metrics
121 |
122 | Args:
123 | -----
124 | batch_data: dict
125 | Dictionary containing the information for the current batch, including images, poses,
126 | actions, or metadata, among others.
127 |
128 | Returns:
129 | --------
130 | pred_data: dict
131 | Predictions from the model for the current batch of data
132 | """
133 | raise NotImplementedError("Base Evaluator Module does not implement 'forward_eval'...")
134 |
135 | #
136 |
--------------------------------------------------------------------------------
/src/base/baseFigGenerator.py:
--------------------------------------------------------------------------------
1 | """
2 | Base module for generating images from a pretrained model.
3 | All other Figure Generator modules inherit from here.
4 |
5 | Basically it removes the scaffolding that is repeat across all Figure Generation modules
6 | """
7 |
8 | import os
9 | from tqdm import tqdm
10 | import torch
11 |
12 | from lib.config import Config
13 | from lib.logger import log_function, for_all_methods
14 | import lib.setup_model as setup_model
15 | import lib.utils as utils
16 | import data
17 | from models.model_utils import freeze_params
18 |
19 |
20 | @for_all_methods(log_function)
21 | class BaseFigGenerator:
22 | """
23 | Base Class for figure generation
24 |
25 | Args:
26 | -----
27 | exp_path: string
28 | Path to the experiment directory from which to read the experiment parameters,
29 | and where to store logs, plots and checkpoints
30 | savi_model: string/None
31 | Name of SAVI model checkpoint to use when generating figures.
32 | It must be stored in the models/ directory of the experiment directory.
33 | num_seqs: int
34 | Number of sequences to process and save
35 | """
36 |
37 | def __init__(self, exp_path, savi_model, num_seqs=10):
38 | """
39 | Initializing the figure generator object
40 | """
41 | self.exp_path = os.path.join(exp_path)
42 | self.cfg = Config(self.exp_path)
43 | self.exp_params = self.cfg.load_exp_config_file()
44 | self.savi_model = savi_model
45 | self.num_seqs = num_seqs
46 |
47 | model_name = savi_model.split('.')[0]
48 | self.plots_path = os.path.join(
49 | self.exp_path,
50 | "plots",
51 | f"figGeneration_SaVIModel_{model_name}"
52 | )
53 | self.models_path = os.path.join(self.exp_path, "models")
54 | utils.create_directory(self.models_path)
55 | return
56 |
57 | def load_data(self):
58 | """
59 | Loading dataset and fitting data-loader for iterating in a batch-like fashion
60 | """
61 | batch_size = 1
62 | shuffle_eval = self.exp_params["dataset"]["shuffle_eval"]
63 | # test_set = data.load_data(exp_params=self.exp_params, split="test")
64 | test_set = data.load_data(exp_params=self.exp_params, split="valid")
65 | self.test_set = test_set
66 | self.test_loader = data.build_data_loader(
67 | dataset=test_set,
68 | batch_size=batch_size,
69 | shuffle=shuffle_eval
70 | )
71 | return
72 |
73 | def load_model(self, exp_path=None):
74 | """
75 | Load pretraiened SAVi model from checkpoint
76 |
77 | Args:
78 | -----
79 | exp_path: sting/None
80 | If None, 'self.exp_path' is used. Otherwise, the given path is used to load
81 | the pretrained model paramters
82 | """
83 | # to use the same function for SAVi and Predictor figure generation
84 | exp_path = exp_path if exp_path is not None else self.exp_path
85 |
86 | torch.backends.cudnn.fastest = True
87 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88 |
89 | # loading model
90 | self.model = setup_model.setup_model(model_params=self.exp_params["model"])
91 | self.model = self.model.eval().to(self.device)
92 |
93 | checkpoint_path = os.path.join(exp_path, "models", self.savi_model)
94 | self.model = setup_model.load_checkpoint(
95 | checkpoint_path=checkpoint_path,
96 | model=self.model,
97 | only_model=True
98 | )
99 | freeze_params(self.model)
100 | return
101 |
102 | def load_predictor(self):
103 | """
104 | Load pretrained predictor model from the corresponding model checkpoint
105 | """
106 | # loading model
107 | predictor = setup_model.setup_predictor(exp_params=self.exp_params)
108 | predictor = predictor.eval().to(self.device)
109 |
110 | # loading pretrained predictor
111 | predictor = setup_model.load_checkpoint(
112 | checkpoint_path=os.path.join(self.models_path, self.checkpoint),
113 | model=predictor,
114 | only_model=True,
115 | )
116 | self.predictor = predictor
117 | return
118 |
119 | @torch.no_grad()
120 | def generate_figs(self):
121 | """
122 | Computing and saving visualizations
123 | """
124 | progress_bar = tqdm(enumerate(self.test_loader), total=self.num_seqs)
125 | for i, batch_data in progress_bar:
126 | if i >= self.num_seqs:
127 | break
128 | self.compute_visualization(batch_data=batch_data, img_idx=i)
129 | return
130 |
131 | def compute_visualization(self, batch_data, img_idx, **kwargs):
132 | """
133 | Making a forwad pass through the model and computing the evaluation metrics
134 |
135 | Args:
136 | -----
137 | batch_data: dict
138 | Dictionary containing the information for the current batch, including images, poses,
139 | actions, or metadata, among others.
140 | img_idx: int
141 | Index of the visualization to compute and save
142 | """
143 | raise NotImplementedError("Base FigGenerator does not implement 'compute_visualization'...")
144 |
145 | #
146 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Object-Centric Video Prediction via Decoupling
of Object Dynamics and Interactions
2 |
3 |
4 |
5 |
6 |
7 |
8 |