├── datasets └── .gitkeep ├── assets ├── gif1_gt.gif ├── gif2_gt.gif ├── teaser.png ├── gif1_pred.gif ├── gif1_seg.gif ├── gif2_pred.gif ├── gif2_seg.gif ├── AllPredictors.png └── docs │ ├── INSTALL.md │ └── TRAIN.md ├── experiments ├── Obj3D │ ├── Predictor_LSTM │ │ ├── model_architecture.txt │ │ └── experiment_params.json │ ├── Predictor_Transformer │ │ ├── model_architecture.txt │ │ └── experiment_params.json │ ├── Predictor_OCVPPar │ │ ├── experiment_params.json │ │ └── model_architecture.txt │ ├── Predictor_OCVPSeq │ │ ├── experiment_params.json │ │ └── model_architecture.txt │ ├── experiment_params.json │ └── model_architecture.txt └── MOViA │ ├── Predictor_LSTM │ ├── model_architecture.txt │ └── experiment_params.json │ ├── Predictor_OCVPPar │ ├── experiment_params.json │ └── model_architecture.txt │ ├── Predictor_OCVPSeq │ ├── experiment_params.json │ └── model_architecture.txt │ ├── Predictor_Transformer │ └── experiment_params.json │ ├── experiment_params.json │ └── model_architecture.txt ├── src ├── data │ ├── __init__.py │ ├── data_utils.py │ ├── obj3d.py │ ├── load_data.py │ └── MoviConvert.py ├── configs │ ├── __init__.py │ └── README.md ├── models │ ├── __init__.py │ ├── initializers.py │ ├── model_utils.py │ └── model_blocks.py ├── 01_create_experiment.py ├── extract_movi_dataset.py ├── 01_create_predictor_experiment.py ├── 03_evaluate_savi_noMasks.py ├── 03_evaluate_savi.py ├── lib │ ├── config.py │ ├── logger.py │ ├── loss.py │ ├── utils.py │ └── schedulers.py ├── 06_generate_figs_savi.py ├── base │ ├── baseEvaluator.py │ ├── baseFigGenerator.py │ └── basePredictorEvaluator.py ├── CONFIG.py ├── 05_evaluate_predictor.py ├── 02_train_savi.py └── 04_train_predictor.py ├── download_pretrained.sh ├── download_obj3d.sh ├── README.md └── environment.yml /datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/gif1_gt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif1_gt.gif -------------------------------------------------------------------------------- /assets/gif2_gt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif2_gt.gif -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /assets/gif1_pred.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif1_pred.gif -------------------------------------------------------------------------------- /assets/gif1_seg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif1_seg.gif -------------------------------------------------------------------------------- /assets/gif2_pred.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif2_pred.gif -------------------------------------------------------------------------------- /assets/gif2_seg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/gif2_seg.gif -------------------------------------------------------------------------------- /assets/AllPredictors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-Bonn/OCVP-object-centric-video-prediction/HEAD/assets/AllPredictors.png -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_LSTM/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 264192 2 | 3 | Params: 264192 4 | ModuleList( 5 | (0): LSTMCell(128, 128) 6 | (1): LSTMCell(128, 128) 7 | ) 8 | -------------------------------------------------------------------------------- /experiments/MOViA/Predictor_LSTM/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 264192 2 | 3 | Params: 264192 4 | LSTMPredictor( 5 | (lstm): ModuleList( 6 | (0): LSTMCell(128, 128) 7 | (1): LSTMCell(128, 128) 8 | ) 9 | ) -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Accessing datasets from script 3 | """ 4 | 5 | from .Movi import MOVI 6 | from .obj3d import OBJ3D 7 | 8 | from .load_data import load_data, build_data_loader, unwrap_batch_data, unwrap_batch_data_masks 9 | -------------------------------------------------------------------------------- /src/configs/__init__.py: -------------------------------------------------------------------------------- 1 | """ Configs """ 2 | 3 | import os 4 | from CONFIG import CONFIG 5 | 6 | 7 | def get_available_configs(): 8 | """ Getting a list with the name of the available config files """ 9 | config_path = CONFIG["paths"]["configs_path"] 10 | files = sorted(os.listdir(config_path)) 11 | available_configs = [f[:-5] for f in files if f[-5:] == ".json"] 12 | return available_configs 13 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Accesing models 3 | """ 4 | 5 | from .model_blocks import SoftPositionEmbed 6 | from .encoders_decoders import get_encoder, get_decoder, SimpleConvEncoder, DownsamplingConvEncoder 7 | from .initializers import get_initalizer 8 | 9 | from .attention import SlotAttention, MultiHeadSelfAttention, TransformerBlock 10 | from .SAVi import SAVi 11 | from .model_utils import freeze_params 12 | -------------------------------------------------------------------------------- /download_pretrained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir experiments 4 | 5 | wget https://www.dropbox.com/s/mbbae2yk7cexwig/MOViA.zip?dl=1 6 | unzip MOViA.zip?dl=1 7 | rm MOViA.zip?dl=1 8 | rsync -va --delete-after MOViA experiments 9 | rm -r MOViA 10 | 11 | wget https://www.dropbox.com/s/luuzfmo3v2ka2kb/Obj3D.zip?dl=1 12 | unzip Obj3D.zip?dl=1 13 | rm Obj3D.zip?dl=1 14 | rsync -va --delete-after Obj3D experiments 15 | rm -r Obj3D 16 | -------------------------------------------------------------------------------- /src/configs/README.md: -------------------------------------------------------------------------------- 1 | # Recommended Configs 2 | 3 | The files in this directory correspond to default ```experiment_parameters.json``` for different datasets. 4 | 5 | These experiment values have been tested and lead to reasonable results. 6 | 7 | 8 | #### TODO 9 | 10 | - Add command line argument to ```01_create_experiment.py``` to initialize the experiment's ```experiment_parameters.json``` with one of the here included configurations. 11 | -------------------------------------------------------------------------------- /download_obj3d.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1XSLW3qBtcxxvV-5oiRruVTlDlQ_Yatzm' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1XSLW3qBtcxxvV-5oiRruVTlDlQ_Yatzm" -O datasets/OBJ3D.zip && rm -rf /tmp/cookies.txt 3 | cd datasets && unzip OBJ3D.zip && rm OBJ3D.zip 4 | -------------------------------------------------------------------------------- /assets/docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | 1. Clone and enter the repository: 4 | ``` 5 | git clone git@github.com:AIS-Bonn/OCVP-object-centric-video-prediction.git 6 | cd OCVP-object-centric-video-prediction 7 | ``` 8 | 9 | 10 | 2. Install all required packages by installing the ```conda``` environment file included in the repository: 11 | ``` 12 | conda env create -f environment.yml 13 | conda activate OCVP 14 | ``` 15 | 16 | 17 | 3. Download the Obj3D and MOVi-A datasets, and place them under the `datasets` directory. The folder structure should be like: 18 | ``` 19 | OCVP 20 | ├── datasets/ 21 | | ├── Obj3D/ 22 | | └── MOViA/ 23 | ``` 24 | 25 | * **Obj3D:** Donwload and extract this dataset by running the following bash script: 26 | ``` 27 | chmod +x download_obj3d.sh 28 | ./download_obj3d.sh 29 | ``` 30 | 31 | - **MOViA:** Download the MOVi-A dataset to your local disk from the [Google Cloud Storage](https://console.cloud.google.com/storage/browser/kubric-public/tfds), and preprocess the *TFRecord* files to extract the video frames and other required metadata by running the following commands: 32 | ``` 33 | gsutil -m cp -r gs://kubric-public/tfds/movi_a/128x128/ . 34 | mkdir movi_a 35 | mv 128x128/ movi_a/128x128/ 36 | python src/extract_movi_dataset.py 37 | ``` 38 | 39 | 40 | 41 | 4. Download and extract the pretrained models, including checkpoints for the SAVi decomposition and prediction modules: 42 | ``` 43 | chmod +x download_pretrained.sh 44 | ./download_pretrained.sh 45 | ``` 46 | -------------------------------------------------------------------------------- /src/01_create_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creating experiment directory and initalizing it with defauls 3 | """ 4 | 5 | import os 6 | from lib.arguments import create_experiment_arguments 7 | from lib.config import Config 8 | from lib.logger import Logger, print_ 9 | from lib.utils import create_directory, delete_directory, timestamp, clear_cmd 10 | 11 | from CONFIG import CONFIG 12 | 13 | 14 | def initialize_experiment(): 15 | """ 16 | Creating experiment directory and initalizing it with defauls 17 | """ 18 | # reading command line args 19 | args = create_experiment_arguments() 20 | exp_dir, config, exp_name = args.exp_directory, args.config, args.name 21 | exp_name = f"experiment_{timestamp()}" if exp_name is None or len(exp_name) < 1 else exp_name 22 | exp_path = os.path.join(CONFIG["paths"]["experiments_path"], exp_dir, exp_name) 23 | 24 | # creating directories 25 | create_directory(exp_path) 26 | _ = Logger(exp_path) # initialize logger once exp_dir is created 27 | create_directory(dir_path=exp_path, dir_name="plots") 28 | create_directory(dir_path=exp_path, dir_name="tboard_logs") 29 | 30 | try: 31 | cfg = Config(exp_path=exp_path) 32 | cfg.create_exp_config_file(config=config) 33 | except FileNotFoundError as e: 34 | print_("An error has occurred...\n Removing experiment directory") 35 | delete_directory(dir_path=exp_path) 36 | print(e) 37 | 38 | return 39 | 40 | 41 | if __name__ == "__main__": 42 | clear_cmd() 43 | initialize_experiment() 44 | 45 | # 46 | -------------------------------------------------------------------------------- /src/extract_movi_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converting the MOVi dataset from the shitty tensorboard format into images, so that we can 3 | then load them like civilised people without a million Tensflow errorsand warnings 4 | """ 5 | 6 | import os 7 | import torch 8 | import torchvision 9 | from tqdm import tqdm 10 | from data.MoviConvert import _MoviA 11 | import lib.utils as utils 12 | from CONFIG import CONFIG 13 | 14 | 15 | PATH = os.path.join(CONFIG["paths"]["data_path"], "movi_a") 16 | 17 | 18 | def process_dataset(split="train"): 19 | """ """ 20 | data_path = os.path.join(PATH, split) 21 | utils.create_directory(data_path) 22 | 23 | db = _MoviA(split=split, num_frames=100, img_size=(128, 128)) 24 | db.get_masks = True 25 | db.get_bbox = True 26 | print(f" --> {len(db) = }") 27 | 28 | # iterating and saving data 29 | for i in tqdm(range(len(db))): 30 | imgs, all_preds = db[i] 31 | bbox = all_preds["bbox_coords"] 32 | com = all_preds["com_coords"] 33 | masks = all_preds["masks"] 34 | flow = all_preds["flow"] 35 | for j in range(imgs.shape[0]): 36 | torchvision.utils.save_image(flow[j], fp=os.path.join(data_path, f"flow_{i:05d}_{j:02d}.png")) 37 | torchvision.utils.save_image(imgs[j], fp=os.path.join(data_path, f"rgb_{i:05d}_{j:02d}.png")) 38 | torch.save({"com": com, "bbox": bbox}, os.path.join(data_path, f"coords_{i:05d}.pt")) 39 | torch.save({"masks": masks}, os.path.join(data_path, f"mask_{i:05d}.pt")) 40 | 41 | return 42 | 43 | 44 | if __name__ == '__main__': 45 | os.system("clear") 46 | print("Processing Validation set") 47 | process_dataset(split="validation") 48 | print("Processing Training set") 49 | process_dataset(split="train") 50 | print("Finished") 51 | -------------------------------------------------------------------------------- /src/data/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some data processing and other utils 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def get_slots_stats(seq, masks): 10 | """ 11 | Obtaining stats about the number of slots in a video sequence 12 | 13 | Args: 14 | ----- 15 | seq: torch tensor 16 | Sequence of images. Shape is (N_frames, N_channels, H, W) 17 | masks: torch Tensor 18 | Instance segmentation masks. Shape is (N_frames, 1, H, W) 19 | """ 20 | total_num_slots = len(torch.unique(masks)) 21 | slot_dist = [len(torch.unique(m)) for m in masks] 22 | 23 | stats = { 24 | "total_num_slots": total_num_slots, 25 | "slot_dist": slot_dist, 26 | "max_num_slots": np.max(slot_dist), 27 | "min_num_slots": np.min(slot_dist) 28 | } 29 | return stats 30 | 31 | 32 | def masks_to_boxes(masks): 33 | """ 34 | Converting a binary segmentation mask into a bounding box 35 | 36 | Args: 37 | ----- 38 | masks: torch Tensor 39 | Segmentation masks. Shape is (n_imgs, n_objs, H, W) 40 | 41 | Returns: 42 | -------- 43 | bboxes: torch Tensor 44 | Bounding boxes corresponding the input segmentation masks in format [x1, y1, x2, y2]. 45 | Shape is (n_imgs, 4) 46 | """ 47 | assert masks.unique().tolist() == [0, 1] 48 | 49 | bboxes = torch.zeros(masks.shape[0], 4) 50 | for i, mask in enumerate(masks): 51 | if mask.max() == 0: 52 | bboxes[i] = torch.ones(4) * -1 53 | continue 54 | vertical_indices = torch.where(torch.any(mask, axis=1))[0] 55 | horizontal_indices = torch.where(torch.any(mask, axis=0))[0] 56 | if horizontal_indices.shape[0]: 57 | x1, x2 = horizontal_indices[[0, -1]] 58 | y1, y2 = vertical_indices[[0, -1]] 59 | else: 60 | bboxes[i] = torch.ones(4) * -1 61 | continue 62 | bboxes[i] = torch.tensor([x1, y1, x2, y2]) 63 | return bboxes.to(masks.device) 64 | 65 | # 66 | -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_Transformer/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 562944 2 | 3 | Params: 16512 4 | Linear(in_features=128, out_features=128, bias=True) 5 | 6 | Params: 16512 7 | Linear(in_features=128, out_features=128, bias=True) 8 | 9 | Params: 529920 10 | Sequential( 11 | (0): TransformerEncoderLayer( 12 | (self_attn): MultiheadAttention( 13 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 14 | ) 15 | (linear1): Linear(in_features=128, out_features=256, bias=True) 16 | (dropout): Dropout(p=0.1, inplace=False) 17 | (linear2): Linear(in_features=256, out_features=128, bias=True) 18 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 19 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 20 | (dropout1): Dropout(p=0.1, inplace=False) 21 | (dropout2): Dropout(p=0.1, inplace=False) 22 | ) 23 | (1): TransformerEncoderLayer( 24 | (self_attn): MultiheadAttention( 25 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 26 | ) 27 | (linear1): Linear(in_features=128, out_features=256, bias=True) 28 | (dropout): Dropout(p=0.1, inplace=False) 29 | (linear2): Linear(in_features=256, out_features=128, bias=True) 30 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 31 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 32 | (dropout1): Dropout(p=0.1, inplace=False) 33 | (dropout2): Dropout(p=0.1, inplace=False) 34 | ) 35 | (2): TransformerEncoderLayer( 36 | (self_attn): MultiheadAttention( 37 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 38 | ) 39 | (linear1): Linear(in_features=128, out_features=256, bias=True) 40 | (dropout): Dropout(p=0.1, inplace=False) 41 | (linear2): Linear(in_features=256, out_features=128, bias=True) 42 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 43 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 44 | (dropout1): Dropout(p=0.1, inplace=False) 45 | (dropout2): Dropout(p=0.1, inplace=False) 46 | ) 47 | (3): TransformerEncoderLayer( 48 | (self_attn): MultiheadAttention( 49 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 50 | ) 51 | (linear1): Linear(in_features=128, out_features=256, bias=True) 52 | (dropout): Dropout(p=0.1, inplace=False) 53 | (linear2): Linear(in_features=256, out_features=128, bias=True) 54 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 55 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 56 | (dropout1): Dropout(p=0.1, inplace=False) 57 | (dropout2): Dropout(p=0.1, inplace=False) 58 | ) 59 | ) -------------------------------------------------------------------------------- /experiments/MOViA/Predictor_LSTM/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "MoviA", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 11, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 128, 19 | "mlp_hidden": 256, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 3, 23 | "num_iterations": 1, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "BBox" 31 | }, 32 | "predictor": { 33 | "predictor_name": "LSTM", 34 | "LSTM": { 35 | "num_cells": 2, 36 | "hidden_dim": 128, 37 | "residual": true 38 | } 39 | } 40 | }, 41 | "loss": [ 42 | { 43 | "type": "mse", 44 | "weight": 1 45 | } 46 | ], 47 | "predictor_loss": [ 48 | { 49 | "type": "pred_img_mse", 50 | "weight": 1 51 | }, 52 | { 53 | "type": "pred_slot_mse", 54 | "weight": 1 55 | } 56 | ], 57 | "training_slots": { 58 | "num_epochs": 1500, 59 | "save_frequency": 10, 60 | "log_frequency": 25, 61 | "image_log_frequency": 100, 62 | "batch_size": 64, 63 | "lr": 0.0002, 64 | "optimizer": "adam", 65 | "momentum": 0, 66 | "weight_decay": 0, 67 | "nesterov": false, 68 | "scheduler": "cosine_annealing", 69 | "lr_factor": 0.5, 70 | "patience": 10, 71 | "scheduler_steps": 400000, 72 | "lr_warmup": true, 73 | "warmup_steps": 2500, 74 | "warmup_epochs": 200, 75 | "gradient_clipping": true, 76 | "clipping_max_value": 0.05 77 | }, 78 | "training_prediction": { 79 | "num_context": 6, 80 | "num_preds": 8, 81 | "teacher_force": false, 82 | "skip_first_slot": false, 83 | "num_epochs": 1500, 84 | "train_iters_per_epoch": 10000000000, 85 | "save_frequency": 10, 86 | "save_frequency_iters": 10000000, 87 | "log_frequency": 25, 88 | "image_log_frequency": 100, 89 | "batch_size": 64, 90 | "sample_length": 14, 91 | "gradient_clipping": false, 92 | "clipping_max_value": 3.0 93 | }, 94 | "_general": { 95 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2", 96 | "created_time": "2023-01-28_11-34-50", 97 | "last_loaded": "2023-02-09_09-52-06" 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_LSTM/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "OBJ3D", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 6, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 64, 19 | "mlp_hidden": 128, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 2, 23 | "num_iterations": 2, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "LearnedRandom" 31 | }, 32 | "predictor": { 33 | "predictor_name": "LSTM", 34 | "LSTM": { 35 | "num_cells": 2, 36 | "hidden_dim": 128, 37 | "residual": true 38 | } 39 | } 40 | }, 41 | "loss": [ 42 | { 43 | "type": "mse", 44 | "weight": 1 45 | } 46 | ], 47 | "predictor_loss": [ 48 | { 49 | "type": "pred_img_mse", 50 | "weight": 1 51 | }, 52 | { 53 | "type": "pred_slot_mse", 54 | "weight": 1 55 | } 56 | ], 57 | "training_slots": { 58 | "num_epochs": 2000, 59 | "save_frequency": 10, 60 | "log_frequency": 100, 61 | "image_log_frequency": 100, 62 | "batch_size": 64, 63 | "lr": 0.0001, 64 | "optimizer": "adam", 65 | "momentum": 0, 66 | "weight_decay": 0, 67 | "nesterov": false, 68 | "scheduler": "cosine_annealing", 69 | "lr_factor": 0.05, 70 | "patience": 10, 71 | "scheduler_steps": 100000, 72 | "lr_warmup": true, 73 | "warmup_steps": 2500, 74 | "warmup_epochs": 1000, 75 | "gradient_clipping": true, 76 | "clipping_max_value": 0.05 77 | }, 78 | "training_prediction": { 79 | "num_context": 5, 80 | "num_preds": 5, 81 | "teacher_force": false, 82 | "skip_first_slot": false, 83 | "num_epochs": 1500, 84 | "train_iters_per_epoch": 10000000000, 85 | "save_frequency": 10, 86 | "save_frequency_iters": 1000000, 87 | "log_frequency": 100, 88 | "image_log_frequency": 100, 89 | "batch_size": 64, 90 | "sample_length": 10, 91 | "gradient_clipping": true, 92 | "clipping_max_value": 0.05 93 | }, 94 | "_general": { 95 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI", 96 | "created_time": "2022-12-06_09-08-22", 97 | "last_loaded": "2022-12-15_14-51-35" 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /experiments/MOViA/Predictor_OCVPPar/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "MoviA", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 11, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 128, 19 | "mlp_hidden": 256, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 3, 23 | "num_iterations": 1, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "BBox" 31 | }, 32 | "predictor": { 33 | "predictor_name": "OCVP-Par", 34 | "OCVP-Par": { 35 | "token_dim": 256, 36 | "hidden_dim": 512, 37 | "num_layers": 4, 38 | "n_heads": 8, 39 | "residual": true, 40 | "input_buffer_size": 30 41 | } 42 | } 43 | }, 44 | "loss": [ 45 | { 46 | "type": "mse", 47 | "weight": 1 48 | } 49 | ], 50 | "predictor_loss": [ 51 | { 52 | "type": "pred_img_mse", 53 | "weight": 1 54 | }, 55 | { 56 | "type": "pred_slot_mse", 57 | "weight": 1 58 | } 59 | ], 60 | "training_slots": { 61 | "num_epochs": 1500, 62 | "save_frequency": 10, 63 | "log_frequency": 25, 64 | "image_log_frequency": 100, 65 | "batch_size": 64, 66 | "lr": 0.0002, 67 | "optimizer": "adam", 68 | "momentum": 0, 69 | "weight_decay": 0, 70 | "nesterov": false, 71 | "scheduler": "cosine_annealing", 72 | "lr_factor": 0.5, 73 | "patience": 10, 74 | "scheduler_steps": 400000, 75 | "lr_warmup": true, 76 | "warmup_steps": 2500, 77 | "warmup_epochs": 200, 78 | "gradient_clipping": true, 79 | "clipping_max_value": 0.05 80 | }, 81 | "training_prediction": { 82 | "num_context": 6, 83 | "num_preds": 8, 84 | "teacher_force": false, 85 | "skip_first_slot": false, 86 | "num_epochs": 1500, 87 | "train_iters_per_epoch": 10000000000, 88 | "save_frequency": 10, 89 | "save_frequency_iters": 10000000, 90 | "log_frequency": 25, 91 | "image_log_frequency": 100, 92 | "batch_size": 64, 93 | "sample_length": 14, 94 | "gradient_clipping": false, 95 | "clipping_max_value": 3.0 96 | }, 97 | "_general": { 98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2", 99 | "created_time": "2023-01-28_11-34-50", 100 | "last_loaded": "2023-02-09_13-44-17" 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /experiments/MOViA/Predictor_OCVPSeq/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "MoviA", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 11, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 128, 19 | "mlp_hidden": 256, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 3, 23 | "num_iterations": 1, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "BBox" 31 | }, 32 | "predictor": { 33 | "predictor_name": "OCVP-Seq", 34 | "OCVP-Seq": { 35 | "token_dim": 256, 36 | "hidden_dim": 512, 37 | "num_layers": 4, 38 | "n_heads": 8, 39 | "residual": true, 40 | "input_buffer_size": 30 41 | } 42 | } 43 | }, 44 | "loss": [ 45 | { 46 | "type": "mse", 47 | "weight": 1 48 | } 49 | ], 50 | "predictor_loss": [ 51 | { 52 | "type": "pred_img_mse", 53 | "weight": 1 54 | }, 55 | { 56 | "type": "pred_slot_mse", 57 | "weight": 1 58 | } 59 | ], 60 | "training_slots": { 61 | "num_epochs": 1500, 62 | "save_frequency": 10, 63 | "log_frequency": 25, 64 | "image_log_frequency": 100, 65 | "batch_size": 64, 66 | "lr": 0.0002, 67 | "optimizer": "adam", 68 | "momentum": 0, 69 | "weight_decay": 0, 70 | "nesterov": false, 71 | "scheduler": "cosine_annealing", 72 | "lr_factor": 0.5, 73 | "patience": 10, 74 | "scheduler_steps": 400000, 75 | "lr_warmup": true, 76 | "warmup_steps": 2500, 77 | "warmup_epochs": 200, 78 | "gradient_clipping": true, 79 | "clipping_max_value": 0.05 80 | }, 81 | "training_prediction": { 82 | "num_context": 6, 83 | "num_preds": 8, 84 | "teacher_force": false, 85 | "skip_first_slot": false, 86 | "num_epochs": 1500, 87 | "train_iters_per_epoch": 10000000000, 88 | "save_frequency": 10, 89 | "save_frequency_iters": 10000000, 90 | "log_frequency": 25, 91 | "image_log_frequency": 100, 92 | "batch_size": 64, 93 | "sample_length": 14, 94 | "gradient_clipping": false, 95 | "clipping_max_value": 3.0 96 | }, 97 | "_general": { 98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2", 99 | "created_time": "2023-01-28_11-34-50", 100 | "last_loaded": "2023-02-08_16-32-24" 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /experiments/MOViA/Predictor_Transformer/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "MoviA", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 11, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 128, 19 | "mlp_hidden": 256, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 3, 23 | "num_iterations": 1, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "BBox" 31 | }, 32 | "predictor": { 33 | "predictor_name": "Transformer", 34 | "Transformer": { 35 | "token_dim": 256, 36 | "hidden_dim": 512, 37 | "num_layers": 4, 38 | "n_heads": 8, 39 | "residual": true, 40 | "input_buffer_size": 30 41 | } 42 | } 43 | }, 44 | "loss": [ 45 | { 46 | "type": "mse", 47 | "weight": 1 48 | } 49 | ], 50 | "predictor_loss": [ 51 | { 52 | "type": "pred_img_mse", 53 | "weight": 1 54 | }, 55 | { 56 | "type": "pred_slot_mse", 57 | "weight": 1 58 | } 59 | ], 60 | "training_slots": { 61 | "num_epochs": 1500, 62 | "save_frequency": 10, 63 | "log_frequency": 25, 64 | "image_log_frequency": 100, 65 | "batch_size": 64, 66 | "lr": 0.0002, 67 | "optimizer": "adam", 68 | "momentum": 0, 69 | "weight_decay": 0, 70 | "nesterov": false, 71 | "scheduler": "cosine_annealing", 72 | "lr_factor": 0.5, 73 | "patience": 10, 74 | "scheduler_steps": 400000, 75 | "lr_warmup": true, 76 | "warmup_steps": 2500, 77 | "warmup_epochs": 200, 78 | "gradient_clipping": true, 79 | "clipping_max_value": 0.05 80 | }, 81 | "training_prediction": { 82 | "num_context": 6, 83 | "num_preds": 8, 84 | "teacher_force": false, 85 | "skip_first_slot": false, 86 | "num_epochs": 1500, 87 | "train_iters_per_epoch": 10000000000, 88 | "save_frequency": 10, 89 | "save_frequency_iters": 10000000, 90 | "log_frequency": 25, 91 | "image_log_frequency": 100, 92 | "batch_size": 64, 93 | "sample_length": 14, 94 | "gradient_clipping": false, 95 | "clipping_max_value": 3.0 96 | }, 97 | "_general": { 98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2", 99 | "created_time": "2023-01-28_11-34-50", 100 | "last_loaded": "2023-02-08_09-15-58" 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_OCVPPar/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "OBJ3D", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 6, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 64, 19 | "mlp_hidden": 128, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 2, 23 | "num_iterations": 2, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "LearnedRandom" 31 | }, 32 | "predictor": { 33 | "predictor_name": "OCVP-Par", 34 | "OCVP-Par": { 35 | "token_dim": 128, 36 | "hidden_dim": 256, 37 | "num_layers": 4, 38 | "n_heads": 4, 39 | "residual": true, 40 | "input_buffer_size": 5 41 | } 42 | } 43 | }, 44 | "loss": [ 45 | { 46 | "type": "mse", 47 | "weight": 1 48 | } 49 | ], 50 | "predictor_loss": [ 51 | { 52 | "type": "pred_img_mse", 53 | "weight": 1 54 | }, 55 | { 56 | "type": "pred_slot_mse", 57 | "weight": 1 58 | } 59 | ], 60 | "training_slots": { 61 | "num_epochs": 2000, 62 | "save_frequency": 10, 63 | "log_frequency": 100, 64 | "image_log_frequency": 100, 65 | "batch_size": 64, 66 | "lr": 0.0001, 67 | "optimizer": "adam", 68 | "momentum": 0, 69 | "weight_decay": 0, 70 | "nesterov": false, 71 | "scheduler": "cosine_annealing", 72 | "lr_factor": 0.05, 73 | "patience": 10, 74 | "scheduler_steps": 100000, 75 | "lr_warmup": true, 76 | "warmup_steps": 2500, 77 | "warmup_epochs": 1000, 78 | "gradient_clipping": true, 79 | "clipping_max_value": 0.05 80 | }, 81 | "training_prediction": { 82 | "num_context": 5, 83 | "num_preds": 5, 84 | "teacher_force": false, 85 | "skip_first_slot": false, 86 | "num_epochs": 1500, 87 | "train_iters_per_epoch": 10000000000, 88 | "save_frequency": 25, 89 | "save_frequency_iters": 1000000, 90 | "log_frequency": 100, 91 | "image_log_frequency": 100, 92 | "batch_size": 16, 93 | "sample_length": 10, 94 | "gradient_clipping": true, 95 | "clipping_max_value": 0.05 96 | }, 97 | "_general": { 98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI", 99 | "created_time": "2022-12-06_09-08-22", 100 | "last_loaded": "2023-01-13_13-41-20" 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_OCVPSeq/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "OBJ3D", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 6, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 64, 19 | "mlp_hidden": 128, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 2, 23 | "num_iterations": 2, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "LearnedRandom" 31 | }, 32 | "predictor": { 33 | "predictor_name": "OCVP-Seq", 34 | "OCVP-Seq": { 35 | "token_dim": 128, 36 | "hidden_dim": 256, 37 | "num_layers": 4, 38 | "n_heads": 4, 39 | "residual": true, 40 | "input_buffer_size": 5 41 | } 42 | } 43 | }, 44 | "loss": [ 45 | { 46 | "type": "mse", 47 | "weight": 1 48 | } 49 | ], 50 | "predictor_loss": [ 51 | { 52 | "type": "pred_img_mse", 53 | "weight": 1 54 | }, 55 | { 56 | "type": "pred_slot_mse", 57 | "weight": 1 58 | } 59 | ], 60 | "training_slots": { 61 | "num_epochs": 2000, 62 | "save_frequency": 10, 63 | "log_frequency": 100, 64 | "image_log_frequency": 100, 65 | "batch_size": 64, 66 | "lr": 0.0001, 67 | "optimizer": "adam", 68 | "momentum": 0, 69 | "weight_decay": 0, 70 | "nesterov": false, 71 | "scheduler": "cosine_annealing", 72 | "lr_factor": 0.05, 73 | "patience": 10, 74 | "scheduler_steps": 100000, 75 | "lr_warmup": true, 76 | "warmup_steps": 2500, 77 | "warmup_epochs": 1000, 78 | "gradient_clipping": true, 79 | "clipping_max_value": 0.05 80 | }, 81 | "training_prediction": { 82 | "num_context": 5, 83 | "num_preds": 5, 84 | "teacher_force": false, 85 | "skip_first_slot": false, 86 | "num_epochs": 1500, 87 | "train_iters_per_epoch": 10000000000, 88 | "save_frequency": 25, 89 | "save_frequency_iters": 1000000, 90 | "log_frequency": 100, 91 | "image_log_frequency": 100, 92 | "batch_size": 16, 93 | "sample_length": 10, 94 | "gradient_clipping": true, 95 | "clipping_max_value": 0.05 96 | }, 97 | "_general": { 98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI", 99 | "created_time": "2022-12-06_09-08-22", 100 | "last_loaded": "2023-01-13_13-36-22" 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_Transformer/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "OBJ3D", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 6, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 64, 19 | "mlp_hidden": 128, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 2, 23 | "num_iterations": 2, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "LearnedRandom" 31 | }, 32 | "predictor": { 33 | "predictor_name": "Transformer", 34 | "Transformer": { 35 | "token_dim": 128, 36 | "hidden_dim": 256, 37 | "num_layers": 4, 38 | "n_heads": 4, 39 | "residual": true, 40 | "input_buffer_size": 5 41 | } 42 | } 43 | }, 44 | "loss": [ 45 | { 46 | "type": "mse", 47 | "weight": 1 48 | } 49 | ], 50 | "predictor_loss": [ 51 | { 52 | "type": "pred_img_mse", 53 | "weight": 1 54 | }, 55 | { 56 | "type": "pred_slot_mse", 57 | "weight": 1 58 | } 59 | ], 60 | "training_slots": { 61 | "num_epochs": 2000, 62 | "save_frequency": 10, 63 | "log_frequency": 100, 64 | "image_log_frequency": 100, 65 | "batch_size": 64, 66 | "lr": 0.0001, 67 | "optimizer": "adam", 68 | "momentum": 0, 69 | "weight_decay": 0, 70 | "nesterov": false, 71 | "scheduler": "cosine_annealing", 72 | "lr_factor": 0.05, 73 | "patience": 10, 74 | "scheduler_steps": 100000, 75 | "lr_warmup": true, 76 | "warmup_steps": 2500, 77 | "warmup_epochs": 1000, 78 | "gradient_clipping": true, 79 | "clipping_max_value": 0.05 80 | }, 81 | "training_prediction": { 82 | "num_context": 5, 83 | "num_preds": 5, 84 | "teacher_force": false, 85 | "skip_first_slot": false, 86 | "num_epochs": 1500, 87 | "train_iters_per_epoch": 10000000000, 88 | "save_frequency": 25, 89 | "save_frequency_iters": 1000000, 90 | "log_frequency": 100, 91 | "image_log_frequency": 100, 92 | "batch_size": 64, 93 | "sample_length": 10, 94 | "gradient_clipping": true, 95 | "clipping_max_value": 0.05 96 | }, 97 | "_general": { 98 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI", 99 | "created_time": "2022-12-06_09-08-22", 100 | "last_loaded": "2022-12-29_08-41-26" 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/01_create_predictor_experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creating a predictor experiment inside an existing experiment directory and initializing 3 | its experiment parameters. 4 | """ 5 | 6 | import os 7 | from lib.arguments import create_predictor_experiment_arguments 8 | from lib.config import Config 9 | from lib.logger import Logger, print_ 10 | from lib.utils import create_directory, delete_directory, clear_cmd 11 | from CONFIG import CONFIG 12 | 13 | 14 | def initialize_predictor_experiment(): 15 | """ 16 | Creating predictor experiment directory and initializing it with defauls 17 | """ 18 | # reading command line args 19 | args = create_predictor_experiment_arguments() 20 | exp_dir, exp_name, predictor_name = args.exp_directory, args.name, args.predictor_name 21 | 22 | # making sure everything adds up 23 | parent_path = os.path.join(CONFIG["paths"]["experiments_path"], exp_dir) 24 | exp_path = os.path.join(CONFIG["paths"]["experiments_path"], exp_dir, exp_name) 25 | if not os.path.exists(parent_path): 26 | raise FileNotFoundError(f"{parent_path = } does not exist") 27 | if not os.path.exists(os.path.join(parent_path, "experiment_params.json")): 28 | raise FileNotFoundError(f"{parent_path = } does not have experiment_params...") 29 | if len(os.listdir(os.path.join(parent_path, "models"))) <= 0: 30 | raise FileNotFoundError("Parent models-dir does not contain any models!...") 31 | if os.path.exists(exp_path): 32 | raise ValueError(f"{exp_path = } already exists. Choose a different name!") 33 | 34 | # creating directories 35 | create_directory(exp_path) 36 | _ = Logger(exp_path) # initialize logger once exp_dir is created 37 | create_directory(dir_path=exp_path, dir_name="plots") 38 | create_directory(dir_path=exp_path, dir_name="tboard_logs") 39 | 40 | # adding experiment parameters from the parent directory, but only with specified predictor params 41 | try: 42 | cfg = Config(exp_path=parent_path) 43 | exp_params = cfg.load_exp_config_file() 44 | new_predictor_params = {} 45 | predictor_params = exp_params["model"]["predictor"] 46 | if predictor_name not in predictor_params.keys(): 47 | raise ValueError(f"{predictor_name} not in keys {predictor_params.keys() = }") 48 | new_predictor_params["predictor_name"] = predictor_name 49 | new_predictor_params[predictor_name] = predictor_params[predictor_name] 50 | exp_params["model"]["predictor"] = new_predictor_params 51 | cfg.save_exp_config_file(exp_path=exp_path, exp_params=exp_params) 52 | except FileNotFoundError as e: 53 | print_("An error has occurred...\n Removing experiment directory") 54 | delete_directory(dir_path=exp_path) 55 | print(e) 56 | print(f"Predictor experiment {exp_name} created successfully! :)") 57 | return 58 | 59 | 60 | if __name__ == "__main__": 61 | clear_cmd() 62 | initialize_predictor_experiment() 63 | 64 | # 65 | -------------------------------------------------------------------------------- /src/03_evaluate_savi_noMasks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluating a SAVI model checkpoint on a dataset without ground truth masks. 3 | Since we do not have masks to compare with, we simply evaluate the visual qualitiy 4 | of the reconstructed images using video-prediction metrics: PSNR, SSIM and LPIPS. 5 | """ 6 | 7 | from data.load_data import unwrap_batch_data 8 | from lib.arguments import get_sa_eval_arguments 9 | from lib.logger import Logger, print_, log_function, for_all_methods 10 | from lib.metrics import MetricTracker 11 | import lib.utils as utils 12 | 13 | from base.baseEvaluator import BaseEvaluator 14 | 15 | 16 | @for_all_methods(log_function) 17 | class Evaluator(BaseEvaluator): 18 | """ 19 | Evaluating a SAVI model checkpoint on a dataset without ground truth masks. 20 | Since we do not have masks to compare with, we simply evaluate the visual qualitiy 21 | of the reconstructed images using video-prediction metrics: PSNR, SSIM and LPIPS. 22 | """ 23 | 24 | def set_metric_tracker(self): 25 | """ 26 | Initializing the metric tracker 27 | """ 28 | self.metric_tracker = MetricTracker( 29 | exp_path=exp_path, 30 | metrics=["psnr", "ssim", "lpips"] 31 | ) 32 | return 33 | 34 | def forward_eval(self, batch_data, **kwargs): 35 | """ 36 | Making a forwad pass through the model and computing the evaluation metrics 37 | 38 | Args: 39 | ----- 40 | batch_data: dict 41 | Dictionary containing the information for the current batch, including images, poses, 42 | actions, or metadata, among others. 43 | 44 | Returns: 45 | -------- 46 | pred_data: dict 47 | Predictions from the model for the current batch of data 48 | """ 49 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data) 50 | videos, targets = videos.to(self.device), targets.to(self.device) 51 | out_model = self.model( 52 | videos, 53 | num_imgs=videos.shape[1], 54 | **initializer_kwargs 55 | ) 56 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model 57 | self.metric_tracker.accumulate( 58 | preds=reconstruction_history.clamp(0, 1), 59 | targets=targets.clamp(0, 1) 60 | ) 61 | return 62 | 63 | 64 | if __name__ == "__main__": 65 | utils.clear_cmd() 66 | exp_path, args = get_sa_eval_arguments() 67 | logger = Logger(exp_path=exp_path) 68 | logger.log_info("Starting SAVi visual quality evaluation procedure", message_type="new_exp") 69 | 70 | print_("Initializing Evaluator...") 71 | print_("Args:") 72 | print_("-----") 73 | for k, v in vars(args).items(): 74 | print_(f" --> {k} = {v}") 75 | evaluator = Evaluator( 76 | exp_path=exp_path, 77 | checkpoint=args.checkpoint 78 | ) 79 | print_("Loading dataset...") 80 | evaluator.load_data() 81 | print_("Setting up model and loading pretrained parameters") 82 | evaluator.setup_model() 83 | print_("Starting evaluation") 84 | evaluator.evaluate() 85 | 86 | # 87 | -------------------------------------------------------------------------------- /src/03_evaluate_savi.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluating a SAVI model checkpoint using object-centric metrics 3 | This evalaution can only be performed on datasets with annotated segmentation masks 4 | """ 5 | 6 | import torch 7 | 8 | from data import unwrap_batch_data_masks 9 | from lib.arguments import get_sa_eval_arguments 10 | from lib.logger import Logger, print_, log_function, for_all_methods 11 | from lib.metrics import MetricTracker 12 | import lib.utils as utils 13 | 14 | from base.baseEvaluator import BaseEvaluator 15 | 16 | 17 | @for_all_methods(log_function) 18 | class Evaluator(BaseEvaluator): 19 | """ 20 | Class for evaluating a SAVI model using object-centric metrics. 21 | This evalaution can only be performed on datasets with annotated segmentation masks 22 | """ 23 | 24 | def set_metric_tracker(self): 25 | """ 26 | Initializing the metric tracker 27 | """ 28 | self.metric_tracker = MetricTracker( 29 | exp_path, 30 | metrics=["segmentation_ari", "IoU"] 31 | ) 32 | return 33 | 34 | def load_data(self): 35 | """ 36 | Loading data 37 | """ 38 | super().load_data() 39 | self.test_set.get_masks = True 40 | return 41 | 42 | def forward_eval(self, batch_data, **kwargs): 43 | """ 44 | Making a forwad pass through the model and computing the evaluation metrics 45 | 46 | Args: 47 | ----- 48 | batch_data: dict 49 | Dictionary containing the information for the current batch, including images, poses, 50 | actions, or metadata, among others. 51 | 52 | Returns: 53 | -------- 54 | pred_data: dict 55 | Predictions from the model for the current batch of data 56 | """ 57 | videos, masks, initializer_kwargs = unwrap_batch_data_masks(self.exp_params, batch_data) 58 | videos, masks = videos.to(self.device), masks.to(self.device) 59 | out_model = self.model( 60 | videos, 61 | num_imgs=videos.shape[1], 62 | **initializer_kwargs 63 | ) 64 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model 65 | 66 | # evaluation 67 | predicted_combined_masks = torch.argmax(masks_history, dim=2).squeeze(2) 68 | self.metric_tracker.accumulate( 69 | preds=predicted_combined_masks, 70 | targets=masks 71 | ) 72 | return 73 | 74 | 75 | if __name__ == "__main__": 76 | utils.clear_cmd() 77 | exp_path, args = get_sa_eval_arguments() 78 | logger = Logger(exp_path=exp_path) 79 | logger.log_info("Starting SAVi object-cetric evaluation procedure", message_type="new_exp") 80 | 81 | print_("Initializing Evaluator...") 82 | print_("Args:") 83 | print_("-----") 84 | for k, v in vars(args).items(): 85 | print_(f" --> {k} = {v}") 86 | evaluator = Evaluator( 87 | exp_path=exp_path, 88 | checkpoint=args.checkpoint 89 | ) 90 | print_("Loading dataset...") 91 | evaluator.load_data() 92 | print_("Setting up model and loading pretrained parameters") 93 | evaluator.setup_model() 94 | print_("Starting evaluation") 95 | evaluator.evaluate() 96 | 97 | 98 | # 99 | -------------------------------------------------------------------------------- /experiments/MOViA/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "MoviA", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 11, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 128, 19 | "mlp_hidden": 256, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 3, 23 | "num_iterations": 1, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "BBox" 31 | }, 32 | "predictor": { 33 | "predictor_name": "Transformer", 34 | "LSTM": { 35 | "num_cells": 2, 36 | "hidden_dim": 64, 37 | "residual": 30 38 | }, 39 | "Transformer": { 40 | "token_dim": 128, 41 | "hidden_dim": 256, 42 | "num_layers": 2, 43 | "n_heads": 4, 44 | "residual": true, 45 | "input_buffer_size": 30 46 | }, 47 | "OCVP-Seq": { 48 | "token_dim": 128, 49 | "hidden_dim": 256, 50 | "num_layers": 2, 51 | "n_heads": 4, 52 | "residual": true, 53 | "input_buffer_size": 30 54 | }, 55 | "OCVP-Par": { 56 | "token_dim": 128, 57 | "hidden_dim": 256, 58 | "num_layers": 2, 59 | "n_heads": 4, 60 | "residual": true, 61 | "input_buffer_size": 30 62 | } 63 | } 64 | }, 65 | "loss": [ 66 | { 67 | "type": "mse", 68 | "weight": 1 69 | } 70 | ], 71 | "predictor_loss": [ 72 | { 73 | "type": "pred_img_mse", 74 | "weight": 1 75 | }, 76 | { 77 | "type": "pred_slot_mse", 78 | "weight": 1 79 | } 80 | ], 81 | "training_slots": { 82 | "num_epochs": 2500, 83 | "save_frequency": 10, 84 | "log_frequency": 25, 85 | "image_log_frequency": 100, 86 | "batch_size": 64, 87 | "lr": 0.0001, 88 | "optimizer": "adam", 89 | "momentum": 0, 90 | "weight_decay": 0, 91 | "nesterov": false, 92 | "scheduler": "cosine_annealing", 93 | "lr_factor": 0.5, 94 | "patience": 10, 95 | "scheduler_steps": 400000, 96 | "lr_warmup": true, 97 | "warmup_steps": 2500, 98 | "warmup_epochs": 200, 99 | "gradient_clipping": true, 100 | "clipping_max_value": 0.05 101 | }, 102 | "training_prediction": { 103 | "num_context": 5, 104 | "num_preds": 5, 105 | "teacher_force": false, 106 | "skip_first_slot": false, 107 | "num_epochs": 1500, 108 | "train_iters_per_epoch": 10000000000, 109 | "save_frequency": 10, 110 | "save_frequency_iters": 1000, 111 | "log_frequency": 25, 112 | "image_log_frequency": 100, 113 | "batch_size": 64, 114 | "sample_length": 10, 115 | "gradient_clipping": true, 116 | "clipping_max_value": 0.05 117 | }, 118 | "_general": { 119 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/MoviExps/exp2", 120 | "created_time": "2023-01-28_11-34-50", 121 | "last_loaded": "2023-01-28_11-34-50" 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /experiments/Obj3D/experiment_params.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset": { 3 | "dataset_name": "OBJ3D", 4 | "shuffle_train": true, 5 | "shuffle_eval": false, 6 | "use_segmentation": true, 7 | "target": "rgb", 8 | "random_start": true 9 | }, 10 | "model": { 11 | "model_name": "SAVi", 12 | "SAVi": { 13 | "num_slots": 6, 14 | "slot_dim": 128, 15 | "in_channels": 3, 16 | "encoder_type": "ConvEncoder", 17 | "num_channels": [32, 32, 32, 32], 18 | "mlp_encoder_dim": 64, 19 | "mlp_hidden": 128, 20 | "num_channels_decoder": [64, 64, 64, 64], 21 | "kernel_size": 5, 22 | "num_iterations_first": 2, 23 | "num_iterations": 2, 24 | "resolution": [64, 64], 25 | "downsample_encoder": false, 26 | "downsample_decoder": true, 27 | "decoder_resolution": [8, 8], 28 | "upsample": 2, 29 | "use_predictor": true, 30 | "initializer": "LearnedRandom" 31 | }, 32 | "predictor": { 33 | "predictor_name": "Transformer", 34 | "LSTM": { 35 | "num_cells": 2, 36 | "hidden_dim": 64, 37 | "residual": true 38 | }, 39 | "Transformer": { 40 | "token_dim": 128, 41 | "hidden_dim": 256, 42 | "num_layers": 2, 43 | "n_heads": 4, 44 | "residual": true, 45 | "input_buffer_size": null 46 | }, 47 | "OCVP-Seq": { 48 | "token_dim": 128, 49 | "hidden_dim": 256, 50 | "num_layers": 2, 51 | "n_heads": 4, 52 | "residual": true, 53 | "input_buffer_size": null 54 | }, 55 | "OCVP-Par": { 56 | "token_dim": 128, 57 | "hidden_dim": 256, 58 | "num_layers": 2, 59 | "n_heads": 4, 60 | "residual": true, 61 | "input_buffer_size": null 62 | } 63 | } 64 | }, 65 | "loss": [ 66 | { 67 | "type": "mse", 68 | "weight": 1 69 | } 70 | ], 71 | "predictor_loss": [ 72 | { 73 | "type": "pred_img_mse", 74 | "weight": 1 75 | }, 76 | { 77 | "type": "pred_slot_mse", 78 | "weight": 1 79 | } 80 | ], 81 | "training_slots": { 82 | "num_epochs": 2000, 83 | "save_frequency": 10, 84 | "log_frequency": 100, 85 | "image_log_frequency": 100, 86 | "batch_size": 64, 87 | "lr": 0.0001, 88 | "optimizer": "adam", 89 | "momentum": 0, 90 | "weight_decay": 0, 91 | "nesterov": false, 92 | "scheduler": "cosine_annealing", 93 | "lr_factor": 0.05, 94 | "patience": 10, 95 | "scheduler_steps": 100000, 96 | "lr_warmup": true, 97 | "warmup_steps": 2500, 98 | "warmup_epochs": 1000, 99 | "gradient_clipping": true, 100 | "clipping_max_value": 0.05 101 | }, 102 | "training_prediction": { 103 | "num_context": 5, 104 | "num_preds": 5, 105 | "teacher_force": false, 106 | "skip_first_slot": false, 107 | "num_epochs": 1500, 108 | "train_iters_per_epoch": 10000000000, 109 | "save_frequency": 10, 110 | "save_frequency_iters": 1000000, 111 | "log_frequency": 100, 112 | "image_log_frequency": 100, 113 | "batch_size": 64, 114 | "sample_length": 10, 115 | "gradient_clipping": true, 116 | "clipping_max_value": 0.05 117 | }, 118 | "_general": { 119 | "exp_path": "/home/data/user/villar/ObjectCentricVideoPred/experiments/NewSAVI/NewSAVI", 120 | "created_time": "2022-12-06_09-08-22", 121 | "last_loaded": "2022-12-06_09-08-22" 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /experiments/Obj3D/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 915652 2 | 3 | Params: 79328 4 | SimpleConvEncoder( 5 | (encoder): Sequential( 6 | (0): ConvBlock( 7 | (block): Sequential( 8 | (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 9 | (1): ReLU() 10 | ) 11 | ) 12 | (1): ConvBlock( 13 | (block): Sequential( 14 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 15 | (1): ReLU() 16 | ) 17 | ) 18 | (2): ConvBlock( 19 | (block): Sequential( 20 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 21 | (1): ReLU() 22 | ) 23 | ) 24 | (3): ConvBlock( 25 | (block): Sequential( 26 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 27 | (1): ReLU() 28 | ) 29 | ) 30 | ) 31 | ) 32 | 33 | Params: 160 34 | SoftPositionEmbed( 35 | (projection): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1)) 36 | ) 37 | 38 | Params: 6336 39 | Sequential( 40 | (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True) 41 | (1): Linear(in_features=32, out_features=64, bias=True) 42 | (2): ReLU() 43 | (3): Linear(in_features=64, out_features=64, bias=True) 44 | ) 45 | 46 | Params: 148480 47 | TransformerBlock( 48 | (attn): MultiHeadSelfAttention( 49 | (q): Linear(in_features=128, out_features=128, bias=False) 50 | (k): Linear(in_features=128, out_features=128, bias=False) 51 | (v): Linear(in_features=128, out_features=128, bias=False) 52 | (drop): Dropout(p=0.0, inplace=False) 53 | (out_projection): Sequential( 54 | (0): Linear(in_features=128, out_features=128, bias=False) 55 | ) 56 | ) 57 | (mlp): Sequential( 58 | (0): Linear(in_features=128, out_features=256, bias=True) 59 | (1): ReLU() 60 | (2): Linear(in_features=256, out_features=128, bias=True) 61 | ) 62 | (layernorm_query): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 63 | (layernorm_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 64 | (dense_o): Linear(in_features=128, out_features=128, bias=True) 65 | ) 66 | 67 | Params: 640 68 | SoftPositionEmbed( 69 | (projection): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1)) 70 | ) 71 | 72 | Params: 514564 73 | Decoder( 74 | (decoder): Sequential( 75 | (0): ConvBlock( 76 | (block): Sequential( 77 | (0): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 78 | (1): ReLU() 79 | ) 80 | ) 81 | (1): Upsample(scale_factor=2.0, mode=nearest) 82 | (2): ConvBlock( 83 | (block): Sequential( 84 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 85 | (1): ReLU() 86 | ) 87 | ) 88 | (3): Upsample(scale_factor=2.0, mode=nearest) 89 | (4): ConvBlock( 90 | (block): Sequential( 91 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 92 | (1): ReLU() 93 | ) 94 | ) 95 | (5): Upsample(scale_factor=2.0, mode=nearest) 96 | (6): ConvBlock( 97 | (block): Sequential( 98 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 99 | (1): ReLU() 100 | ) 101 | ) 102 | (7): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 103 | ) 104 | ) 105 | 106 | Params: 166144 107 | SlotAttention( 108 | (norm_input): LayerNorm((64,), eps=0.001, elementwise_affine=True) 109 | (norm_slot): LayerNorm((128,), eps=0.001, elementwise_affine=True) 110 | (norm_mlp): LayerNorm((128,), eps=0.001, elementwise_affine=True) 111 | (to_q): Linear(in_features=128, out_features=128, bias=True) 112 | (to_k): Linear(in_features=64, out_features=128, bias=True) 113 | (to_v): Linear(in_features=64, out_features=128, bias=True) 114 | (gru): GRUCell(128, 128) 115 | (mlp): Sequential( 116 | (0): Linear(in_features=128, out_features=128, bias=True) 117 | (1): ReLU() 118 | (2): Linear(in_features=128, out_features=128, bias=True) 119 | ) 120 | ) 121 | -------------------------------------------------------------------------------- /src/data/obj3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset class to load Obj3D dataset 3 | - Source: https://github.com/zhixuan-lin/G-SWM/blob/master/src/dataset/obj3d.py 4 | """ 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | import glob 9 | import os 10 | import os.path as osp 11 | import torch 12 | from PIL import Image, ImageFile 13 | 14 | from CONFIG import CONFIG 15 | PATH = CONFIG["paths"]["data_path"] 16 | 17 | ImageFile.LOAD_TRUNCATED_IMAGES = True 18 | 19 | 20 | class OBJ3D(Dataset): 21 | """ 22 | DataClass for the Obj3D Dataset. 23 | 24 | During training, we sample a random subset of frames in the episode. At inference time, 25 | we always start from the first frame, e.g., when the ball moves towards the objects, and 26 | before any collision happens. 27 | 28 | - Source: https://github.com/zhixuan-lin/G-SWM/blob/master/src/dataset/obj3d.py 29 | 30 | Args: 31 | ----- 32 | mode: string 33 | Dataset split to load. Can be one of ['train', 'val', 'test'] 34 | ep_len: int 35 | Number of frames in an episode. Default is 30 36 | sample_length: int 37 | Number of frames in the sequences to load 38 | random_start: bool 39 | If True, first frame of the sequence is sampled at random between the possible starting frames. 40 | Otherwise, starting frame is always the first frame in the sequence. 41 | """ 42 | 43 | def __init__(self, mode, ep_len=30, sample_length=20, random_start=True): 44 | """ 45 | Dataset Initializer 46 | """ 47 | assert mode in ["train", "val", "valid", "eval", "test"], f"Unknown dataset split {mode}..." 48 | mode = "val" if mode in ["val", "valid"] else mode 49 | mode = "test" if mode in ["test", "eval"] else mode 50 | assert mode in ['train', 'val', 'test'], f"Unknown dataset split {mode}..." 51 | 52 | self.root = os.path.join(PATH, "OBJ3D", mode) 53 | self.mode = mode 54 | self.sample_length = sample_length 55 | self.random_start = random_start 56 | 57 | # Get all numbers 58 | self.folders = [] 59 | for file in os.listdir(self.root): 60 | try: 61 | self.folders.append(int(file)) 62 | except ValueError: 63 | continue 64 | self.folders.sort() 65 | 66 | # episode-related paramters 67 | self.epsisodes = [] 68 | self.EP_LEN = ep_len 69 | if mode == "train" and self.random_start: 70 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1 71 | else: 72 | self.seq_per_episode = 1 73 | 74 | # loading images from data directories and assembling then into episodes 75 | for f in self.folders: 76 | dir_name = os.path.join(self.root, str(f)) 77 | paths = list(glob.glob(osp.join(dir_name, 'test_*.png'))) 78 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0].partition('_')[-1]) 79 | paths.sort(key=get_num) 80 | self.epsisodes.append(paths) 81 | return 82 | 83 | def __getitem__(self, index): 84 | """ 85 | Fetching a sequence from the dataset 86 | """ 87 | imgs = [] 88 | 89 | # Implement continuous indexing 90 | ep = index // self.seq_per_episode 91 | offset = index % self.seq_per_episode 92 | end = offset + self.sample_length 93 | 94 | e = self.epsisodes[ep] 95 | for image_index in range(offset, end): 96 | img = Image.open(osp.join(e[image_index])) 97 | img = img.resize((64, 64)) 98 | img = transforms.ToTensor()(img)[:3] 99 | imgs.append(img) 100 | img = torch.stack(imgs, dim=0).float() 101 | 102 | targets = img 103 | all_reps = {"videos": img} 104 | return img, targets, all_reps 105 | 106 | def __len__(self): 107 | """ 108 | Number of episodes in the dataset 109 | """ 110 | length = len(self.epsisodes) 111 | return length 112 | 113 | 114 | # 115 | -------------------------------------------------------------------------------- /experiments/MOViA/Predictor_OCVPPar/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 4279680 2 | 3 | Params: 4279680 4 | OCVTransformerV2Predictor( 5 | (mlp_in): Linear(in_features=128, out_features=256, bias=True) 6 | (mlp_out): Linear(in_features=256, out_features=128, bias=True) 7 | (transformer_encoders): Sequential( 8 | (0): ObjectCentricTransformerLayerV2( 9 | (self_attn): MultiheadAttention( 10 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 11 | ) 12 | (linear1): Linear(in_features=256, out_features=512, bias=True) 13 | (dropout): Dropout(p=0.1, inplace=False) 14 | (linear2): Linear(in_features=512, out_features=256, bias=True) 15 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 16 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 17 | (dropout1): Dropout(p=0.1, inplace=False) 18 | (dropout2): Dropout(p=0.1, inplace=False) 19 | (self_attn_obj): MultiheadAttention( 20 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 21 | ) 22 | (self_attn_time): MultiheadAttention( 23 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 24 | ) 25 | ) 26 | (1): ObjectCentricTransformerLayerV2( 27 | (self_attn): MultiheadAttention( 28 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 29 | ) 30 | (linear1): Linear(in_features=256, out_features=512, bias=True) 31 | (dropout): Dropout(p=0.1, inplace=False) 32 | (linear2): Linear(in_features=512, out_features=256, bias=True) 33 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 34 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 35 | (dropout1): Dropout(p=0.1, inplace=False) 36 | (dropout2): Dropout(p=0.1, inplace=False) 37 | (self_attn_obj): MultiheadAttention( 38 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 39 | ) 40 | (self_attn_time): MultiheadAttention( 41 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 42 | ) 43 | ) 44 | (2): ObjectCentricTransformerLayerV2( 45 | (self_attn): MultiheadAttention( 46 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 47 | ) 48 | (linear1): Linear(in_features=256, out_features=512, bias=True) 49 | (dropout): Dropout(p=0.1, inplace=False) 50 | (linear2): Linear(in_features=512, out_features=256, bias=True) 51 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 52 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 53 | (dropout1): Dropout(p=0.1, inplace=False) 54 | (dropout2): Dropout(p=0.1, inplace=False) 55 | (self_attn_obj): MultiheadAttention( 56 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 57 | ) 58 | (self_attn_time): MultiheadAttention( 59 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 60 | ) 61 | ) 62 | (3): ObjectCentricTransformerLayerV2( 63 | (self_attn): MultiheadAttention( 64 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 65 | ) 66 | (linear1): Linear(in_features=256, out_features=512, bias=True) 67 | (dropout): Dropout(p=0.1, inplace=False) 68 | (linear2): Linear(in_features=512, out_features=256, bias=True) 69 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 70 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 71 | (dropout1): Dropout(p=0.1, inplace=False) 72 | (dropout2): Dropout(p=0.1, inplace=False) 73 | (self_attn_obj): MultiheadAttention( 74 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 75 | ) 76 | (self_attn_time): MultiheadAttention( 77 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 78 | ) 79 | ) 80 | ) 81 | (pe): PositionalEncoding( 82 | (dropout): Dropout(p=0.1, inplace=False) 83 | ) 84 | ) -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_OCVPPar/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 1091328 2 | 3 | Params: 1091328 4 | OCVTransformerV2Predictor( 5 | (mlp_in): Linear(in_features=128, out_features=128, bias=True) 6 | (mlp_out): Linear(in_features=128, out_features=128, bias=True) 7 | (transformer_encoders): Sequential( 8 | (0): ObjectCentricTransformerLayerV2( 9 | (self_attn): MultiheadAttention( 10 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 11 | ) 12 | (linear1): Linear(in_features=128, out_features=256, bias=True) 13 | (dropout): Dropout(p=0.1, inplace=False) 14 | (linear2): Linear(in_features=256, out_features=128, bias=True) 15 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 16 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 17 | (dropout1): Dropout(p=0.1, inplace=False) 18 | (dropout2): Dropout(p=0.1, inplace=False) 19 | (self_attn_obj): MultiheadAttention( 20 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 21 | ) 22 | (self_attn_time): MultiheadAttention( 23 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 24 | ) 25 | ) 26 | (1): ObjectCentricTransformerLayerV2( 27 | (self_attn): MultiheadAttention( 28 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 29 | ) 30 | (linear1): Linear(in_features=128, out_features=256, bias=True) 31 | (dropout): Dropout(p=0.1, inplace=False) 32 | (linear2): Linear(in_features=256, out_features=128, bias=True) 33 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 34 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 35 | (dropout1): Dropout(p=0.1, inplace=False) 36 | (dropout2): Dropout(p=0.1, inplace=False) 37 | (self_attn_obj): MultiheadAttention( 38 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 39 | ) 40 | (self_attn_time): MultiheadAttention( 41 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 42 | ) 43 | ) 44 | (2): ObjectCentricTransformerLayerV2( 45 | (self_attn): MultiheadAttention( 46 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 47 | ) 48 | (linear1): Linear(in_features=128, out_features=256, bias=True) 49 | (dropout): Dropout(p=0.1, inplace=False) 50 | (linear2): Linear(in_features=256, out_features=128, bias=True) 51 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 52 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 53 | (dropout1): Dropout(p=0.1, inplace=False) 54 | (dropout2): Dropout(p=0.1, inplace=False) 55 | (self_attn_obj): MultiheadAttention( 56 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 57 | ) 58 | (self_attn_time): MultiheadAttention( 59 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 60 | ) 61 | ) 62 | (3): ObjectCentricTransformerLayerV2( 63 | (self_attn): MultiheadAttention( 64 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 65 | ) 66 | (linear1): Linear(in_features=128, out_features=256, bias=True) 67 | (dropout): Dropout(p=0.1, inplace=False) 68 | (linear2): Linear(in_features=256, out_features=128, bias=True) 69 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 70 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 71 | (dropout1): Dropout(p=0.1, inplace=False) 72 | (dropout2): Dropout(p=0.1, inplace=False) 73 | (self_attn_obj): MultiheadAttention( 74 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 75 | ) 76 | (self_attn_time): MultiheadAttention( 77 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 78 | ) 79 | ) 80 | ) 81 | (pe): PositionalEncoding( 82 | (dropout): Dropout(p=0.1, inplace=False) 83 | ) 84 | ) -------------------------------------------------------------------------------- /experiments/MOViA/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 1013445 2 | 3 | Params: 34177 4 | CoordInit( 5 | (coord_encoder): Sequential( 6 | (0): Linear(in_features=4, out_features=256, bias=True) 7 | (1): ReLU() 8 | (2): Linear(in_features=256, out_features=128, bias=True) 9 | ) 10 | ) 11 | 12 | Params: 79328 13 | SimpleConvEncoder( 14 | (encoder): Sequential( 15 | (0): ConvBlock( 16 | (block): Sequential( 17 | (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 18 | (1): ReLU() 19 | ) 20 | ) 21 | (1): ConvBlock( 22 | (block): Sequential( 23 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 24 | (1): ReLU() 25 | ) 26 | ) 27 | (2): ConvBlock( 28 | (block): Sequential( 29 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 30 | (1): ReLU() 31 | ) 32 | ) 33 | (3): ConvBlock( 34 | (block): Sequential( 35 | (0): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 36 | (1): ReLU() 37 | ) 38 | ) 39 | ) 40 | ) 41 | 42 | Params: 160 43 | SoftPositionEmbed( 44 | (projection): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1)) 45 | ) 46 | 47 | Params: 20800 48 | Sequential( 49 | (0): LayerNorm((32,), eps=1e-05, elementwise_affine=True) 50 | (1): Linear(in_features=32, out_features=128, bias=True) 51 | (2): ReLU() 52 | (3): Linear(in_features=128, out_features=128, bias=True) 53 | ) 54 | 55 | Params: 148480 56 | TransformerBlock( 57 | (attn): MultiHeadSelfAttention( 58 | (q): Linear(in_features=128, out_features=128, bias=False) 59 | (k): Linear(in_features=128, out_features=128, bias=False) 60 | (v): Linear(in_features=128, out_features=128, bias=False) 61 | (drop): Dropout(p=0.0, inplace=False) 62 | (out_projection): Sequential( 63 | (0): Linear(in_features=128, out_features=128, bias=False) 64 | ) 65 | ) 66 | (mlp): Sequential( 67 | (0): Linear(in_features=128, out_features=256, bias=True) 68 | (1): ReLU() 69 | (2): Linear(in_features=256, out_features=128, bias=True) 70 | ) 71 | (layernorm_query): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 72 | (layernorm_mlp): LayerNorm((128,), eps=1e-06, elementwise_affine=True) 73 | (dense_o): Linear(in_features=128, out_features=128, bias=True) 74 | ) 75 | 76 | Params: 640 77 | SoftPositionEmbed( 78 | (projection): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1)) 79 | ) 80 | 81 | Params: 514564 82 | Decoder( 83 | (decoder): Sequential( 84 | (0): ConvBlock( 85 | (block): Sequential( 86 | (0): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 87 | (1): ReLU() 88 | ) 89 | ) 90 | (1): Upsample(scale_factor=2) 91 | (2): ConvBlock( 92 | (block): Sequential( 93 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 94 | (1): ReLU() 95 | ) 96 | ) 97 | (3): Upsample(scale_factor=2) 98 | (4): ConvBlock( 99 | (block): Sequential( 100 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 101 | (1): ReLU() 102 | ) 103 | ) 104 | (5): Upsample(scale_factor=2) 105 | (6): ConvBlock( 106 | (block): Sequential( 107 | (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 108 | (1): ReLU() 109 | ) 110 | ) 111 | (7): Conv2d(64, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 112 | ) 113 | ) 114 | 115 | Params: 215296 116 | SlotAttention( 117 | (norm_input): LayerNorm((128,), eps=0.001, elementwise_affine=True) 118 | (norm_slot): LayerNorm((128,), eps=0.001, elementwise_affine=True) 119 | (norm_mlp): LayerNorm((128,), eps=0.001, elementwise_affine=True) 120 | (to_q): Linear(in_features=128, out_features=128, bias=True) 121 | (to_k): Linear(in_features=128, out_features=128, bias=True) 122 | (to_v): Linear(in_features=128, out_features=128, bias=True) 123 | (gru): GRUCell(128, 128) 124 | (mlp): Sequential( 125 | (0): Linear(in_features=128, out_features=256, bias=True) 126 | (1): ReLU() 127 | (2): Linear(in_features=256, out_features=128, bias=True) 128 | ) 129 | ) 130 | -------------------------------------------------------------------------------- /src/lib/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Methods to manage parameters and configurations 3 | TODO: 4 | - Add support to change any values from the command line 5 | """ 6 | 7 | import os 8 | import json 9 | 10 | from lib.logger import print_ 11 | from lib.utils import timestamp 12 | from CONFIG import DEFAULTS, CONFIG 13 | 14 | 15 | class Config(dict): 16 | """ 17 | """ 18 | _default_values = DEFAULTS 19 | _help = "Potentially you can add here comments for what your configs are" 20 | _config_groups = ["dataset", "model", "training", "loss"] 21 | 22 | def __init__(self, exp_path): 23 | """ 24 | Populating the dictionary with the default values 25 | """ 26 | for key in self._default_values.keys(): 27 | self[key] = self._default_values[key] 28 | self["_general"] = {} 29 | self["_general"]["exp_path"] = exp_path 30 | return 31 | 32 | def create_exp_config_file(self, exp_path=None, config=None): 33 | """ 34 | Creating a JSON file with exp configs in the experiment path 35 | """ 36 | exp_path = exp_path if exp_path is not None else self["_general"]["exp_path"] 37 | if not os.path.exists(exp_path): 38 | raise FileNotFoundError(f"ERROR!: exp_path {exp_path} does not exist...") 39 | 40 | if config is not None: 41 | config_file = os.path.join(CONFIG["paths"]["configs_path"], config) 42 | if not os.path.exists(config_file): 43 | raise FileNotFoundError(f"Given config file {config_file} does not exist...") 44 | 45 | with open(config_file) as file: 46 | self = json.load(file) 47 | self["_general"] = {} 48 | self["_general"]["exp_path"] = exp_path 49 | print_(f"Creating experiment parameters file from config {config}...") 50 | 51 | self["_general"]["created_time"] = timestamp() 52 | self["_general"]["last_loaded"] = timestamp() 53 | exp_config = os.path.join(exp_path, "experiment_params.json") 54 | with open(exp_config, "w") as file: 55 | json.dump(self, file) 56 | return 57 | 58 | def load_exp_config_file(self, exp_path=None): 59 | """ 60 | Loading the JSON file with exp configs 61 | """ 62 | if exp_path is not None: 63 | self["_general"]["exp_path"] = exp_path 64 | exp_config = os.path.join(self["_general"]["exp_path"], "experiment_params.json") 65 | if not os.path.exists(exp_config): 66 | raise FileNotFoundError(f"ERROR! exp. configs file {exp_config} does not exist...") 67 | 68 | with open(exp_config) as file: 69 | self = json.load(file) 70 | self["_general"]["last_loaded"] = timestamp() 71 | return self 72 | 73 | def update_config(self, exp_params): 74 | """ 75 | Updating an experiments parameters file with newly added configurations from CONFIG. 76 | """ 77 | # TODO: Add recursion to make it always work 78 | for group in Config._config_groups: 79 | if not isinstance(Config._default_values[group], dict): 80 | continue 81 | for k in Config._default_values[group].keys(): 82 | if(k not in exp_params[group]): 83 | if(isinstance(Config._default_values[group][k], (dict))): 84 | exp_params[group][k] = {} 85 | else: 86 | exp_params[group][k] = Config._default_values[group][k] 87 | 88 | if(isinstance(Config._default_values[group][k], dict)): 89 | for q in Config._default_values[group][k].keys(): 90 | if(q not in exp_params[group][k]): 91 | exp_params[group][k][q] = Config._default_values[group][k][q] 92 | return exp_params 93 | 94 | def save_exp_config_file(self, exp_path=None, exp_params=None): 95 | """ 96 | Dumping experiment parameters into path 97 | """ 98 | exp_path = self["_general"]["exp_path"] if exp_path is None else exp_path 99 | exp_params = self if exp_params is None else exp_params 100 | 101 | exp_config = os.path.join(exp_path, "experiment_params.json") 102 | with open(exp_config, "w") as file: 103 | json.dump(exp_params, file) 104 | return 105 | 106 | # 107 | -------------------------------------------------------------------------------- /src/06_generate_figs_savi.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generating figures using a pretrained SAVI model 3 | """ 4 | 5 | import os 6 | 7 | import matplotlib.pyplot as plt 8 | import torch 9 | 10 | from base.baseFigGenerator import BaseFigGenerator 11 | 12 | from data import unwrap_batch_data 13 | from lib.arguments import get_generate_figs_savi 14 | from lib.logger import print_ 15 | import lib.utils as utils 16 | from lib.visualizations import visualize_recons, visualize_decomp 17 | 18 | 19 | class FigGenerator(BaseFigGenerator): 20 | """ 21 | Class for generating figures using a pretrained SAVI model 22 | """ 23 | 24 | def __init__(self, exp_path, savi_model, num_seqs=10): 25 | """ 26 | Initializing the figure generation module 27 | """ 28 | super().__init__( 29 | exp_path=exp_path, 30 | savi_model=savi_model, 31 | num_seqs=num_seqs 32 | ) 33 | 34 | model_name = savi_model.split('.')[0] 35 | self.plots_path = os.path.join( 36 | self.exp_path, 37 | "plots", 38 | f"figGeneration_SaVIModel_{model_name}" 39 | ) 40 | self.models_path = os.path.join(self.exp_path, "models") 41 | utils.create_directory(self.plots_path) 42 | return 43 | 44 | @torch.no_grad() 45 | def compute_visualization(self, batch_data, img_idx): 46 | """ 47 | Computing visualization 48 | """ 49 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data) 50 | videos, targets = videos.to(self.device), targets.to(self.device) 51 | out_model = self.model(videos, num_imgs=videos.shape[1], **initializer_kwargs) 52 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model 53 | 54 | cur_dir = f"sequence_{img_idx:02d}" 55 | utils.create_directory(os.path.join(self.plots_path, cur_dir)) 56 | 57 | N = min(10, videos.shape[1]) 58 | savepath = os.path.join(self.plots_path, cur_dir, f"Recons_{img_idx+1}.png") 59 | visualize_recons( 60 | imgs=videos[0, :N].clamp(0, 1), 61 | recons=reconstruction_history[0, :N].clamp(0, 1), 62 | n_cols=10, 63 | savepath=savepath 64 | ) 65 | 66 | savepath = os.path.join(self.plots_path, cur_dir, f"ReconsTargets_{img_idx+1}.png") 67 | visualize_recons( 68 | imgs=targets[0, :N].clamp(0, 1), 69 | recons=reconstruction_history[0, :N].clamp(0, 1), 70 | n_cols=10, 71 | savepath=savepath 72 | ) 73 | 74 | savepath = os.path.join(self.plots_path, cur_dir, f"Objects_{img_idx+1}.png") 75 | fig, _, _ = visualize_decomp( 76 | individual_recons_history[0, :N], 77 | savepath=savepath, 78 | vmin=0, 79 | vmax=1, 80 | ) 81 | plt.close(fig) 82 | 83 | savepath = os.path.join(self.plots_path, cur_dir, f"masks_{img_idx+1}.png") 84 | fig, _, _ = visualize_decomp( 85 | masks_history[0][:N], 86 | savepath=savepath, 87 | cmap="gray_r", 88 | vmin=0, 89 | vmax=1, 90 | ) 91 | plt.close(fig) 92 | savepath = os.path.join(self.plots_path, cur_dir, f"maskedObj_{img_idx+1}.png") 93 | recon_combined = masks_history[0][:N] * individual_recons_history[0][:N] 94 | recon_combined = torch.clamp(recon_combined, min=0, max=1) 95 | fig, _, _ = visualize_decomp( 96 | recon_combined, 97 | savepath=savepath, 98 | vmin=0, 99 | vmax=1, 100 | ) 101 | plt.close(fig) 102 | return 103 | 104 | 105 | if __name__ == "__main__": 106 | utils.clear_cmd() 107 | exp_path, args = get_generate_figs_savi() 108 | print_("Generating figures for SAVI...") 109 | figGenerator = FigGenerator( 110 | exp_path=exp_path, 111 | savi_model=args.savi_model, 112 | num_seqs=args.num_seqs 113 | ) 114 | print_("Loading dataset...") 115 | figGenerator.load_data() 116 | print_("Setting up model and loading pretrained parameters") 117 | figGenerator.load_model() 118 | print_("Generating and saving figures") 119 | figGenerator.generate_figs() 120 | 121 | 122 | # 123 | -------------------------------------------------------------------------------- /src/lib/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils for creating a Logger that writes info, warnings and so on into a logs file 3 | in the experiment directory. Each experiment has its own independent logs 4 | """ 5 | 6 | import os 7 | import traceback 8 | from datetime import datetime 9 | 10 | LOGGER = None 11 | 12 | 13 | def log_function(func): 14 | """ 15 | Decorator for logging a method in case of raising an exception 16 | """ 17 | def try_call_log(*args, **kwargs): 18 | """ 19 | Calling the function but calling the logger in case an exception is raised 20 | """ 21 | try: 22 | if(LOGGER is not None): 23 | message = f"Calling: {func.__name__}..." 24 | LOGGER.log_info(message=message, message_type="info") 25 | return func(*args, **kwargs) 26 | except Exception as e: 27 | if(LOGGER is None): 28 | raise e 29 | message = traceback.format_exc() 30 | print_(message, message_type="error") 31 | exit() 32 | return try_call_log 33 | 34 | 35 | def for_all_methods(decorator): 36 | """ 37 | Decorator that applies a decorator to all methods inside a class 38 | """ 39 | def decorate(cls): 40 | for attr in cls.__dict__: # there's propably a better way to do this 41 | if callable(getattr(cls, attr)): 42 | setattr(cls, attr, decorator(getattr(cls, attr))) 43 | return cls 44 | return decorate 45 | 46 | 47 | def print_(message, message_type="info"): 48 | """ 49 | Overloads the print method so that the message is written both in logs file and console 50 | """ 51 | 52 | print(message) 53 | if(LOGGER is not None): 54 | LOGGER.log_info(message, message_type) 55 | return 56 | 57 | 58 | def log_info(message, message_type="info"): 59 | if(LOGGER is not None): 60 | LOGGER.log_info(message, message_type) 61 | return 62 | 63 | 64 | class Logger(): 65 | """ 66 | Class that instanciates a Logger object to write logs into a file 67 | 68 | Args: 69 | ----- 70 | exp_path: string 71 | path to the root directory of an experiment where the logs are saved 72 | file_name: string 73 | name of the file where logs are stored 74 | """ 75 | 76 | def __init__(self, exp_path, file_name="logs.txt"): 77 | """ 78 | Initializer of the logger object 79 | """ 80 | 81 | logs_path = os.path.join(exp_path, file_name) 82 | self.logs_path = logs_path 83 | 84 | if not os.path.exists(logs_path): 85 | if(not os.path.exists(exp_path)): 86 | os.makedirs(exp_path) 87 | with open(logs_path, 'w') as f: 88 | f.write("") 89 | 90 | global LOGGER 91 | LOGGER = self 92 | return 93 | 94 | def log_info(self, message, message_type="info", **kwargs): 95 | """ 96 | Logging a message into the file 97 | """ 98 | 99 | if(message_type not in ["new_exp", "info", "warning", "error", "params"]): 100 | message_type = "info" 101 | cur_time = self._get_datetime() 102 | format_message = self._format_message(message=message, cur_time=cur_time, 103 | message_type=message_type) 104 | with open(self.logs_path, 'a') as f: 105 | f.write(format_message) 106 | 107 | if(message_type == "error"): 108 | exit() 109 | 110 | return 111 | 112 | def log_params(self, params): 113 | """ 114 | Logging parameters so that it is visually appealing 115 | Args: 116 | ----- 117 | params: dictionary 118 | dictionary containing parameters and values 119 | """ 120 | 121 | for param, value in params.items(): 122 | message = f" {param}:{value}" 123 | self.log_info(message, message_type="params") 124 | 125 | return 126 | 127 | def _format_message(self, message, cur_time, message_type="info"): 128 | """ 129 | Formatting the message to have a standarizied template 130 | """ 131 | pre_string = "" 132 | if(message_type == "new_exp"): 133 | pre_string = "\n\n\n" 134 | form_message = f"{pre_string}{cur_time} {message_type.upper()}: {message}\n" 135 | return form_message 136 | 137 | def _get_datetime(self): 138 | """ 139 | Obtaining current data and time in format YYYY-MM-DD-HH-MM-SS 140 | """ 141 | time = datetime.today().strftime('%Y-%m-%d-%H:%M:%S') 142 | return time 143 | 144 | # 145 | -------------------------------------------------------------------------------- /src/data/load_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Methods for loading specific datasets, fitting data loaders and other 3 | """ 4 | 5 | # from torchvision import datasets 6 | from torch.utils.data import DataLoader 7 | from data import OBJ3D, MOVI 8 | from CONFIG import CONFIG, DATASETS 9 | 10 | 11 | def load_data(exp_params, split="train"): 12 | """ 13 | Loading a dataset given the parameters 14 | 15 | Args: 16 | ----- 17 | dataset_name: string 18 | name of the dataset to load 19 | split: string 20 | Split from the dataset to obtain (e.g., 'train' or 'test') 21 | 22 | Returns: 23 | -------- 24 | dataset: torch dataset 25 | Dataset loaded given specifications from exp_params 26 | """ 27 | dataset_name = exp_params["dataset"]["dataset_name"] 28 | 29 | if dataset_name == "OBJ3D": 30 | dataset = OBJ3D( 31 | mode=split, 32 | sample_length=exp_params["training_prediction"]["sample_length"] 33 | ) 34 | elif dataset_name == "MoviA": 35 | dataset = MOVI( 36 | datapath="/home/nfs/inf6/data/datasets/MOVi/movi_a", 37 | target=exp_params["dataset"].get("target", "rgb"), 38 | split=split, 39 | num_frames=exp_params["training_prediction"]["sample_length"], 40 | img_size=(64, 64), 41 | random_start=exp_params["dataset"].get("random_start", False), 42 | slot_initializer=exp_params["model"]["SAVi"].get("initializer", "LearnedRandom") 43 | ) 44 | elif dataset_name == "MoviC": 45 | dataset = MOVI( 46 | datapath="/home/nfs/inf6/data/datasets/MOVi/movi_c", 47 | target=exp_params["dataset"].get("target", "rgb"), 48 | split=split, 49 | num_frames=exp_params["training_prediction"]["sample_length"], 50 | img_size=(64, 64), 51 | random_start=exp_params["dataset"].get("random_start", False), 52 | slot_initializer=exp_params["model"]["SAVi"].get("initializer", "LearnedRandom") 53 | ) 54 | else: 55 | raise NotImplementedError( 56 | f"""ERROR! Dataset'{dataset_name}' is not available. 57 | Please use one of the following: {DATASETS}...""" 58 | ) 59 | 60 | return dataset 61 | 62 | 63 | def build_data_loader(dataset, batch_size=8, shuffle=False): 64 | """ 65 | Fitting a data loader for the given dataset 66 | 67 | Args: 68 | ----- 69 | dataset: torch dataset 70 | Dataset (or dataset split) to fit to the DataLoader 71 | batch_size: integer 72 | number of elements per mini-batch 73 | shuffle: boolean 74 | If True, mini-batches are sampled randomly from the database 75 | """ 76 | 77 | data_loader = DataLoader( 78 | dataset=dataset, 79 | batch_size=batch_size, 80 | shuffle=shuffle, 81 | num_workers=CONFIG["num_workers"] 82 | ) 83 | 84 | return data_loader 85 | 86 | 87 | def unwrap_batch_data(exp_params, batch_data): 88 | """ 89 | Unwrapping the batch data depending on the dataset that we are training on 90 | """ 91 | initializer_kwargs = {} 92 | if exp_params["dataset"]["dataset_name"] in ["OBJ3D"]: 93 | videos, targets, _ = batch_data 94 | elif exp_params["dataset"]["dataset_name"] in ["MoviA", "MoviC"]: 95 | videos, targets, all_reps = batch_data 96 | initializer_kwargs["instance_masks"] = all_reps["masks"] 97 | initializer_kwargs["com_coords"] = all_reps["com_coords"] 98 | initializer_kwargs["bbox_coords"] = all_reps["bbox_coords"] 99 | else: 100 | dataset_name = exp_params["dataset"]["dataset_name"] 101 | raise NotImplementedError(f"Dataset {dataset_name} is not supported...") 102 | return videos, targets, initializer_kwargs 103 | 104 | 105 | def unwrap_batch_data_masks(exp_params, batch_data): 106 | """ 107 | Unwrapping the batch data for a mask-based evaluation depending on the dataset that 108 | we are currently evaluating on 109 | """ 110 | dbs = ["MoviA", "MoviC"] 111 | dataset_name = exp_params["dataset"]["dataset_name"] 112 | initializer_kwargs = {} 113 | if dataset_name in ["MoviA", "MoviC"]: 114 | videos, _, all_reps = batch_data 115 | masks = all_reps["masks"] 116 | initializer_kwargs["instance_masks"] = all_reps["masks"] 117 | initializer_kwargs["com_coords"] = all_reps["com_coords"] 118 | initializer_kwargs["bbox_coords"] = all_reps["bbox_coords"] 119 | else: 120 | raise ValueError(f"Only {dbs} support object-based mask evaluation. Given {dataset_name = }") 121 | return videos, masks, initializer_kwargs 122 | 123 | 124 | # 125 | -------------------------------------------------------------------------------- /src/base/baseEvaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base evaluator from which all backbone evaluator modules inherit. 3 | 4 | Basically it removes the scaffolding that is repeat across all evaluation modules 5 | """ 6 | 7 | import os 8 | from tqdm import tqdm 9 | import torch 10 | 11 | from lib.config import Config 12 | from lib.logger import log_function, for_all_methods 13 | from lib.metrics import MetricTracker 14 | import lib.setup_model as setup_model 15 | import lib.utils as utils 16 | import data 17 | 18 | 19 | @for_all_methods(log_function) 20 | class BaseEvaluator: 21 | """ 22 | Base Class for evaluating a model 23 | 24 | Args: 25 | ----- 26 | exp_path: string 27 | Path to the experiment directory from which to read the experiment parameters, 28 | and where to store logs, plots and checkpoints 29 | checkpoint: string/None 30 | Name of a model checkpoint to evaluate. 31 | It must be stored in the models/ directory of the experiment directory. 32 | """ 33 | 34 | def __init__(self, exp_path, checkpoint): 35 | """ 36 | Initializing the trainer object 37 | """ 38 | self.exp_path = exp_path 39 | self.cfg = Config(exp_path) 40 | self.exp_params = self.cfg.load_exp_config_file() 41 | self.checkpoint = checkpoint 42 | model_name = checkpoint.split(".")[0] 43 | self.results_name = f"{model_name}" 44 | 45 | self.plots_path = os.path.join(self.exp_path, "plots") 46 | utils.create_directory(self.plots_path) 47 | self.models_path = os.path.join(self.exp_path, "models") 48 | utils.create_directory(self.models_path) 49 | return 50 | 51 | def set_metric_tracker(self): 52 | """ 53 | Initializing the metric tracker with evaluation metrics to track 54 | """ 55 | self.metric_tracker = MetricTracker( 56 | self.exp_path, 57 | metrics=["segmentation_ari"] 58 | ) 59 | 60 | def load_data(self): 61 | """ 62 | Loading test-set and fitting data-loader for iterating in a batch-like fashion 63 | """ 64 | batch_size = 1 # self.exp_params["training"]["batch_size"] 65 | shuffle_eval = self.exp_params["dataset"]["shuffle_eval"] 66 | self.test_set = data.load_data( 67 | exp_params=self.exp_params, 68 | split="test" 69 | ) 70 | self.test_loader = data.build_data_loader( 71 | dataset=self.test_set, 72 | batch_size=batch_size, 73 | shuffle=shuffle_eval 74 | ) 75 | return 76 | 77 | def setup_model(self): 78 | """ 79 | Initializing model and loading pretrained parameters given checkpoint 80 | """ 81 | torch.backends.cudnn.fastest = True 82 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 83 | 84 | # loading model 85 | self.model = setup_model.setup_model(model_params=self.exp_params["model"]) 86 | self.model = self.model.eval().to(self.device) 87 | 88 | # loading pretrained paramters 89 | checkpoint_path = os.path.join(self.models_path, self.checkpoint) 90 | self.model = setup_model.load_checkpoint( 91 | checkpoint_path=checkpoint_path, 92 | model=self.model, 93 | only_model=True 94 | ) 95 | self.set_metric_tracker() 96 | return 97 | 98 | @torch.no_grad() 99 | def evaluate(self, save_results=True): 100 | """ 101 | Evaluating model 102 | """ 103 | self.model = self.model.eval() 104 | progress_bar = tqdm(enumerate(self.test_loader), total=len(self.test_loader)) 105 | 106 | # iterating test set and accumulating the results 107 | for i, batch_data in progress_bar: 108 | self.forward_eval(batch_data=batch_data) 109 | progress_bar.set_description(f"Iter {i}/{len(self.test_loader)}") 110 | 111 | # computing average results and saving to results file 112 | self.metric_tracker.aggregate() 113 | self.results = self.metric_tracker.summary() 114 | if save_results: 115 | self.metric_tracker.save_results(exp_path=self.exp_path, fname=self.results_name) 116 | return 117 | 118 | def forward_eval(self, batch_data, **kwargs): 119 | """ 120 | Making a forwad pass through the model and computing the evaluation metrics 121 | 122 | Args: 123 | ----- 124 | batch_data: dict 125 | Dictionary containing the information for the current batch, including images, poses, 126 | actions, or metadata, among others. 127 | 128 | Returns: 129 | -------- 130 | pred_data: dict 131 | Predictions from the model for the current batch of data 132 | """ 133 | raise NotImplementedError("Base Evaluator Module does not implement 'forward_eval'...") 134 | 135 | # 136 | -------------------------------------------------------------------------------- /src/base/baseFigGenerator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base module for generating images from a pretrained model. 3 | All other Figure Generator modules inherit from here. 4 | 5 | Basically it removes the scaffolding that is repeat across all Figure Generation modules 6 | """ 7 | 8 | import os 9 | from tqdm import tqdm 10 | import torch 11 | 12 | from lib.config import Config 13 | from lib.logger import log_function, for_all_methods 14 | import lib.setup_model as setup_model 15 | import lib.utils as utils 16 | import data 17 | from models.model_utils import freeze_params 18 | 19 | 20 | @for_all_methods(log_function) 21 | class BaseFigGenerator: 22 | """ 23 | Base Class for figure generation 24 | 25 | Args: 26 | ----- 27 | exp_path: string 28 | Path to the experiment directory from which to read the experiment parameters, 29 | and where to store logs, plots and checkpoints 30 | savi_model: string/None 31 | Name of SAVI model checkpoint to use when generating figures. 32 | It must be stored in the models/ directory of the experiment directory. 33 | num_seqs: int 34 | Number of sequences to process and save 35 | """ 36 | 37 | def __init__(self, exp_path, savi_model, num_seqs=10): 38 | """ 39 | Initializing the figure generator object 40 | """ 41 | self.exp_path = os.path.join(exp_path) 42 | self.cfg = Config(self.exp_path) 43 | self.exp_params = self.cfg.load_exp_config_file() 44 | self.savi_model = savi_model 45 | self.num_seqs = num_seqs 46 | 47 | model_name = savi_model.split('.')[0] 48 | self.plots_path = os.path.join( 49 | self.exp_path, 50 | "plots", 51 | f"figGeneration_SaVIModel_{model_name}" 52 | ) 53 | self.models_path = os.path.join(self.exp_path, "models") 54 | utils.create_directory(self.models_path) 55 | return 56 | 57 | def load_data(self): 58 | """ 59 | Loading dataset and fitting data-loader for iterating in a batch-like fashion 60 | """ 61 | batch_size = 1 62 | shuffle_eval = self.exp_params["dataset"]["shuffle_eval"] 63 | # test_set = data.load_data(exp_params=self.exp_params, split="test") 64 | test_set = data.load_data(exp_params=self.exp_params, split="valid") 65 | self.test_set = test_set 66 | self.test_loader = data.build_data_loader( 67 | dataset=test_set, 68 | batch_size=batch_size, 69 | shuffle=shuffle_eval 70 | ) 71 | return 72 | 73 | def load_model(self, exp_path=None): 74 | """ 75 | Load pretraiened SAVi model from checkpoint 76 | 77 | Args: 78 | ----- 79 | exp_path: sting/None 80 | If None, 'self.exp_path' is used. Otherwise, the given path is used to load 81 | the pretrained model paramters 82 | """ 83 | # to use the same function for SAVi and Predictor figure generation 84 | exp_path = exp_path if exp_path is not None else self.exp_path 85 | 86 | torch.backends.cudnn.fastest = True 87 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 | 89 | # loading model 90 | self.model = setup_model.setup_model(model_params=self.exp_params["model"]) 91 | self.model = self.model.eval().to(self.device) 92 | 93 | checkpoint_path = os.path.join(exp_path, "models", self.savi_model) 94 | self.model = setup_model.load_checkpoint( 95 | checkpoint_path=checkpoint_path, 96 | model=self.model, 97 | only_model=True 98 | ) 99 | freeze_params(self.model) 100 | return 101 | 102 | def load_predictor(self): 103 | """ 104 | Load pretrained predictor model from the corresponding model checkpoint 105 | """ 106 | # loading model 107 | predictor = setup_model.setup_predictor(exp_params=self.exp_params) 108 | predictor = predictor.eval().to(self.device) 109 | 110 | # loading pretrained predictor 111 | predictor = setup_model.load_checkpoint( 112 | checkpoint_path=os.path.join(self.models_path, self.checkpoint), 113 | model=predictor, 114 | only_model=True, 115 | ) 116 | self.predictor = predictor 117 | return 118 | 119 | @torch.no_grad() 120 | def generate_figs(self): 121 | """ 122 | Computing and saving visualizations 123 | """ 124 | progress_bar = tqdm(enumerate(self.test_loader), total=self.num_seqs) 125 | for i, batch_data in progress_bar: 126 | if i >= self.num_seqs: 127 | break 128 | self.compute_visualization(batch_data=batch_data, img_idx=i) 129 | return 130 | 131 | def compute_visualization(self, batch_data, img_idx, **kwargs): 132 | """ 133 | Making a forwad pass through the model and computing the evaluation metrics 134 | 135 | Args: 136 | ----- 137 | batch_data: dict 138 | Dictionary containing the information for the current batch, including images, poses, 139 | actions, or metadata, among others. 140 | img_idx: int 141 | Index of the visualization to compute and save 142 | """ 143 | raise NotImplementedError("Base FigGenerator does not implement 'compute_visualization'...") 144 | 145 | # 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Object-Centric Video Prediction via Decoupling
of Object Dynamics and Interactions 2 | 3 | 4 |

5 | 6 |      7 | 8 |

9 | 10 | ![](assets/gif1_gt.gif) 11 | ![](assets/gif1_pred.gif) 12 | ![](assets/gif1_seg.gif) 13 | ![](assets/gif2_gt.gif) 14 | ![](assets/gif2_pred.gif) 15 | ![](assets/gif2_seg.gif) 16 | 17 | 18 | 19 | Official implementation of: **Object-Centric Video Prediction via Decoupling of Object Dynamics and Interactions** by Villar-Corrales et al. ICIP 2023. [[Paper](http://www.angelvillarcorrales.com/templates/others/Publications/2023_ObjectCentricVideoPrediction_ICIP.pdf)] [[Project Page](https://sites.google.com/view/ocvp-vp)] 20 | 21 | 22 | 23 | ## Installation 24 | 25 | We refer to [docs/INSTALL.md](https://github.com/AIS-Bonn/OCVP-object-centric-video-prediction/blob/master/assets/docs/INSTALL.md) for detailed installation and preparation instructions. 26 | 27 | 28 | 29 | ## Training 30 | 31 | We refer to [docs/TRAIN.md](https://github.com/AIS-Bonn/OCVP-object-centric-video-prediction/blob/master/assets/docs/TRAIN.md) for detailed instructions for training your own Object-Centric Video Decomposition model. Additonally, we report the required training time for both the SAVi scene decomposition, as well as the OCVP-Seq predictor module. 32 | 33 | 34 | 35 | ## Evaluation and Figure Generation 36 | 37 | To reproduce the results provided in our paper, you can download our pretrained models, including checkpoints for the SAVi decomposition and prediction modules, by running the `download_pretrained` bash script: 38 | 39 | ``` 40 | chmod +x download_pretrained.sh 41 | ./download_pretrained.sh 42 | ``` 43 | 44 | ### Evaluate SAVi for Image Decomposition 45 | 46 | You can evaluate a SAVi video decomposition model using the `src/03_evaluate_savi_noMasks.py` and `src/03_evaluate_savi.py` scripts. The former measures the quality of the reconstructed frames, whereas the latter measures the fidelity of the object masks. 47 | 48 | **Example:** 49 | ``` 50 | python src/03_evaluate_savi_noMasks.py \ 51 | -d experiments/MOViA/ \ 52 | --checkpoint savi_movia.pth 53 | 54 | python src/03_evaluate_savi.py \ 55 | -d experiments/MOViA/ \ 56 | --checkpoint savi_movia.pth 57 | ``` 58 | 59 | ### Evaluate Object-Centric Video Prediction 60 | 61 | To evaluate an object-centric video predictor module (i.e. LSTM, Transformer, OCVP-Seq or OCVP-Par), you can use the `src/05_evaluate_predictor.py` script. 62 | 63 | 64 | ``` 65 | usage: 05_evaluate_predictor.py [-h] -d EXP_DIRECTORY -m SAVI_MODEL --name_predictor_experiment NAME_PREDICTOR_EXPERIMENT --checkpoint CHECKPOINT [--num_preds NUM_PREDS] 66 | 67 | arguments: 68 | -d EXP_DIRECTORY, --exp_directory EXP_DIRECTORY 69 | Path to the father exp. directory 70 | -m SAVI_MODEL, --savi_model SAVI_MODEL 71 | Name of the SAVi checkpoint to load 72 | --name_predictor_experiment NAME_PREDICTOR_EXPERIMENT 73 | Name to the directory inside the exp_directory corresponding to a predictor experiment. 74 | --checkpoint CHECKPOINT 75 | Checkpoint with predictor pre-trained parameters to load for evaluation 76 | --num_preds NUM_PREDS 77 | Number of rollout frames to predict for 78 | ``` 79 | 80 | **Example 1:** Reproduce LSTM predictor results on the Obj3D dataset: 81 | ``` 82 | python src/05_evaluate_predictor.py \ 83 | -d experiments/Obj3D/ \ 84 | --savi_model savi_obj3d.pth \ 85 | --name_predictor_experiment Predictor_LSTM \ 86 | --checkpoint lstm_obj3d.pth \ 87 | --num_preds 25 88 | ``` 89 | 90 | **Example 2:** Reproduce OCVP-Seq predictor results on the MOVi-A dataset: 91 | ``` 92 | python src/05_evaluate_predictor.py \ 93 | -d experiments/MOViA/ \ 94 | --savi_model savi_movia.pth \ 95 | --name_predictor_experiment Predictor_OCVPSeq \ 96 | --checkpoint OCVPSeq_movia.pth \ 97 | --num_preds 18 98 | ``` 99 | 100 | ### Generate Figures and Animations 101 | 102 | To generate video prediction, object prediction and segmentation figures and animations, you can use the 103 | `src/06_generate_figs_pred.py` script. 104 | 105 | **Example:** 106 | ``` 107 | python src/06_generate_figs_pred.py \ 108 | -d experiments/Obj3D/ \ 109 | --savi_model savi_obj3d.pth \ 110 | --name_predictor_experiment Predictor_OCVPSeq \ 111 | --checkpoint OCVPSeq_obj3d.pth \ 112 | --num_seqs 10 \ 113 | --num_preds 25 114 | ``` 115 | 116 | 117 | ## Acknowledgement 118 | 119 | Our work is inspired and uses resources from the following repositories: 120 | - [SAVi-pytorch](https://github.com/junkeun-yi/SAVi-pytorch) 121 | - [slot-attention-video](https://github.com/google-research/slot-attention-video/) 122 | - [G-SWM](https://github.com/zhixuan-lin/G-SWM) 123 | 124 | 125 | ## Contact and Citation 126 | 127 | This repository is maintained by [Angel Villar-Corrales](http://angelvillarcorrales.com/templates/home.php). 128 | 129 | 130 | Please consider citing our paper if you find our work or our repository helpful. 131 | 132 | ``` 133 | @inproceedings{villar_ObjectCentricVideoPrediction_2023, 134 | title={Object-Centric Video Prediction via Decoupling of Object Dynamics and Interactions}, 135 | author={Villar-Corrales, Angel and Wahdan, Ismail and Behnke, Sven}, 136 | booktitle={Internation Conference on Image Processing (ICIP)}, 137 | year={2023} 138 | } 139 | ``` 140 | 141 | In case of any questions or problems regarding the project or repository, do not hesitate to contact the authors at villar@ais.uni-bonn.de. 142 | -------------------------------------------------------------------------------- /src/models/initializers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modules for the initalization of the slots on SlotAttention and SAVI 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from math import sqrt 8 | 9 | from CONFIG import INITIALIZERS 10 | 11 | 12 | ENCODER_RESOLUTION = (8, 14) 13 | 14 | 15 | def get_initalizer(mode, slot_dim, num_slots, encoder_resolution=None): 16 | """ 17 | Fetching the initializer module of the slots 18 | 19 | Args: 20 | ----- 21 | model: string 22 | Type of initializer to use. Valid modes are {INITIALIZERS} 23 | slot_dim: int 24 | Dimensionality of the slots 25 | num_slots: int 26 | Number of slots to initialize 27 | """ 28 | encoder_resolution = encoder_resolution if encoder_resolution is not None else ENCODER_RESOLUTION 29 | if mode not in INITIALIZERS: 30 | raise ValueError(f"Unknown initializer {mode = }. Available modes are {INITIALIZERS}") 31 | 32 | if mode == "Random": 33 | intializer = Random(slot_dim=slot_dim, num_slots=num_slots) 34 | elif mode == "LearnedRandom": 35 | intializer = LearnedRandom(slot_dim=slot_dim, num_slots=num_slots) 36 | elif mode == "Masks": 37 | raise NotImplementedError("'Masks' initialization is not supported...") 38 | elif mode == "CoM": 39 | intializer = CoordInit(slot_dim=slot_dim, num_slots=num_slots, mode="CoM") 40 | elif mode == "BBox": 41 | intializer = CoordInit(slot_dim=slot_dim, num_slots=num_slots, mode="BBox") 42 | else: 43 | raise ValueError(f"UPSI, {mode = } should not have reached here...") 44 | 45 | return intializer 46 | 47 | 48 | class Random(nn.Module): 49 | """ 50 | Gaussian random slot initialization 51 | """ 52 | 53 | def __init__(self, slot_dim, num_slots): 54 | """ 55 | Module intializer 56 | """ 57 | super().__init__() 58 | self.slot_dim = slot_dim 59 | self.num_slots = num_slots 60 | 61 | def forward(self, batch_size, **kwargs): 62 | """ 63 | Sampling random Gaussian slots 64 | """ 65 | slots = torch.randn(batch_size, self.num_slots, self.slot_dim) 66 | return slots 67 | 68 | 69 | class LearnedRandom(nn.Module): 70 | """ 71 | Learned random intialization. This is the default mode used in SlotAttention. 72 | Slots are randomly sampled from a Gaussian distribution. However, the statistics of this 73 | distribution (mean vector and diagonal of covariance) are learned via backpropagation 74 | """ 75 | 76 | def __init__(self, slot_dim, num_slots): 77 | """ Module intializer """ 78 | super().__init__() 79 | self.slot_dim = slot_dim 80 | self.num_slots = num_slots 81 | 82 | self.slots_mu = nn.Parameter(torch.randn(1, 1, slot_dim)) 83 | self.slots_sigma = nn.Parameter(torch.randn(1, 1, slot_dim)) 84 | 85 | with torch.no_grad(): 86 | limit = sqrt(6.0 / (1 + slot_dim)) 87 | torch.nn.init.uniform_(self.slots_mu, -limit, limit) 88 | torch.nn.init.uniform_(self.slots_sigma, -limit, limit) 89 | return 90 | 91 | def forward(self, batch_size, **kwargs): 92 | """ 93 | Sampling random slots from the learned gaussian distribution 94 | """ 95 | mu = self.slots_mu.expand(batch_size, self.num_slots, -1) 96 | sigma = self.slots_sigma.expand(batch_size, self.num_slots, -1) 97 | slots = mu + sigma * torch.randn(mu.shape, device=self.slots_mu.device) 98 | return slots 99 | 100 | 101 | class CoordInit(nn.Module): 102 | """ 103 | Slots are initalized by encoding, for each object, the coordinates of one of the following: 104 | - the CoM of the instance segmentation of each object, represented as [y, x] 105 | - the BBox containing each object, represented as [y_min, x_min, y_max, x_max] 106 | """ 107 | 108 | MODES = ["CoM", "BBox"] 109 | MODE_REP = { 110 | "CoM": "com_coords", 111 | "BBox": "bbox_coords" 112 | } 113 | IN_FEATS = { 114 | "CoM": 2, 115 | "BBox": 4 116 | } 117 | 118 | def __init__(self, slot_dim, num_slots, mode): 119 | """ 120 | Module intializer 121 | """ 122 | assert mode in CoordInit.MODES, f"Unknown {mode = }. Use one of {CoordInit.MODES}" 123 | super().__init__() 124 | self.slot_dim = slot_dim 125 | self.num_slots = num_slots 126 | self.mode = mode 127 | self.coord_encoder = nn.Sequential( 128 | nn.Linear(CoordInit.IN_FEATS[self.mode], 256), 129 | nn.ReLU(), 130 | nn.Linear(256, slot_dim), 131 | ) 132 | self.dummy_parameter = nn.Parameter(torch.tensor([0.])) 133 | return 134 | 135 | def forward(self, batch_size, **kwargs): 136 | """ 137 | Encoding BBox or CoM coordinates into slots using an MLP 138 | """ 139 | device = self.dummy_parameter.device 140 | rep_name = CoordInit.MODE_REP[self.mode] 141 | in_feats = CoordInit.IN_FEATS[self.mode] 142 | 143 | coords = kwargs.get(rep_name, None) 144 | if coords is None or coords.sum() == 0: 145 | raise ValueError(f"{self.mode} Initializer requires having '{rep_name}'...") 146 | if len(coords.shape) == 4: # getting only coords corresponding to time-step t=0 147 | coords = coords[:, 0] 148 | coords = coords.to(device) 149 | 150 | # obtaining -1-vectors for filling the slots that currently do not have an object 151 | num_coords = coords.shape[1] 152 | if num_coords > self.num_slots: 153 | raise ValueError(f"There shouldnt be more {num_coords = } than {self.num_slots = }! ") 154 | if num_coords < self.num_slots: 155 | remaining_masks = self.num_slots - num_coords 156 | pad_zeros = -1 * torch.ones((coords.shape[0], remaining_masks, in_feats), device=device) 157 | coords = torch.cat([coords, pad_zeros], dim=2) 158 | 159 | slots = self.coord_encoder(coords) 160 | return slots 161 | 162 | 163 | # 164 | -------------------------------------------------------------------------------- /experiments/MOViA/Predictor_OCVPSeq/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 4282752 2 | 3 | Params: 4282752 4 | OCVTransformerV1Predictor( 5 | (mlp_in): Linear(in_features=128, out_features=256, bias=True) 6 | (mlp_out): Linear(in_features=256, out_features=128, bias=True) 7 | (transformer_encoders): Sequential( 8 | (0): ObjectCentricTransformerLayerV1( 9 | (object_encoder_block): TransformerEncoderLayer( 10 | (self_attn): MultiheadAttention( 11 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 12 | ) 13 | (linear1): Linear(in_features=256, out_features=512, bias=True) 14 | (dropout): Dropout(p=0.1, inplace=False) 15 | (linear2): Linear(in_features=512, out_features=256, bias=True) 16 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 17 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 18 | (dropout1): Dropout(p=0.1, inplace=False) 19 | (dropout2): Dropout(p=0.1, inplace=False) 20 | ) 21 | (time_encoder_block): TransformerEncoderLayer( 22 | (self_attn): MultiheadAttention( 23 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 24 | ) 25 | (linear1): Linear(in_features=256, out_features=512, bias=True) 26 | (dropout): Dropout(p=0.1, inplace=False) 27 | (linear2): Linear(in_features=512, out_features=256, bias=True) 28 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 29 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 30 | (dropout1): Dropout(p=0.1, inplace=False) 31 | (dropout2): Dropout(p=0.1, inplace=False) 32 | ) 33 | ) 34 | (1): ObjectCentricTransformerLayerV1( 35 | (object_encoder_block): TransformerEncoderLayer( 36 | (self_attn): MultiheadAttention( 37 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 38 | ) 39 | (linear1): Linear(in_features=256, out_features=512, bias=True) 40 | (dropout): Dropout(p=0.1, inplace=False) 41 | (linear2): Linear(in_features=512, out_features=256, bias=True) 42 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 43 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 44 | (dropout1): Dropout(p=0.1, inplace=False) 45 | (dropout2): Dropout(p=0.1, inplace=False) 46 | ) 47 | (time_encoder_block): TransformerEncoderLayer( 48 | (self_attn): MultiheadAttention( 49 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 50 | ) 51 | (linear1): Linear(in_features=256, out_features=512, bias=True) 52 | (dropout): Dropout(p=0.1, inplace=False) 53 | (linear2): Linear(in_features=512, out_features=256, bias=True) 54 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 55 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 56 | (dropout1): Dropout(p=0.1, inplace=False) 57 | (dropout2): Dropout(p=0.1, inplace=False) 58 | ) 59 | ) 60 | (2): ObjectCentricTransformerLayerV1( 61 | (object_encoder_block): TransformerEncoderLayer( 62 | (self_attn): MultiheadAttention( 63 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 64 | ) 65 | (linear1): Linear(in_features=256, out_features=512, bias=True) 66 | (dropout): Dropout(p=0.1, inplace=False) 67 | (linear2): Linear(in_features=512, out_features=256, bias=True) 68 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 69 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 70 | (dropout1): Dropout(p=0.1, inplace=False) 71 | (dropout2): Dropout(p=0.1, inplace=False) 72 | ) 73 | (time_encoder_block): TransformerEncoderLayer( 74 | (self_attn): MultiheadAttention( 75 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 76 | ) 77 | (linear1): Linear(in_features=256, out_features=512, bias=True) 78 | (dropout): Dropout(p=0.1, inplace=False) 79 | (linear2): Linear(in_features=512, out_features=256, bias=True) 80 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 81 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 82 | (dropout1): Dropout(p=0.1, inplace=False) 83 | (dropout2): Dropout(p=0.1, inplace=False) 84 | ) 85 | ) 86 | (3): ObjectCentricTransformerLayerV1( 87 | (object_encoder_block): TransformerEncoderLayer( 88 | (self_attn): MultiheadAttention( 89 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 90 | ) 91 | (linear1): Linear(in_features=256, out_features=512, bias=True) 92 | (dropout): Dropout(p=0.1, inplace=False) 93 | (linear2): Linear(in_features=512, out_features=256, bias=True) 94 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 95 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 96 | (dropout1): Dropout(p=0.1, inplace=False) 97 | (dropout2): Dropout(p=0.1, inplace=False) 98 | ) 99 | (time_encoder_block): TransformerEncoderLayer( 100 | (self_attn): MultiheadAttention( 101 | (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True) 102 | ) 103 | (linear1): Linear(in_features=256, out_features=512, bias=True) 104 | (dropout): Dropout(p=0.1, inplace=False) 105 | (linear2): Linear(in_features=512, out_features=256, bias=True) 106 | (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 107 | (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True) 108 | (dropout1): Dropout(p=0.1, inplace=False) 109 | (dropout2): Dropout(p=0.1, inplace=False) 110 | ) 111 | ) 112 | ) 113 | (pe): PositionalEncoding( 114 | (dropout): Dropout(p=0.1, inplace=False) 115 | ) 116 | ) -------------------------------------------------------------------------------- /experiments/Obj3D/Predictor_OCVPSeq/model_architecture.txt: -------------------------------------------------------------------------------- 1 | Total Params: 1092864 2 | 3 | Params: 1092864 4 | OCVTransformerV1Predictor( 5 | (mlp_in): Linear(in_features=128, out_features=128, bias=True) 6 | (mlp_out): Linear(in_features=128, out_features=128, bias=True) 7 | (transformer_encoders): Sequential( 8 | (0): ObjectCentricTransformerLayerV1( 9 | (object_encoder_block): TransformerEncoderLayer( 10 | (self_attn): MultiheadAttention( 11 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 12 | ) 13 | (linear1): Linear(in_features=128, out_features=256, bias=True) 14 | (dropout): Dropout(p=0.1, inplace=False) 15 | (linear2): Linear(in_features=256, out_features=128, bias=True) 16 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 17 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 18 | (dropout1): Dropout(p=0.1, inplace=False) 19 | (dropout2): Dropout(p=0.1, inplace=False) 20 | ) 21 | (time_encoder_block): TransformerEncoderLayer( 22 | (self_attn): MultiheadAttention( 23 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 24 | ) 25 | (linear1): Linear(in_features=128, out_features=256, bias=True) 26 | (dropout): Dropout(p=0.1, inplace=False) 27 | (linear2): Linear(in_features=256, out_features=128, bias=True) 28 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 29 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 30 | (dropout1): Dropout(p=0.1, inplace=False) 31 | (dropout2): Dropout(p=0.1, inplace=False) 32 | ) 33 | ) 34 | (1): ObjectCentricTransformerLayerV1( 35 | (object_encoder_block): TransformerEncoderLayer( 36 | (self_attn): MultiheadAttention( 37 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 38 | ) 39 | (linear1): Linear(in_features=128, out_features=256, bias=True) 40 | (dropout): Dropout(p=0.1, inplace=False) 41 | (linear2): Linear(in_features=256, out_features=128, bias=True) 42 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 43 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 44 | (dropout1): Dropout(p=0.1, inplace=False) 45 | (dropout2): Dropout(p=0.1, inplace=False) 46 | ) 47 | (time_encoder_block): TransformerEncoderLayer( 48 | (self_attn): MultiheadAttention( 49 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 50 | ) 51 | (linear1): Linear(in_features=128, out_features=256, bias=True) 52 | (dropout): Dropout(p=0.1, inplace=False) 53 | (linear2): Linear(in_features=256, out_features=128, bias=True) 54 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 55 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 56 | (dropout1): Dropout(p=0.1, inplace=False) 57 | (dropout2): Dropout(p=0.1, inplace=False) 58 | ) 59 | ) 60 | (2): ObjectCentricTransformerLayerV1( 61 | (object_encoder_block): TransformerEncoderLayer( 62 | (self_attn): MultiheadAttention( 63 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 64 | ) 65 | (linear1): Linear(in_features=128, out_features=256, bias=True) 66 | (dropout): Dropout(p=0.1, inplace=False) 67 | (linear2): Linear(in_features=256, out_features=128, bias=True) 68 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 69 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 70 | (dropout1): Dropout(p=0.1, inplace=False) 71 | (dropout2): Dropout(p=0.1, inplace=False) 72 | ) 73 | (time_encoder_block): TransformerEncoderLayer( 74 | (self_attn): MultiheadAttention( 75 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 76 | ) 77 | (linear1): Linear(in_features=128, out_features=256, bias=True) 78 | (dropout): Dropout(p=0.1, inplace=False) 79 | (linear2): Linear(in_features=256, out_features=128, bias=True) 80 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 81 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 82 | (dropout1): Dropout(p=0.1, inplace=False) 83 | (dropout2): Dropout(p=0.1, inplace=False) 84 | ) 85 | ) 86 | (3): ObjectCentricTransformerLayerV1( 87 | (object_encoder_block): TransformerEncoderLayer( 88 | (self_attn): MultiheadAttention( 89 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 90 | ) 91 | (linear1): Linear(in_features=128, out_features=256, bias=True) 92 | (dropout): Dropout(p=0.1, inplace=False) 93 | (linear2): Linear(in_features=256, out_features=128, bias=True) 94 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 95 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 96 | (dropout1): Dropout(p=0.1, inplace=False) 97 | (dropout2): Dropout(p=0.1, inplace=False) 98 | ) 99 | (time_encoder_block): TransformerEncoderLayer( 100 | (self_attn): MultiheadAttention( 101 | (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True) 102 | ) 103 | (linear1): Linear(in_features=128, out_features=256, bias=True) 104 | (dropout): Dropout(p=0.1, inplace=False) 105 | (linear2): Linear(in_features=256, out_features=128, bias=True) 106 | (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 107 | (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True) 108 | (dropout1): Dropout(p=0.1, inplace=False) 109 | (dropout2): Dropout(p=0.1, inplace=False) 110 | ) 111 | ) 112 | ) 113 | (pe): PositionalEncoding( 114 | (dropout): Dropout(p=0.1, inplace=False) 115 | ) 116 | ) -------------------------------------------------------------------------------- /src/CONFIG.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global configurations 3 | """ 4 | 5 | import os 6 | 7 | CONFIG = { 8 | "random_seed": 14, 9 | "epsilon_min": 1e-16, 10 | "epsilon_max": 1e16, 11 | "num_workers": 8, 12 | "paths": { 13 | "data_path": os.path.join(os.getcwd(), "datasets"), 14 | "experiments_path": os.path.join(os.getcwd(), "experiments"), 15 | "configs_path": os.path.join(os.getcwd(), "src", "configs"), 16 | } 17 | } 18 | 19 | 20 | # Supported datasets, models, metrics, and so on 21 | DATASETS = ["OBJ3D", "MoviA", "MoviC"] 22 | LOSSES = [ 23 | "mse", "l2", # standard losses 24 | "pred_img_mse", # ||pred_img - target_img||^2 when predicting future images 25 | "pred_slot_mse", # ||pred_slot - target_slot||^2 when (also) predicting future slots 26 | ] 27 | METRICS = [ 28 | "segmentation_ari", "IoU", # object-centric metrics 29 | "mse", "psnr", "ssim", "lpips" # video predicition metrics 30 | ] 31 | MODELS = ["SAVi"] 32 | PREDICTORS = ["LSTM", "Transformer", "OCVP-Seq", "OCVP-Par"] 33 | INITIALIZERS = ["Random", "LearnedRandom", "Masks", "CoM", "BBox"] 34 | 35 | DEFAULTS = { 36 | "dataset": { 37 | "dataset_name": "OBJ3D", 38 | "shuffle_train": True, 39 | "shuffle_eval": False, 40 | "use_segmentation": True, 41 | "target": "rgb", 42 | "random_start": True 43 | }, 44 | "model": { 45 | "model_name": "SAVi", 46 | "SAVi": { 47 | "num_slots": 6, 48 | "slot_dim": 64, 49 | "in_channels": 3, 50 | "encoder_type": "ConvEncoder", 51 | "num_channels": (32, 32, 32, 32), 52 | "mlp_encoder_dim": 64, 53 | "mlp_hidden": 128, 54 | "num_channels_decoder": (32, 32, 32, 32), 55 | "kernel_size": 5, 56 | "num_iterations_first": 3, 57 | "num_iterations": 1, 58 | "resolution": (64, 64), 59 | "downsample_encoder": False, 60 | "downsample_decoder": False, 61 | "decoder_resolution": (64, 64), # set it according to size after downsampling if necessary 62 | "upsample": 2, 63 | "use_predictor": True, 64 | "initializer": "LearnedRandom" 65 | }, 66 | "predictor": 67 | { 68 | "predictor_name": "Transformer", 69 | "LSTM": { 70 | "num_cells": 2, 71 | "hidden_dim": 64, 72 | "residual": True, 73 | }, 74 | "Transformer": { 75 | "token_dim": 128, 76 | "hidden_dim": 256, 77 | "num_layers": 2, 78 | "n_heads": 4, 79 | "residual": True, 80 | "input_buffer_size": None 81 | }, 82 | "OCVP-Seq": { 83 | "token_dim": 128, 84 | "hidden_dim": 256, 85 | "num_layers": 2, 86 | "n_heads": 4, 87 | "residual": True, 88 | "input_buffer_size": None 89 | }, 90 | "OCVP-Par": { 91 | "token_dim": 128, 92 | "hidden_dim": 256, 93 | "num_layers": 2, 94 | "n_heads": 4, 95 | "residual": True, 96 | "input_buffer_size": None 97 | } 98 | } 99 | }, 100 | "loss": [ 101 | { 102 | "type": "mse", 103 | "weight": 1 104 | } 105 | ], 106 | "predictor_loss": [ 107 | { 108 | "type": "pred_img_mse", 109 | "weight": 1 110 | }, 111 | { 112 | "type": "pred_slot_mse", 113 | "weight": 1 114 | } 115 | ], 116 | "training_slots": { # training related parameters 117 | "num_epochs": 1000, # number of epochs to train for 118 | "save_frequency": 10, # saving a checkpoint after these iterations () 119 | "log_frequency": 25, # logging stats after this amount of updates 120 | "image_log_frequency": 100, # logging stats after this amount of updates 121 | "batch_size": 64, 122 | "lr": 1e-4, 123 | "optimizer": "adam", # optimizer parameters: name, L2-reg, momentum 124 | "momentum": 0, 125 | "weight_decay": 0, 126 | "nesterov": False, 127 | "scheduler": "", # learning rate scheduler parameters 128 | "lr_factor": 0.8, # Meaning depends on scheduler. See lib/model_setup.py 129 | "patience": 10, 130 | "scheduler_steps": 1e6, 131 | "lr_warmup": False, # learning rate warmup parameters (2 epochs or 200 iters default) 132 | "warmup_steps": 2000, 133 | "warmup_epochs": 2, 134 | "gradient_clipping": True, 135 | "clipping_max_value": 0.05 # according to SAVI paper 136 | }, 137 | "training_prediction": { # training related parameters 138 | "num_context": 5, 139 | "num_preds": 5, 140 | "teacher_force": False, 141 | "skip_first_slot": False, # if True, slots from first image are not considered 142 | "num_epochs": 1500, # number of epochs to train for 143 | "train_iters_per_epoch": 1e10, # max number of iterations per epoch 144 | "save_frequency": 10, # saving a checkpoint after these iterations () 145 | "save_frequency_iters": 1000, # saving a checkpoint after these iterations 146 | "log_frequency": 25, # logging stats after this amount of updates 147 | "image_log_frequency": 100, # logging stats after this amount of updates 148 | "batch_size": 64, 149 | "sample_length": 10, 150 | "gradient_clipping": True, 151 | "clipping_max_value": 0.05 # according to SAVI paper 152 | } 153 | } 154 | 155 | 156 | COLORS = ["white", "blue", "green", "olive", "red", "yellow", "purple", "orange", "cyan", 157 | "brown", "pink", "darkorange", "goldenrod", "darkviolet", "springgreen", 158 | "aqua", "royalblue", "navy", "forestgreen", "plum", "magenta", "slategray", 159 | "maroon", "gold", "peachpuff", "silver", "aquamarine", "indianred", "greenyellow", 160 | "darkcyan", "sandybrown"] 161 | 162 | 163 | # 164 | -------------------------------------------------------------------------------- /assets/docs/TRAIN.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | We provide our entire pipeline for training a SAVi model for object-centric video decomposition, as well as for training our object-centric video predictor modules. 4 | 5 | 6 | ## Train SAVi Video Decomposition Model 7 | 8 | **1.** Create a new experiment using the `src/01_create_experiment.py` script. This will create a new experiments folder in the `/experiments` directory. 9 | 10 | ``` 11 | usage 01_create_experiment.py [-h] -d EXP_DIRECTORY [--name NAME] [--config CONFIG] 12 | 13 | optional arguments: 14 | -d EXP_DIRECTORY, --exp_directory EXP_DIRECTORY Directory where the experiment folder will be created 15 | --name NAME Name to give to the experiment 16 | ``` 17 | 18 | 19 | 20 | **2.** Modify the experiment parameters located in `experiments/YOUR_EXP_DIR/YOUR_EXP_NAME/experiment_params.json` to adapt to your dataset and training needs. 21 | You provide two examples for training SAVi on the [Obj3D](https://github.com/AIS-Bonn/OCVP-object-centric-video-prediction/blob/master/experiments/Obj3D/experiment_params.json) and [MOVi-A](https://github.com/AIS-Bonn/OCVP-object-centric-video-prediction/blob/master/experiments/MOViA/experiment_params.json) datasets. 22 | 23 | 24 | 25 | **3.** Train SAVi given the specified experiment parameters: 26 | 27 | ``` 28 | usage: 02_train_savi.py [-h] -d EXP_DIRECTORY [--checkpoint CHECKPOINT] [--resume_training] 29 | 30 | optional arguments: 31 | -d EXP_DIRECTORY, --exp_directory EXP_DIRECTORY 32 | Path to the experiment directory 33 | --checkpoint CHECKPOINT 34 | Checkpoint with pretrained parameters to load 35 | --resume_training For resuming training 36 | ``` 37 | 38 | 39 | #### Example: SAVi Training 40 | 41 | Below we provide an example of how to train a new SAVi model: 42 | 43 | ``` 44 | python src/01_create_experiment.py -d new_exps --name my_exp 45 | python src/02_train_savi.py -d experiments/new_exps/my_exp 46 | ``` 47 | 48 | 49 | ## Train an Object-Centric Video Prediction Model 50 | 51 | Training an object-centric video prediction requires having a pretrained SAVi model. You can use either our provided pretrained models, or you can train your own SAVi video decomposition models. 52 | 53 | 54 | **1.** Create a new predictor experiment using the `src/src/01_create_predictor_experiment.py` script. This will create a new predictor folder in the specified experiment directory. 55 | 56 | ``` 57 | usage: 01_create_predictor_experiment.py [-h] -d EXP_DIRECTORY --name NAME --predictor_name PREDICTOR_NAME 58 | 59 | optional arguments: 60 | -d EXP_DIRECTORY, --exp_directory EXP_DIRECTORY 61 | Directory where the predictor experimentwill be created 62 | --name NAME Name to give to the predictor experiment 63 | --predictor_name PREDICTOR_NAME 64 | Name of the predictor module to use: ['LSTM', 'Transformer', 'OCVP-Seq', 'OCVP-Par'] 65 | ``` 66 | 67 | 68 | **2.** Modify the experiment parameters located in `experiments/YOUR_EXP_DIR/YOUR_EXP_NAME/YOUR_PREDICTOR_NAME/experiment_params.json` to adapt the predictor training parameters to your dataset and training needs. 69 | We provide examples for each predictor module on the Obj3D and MOVi-A datasets. For instance: 70 | - [LSTM on Obj3D](https://github.com/AIS-Bonn/OCVP-object-centric-video-prediction/blob/master/experiments/Obj3D/Predictor_LSTM/experiment_params.json) 71 | - [Transformer on Obj3D](https://github.com/AIS-Bonn/OCVP-object-centric-video-prediction/blob/master/experiments/Obj3D/Predictor_Transformer/experiment_params.json) 72 | - [OCVP-Par on MOVi-A](https://github.com/AIS-Bonn/OCVP-object-centric-video-prediction/blob/master/experiments/MOViA/Predictor_OCVPPar/experiment_params.json) 73 | 74 | 75 | **3.** Train your predictor given the specified experiment parameters and a pretrained SAVi model: 76 | 77 | ``` 78 | usage: 04_train_predictor.py [-h] -d EXP_DIRECTORY [--checkpoint CHECKPOINT] [--resume_training] -m SAVI_MODEL --name_predictor_experiment 79 | NAME_PREDICTOR_EXPERIMENT 80 | 81 | optional arguments: 82 | -h, --help show this help message and exit 83 | -d EXP_DIRECTORY, --exp_directory EXP_DIRECTORY 84 | Path to the father exp. directory 85 | --checkpoint CHECKPOINT 86 | Checkpoint with predictor pretrained parameters to load 87 | --resume_training Resuming training 88 | -m SAVI_MODEL, --savi_model SAVI_MODEL 89 | Path to SAVi checkpoint to be used during in training or validation, from inside the experiments directory 90 | --name_predictor_experiment NAME_PREDICTOR_EXPERIMENT 91 | Name to the directory inside the exp_directory corresponding to a predictor experiment. 92 | ``` 93 | 94 | #### Example: Predictor Training 95 | 96 | Below we provide an example of how to train an object-centric predictor given a pretrained SAVi model. This example continues the example above 97 | 98 | ``` 99 | python src/01_create_predictor_experiment.py \ 100 | -d new_exps/my_exp \ 101 | --name my_OCVPSeq_model \ 102 | --predictor_name OCVP-Seq 103 | 104 | python src/04_train_predictor.py \ 105 | -d experiments/new_exps/my_exp 106 | --savi_model checkpoint_epoch_final.pth 107 | --name_predictor_experiment my_OCVPSeq_model 108 | ``` 109 | 110 | ## Training Time 111 | 112 | The table below summarizes the amout of epochs, hours and iterations (batches) required to train both the SAVi scene parsing module, as well as our OCVP-Seq predictor. 113 | These values correspond to experiments trained with an NVIDIA A6000 with 48Gb. 114 | 115 | | Dataset | Model | Iters. | Epochs | Time | 116 | | --- | --- | --- | --- | --- | 117 | | Ojb3D | SAVi | 100k | 2000 | 26h | 118 | |Obj3D | OCVP-Seq | 70k | 1500 | 34h | 119 | |MOVi-A |SAVi | 150k | 2000| 120h | 120 | |MOVi-A | OCVP-Seq | 50k | 300 | 18h | 121 | 122 | 123 | ## Further Comments 124 | 125 | - You can find examples of some experiment directories under the `experiments` directory. 126 | 127 | - The training can be monitored using Tensorboard. 128 | To launch tensorboard, 129 | ``` 130 | tensorboard --logdir experiments/EXP_DIR/EXP_NAME --port 8888 131 | ``` 132 | 133 | - In case of questions, do not hesitate to open an issue or contact the authors at `villar@ais.uni-bonn.de` 134 | -------------------------------------------------------------------------------- /src/05_evaluate_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluating an object-centric predictor model checkpoint. 3 | This module supports two different evaluations: 4 | - Visual quality of the predicted frames, pretty much video prediction. 5 | - Prediction quality of object dynamics. How well object segmentation masks are forecasted. 6 | """ 7 | 8 | import torch 9 | 10 | from data import unwrap_batch_data, unwrap_batch_data_masks 11 | from lib.arguments import get_predictor_evaluation_arguments 12 | from lib.logger import Logger, print_ 13 | import lib.utils as utils 14 | 15 | from base.basePredictorEvaluator import BasePredictorEvaluator 16 | 17 | 18 | class Evaluator(BasePredictorEvaluator): 19 | """ 20 | Evaluating an object-centric predictor model checkpoint. 21 | This module supports two different evaluations: 22 | - Visual quality of the predicted frames, pretty much video prediction. 23 | - Prediction quality of object dynamics. How well object segmentation masks are forecasted. 24 | """ 25 | 26 | MODES = ["VideoPred", "Masks"] 27 | 28 | @torch.no_grad() 29 | def forward_eval(self, batch_data, **kwargs): 30 | """ 31 | Making a forwad pass through the model and computing the evaluation metrics 32 | 33 | Args: 34 | ----- 35 | batch_data: dict 36 | Dictionary containing the information for the current batch, including images, poses, 37 | actions, or metadata, among others. 38 | 39 | Returns: 40 | -------- 41 | pred_data: dict 42 | Predictions from the model for the current batch of data 43 | """ 44 | num_context = self.exp_params["training_prediction"]["num_context"] 45 | num_preds = self.exp_params["training_prediction"]["num_preds"] 46 | video_length = self.exp_params["training_prediction"]["sample_length"] 47 | num_slots = self.model.num_slots 48 | slot_dim = self.model.slot_dim 49 | 50 | # fetching and preparing data 51 | videos, targets, initializer_data = self.unwrap_function(self.exp_params, batch_data) 52 | videos, targets = videos.to(self.device), targets.to(self.device) 53 | B, L, C, H, W = videos.shape 54 | if L < num_context + num_preds: 55 | raise ValueError(f"Seq. length {L} smaller that #seed {num_context} + #preds {num_preds}") 56 | 57 | # encoding images into object-centric slots, and temporally aligning slots 58 | out_model = self.model(videos, num_imgs=video_length, **initializer_data) 59 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model 60 | # predicting future slots 61 | pred_slots = self.predictor(slot_history) 62 | # decoding predicted slots into predicted frames 63 | pred_slots_decode = pred_slots.clone().reshape(B * num_preds, num_slots, slot_dim) 64 | img_recons, (pred_recons, pred_masks) = self.model.decode(pred_slots_decode) 65 | 66 | # selecting predictions and targets given evaluation mode, and computing evaluation metrics 67 | if self.evaluation_mode == "VideoPred": 68 | preds_eval = img_recons.view(B, num_preds, C, H, W).clamp(0, 1) 69 | targets_eval = targets[:, num_context:num_context+num_preds, :, :].clamp(0, 1) 70 | elif self.evaluation_mode == "Masks": 71 | pred_masks = pred_masks.reshape(B, num_preds, -1, H, W) 72 | preds_eval = torch.argmax(pred_masks, dim=2).squeeze(2) 73 | targets_eval = targets[:, num_context:num_context+num_preds, :, :] 74 | else: 75 | raise ValueError(f"{self.evaluation_mode = } not recognized in {Evaluator.MODES = }...") 76 | self.metric_tracker.accumulate( 77 | preds=preds_eval, 78 | targets=targets_eval 79 | ) 80 | return 81 | 82 | def set_evaluation_mode(self, evaluation_mode): 83 | """ 84 | Toggling functions depending on the current evaluation model 85 | """ 86 | if evaluation_mode not in Evaluator.MODES: 87 | raise ValueError(f"{evaluation_mode = } not recognized in {Evaluator.MODES = }...") 88 | print_(f"Setting evaluation to mode: {evaluation_mode}") 89 | 90 | if evaluation_mode == "VideoPred": 91 | self.evaluation_mode = "VideoPred" 92 | self.set_metric_tracker_video_pred() 93 | self.unwrap_function = unwrap_batch_data 94 | elif evaluation_mode == "Masks": 95 | self.evaluation_mode = "Masks" 96 | self.set_metric_tracker_object_pred() 97 | self.unwrap_function = unwrap_batch_data_masks 98 | return 99 | 100 | 101 | if __name__ == "__main__": 102 | utils.clear_cmd() 103 | all_args = get_predictor_evaluation_arguments() 104 | exp_path, savi_model, checkpoint, name_predictor_experiment, args = all_args 105 | 106 | logger = Logger(exp_path=f"{exp_path}/{name_predictor_experiment}") 107 | logger.log_info("Starting object-centric predictor evaluation procedure", message_type="new_exp") 108 | print_("Initializing Evaluator...") 109 | print_("Args:") 110 | print_("-----") 111 | for k, v in vars(args).items(): 112 | print_(f" --> {k} = {v}") 113 | 114 | evaluator = Evaluator( 115 | name_predictor_experiment=name_predictor_experiment, 116 | exp_path=exp_path, 117 | savi_model=savi_model, 118 | checkpoint=args.checkpoint, 119 | num_preds=args.num_preds 120 | ) 121 | print_("Loading dataset...") 122 | evaluator.load_data() 123 | print_("Setting up model and predictor and loading pretrained parameters") 124 | evaluator.load_model() 125 | evaluator.setup_predictor() 126 | 127 | # VIDEO PREDICTION EVALUATION 128 | print_("Starting video predictor evaluation") 129 | evaluator.set_evaluation_mode(evaluation_mode="VideoPred") 130 | evaluator.evaluate() 131 | 132 | # OBJECT DYNAMICS EVALUATION (only on datasets with segmentation labels) 133 | db_name = evaluator.exp_params["dataset"]["dataset_name"] 134 | if db_name not in ["MoviA", "MoviC"]: 135 | print_(f"Dataset {db_name} does not support 'Masks' evaluation...\n Finishing execution") 136 | exit() 137 | print_("Starting object-centric evaluation") 138 | evaluator.set_evaluation_mode(evaluation_mode="Masks") 139 | evaluator.evaluate() 140 | 141 | 142 | # 143 | -------------------------------------------------------------------------------- /src/02_train_savi.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training and Validating a SAVi video decomposition model 3 | """ 4 | 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | matplotlib.use('Agg') # for avoiding memory leak 8 | import torch 9 | 10 | from data.load_data import unwrap_batch_data 11 | from lib.arguments import get_directory_argument 12 | from lib.logger import Logger, print_ 13 | import lib.utils as utils 14 | from lib.visualizations import visualize_decomp, visualize_recons 15 | 16 | from base.baseTrainer import BaseTrainer 17 | 18 | 19 | class Trainer(BaseTrainer): 20 | """ 21 | Class for training a SAVi model for object-centric video 22 | """ 23 | 24 | def forward_loss_metric(self, batch_data, training=False, inference_only=False, **kwargs): 25 | """ 26 | Computing a forwad pass through the model, and (if necessary) the loss values and metrics 27 | 28 | Args: 29 | ----- 30 | batch_data: dict 31 | Dictionary containing the information for the current batch, including images, poses, 32 | actions, or metadata, among others. 33 | training: bool 34 | If True, model is in training mode 35 | inference_only: bool 36 | If True, only forward pass through the model is performed 37 | 38 | Returns: 39 | -------- 40 | pred_data: dict 41 | Predictions from the model for the current batch of data 42 | loss: torch.Tensor 43 | Total loss for the current batch 44 | """ 45 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data) 46 | 47 | # forward pass 48 | videos, targets = videos.to(self.device), targets.to(self.device) 49 | out_model = self.model(videos, num_imgs=videos.shape[1], **initializer_kwargs) 50 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model 51 | 52 | if inference_only: 53 | return out_model, None 54 | 55 | # if necessary, doing loss computation, backward pass, optimization, and computing metrics 56 | self.loss_tracker( 57 | pred_imgs=reconstruction_history.clamp(0, 1), 58 | target_imgs=targets.clamp(0, 1) 59 | ) 60 | 61 | loss = self.loss_tracker.get_last_losses(total_only=True) 62 | if training: 63 | self.optimizer.zero_grad() 64 | loss.backward() 65 | if self.exp_params["training_slots"]["gradient_clipping"]: 66 | torch.nn.utils.clip_grad_norm_( 67 | self.model.parameters(), 68 | self.exp_params["training_slots"]["clipping_max_value"] 69 | ) 70 | self.optimizer.step() 71 | 72 | return out_model, loss 73 | 74 | @torch.no_grad() 75 | def visualizations(self, batch_data, epoch, iter_): 76 | """ 77 | Making a visualization of some ground-truth, targets and predictions from the current model. 78 | """ 79 | if(iter_ % self.exp_params["training_slots"]["image_log_frequency"] != 0): 80 | return 81 | 82 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data) 83 | out_model, _ = self.forward_loss_metric( 84 | batch_data=batch_data, 85 | training=False, 86 | inference_only=True 87 | ) 88 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model 89 | N = min(10, videos.shape[1]) # max of 10 frames for sleeker figures 90 | 91 | # output reconstructions versus targets 92 | visualize_recons( 93 | imgs=targets[0][:N], 94 | recons=reconstruction_history[0][:N].clamp(0, 1), 95 | tag="target", 96 | savepath=None, 97 | tb_writer=self.writer, 98 | iter=iter_ 99 | ) 100 | 101 | # output reconstructions and input images 102 | visualize_recons( 103 | imgs=videos[0][:N], 104 | recons=reconstruction_history[0][:N].clamp(0, 1), 105 | savepath=None, 106 | tb_writer=self.writer, 107 | iter=iter_ 108 | ) 109 | 110 | # Rendered individual objects 111 | fig, _, _ = visualize_decomp( 112 | individual_recons_history[0][:N].clamp(0, 1), 113 | savepath=None, 114 | tag="objects_decomposed", 115 | vmin=0, 116 | vmax=1, 117 | tb_writer=self.writer, 118 | iter=iter_ 119 | ) 120 | plt.close(fig) 121 | 122 | # Rendered individual object masks 123 | fig, _, _ = visualize_decomp( 124 | masks_history[0][:N].clamp(0, 1), 125 | savepath=None, 126 | tag="masks", 127 | cmap="gray", 128 | vmin=0, 129 | vmax=1, 130 | tb_writer=self.writer, 131 | iter=iter_, 132 | ) 133 | plt.close(fig) 134 | 135 | # Rendered individual combination of an object with its masks 136 | recon_combined = masks_history[0][:N] * individual_recons_history[0][:N] 137 | fig, _, _ = visualize_decomp( 138 | recon_combined.clamp(0, 1), 139 | savepath=None, 140 | tag="reconstruction_combined", 141 | vmin=0, 142 | vmax=1, 143 | tb_writer=self.writer, 144 | iter=iter_ 145 | ) 146 | plt.close(fig) 147 | return 148 | 149 | 150 | if __name__ == "__main__": 151 | utils.clear_cmd() 152 | exp_path, args = get_directory_argument() 153 | logger = Logger(exp_path=exp_path) 154 | logger.log_info("Starting SAVi training procedure", message_type="new_exp") 155 | 156 | print_("Initializing SAVi Trainer...") 157 | print_("Args:") 158 | print_("-----") 159 | for k, v in vars(args).items(): 160 | print_(f" --> {k} = {v}") 161 | trainer = Trainer( 162 | exp_path=exp_path, 163 | checkpoint=args.checkpoint, 164 | resume_training=args.resume_training 165 | ) 166 | print_("Setting up model and optimizer") 167 | trainer.setup_model() 168 | print_("Loading dataset...") 169 | trainer.load_data() 170 | print_("Starting to train") 171 | trainer.training_loop() 172 | 173 | 174 | # 175 | -------------------------------------------------------------------------------- /src/lib/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions and loss-related utils 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from lib.logger import log_info 8 | from CONFIG import LOSSES 9 | 10 | 11 | class LossTracker: 12 | """ 13 | Class for computing, weighting and tracking several loss functions 14 | 15 | Args: 16 | ----- 17 | loss_params: dict 18 | Loss section of the experiment paramteres JSON file 19 | """ 20 | 21 | def __init__(self, loss_params): 22 | """ 23 | Loss tracker initializer 24 | """ 25 | assert isinstance(loss_params, list), f"Loss_params must be a list, not {type(loss_params)}" 26 | for loss in loss_params: 27 | if loss["type"] not in LOSSES: 28 | raise NotImplementedError(f"Loss {loss['type']} not implemented. Use one of {LOSSES}") 29 | 30 | self.loss_computers = {} 31 | for loss in loss_params: 32 | loss_type, loss_weight = loss["type"], loss["weight"] 33 | self.loss_computers[loss_type] = {} 34 | self.loss_computers[loss_type]["metric"] = get_loss(loss_type, **loss) 35 | self.loss_computers[loss_type]["weight"] = loss_weight 36 | self.reset() 37 | return 38 | 39 | def reset(self): 40 | """ 41 | Reseting loss tracker 42 | """ 43 | self.loss_values = {loss: [] for loss in self.loss_computers.keys()} 44 | self.loss_values["_total"] = [] 45 | return 46 | 47 | def __call__(self, **kwargs): 48 | """ 49 | Wrapper for calling accumulate 50 | """ 51 | self.accumulate(**kwargs) 52 | 53 | def accumulate(self, **kwargs): 54 | """ 55 | Computing the different metrics, weigting them according to their multiplier, 56 | and adding them to the results list. 57 | """ 58 | total_loss = 0 59 | for loss in self.loss_computers: 60 | loss_val = self.loss_computers[loss]["metric"](**kwargs) 61 | self.loss_values[loss].append(loss_val) 62 | total_loss = total_loss + loss_val * self.loss_computers[loss]["weight"] 63 | self.loss_values["_total"].append(total_loss) 64 | return 65 | 66 | def aggregate(self): 67 | """ 68 | Aggregating the results for each metric 69 | """ 70 | self.loss_values["mean_loss"] = {} 71 | for loss in self.loss_computers: 72 | self.loss_values["mean_loss"][loss] = torch.stack(self.loss_values[loss]).mean() 73 | self.loss_values["mean_loss"]["_total"] = torch.stack(self.loss_values["_total"]).mean() 74 | return 75 | 76 | def get_last_losses(self, total_only=False): 77 | """ 78 | Fetching the last computed loss value for each loss function 79 | """ 80 | if total_only: 81 | last_losses = self.loss_values["_total"][-1] 82 | else: 83 | last_losses = {loss: loss_vals[-1] for loss, loss_vals in self.loss_values.items()} 84 | return last_losses 85 | 86 | def summary(self, log=True, get_results=True): 87 | """ 88 | Printing and fetching the results 89 | """ 90 | if log: 91 | log_info("LOSS VALUES:") 92 | log_info("--------") 93 | for loss, loss_value in self.loss_values["mean_loss"].items(): 94 | log_info(f" {loss}: {round(loss_value.item(), 5)}") 95 | 96 | return_val = self.loss_values["mean_loss"] if get_results else None 97 | return return_val 98 | 99 | 100 | def get_loss(loss_type="mse", **kwargs): 101 | """ 102 | Loading a function of object for computing a loss 103 | """ 104 | if loss_type not in LOSSES: 105 | raise NotImplementedError(f"Loss {loss_type} not available. Use one of {LOSSES}") 106 | 107 | print(f"creating loss function of type: {loss_type}") 108 | if loss_type in ["mse", "l2"]: 109 | loss = MSELoss() 110 | elif loss_type in ["pred_img_mse"]: 111 | loss = PredImgMSELoss() 112 | elif loss_type in ["pred_slot_mse"]: 113 | loss = PredSlotMSELoss() 114 | return loss 115 | 116 | 117 | class MSELoss(nn.Module): 118 | """ 119 | Overriding MSE Loss 120 | """ 121 | 122 | def __init__(self): 123 | """ 124 | Module initializer 125 | """ 126 | super().__init__() 127 | self.mse = nn.MSELoss() 128 | 129 | def forward(self, **kwargs): 130 | """ 131 | Computing loss 132 | """ 133 | if "pred_imgs" not in kwargs: 134 | raise ValueError("'pred_imgs' must be given to LossTracker to compute 'MSELoss'") 135 | if "target_imgs" not in kwargs: 136 | raise ValueError("'target_imgs' must be given to LossTracker to compute 'MSELoss'") 137 | preds, targets = kwargs.get("pred_imgs"), kwargs.get("target_imgs") 138 | loss = self.mse(preds, targets) 139 | return loss 140 | 141 | 142 | class PredImgMSELoss(nn.Module): 143 | """ 144 | Pretty much the same MSE Loss. 145 | Use this loss on predicted images, while still enforcing MSELoss on predicted slots 146 | """ 147 | 148 | def __init__(self): 149 | """ 150 | Module initializer 151 | """ 152 | super().__init__() 153 | self.mse = nn.MSELoss() 154 | 155 | def forward(self, **kwargs): 156 | """ 157 | Computing loss 158 | """ 159 | if "pred_imgs" not in kwargs: 160 | raise ValueError("'pred_imgs' must be given to LossTracker to compute 'PredImgMSELoss'") 161 | if "target_imgs" not in kwargs: 162 | raise ValueError("'target_imgs' must be given to LossTracker to compute 'PredImgMSELoss'") 163 | preds, targets = kwargs.get("pred_imgs"), kwargs.get("target_imgs") 164 | loss = self.mse(preds, targets) 165 | return loss 166 | 167 | 168 | class PredSlotMSELoss(nn.Module): 169 | """ 170 | MSE Loss used on slot-like representations. This can be used when forecasting future slots. 171 | """ 172 | 173 | def __init__(self): 174 | """ 175 | Module initializer 176 | """ 177 | super().__init__() 178 | self.mse = nn.MSELoss() 179 | 180 | def forward(self, **kwargs): 181 | """ 182 | Computing loss 183 | """ 184 | if "preds" not in kwargs: 185 | raise ValueError("'pred' must be given to LossTracker to compute 'PredSlotMSELoss'") 186 | if "targets" not in kwargs: 187 | raise ValueError("'target_slots' must be given to LossTracker to compute 'PredSlotMSELoss'") 188 | preds, targets = kwargs.get("preds"), kwargs.get("targets") 189 | loss = self.mse(preds, targets) 190 | return loss 191 | 192 | # 193 | -------------------------------------------------------------------------------- /src/data/MoviConvert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataclass and loading of the MOVI dataset from the Tensorflow files. 3 | 4 | https://github.com/google-research/kubric/tree/main/challenges/movi 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | import tensorflow_datasets as tfds 12 | # from itertools import islice 13 | import tensorflow as tf 14 | import lib.visualizations as visualizations 15 | 16 | from CONFIG import CONFIG 17 | PATH = CONFIG["paths"]["data_path"] 18 | 19 | # Hide GPU from visible devices 20 | tf.config.set_visible_devices([], 'GPU') 21 | 22 | 23 | class _MOVI(Dataset): 24 | """ 25 | DataClass for the MOVI dataset. 26 | 27 | Args: 28 | ----- 29 | movi_type: string 30 | Type of MOVI dataset to use 31 | split: string 32 | Dataset split to load 33 | num_frames: int 34 | Desired length of the sequences to load 35 | img_size: tuple 36 | Images are resized to this resolution 37 | """ 38 | 39 | MAX_OBJS = 11 40 | 41 | def __init__(self, split, num_frames, img_size=(64, 64), slot_initializer="LearnedInit"): 42 | """ Dataset initializer """ 43 | assert split in ["train", "val", "valid", "validation", "test"] 44 | split = "validation" if split in ["val", "valid", "validation"] else split 45 | 46 | # dataset parameters 47 | self.split = split 48 | self.num_frames = num_frames 49 | self.img_size = img_size 50 | self.slot_initializer = slot_initializer 51 | self.get_masks = False 52 | self.get_bbox = False 53 | 54 | # resizer modules for the images and masks respectively 55 | self.resizer = transforms.Resize( 56 | self.img_size, 57 | interpolation=transforms.InterpolationMode.BILINEAR 58 | ) 59 | self.resizer_mask = transforms.Resize( 60 | self.img_size, 61 | interpolation=transforms.InterpolationMode.NEAREST 62 | ) 63 | 64 | # loading data 65 | self.db, self.len_db = self._load_data() 66 | return 67 | 68 | def __len__(self): 69 | """ Number of sequences in dataset """ 70 | return self.len_db 71 | 72 | def __getitem__(self, i): 73 | """ 74 | Sampling a sequence from the dataset 75 | """ 76 | all_data = next(self.db) 77 | # all_data = next(islice(self.db, i, i+1)) 78 | # all_data = self.db[i] 79 | 80 | # images 81 | imgs = torch.from_numpy(all_data["video"])[:self.num_frames].permute(0, 3, 1, 2) / 255 82 | imgs = self.resizer(imgs).float() 83 | 84 | # instance segmentations 85 | segmentation = torch.from_numpy(all_data["segmentations"][:self.num_frames, ..., 0]) 86 | segmentation = self.resizer_mask(segmentation) 87 | 88 | # coordinates 89 | bbox, com = self._get_bbox_com(all_data, imgs) 90 | 91 | # optical flow 92 | minv, maxv = all_data["metadata"]["forward_flow_range"] 93 | forward_flow = all_data["forward_flow"] / 65535 * (maxv - minv) + minv 94 | flow_rgb = visualizations.flow_to_rgb(forward_flow) 95 | flow_rgb = torch.from_numpy(flow_rgb).permute(0, 3, 1, 2) 96 | flow_rgb = self.resizer_mask(flow_rgb) 97 | 98 | data = { 99 | "frames": imgs, 100 | "masks": segmentation, 101 | "com_coords": com, 102 | "bbox_coords": bbox, 103 | "flow": flow_rgb 104 | } 105 | return imgs, data 106 | 107 | def _get_bbox_com(self, all_data, imgs): 108 | """ 109 | Obtaining BBox information 110 | """ 111 | bboxes = all_data["instances"]["bboxes"].numpy() 112 | bbox_frames = all_data["instances"]["bbox_frames"].numpy() 113 | num_frames, _, H, W = imgs.shape 114 | num_objects = bboxes.shape[0] 115 | com = torch.zeros(num_frames, num_objects, 2) 116 | bbox = torch.zeros(num_frames, num_objects, 4) 117 | for t in range(num_frames): 118 | for k in range(num_objects): 119 | if t in bbox_frames[k]: 120 | idx = np.nonzero(bbox_frames[k] == t)[0][0] 121 | min_y, min_x, max_y, max_x = bboxes[k][idx] 122 | min_y, min_x = max(1, min_y * H), max(1, min_x * W) 123 | max_y, max_x = min(H - 1, max_y * H), min(W - 1, max_x * W) 124 | bbox[t, k] = torch.tensor([min_x, min_y, max_x, max_y]) 125 | com[t, k] = torch.tensor([(max_x + min_x) / 2, (max_y + min_y) / 2]).round() 126 | else: 127 | bbox[t, k] = torch.ones(4) * -1 128 | com[t, k] = torch.ones(2) * -1 129 | 130 | # padding so as to batch BBoxes or CoMs 131 | if num_objects < self.MAX_OBJS: 132 | rest = self.MAX_OBJS - num_objects 133 | rest_bbox = torch.ones((bbox.shape[0], rest, 4), device=imgs.device) * -1 134 | rest_com = torch.ones((bbox.shape[0], rest, 2), device=imgs.device) * -1 135 | bbox = torch.cat([bbox, rest_bbox], dim=1) 136 | com = torch.cat([com, rest_com], dim=1) 137 | 138 | return bbox, com 139 | 140 | 141 | class _MoviA(_MOVI): 142 | """ 143 | DataClass for the Movi-A dataset. 144 | It contains CLEVR-like objects on a grey environment. Objects collide with each other 145 | """ 146 | 147 | def _load_data(self): 148 | """ Loading MOVI-A data""" 149 | print(f"Loading MOVI-A {self.split} set...") 150 | dataset_builder = tfds.builder( 151 | "movi_a/128x128:1.0.0", 152 | data_dir="/home/nfs/inf6/data/datasets/MOVi" 153 | ) 154 | split = "train" if self.split == "train" else "validation" 155 | ds = tfds.as_numpy(dataset_builder.as_dataset(split=split)) 156 | len_ds = len(ds) 157 | 158 | new_ds = iter(ds) 159 | return new_ds, len_ds 160 | 161 | 162 | class _MoviC(_MOVI): 163 | """ 164 | DataClass for the Movi-A dataset. 165 | Complex objects on a grey environment. Objects collide with each other 166 | """ 167 | 168 | def _load_data(self): 169 | """ Loading MOVI-B data""" 170 | print(f"Loading MOVI-C {self.split} set...") 171 | dataset_builder = tfds.builder( 172 | "movi_c/128x128:1.0.0", 173 | data_dir="/home/nfs/inf6/data/datasets/MOVi" 174 | ) 175 | split = "train" if self.split == "train" else "validation" 176 | ds = tfds.as_numpy(dataset_builder.as_dataset(split=split)) 177 | len_ds = len(ds) 178 | 179 | new_ds = iter(ds) 180 | return new_ds, len_ds 181 | 182 | 183 | # 184 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: OCVP 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - anyio=3.5.0=py39h06a4308_0 9 | - argon2-cffi=20.1.0=py39h27cfd23_1 10 | - asttokens=2.0.5=pyhd3eb1b0_0 11 | - attrs=21.4.0=pyhd3eb1b0_0 12 | - babel=2.9.1=pyhd3eb1b0_0 13 | - backcall=0.2.0=pyhd3eb1b0_0 14 | - beautifulsoup4=4.11.1=py39h06a4308_0 15 | - blas=1.0=mkl 16 | - bleach=4.1.0=pyhd3eb1b0_0 17 | - brotlipy=0.7.0=py39h27cfd23_1003 18 | - bzip2=1.0.8=h7b6447c_0 19 | - ca-certificates=2022.07.19=h06a4308_0 20 | - certifi=2022.6.15=py39h06a4308_0 21 | - cffi=1.15.0=py39hd667e15_1 22 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 23 | - cryptography=36.0.0=py39h9ce1e76_0 24 | - cudatoolkit=11.3.1=h2bc3f7f_2 25 | - debugpy=1.5.1=py39h295c915_0 26 | - decorator=5.1.1=pyhd3eb1b0_0 27 | - defusedxml=0.7.1=pyhd3eb1b0_0 28 | - entrypoints=0.4=py39h06a4308_0 29 | - executing=0.8.3=pyhd3eb1b0_0 30 | - ffmpeg=4.3=hf484d3e_0 31 | - freetype=2.11.0=h70c0345_0 32 | - giflib=5.2.1=h7b6447c_0 33 | - gmp=6.2.1=h2531618_2 34 | - gnutls=3.6.15=he1e5248_0 35 | - idna=3.3=pyhd3eb1b0_0 36 | - intel-openmp=2021.4.0=h06a4308_3561 37 | - ipykernel=6.9.1=py39h06a4308_0 38 | - ipython=8.4.0=py39h06a4308_0 39 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 40 | - jedi=0.18.1=py39h06a4308_1 41 | - jinja2=3.0.3=pyhd3eb1b0_0 42 | - jpeg=9d=h7f8727e_0 43 | - json5=0.9.6=pyhd3eb1b0_0 44 | - jsonschema=4.4.0=py39h06a4308_0 45 | - jupyter_client=7.1.2=pyhd3eb1b0_0 46 | - jupyter_core=4.10.0=py39h06a4308_0 47 | - jupyter_server=1.18.1=py39h06a4308_0 48 | - jupyterlab=3.4.4=py39h06a4308_0 49 | - jupyterlab_pygments=0.1.2=py_0 50 | - jupyterlab_server=2.12.0=py39h06a4308_0 51 | - lame=3.100=h7b6447c_0 52 | - lcms2=2.12=h3be6417_0 53 | - ld_impl_linux-64=2.35.1=h7274673_9 54 | - libffi=3.3=he6710b0_2 55 | - libgcc-ng=9.3.0=h5101ec6_17 56 | - libgomp=9.3.0=h5101ec6_17 57 | - libiconv=1.15=h63c8f33_5 58 | - libidn2=2.3.2=h7f8727e_0 59 | - libpng=1.6.37=hbc83047_0 60 | - libsodium=1.0.18=h7b6447c_0 61 | - libstdcxx-ng=9.3.0=hd4cf53a_17 62 | - libtasn1=4.16.0=h27cfd23_0 63 | - libtiff=4.2.0=h85742a9_0 64 | - libunistring=0.9.10=h27cfd23_0 65 | - libuv=1.40.0=h7b6447c_0 66 | - libwebp=1.2.2=h55f646e_0 67 | - libwebp-base=1.2.2=h7f8727e_0 68 | - lz4-c=1.9.3=h295c915_1 69 | - markupsafe=2.1.1=py39h7f8727e_0 70 | - matplotlib-inline=0.1.6=py39h06a4308_0 71 | - mistune=0.8.4=py39h27cfd23_1000 72 | - mkl=2021.4.0=h06a4308_640 73 | - mkl-service=2.4.0=py39h7f8727e_0 74 | - mkl_fft=1.3.1=py39hd3c417c_0 75 | - mkl_random=1.2.2=py39h51133e4_0 76 | - nbclassic=0.3.5=pyhd3eb1b0_0 77 | - nbclient=0.5.13=py39h06a4308_0 78 | - nbconvert=6.4.4=py39h06a4308_0 79 | - nbformat=5.3.0=py39h06a4308_0 80 | - ncurses=6.3=h7f8727e_2 81 | - nest-asyncio=1.5.5=py39h06a4308_0 82 | - nettle=3.7.3=hbbd107a_1 83 | - notebook=6.4.12=py39h06a4308_0 84 | - numpy=1.21.5=py39he7a7128_1 85 | - numpy-base=1.21.5=py39hf524024_1 86 | - openh264=2.1.1=h4ff587b_0 87 | - openssl=1.1.1q=h7f8727e_0 88 | - packaging=21.3=pyhd3eb1b0_0 89 | - pandocfilters=1.5.0=pyhd3eb1b0_0 90 | - parso=0.8.3=pyhd3eb1b0_0 91 | - pexpect=4.8.0=pyhd3eb1b0_3 92 | - pickleshare=0.7.5=pyhd3eb1b0_1003 93 | - pillow=9.0.1=py39h22f2fdc_0 94 | - pip=21.2.4=py39h06a4308_0 95 | - prometheus_client=0.14.1=py39h06a4308_0 96 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 97 | - ptyprocess=0.7.0=pyhd3eb1b0_2 98 | - pure_eval=0.2.2=pyhd3eb1b0_0 99 | - pycparser=2.21=pyhd3eb1b0_0 100 | - pygments=2.11.2=pyhd3eb1b0_0 101 | - pyopenssl=22.0.0=pyhd3eb1b0_0 102 | - pyrsistent=0.18.0=py39heee7806_0 103 | - pysocks=1.7.1=py39h06a4308_0 104 | - python=3.9.12=h12debd9_0 105 | - python-dateutil=2.8.2=pyhd3eb1b0_0 106 | - python-fastjsonschema=2.16.2=py39h06a4308_0 107 | - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0 108 | - pytorch-mutex=1.0=cuda 109 | - pytz=2022.1=py39h06a4308_0 110 | - pyzmq=22.3.0=py39h295c915_2 111 | - readline=8.1.2=h7f8727e_1 112 | - requests=2.27.1=pyhd3eb1b0_0 113 | - send2trash=1.8.0=pyhd3eb1b0_1 114 | - setuptools=61.2.0=py39h06a4308_0 115 | - six=1.16.0=pyhd3eb1b0_1 116 | - sniffio=1.2.0=py39h06a4308_1 117 | - soupsieve=2.3.1=pyhd3eb1b0_0 118 | - sqlite=3.38.2=hc218d9a_0 119 | - stack_data=0.2.0=pyhd3eb1b0_0 120 | - terminado=0.13.1=py39h06a4308_0 121 | - testpath=0.6.0=py39h06a4308_0 122 | - tk=8.6.11=h1ccaba5_0 123 | - torchaudio=0.11.0=py39_cu113 124 | - torchvision=0.12.0=py39_cu113 125 | - tornado=6.1=py39h27cfd23_0 126 | - traitlets=5.1.1=pyhd3eb1b0_0 127 | - typing_extensions=4.1.1=pyh06a4308_0 128 | - tzdata=2022a=hda174b7_0 129 | - urllib3=1.26.8=pyhd3eb1b0_0 130 | - wcwidth=0.2.5=pyhd3eb1b0_0 131 | - webencodings=0.5.1=py39h06a4308_1 132 | - websocket-client=0.58.0=py39h06a4308_4 133 | - wheel=0.37.1=pyhd3eb1b0_0 134 | - xz=5.2.5=h7b6447c_0 135 | - zeromq=4.3.4=h2531618_0 136 | - zlib=1.2.12=h7f8727e_2 137 | - zstd=1.4.9=haebb681_0 138 | - pip: 139 | - absl-py==1.0.0 140 | - astunparse==1.6.3 141 | - cachetools==5.0.0 142 | - click==8.1.3 143 | - cycler==0.11.0 144 | - dill==0.3.6 145 | - dm-tree==0.1.8 146 | - etils==1.0.0 147 | - flatbuffers==23.1.21 148 | - fonttools==4.32.0 149 | - fvcore==0.1.5.post20220512 150 | - gast==0.4.0 151 | - gitdb==4.0.9 152 | - gitpython==3.1.27 153 | - google-auth==2.6.5 154 | - google-auth-oauthlib==0.4.6 155 | - google-pasta==0.2.0 156 | - googleapis-common-protos==1.58.0 157 | - grpcio==1.44.0 158 | - h5py==3.7.0 159 | - imageio==2.18.0 160 | - importlib-metadata==4.11.3 161 | - importlib-resources==5.10.2 162 | - install==1.3.5 163 | - iopath==0.1.10 164 | - joblib==1.1.0 165 | - keras==2.11.0 166 | - kiwisolver==1.4.2 167 | - kornia==0.6.8 168 | - libclang==15.0.6.1 169 | - lpips==0.1.4 170 | - markdown==3.3.6 171 | - matplotlib==3.5.1 172 | - mediapy==1.1.4 173 | - networkx==2.8.7 174 | - oauthlib==3.2.0 175 | - opencv-python==4.6.0.66 176 | - opt-einsum==3.3.0 177 | - piqa==1.2.2 178 | - portalocker==2.5.1 179 | - promise==2.3 180 | - protobuf==3.19.6 181 | - psutil==5.9.4 182 | - pyasn1==0.4.8 183 | - pyasn1-modules==0.2.8 184 | - pyparsing==3.0.8 185 | - python-version==0.0.2 186 | - pywavelets==1.4.1 187 | - pyyaml==6.0 188 | - requests-oauthlib==1.3.1 189 | - rsa==4.8 190 | - scikit-image==0.19.3 191 | - scikit-learn==1.1.1 192 | - scipy==1.8.1 193 | - smmap==5.0.0 194 | - tabulate==0.8.10 195 | - tensorboard==2.11.2 196 | - tensorboard-data-server==0.6.1 197 | - tensorboard-plugin-wit==1.8.1 198 | - tensorflow==2.11.0 199 | - tensorflow-datasets==4.8.2 200 | - tensorflow-estimator==2.11.0 201 | - tensorflow-io-gcs-filesystem==0.30.0 202 | - tensorflow-metadata==1.12.0 203 | - termcolor==1.1.0 204 | - threadpoolctl==3.1.0 205 | - tifffile==2022.10.10 206 | - toml==0.10.2 207 | - torchfile==0.1.0 208 | - torchmetrics==0.9.3 209 | - tqdm==4.64.0 210 | - webcolors==1.12 211 | - werkzeug==2.1.1 212 | - wrapt==1.14.1 213 | - yacs==0.1.8 214 | - zipp==3.8.0 215 | prefix: /home/user/villar/anaconda3/envs/OCVP 216 | -------------------------------------------------------------------------------- /src/04_train_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training and Validation of an object-centric predictor module using a frozen and pretrained 3 | SAVI video decomposition model 4 | """ 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | matplotlib.use('Agg') # for avoiding memory leak 8 | import torch 9 | 10 | from data.load_data import unwrap_batch_data 11 | from lib.arguments import get_predictor_training_arguments 12 | from lib.logger import Logger, print_ 13 | import lib.utils as utils 14 | from lib.visualizations import visualize_decomp, visualize_qualitative_eval 15 | 16 | from base.basePredictorTrainer import BasePredictorTrainer 17 | 18 | 19 | class Trainer(BasePredictorTrainer): 20 | """ 21 | Training and Validation of an object-centric predictor module using a frozen and pretrained 22 | SAVI video decomposition model 23 | """ 24 | 25 | def forward_loss_metric(self, batch_data, training=False, inference_only=False, **kwargs): 26 | """ 27 | Computing a forwad pass through the model, and (if necessary) the loss values and metrics 28 | 29 | Args: 30 | ----- 31 | batch_data: dict 32 | Dictionary containing the information for the current batch, including images, poses, 33 | actions, or metadata, among others. 34 | training: bool 35 | If True, model is in training mode 36 | inference_only: bool 37 | If True, only forward pass through the model is performed 38 | 39 | Returns: 40 | -------- 41 | pred_data: dict 42 | Predictions from the model for the current batch of data 43 | loss: torch.Tensor 44 | Total loss for the current batch 45 | """ 46 | num_context = self.exp_params["training_prediction"]["num_context"] 47 | num_preds = self.exp_params["training_prediction"]["num_preds"] 48 | video_length = self.exp_params["training_prediction"]["sample_length"] 49 | num_slots = self.model.num_slots 50 | slot_dim = self.model.slot_dim 51 | 52 | # fetching and checking data 53 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data) 54 | videos, targets = videos.to(self.device), videos.to(self.device) 55 | B, L, C, H, W = videos.shape 56 | if L < num_context + num_preds: 57 | raise ValueError(f"Seq. length {L} smaller that #seed {num_context} + #preds {num_preds}") 58 | 59 | # encoding frames into object slots usign pretrained SAVi 60 | with torch.no_grad(): 61 | out_model = self.model(videos, num_imgs=video_length, **initializer_kwargs) 62 | slot_history, reconstruction_history, individual_recons_history, masks_history = out_model 63 | # predicting future slots 64 | pred_slots = self.predictor(slot_history) 65 | # rendering future objects and frames from predicted object slots 66 | pred_slots_decode = pred_slots.clone().reshape(B * num_preds, num_slots, slot_dim) 67 | img_recons, (pred_recons, pred_masks) = self.model.decode(pred_slots_decode) 68 | pred_imgs = img_recons.view(B, num_preds, C, H, W) 69 | 70 | # Generating only model outputs 71 | out_model = (pred_imgs, pred_recons, pred_masks) 72 | if inference_only: 73 | return out_model, None 74 | 75 | # if necessary, doing loss computation, backward pass, optimization, and computing metrics 76 | target_slots = slot_history[:, num_context:num_context+num_preds, :, :] 77 | target_imgs = targets[:, num_context:num_context+num_preds, :, :] 78 | self.loss_tracker( 79 | preds=pred_slots, 80 | targets=target_slots, 81 | pred_imgs=pred_imgs, 82 | target_imgs=target_imgs 83 | ) 84 | loss = self.loss_tracker.get_last_losses(total_only=True) 85 | if training: 86 | self.optimizer.zero_grad() 87 | loss.backward() 88 | if self.exp_params["training_slots"]["gradient_clipping"]: 89 | torch.nn.utils.clip_grad_norm_( 90 | self.predictor.parameters(), 91 | self.exp_params["training_slots"]["clipping_max_value"] 92 | ) 93 | self.optimizer.step() 94 | 95 | return out_model, loss 96 | 97 | @torch.no_grad() 98 | def visualizations(self, batch_data, epoch): 99 | """ 100 | Making a visualization of some ground-truth, targets and predictions from the current model. 101 | """ 102 | num_context = self.exp_params["training_prediction"]["num_context"] 103 | num_preds = self.exp_params["training_prediction"]["num_preds"] 104 | 105 | # forward pass 106 | videos, targets, initializer_kwargs = unwrap_batch_data(self.exp_params, batch_data) 107 | out_model, _ = self.forward_loss_metric( 108 | batch_data=batch_data, 109 | training=False, 110 | inference_only=True 111 | ) 112 | pred_imgs, pred_recons, pred_masks = out_model 113 | target_imgs = targets[:, num_context:num_context+num_preds, :, :] 114 | 115 | # visualitations 116 | ids = torch.linspace(0, videos.shape[0]-1, 3).round().int() # equispaced videos in batch 117 | for idx in range(3): 118 | k = ids[idx] 119 | fig, ax = visualize_qualitative_eval( 120 | context=videos[k, :num_context], 121 | targets=target_imgs[k], 122 | preds=pred_imgs[k], 123 | savepath=None 124 | ) 125 | self.writer.add_figure(tag=f"Qualitative Eval {k+1}", figure=fig, step=epoch + 1) 126 | plt.close(fig) 127 | 128 | objs = pred_masks[k*num_preds:(k+1)*num_preds] * pred_recons[k*num_preds:(k+1)*num_preds] 129 | fig, _, _ = visualize_decomp( 130 | objs.clamp(0, 1), 131 | savepath=None, 132 | tag=f"Pred. Object Recons. {k+1}", 133 | tb_writer=self.writer, 134 | iter=epoch 135 | ) 136 | plt.close(fig) 137 | return 138 | 139 | 140 | if __name__ == "__main__": 141 | utils.clear_cmd() 142 | exp_path, savi_model, checkpoint, name_predictor_experiment, args = get_predictor_training_arguments() 143 | logger = Logger(exp_path=f"{exp_path}/{name_predictor_experiment}") 144 | logger.log_info("Starting object-centric predictor training procedure", message_type="new_exp") 145 | 146 | print_("Initializing Trainer...") 147 | print_("Args:") 148 | print_("-----") 149 | for k, v in vars(args).items(): 150 | print_(f" --> {k} = {v}") 151 | trainer = Trainer( 152 | name_predictor_experiment=name_predictor_experiment, 153 | exp_path=exp_path, 154 | savi_model=savi_model, 155 | checkpoint=args.checkpoint, 156 | resume_training=args.resume_training 157 | ) 158 | print_("Loading dataset...") 159 | trainer.load_data() 160 | print_("Setting up model, predictor and optimizer") 161 | trainer.load_model() 162 | trainer.setup_predictor() 163 | print_("Starting to train") 164 | trainer.training_loop() 165 | 166 | 167 | # 168 | -------------------------------------------------------------------------------- /src/base/basePredictorEvaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base predictor evaluator from which all predictor evaluator classes inherit. 3 | Basically it removes the scaffolding that is repeat across all predictor evaluator modules 4 | """ 5 | 6 | import os 7 | from tqdm import tqdm 8 | import torch 9 | 10 | from lib.config import Config 11 | from lib.logger import print_, log_function, for_all_methods 12 | from lib.metrics import MetricTracker 13 | import lib.setup_model as setup_model 14 | import lib.utils as utils 15 | import data as datalib 16 | from models.model_utils import freeze_params 17 | 18 | 19 | @for_all_methods(log_function) 20 | class BasePredictorEvaluator: 21 | """ 22 | Base Class for evaluating a slot predictor model 23 | 24 | Args: 25 | ----- 26 | exp_path: string 27 | Path to the experiment directory from which to read the experiment parameters, 28 | and where to store logs, plots and checkpoints 29 | name_predictor_experiment: string 30 | Name of the predictor experiment (subdirectory in parent directory) to train. 31 | savi_model: string 32 | Name of the pretrained SAVI model used to extract object representation from frames 33 | and to decode the predicted slots back to images 34 | checkpoint: string/None 35 | Name of a model checkpoint stored in the models/ directory of the experiment directory. 36 | If given, the model is initialized with the parameters of such checkpoint. 37 | This can be used to continue training or for transfer learning. 38 | num_preds: int 39 | Number of predictions to make and evaluate 40 | """ 41 | 42 | def __init__(self, name_predictor_experiment, exp_path, savi_model, checkpoint, 43 | num_preds=None, **kwargs): 44 | """ 45 | Initializing the predictor evaluator object 46 | """ 47 | self.parent_exp_path = exp_path 48 | self.exp_path = os.path.join(exp_path, name_predictor_experiment) 49 | self.cfg = Config(self.exp_path) 50 | self.exp_params = self.cfg.load_exp_config_file() 51 | self.savi_model = savi_model 52 | self.checkpoint = checkpoint 53 | self.num_preds_args = num_preds 54 | 55 | # overriding 'num_preds' if given as argument 56 | if num_preds is not None: 57 | num_seed = self.exp_params["training_prediction"]["num_context"] 58 | self.exp_params["training_prediction"]["num_preds"] = num_preds 59 | self.exp_params["training_prediction"]["sample_length"] = num_seed + num_preds 60 | print_(f" --> Overriding 'num_preds' to {num_preds}") 61 | print_(f" --> New 'sample_length' is {num_seed + num_preds}") 62 | 63 | self.plots_path = os.path.join(self.exp_path, "plots") 64 | utils.create_directory(self.plots_path) 65 | self.models_path = os.path.join(self.exp_path, "models") 66 | utils.create_directory(self.models_path) 67 | 68 | return 69 | 70 | def set_metric_tracker(self): 71 | """ 72 | Initializing the metric tracker with the Video Prediction tracker by default 73 | """ 74 | self.set_metric_tracker_video_pred() 75 | 76 | def set_metric_tracker_video_pred(self): 77 | """ 78 | Initializing the metric tracker with Video Prediction metrics 79 | """ 80 | self.metric_tracker = MetricTracker( 81 | self.exp_path, 82 | metrics=["psnr", "ssim", "lpips"] 83 | ) 84 | 85 | def set_metric_tracker_object_pred(self): 86 | """ 87 | Initializing the metric tracker with Object-centric metrics 88 | """ 89 | self.test_set.get_masks = True 90 | self.metric_tracker = MetricTracker( 91 | self.exp_path, 92 | metrics=["segmentation_ari", "IoU"] 93 | ) 94 | 95 | def load_data(self): 96 | """ 97 | Loading test dataset and fitting data-loader for iterating in a batch-like fashion 98 | """ 99 | batch_size = self.exp_params["training_prediction"]["batch_size"] 100 | shuffle_eval = self.exp_params["dataset"]["shuffle_eval"] 101 | self.test_set = datalib.load_data( 102 | exp_params=self.exp_params, 103 | split="test" 104 | ) 105 | self.test_loader = datalib.build_data_loader( 106 | dataset=self.test_set, 107 | batch_size=batch_size, 108 | shuffle=shuffle_eval 109 | ) 110 | return 111 | 112 | def load_model(self): 113 | """ 114 | Load pretrained SAVi model from checkpoint 115 | """ 116 | torch.backends.cudnn.fastest = True 117 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 118 | 119 | # loading model 120 | self.model = setup_model.setup_model(model_params=self.exp_params["model"]) 121 | self.model = self.model.eval().to(self.device) 122 | 123 | # loading pretrained parameters and freezing 124 | checkpoint_path = os.path.join(self.parent_exp_path, "models", self.savi_model) 125 | self.model = setup_model.load_checkpoint( 126 | checkpoint_path=checkpoint_path, 127 | model=self.model, 128 | only_model=True 129 | ) 130 | freeze_params(self.model) 131 | return 132 | 133 | def setup_predictor(self): 134 | """ 135 | Load pretrained predictor model from the corresponding model checkpoint 136 | """ 137 | predictor = setup_model.setup_predictor(exp_params=self.exp_params) 138 | predictor = predictor.eval().to(self.device) 139 | 140 | print_(f"Loading pretrained parameters from checkpoint {self.checkpoint}...") 141 | predictor = setup_model.load_checkpoint( 142 | checkpoint_path=os.path.join(self.models_path, self.checkpoint), 143 | model=predictor, 144 | only_model=True, 145 | ) 146 | self.predictor = predictor 147 | self.set_metric_tracker() 148 | return 149 | 150 | @torch.no_grad() 151 | def evaluate(self, save_results=True): 152 | """ 153 | Evaluating model epoch loop 154 | """ 155 | num_context = self.exp_params["training_prediction"]["num_context"] 156 | num_preds = self.exp_params["training_prediction"]["num_preds"] 157 | self.model = self.model.eval() 158 | progress_bar = tqdm(enumerate(self.test_loader), total=len(self.test_loader)) 159 | 160 | # iterating test set and accumulating the results 161 | for i, batch_data in progress_bar: 162 | self.forward_eval(batch_data=batch_data) 163 | progress_bar.set_description(f"Iter {i}/{len(self.test_loader)}") 164 | 165 | self.metric_tracker.aggregate() 166 | _ = self.metric_tracker.summary() 167 | fname = f"{self.checkpoint[:-4]}_NumPreds={num_preds}" 168 | self.metric_tracker.save_results(exp_path=self.exp_path, fname=fname) 169 | self.metric_tracker.make_plots( 170 | start_idx=num_context, 171 | savepath=os.path.join(self.exp_path, "results", fname) 172 | ) 173 | return 174 | 175 | def forward_eval(self, batch_data, **kwargs): 176 | """ 177 | Making a forwad pass through the model and computing the evaluation metrics 178 | 179 | Args: 180 | ----- 181 | batch_data: dict 182 | Dictionary containing the information for the current batch, including images, poses, 183 | actions, or metadata, among others. 184 | """ 185 | raise NotImplementedError("Base Evaluator Module does not implement 'forward_eval'...") 186 | 187 | 188 | # 189 | -------------------------------------------------------------------------------- /src/lib/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utils methods for bunch of purposes, including 3 | - Reading/writing files 4 | - Creating directories 5 | - Timestamp 6 | - Handling tensorboard 7 | """ 8 | 9 | import os 10 | import pickle 11 | import shutil 12 | import random 13 | import datetime 14 | import numpy as np 15 | import torch 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | from lib.logger import log_function 19 | from CONFIG import CONFIG 20 | 21 | 22 | def set_random_seed(random_seed=None): 23 | """ 24 | Using random seed for numpy and torch 25 | """ 26 | if(random_seed is None): 27 | random_seed = CONFIG["random_seed"] 28 | os.environ['PYTHONHASHSEED'] = str(random_seed) 29 | random.seed(random_seed) 30 | np.random.seed(random_seed) 31 | torch.manual_seed(random_seed) 32 | torch.cuda.manual_seed_all(random_seed) 33 | return 34 | 35 | 36 | def load_pickle_file(path): 37 | """ Loading pickle file """ 38 | with open(path, "rb") as a_file: 39 | data = pickle.load(a_file) 40 | return data 41 | 42 | 43 | def save_pickle_file(path, data): 44 | """ Saving pickle file """ 45 | with open(path, "wb") as file: 46 | pickle.dump(data, file) 47 | return 48 | 49 | 50 | def clear_cmd(): 51 | """Clearning command line window""" 52 | os.system('cls' if os.name == 'nt' else 'clear') 53 | return 54 | 55 | 56 | @log_function 57 | def create_directory(dir_path, dir_name=None): 58 | """ 59 | Creating a folder in given path. 60 | """ 61 | if(dir_name is not None): 62 | dir_path = os.path.join(dir_path, dir_name) 63 | if(not os.path.exists(dir_path)): 64 | os.makedirs(dir_path) 65 | return 66 | 67 | 68 | def delete_directory(dir_path): 69 | """ 70 | Deleting a directory and all its contents 71 | """ 72 | if os.path.exists(dir_path): 73 | shutil.rmtree(dir_path) 74 | return 75 | 76 | 77 | def split_path(path): 78 | """ Splitting a path into a list containing the names of all directories to the path """ 79 | allparts = [] 80 | while 1: 81 | parts = os.path.split(path) 82 | if parts[0] == path: # sentinel for absolute paths 83 | allparts.insert(0, parts[0]) 84 | break 85 | elif parts[1] == path: # sentinel for relative paths 86 | allparts.insert(0, parts[1]) 87 | break 88 | else: 89 | path = parts[0] 90 | allparts.insert(0, parts[1]) 91 | return allparts 92 | 93 | 94 | def timestamp(): 95 | """ 96 | Obtaining the current timestamp in an human-readable way 97 | """ 98 | timestamp = str(datetime.datetime.now()).split('.')[0].replace(' ', '_').replace(':', '-') 99 | return timestamp 100 | 101 | 102 | @log_function 103 | def log_architecture(model, exp_path, fname="model_architecture.txt"): 104 | """ 105 | Printing architecture modules into a txt file 106 | """ 107 | assert fname[-4:] == ".txt", "ERROR! 'fname' must be a .txt file" 108 | savepath = os.path.join(exp_path, fname) 109 | 110 | # getting all_params 111 | with open(savepath, "w") as f: 112 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 113 | f.write(f"Total Params: {num_params}") 114 | 115 | for i, layer in enumerate(model.children()): 116 | if(isinstance(layer, torch.nn.Module)): 117 | log_module(module=layer, exp_path=exp_path, fname=fname) 118 | return 119 | 120 | 121 | def log_module(module, exp_path, fname="model_architecture.txt", append=True): 122 | """ 123 | Printing architecture modules into a txt file 124 | """ 125 | assert fname[-4:] == ".txt", "ERROR! 'fname' must be a .txt file" 126 | savepath = os.path.join(exp_path, fname) 127 | 128 | # writing from scratch or appending to existing file 129 | if (append is False): 130 | with open(savepath, "w") as f: 131 | f.write("") 132 | else: 133 | with open(savepath, "a") as f: 134 | f.write("\n\n") 135 | 136 | # writing info 137 | with open(savepath, "a") as f: 138 | num_params = sum(p.numel() for p in module.parameters() if p.requires_grad) 139 | f.write(f"Params: {num_params}") 140 | f.write("\n") 141 | f.write(str(module)) 142 | return 143 | 144 | 145 | def press_yes_to_continue(message, key="y"): 146 | """ Asking the user for input to continut """ 147 | if isinstance(message, (list, tuple)): 148 | for m in message: 149 | print(m) 150 | else: 151 | print(message) 152 | val = input(f"Press '{key}' to continue...") 153 | if(val != key): 154 | print("Exiting...") 155 | exit() 156 | return 157 | 158 | 159 | def press_yes_or_no(message): 160 | """ Asking the user for input for yes or no """ 161 | if isinstance(message, (list, tuple)): 162 | for m in message: 163 | print(m) 164 | else: 165 | print(message) 166 | val = "" 167 | while val not in ["y", "n"]: 168 | val = input("Press 'y' to for yes or 'n' for no...") 169 | if val not in ["y", "n"]: 170 | print(f" Key {val} is not valid...") 171 | return val 172 | 173 | 174 | def get_from_dict(params, key_list): 175 | """ Getting a value from a dictionary given a list with the keys to get there """ 176 | for key in key_list: 177 | params = params[key] 178 | return params 179 | 180 | 181 | def set_in_dict(params, key_list, value): 182 | """ Updating a dictionary value, indexed by a list of keys to get there """ 183 | for key in key_list[:-1]: 184 | # params = params.setdefault(key, {}) 185 | params = params[key] 186 | params[key_list[-1]] = value 187 | return 188 | 189 | 190 | class TensorboardWriter: 191 | """ 192 | Class for handling the tensorboard logger 193 | 194 | Args: 195 | ----- 196 | logdir: string 197 | path where the tensorboard logs will be stored 198 | """ 199 | 200 | def __init__(self, logdir): 201 | """ Initializing tensorboard writer """ 202 | self.logdir = logdir 203 | self.writer = SummaryWriter(logdir) 204 | return 205 | 206 | def add_scalar(self, name, val, step): 207 | """ Adding a scalar for plot """ 208 | self.writer.add_scalar(name, val, step) 209 | return 210 | 211 | def add_scalars(self, plot_name, val_names, vals, step): 212 | """ Adding several values in one plot """ 213 | val_dict = {val_name: val for (val_name, val) in zip(val_names, vals)} 214 | self.writer.add_scalars(plot_name, val_dict, step) 215 | return 216 | 217 | def add_image(self, fig_name, img_grid, step): 218 | """ Adding a new step image to a figure """ 219 | self.writer.add_image(fig_name, img_grid, global_step=step) 220 | return 221 | 222 | def add_images(self, fig_name, img_grid, step): 223 | """ Adding a new step image to a figure """ 224 | self.writer.add_images(fig_name, img_grid, global_step=step) 225 | return 226 | 227 | def add_figure(self, tag, figure, step): 228 | """ Adding a whole new figure to the tensorboard """ 229 | self.writer.add_figure(tag=tag, figure=figure, global_step=step) 230 | return 231 | 232 | def add_graph(self, model, input): 233 | """ Logging model graph to tensorboard """ 234 | self.writer.add_graph(model, input_to_model=input) 235 | return 236 | 237 | def log_full_dictionary(self, dict, step, plot_name="Losses", dir=None): 238 | """ 239 | Logging a bunch of losses into the Tensorboard. Logging each of them into 240 | its independent plot and into a joined plot 241 | """ 242 | if dir is not None: 243 | dict = {f"{dir}/{key}": val for key, val in dict.items()} 244 | else: 245 | dict = {key: val for key, val in dict.items()} 246 | 247 | for key, val in dict.items(): 248 | self.add_scalar(name=key, val=val, step=step) 249 | 250 | plot_name = f"{dir}/{plot_name}" if dir is not None else key 251 | self.add_scalars(plot_name=plot_name, val_names=dict.keys(), vals=dict.values(), step=step) 252 | return 253 | 254 | # 255 | -------------------------------------------------------------------------------- /src/models/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Model utils 3 | """ 4 | 5 | from time import time 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis 11 | 12 | from lib.logger import print_, log_info 13 | 14 | 15 | def build_grid(resolution, vmin=-1., vmax=1., device=None): 16 | """ 17 | Building four grids with gradients [0,1] in directios (x,-x,y,-y) 18 | This can be used as a positional encoding. 19 | 20 | Args: 21 | ----- 22 | resolution: list/tuple of integers 23 | number of elements in each of the gradients 24 | 25 | Returns: 26 | ------- 27 | torch_grid: torch Tensor 28 | Grid gradients in 4 directions. Shape is [R, R, 4] 29 | """ 30 | ranges = [np.linspace(vmin, vmax, num=res) for res in resolution] 31 | grid = np.meshgrid(*ranges, sparse=False, indexing="ij") 32 | grid = np.stack(grid, axis=-1) 33 | grid = np.reshape(grid, [resolution[0], resolution[1], -1]) 34 | grid = np.expand_dims(grid, axis=0) 35 | grid = grid.astype(np.float32) 36 | torch_grid = torch.from_numpy(np.concatenate([grid, 1.0 - grid], axis=-1)).to(device) 37 | return torch_grid 38 | 39 | 40 | def conv_transpose_out_shape(in_size, stride, padding, kernel_size, out_padding, dilation=1): 41 | """ 42 | Calculating the output shape of a Transposed Conv. Decoder 43 | """ 44 | return (in_size - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + out_padding + 1 45 | 46 | 47 | def count_model_params(model, verbose=False): 48 | """ 49 | Counting number of learnable parameters 50 | """ 51 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 52 | if verbose: 53 | print_(f" --> Number of learnable parameters: {num_params}") 54 | return num_params 55 | 56 | 57 | def compute_flops(model, dummy_input, verbose=True, detailed=False): 58 | """ 59 | Computing the number of activations and flops in a forward pass 60 | """ 61 | func = print_ if verbose else log_info 62 | 63 | # benchmarking 64 | print(f"model device {next(model.parameters()).device}") 65 | print(f"dummy input device: {dummy_input.device}") 66 | fca = FlopCountAnalysis(model, dummy_input) 67 | print("fca ok") 68 | act = ActivationCountAnalysis(model, dummy_input) 69 | print("act ok") 70 | if detailed: 71 | fcs = flop_count_str(fca) 72 | print("detailed ok") 73 | func(fcs) 74 | total_flops = fca.total() 75 | print("total 1 ok") 76 | total_act = act.total() 77 | print("total 2 ok") 78 | 79 | # logging 80 | func(" --> Number of FLOPS in a forward pass:") 81 | func(f" --> FLOPS = {total_flops}") 82 | func(f" --> FLOPS = {round(total_flops / 1e9, 3)}G") 83 | func(" --> Number of activations in a forward pass:") 84 | func(f" --> Activations = {total_act}") 85 | func(f" --> Activations = {round(total_act / 1e6, 3)}M") 86 | return total_flops, total_act 87 | 88 | 89 | def compute_throughput(model, dataset, device, num_imgs=500, use_tqdm=True, verbose=True): 90 | """ 91 | Computing the throughput of a model in imgs/s 92 | """ 93 | times = [] 94 | N = min(num_imgs, len(dataset)) 95 | iterator = tqdm(range(N)) if use_tqdm else range(N) 96 | model = model.to(device) 97 | 98 | # benchmarking by averaging over N images 99 | for i in iterator: 100 | img = dataset[i][0].unsqueeze(0).to(device) 101 | torch.cuda.synchronize() 102 | start = time() 103 | _ = model(img) 104 | torch.cuda.synchronize() 105 | times.append(time() - start) 106 | avg_time_per_img = np.mean(times) 107 | throughput = 1 / avg_time_per_img 108 | 109 | # logging 110 | func = print_ if verbose else log_info 111 | func(f" --> Average time per image: {round(avg_time_per_img, 3)}s") 112 | func(f" --> Throughput: {round(throughput)} imgs/s") 113 | return throughput, avg_time_per_img 114 | 115 | 116 | def freeze_params(model): 117 | """ 118 | Freezing model params to avoid updates in backward pass 119 | """ 120 | for param in model.parameters(): 121 | param.requires_grad = False 122 | return model 123 | 124 | 125 | def unfreeze_params(model): 126 | """ 127 | Unfreezing model params to allow for updates during backward pass 128 | """ 129 | for param in model.parameters(): 130 | param.requires_grad = True 131 | return model 132 | 133 | 134 | class GradientInspector: 135 | """ 136 | Module that computes some statistics from the gradients of one parameter, 137 | and logs the stats into the Tensorboard# 138 | 139 | Args: 140 | ----- 141 | writer: TensorboardWriter 142 | TensorboardWriter object used to log into the Tensorboard 143 | layers: list of nn.Module 144 | Layers whose gradients are processed and logged into the Tensorboard 145 | names: list of strings 146 | Name given to each of the layers to track 147 | stats: list 148 | List with the stats to track. Possible stats are: ['Min', 'Max', 'Mean', 'Var', 'Norm'] 149 | """ 150 | 151 | STATS = ["Min", "Max", "Mean", "Var", "Norm"] 152 | FUNCS = { 153 | "Min": torch.min, 154 | "Max": torch.max, 155 | "Mean": torch.mean, 156 | "Var": torch.var, 157 | "Norm": torch.norm, 158 | } 159 | 160 | def __init__(self, writer, layers, names, stats=None): 161 | """ Module initializer """ 162 | stats = stats if stats is not None else GradientInspector.STATS 163 | for stat in stats: 164 | assert stat in GradientInspector.STATS, f"{stat = } not included in {self.STATS = }" 165 | assert isinstance(layers, list), f"Layers is not list, but {type(layers)}..." 166 | assert len(layers) == len(names), f"{len(layers) = } and {len(names) = } must be the same..." 167 | for layer in layers: 168 | assert isinstance(layer, torch.nn.Module), f"Layer is not nn.Module, but {type(layer)}..." 169 | assert hasattr(layer, "weight"), "Layer does not have attribute 'weight'" 170 | 171 | self.writer = writer 172 | self.layers = layers 173 | self.names = names 174 | self.stats = stats 175 | 176 | print_("Initializing Gradient-Inspector:") 177 | print_(f" --> Tracking stats {stats} of gradients in the following layers") 178 | for name, layer in zip(names, layers): 179 | print_(f" --> {name}: {layer}") 180 | return 181 | 182 | def __call__(self, step): 183 | """ Computing gradient stats and logging into Tensorboard """ 184 | for layer, name in zip(self.layers, self.names): 185 | grad = layer.weight.grad 186 | for stat in self.stats: 187 | func = self.FUNCS[stat] 188 | self.writer.add_scalar(f"Grad Stats {name}/{stat} Grad", func(grad).item(), step) 189 | return 190 | 191 | 192 | def get_norm_layer(norm="batch"): 193 | """ 194 | Selecting norm layer by name 195 | """ 196 | assert norm in ["batch", "instance", "group", "layer", "", None] 197 | if norm == "batch": 198 | norm_layer = nn.BatchNorm2d 199 | elif norm == "instance": 200 | norm_layer = nn.InstanceNorm2d 201 | elif norm == "group": 202 | norm_layer = nn.GroupNorm 203 | elif norm == "layer": 204 | norm_layer = nn.LayerNorm 205 | elif norm == "" or norm is None: 206 | norm_layer = nn.Identity 207 | return norm_layer 208 | 209 | 210 | @torch.no_grad() 211 | def init_xavier_(model: nn.Module): 212 | """ 213 | Initializes (in-place) a model's weights with xavier uniform, and its biases to zero. 214 | All parameters with name containing "bias" are initialized to zero. 215 | All other parameters are initialized with xavier uniform with default parameters, 216 | unless they have dimensionality <= 1. 217 | """ 218 | for name, tensor in model.named_parameters(): 219 | if name.endswith(".bias"): 220 | tensor.zero_() 221 | elif len(tensor.shape) <= 1: 222 | pass # silent 223 | else: 224 | torch.nn.init.xavier_uniform_(tensor) 225 | 226 | 227 | # 228 | -------------------------------------------------------------------------------- /src/models/model_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic building blocks for neural nets 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | 9 | from models.model_utils import build_grid 10 | 11 | __all__ = ["ConvBlock", "ConvTransposeBlock", "SoftPositionEmbed", "PositionalEncoding"] 12 | 13 | 14 | class ConvBlock(nn.Module): 15 | """ 16 | Simple convolutional block for conv. encoders 17 | 18 | Args: 19 | ----- 20 | in_channels: int 21 | Number of channels in the input feature maps. 22 | out_channels: int 23 | Number of convolutional kernels in the conv layer 24 | kernel_size: int 25 | Size of the kernel for the conv layer 26 | stride: int 27 | Amount of strid applied in the convolution 28 | padding: int/None 29 | Whether to pad the input feature maps, and how much padding to use. 30 | batch_norm: bool 31 | If True, Batch Norm is applied after the convolutional layer 32 | max_pool: int/tuple/None 33 | If not None, output feature maps are downsampled by this amount via max pooling 34 | activation: bool 35 | If True, output feature maps are activated via a ReLU nonlinearity. 36 | """ 37 | 38 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, 39 | batch_norm=False, max_pool=None, activation=True): 40 | """ 41 | Module initializer 42 | """ 43 | super().__init__() 44 | padding = padding if padding is not None else kernel_size // 2 45 | 46 | # adding conv-(bn)-(pool)-act layer 47 | layers = [] 48 | layers.append( 49 | nn.Conv2d( 50 | in_channels=in_channels, 51 | out_channels=out_channels, 52 | kernel_size=kernel_size, 53 | stride=stride, 54 | padding=padding 55 | ) 56 | ) 57 | if batch_norm: 58 | layers.append(nn.BatchNorm2d(num_features=out_channels)) 59 | if max_pool: 60 | assert isinstance(max_pool, (int, tuple, list)) 61 | layers.append(nn.MaxPool2d(kernel_size=max_pool, stride=max_pool)) 62 | if activation: 63 | layers.append(nn.ReLU()) 64 | 65 | self.block = nn.Sequential(*layers) 66 | return 67 | 68 | def forward(self, x): 69 | """ 70 | Forward pass 71 | """ 72 | y = self.block(x) 73 | return y 74 | 75 | 76 | class ConvTransposeBlock(nn.Module): 77 | """ 78 | Simple transposed-convolutional block for conv. decoders 79 | 80 | Args: 81 | ----- 82 | in_channels: int 83 | Number of channels in the input feature maps. 84 | out_channels: int 85 | Number of convolutional kernels in the conv layer 86 | kernel_size: int 87 | Size of the kernel for the conv layer 88 | stride: int 89 | Amount of strid applied in the convolution 90 | padding: int/None 91 | Whether to pad the input feature maps, and how much padding to use. 92 | batch_norm: bool 93 | If True, Batch Norm is applied after the convolutional layer 94 | upsample: int/tuple/None 95 | If not None, output feature maps are upsampled by this amount via (nn.) Upsampling 96 | activation: bool 97 | If True, output feature maps are activated via a ReLU nonlinearity. 98 | conv_transpose_2d: bool 99 | If True, Transposed convolutional layers are used. 100 | Otherwise, standard convolutions (combined with Upsampling) are applied. 101 | """ 102 | 103 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, 104 | batch_norm=False, upsample=None, activation=True, conv_transpose_2d=True): 105 | """ Module initializer """ 106 | super().__init__() 107 | padding = padding if padding is not None else kernel_size // 2 108 | 109 | # adding conv-(bn)-(pool)-act layer 110 | layers = [] 111 | if conv_transpose_2d: 112 | layers.append( 113 | nn.ConvTranspose2d( 114 | in_channels=in_channels, 115 | out_channels=out_channels, 116 | kernel_size=kernel_size, 117 | stride=stride, 118 | padding=padding, 119 | ) 120 | ) 121 | else: 122 | layers.append( 123 | nn.Conv2d( 124 | in_channels=in_channels, 125 | out_channels=out_channels, 126 | kernel_size=kernel_size, 127 | stride=stride, 128 | padding=padding) 129 | ) 130 | if batch_norm: 131 | layers.append(nn.BatchNorm2d(num_features=out_channels)) 132 | if upsample: 133 | assert isinstance(upsample, (int, tuple, list)) 134 | layers.append(nn.Upsample(scale_factor=upsample)) 135 | if activation: 136 | layers.append(nn.ReLU()) 137 | 138 | self.block = nn.Sequential(*layers) 139 | return 140 | 141 | def forward(self, x): 142 | """ 143 | Forward pass 144 | """ 145 | y = self.block(x) 146 | return y 147 | 148 | 149 | class SoftPositionEmbed(nn.Module): 150 | """ 151 | Soft positional embedding with learnable linear projection. 152 | 1. The positional encoding corresponds to a 4-channel grid with coords [-1, ..., 1] and 153 | [1, ..., -1] in the vertical and horizontal directions 154 | 2. The 4 channels are projected into a hidden_dimension via a linear layer (or Conv-1D) 155 | 156 | 157 | Args: 158 | ----- 159 | hidden_size: int 160 | Number of output channels 161 | resolution: list/tuple of integers 162 | Number of elements in the positional embedding. Corresponds to a spatial size 163 | vmin, vmax: int 164 | Minimum and maximum values in the grids. By default vmin=-1 and vmax=1 165 | """ 166 | 167 | def __init__(self, hidden_size, resolution, vmin=-1., vmax=1.): 168 | """ 169 | Soft positional encoding 170 | """ 171 | super().__init__() 172 | self.projection = nn.Conv2d(4, hidden_size, kernel_size=1) 173 | self.grid = build_grid(resolution, vmin=-1., vmax=1.).permute(0, 3, 1, 2) 174 | return 175 | 176 | def forward(self, inputs, channels_last=True): 177 | """ 178 | Projecting grid and adding to inputs 179 | """ 180 | b_size = inputs.shape[0] 181 | if self.grid.device != inputs.device: 182 | self.grid = self.grid.to(inputs.device) 183 | grid = self.grid.repeat(b_size, 1, 1, 1) 184 | emb_proj = self.projection(grid) 185 | if channels_last: 186 | emb_proj = emb_proj.permute(0, 2, 3, 1) 187 | return inputs + emb_proj 188 | 189 | 190 | class PositionalEncoding(nn.Module): 191 | """ 192 | Positional encoding to be added to the input tokens of the transformer predictor. 193 | 194 | Our positional encoding only informs about the time-step, i.e., all slots extracted 195 | from the same input frame share the same positional embedding. This allows our predictor 196 | model to maintain the permutation equivariance properties. 197 | 198 | Args: 199 | ----- 200 | batch_size: int 201 | Number of elements in the batch. 202 | num_slots: int 203 | Number of slots extracted per frame. Positional encoding will be repeat for each of these. 204 | d_model: int 205 | Dimensionality of the slots/tokens 206 | dropout: float 207 | Percentage of dropout to apply after adding the poisitional encoding. Default is 0.1 208 | max_len: int 209 | Length of the sequence. 210 | """ 211 | 212 | def __init__(self, d_model, dropout=0.1, max_len=50): 213 | """ 214 | Initializing the positional encoding 215 | """ 216 | super().__init__() 217 | self.dropout = nn.Dropout(p=dropout) 218 | 219 | # initializing sinusoidal positional embedding 220 | position = torch.arange(max_len).unsqueeze(1) 221 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 222 | pe = torch.zeros(max_len, 1, d_model) 223 | pe[:, 0, 0::2] = torch.sin(position * div_term) 224 | pe[:, 0, 1::2] = torch.cos(position * div_term) 225 | pe = pe.view(1, max_len, 1, d_model) 226 | self.pe = pe 227 | return 228 | 229 | def forward(self, x, batch_size, num_slots): 230 | """ 231 | Adding the positional encoding to the input tokens of the transformer 232 | 233 | Args: 234 | ----- 235 | x: torch Tensor 236 | Tokens to enhance with positional encoding. Shape is (B, Seq_len, Num_Slots, Token_Dim) 237 | batch_size: int 238 | Given batch size to repeat the positional encoding for 239 | num_slots: int 240 | Number of slots to repear the positional encoder for 241 | """ 242 | if x.device != self.pe.device: 243 | self.pe = self.pe.to(x.device) 244 | cur_seq_len = x.shape[1] 245 | cur_pe = self.pe.repeat(batch_size, 1, num_slots, 1)[:, :cur_seq_len] 246 | x = x + cur_pe 247 | y = self.dropout(x) 248 | return y 249 | 250 | 251 | # 252 | -------------------------------------------------------------------------------- /src/lib/schedulers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of learning rate schedulers, early stopping and other utils 3 | for improving optimization 4 | """ 5 | 6 | from lib.logger import print_ 7 | 8 | 9 | def update_scheduler(scheduler, exp_params, control_metric=None, iter=-1, end_epoch=False): 10 | """ 11 | Updating the learning rate scheduler by performing the scheduler step. 12 | 13 | Args: 14 | ----- 15 | scheduler: torch.optim 16 | scheduler to evaluate 17 | exp_params: dictionary 18 | dictionary containing the experiment parameters 19 | control_metric: float/torch Tensor 20 | Last computed validation metric. 21 | Needed for plateau scheduler 22 | iter: float 23 | number of optimization step. 24 | Needed for cyclic, cosine and exponential schedulers 25 | end_epoch: boolean 26 | True after finishing a validation epoch or certain number of iterations. 27 | Triggers schedulers such as plateau or fixed-step 28 | """ 29 | scheduler_type = exp_params["training_slots"]["scheduler"] 30 | if(scheduler_type == "plateau" and end_epoch): 31 | scheduler.step(control_metric) 32 | elif(scheduler_type in ["step", "multi_step"] and end_epoch): 33 | scheduler.step() 34 | elif(scheduler_type == "exponential" and not end_epoch): 35 | scheduler.step(iter) 36 | elif(scheduler_type == "cosine_annealing" and not end_epoch): 37 | scheduler.step() 38 | else: 39 | pass 40 | return 41 | 42 | 43 | class ExponentialLRSchedule: 44 | """ 45 | Exponential LR Scheduler that decreases the learning rate by multiplying it 46 | by an exponentially decreasing decay factor: 47 | LR = LR * gamma ^ (step/total_steps) 48 | 49 | Args: 50 | ----- 51 | optimizer: torch.optim 52 | Optimizer to schedule 53 | init_lr: float 54 | base learning rate to decrease with the exponential scheduler 55 | gamma: float 56 | exponential decay factor 57 | total_steps: int/float 58 | number of optimization steps to optimize for. Once this is reached, 59 | lr is not decreased anymore 60 | """ 61 | 62 | def __init__(self, optimizer, init_lr, gamma=0.5, total_steps=1_000_000): 63 | """ 64 | Module initializer 65 | """ 66 | self.optimizer = optimizer 67 | self.init_lr = init_lr 68 | self.gamma = gamma 69 | self.total_steps = total_steps 70 | return 71 | 72 | def update_lr(self, step): 73 | """ 74 | Computing exponential lr update 75 | """ 76 | new_lr = self.init_lr * self.gamma ** (step / self.total_steps) 77 | return new_lr 78 | 79 | def step(self, iter): 80 | """ 81 | Scheduler step 82 | """ 83 | if(iter < self.total_steps): 84 | for params in self.optimizer.param_groups: 85 | params["lr"] = self.update_lr(iter) 86 | elif(iter == self.total_steps): 87 | print_(f"Finished exponential decay due to reach of {self.total_steps} steps") 88 | return 89 | 90 | def state_dict(self): 91 | """ 92 | State dictionary 93 | """ 94 | state_dict = {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 95 | return state_dict 96 | 97 | def load_state_dict(self, state_dict): 98 | """ 99 | Loading state dictinary 100 | """ 101 | self.init_lr = state_dict["init_lr"] 102 | self.gamma = state_dict["gamma"] 103 | self.total_steps = state_dict["total_steps"] 104 | return 105 | 106 | 107 | class LRWarmUp: 108 | """ 109 | Class for performing learning rate warm-ups. We increase the learning rate 110 | during the first few iterations until it reaches the standard LR 111 | 112 | Args: 113 | ----- 114 | init_lr: float 115 | initial learning rate 116 | warmup_steps: integer 117 | number of optimization steps to warm up for 118 | max_epochs: integer 119 | maximum number of epochs to warmup. It overrides 'warmup_step' 120 | """ 121 | 122 | def __init__(self, init_lr, warmup_steps, max_epochs=1): 123 | """ 124 | Initializer 125 | """ 126 | self.init_lr = init_lr 127 | self.warmup_steps = warmup_steps 128 | self.max_epochs = max_epochs 129 | self.active = True 130 | self.final_step = -1 131 | 132 | def __call__(self, iter, epoch, optimizer): 133 | """ 134 | Computing actual learning rate and updating optimizer 135 | """ 136 | if(iter > self.warmup_steps): 137 | if(self.active): 138 | self.final_step = iter 139 | self.active = False 140 | lr = self.init_lr 141 | print_("Finished learning rate warmup period...") 142 | print_(f" --> Reached iter {iter} >= {self.warmup_steps}") 143 | print_(f" --> Reached at epoch {epoch}") 144 | elif(epoch >= self.max_epochs): 145 | if(self.active): 146 | self.final_step = iter 147 | self.active = False 148 | lr = self.init_lr 149 | print_("Finished learning rate warmup period:") 150 | print_(f" --> Reached epoch {epoch} >= {self.max_epochs}") 151 | print_(f" --> Reached at iter {iter}") 152 | else: 153 | if iter >= 0: 154 | lr = self.init_lr * (iter / self.warmup_steps) 155 | for params in optimizer.param_groups: 156 | params["lr"] = lr 157 | return 158 | 159 | def state_dict(self): 160 | """ 161 | State dictionary 162 | """ 163 | state_dict = {key: value for key, value in self.__dict__.items()} 164 | return state_dict 165 | 166 | def load_state_dict(self, state_dict): 167 | """ 168 | Loading state dictinary 169 | """ 170 | self.init_lr = state_dict.init_lr 171 | self.warmup_steps = state_dict.warmup_steps 172 | self.max_epochs = state_dict.max_epochs 173 | self.active = state_dict.active 174 | self.final_step = state_dict.final_step 175 | return 176 | 177 | 178 | class EarlyStop: 179 | """ 180 | Implementation of an early stop criterion 181 | 182 | Args: 183 | ----- 184 | mode: string ['min', 'max'] 185 | whether we validate based on maximizing or minmizing a metric 186 | delta: float 187 | threshold to consider improvements 188 | patience: integer 189 | number of epochs without improvement to trigger early stopping 190 | """ 191 | 192 | def __init__(self, mode="min", delta=1e-6, patience=7): 193 | """ 194 | Early stopper initializer 195 | """ 196 | assert mode in ["min", "max"] 197 | self.mode = mode 198 | self.delta = delta 199 | self.patience = patience 200 | self.counter = 0 201 | 202 | if(mode == "min"): 203 | self.best = 1e15 204 | self.criterion = lambda x: x < (self.best - self.min_delta) 205 | elif(mode == "max"): 206 | self.best = 1e-15 207 | self.criterion = lambda x: x < (self.best - self.min_delta) 208 | 209 | return 210 | 211 | def __call__(self, value): 212 | """ 213 | Comparing current metric agains best past results and computing if we 214 | should early stop or not 215 | 216 | Args: 217 | ----- 218 | value: float 219 | validation metric measured by the early stopping criterion 220 | 221 | Returns: 222 | -------- 223 | stop_training: boolean 224 | If True, we should early stop. Otherwise, metric is still improving 225 | """ 226 | are_we_better = self.criterion(value) 227 | if(are_we_better): 228 | self.counter = 0 229 | self.best = value 230 | else: 231 | self.counter = self.counter + 1 232 | 233 | stop_training = True if(self.counter >= self.patience) else False 234 | 235 | return stop_training 236 | 237 | 238 | class WarmupVSScehdule: 239 | """ 240 | Orquestrator module that calls the LR-Warmup module during the warmup iterations, 241 | and makes calls the LR Scheduler once warmup is finished. 242 | """ 243 | 244 | def __init__(self, optimizer, lr_warmup, scheduler): 245 | """ 246 | Initializer of the Warmup-Scheduler orquestrator 247 | """ 248 | self.optimizer = optimizer 249 | self.lr_warmup = lr_warmup 250 | self.scheduler = scheduler 251 | return 252 | 253 | def __call__(self, iter, epoch, exp_params, end_epoch, control_metric=None): 254 | """ 255 | Calling either LR-Warmup or LR-Scheduler 256 | """ 257 | if self.lr_warmup.active: 258 | self.lr_warmup(iter=iter, epoch=epoch, optimizer=self.optimizer) 259 | else: 260 | update_scheduler( 261 | scheduler=self.scheduler, 262 | exp_params=exp_params, 263 | iter=iter - self.lr_warmup.final_step - 1, 264 | end_epoch=end_epoch, 265 | control_metric=control_metric 266 | ) 267 | return 268 | 269 | 270 | # 271 | --------------------------------------------------------------------------------