├── checkpoints └── PLACE PRE-TRAINED CHECKPOINTS HERE.txt ├── assets ├── dlp_tut.PNG └── dlp_example_output.jpg ├── requirements.txt ├── accel_conf.yml ├── configs ├── diffuse_ddlp.json ├── balls.json ├── phyre.json ├── clevrer.json ├── obj3d128.json ├── obj3d.json ├── traffic.json ├── obj3d_img.json ├── shapes.json ├── traffic_img.json ├── obj3d128_img.json └── generate_config_file.py ├── docs ├── README.md ├── installation.md ├── gui.md ├── example_usage.py └── hyperparameters.md ├── LICENSE ├── datasets ├── shapes_ds.py ├── get_dataset.py ├── balls_ds.py ├── obj3d_ds.py ├── traffic_ds.py ├── phyre_ds.py └── clevrer_ds.py ├── generate_ddlp_video_prediction.py ├── environment.yml ├── generate_diffuse_ddlp_video_generation.py ├── eval └── eval_gen_metrics.py ├── utils └── loss_functions.py ├── train_diffuse_ddlp.py └── train_dlp.py /checkpoints/PLACE PRE-TRAINED CHECKPOINTS HERE.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/dlp_tut.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/ddlp/main/assets/dlp_tut.PNG -------------------------------------------------------------------------------- /assets/dlp_example_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taldatech/ddlp/main/assets/dlp_example_output.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.19.0 2 | einops==0.6.1 3 | h5py==3.1.0 4 | imageio==2.28.1 5 | matplotlib==3.7.1 6 | numpy==1.24.3 7 | opencv_python==3.4.18.65 8 | piqa==1.3.1 9 | scikit_image==0.19.2 10 | torch==2.0.1 11 | torchvision==0.15.2 12 | tqdm==4.65.0 13 | ttkthemes==3.2.2 14 | ttkwidgets==0.13.0 15 | -------------------------------------------------------------------------------- /accel_conf.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | fsdp_config: {} 6 | machine_rank: 0 7 | main_process_ip: null 8 | main_process_port: null 9 | main_training_function: main 10 | mixed_precision: 'no' 11 | num_machines: 1 12 | num_processes: 4 13 | rdzv_backend: static 14 | same_network: true 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /configs/diffuse_ddlp.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "phyre", 3 | "ds_root": "/media/newhd/data/phyre/full", 4 | "device": "cuda", 5 | "batch_size": 32, 6 | "lr": 0.00008, 7 | "train_num_steps": 700000, 8 | "diffusion_num_steps": 1000, 9 | "loss_type": "l1", 10 | "particle_norm": "minmax", 11 | "diffuse_frames": 4, 12 | "ddlp_dir": "./checkpoints/ddlp-phyre", 13 | "ddlp_ckpt": "./checkpoints/ddlp-phyre/phyre_ddlp.pth", 14 | "result_dir": null 15 | } -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | We provide documentation for better understanding of the models. 4 | 5 | | File | Content | 6 | |-----------------------|-----------------------------------------------------| 7 | | `installation.md` |Manual instructions to install packages with `conda`| 8 | | `gui.md` | Instructions of using the GUI: loading models and examples, interacting with the particles | 9 | | `hyperparameters.md` | Explanations of the various hyperparameters of the models and recommended values | 10 | | `example_usage.py` | overview of the models functionality: forward output, loss calculation and sampling | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tal Daniel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/shapes_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple Random Colored Shapes Dataset 3 | """ 4 | # imports 5 | import numpy as np 6 | from skimage.draw import random_shapes 7 | from tqdm.auto import tqdm 8 | import torch 9 | 10 | 11 | def generate_shape_dataset(img_size=64, min_shapes=2, max_shapes=5, min_size=10, max_size=12, allow_overlap=False, 12 | num_images=10_000): 13 | images = [] 14 | for i in tqdm(range(num_images)): 15 | img, _ = random_shapes((img_size, img_size), min_shapes=min_shapes, max_shapes=max_shapes, 16 | intensity_range=((0, 200),), min_size=min_size, max_size=max_size, 17 | allow_overlap=allow_overlap, num_trials=100) 18 | img[:, :, 0][img[:, :, 0] == 255] = 0 19 | img[:, :, 1][img[:, :, 1] == 255] = 255 20 | img[:, :, 2][img[:, :, 2] == 255] = 255 21 | img = img / 255.0 22 | images.append(img) 23 | images = np.stack(images, axis=0) # [num_mages, H, W, 3] 24 | return images 25 | 26 | 27 | def generate_shape_dataset_torch(img_size=64, min_shapes=2, max_shapes=5, min_size=11, max_size=13, allow_overlap=False, 28 | num_images=10_000): 29 | images = generate_shape_dataset(img_size=img_size, min_shapes=min_shapes, max_shapes=max_shapes, min_size=min_size, 30 | max_size=max_size, 31 | allow_overlap=allow_overlap, num_images=num_images) 32 | # create torch dataset 33 | img_data_torch = images.transpose(0, 3, 1, 2) # [num_images, 3, H, W] 34 | img_ds = torch.utils.data.TensorDataset(torch.tensor(img_data_torch, dtype=torch.float), torch.arange(num_images)) 35 | return img_ds -------------------------------------------------------------------------------- /configs/balls.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "balls", 3 | "root": "/mnt/data/tal/gswm_balls/BALLS_INTERACTION", 4 | "device": "cuda", 5 | "batch_size": 32, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "mse", 14 | "sigma": 1.0, 15 | "beta_kl": 0.1, 16 | "beta_rec": 1.0, 17 | "patch_size": 8, 18 | "topk": 6, 19 | "n_kp_enc": 6, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 3, 22 | "n_kp_prior": 12, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 64, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 128 39 | ], 40 | "prior_channels": [ 41 | 16, 42 | 32, 43 | 64 44 | ], 45 | "timestep_horizon": 10, 46 | "predict_delta": true, 47 | "beta_dyn": 0.1, 48 | "scale_std": 0.3, 49 | "offset_std": 0.2, 50 | "obj_on_alpha": 0.1, 51 | "obj_on_beta": 0.1, 52 | "beta_dyn_rec": 1.0, 53 | "num_static_frames": 4, 54 | "pint_layers": 6, 55 | "pint_heads": 8, 56 | "pint_dim": 256, 57 | "run_prefix": "", 58 | "animation_horizon": 100, 59 | "eval_im_metrics": true, 60 | "use_resblock": false, 61 | "scheduler_gamma": 0.95, 62 | "adam_betas": [ 63 | 0.9, 64 | 0.999 65 | ], 66 | "adam_eps": 0.0001, 67 | "train_enc_prior": true, 68 | "start_dyn_epoch": 0, 69 | "cond_steps": 10, 70 | "animation_fps": 0.06, 71 | "use_correlation_heatmaps": true, 72 | "enable_enc_attn": false, 73 | "filtering_heuristic": "variance" 74 | } -------------------------------------------------------------------------------- /configs/phyre.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "phyre", 3 | "root": "/mnt/data/tal/phyre/", 4 | "device": "cuda", 5 | "batch_size": 3, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "mse", 14 | "sigma": 1.0, 15 | "beta_kl": 0.15, 16 | "beta_rec": 1.0, 17 | "patch_size": 16, 18 | "topk": 10, 19 | "n_kp_enc": 25, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 4, 22 | "n_kp_prior": 30, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.125, 32 | "kl_balance": 0.001, 33 | "image_size": 128, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 128, 39 | 256 40 | ], 41 | "prior_channels": [ 42 | 16, 43 | 32, 44 | 64 45 | ], 46 | "timestep_horizon": 15, 47 | "predict_delta": true, 48 | "beta_dyn": 0.15, 49 | "scale_std": 0.3, 50 | "offset_std": 0.2, 51 | "obj_on_alpha": 0.1, 52 | "obj_on_beta": 0.1, 53 | "beta_dyn_rec": 1.0, 54 | "num_static_frames": 4, 55 | "pint_layers": 6, 56 | "pint_heads": 8, 57 | "pint_dim": 512, 58 | "run_prefix": "", 59 | "animation_horizon": 60, 60 | "eval_im_metrics": true, 61 | "use_resblock": false, 62 | "scheduler_gamma": 0.95, 63 | "adam_betas": [ 64 | 0.9, 65 | 0.999 66 | ], 67 | "adam_eps": 0.0001, 68 | "train_enc_prior": true, 69 | "start_dyn_epoch": 0, 70 | "cond_steps": 10, 71 | "animation_fps": 0.05, 72 | "use_correlation_heatmaps": true, 73 | "enable_enc_attn": false, 74 | "filtering_heuristic": "variance" 75 | } -------------------------------------------------------------------------------- /configs/clevrer.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "clevrer", 3 | "root": "/datadrive/clevrer/", 4 | "device": "cuda", 5 | "batch_size": 8, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "vgg", 14 | "sigma": 1.0, 15 | "beta_kl": 100.0, 16 | "beta_rec": 1.0, 17 | "patch_size": 16, 18 | "topk": 10, 19 | "n_kp_enc": 12, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 8, 22 | "n_kp_prior": 16, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 128, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 128, 39 | 256 40 | ], 41 | "prior_channels": [ 42 | 16, 43 | 32, 44 | 64 45 | ], 46 | "timestep_horizon": 20, 47 | "predict_delta": true, 48 | "beta_dyn": 100.0, 49 | "scale_std": 0.3, 50 | "offset_std": 0.2, 51 | "obj_on_alpha": 0.1, 52 | "obj_on_beta": 0.1, 53 | "beta_dyn_rec": 1.0, 54 | "num_static_frames": 4, 55 | "pint_layers": 6, 56 | "pint_heads": 8, 57 | "pint_dim": 256, 58 | "run_prefix": "", 59 | "animation_horizon": 100, 60 | "eval_im_metrics": true, 61 | "use_resblock": false, 62 | "scheduler_gamma": 0.95, 63 | "adam_betas": [ 64 | 0.9, 65 | 0.999 66 | ], 67 | "adam_eps": 0.0001, 68 | "train_enc_prior": true, 69 | "start_dyn_epoch": 0, 70 | "cond_steps": 10, 71 | "animation_fps": 0.06, 72 | "use_correlation_heatmaps": true, 73 | "enable_enc_attn": false, 74 | "filtering_heuristic": "variance" 75 | } -------------------------------------------------------------------------------- /configs/obj3d128.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "obj3d128", 3 | "root": "/mnt/data/tal/obj3d/", 4 | "device": "cuda", 5 | "batch_size": 4, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "vgg", 14 | "sigma": 1.0, 15 | "beta_kl": 100.0, 16 | "beta_rec": 1.0, 17 | "patch_size": 16, 18 | "topk": 10, 19 | "n_kp_enc": 12, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 8, 22 | "n_kp_prior": 16, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 128, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 128, 39 | 256 40 | ], 41 | "prior_channels": [ 42 | 16, 43 | 32, 44 | 64 45 | ], 46 | "timestep_horizon": 10, 47 | "predict_delta": true, 48 | "beta_dyn": 100.0, 49 | "scale_std": 0.3, 50 | "offset_std": 0.2, 51 | "obj_on_alpha": 0.1, 52 | "obj_on_beta": 0.1, 53 | "beta_dyn_rec": 1.0, 54 | "num_static_frames": 4, 55 | "pint_layers": 6, 56 | "pint_heads": 8, 57 | "pint_dim": 256, 58 | "run_prefix": "", 59 | "animation_horizon": 100, 60 | "eval_im_metrics": true, 61 | "use_resblock": false, 62 | "scheduler_gamma": 0.95, 63 | "adam_betas": [ 64 | 0.9, 65 | 0.999 66 | ], 67 | "adam_eps": 0.0001, 68 | "train_enc_prior": true, 69 | "start_dyn_epoch": 0, 70 | "cond_steps": 10, 71 | "animation_fps": 0.06, 72 | "use_correlation_heatmaps": true, 73 | "enable_enc_attn": false, 74 | "filtering_heuristic": "variance" 75 | } -------------------------------------------------------------------------------- /configs/obj3d.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "obj3d", 3 | "root": "/mnt/data/tal/obj3d/", 4 | "device": "cuda", 5 | "batch_size": 10, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "vgg", 14 | "sigma": 1.0, 15 | "beta_kl": 30.0, 16 | "beta_rec": 1.0, 17 | "patch_size": 8, 18 | "topk": 8, 19 | "n_kp_enc": 8, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 10, 22 | "n_kp_prior": 16, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 64, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 128 39 | ], 40 | "prior_channels": [ 41 | 16, 42 | 32, 43 | 64 44 | ], 45 | "timestep_horizon": 10, 46 | "predict_delta": true, 47 | "beta_dyn": 30.0, 48 | "scale_std": 0.3, 49 | "offset_std": 0.2, 50 | "obj_on_alpha": 0.1, 51 | "obj_on_beta": 0.1, 52 | "beta_dyn_rec": 1.0, 53 | "num_static_frames": 4, 54 | "pint_layers": 6, 55 | "pint_heads": 8, 56 | "pint_dim": 256, 57 | "run_prefix": "", 58 | "animation_horizon": 100, 59 | "eval_im_metrics": true, 60 | "use_resblock": false, 61 | "scheduler_gamma": 0.95, 62 | "adam_betas": [ 63 | 0.9, 64 | 0.999 65 | ], 66 | "adam_eps": 0.0001, 67 | "train_enc_prior": true, 68 | "start_dyn_epoch": 0, 69 | "cond_steps": 10, 70 | "animation_fps": 0.06, 71 | "use_correlation_heatmaps": true, 72 | "enable_enc_attn": false, 73 | "filtering_heuristic": "variance", 74 | "use_tracking": true 75 | } -------------------------------------------------------------------------------- /configs/traffic.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "traffic", 3 | "root": "/home/tal/data/traffic/img128np.npy", 4 | "device": "cuda", 5 | "batch_size": 8, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "vgg", 14 | "sigma": 1.0, 15 | "beta_kl": 40.0, 16 | "beta_rec": 1.0, 17 | "patch_size": 16, 18 | "topk": 10, 19 | "n_kp_enc": 25, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 8, 22 | "n_kp_prior": 30, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 3, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 128, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 128, 39 | 256 40 | ], 41 | "prior_channels": [ 42 | 16, 43 | 32, 44 | 64 45 | ], 46 | "timestep_horizon": 10, 47 | "predict_delta": true, 48 | "beta_dyn": 40.0, 49 | "scale_std": 0.3, 50 | "offset_std": 0.2, 51 | "obj_on_alpha": 0.1, 52 | "obj_on_beta": 0.1, 53 | "beta_dyn_rec": 1.0, 54 | "num_static_frames": 4, 55 | "pint_layers": 6, 56 | "pint_heads": 8, 57 | "pint_dim": 256, 58 | "run_prefix": "", 59 | "animation_horizon": 100, 60 | "eval_im_metrics": true, 61 | "use_resblock": false, 62 | "scheduler_gamma": 0.95, 63 | "adam_betas": [ 64 | 0.9, 65 | 0.999 66 | ], 67 | "adam_eps": 0.0001, 68 | "train_enc_prior": true, 69 | "start_dyn_epoch": 0, 70 | "cond_steps": 10, 71 | "animation_fps": 0.06, 72 | "use_correlation_heatmaps": true, 73 | "enable_enc_attn": false, 74 | "filtering_heuristic": "variance" 75 | } -------------------------------------------------------------------------------- /configs/obj3d_img.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "obj3d", 3 | "root": "/media/newhd/data/obj3d/", 4 | "device": "cuda", 5 | "batch_size": 32, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "vgg", 14 | "sigma": 1.0, 15 | "beta_kl": 40.0, 16 | "beta_rec": 1.0, 17 | "patch_size": 16, 18 | "topk": 8, 19 | "n_kp_enc": 12, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 8, 22 | "n_kp_prior": 16, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 64, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 128 39 | ], 40 | "prior_channels": [ 41 | 16, 42 | 32, 43 | 64 44 | ], 45 | "timestep_horizon": 10, 46 | "predict_delta": true, 47 | "beta_dyn": 30.0, 48 | "scale_std": 0.3, 49 | "offset_std": 0.2, 50 | "obj_on_alpha": 0.1, 51 | "obj_on_beta": 0.1, 52 | "beta_dyn_rec": 1.0, 53 | "num_static_frames": 4, 54 | "pint_layers": 6, 55 | "pint_heads": 8, 56 | "pint_dim": 256, 57 | "run_prefix": "", 58 | "animation_horizon": 100, 59 | "eval_im_metrics": true, 60 | "use_resblock": false, 61 | "scheduler_gamma": 0.95, 62 | "adam_betas": [ 63 | 0.9, 64 | 0.999 65 | ], 66 | "adam_eps": 0.0001, 67 | "train_enc_prior": true, 68 | "start_dyn_epoch": 0, 69 | "cond_steps": 10, 70 | "animation_fps": 0.06, 71 | "use_correlation_heatmaps": true, 72 | "enable_enc_attn": false, 73 | "filtering_heuristic": "variance", 74 | "use_tracking": false 75 | } -------------------------------------------------------------------------------- /configs/shapes.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "shapes", 3 | "root": null, 4 | "device": "cuda", 5 | "batch_size": 32, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "mse", 14 | "sigma": 1.0, 15 | "beta_kl": 0.05, 16 | "beta_rec": 1.0, 17 | "patch_size": 8, 18 | "topk": 6, 19 | "n_kp_enc": 10, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 6, 22 | "context_dim": 3, 23 | "n_kp_prior": 64, 24 | "weight_decay": 0.0, 25 | "kp_range": [ 26 | -1, 27 | 1 28 | ], 29 | "warmup_epoch": 1, 30 | "dropout": 0.0, 31 | "iou_thresh": 0.2, 32 | "anchor_s": 0.25, 33 | "kl_balance": 0.001, 34 | "image_size": 64, 35 | "ch": 3, 36 | "enc_channels": [ 37 | 32, 38 | 32, 39 | 64, 40 | 64 41 | ], 42 | "prior_channels": [ 43 | 32, 44 | 32, 45 | 64 46 | ], 47 | "timestep_horizon": 1, 48 | "predict_delta": true, 49 | "beta_dyn": 0.1, 50 | "scale_std": 0.3, 51 | "offset_std": 0.2, 52 | "obj_on_alpha": 0.1, 53 | "obj_on_beta": 0.1, 54 | "beta_dyn_rec": 1.0, 55 | "num_static_frames": 4, 56 | "pint_layers": 6, 57 | "pint_heads": 8, 58 | "pint_dim": 256, 59 | "run_prefix": "_2_schedulers_ln", 60 | "animation_horizon": 100, 61 | "eval_im_metrics": true, 62 | "use_resblock": false, 63 | "scheduler_gamma": 0.95, 64 | "adam_betas": [ 65 | 0.9, 66 | 0.999 67 | ], 68 | "adam_eps": 0.0001, 69 | "train_enc_prior": true, 70 | "start_dyn_epoch": 0, 71 | "cond_steps": 10, 72 | "animation_fps": 0.06, 73 | "use_correlation_heatmaps": false, 74 | "enable_enc_attn": false, 75 | "filtering_heuristic": "none", 76 | "use_tracking": false 77 | } -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation & Setting Up the Environment 2 | 3 | * For your convenience, we provide an `environemnt.yml` file which installs the required packages in a `conda` 4 | environment named `dlp`. Alternatively, you can use `pip` to install `requirements.txt`. 5 | * Use the terminal or an Anaconda Prompt and run the following command `conda env create -f environment.yml`. 6 | 7 | If you prefer to set-up an environment manually, we provide the steps required set it up. 8 | We assume Anaconda or Miniconda is installed for environment management. 9 | 10 | 1. Create a new environment: `conda create -n dlp python=3.8` 11 | 2. Install PyTorch and CUDA (the command may vary depending on your system, change appropriately. https://pytorch.org/get-started/locally/): 12 | 13 | `conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.8 -c pytorch -c nvidia` 14 | 15 | 3. Run the following commands to install remaining `conda` libraries: 16 | 17 | `conda install -c conda-forge numpy` (should already be installed from (1)) 18 | `conda install -c conda-forge matplotlib` 19 | `conda install -c conda-forge tqdm` 20 | `conda install -c conda-forge scipy` 21 | `conda install -c conda-forge scikit-image` 22 | `conda install -c conda-forge imageio` 23 | `conda install -c conda-forge h5py` 24 | `conda install -c conda-forge notebook` (if you want to be able to run Jupyter Notebooks) 25 | `conda update ffmpeg` 26 | 27 | 4. Install `pip` packages: 28 | 29 | `pip install opencv-python==3.4.18.65` 30 | `pip install accelerate` 31 | `pip install piqa` 32 | `pip install einops` 33 | `pip install ttkthemes` (for the GUI) 34 | `pip install ttkwidgets` (for the GUI) 35 | 36 | 5. (OPTIONAL) Clean `conda` cache: `conda clean --all` 37 | -------------------------------------------------------------------------------- /configs/traffic_img.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "traffic", 3 | "root": "/media/newhd/data/traffic_data/img128np.npy", 4 | "device": "cuda", 5 | "batch_size": 32, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "vgg", 14 | "sigma": 1.0, 15 | "beta_kl": 40.0, 16 | "beta_rec": 1.0, 17 | "patch_size": 16, 18 | "topk": 10, 19 | "n_kp_enc": 25, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 8, 22 | "n_kp_prior": 64, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 128, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 32, 37 | 64, 38 | 64, 39 | 128, 40 | 128 41 | ], 42 | "prior_channels": [ 43 | 32, 44 | 32, 45 | 64 46 | ], 47 | "timestep_horizon": 10, 48 | "predict_delta": true, 49 | "beta_dyn": 40.0, 50 | "scale_std": 1.0, 51 | "offset_std": 1.0, 52 | "obj_on_alpha": 0.1, 53 | "obj_on_beta": 0.1, 54 | "beta_dyn_rec": 1.0, 55 | "num_static_frames": 4, 56 | "pint_layers": 6, 57 | "pint_heads": 8, 58 | "pint_dim": 256, 59 | "run_prefix": "", 60 | "animation_horizon": 100, 61 | "eval_im_metrics": true, 62 | "use_resblock": false, 63 | "scheduler_gamma": 0.95, 64 | "adam_betas": [ 65 | 0.9, 66 | 0.999 67 | ], 68 | "adam_eps": 0.0001, 69 | "train_enc_prior": true, 70 | "start_dyn_epoch": 0, 71 | "cond_steps": 10, 72 | "animation_fps": 0.06, 73 | "use_correlation_heatmaps": true, 74 | "enable_enc_attn": false, 75 | "filtering_heuristic": "variance", 76 | "use_tracking": false 77 | } -------------------------------------------------------------------------------- /configs/obj3d128_img.json: -------------------------------------------------------------------------------- 1 | { 2 | "ds": "obj3d128", 3 | "root": "/media/newhd/data/obj3d/", 4 | "device": "cuda:0", 5 | "batch_size": 32, 6 | "lr": 0.0002, 7 | "kp_activation": "tanh", 8 | "pad_mode": "replicate", 9 | "load_model": false, 10 | "pretrained_path": null, 11 | "num_epochs": 150, 12 | "n_kp": 1, 13 | "recon_loss_type": "vgg", 14 | "sigma": 1.0, 15 | "beta_kl": 40.0, 16 | "beta_rec": 1.0, 17 | "patch_size": 16, 18 | "topk": 8, 19 | "n_kp_enc": 12, 20 | "eval_epoch_freq": 1, 21 | "learned_feature_dim": 8, 22 | "n_kp_prior": 64, 23 | "weight_decay": 0.0, 24 | "kp_range": [ 25 | -1, 26 | 1 27 | ], 28 | "warmup_epoch": 1, 29 | "dropout": 0.0, 30 | "iou_thresh": 0.2, 31 | "anchor_s": 0.25, 32 | "kl_balance": 0.001, 33 | "image_size": 128, 34 | "ch": 3, 35 | "enc_channels": [ 36 | 16, 37 | 32, 38 | 64, 39 | 64, 40 | 128 41 | ], 42 | "prior_channels": [ 43 | 16, 44 | 32, 45 | 64 46 | ], 47 | "timestep_horizon": 10, 48 | "predict_delta": true, 49 | "beta_dyn": 40.0, 50 | "scale_std": 1.0, 51 | "offset_std": 1.0, 52 | "obj_on_alpha": 0.1, 53 | "obj_on_beta": 0.1, 54 | "beta_dyn_rec": 1.0, 55 | "num_static_frames": 4, 56 | "pint_layers": 6, 57 | "pint_heads": 8, 58 | "pint_dim": 256, 59 | "run_prefix": "_first_filters_12_bkl_40", 60 | "animation_horizon": 100, 61 | "eval_im_metrics": true, 62 | "use_resblock": false, 63 | "scheduler_gamma": 0.95, 64 | "adam_betas": [ 65 | 0.9, 66 | 0.999 67 | ], 68 | "adam_eps": 0.0001, 69 | "train_enc_prior": true, 70 | "start_dyn_epoch": 0, 71 | "cond_steps": 10, 72 | "animation_fps": 0.06, 73 | "use_correlation_heatmaps": false, 74 | "enable_enc_attn": false, 75 | "filtering_heuristic": "variance", 76 | "use_tracking": false 77 | } -------------------------------------------------------------------------------- /datasets/get_dataset.py: -------------------------------------------------------------------------------- 1 | # datasets 2 | from datasets.traffic_ds import TrafficDataset, TrafficDatasetImage 3 | from datasets.clevrer_ds import CLEVREREpDataset, CLEVREREpDatasetImage 4 | from datasets.shapes_ds import generate_shape_dataset_torch 5 | from datasets.balls_ds import Balls, BallsImage 6 | from datasets.obj3d_ds import Obj3D, Obj3DImage 7 | from datasets.phyre_ds import PhyreDataset, PhyreDatasetImage 8 | from datasets.langtable_ds import LanguageTableDataset, LanguageTableDatasetImage 9 | 10 | 11 | def get_video_dataset(ds, root, seq_len=1, mode='train', image_size=128): 12 | # load data 13 | if ds == "traffic": 14 | dataset = TrafficDataset(path_to_npy=root, image_size=image_size, mode=mode, sample_length=seq_len) 15 | elif ds == 'clevrer': 16 | dataset = CLEVREREpDataset(root=root, mode=mode, sample_length=seq_len) 17 | elif ds == 'balls': 18 | dataset = Balls(root=root, mode=mode, sample_length=seq_len) 19 | elif ds == 'obj3d': 20 | dataset = Obj3D(root=root, mode=mode, sample_length=seq_len) 21 | elif ds == 'obj3d128': 22 | image_size = 128 23 | dataset = Obj3D(root=root, mode=mode, sample_length=seq_len, res=image_size) 24 | elif ds == 'phyre': 25 | dataset = PhyreDataset(root=root, mode=mode, sample_length=seq_len, image_size=image_size) 26 | elif ds == 'langtable': 27 | dataset = LanguageTableDataset(root=root, mode=mode, sample_length=seq_len, image_size=image_size) 28 | else: 29 | raise NotImplementedError 30 | return dataset 31 | 32 | 33 | def get_image_dataset(ds, root, mode='train', image_size=128, seq_len=1): 34 | # set seq_len > 1 when training with use_tracking 35 | # load data 36 | if ds == "traffic": 37 | dataset = TrafficDatasetImage(path_to_npy=root, image_size=image_size, mode=mode, sample_length=seq_len) 38 | elif ds == 'clevrer': 39 | dataset = CLEVREREpDatasetImage(root=root, mode=mode, sample_length=seq_len) 40 | elif ds == 'balls': 41 | dataset = BallsImage(root=root, mode=mode, sample_length=seq_len) 42 | elif ds == 'obj3d': 43 | dataset = Obj3DImage(root=root, mode=mode, sample_length=seq_len) 44 | elif ds == 'obj3d128': 45 | image_size = 128 46 | dataset = Obj3DImage(root=root, mode=mode, sample_length=seq_len, res=image_size) 47 | elif ds == 'phyre': 48 | dataset = PhyreDatasetImage(root=root, mode=mode, sample_length=seq_len, image_size=image_size) 49 | elif ds == 'shapes': 50 | if mode == 'train': 51 | dataset = generate_shape_dataset_torch(img_size=image_size, num_images=40_000) 52 | else: 53 | dataset = generate_shape_dataset_torch(img_size=image_size, num_images=2_000) 54 | elif ds == 'langtable': 55 | dataset = LanguageTableDatasetImage(root=root, mode=mode, sample_length=seq_len, image_size=image_size) 56 | else: 57 | raise NotImplementedError 58 | return dataset 59 | -------------------------------------------------------------------------------- /docs/gui.md: -------------------------------------------------------------------------------- 1 | # (D)DLP Graphical User Interface (GUI) Usage Instructions 2 | 3 | 4 | **IMPORTANT NOTE**: DDLP only works with small latent modifications, and cannot handle modifications that result in out-of-distribution examples . 5 | 6 | General usage: 7 | 1. Choosing a pre-trained model: the GUI looks for models inside the `checkpoints`" directory. The GUI supports 3 types 8 | of models: [`dlp`, `ddlp`, `diffuse-ddlp`], and the pre-trained paths should be organized as follows: 9 | `checkpoints/{model-type}-{ds}/[hparams.json, {ds}_{model-type}.pth]` for `dlp`/`ddlp` 10 | and `checkpoints/diffuse-ddlp-{ds}/[ddlp_hparams.json,diffusion_hparams.json, {ds}_ddlp.pth, latent_stats.pth, /saves/model.pth]` 11 | for `diffuse-ddlp`. For example: `checkpoints/ddlp-obj3d128/[hparams.json, obj3d128_ddlp.pth]`. 12 | 13 | 14 | ``` 15 | checkpoints 16 | ├── ddlp-obj3d128 17 | │ ├── obj3d128_ddlp.pth 18 | │ ├── hparams.json 19 | ├── dlp-traffic 20 | ├── diffuse-ddlp-obj3d128 21 | │ ├── diffusion_hparams.json 22 | │ ├── ddlp_hparams.json 23 | │ ├── latent_stats.pth 24 | │ ├── saves 25 | │ │ ├── model.pth 26 | └── ... 27 | ``` 28 | 29 | 3. Choosing/generating an example: For `dlp/ddlp` the GUI looks for examples for a dataset in the `assets` directory, 30 | where each example is a directory with an integer number as its name. 31 | Under each example directory there should be images, where for `ddlp` at least 4 consecutive are required, 32 | numbered by their order. For example, `assets/obj3d128/428/[1.png, 2.png, ...]`. 33 | For `diffuse-ddlp`, press the `Generate button` to generate a sequence of 4 (latent) frames. 34 | 35 | 4. Choosing a device: the GUI can run everything on the CPU, but if CUDA is available, 36 | you can switch to a GPU to perform computations. 37 | 38 | 5. Animating latent transitions: if the `animate` checkbox is marked, the GUI will animate the latent interpolation 39 | between modifications after the `Update` button is pressed (naturally, this is slower). 40 | 41 | 6. Hiding particles: you can temporarily hide the particles to view the current image 42 | by marking the `hide particles` checkbox. Removing the check will restore the particles. 43 | 44 | 7. Latent modifications: the GUI supports the following modifications: 45 | * Moving particles by dragging. Use the selection tool to select multiple particles at once and then drag 46 | them all together. This is useful when objects are assigned multiple particles. 47 | * Changing scale/transparency: once a particle is pressed on, a modifications menu will open. 48 | You can change the scale and transparency for multiple particles at once by first pressing a particle, 49 | and then using the selection tool to select multiple particles and all changes applied to the pressed particle 50 | will be applied to all of them. 51 | * Changing visual appearance: when an example is selected/generated, the visual features of all particles are saved 52 | in a dictionary where the key is the particle number. You can choose between these available features. 53 | Similarly, you can change the features of multiple features at once by first pressing on a particle and 54 | then use the selection tool to pick all the particles that will be changed. 55 | 56 | 8. Video prediction: when using a `ddlp`-based model, you can unroll the latent particles and generate a video in 57 | a new window by pressing the `Play` button. 58 | Note for DDLP, to make things simpler, we only allow changes to particles at t=0 and t=3 and interpolate all 59 | particles in-between. 60 | 61 | Note that DDLP is quite sensitive to out-of-distribution modifications. 62 | -------------------------------------------------------------------------------- /datasets/balls_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/zhixuan-lin/G-SWM 3 | """ 4 | 5 | from torch.utils.data import Dataset, DataLoader 6 | import h5py 7 | import os 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class Balls(Dataset): 13 | def __init__(self, root, mode, ep_len=100, sample_length=20): 14 | """ 15 | Args: 16 | root: dataset root 17 | mode: one of ['train', 'val', 'test'] 18 | ep_len: episode length of the dataset file 19 | sample_length: the actual maximum episode length you want 20 | """ 21 | assert mode in ['train', 'val', 'valid', 'test'] 22 | if mode == 'valid': 23 | mode = 'val' 24 | self.root = root 25 | file = os.path.join(self.root, f'{mode}.hdf5') 26 | assert os.path.exists(file), 'Path {} does not exist'.format(file) 27 | self.file = file 28 | 29 | self.mode = mode 30 | self.sample_length = sample_length 31 | 32 | self.ep_len = ep_len 33 | self.seq_per_episode = self.ep_len - self.sample_length + 1 34 | 35 | def __getitem__(self, index): 36 | 37 | with h5py.File(self.file, 'r') as f: 38 | self.imgs = f['imgs'] 39 | self.positions = f['positions'] 40 | self.sizes = f['sizes'] 41 | self.ids = f['ids'] 42 | self.in_camera = f['in_camera'] 43 | 44 | if self.mode == 'train': 45 | # Implement continuous indexing 46 | ep = index // self.seq_per_episode 47 | offset = index % self.seq_per_episode 48 | end = offset + self.sample_length 49 | img = self.imgs[ep][offset:end] 50 | pos = self.positions[ep][offset:end] 51 | size = self.sizes[ep][offset:end] 52 | id = self.ids[ep][offset:end] 53 | in_camera = self.in_camera[ep][offset:end] 54 | else: 55 | img = self.imgs[index] 56 | pos = self.positions[index] 57 | size = self.sizes[index] 58 | id = self.ids[index] 59 | in_camera = self.in_camera[index] 60 | assert img.shape[0] == self.ep_len 61 | 62 | img = torch.from_numpy(img).permute(0, 3, 1, 2) 63 | img = img.float() / 255.0 64 | 65 | return img, pos, size, id, in_camera 66 | 67 | def __len__(self): 68 | with h5py.File(self.file, 'r') as f: 69 | length = len(f['imgs']) 70 | if self.mode == 'train': 71 | return length * self.seq_per_episode 72 | else: 73 | return length 74 | 75 | 76 | class BallsImage(Dataset): 77 | def __init__(self, root, mode, ep_len=100, sample_length=20): 78 | """ 79 | Args: 80 | root: dataset root 81 | mode: one of ['train', 'val', 'test'] 82 | ep_len: episode length of the dataset file 83 | sample_length: the actual maximum episode length you want 84 | """ 85 | assert mode in ['train', 'val', 'valid', 'test'] 86 | if mode == 'valid': 87 | mode = 'val' 88 | self.root = root 89 | file = os.path.join(self.root, f'{mode}.hdf5') 90 | assert os.path.exists(file), 'Path {} does not exist'.format(file) 91 | self.file = file 92 | 93 | self.mode = mode 94 | self.sample_length = sample_length 95 | 96 | self.ep_len = ep_len 97 | self.seq_per_episode = self.ep_len - self.sample_length + 1 98 | 99 | def __getitem__(self, index): 100 | 101 | with h5py.File(self.file, 'r') as f: 102 | self.imgs = f['imgs'] 103 | self.positions = f['positions'] 104 | self.sizes = f['sizes'] 105 | self.ids = f['ids'] 106 | self.in_camera = f['in_camera'] 107 | 108 | # Implement continuous indexing 109 | ep = index // self.seq_per_episode 110 | offset = index % self.seq_per_episode 111 | end = offset + self.sample_length 112 | img = self.imgs[ep][offset:end] 113 | pos = self.positions[ep][offset:end] 114 | size = self.sizes[ep][offset:end] 115 | id = self.ids[ep][offset:end] 116 | in_camera = self.in_camera[ep][offset:end] 117 | 118 | img = torch.from_numpy(img).permute(0, 3, 1, 2) 119 | img = img.float() / 255.0 120 | 121 | return img, pos, size, id, in_camera 122 | 123 | def __len__(self): 124 | with h5py.File(self.file, 'r') as f: 125 | length = len(f['imgs']) 126 | return length * self.seq_per_episode 127 | 128 | 129 | if __name__ == '__main__': 130 | import matplotlib.pyplot as plt 131 | 132 | root = '/media/newhd/data/gswm_balls/BALLS_INTERACTION' 133 | mode = 'train' 134 | ds = Balls(root, mode, ep_len=100, sample_length=10) 135 | dl = DataLoader(ds, shuffle=True, batch_size=2) 136 | 137 | batch = next(iter(dl)) 138 | img, _, _, _, _ = batch 139 | print(f'image shape: {img.shape}') 140 | fig = plt.figure(dpi=300) 141 | for i in range(len(img[0])): 142 | im = img[0][i].permute(1, 2, 0).data.cpu().numpy() 143 | ax = fig.add_subplot(1, len(img[0]), i + 1) 144 | ax.imshow(im) 145 | ax.set_axis_off() 146 | ax.set_title(f'{i}') 147 | plt.show() 148 | -------------------------------------------------------------------------------- /datasets/obj3d_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | OBJ3D from G-SWM 3 | https://github.com/zhixuan-lin/G-SWM/blob/master/src/dataset/obj3d.py 4 | """ 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | import glob 9 | import os 10 | import os.path as osp 11 | import torch 12 | from PIL import Image, ImageFile 13 | 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | 17 | class Obj3D(Dataset): 18 | def __init__(self, root, mode, ep_len=100, sample_length=20, res=64): 19 | # path = os.path.join(root, mode) 20 | assert mode in ['train', 'val', 'valid', 'test'] 21 | if mode == 'valid': 22 | mode = 'val' 23 | self.root = os.path.join(root, mode) 24 | self.res = res 25 | 26 | self.mode = mode 27 | self.sample_length = sample_length 28 | 29 | # Get all numbers 30 | self.folders = [] 31 | for file in os.listdir(self.root): 32 | try: 33 | self.folders.append(int(file)) 34 | except ValueError: 35 | continue 36 | self.folders.sort() 37 | 38 | self.epsisodes = [] 39 | self.EP_LEN = ep_len 40 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1 41 | 42 | for f in self.folders: 43 | dir_name = os.path.join(self.root, str(f)) 44 | paths = list(glob.glob(osp.join(dir_name, 'test_*.png'))) 45 | # if len(paths) != self.EP_LEN: 46 | # continue 47 | # assert len(paths) == self.EP_LEN, 'len(paths): {}'.format(len(paths)) 48 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0].partition('_')[-1]) 49 | paths.sort(key=get_num) 50 | self.epsisodes.append(paths) 51 | 52 | def __getitem__(self, index): 53 | 54 | imgs = [] 55 | if self.mode == 'train': 56 | # Implement continuous indexing 57 | ep = index // self.seq_per_episode 58 | offset = index % self.seq_per_episode 59 | end = offset + self.sample_length 60 | 61 | e = self.epsisodes[ep] 62 | for image_index in range(offset, end): 63 | img = Image.open(osp.join(e[image_index])) 64 | img = img.resize((self.res, self.res)) 65 | img = transforms.ToTensor()(img)[:3] 66 | imgs.append(img) 67 | else: 68 | for path in self.epsisodes[index]: 69 | img = Image.open(path) 70 | img = img.resize((self.res, self.res)) 71 | img = transforms.ToTensor()(img)[:3] 72 | imgs.append(img) 73 | 74 | img = torch.stack(imgs, dim=0).float() 75 | pos = torch.zeros(0) 76 | size = torch.zeros(0) 77 | id = torch.zeros(0) 78 | in_camera = torch.zeros(0) 79 | 80 | return img, pos, size, id, in_camera 81 | 82 | def __len__(self): 83 | length = len(self.epsisodes) 84 | if self.mode == 'train': 85 | return length * self.seq_per_episode 86 | else: 87 | return length 88 | 89 | 90 | class Obj3DImage(Dataset): 91 | def __init__(self, root, mode, ep_len=100, sample_length=20, res=64): 92 | # path = os.path.join(root, mode) 93 | assert mode in ['train', 'val', 'valid', 'test'] 94 | if mode == 'valid': 95 | mode = 'val' 96 | self.root = os.path.join(root, mode) 97 | self.res = res 98 | 99 | self.mode = mode 100 | self.sample_length = sample_length 101 | 102 | # Get all numbers 103 | self.folders = [] 104 | for file in os.listdir(self.root): 105 | try: 106 | self.folders.append(int(file)) 107 | except ValueError: 108 | continue 109 | self.folders.sort() 110 | 111 | self.epsisodes = [] 112 | self.EP_LEN = ep_len 113 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1 114 | 115 | for f in self.folders: 116 | dir_name = os.path.join(self.root, str(f)) 117 | paths = list(glob.glob(osp.join(dir_name, 'test_*.png'))) 118 | # if len(paths) != self.EP_LEN: 119 | # continue 120 | # assert len(paths) == self.EP_LEN, 'len(paths): {}'.format(len(paths)) 121 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0].partition('_')[-1]) 122 | paths.sort(key=get_num) 123 | self.epsisodes.append(paths) 124 | 125 | def __getitem__(self, index): 126 | 127 | imgs = [] 128 | # Implement continuous indexing 129 | ep = index // self.seq_per_episode 130 | offset = index % self.seq_per_episode 131 | end = offset + self.sample_length 132 | 133 | e = self.epsisodes[ep] 134 | for image_index in range(offset, end): 135 | img = Image.open(osp.join(e[image_index])) 136 | img = img.resize((self.res, self.res)) 137 | img = transforms.ToTensor()(img)[:3] 138 | imgs.append(img) 139 | 140 | img = torch.stack(imgs, dim=0).float() 141 | pos = torch.zeros(0) 142 | size = torch.zeros(0) 143 | id = torch.zeros(0) 144 | in_camera = torch.zeros(0) 145 | 146 | return img, pos, size, id, in_camera 147 | 148 | def __len__(self): 149 | length = len(self.epsisodes) 150 | return length * self.seq_per_episode 151 | -------------------------------------------------------------------------------- /generate_ddlp_video_prediction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate conditional video prediction from a pre-trained DDLP 3 | """ 4 | # imports 5 | import os 6 | import argparse 7 | import json 8 | # torch 9 | import torch 10 | # modules 11 | from models import ObjectDynamicsDLP 12 | # util functions 13 | from eval.eval_model import animate_trajectory_ddlp 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description="DDLP Video Prediction") 17 | parser.add_argument("-d", "--dataset", type=str, default='balls', 18 | help="dataset of to train the model on: ['traffic', 'clevrer', 'shapes']") 19 | parser.add_argument("-p", "--path", type=str, 20 | help="path to model directory, e.g. ./310822_141959_balls_ddlp") 21 | parser.add_argument("--checkpoint", type=str, 22 | help="direct path to model checkpoint, e.g. ./checkpoints/ddlp-obj3d128/obj3d_ddlp.pth", 23 | default="") 24 | parser.add_argument("--use_last", action='store_true', 25 | help="use the last checkpoint instead of best") 26 | parser.add_argument("--use_train", action='store_true', 27 | help="use the train set for the predictions") 28 | parser.add_argument("--sample", action='store_true', 29 | help="use stochastic (non-deterministic) predictions") 30 | parser.add_argument("--cpu", action='store_true', 31 | help="use cpu for inference") 32 | parser.add_argument("--horizon", type=int, help="frame horizon to generate", default=50) 33 | parser.add_argument("-n", "--num_predictions", type=int, help="number of animations to generate", default=5) 34 | parser.add_argument("-c", "--cond_steps", type=int, help="the initial number of frames for predictions", default=-1) 35 | parser.add_argument("--prefix", type=str, default='', 36 | help="prefix used for model saving") 37 | args = parser.parse_args() 38 | # parse input 39 | dir_path = args.path 40 | checkpoint_path = args.checkpoint 41 | ds = args.dataset 42 | use_train = args.use_train 43 | generation_horizon = args.horizon 44 | num_predictions = args.num_predictions 45 | cond_steps = args.cond_steps 46 | use_cpu = args.cpu 47 | deterministic = not args.sample 48 | prefix = args.prefix 49 | # load model config 50 | model_ckpt_name = f'{ds}_ddlp{prefix}.pth' 51 | model_best_ckpt_name = f'{ds}_ddlp{prefix}_best_lpips.pth' 52 | # model_best_ckpt_name = f'{ds}_ddlp{prefix}_best.pth' # can also be used 53 | use_last = args.use_last if os.path.exists(os.path.join(dir_path, f'saves/{model_best_ckpt_name}')) else True 54 | conf_path = os.path.join(dir_path, 'hparams.json') 55 | with open(conf_path, 'r') as f: 56 | config = json.load(f) 57 | if use_cpu: 58 | device = torch.device("cpu") 59 | else: 60 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") 61 | image_size = config['image_size'] 62 | ch = config['ch'] 63 | enc_channels = config['enc_channels'] 64 | prior_channels = config['prior_channels'] 65 | use_correlation_heatmaps = config['use_correlation_heatmaps'] 66 | enable_enc_attn = config['enable_enc_attn'] 67 | filtering_heuristic = config['filtering_heuristic'] 68 | 69 | model = ObjectDynamicsDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 70 | image_size=image_size, n_kp=config['n_kp'], 71 | learned_feature_dim=config['learned_feature_dim'], 72 | pad_mode=config['pad_mode'], 73 | sigma=config['sigma'], 74 | dropout=config['dropout'], patch_size=config['patch_size'], 75 | n_kp_enc=config['n_kp_enc'], 76 | n_kp_prior=config['n_kp_prior'], kp_range=config['kp_range'], 77 | kp_activation=config['kp_activation'], 78 | anchor_s=config['anchor_s'], 79 | use_resblock=config['use_resblock'], 80 | timestep_horizon=config['timestep_horizon'], predict_delta=config['predict_delta'], 81 | scale_std=config['scale_std'], 82 | offset_std=config['offset_std'], obj_on_alpha=config['obj_on_alpha'], 83 | obj_on_beta=config['obj_on_beta'], pint_heads=config['pint_heads'], 84 | pint_layers=config['pint_layers'], pint_dim=config['pint_dim'], 85 | use_correlation_heatmaps=use_correlation_heatmaps, 86 | enable_enc_attn=enable_enc_attn, filtering_heuristic=filtering_heuristic).to(device) 87 | if checkpoint_path.endswith('.pth'): 88 | ckpt_path = checkpoint_path 89 | else: 90 | ckpt_path = os.path.join(dir_path, f'saves/{model_ckpt_name if use_last else model_best_ckpt_name}') 91 | model.load_state_dict(torch.load(ckpt_path, map_location=device)) 92 | model.eval() 93 | model.requires_grad_(False) 94 | print(f"loaded model from {ckpt_path}") 95 | 96 | # create dir for videos 97 | pred_dir = os.path.join(dir_path, 'animations') 98 | os.makedirs(pred_dir, exist_ok=True) 99 | 100 | # conditional frames 101 | cond_steps = cond_steps if cond_steps > 0 else config['timestep_horizon'] 102 | print(f'conditional input frames: {cond_steps}') 103 | print(f'deterministic predictions (use only mu): {deterministic}') 104 | # generate 105 | print('generating animations...') 106 | animate_trajectory_ddlp(model, config, epoch=0, device=device, fig_dir=pred_dir, 107 | timestep_horizon=generation_horizon, 108 | num_trajetories=num_predictions, accelerator=None, train=use_train, prefix='', 109 | cond_steps=cond_steps, deterministic=deterministic) 110 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dlp 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - _openmp_mutex=5.1 10 | - anyio=3.7.0 11 | - argon2-cffi=21.3.0 12 | - argon2-cffi-bindings=21.2.0 13 | - asttokens=2.2.1 14 | - attrs=23.1.0 15 | - backcall=0.2.0 16 | - backports=1.0 17 | - backports.functools_lru_cache=1.6.4 18 | - beautifulsoup4=4.12.2 19 | - blas=1.0 20 | - bleach=6.0.0 21 | - brotli=1.0.9 22 | - brotli-bin=1.0.9 23 | - brotlipy=0.7.0 24 | - bzip2=1.0.8 25 | - ca-certificates=2023.5.7 26 | - cached-property=1.5.2 27 | - cached_property=1.5.2 28 | - certifi=2023.5.7 29 | - cffi=1.15.1 30 | - charset-normalizer=2.0.4 31 | - click=8.1.3 32 | - cloudpickle=2.2.1 33 | - colorama=0.4.6 34 | - contourpy=1.0.5 35 | - cryptography=39.0.1 36 | - cuda-cudart=11.8.89 37 | - cuda-cupti=11.8.87 38 | - cuda-libraries=11.8.0 39 | - cuda-nvrtc=11.8.89 40 | - cuda-nvtx=11.8.86 41 | - cuda-runtime=11.8.0 42 | - cycler=0.11.0 43 | - cytoolz=0.12.0 44 | - dask-core=2023.3.2 45 | - dbus=1.13.18 46 | - decorator=5.1.1 47 | - defusedxml=0.7.1 48 | - entrypoints=0.4 49 | - enum34=1.1.10 50 | - exceptiongroup=1.1.1 51 | - executing=1.2.0 52 | - expat=2.4.9 53 | - ffmpeg=4.2.2 54 | - filelock=3.9.0 55 | - flit-core=3.9.0 56 | - fontconfig=2.14.1 57 | - fonttools=4.25.0 58 | - freetype=2.12.1 59 | - fsspec=2023.5.0 60 | - giflib=5.2.1 61 | - glib=2.69.1 62 | - gmp=6.2.1 63 | - gmpy2=2.1.2 64 | - gnutls=3.6.15 65 | - gst-plugins-base=1.14.1 66 | - gstreamer=1.14.1 67 | - h5py=3.1.0 68 | - hdf5=1.10.6 69 | - icu=58.2 70 | - idna=3.4 71 | - imagecodecs-lite=2019.12.3 72 | - imageio=2.28.1 73 | - importlib-metadata=6.6.0 74 | - importlib_metadata=6.6.0 75 | - importlib_resources=5.12.0 76 | - intel-openmp=2023.1.0 77 | - ipykernel=5.5.5 78 | - ipython=8.12.0 79 | - ipython_genutils=0.2.0 80 | - jedi=0.18.2 81 | - jinja2=3.1.2 82 | - jpeg=9e 83 | - jsonschema=4.17.3 84 | - jupyter_client=7.0.6 85 | - jupyter_core=4.12.0 86 | - jupyter_server=1.23.6 87 | - jupyterlab_pygments=0.2.2 88 | - kiwisolver=1.4.4 89 | - krb5=1.19.4 90 | - lame=3.100 91 | - lcms2=2.12 92 | - ld_impl_linux-64=2.38 93 | - lerc=3.0 94 | - libblas=3.9.0 95 | - libbrotlicommon=1.0.9 96 | - libbrotlidec=1.0.9 97 | - libbrotlienc=1.0.9 98 | - libcblas=3.9.0 99 | - libclang=14.0.6 100 | - libclang13=14.0.6 101 | - libcublas=11.11.3.6 102 | - libcufft=10.9.0.58 103 | - libcufile=1.6.1.9 104 | - libcurand=10.3.2.106 105 | - libcusolver=11.4.1.48 106 | - libcusparse=11.7.5.86 107 | - libdeflate=1.17 108 | - libedit=3.1.20221030 109 | - libevent=2.1.12 110 | - libffi=3.4.4 111 | - libgcc-ng=11.2.0 112 | - libgfortran-ng=13.1.0 113 | - libgfortran5=13.1.0 114 | - libgomp=11.2.0 115 | - libiconv=1.16 116 | - libidn2=2.3.4 117 | - liblapack=3.9.0 118 | - libllvm14=14.0.6 119 | - libnpp=11.8.0.86 120 | - libnvjpeg=11.9.0.86 121 | - libpng=1.6.39 122 | - libpq=12.9 123 | - libsodium=1.0.18 124 | - libstdcxx-ng=11.2.0 125 | - libtasn1=4.19.0 126 | - libtiff=4.5.0 127 | - libunistring=0.9.10 128 | - libuuid=1.41.5 129 | - libwebp=1.2.4 130 | - libwebp-base=1.2.4 131 | - libxcb=1.15 132 | - libxkbcommon=1.0.1 133 | - libxml2=2.10.3 134 | - libxslt=1.1.37 135 | - locket=1.0.0 136 | - lz4-c=1.9.4 137 | - markupsafe=2.1.1 138 | - matplotlib=3.7.1 139 | - matplotlib-base=3.7.1 140 | - matplotlib-inline=0.1.6 141 | - mistune=2.0.5 142 | - mkl=2023.1.0 143 | - mkl-service=2.4.0 144 | - mkl_fft=1.3.6 145 | - mkl_random=1.2.2 146 | - mpc=1.1.0 147 | - mpfr=4.0.2 148 | - mpmath=1.2.1 149 | - munkres=1.1.4 150 | - nbclassic=1.0.0 151 | - nbclient=0.8.0 152 | - nbconvert-core=7.4.0 153 | - nbformat=5.9.0 154 | - ncurses=6.4 155 | - nest-asyncio=1.5.6 156 | - nettle=3.7.3 157 | - networkx=2.8.4 158 | - notebook=6.5.4 159 | - notebook-shim=0.2.3 160 | - nspr=4.35 161 | - nss=3.89.1 162 | - numpy=1.24.3 163 | - numpy-base=1.24.3 164 | - openh264=2.1.1 165 | - openssl=1.1.1t 166 | - packaging=23.1 167 | - pandocfilters=1.5.0 168 | - parso=0.8.3 169 | - partd=1.4.0 170 | - pathlib=1.0.1 171 | - pcre=8.45 172 | - pexpect=4.8.0 173 | - pickleshare=0.7.5 174 | - pillow=9.4.0 175 | - pip=23.0.1 176 | - pkgutil-resolve-name=1.3.10 177 | - ply=3.11 178 | - prometheus_client=0.17.0 179 | - prompt-toolkit=3.0.38 180 | - prompt_toolkit=3.0.38 181 | - ptyprocess=0.7.0 182 | - pure_eval=0.2.2 183 | - pycparser=2.21 184 | - pygments=2.15.1 185 | - pyopenssl=23.0.0 186 | - pyparsing=3.0.9 187 | - pyqt=5.15.7 188 | - pyqt5-sip=12.11.0 189 | - pyrsistent=0.18.0 190 | - pysocks=1.7.1 191 | - python=3.8.16 192 | - python-dateutil=2.8.2 193 | - python-fastjsonschema 194 | - python_abi=3.8 195 | - pytorch=2.0.1 196 | - pytorch-cuda=11.8 197 | - pytorch-mutex=1.0 198 | - pywavelets=1.4.1 199 | - pyyaml=6.0 200 | - pyzmq=19.0.2 201 | - qt-main=5.15.2 202 | - qt-webengine=5.15.9 203 | - qtwebkit=5.212 204 | - readline=8.2 205 | - requests=2.29.0 206 | - scikit-image=0.19.2 207 | - scipy=1.8.1 208 | - send2trash=1.8.2 209 | - setuptools=67.8.0 210 | - sip=6.6.2 211 | - six=1.16.0 212 | - sniffio=1.3.0 213 | - soupsieve=2.3.2.post1 214 | - sqlite=3.41.2 215 | - stack_data=0.6.2 216 | - sympy=1.11.1 217 | - tbb=2021.8.0 218 | - terminado=0.17.1 219 | - tifffile=2019.7.26.2 220 | - tinycss2=1.2.1 221 | - tk=8.6.12 222 | - toml=0.10.2 223 | - toolz=0.12.0 224 | - torchtriton=2.0.0 225 | - torchvision=0.15.2 226 | - tornado=6.1 227 | - tqdm=4.65.0 228 | - traitlets=5.9.0 229 | - typing_extensions=4.5.0 230 | - urllib3=1.26.15 231 | - wcwidth=0.2.6 232 | - webencodings=0.5.1 233 | - websocket-client=1.5.2 234 | - wheel=0.38.4 235 | - xz=5.4.2 236 | - yaml=0.2.5 237 | - zeromq=4.3.4 238 | - zipp=3.15.0 239 | - zlib=1.2.13 240 | - zstd=1.5.5 241 | - pip: 242 | - accelerate==0.19.0 243 | - einops==0.6.1 244 | - opencv-python==3.4.18.65 245 | - piqa==1.3.1 246 | - psutil==5.9.5 247 | - ttkthemes==3.2.2 248 | - ttkwidgets==0.13.0 249 | -------------------------------------------------------------------------------- /datasets/traffic_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | functions and classes to process the Traffic dataset 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | 14 | def list_images_in_dir(path): 15 | valid_images = [".jpg", ".gif", ".png"] 16 | img_list = [] 17 | for f in os.listdir(path): 18 | ext = os.path.splitext(f)[1] 19 | if ext.lower() not in valid_images: 20 | continue 21 | img_list.append(os.path.join(path, f)) 22 | return img_list 23 | 24 | 25 | def prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1): 26 | # path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/' 27 | img_list = list_images_in_dir(path_to_image_dir) 28 | img_list = sorted(img_list, key=lambda x: int(x.split('/')[-1].split('.')[0])) 29 | print(f'img_list: {len(img_list)}, 0: {img_list[0]}, -1: {img_list[-1]}') 30 | img_np_list = [] 31 | for i in tqdm(range(len(img_list))): 32 | if i % frameskip != 0: 33 | continue 34 | img = Image.open(img_list[i]) 35 | img = img.convert('RGB') 36 | img = img.crop((60, 0, 480, 420)) 37 | img = img.resize((image_size, image_size), Image.BICUBIC) 38 | img_np = np.asarray(img) 39 | img_np_list.append(img_np) 40 | img_np_array = np.stack(img_np_list, axis=0) 41 | print(f'img_np_array: {img_np_array.shape}') 42 | save_path = os.path.join(path_to_image_dir, f'img{image_size}np_fs{frameskip}.npy') 43 | np.save(save_path, img_np_array) 44 | print(f'file save at @ {save_path}') 45 | 46 | 47 | class TrafficDataset(Dataset): 48 | def __init__(self, path_to_npy, mode, ep_len=50, sample_length=20, image_size=128, transform=None): 49 | super(TrafficDataset, self).__init__() 50 | assert mode in ['train', 'val', 'valid', 'test'] 51 | if mode == 'valid': 52 | mode = 'val' 53 | self.mode = mode 54 | self.horizon = sample_length 55 | self.ep_len = ep_len 56 | data = np.load(path_to_npy) 57 | train_size = int(0.8 * data.shape[0]) 58 | valid_size = int(0.1 * data.shape[0]) 59 | test_size = int(0.1 * data.shape[0]) 60 | if mode == 'train': 61 | print(f'loaded data with shape: {data.shape}, train_size: {train_size}, valid_size: {valid_size}') 62 | self.data = data[:train_size] 63 | elif mode == 'val': 64 | self.data = data[train_size:train_size + valid_size] 65 | elif mode == 'test': 66 | self.data = data[train_size + valid_size:] 67 | else: 68 | raise SystemError('unrecognized ds mode: {mode}') 69 | self.image_size = image_size 70 | self.num_episodes = len(self.data) // self.ep_len 71 | if transform is None: 72 | self.input_transform = transforms.Compose([ 73 | transforms.ToPILImage(), 74 | transforms.Resize(image_size), 75 | transforms.ToTensor() 76 | ]) 77 | else: 78 | self.input_transform = transform 79 | 80 | def __getitem__(self, index): 81 | images = [] 82 | if self.mode == 'train': 83 | length = self.data.shape[0] 84 | horizon = self.horizon if self.mode == 'train' else self.ep_len 85 | if (index + horizon) >= length: 86 | slack = index + horizon - length 87 | index = index - slack 88 | for i in range(horizon): 89 | t = index + i 90 | images.append(self.input_transform(self.data[t])) 91 | else: 92 | # episode i, get the starting index 93 | first_frame = index * self.ep_len 94 | length = self.data.shape[0] 95 | horizon = self.ep_len 96 | if (first_frame + horizon) >= length: 97 | slack = first_frame + horizon - length 98 | first_frame = first_frame - slack 99 | for i in range(horizon): 100 | t = first_frame + i 101 | images.append(self.input_transform(self.data[t])) 102 | 103 | images = torch.stack(images, dim=0) 104 | pos = torch.zeros(0) 105 | size = torch.zeros(0) 106 | id = torch.zeros(0) 107 | in_camera = torch.zeros(0) 108 | return images, pos, size, id, in_camera 109 | 110 | def __len__(self): 111 | if self.mode == 'train': 112 | return self.data.shape[0] 113 | else: 114 | return self.num_episodes 115 | 116 | 117 | class TrafficDatasetImage(Dataset): 118 | def __init__(self, path_to_npy, mode, ep_len=50, sample_length=20, image_size=128, transform=None): 119 | super(TrafficDatasetImage, self).__init__() 120 | assert mode in ['train', 'val', 'valid', 'test'] 121 | if mode == 'valid': 122 | mode = 'val' 123 | self.mode = mode 124 | self.horizon = sample_length 125 | self.ep_len = ep_len 126 | data = np.load(path_to_npy) 127 | train_size = int(0.8 * data.shape[0]) 128 | valid_size = int(0.1 * data.shape[0]) 129 | test_size = int(0.1 * data.shape[0]) 130 | if mode == 'train': 131 | print(f'loaded data with shape: {data.shape}, train_size: {train_size}, valid_size: {valid_size}') 132 | self.data = data[:train_size] 133 | elif mode == 'val': 134 | self.data = data[train_size:train_size + valid_size] 135 | elif mode == 'test': 136 | self.data = data[train_size + valid_size:] 137 | else: 138 | raise SystemError('unrecognized ds mode: {mode}') 139 | self.image_size = image_size 140 | self.num_episodes = len(self.data) // self.ep_len 141 | if transform is None: 142 | self.input_transform = transforms.Compose([ 143 | transforms.ToPILImage(), 144 | transforms.Resize(image_size), 145 | transforms.ToTensor() 146 | ]) 147 | else: 148 | self.input_transform = transform 149 | 150 | def __getitem__(self, index): 151 | images = [] 152 | length = self.data.shape[0] 153 | horizon = self.horizon 154 | if (index + horizon) >= length: 155 | slack = index + horizon - length 156 | index = index - slack 157 | for i in range(horizon): 158 | t = index + i 159 | images.append(self.input_transform(self.data[t])) 160 | images = torch.stack(images, dim=0) 161 | pos = torch.zeros(0) 162 | size = torch.zeros(0) 163 | id = torch.zeros(0) 164 | in_camera = torch.zeros(0) 165 | return images, pos, size, id, in_camera 166 | 167 | def __len__(self): 168 | return self.data.shape[0] 169 | -------------------------------------------------------------------------------- /generate_diffuse_ddlp_video_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to generate unconditional videos/images from a pre-trained DiffuseDDLP 3 | """ 4 | 5 | # imports 6 | import os 7 | import argparse 8 | 9 | # torch 10 | import torch 11 | 12 | # utils 13 | from utils.util_func import get_config 14 | 15 | # models 16 | from modules.diffusion_modules import TrainerDiffuseDDLP, GaussianDiffusionPINT, PINTDenoiser 17 | from train_diffuse_ddlp import ParticleNormalization 18 | from models import ObjectDynamicsDLP 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser(description="Diffusion DDLP Video Generation") 22 | parser.add_argument("-c", "--config", type=str, 23 | help="json file name of config file in the pre-trained model dir") 24 | parser.add_argument("-b", "--batch_size", type=int, help="batch size", default=10) 25 | parser.add_argument("-n", "--num_samples", type=int, help="num samples to generate", default=30) 26 | parser.add_argument("--cpu", action='store_true', help="use cpu for inference") 27 | parser.add_argument("--image", action='store_true', help="generate image plot instead of video") 28 | 29 | args = parser.parse_args() 30 | # parse input 31 | conf_path = args.config 32 | diffusion_config = get_config(conf_path) 33 | batch_size = args.batch_size 34 | num_samples = args.num_samples 35 | use_cpu = args.cpu 36 | gen_img = args.image 37 | result_dir = diffusion_config['result_dir'] 38 | ds = diffusion_config['ds'] 39 | ds_root = diffusion_config['ds_root'] # dataset root 40 | diffuse_frames = diffusion_config['diffuse_frames'] # number of particle frames to generate 41 | lr = diffusion_config['lr'] 42 | train_num_steps = diffusion_config['train_num_steps'] 43 | diffusion_num_steps = diffusion_config['diffusion_num_steps'] 44 | loss_type = diffusion_config['loss_type'] 45 | particle_norm = diffusion_config['particle_norm'] 46 | device = "cpu" if use_cpu else diffusion_config['device'] 47 | 48 | if 'cuda' in device: 49 | device = torch.device(f'{device}' if torch.cuda.is_available() else 'cpu') 50 | else: 51 | device = torch.device('cpu') 52 | 53 | """ 54 | load pre-trained DDLP 55 | """ 56 | ddlp_dir = diffusion_config['ddlp_dir'] 57 | ddlp_ckpt = diffusion_config['ddlp_ckpt'] 58 | ddlp_conf = os.path.join(ddlp_dir, 'hparams.json') 59 | ddlp_config = get_config(ddlp_conf) 60 | # load model 61 | image_size = ddlp_config['image_size'] 62 | ch = ddlp_config['ch'] 63 | enc_channels = ddlp_config['enc_channels'] 64 | prior_channels = ddlp_config['prior_channels'] 65 | use_correlation_heatmaps = ddlp_config['use_correlation_heatmaps'] 66 | enable_enc_attn = ddlp_config['enable_enc_attn'] 67 | filtering_heuristic = ddlp_config['filtering_heuristic'] 68 | animation_fps = ddlp_config["animation_fps"] 69 | 70 | model = ObjectDynamicsDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 71 | image_size=image_size, n_kp=ddlp_config['n_kp'], 72 | learned_feature_dim=ddlp_config['learned_feature_dim'], 73 | pad_mode=ddlp_config['pad_mode'], 74 | sigma=ddlp_config['sigma'], 75 | dropout=ddlp_config['dropout'], patch_size=ddlp_config['patch_size'], 76 | n_kp_enc=ddlp_config['n_kp_enc'], 77 | n_kp_prior=ddlp_config['n_kp_prior'], kp_range=ddlp_config['kp_range'], 78 | kp_activation=ddlp_config['kp_activation'], 79 | anchor_s=ddlp_config['anchor_s'], 80 | use_resblock=ddlp_config['use_resblock'], 81 | timestep_horizon=ddlp_config['timestep_horizon'], 82 | predict_delta=ddlp_config['predict_delta'], 83 | scale_std=ddlp_config['scale_std'], 84 | offset_std=ddlp_config['offset_std'], obj_on_alpha=ddlp_config['obj_on_alpha'], 85 | obj_on_beta=ddlp_config['obj_on_beta'], pint_heads=ddlp_config['pint_heads'], 86 | pint_layers=ddlp_config['pint_layers'], pint_dim=ddlp_config['pint_dim'], 87 | use_correlation_heatmaps=use_correlation_heatmaps, 88 | enable_enc_attn=enable_enc_attn, filtering_heuristic=filtering_heuristic).to(device) 89 | model.load_state_dict(torch.load(ddlp_ckpt, map_location=device)) 90 | model.eval() 91 | model.requires_grad_(False) 92 | print(f"loaded ddlp model from {ddlp_ckpt}") 93 | 94 | features_dim = 2 + 2 + 1 + 1 + ddlp_config['learned_feature_dim'] 95 | # features: xy, scale_xy, depth, obj_on, particle features 96 | # total particles: n_kp + 1 for bg 97 | ddpm_feat_dim = features_dim 98 | 99 | denoiser_model = PINTDenoiser(features_dim, hidden_dim=ddlp_config['pint_dim'], 100 | projection_dim=ddlp_config['pint_dim'], 101 | n_head=ddlp_config['pint_heads'], n_layer=ddlp_config['pint_layers'], 102 | block_size=diffuse_frames, dropout=0.1, 103 | predict_delta=False, positional_bias=True, max_particles=ddlp_config['n_kp_enc'] + 1, 104 | self_condition=False, 105 | learned_sinusoidal_cond=False, random_fourier_features=False, 106 | learned_sinusoidal_dim=16).to(device) 107 | 108 | diffusion = GaussianDiffusionPINT( 109 | denoiser_model, 110 | seq_length=diffuse_frames, 111 | timesteps=diffusion_num_steps, # number of steps 112 | sampling_timesteps=diffusion_num_steps, 113 | loss_type=loss_type, # L1 or L2 114 | objective='pred_x0', 115 | ).to(device) 116 | 117 | particle_normalizer = ParticleNormalization(diffusion_config, mode=particle_norm).to(device) 118 | 119 | # expects input: [batch_size, feature_dim, seq_len] 120 | 121 | trainer = TrainerDiffuseDDLP( 122 | diffusion, 123 | ddlp_model=model, 124 | diffusion_config=diffusion_config, 125 | particle_norm=particle_normalizer, 126 | train_batch_size=batch_size, 127 | train_lr=lr, 128 | train_num_steps=train_num_steps, # total training steps 129 | gradient_accumulate_every=1, # gradient accumulation steps 130 | ema_decay=0.995, # exponential moving average decay 131 | amp=False, # turn on mixed precision 132 | seq_len=diffuse_frames, 133 | save_and_sample_every=1000, 134 | results_folder=result_dir, animation_fps=animation_fps 135 | ) 136 | 137 | trainer.load() 138 | if gen_img: 139 | trainer.sample_image(num_samples) 140 | else: 141 | trainer.sample(num_samples) 142 | -------------------------------------------------------------------------------- /datasets/phyre_ds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import torchvision.transforms as transforms 6 | import glob 7 | from PIL import Image, ImageFile 8 | 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | 12 | def list_images_in_dir(path): 13 | valid_images = [".jpg", ".gif", ".png"] 14 | img_list = [] 15 | for f in os.listdir(path): 16 | ext = os.path.splitext(f)[1] 17 | if ext.lower() not in valid_images: 18 | continue 19 | img_list.append(os.path.join(path, f)) 20 | return img_list 21 | 22 | 23 | # --- new preprocessing functions for the episodic setting --- # 24 | class PhyreDataset(Dataset): 25 | def __init__(self, root, mode, ep_len=100, sample_length=20, image_size=128, start_idx=0, fps=10): 26 | # path = os.path.join(root, mode) 27 | # assume 10 frames-per-second 28 | # the data is generated such that a task is completed if the completion condition is met for 3 seconds or more. 29 | # that means that we can cut off 3 seconds (=30 frames) from the end of the episode) 30 | assert mode in ['train', 'val', 'valid', 'test'] 31 | if mode == 'val': 32 | mode = 'valid' 33 | self.root = os.path.join(root, mode) 34 | self.image_size = image_size 35 | self.start_idx = start_idx 36 | self.fps = fps 37 | self.cutoff = 3 * self.fps # 3 seconds off the end is just idle stuff 38 | 39 | self.mode = mode 40 | self.sample_length = sample_length 41 | 42 | # Get all numbers 43 | # print(os.listdir(self.root)) 44 | get_dir_num = lambda x: int(x) 45 | 46 | self.folders = [d for d in os.listdir(self.root) if osp.isdir(osp.join(self.root, d))] 47 | self.folders.sort(key=get_dir_num) 48 | # print(f'folders: {len(self.folders)}') 49 | 50 | self.episodes = [] 51 | self.episodes_len = [] 52 | self.EP_LEN = ep_len 53 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1 54 | # self.seq_per_episode = [] 55 | 56 | for f in self.folders: 57 | dir_name = os.path.join(self.root, str(f)) 58 | paths = list(glob.glob(osp.join(dir_name, '*.png'))) 59 | # ep_len = len(paths) 60 | # pad 61 | # if len(paths) < self.EP_LEN: 62 | # continue 63 | # self.episodes_len.append(ep_len) 64 | # self.episodes_len.append(self.EP_LEN) 65 | # self.seq_per_episode.append(self.EP_LEN - self.sample_length + 1) 66 | # assert len(paths) == self.EP_LEN, 'len(paths): {}'.format(len(paths)) 67 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0]) 68 | paths.sort(key=get_num) 69 | paths = paths[self.start_idx:-self.cutoff] 70 | if len(paths) < self.EP_LEN: 71 | # continue 72 | self.episodes_len.append(len(paths)) 73 | else: 74 | self.episodes_len.append(self.EP_LEN) 75 | while len(paths) < self.EP_LEN: 76 | paths.append(paths[-1]) 77 | self.episodes.append(paths[:self.EP_LEN]) 78 | # self.episodes_len_cumsum = np.cumsum(self.episodes_len) 79 | # print(f'episodes: {len(self.episodes)}, min: {min(self.episodes_len)}, max: {max(self.episodes_len)}') 80 | 81 | def __getitem__(self, index): 82 | 83 | imgs = [] 84 | if self.mode == 'train': 85 | # Implement continuous indexing 86 | ep = index // self.seq_per_episode 87 | # ep = np.argmax((index < self.episodes_len_cumsum)) 88 | offset = index % self.seq_per_episode 89 | # offset = index % self.seq_per_episode[ep] 90 | end = offset + self.sample_length 91 | # if `end` is after the episode ended, move backwards 92 | ep_len = self.episodes_len[ep] 93 | if end > ep_len: 94 | # print(f'before: offset: {offset}, end: {end}, ep_len: {ep_len}') 95 | if self.sample_length > ep_len: 96 | offset = 0 97 | end = offset + self.sample_length 98 | else: 99 | offset = ep_len - self.sample_length 100 | end = ep_len 101 | # print(f'after: offset: {offset}, end: {end}, ep_len: {ep_len}') 102 | 103 | e = self.episodes[ep] 104 | for image_index in range(offset, end): 105 | img = Image.open(osp.join(e[image_index])) 106 | # img.point(lambda x: 215.0 if x >= 253 else x) 107 | img = img.resize((self.image_size, self.image_size)) 108 | img = transforms.ToTensor()(img)[:3] 109 | imgs.append(img) 110 | else: 111 | for path in self.episodes[index]: 112 | img = Image.open(path) 113 | img = img.resize((self.image_size, self.image_size)) 114 | img = transforms.ToTensor()(img)[:3] 115 | imgs.append(img) 116 | 117 | img = torch.stack(imgs, dim=0).float() 118 | # invert colors 119 | img = 1.0 - img 120 | pos = torch.zeros(0) 121 | size = torch.zeros(0) 122 | id = torch.zeros(0) 123 | in_camera = torch.zeros(0) 124 | 125 | return img, pos, size, id, in_camera 126 | 127 | def __len__(self): 128 | length = len(self.episodes) 129 | if self.mode == 'train': 130 | return length * self.seq_per_episode 131 | # return sum(self.episodes_len) 132 | else: 133 | return length 134 | 135 | 136 | class PhyreDatasetImage(Dataset): 137 | def __init__(self, root, mode, ep_len=100, sample_length=20, image_size=128, start_idx=0, fps=10): 138 | # path = os.path.join(root, mode) 139 | # assume 10 frames-per-second 140 | # the data is generated such that a task is completed if the completion condition is met for 3 seconds or more. 141 | # that means that we can cut off 3 seconds (=30 frames) from the end of the episode) 142 | assert mode in ['train', 'val', 'valid', 'test'] 143 | if mode == 'val': 144 | mode = 'valid' 145 | self.root = os.path.join(root, mode) 146 | self.image_size = image_size 147 | self.start_idx = start_idx 148 | self.fps = fps 149 | self.cutoff = 3 * self.fps # 3 seconds off the end is just idle stuff 150 | 151 | self.mode = mode 152 | self.sample_length = sample_length 153 | 154 | # Get all numbers 155 | # print(os.listdir(self.root)) 156 | get_dir_num = lambda x: int(x) 157 | 158 | self.folders = [d for d in os.listdir(self.root) if osp.isdir(osp.join(self.root, d))] 159 | self.folders.sort(key=get_dir_num) 160 | # print(f'folders: {len(self.folders)}') 161 | 162 | self.episodes = [] 163 | self.episodes_len = [] 164 | self.EP_LEN = ep_len 165 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1 166 | # self.seq_per_episode = [] 167 | 168 | for f in self.folders: 169 | dir_name = os.path.join(self.root, str(f)) 170 | paths = list(glob.glob(osp.join(dir_name, '*.png'))) 171 | # ep_len = len(paths) 172 | # pad 173 | # if len(paths) < self.EP_LEN: 174 | # continue 175 | # self.episodes_len.append(ep_len) 176 | # self.episodes_len.append(self.EP_LEN) 177 | # self.seq_per_episode.append(self.EP_LEN - self.sample_length + 1) 178 | # assert len(paths) == self.EP_LEN, 'len(paths): {}'.format(len(paths)) 179 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0]) 180 | paths.sort(key=get_num) 181 | paths = paths[self.start_idx:-self.cutoff] 182 | if len(paths) < self.EP_LEN: 183 | # continue 184 | self.episodes_len.append(len(paths)) 185 | else: 186 | self.episodes_len.append(self.EP_LEN) 187 | while len(paths) < self.EP_LEN: 188 | paths.append(paths[-1]) 189 | self.episodes.append(paths[:self.EP_LEN]) 190 | # self.episodes_len_cumsum = np.cumsum(self.episodes_len) 191 | # print(f'episodes: {len(self.episodes)}, min: {min(self.episodes_len)}, max: {max(self.episodes_len)}') 192 | 193 | def __getitem__(self, index): 194 | 195 | imgs = [] 196 | # Implement continuous indexing 197 | ep = index // self.seq_per_episode 198 | # ep = np.argmax((index < self.episodes_len_cumsum)) 199 | offset = index % self.seq_per_episode 200 | # offset = index % self.seq_per_episode[ep] 201 | end = offset + self.sample_length 202 | # if `end` is after the episode ended, move backwards 203 | ep_len = self.episodes_len[ep] 204 | if end > ep_len: 205 | # print(f'before: offset: {offset}, end: {end}, ep_len: {ep_len}') 206 | if self.sample_length > ep_len: 207 | offset = 0 208 | end = offset + self.sample_length 209 | else: 210 | offset = ep_len - self.sample_length 211 | end = ep_len 212 | # print(f'after: offset: {offset}, end: {end}, ep_len: {ep_len}') 213 | 214 | e = self.episodes[ep] 215 | for image_index in range(offset, end): 216 | img = Image.open(osp.join(e[image_index])) 217 | # img.point(lambda x: 215.0 if x >= 253 else x) 218 | img = img.resize((self.image_size, self.image_size)) 219 | img = transforms.ToTensor()(img)[:3] 220 | imgs.append(img) 221 | 222 | img = torch.stack(imgs, dim=0).float() 223 | # invert colors 224 | img = 1.0 - img 225 | pos = torch.zeros(0) 226 | size = torch.zeros(0) 227 | id = torch.zeros(0) 228 | in_camera = torch.zeros(0) 229 | 230 | return img, pos, size, id, in_camera 231 | 232 | def __len__(self): 233 | length = len(self.episodes) 234 | return length * self.seq_per_episode 235 | 236 | 237 | if __name__ == '__main__': 238 | test_epochs = True 239 | # --- episodic setting --- # 240 | root = '/media/newhd/data/phyre' 241 | # root = '/mnt/data/tal/phyre' 242 | phyre_ds = PhyreDataset(root=root, ep_len=100, sample_length=10, mode='train', image_size=128, start_idx=0) 243 | phyre_dl = DataLoader(phyre_ds, shuffle=True, pin_memory=False, batch_size=32, num_workers=4) 244 | batch = next(iter(phyre_dl)) 245 | im = batch[0] 246 | print(im.shape) 247 | # img_np = im.permute(1, 2, 0).data.cpu().numpy() 248 | # fig = plt.figure(figsize=(5, 5)) 249 | # ax = fig.add_subplot(111) 250 | # ax.imshow(img_np) 251 | # plt.show() 252 | 253 | if test_epochs: 254 | from tqdm import tqdm 255 | 256 | pbar = tqdm(iterable=phyre_dl) 257 | for batch in pbar: 258 | pass 259 | pbar.close() 260 | -------------------------------------------------------------------------------- /eval/eval_gen_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate image metrics such as LPIPS, PSNR and SSIM using PIQA, 3 | """ 4 | 5 | # set workdir 6 | import os 7 | import sys 8 | 9 | sys.path.append(os.getcwd()) 10 | # sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) 11 | import argparse 12 | import json 13 | from tqdm import tqdm 14 | from models import ObjectDynamicsDLP 15 | # datasets 16 | from datasets.get_dataset import get_video_dataset, get_image_dataset 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.utils.data import DataLoader 21 | 22 | try: 23 | from piqa import PSNR, LPIPS, SSIM 24 | except ImportError: 25 | print("piqa library required to compute image metrics") 26 | raise SystemExit 27 | 28 | 29 | class ImageMetrics(nn.Module): 30 | """ 31 | A class to calculate visual metrics between generated and ground-truth images 32 | """ 33 | 34 | def __init__(self, metrics=('ssim', 'psnr', 'lpips')): 35 | super().__init__() 36 | self.metrics = metrics 37 | self.ssim = SSIM(reduction='none') if 'ssim' in self.metrics else None 38 | self.psnr = PSNR(reduction='none') if 'psnr' in self.metrics else None 39 | self.lpips = LPIPS(network='vgg', reduction='none') if 'lpips' in self.metrics else None 40 | 41 | @torch.no_grad() 42 | def forward(self, x, y): 43 | # x, y: [batch_size, 3, im_size, im_size] in [0,1] 44 | results = {} 45 | if self.ssim is not None: 46 | results['ssim'] = self.ssim(x, y) 47 | if self.psnr is not None: 48 | results['psnr'] = self.psnr(x, y) 49 | if self.lpips is not None: 50 | results['lpips'] = self.lpips(x, y) 51 | return results 52 | 53 | 54 | def eval_ddlp_im_metric(model, device, config, timestep_horizon=50, val_mode='val', eval_dir='./', 55 | cond_steps=10, 56 | metrics=('ssim', 'psnr', 'lpips'), batch_size=32, verbose=False, accelerator=None): 57 | if isinstance(model, torch.nn.DataParallel): 58 | model = model.module 59 | model.eval() 60 | ds = config['ds'] 61 | ch = config['ch'] # image channels 62 | image_size = config['image_size'] 63 | root = config['root'] # dataset root 64 | dataset = get_video_dataset(ds, root, seq_len=timestep_horizon, mode=val_mode, image_size=image_size) 65 | 66 | dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, num_workers=0, drop_last=False) 67 | model_timestep_horizon = model.timestep_horizon 68 | cond_steps = model_timestep_horizon if cond_steps is None else cond_steps 69 | 70 | # image metric instance 71 | evaluator = ImageMetrics(metrics=metrics).to(device) 72 | 73 | results = {} 74 | ssims = [] 75 | psnrs = [] 76 | lpipss = [] 77 | for i, batch in enumerate(tqdm(dataloader)): 78 | x = batch[0][:, :timestep_horizon].to(device) 79 | with torch.no_grad(): 80 | generated = model.sample(x, cond_steps=cond_steps, num_steps=timestep_horizon - cond_steps) 81 | generated = generated.clamp(0, 1) 82 | assert x.shape[1] == generated.shape[1], "prediction and gt frames shape don't match" 83 | results = evaluator(x[:, cond_steps:].reshape(-1, *x.shape[2:]), 84 | generated[:, cond_steps:].reshape(-1, *generated.shape[2:])) 85 | # [batch_size * T] 86 | if 'ssim' in metrics: 87 | ssims.append(results['ssim']) 88 | if 'psnr' in metrics: 89 | psnrs.append(results['psnr']) 90 | if 'lpips' in metrics: 91 | lpipss.append(results['lpips']) 92 | 93 | if 'ssim' in metrics: 94 | ssims = torch.cat(ssims, dim=0) 95 | mean_ssim = ssims.mean().data.cpu().item() 96 | std_ssim = ssims.std().data.cpu().item() 97 | results['ssim'] = mean_ssim 98 | results['ssim_std'] = std_ssim 99 | if 'psnr' in metrics: 100 | psnrs = torch.cat(psnrs, dim=0) 101 | mean_psnr = psnrs.mean().data.cpu().item() 102 | std_psnr = psnrs.std().data.cpu().item() 103 | results['psnr'] = mean_psnr 104 | results['psnr_std'] = std_psnr 105 | if 'lpips' in metrics: 106 | lpipss = torch.cat(lpipss, dim=0) 107 | mean_lpips = lpipss.mean().data.cpu().item() 108 | std_lpips = lpipss.std().data.cpu().item() 109 | results['lpips'] = mean_lpips 110 | results['lpips_std'] = std_lpips 111 | 112 | # save results 113 | path_to_conf = os.path.join(eval_dir, 'last_val_image_metrics.json') 114 | with open(path_to_conf, "w") as outfile: 115 | json.dump(results, outfile, indent=2) 116 | 117 | del evaluator # clear memory 118 | 119 | return results 120 | 121 | 122 | def eval_dlp_im_metric(model, device, config, val_mode='val', eval_dir='./', 123 | metrics=('ssim', 'psnr', 'lpips'), batch_size=32, verbose=False, accelerator=None): 124 | if isinstance(model, torch.nn.DataParallel): 125 | model = model.module 126 | model.eval() 127 | ds = config['ds'] 128 | ch = config['ch'] # image channels 129 | image_size = config['image_size'] 130 | root = config['root'] # dataset root 131 | use_tracking = config['use_tracking'] 132 | dataset = get_image_dataset(ds, root, mode=val_mode, image_size=image_size) 133 | 134 | dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size, num_workers=0, drop_last=False) 135 | 136 | # image metric instance 137 | evaluator = ImageMetrics(metrics=metrics).to(device) 138 | 139 | results = {} 140 | ssims = [] 141 | psnrs = [] 142 | lpipss = [] 143 | for i, batch in enumerate(tqdm(dataloader)): 144 | x = batch[0].to(device) 145 | if len(x.shape) == 5 and not use_tracking: 146 | # [bs, T, ch, h, w] 147 | x = x.view(-1, *x.shape[2:]) 148 | elif len(x.shape) == 4 and use_tracking: 149 | # [bs, ch, h, w] 150 | x = x.unsqueeze(1) 151 | x_prior = x 152 | with torch.no_grad(): 153 | output = model(x, x_prior=x_prior, deterministic=True) 154 | generated = output['rec'].clamp(0, 1) 155 | if len(x.shape) == 5: 156 | # [bs, T, ch, h, w] 157 | x = x.view(-1, *x.shape[2:]) 158 | results = evaluator(x, generated) 159 | # [batch_size * T] 160 | if 'ssim' in metrics: 161 | ssims.append(results['ssim']) 162 | if 'psnr' in metrics: 163 | psnrs.append(results['psnr']) 164 | if 'lpips' in metrics: 165 | lpipss.append(results['lpips']) 166 | 167 | if 'ssim' in metrics: 168 | ssims = torch.cat(ssims, dim=0) 169 | mean_ssim = ssims.mean().data.cpu().item() 170 | std_ssim = ssims.std().data.cpu().item() 171 | results['ssim'] = mean_ssim 172 | results['ssim_std'] = std_ssim 173 | if 'psnr' in metrics: 174 | psnrs = torch.cat(psnrs, dim=0) 175 | mean_psnr = psnrs.mean().data.cpu().item() 176 | std_psnr = psnrs.std().data.cpu().item() 177 | results['psnr'] = mean_psnr 178 | results['psnr_std'] = std_psnr 179 | if 'lpips' in metrics: 180 | lpipss = torch.cat(lpipss, dim=0) 181 | mean_lpips = lpipss.mean().data.cpu().item() 182 | std_lpips = lpipss.std().data.cpu().item() 183 | results['lpips'] = mean_lpips 184 | results['lpips_std'] = std_lpips 185 | 186 | # save results 187 | path_to_conf = os.path.join(eval_dir, 'last_val_image_metrics.json') 188 | with open(path_to_conf, "w") as outfile: 189 | json.dump(results, outfile, indent=2) 190 | 191 | del evaluator # clear memory 192 | 193 | return results 194 | 195 | 196 | if __name__ == '__main__': 197 | parser = argparse.ArgumentParser(description="DDLP Video Prediction Evaluation") 198 | parser.add_argument("-d", "--dataset", type=str, default='balls', 199 | help="dataset to use: ['balls', 'traffic', 'clevrer', 'obj3d128', ...]") 200 | parser.add_argument("-p", "--path", type=str, 201 | help="path to model directory, e.g. ./310822_141959_balls_ddlp") 202 | parser.add_argument("--checkpoint", type=str, 203 | help="direct path to model checkpoint, e.g. ./checkpoints/ddlp-obj3d128/obj3d_ddlp.pth", 204 | default="") 205 | parser.add_argument("--use_last", action='store_true', 206 | help="use the last checkpoint instead of best") 207 | parser.add_argument("--use_train", action='store_true', 208 | help="use the train set for the predictions") 209 | parser.add_argument("--sample", action='store_true', 210 | help="use stochastic (non-deterministic) predictions") 211 | parser.add_argument("--cpu", action='store_true', 212 | help="use cpu for inference") 213 | parser.add_argument("-c", "--cond_steps", type=int, help="the initial number of frames for predictions", default=-1) 214 | parser.add_argument("-b", "--batch_size", type=int, help="batch size", default=10) 215 | parser.add_argument("--horizon", type=int, help="timestep horizon for prediction", default=50) 216 | parser.add_argument("--prefix", type=str, default='', 217 | help="prefix used for model saving") 218 | args = parser.parse_args() 219 | # parse input 220 | dir_path = args.path 221 | checkpoint_path = args.checkpoint 222 | ds = args.dataset 223 | use_train = args.use_train 224 | cond_steps = args.cond_steps 225 | timestep_horizon = args.horizon 226 | batch_size = args.batch_size 227 | use_cpu = args.cpu 228 | deterministic = not args.sample 229 | prefix = args.prefix 230 | # load model config 231 | model_ckpt_name = f'{ds}_ddlp{prefix}.pth' 232 | # model_best_ckpt_name = f'{ds}_ddlp{prefix}_best.pth' 233 | model_best_ckpt_name = f'{ds}_ddlp{prefix}_best_lpips.pth' 234 | use_last = args.use_last if os.path.exists(os.path.join(dir_path, f'saves/{model_best_ckpt_name}')) else True 235 | conf_path = os.path.join(dir_path, 'hparams.json') 236 | with open(conf_path, 'r') as f: 237 | config = json.load(f) 238 | if use_cpu: 239 | device = torch.device("cpu") 240 | else: 241 | device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") 242 | 243 | ds = config['ds'] 244 | image_size = config['image_size'] 245 | ch = config['ch'] 246 | enc_channels = config['enc_channels'] 247 | prior_channels = config['prior_channels'] 248 | use_correlation_heatmaps = config['use_correlation_heatmaps'] # use heatmaps for tracking 249 | enable_enc_attn = config['enable_enc_attn'] # enable attention between patches in the particle encoder 250 | filtering_heuristic = config["filtering_heuristic"] # filtering heuristic to filter prior keypoints 251 | 252 | model = ObjectDynamicsDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 253 | image_size=image_size, n_kp=config['n_kp'], 254 | learned_feature_dim=config['learned_feature_dim'], 255 | pad_mode=config['pad_mode'], 256 | sigma=config['sigma'], 257 | dropout=config['dropout'], patch_size=config['patch_size'], 258 | n_kp_enc=config['n_kp_enc'], 259 | n_kp_prior=config['n_kp_prior'], kp_range=config['kp_range'], 260 | kp_activation=config['kp_activation'], 261 | anchor_s=config['anchor_s'], 262 | use_resblock=config['use_resblock'], 263 | timestep_horizon=config['timestep_horizon'], predict_delta=config['predict_delta'], 264 | scale_std=config['scale_std'], 265 | offset_std=config['offset_std'], obj_on_alpha=config['obj_on_alpha'], 266 | obj_on_beta=config['obj_on_beta'], pint_heads=config['pint_heads'], 267 | pint_layers=config['pint_layers'], pint_dim=config['pint_dim'], 268 | use_correlation_heatmaps=use_correlation_heatmaps, 269 | enable_enc_attn=enable_enc_attn, filtering_heuristic=filtering_heuristic).to(device) 270 | if checkpoint_path.endswith('.pth'): 271 | ckpt_path = checkpoint_path 272 | else: 273 | ckpt_path = os.path.join(dir_path, f'saves/{model_ckpt_name if use_last else model_best_ckpt_name}') 274 | model.load_state_dict(torch.load(ckpt_path, map_location=device)) 275 | model.eval() 276 | print(f"loaded model from {ckpt_path}") 277 | 278 | # create dir for results 279 | pred_dir = os.path.join(dir_path, 'eval') 280 | os.makedirs(pred_dir, exist_ok=True) 281 | 282 | # conditional frames 283 | cond_steps = cond_steps if cond_steps > 0 else config['timestep_horizon'] 284 | val_mode = 'train' if use_train else 'test' 285 | results = eval_ddlp_im_metric(model, device, timestep_horizon=timestep_horizon, val_mode=val_mode, config=config, 286 | eval_dir=pred_dir, cond_steps=cond_steps, metrics=('ssim', 'psnr', 'lpips'), 287 | batch_size=batch_size) 288 | print(f'results: {results}') 289 | -------------------------------------------------------------------------------- /docs/example_usage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example usage of DLPv2 and DDLP 3 | """ 4 | # imports 5 | import os 6 | import sys 7 | 8 | sys.path.append(os.getcwd()) 9 | # torch 10 | import torch 11 | # modules 12 | from models import ObjectDLP, ObjectDynamicsDLP 13 | 14 | torch.backends.cudnn.benchmark = False 15 | torch.backends.cudnn.deterministic = True 16 | 17 | if __name__ == '__main__': 18 | # example hyper-parameters 19 | batch_size = 32 20 | beta_kl = 0.1 21 | beta_rec = 1.0 22 | kl_balance = 0.001 # balance between spatial attributes (x, y, scale, depth) and visual features 23 | n_kp_enc = 12 24 | n_kp_prior = 15 25 | patch_size = 8 # patch size for the prior to generate prior proposals 26 | learned_feature_dim = 6 # visual features 27 | anchor_s = 0.25 # effective patch size for the posterior: anchor_s * image_size 28 | 29 | image_size = 64 30 | ch = 3 31 | enc_channels = [32, 64, 128] 32 | prior_channels = (32, 32, 64) 33 | 34 | use_correlation_heatmaps = False # for tracking, set True to use correlation heatmaps between patches 35 | 36 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 37 | device = torch.device("cpu") 38 | 39 | print("--- DLPv2 ---") 40 | 41 | # create model 42 | model = ObjectDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 43 | image_size=image_size, learned_feature_dim=learned_feature_dim, 44 | patch_size=patch_size, n_kp_enc=n_kp_enc, n_kp_prior=n_kp_prior, 45 | anchor_s=anchor_s, use_correlation_heatmaps=use_correlation_heatmaps).to(device) 46 | print(f'model.info():') 47 | print(model.info()) 48 | print("----------------------------------") 49 | # dummy data 50 | x = torch.rand(batch_size, ch, image_size, image_size, device=device) 51 | # complete forward 52 | model_output = model(x) 53 | # let's see what's inside 54 | print(f'model(x) output:') 55 | for k in model_output.keys(): 56 | print(f'{k}: {model_output[k].shape}') 57 | print("----------------------------------") 58 | """ 59 | output: 60 | kp_p: torch.Size([32, 15, 2]) # prior proposals 61 | rec: torch.Size([32, 3, 64, 64]) # full reconstructions 62 | mu: torch.Size([32, 12, 2]) # position mu 63 | logvar: torch.Size([32, 12, 2]) # position logvar 64 | z: torch.Size([32, 12, 2]) # position z 65 | z_base: torch.Size([32, 12, 2]) # position anchors (mu = z_base + mu_offset) 66 | z_kp_bg: torch.Size([32, 1, 2]) # constants (0.0, 0.0) for the bg kp 67 | mu_offset: torch.Size([32, 12, 2]) # position offset mu 68 | logvar_offset: torch.Size([32, 12, 2]) # position offset logvar 69 | mu_features: torch.Size([32, 12, 6]) # visual features mu 70 | logvar_features: torch.Size([32, 12, 6]) # visual features logvar 71 | z_features: torch.Size([32, 12, 6]) # visual features z 72 | bg: torch.Size([32, 3, 64, 64]) # bg reconstructions 73 | mu_bg: torch.Size([32, 6]) # bg visual features mu 74 | logvar_bg: torch.Size([32, 6]) # bg visual features logvar 75 | z_bg: torch.Size([32, 6]) # bg visual features z 76 | cropped_objects_original: torch.Size([32, 12, 3, 16, 16]) # extracted patches from the original image 77 | obj_on_a: torch.Size([32, 12]) # transparency beta distribution "a" parameter 78 | obj_on_b: torch.Size([32, 12]) # transparency beta distribution "b" parameter 79 | obj_on: torch.Size([32, 12]) # transparency sample per particle 80 | dec_objects_original: torch.Size([32, 12, 4, 16, 16]) # decoded glimpses (rgb + alpha channel) 81 | dec_objects: torch.Size([32, 3, 64, 64]) # decoded foreground (no bg) 82 | mu_depth: torch.Size([32, 12, 1]) # depth mu 83 | logvar_depth: torch.Size([32, 12, 1]) # depth logvar 84 | z_depth: torch.Size([32, 12, 1]) # depth z 85 | mu_scale: torch.Size([32, 12, 2]) # scale mu 86 | logvar_scale: torch.Size([32, 12, 2]) # scale logvar 87 | z_scale: torch.Size([32, 12, 2]) # scale z 88 | alpha_masks: torch.Size([32, 12, 1, 64, 64]) # objects masks 89 | """ 90 | 91 | # loss calculation 92 | all_losses = model.calc_elbo(x, model_output, beta_kl=beta_kl, 93 | beta_rec=beta_rec, kl_balance=kl_balance, 94 | recon_loss_type="mse") 95 | # let's see what's inside 96 | print(f'model.calc_elbo(): model losses:') 97 | for k in all_losses.keys(): 98 | print(f'{k}: {all_losses[k]}') 99 | print("----------------------------------") 100 | """ 101 | output: 102 | loss: the complete loss (for loss.backward()) 103 | psnr: mean PSNR 104 | kl: complete kl-divergence (of all components) 105 | loss_rec: reconstruction loss 106 | obj_on_l1: if all particles are "on" then obj_on_l1=n_particles, effective # of visible particles 107 | loss_kl_kp: kl of the position 108 | loss_kl_feat: kl of the visual features 109 | loss_kl_obj_on: kl of the transparency 110 | loss_kl_scale: kl of the scale 111 | loss_kl_depth: kl of the depth 112 | """ 113 | 114 | # only encoding: 115 | model_output = model.encode_all(x, deterministic=True) # deterministic=True -> z = mu 116 | # let's see what's inside 117 | print(f'model.encode_all(): model encoder output:') 118 | for k in model_output.keys(): 119 | out_print = model_output[k].shape if model_output[k] is not None else None 120 | print(f'{k}: {out_print}') 121 | print("----------------------------------") 122 | """ 123 | output: 124 | mu: torch.Size([32, 12, 2]) 125 | logvar: torch.Size([32, 12, 2]) 126 | z: torch.Size([32, 12, 2]) 127 | z_base: torch.Size([32, 12, 2]) 128 | kp_heatmap: None # this is not used in this model, it was used in the non-object DLP model 129 | mu_features: torch.Size([32, 12, 6]) 130 | logvar_features: torch.Size([32, 12, 6]) 131 | z_features: torch.Size([32, 12, 6]) 132 | obj_on_a: torch.Size([32, 12]) 133 | obj_on_b: torch.Size([32, 12]) 134 | obj_on: torch.Size([32, 12]) 135 | mu_depth: torch.Size([32, 12, 1]) 136 | logvar_depth: torch.Size([32, 12, 1]) 137 | z_depth: torch.Size([32, 12, 1]) 138 | cropped_objects: torch.Size([32, 12, 3, 16, 16]) 139 | bg_mask: torch.Size([32, 1, 64, 64]) 140 | mu_scale: torch.Size([32, 12, 2]) 141 | logvar_scale: torch.Size([32, 12, 2]) 142 | z_scale: torch.Size([32, 12, 2]) 143 | mu_offset: torch.Size([32, 12, 2]) 144 | logvar_offset: torch.Size([32, 12, 2]) 145 | z_offset: torch.Size([32, 12, 2]) 146 | mu_bg: torch.Size([32, 6]) 147 | logvar_bg: torch.Size([32, 6]) 148 | z_bg: torch.Size([32, 6]) 149 | z_kp_bg: torch.Size([32, 1, 2]) 150 | """ 151 | 152 | # only decoding: 153 | z = model_output['z'] 154 | z_scale = model_output['z_scale'] 155 | z_depth = model_output['z_depth'] 156 | z_obj_on = model_output['obj_on'] 157 | z_features = model_output['z_features'] 158 | z_bg = model_output['z_bg'] 159 | 160 | decode_output = model.decode_all(z=z, z_features=z_features, z_bg=z_bg, obj_on=z_obj_on, z_scale=z_scale, 161 | z_depth=z_depth) 162 | # let's see what's inside 163 | print(f'model.decode_all(): model decode output:') 164 | for k in decode_output.keys(): 165 | out_print = decode_output[k].shape if decode_output[k] is not None else None 166 | print(f'{k}: {out_print}') 167 | print("----------------------------------") 168 | """ 169 | output: 170 | rec: torch.Size([32, 3, 64, 64]) 171 | dec_objects: torch.Size([32, 12, 4, 16, 16]) # decoded glimpses (rgb + alpha channel) 172 | dec_objects_trans: torch.Size([32, 3, 64, 64]) # decoded foreground (no bg) 173 | bg: torch.Size([32, 3, 64, 64]) 174 | alpha_masks: torch.Size([32, 12, 1, 64, 64]) 175 | """ 176 | 177 | print("--- DLPv2 with Tracking ---") 178 | # tracking 179 | use_correlation_heatmaps = True 180 | use_tracking = True 181 | model = ObjectDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 182 | image_size=image_size, learned_feature_dim=learned_feature_dim, 183 | patch_size=patch_size, n_kp_enc=n_kp_enc, n_kp_prior=n_kp_prior, 184 | anchor_s=anchor_s, 185 | use_correlation_heatmaps=use_correlation_heatmaps, use_tracking=use_tracking).to(device) 186 | num_frames = 3 187 | x = torch.rand(1, num_frames, ch, image_size, image_size, device=device) 188 | model_output = model(x) 189 | # let's see what's inside 190 | print(f'model(x) tracking output:') 191 | for k in model_output.keys(): 192 | print(f'{k}: {model_output[k].shape}') 193 | print("----------------------------------") 194 | """ 195 | output: 196 | similar to before, but the first dimension is batch_size * num_frames for all 197 | """ 198 | 199 | print("--- DDLP ---") 200 | # example additional hyper-parameters 201 | use_correlation_heatmaps = True 202 | pint_layers = 6 # transformer-based dynamics module number of layers 203 | pint_heads = 8 # transformer-based dynamics module attention heads 204 | pint_dim = 256 # transformer-based dynamics module inner dimension (+projection dim) 205 | beta_dyn = 0.1 # beta-kl for the dynamics loss 206 | num_static_frames = 4 # "burn-in frames", number of initial frames with kl w.r.t. constant prior (as in DLPv2) 207 | 208 | model = ObjectDynamicsDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 209 | image_size=image_size, learned_feature_dim=learned_feature_dim, 210 | patch_size=patch_size, n_kp_enc=n_kp_enc, n_kp_prior=n_kp_prior, 211 | anchor_s=anchor_s, use_correlation_heatmaps=use_correlation_heatmaps, 212 | pint_layers=pint_layers, pint_heads=pint_heads, pint_dim=pint_dim).to(device) 213 | print(f'model.info():') 214 | print(model.info()) 215 | print("----------------------------------") 216 | timestep_horizon = 10 217 | x = torch.rand(batch_size, timestep_horizon + 1, ch, image_size, image_size, device=device) 218 | model_output = model(x) 219 | # let's see what's inside 220 | print(f'model(x) output:') 221 | for k in model_output.keys(): 222 | print(f'{k}: {model_output[k].shape}') 223 | print("----------------------------------") 224 | """ 225 | output: similar to before, but the first dimension is batch_size * num_frames for all 226 | kp_p: torch.Size([352, 15, 2]) 227 | rec: torch.Size([352, 3, 64, 64]) 228 | mu: torch.Size([352, 12, 2]) 229 | logvar: torch.Size([352, 12, 2]) 230 | z_base: torch.Size([352, 12, 2]) 231 | z: torch.Size([352, 12, 2]) 232 | z_kp_bg: torch.Size([352, 1, 2]) 233 | mu_offset: torch.Size([352, 12, 2]) 234 | logvar_offset: torch.Size([352, 12, 2]) 235 | mu_features: torch.Size([352, 12, 6]) 236 | logvar_features: torch.Size([352, 12, 6]) 237 | z_features: torch.Size([352, 12, 6]) 238 | bg: torch.Size([352, 3, 64, 64]) 239 | mu_bg: torch.Size([352, 6]) 240 | logvar_bg: torch.Size([352, 6]) 241 | z_bg: torch.Size([352, 6]) 242 | cropped_objects_original: torch.Size([352, 12, 3, 16, 16]) 243 | obj_on_a: torch.Size([352, 12]) 244 | obj_on_b: torch.Size([352, 12]) 245 | obj_on: torch.Size([352, 12]) 246 | dec_objects_original: torch.Size([352, 12, 4, 16, 16]) 247 | dec_objects: torch.Size([352, 3, 64, 64]) 248 | mu_depth: torch.Size([352, 12, 1]) 249 | logvar_depth: torch.Size([352, 12, 1]) 250 | z_depth: torch.Size([352, 12, 1]) 251 | mu_scale: torch.Size([352, 12, 2]) 252 | logvar_scale: torch.Size([352, 12, 2]) 253 | z_scale: torch.Size([352, 12, 2]) 254 | alpha_masks: torch.Size([352, 12, 1, 64, 64]) 255 | mu_dyn: torch.Size([32, 10, 12, 2]) # dynamics-prior position for t=1->T-1 given t=0->T-2 256 | logvar_dyn: torch.Size([32, 10, 12, 2]) # dynamics-prior position for t=1->T-1 given t=0->T-2 257 | mu_features_dyn: torch.Size([32, 10, 12, 6]) # dynamics-prior visual appearance for t=1->T-1 given t=0->T-2 258 | logvar_features_dyn: torch.Size([32, 9, 12, 6]) # dynamics-prior visual appearance for t=1->T-1 given t=0->T-2 259 | obj_on_a_dyn: torch.Size([32, 10, 12]) # dynamics-prior transparency for t=1->T-1 given t=0->T-2 260 | obj_on_b_dyn: torch.Size([32, 10, 12]) # dynamics-prior transparency for t=1->T-1 given t=0->T-2 261 | mu_depth_dyn: torch.Size([32, 10, 12, 1]) # dynamics-prior depth for t=1->T-1 given t=0->T-2 262 | logvar_depth_dyn: torch.Size([32, 10, 12, 1]) # dynamics-prior depth for t=1->T-1 given t=0->T-2 263 | mu_scale_dyn: torch.Size([32, 10, 12, 2]) # dynamics-prior scale for t=1->T-1 given t=0->T-2 264 | logvar_scale_dyn: torch.Size([32, 10, 12, 2]) # dynamics-prior scale for t=1->T-1 given t=0->T-2 265 | mu_bg_dyn: torch.Size([32, 10, 6]) # dynamics-prior background appearance for t=1->T-1 given t=0->T-2 266 | logvar_bg_dyn: torch.Size([32, 10, 6]) # dynamics-prior background appearance for t=1->T-1 given t=0->T-2 267 | """ 268 | 269 | # loss calculation 270 | all_losses = model.calc_elbo(x, model_output, beta_kl=beta_kl, 271 | beta_rec=beta_rec, kl_balance=kl_balance, beta_dyn=beta_dyn, 272 | num_static=num_static_frames, 273 | recon_loss_type="mse") 274 | # let's see what's inside 275 | print(f'model.calc_elbo(): model losses:') 276 | for k in all_losses.keys(): 277 | print(f'{k}: {all_losses[k]}') 278 | print("----------------------------------") 279 | """ 280 | output: 281 | model.calc_elbo(): model losses: 282 | loss: 465.488525390625 283 | psnr: 10.788384437561035 284 | kl: 2579.010986328125 285 | kl_dyn: 406.3466491699219 # <---- dynamics kl 286 | loss_rec: 4821.837890625 287 | obj_on_l1: 5.660502910614014 288 | loss_kl_kp: 997.2069702148438 289 | loss_kl_feat: 0.212782084941864 290 | loss_kl_obj_on: 56.031768798828125 291 | loss_kl_scale: 1525.77197265625 292 | loss_kl_depth: 0.08919306844472885 293 | """ 294 | 295 | # sampling 296 | num_steps = 15 297 | cond_steps = 5 298 | x = torch.rand(1, num_steps + cond_steps, ch, image_size, image_size, device=device) 299 | sample_out, sample_z_out = model.sample(x, cond_steps=cond_steps, num_steps=num_steps, deterministic=True, 300 | return_z=True) 301 | # let's see what's inside 302 | print(f'model.sample(): model dynamics unrolling:') 303 | print(f'sample_out: {sample_out.shape}') 304 | print(f'sample_z_out:') 305 | for k in sample_z_out.keys(): 306 | print(f'{k}: {sample_z_out[k].shape}') 307 | print("----------------------------------") 308 | """ 309 | output: 310 | sample_out: torch.Size([1, 20, 3, 64, 64]) # generated frames 311 | sample_z_out: # latent unrolls 312 | z_pos: torch.Size([1, 20, 12, 2]) 313 | z_scale: torch.Size([1, 20, 12, 2]) 314 | z_obj_on: torch.Size([1, 20, 12]) 315 | z_depth: torch.Size([1, 20, 12, 1]) 316 | z_features: torch.Size([1, 20, 12, 6]) 317 | z_bg_features: torch.Size([1, 20, 6]) 318 | z_ids: torch.Size([1, 20, 12]) # this is only used for the balls-interaction dataset, each particle gets an id 319 | """ 320 | -------------------------------------------------------------------------------- /utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions implementations used in the optimization of DLP. 3 | """ 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models, transforms 9 | 10 | 11 | # functions 12 | def batch_pairwise_kl(mu_x, logvar_x, mu_y, logvar_y, reverse_kl=False): 13 | """ 14 | Calculate batch-wise KL-divergence 15 | mu_x, logvar_x: [batch_size, n_x, points_dim] 16 | mu_y, logvar_y: [batch_size, n_y, points_dim] 17 | kl = -0.5 * Σ_points_dim (1 + logvar_x - logvar_y - exp(logvar_x)/exp(logvar_y) 18 | - ((mu_x - mu_y) ** 2)/exp(logvar_y)) 19 | """ 20 | if reverse_kl: 21 | mu_a, logvar_a = mu_y, logvar_y 22 | mu_b, logvar_b = mu_x, logvar_x 23 | else: 24 | mu_a, logvar_a = mu_x, logvar_x 25 | mu_b, logvar_b = mu_y, logvar_y 26 | bs, n_a, points_dim = mu_a.size() 27 | _, n_b, _ = mu_b.size() 28 | logvar_aa = logvar_a.unsqueeze(2).expand(-1, -1, n_b, -1) # [batch_size, n_a, n_b, points_dim] 29 | logvar_bb = logvar_b.unsqueeze(1).expand(-1, n_a, -1, -1) # [batch_size, n_a, n_b, points_dim] 30 | mu_aa = mu_a.unsqueeze(2).expand(-1, -1, n_b, -1) # [batch_size, n_a, n_b, points_dim] 31 | mu_bb = mu_b.unsqueeze(1).expand(-1, n_a, -1, -1) # [batch_size, n_a, n_b, points_dim] 32 | p_kl = -0.5 * (1 + logvar_aa - logvar_bb - logvar_aa.exp() / logvar_bb.exp() 33 | - ((mu_aa - mu_bb) ** 2) / logvar_bb.exp()).sum(-1) # [batch_size, n_x, n_y] 34 | return p_kl 35 | 36 | 37 | def batch_pairwise_dist(x, y, metric='l2'): 38 | assert metric in ['l2', 'l2_simple', 'l1', 'cosine'], f'metric {metric} unrecognized' 39 | bs, num_points_x, points_dim = x.size() 40 | _, num_points_y, _ = y.size() 41 | if metric == 'cosine': 42 | dist_func = torch.nn.functional.cosine_similarity 43 | P = -dist_func(x.unsqueeze(2), y.unsqueeze(1), dim=-1, eps=1e-8) 44 | elif metric == 'l1': 45 | P = torch.abs(x.unsqueeze(2) - y.unsqueeze(1)).sum(-1) 46 | elif metric == 'l2_simple': 47 | P = ((x.unsqueeze(2) - y.unsqueeze(1)) ** 2).sum(-1) 48 | else: 49 | xx = torch.bmm(x, x.transpose(2, 1)) 50 | yy = torch.bmm(y, y.transpose(2, 1)) 51 | zz = torch.bmm(x, y.transpose(2, 1)) 52 | diag_ind_x = torch.arange(0, num_points_x, device=x.device) 53 | diag_ind_y = torch.arange(0, num_points_y, device=y.device) 54 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1)) 55 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 56 | P = rx.transpose(2, 1) + ry - 2 * zz 57 | return P 58 | 59 | 60 | def calc_reconstruction_loss(x, recon_x, loss_type='mse', reduction='sum'): 61 | """ 62 | 63 | :param x: original inputs 64 | :param recon_x: reconstruction of the VAE's input 65 | :param loss_type: "mse", "l1", "bce" 66 | :param reduction: "sum", "mean", "none" 67 | :return: recon_loss 68 | """ 69 | if reduction not in ['sum', 'mean', 'none']: 70 | raise NotImplementedError 71 | recon_x = recon_x.view(recon_x.size(0), -1) 72 | x = x.view(x.size(0), -1) 73 | if loss_type == 'mse': 74 | recon_error = F.mse_loss(recon_x, x, reduction='none') 75 | recon_error = recon_error.sum(1) 76 | if reduction == 'sum': 77 | recon_error = recon_error.sum() 78 | elif reduction == 'mean': 79 | recon_error = recon_error.mean() 80 | elif loss_type == 'l1': 81 | recon_error = F.l1_loss(recon_x, x, reduction=reduction) 82 | elif loss_type == 'bce': 83 | recon_error = F.binary_cross_entropy(recon_x, x, reduction=reduction) 84 | else: 85 | raise NotImplementedError 86 | return recon_error 87 | 88 | 89 | def calc_kl(logvar, mu, mu_o=0.0, logvar_o=0.0, reduce='sum', balance=0.5): 90 | """ 91 | Calculate kl-divergence 92 | :param logvar: log-variance from the encoder 93 | :param mu: mean from the encoder 94 | :param mu_o: negative mean for outliers (hyper-parameter) 95 | :param logvar_o: negative log-variance for outliers (hyper-parameter) 96 | :param reduce: type of reduce: 'sum', 'none' 97 | :param balance: balancing coefficient between posterior and prior 98 | :return: kld 99 | """ 100 | if not isinstance(mu_o, torch.Tensor): 101 | mu_o = torch.tensor(mu_o).to(mu.device) 102 | if not isinstance(logvar_o, torch.Tensor): 103 | logvar_o = torch.tensor(logvar_o).to(mu.device) 104 | if balance == 0.5: 105 | kl = -0.5 * (1 + logvar - logvar_o - logvar.exp() / torch.exp(logvar_o) - (mu - mu_o).pow(2) / torch.exp( 106 | logvar_o)).sum(1) 107 | else: 108 | # detach post 109 | mu_post = mu.detach() 110 | logvar_post = logvar.detach() 111 | mu_prior = mu_o 112 | logvar_prior = logvar_o 113 | kl_a = -0.5 * (1 + logvar_post - logvar_prior - logvar_post.exp() / torch.exp(logvar_prior) - ( 114 | mu_post - mu_prior).pow(2) / torch.exp(logvar_prior)).sum(1) 115 | # detach prior 116 | mu_post = mu 117 | logvar_post = logvar 118 | mu_prior = mu_o.detach() 119 | logvar_prior = logvar_o.detach() 120 | kl_b = -0.5 * (1 + logvar_post - logvar_prior - logvar_post.exp() / torch.exp(logvar_prior) - ( 121 | mu_post - mu_prior).pow(2) / torch.exp(logvar_prior)).sum(1) 122 | kl = (1 - balance) * kl_a + balance * kl_b 123 | if reduce == 'sum': 124 | kl = torch.sum(kl) 125 | elif reduce == 'mean': 126 | kl = torch.mean(kl) 127 | return kl 128 | 129 | 130 | def calc_kl_bern(post_prob, prior_prob, eps=1e-15, reduce='none'): 131 | """ 132 | Compute kl divergence of Bernoulli variable 133 | :param post_prob [batch_size, 1], in [0,1] 134 | :param prior_prob [batch_size, 1], in [0,1] 135 | :return: kl divergence, (B, ...) 136 | """ 137 | kl = post_prob * (torch.log(post_prob + eps) - torch.log(prior_prob + eps)) + (1 - post_prob) * ( 138 | torch.log(1 - post_prob + eps) - torch.log(1 - prior_prob + eps)) 139 | if reduce == 'sum': 140 | kl = kl.sum() 141 | elif reduce == 'mean': 142 | kl = kl.mean() 143 | else: 144 | kl = kl.squeeze(-1) 145 | return kl 146 | 147 | 148 | def log_beta_function(alpha, beta, eps=1e-5): 149 | """ 150 | B(alpha, beta) = gamma(alpha) * gamma(beta) / gamma(alpha + beta) 151 | logB = loggamma(alpha) + loggamma(beta) - loggamaa(alpha + beta) 152 | """ 153 | # return torch.special.gammaln(alpha) + torch.special.gammaln(beta) - torch.special.gammaln(alpha + beta) 154 | return torch.lgamma(alpha + eps) + torch.lgamma(beta + eps) - torch.lgamma(alpha + beta + eps) 155 | 156 | 157 | def calc_kl_beta_dist(alpha_post, beta_post, alpha_prior, beta_prior, reduce='none', eps=1e-5, balance=0.5): 158 | """ 159 | Compute kl divergence of Beta variable 160 | https://en.wikipedia.org/wiki/Beta_distribution 161 | :param alpha_post, beta_post [batch_size, 1] 162 | :param alpha_prior, beta_prior [batch_size, 1] 163 | :param balance kl balance between posterior and prior 164 | :return: kl divergence, (B, ...) 165 | """ 166 | if balance == 0.5: 167 | log_bettas = log_beta_function(alpha_prior, beta_prior) - log_beta_function(alpha_post, beta_post) 168 | alpha = (alpha_post - alpha_prior) * torch.digamma(alpha_post + eps) 169 | beta = (beta_post - beta_prior) * torch.digamma(beta_post + eps) 170 | alpha_beta = (alpha_prior - alpha_post + beta_prior - beta_post) * torch.digamma(alpha_post + beta_post + eps) 171 | kl = log_bettas + alpha + beta + alpha_beta 172 | else: 173 | # detach post 174 | log_bettas = log_beta_function(alpha_prior, beta_prior) - log_beta_function(alpha_post.detach(), 175 | beta_post.detach()) 176 | alpha = (alpha_post - alpha_prior) * torch.digamma(alpha_post.detach() + eps) 177 | beta = (beta_post.detach() - beta_prior) * torch.digamma(beta_post.detach() + eps) 178 | alpha_beta = (alpha_prior - alpha_post.detach() + beta_prior - beta_post.detach()) * torch.digamma( 179 | alpha_post.detach() + beta_post.detach() + eps) 180 | kl_a = log_bettas + alpha + beta + alpha_beta 181 | 182 | # detach prior 183 | log_bettas = log_beta_function(alpha_prior.detach(), beta_prior.detach()) - log_beta_function(alpha_post, 184 | beta_post) 185 | alpha = (alpha_post - alpha_prior.detach()) * torch.digamma(alpha_post + eps) 186 | beta = (beta_post - beta_prior.detach()) * torch.digamma(beta_post + eps) 187 | alpha_beta = (alpha_prior.detach() - alpha_post + beta_prior.detach() - beta_post) * torch.digamma( 188 | alpha_post + beta_post + eps) 189 | kl_b = log_bettas + alpha + beta + alpha_beta 190 | kl = (1 - balance) * kl_a + balance * kl_b 191 | if reduce == 'sum': 192 | kl = kl.sum() 193 | elif reduce == 'mean': 194 | kl = kl.mean() 195 | else: 196 | kl = kl.squeeze(-1) 197 | return kl 198 | 199 | 200 | # classes 201 | class ChamferLossKL(nn.Module): 202 | """ 203 | Calculates the KL-divergence between two sets of (R.V.) particle coordinates. 204 | """ 205 | 206 | def __init__(self, use_reverse_kl=False): 207 | super(ChamferLossKL, self).__init__() 208 | self.use_reverse_kl = use_reverse_kl 209 | 210 | def forward(self, mu_preds, logvar_preds, mu_gts, logvar_gts, posterior_mask=None): 211 | """ 212 | mu_preds, logvar_preds: [bs, n_x, feat_dim] 213 | mu_gts, logvar_gts: [bs, n_y, feat_dim] 214 | posterior_mask: [bs, n_x] 215 | """ 216 | p_kl = batch_pairwise_kl(mu_preds, logvar_preds, mu_gts, logvar_gts, reverse_kl=False) 217 | # [bs, n_x, n_y] 218 | if self.use_reverse_kl: 219 | p_rkl = batch_pairwise_kl(mu_preds, logvar_preds, mu_gts, logvar_gts, reverse_kl=True) 220 | p_kl = 0.5 * (p_kl + p_rkl.transpose(2, 1)) 221 | mins, _ = torch.min(p_kl, 1) # [bs, n_y] 222 | loss_1 = torch.sum(mins, 1) 223 | mins, _ = torch.min(p_kl, 2) # [bs, n_x] 224 | if posterior_mask is not None: 225 | mins = mins * posterior_mask 226 | loss_2 = torch.sum(mins, 1) 227 | return loss_1 + loss_2 228 | 229 | 230 | class NetVGGFeatures(nn.Module): 231 | 232 | def __init__(self, layer_ids): 233 | super().__init__() 234 | 235 | self.vggnet = models.vgg16(pretrained=True) 236 | self.vggnet.eval() 237 | self.vggnet.requires_grad_(False) 238 | self.layer_ids = layer_ids 239 | 240 | def forward(self, x): 241 | output = [] 242 | for i in range(self.layer_ids[-1] + 1): 243 | x = self.vggnet.features[i](x) 244 | 245 | if i in self.layer_ids: 246 | output.append(x) 247 | 248 | return output 249 | 250 | 251 | class VGGDistance(nn.Module): 252 | 253 | def __init__(self, layer_ids=(2, 7, 12, 21, 30), accumulate_mode='sum', device=torch.device("cpu"), 254 | normalize=True, use_loss_scale=False, vgg_coeff=0.12151): 255 | super().__init__() 256 | 257 | self.vgg = NetVGGFeatures(layer_ids).to(device) 258 | self.layer_ids = layer_ids 259 | self.accumulate_mode = accumulate_mode 260 | self.device = device 261 | self.use_normalization = normalize 262 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 263 | std=[0.229, 0.224, 0.225]) 264 | self.use_loss_scale = use_loss_scale 265 | self.vgg_coeff = vgg_coeff 266 | 267 | def forward(self, I1, I2, reduction='sum', only_image=False): 268 | b_sz = I1.size(0) 269 | num_ch = I1.size(1) 270 | 271 | if self.accumulate_mode == 'sum': 272 | loss = ((I1 - I2) ** 2).view(b_sz, -1).sum(1) 273 | # if normalized, effectively: (1 / (std ** 2)) * (I_1 - I_2) ** 2 274 | elif self.accumulate_mode == 'ch_mean': 275 | loss = ((I1 - I2) ** 2).view(b_sz, I1.shape[1], -1).mean(1).sum(-1) 276 | else: 277 | loss = ((I1 - I2) ** 2).view(b_sz, -1).mean(1) 278 | 279 | if self.use_normalization: 280 | I1, I2 = self.normalize(I1), self.normalize(I2) 281 | 282 | if num_ch == 1: 283 | I1 = I1.repeat(1, 3, 1, 1) 284 | I2 = I2.repeat(1, 3, 1, 1) 285 | 286 | f1 = self.vgg(I1) 287 | f2 = self.vgg(I2) 288 | 289 | if not only_image: 290 | for i in range(len(self.layer_ids)): 291 | if self.accumulate_mode == 'sum': 292 | layer_loss = ((f1[i] - f2[i]) ** 2).view(b_sz, -1).sum(1) 293 | elif self.accumulate_mode == 'ch_mean': 294 | layer_loss = ((f1[i] - f2[i]) ** 2).view(b_sz, f1[i].shape[1], -1).mean(1).sum(-1) 295 | else: 296 | layer_loss = ((f1[i] - f2[i]) ** 2).view(b_sz, -1).mean(1) 297 | c = self.vgg_coeff if self.use_normalization else 1.0 298 | loss = loss + c * layer_loss 299 | 300 | if self.use_loss_scale: 301 | # by using `sum` for the features, and using scaling instead of `mean` we maintain the weight 302 | # of each dimension contribution to the loss 303 | max_dim = max([np.product(f.shape[1:]) for f in f1]) 304 | scale = 1 / max_dim 305 | loss = scale * loss 306 | if reduction == 'mean': 307 | return loss.mean() 308 | elif reduction == 'sum': 309 | return loss.sum() 310 | else: 311 | return loss 312 | 313 | def get_dimensions(self, device=torch.device("cpu")): 314 | dims = [] 315 | dummy_input = torch.zeros(1, 3, 128, 128).to(device) 316 | dims.append(dummy_input.view(1, -1).size(1)) 317 | f = self.vgg(dummy_input) 318 | for i in range(len(self.layer_ids)): 319 | dims.append(f[i].view(1, -1).size(1)) 320 | return dims 321 | 322 | 323 | class ChamferLoss(nn.Module): 324 | 325 | def __init__(self): 326 | super(ChamferLoss, self).__init__() 327 | # self.use_cuda = torch.cuda.is_available() 328 | 329 | def forward(self, preds, gts): 330 | P = self.batch_pairwise_dist(gts, preds) 331 | mins, _ = torch.min(P, 1) 332 | loss_1 = torch.sum(mins, 1) 333 | mins, _ = torch.min(P, 2) 334 | loss_2 = torch.sum(mins, 1) 335 | return loss_1 + loss_2 336 | 337 | def batch_pairwise_dist(self, x, y): 338 | bs, num_points_x, points_dim = x.size() 339 | _, num_points_y, _ = y.size() 340 | xx = torch.bmm(x, x.transpose(2, 1)) 341 | yy = torch.bmm(y, y.transpose(2, 1)) 342 | zz = torch.bmm(x, y.transpose(2, 1)) 343 | diag_ind_x = torch.arange(0, num_points_x, device=x.device, dtype=torch.long) 344 | diag_ind_y = torch.arange(0, num_points_y, device=y.device, dtype=torch.long) 345 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as( 346 | zz.transpose(2, 1)) 347 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 348 | P = rx.transpose(2, 1) + ry - 2 * zz 349 | return P 350 | 351 | 352 | if __name__ == '__main__': 353 | bs = 32 354 | n_points_x = 10 355 | n_points_y = 15 356 | dim = 8 357 | x = torch.randn(bs, n_points_x, dim) 358 | y = torch.randn(bs, n_points_y, dim) 359 | for metric in ['cosine', 'l1', 'l2', 'l2_simple']: 360 | P = batch_pairwise_dist(x, y, metric) 361 | print(f'metric: {metric}, P: {P.shape}, max: {P.max()}, min: {P.min()}') 362 | -------------------------------------------------------------------------------- /train_diffuse_ddlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main training function of DiffuseDDLP 3 | """ 4 | 5 | import argparse 6 | import os 7 | import shutil 8 | import json 9 | from utils.util_func import prepare_logdir, get_config 10 | from tqdm.auto import tqdm 11 | 12 | # torch 13 | import torch 14 | from torch.utils.data import DataLoader 15 | 16 | # datasets 17 | from datasets.get_dataset import get_video_dataset 18 | 19 | # models 20 | from modules.diffusion_modules import TrainerDiffuseDDLP, GaussianDiffusionPINT, PINTDenoiser 21 | from models import ObjectDynamicsDLP 22 | 23 | """ 24 | Particle Normalization 25 | Calculate and save the latent statistics of particles for normalization/standardization purposes. 26 | Denoisers' input is usually normalized, thus, we need to calculate the statistics of the particles. 27 | """ 28 | 29 | 30 | class ParticleNormalization(torch.nn.Module): 31 | def __init__(self, config, mode='minmax', eps=1e-5): 32 | super().__init__() 33 | assert mode in ["minmax", "std"], f'mode: {mode} not supported' 34 | self.diffusion_config = config 35 | self.root = config['ddlp_dir'] 36 | self.eps = eps 37 | self.ds = config['ds'] 38 | device = config['device'] 39 | if 'cuda' in device: 40 | device = torch.device(f'{device}' if torch.cuda.is_available() else 'cpu') 41 | else: 42 | device = torch.device('cpu') 43 | self.device = device 44 | self.mode = mode 45 | self.ddlp_dir = config['ddlp_dir'] 46 | self.ddlp_ckpt = config['ddlp_ckpt'] 47 | ddlp_conf = os.path.join(self.ddlp_dir, 'hparams.json') 48 | ddlp_config = get_config(ddlp_conf) 49 | self.config = ddlp_config 50 | self.particle_feature_dim = self.config['learned_feature_dim'] 51 | self.fg_total_dim = 2 + 2 + 2 + self.particle_feature_dim # (x, y), (scale_x, scale_y), depth, transparency 52 | self.bg_total_dim = self.particle_feature_dim 53 | mu = torch.zeros(self.fg_total_dim) 54 | self.register_buffer('mu', mu) 55 | mu_bg = torch.zeros(self.bg_total_dim) 56 | self.register_buffer('mu_bg', mu_bg) 57 | std = torch.ones(self.fg_total_dim) 58 | self.register_buffer('std', std) 59 | std_bg = torch.ones(self.bg_total_dim) 60 | self.register_buffer('std_bg', std_bg) 61 | min_val = torch.zeros(self.fg_total_dim) 62 | self.register_buffer('min_val', min_val) 63 | max_val = torch.zeros(self.fg_total_dim) 64 | self.register_buffer('max_val', max_val) 65 | min_val_bg = torch.zeros(self.bg_total_dim) 66 | self.register_buffer('min_val_bg', min_val_bg) 67 | max_val_bg = torch.zeros(self.bg_total_dim) 68 | self.register_buffer('max_val_bg', max_val_bg) 69 | # get statistics 70 | self.get_latent_statistics() 71 | print(f'mu: {self.mu}, std: {self.std}, min: {self.min_val}, max: {self.max_val}') 72 | 73 | def get_latent_statistics(self): 74 | stats_path = os.path.join(self.root, 'latent_stats.pth') 75 | if os.path.exists(stats_path): 76 | params = torch.load(stats_path) 77 | self.load_state_dict(params) 78 | print(f'latent stats loaded from {stats_path}') 79 | else: 80 | # calculate stats 81 | print(f'latent stats not found, calculating stats...') 82 | self.calc_latent_stats() 83 | 84 | def calc_latent_stats(self, ): 85 | # load model 86 | ddlp_config = self.config 87 | ddlp_ckpt = self.ddlp_ckpt 88 | device = self.device 89 | # load model 90 | image_size = ddlp_config['image_size'] 91 | ch = ddlp_config['ch'] 92 | enc_channels = ddlp_config['enc_channels'] 93 | prior_channels = ddlp_config['prior_channels'] 94 | use_correlation_heatmaps = ddlp_config['use_correlation_heatmaps'] 95 | enable_enc_attn = ddlp_config['enable_enc_attn'] 96 | filtering_heuristic = ddlp_config['filtering_heuristic'] 97 | 98 | model = ObjectDynamicsDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 99 | image_size=image_size, n_kp=ddlp_config['n_kp'], 100 | learned_feature_dim=ddlp_config['learned_feature_dim'], 101 | pad_mode=ddlp_config['pad_mode'], 102 | sigma=ddlp_config['sigma'], 103 | dropout=ddlp_config['dropout'], patch_size=ddlp_config['patch_size'], 104 | n_kp_enc=ddlp_config['n_kp_enc'], 105 | n_kp_prior=ddlp_config['n_kp_prior'], kp_range=ddlp_config['kp_range'], 106 | kp_activation=ddlp_config['kp_activation'], 107 | anchor_s=ddlp_config['anchor_s'], 108 | use_resblock=ddlp_config['use_resblock'], 109 | timestep_horizon=ddlp_config['timestep_horizon'], 110 | predict_delta=ddlp_config['predict_delta'], 111 | scale_std=ddlp_config['scale_std'], 112 | offset_std=ddlp_config['offset_std'], obj_on_alpha=ddlp_config['obj_on_alpha'], 113 | obj_on_beta=ddlp_config['obj_on_beta'], pint_heads=ddlp_config['pint_heads'], 114 | pint_layers=ddlp_config['pint_layers'], pint_dim=ddlp_config['pint_dim'], 115 | use_correlation_heatmaps=use_correlation_heatmaps, 116 | enable_enc_attn=enable_enc_attn, filtering_heuristic=filtering_heuristic).to(device) 117 | model.load_state_dict(torch.load(ddlp_ckpt, map_location=device)) 118 | model.eval() 119 | model.requires_grad_(False) 120 | print(f"loaded ddlp model from {ddlp_ckpt}") 121 | print(f"particle normalizer: loaded ddlp model from {ddlp_ckpt}") 122 | seq_len = 50 if self.ds == 'traffic' else 100 123 | ds = get_video_dataset(self.ds, root=self.diffusion_config['ds_root'], mode='train', seq_len=seq_len) 124 | dl = DataLoader(ds, batch_size=32, shuffle=False, pin_memory=True, num_workers=4) 125 | pbar = tqdm(iterable=dl) 126 | z_all = [] 127 | z_bg_all = [] 128 | for i, batch in enumerate(pbar): 129 | x = batch[0][:, :self.diffusion_config['diffuse_frames']].to(device) 130 | x_prior = x 131 | batch_size, timesteps, ch, h, w = x.shape 132 | fg_dict = model.fg_sequential_opt(x, deterministic=True, x_prior=x, reshape=True) 133 | # encoder 134 | z = fg_dict['z'] 135 | z_features = fg_dict['z_features'] 136 | z_obj_on = fg_dict['obj_on'] 137 | z_depth = fg_dict['z_depth'] 138 | z_scale = fg_dict['z_scale'] 139 | 140 | # decoder 141 | bg_mask = fg_dict['bg_mask'] 142 | 143 | x_in = x.view(-1, *x.shape[2:]) # [bs * T, ...] 144 | bg_dict = model.bg_module(x_in, bg_mask, deterministic=True) 145 | z_bg = bg_dict['z_bg'] 146 | z_kp_bg = bg_dict['z_kp'] 147 | 148 | # collect and pad 149 | z_fg = torch.cat([z, z_scale, z_depth, z_obj_on.unsqueeze(-1), z_features], dim=-1) 150 | # [batch_size * timesteps, n_kp, features] 151 | z_fg = z_fg.view(-1, *z_fg.shape[2:]) 152 | # [batch_size * timesteps * n_kp, features] 153 | z_all.append(z_fg.data.cpu()) 154 | z_bg_all.append(z_bg.data.cpu()) 155 | 156 | pbar.close() 157 | z_all = torch.cat(z_all, dim=0) 158 | z_bg_all = torch.cat(z_bg_all, dim=0) 159 | self.mu = z_all.mean(0) 160 | self.std = z_all.std(0) 161 | self.min_val = z_all.min(0)[0] 162 | self.max_val = z_all.max(0)[0] 163 | 164 | self.mu_bg = z_bg_all.mean(0) 165 | self.std_bg = z_bg_all.std(0) 166 | self.min_val_bg = z_bg_all.min(0)[0] 167 | self.max_val_bg = z_bg_all.max(0)[0] 168 | stats_path = os.path.join(self.root, 'latent_stats.pth') 169 | torch.save(self.state_dict(), stats_path) 170 | print(f'saved statistics @ {stats_path}') 171 | 172 | def normalize(self, z=None, z_bg=None): 173 | if self.mode == 'minmax': 174 | if z is not None: 175 | z = (z - self.min_val) / (self.max_val - self.min_val + self.eps) # [0, 1] 176 | z = 2 * z - 1 # [-1, 1] 177 | if z_bg is not None: 178 | z_bg = (z_bg - self.min_val_bg) / (self.max_val_bg - self.min_val_bg + self.eps) # [0, 1] 179 | z_bg = 2 * z_bg - 1 # [-1, 1] 180 | else: 181 | # std 182 | if z is not None: 183 | z = (z - self.mu) / (self.std + self.eps) 184 | if z_bg is not None: 185 | z_bg = (z_bg - self.mu_bg) / (self.std_bg + self.eps) 186 | 187 | return z, z_bg 188 | 189 | def unnormalize(self, z=None, z_bg=None): 190 | if self.mode == 'minmax': 191 | if z is not None: 192 | z = (z + 1) / 2 # [0, 1] 193 | z = z * (self.max_val - self.min_val + self.eps) + self.min_val 194 | if z_bg is not None: 195 | z_bg = (z_bg + 1) / 2 # [0, 1] 196 | z_bg = z_bg * (self.max_val_bg - self.min_val_bg + self.eps) + self.min_val_bg 197 | else: 198 | # std 199 | if z is not None: 200 | z = z * (self.std + self.eps) + self.mu 201 | if z_bg is not None: 202 | z_bg = z_bg * (self.std_bg + self.eps) + self.mu_bg 203 | 204 | return z, z_bg 205 | 206 | def forward(self, z=None, z_bg=None, normalize=True): 207 | if normalize: 208 | return self.normalize(z, z_bg) 209 | else: 210 | return self.unnormalize(z, z_bg) 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = argparse.ArgumentParser(description="Diffusion DDLP Trainer") 215 | parser.add_argument("-c", "--config", type=str, default='diffuse_ddlp', 216 | help="json file name of config file in './configs'") 217 | args = parser.parse_args() 218 | # parse input 219 | conf = args.config 220 | if conf.endswith('json'): 221 | conf_path = os.path.join('./configs', conf) 222 | else: 223 | conf_path = os.path.join('./configs', f'{conf}.json') 224 | diffusion_config = get_config(conf_path) 225 | ds = diffusion_config['ds'] 226 | ds_root = diffusion_config['ds_root'] # dataset root 227 | batch_size = diffusion_config['batch_size'] 228 | diffuse_frames = diffusion_config['diffuse_frames'] # number of particle frames to generate 229 | lr = diffusion_config['lr'] 230 | train_num_steps = diffusion_config['train_num_steps'] 231 | diffusion_num_steps = diffusion_config['diffusion_num_steps'] 232 | loss_type = diffusion_config['loss_type'] 233 | particle_norm = diffusion_config['particle_norm'] 234 | device = diffusion_config['device'] 235 | if 'cuda' in device: 236 | device = torch.device(f'{device}' if torch.cuda.is_available() else 'cpu') 237 | else: 238 | device = torch.device('cpu') 239 | """ 240 | load pre-trained DDLP 241 | """ 242 | ddlp_dir = diffusion_config['ddlp_dir'] 243 | ddlp_ckpt = diffusion_config['ddlp_ckpt'] 244 | ddlp_conf = os.path.join(ddlp_dir, 'hparams.json') 245 | ddlp_config = get_config(ddlp_conf) 246 | # load model 247 | image_size = ddlp_config['image_size'] 248 | ch = ddlp_config['ch'] 249 | enc_channels = ddlp_config['enc_channels'] 250 | prior_channels = ddlp_config['prior_channels'] 251 | use_correlation_heatmaps = ddlp_config['use_correlation_heatmaps'] 252 | enable_enc_attn = ddlp_config['enable_enc_attn'] 253 | filtering_heuristic = ddlp_config['filtering_heuristic'] 254 | animation_fps = ddlp_config["animation_fps"] 255 | 256 | model = ObjectDynamicsDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 257 | image_size=image_size, n_kp=ddlp_config['n_kp'], 258 | learned_feature_dim=ddlp_config['learned_feature_dim'], 259 | pad_mode=ddlp_config['pad_mode'], 260 | sigma=ddlp_config['sigma'], 261 | dropout=ddlp_config['dropout'], patch_size=ddlp_config['patch_size'], 262 | n_kp_enc=ddlp_config['n_kp_enc'], 263 | n_kp_prior=ddlp_config['n_kp_prior'], kp_range=ddlp_config['kp_range'], 264 | kp_activation=ddlp_config['kp_activation'], 265 | anchor_s=ddlp_config['anchor_s'], 266 | use_resblock=ddlp_config['use_resblock'], 267 | timestep_horizon=ddlp_config['timestep_horizon'], 268 | predict_delta=ddlp_config['predict_delta'], 269 | scale_std=ddlp_config['scale_std'], 270 | offset_std=ddlp_config['offset_std'], obj_on_alpha=ddlp_config['obj_on_alpha'], 271 | obj_on_beta=ddlp_config['obj_on_beta'], pint_heads=ddlp_config['pint_heads'], 272 | pint_layers=ddlp_config['pint_layers'], pint_dim=ddlp_config['pint_dim'], 273 | use_correlation_heatmaps=use_correlation_heatmaps, 274 | enable_enc_attn=enable_enc_attn, filtering_heuristic=filtering_heuristic).to(device) 275 | model.load_state_dict(torch.load(ddlp_ckpt, map_location=device)) 276 | model.eval() 277 | model.requires_grad_(False) 278 | print(f"loaded ddlp model from {ddlp_ckpt}") 279 | 280 | features_dim = 2 + 2 + 1 + 1 + ddlp_config['learned_feature_dim'] 281 | # features: xy, scale_xy, depth, obj_on, particle features 282 | # total particles: n_kp + 1 for bg 283 | ddpm_feat_dim = features_dim 284 | 285 | denoiser_model = PINTDenoiser(features_dim, hidden_dim=ddlp_config['pint_dim'], 286 | projection_dim=ddlp_config['pint_dim'], 287 | n_head=ddlp_config['pint_heads'], n_layer=ddlp_config['pint_layers'], 288 | block_size=diffuse_frames, dropout=0.1, 289 | predict_delta=False, positional_bias=True, max_particles=ddlp_config['n_kp_enc'] + 1, 290 | self_condition=False, 291 | learned_sinusoidal_cond=False, random_fourier_features=False, 292 | learned_sinusoidal_dim=16).to(device) 293 | 294 | diffusion = GaussianDiffusionPINT( 295 | denoiser_model, 296 | seq_length=diffuse_frames, 297 | timesteps=diffusion_num_steps, # number of steps 298 | sampling_timesteps=diffusion_num_steps, 299 | loss_type=loss_type, # L1 or L2 300 | objective='pred_x0', 301 | ).to(device) 302 | 303 | particle_normalizer = ParticleNormalization(diffusion_config, mode=particle_norm).to(device) 304 | result_dir = diffusion_config.get('result_dir') 305 | if result_dir is None: 306 | run_name = f'{ds}_diffuse_ddlp' 307 | result_dir = prepare_logdir(run_name, src_dir='./') 308 | diffusion_config['result_dir'] = result_dir 309 | 310 | # make copy of configs 311 | path_to_conf = os.path.join(result_dir, 'ddlp_hparams.json') 312 | with open(path_to_conf, "w") as outfile: 313 | json.dump(ddlp_config, outfile, indent=2) 314 | path_to_conf = os.path.join(result_dir, 'diffusion_hparams.json') 315 | with open(path_to_conf, "w") as outfile: 316 | json.dump(diffusion_config, outfile, indent=2) 317 | latent_stats_path = os.path.join(ddlp_dir, 'latent_stats.pth') # make a copy of latent stats just in case 318 | latent_stats_path_target = os.path.join(result_dir, 'latent_stats.pth') 319 | shutil.copy(latent_stats_path, latent_stats_path_target) 320 | 321 | # expects input: [batch_size, feature_dim, seq_len] 322 | 323 | trainer = TrainerDiffuseDDLP( 324 | diffusion, 325 | ddlp_model=model, 326 | diffusion_config=diffusion_config, 327 | particle_norm=particle_normalizer, 328 | train_batch_size=batch_size, 329 | train_lr=lr, 330 | train_num_steps=train_num_steps, # total training steps 331 | gradient_accumulate_every=1, # gradient accumulation steps 332 | ema_decay=0.995, # exponential moving average decay 333 | amp=False, # turn on mixed precision 334 | seq_len=diffuse_frames, 335 | save_and_sample_every=1000, 336 | results_folder=result_dir, animation_fps=animation_fps 337 | ) 338 | 339 | trainer.train() 340 | -------------------------------------------------------------------------------- /docs/hyperparameters.md: -------------------------------------------------------------------------------- 1 | # Hyperparameters 2 | We provide details for every adjustable hyperparameter in this project. 3 | 4 | This list might look long, but in practice very few hyperparameters need to be tuned, we just wanted to be thorough :) 5 | 6 | Important hyperparameters: 7 | 8 | 9 | | Hyperparameter | Description | Recommended/Legal Values | 10 | |------------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------| 11 | | `ds` | Dataset name | "obj3d128", "traffic", "phyre", etc. | 12 | | `root` | Root directory for dataset | Example: `/mnt/data/tal/obj3d/` | 13 | | `device` | Device to run the model on | "cuda", "cpu" | 14 | | `batch_size` | Number of samples in each mini-batch | Integer value | 15 | | `lr` | Learning rate | Floating-point value, `0.0002` | 16 | | `num_epochs` | Number of training epochs | Integer value | 17 | | `recon_loss_type` | Type of reconstruction loss | "vgg", "mse" | 18 | | `beta_kl` | Weight for the KL divergence loss | Floating-point value, recommended: 0.15 for "mse", 40.0 for "vgg" | 19 | | `patch_size` | Patch size for the prior keypoint proposals network | Integer value, recommended: [8 ,16, 32] | 20 | | `n_kp_enc` | Number of posterior keypoints to be learned | Integer value, we used: [10, 12, 25] | 21 | | `learned_feature_dim` | Dimension of latent visual features extracted from glimpses | Integer value, best: [4, 10] | 22 | | `n_kp_prior` | Number of keypoints to filter from the set of prior keypoints | Integer value, in practice, we don't filter (=number of prior patches) | 23 | | `anchor_s` | Glimpse size defined as a ratio of image size, effectively the posterior patch size (e.g., 0.25 for image_size=128 -> glimpse_size=32) | Floating-point value, best: 0.125 for `phyre`, 0.25 for all others | 24 | | `enc_channels` | Number of channels for the posterior CNN | List of integer values, (32, 32, 64, 64, 128, 128) | 25 | | `prior_channels` | Number of channels for the prior CNN | List of integer values (32, 32, 64) | 26 | | `timestep_horizon` | Number of timesteps to train DDLP on (DDLP) | Integer value, typical values: [10, 15, 20] | 27 | | `beta_dyn` | Weight for the KL dynamics loss (DDLP) | Floating-point value, recommended: =`beta_kl` | 28 | | `pint_dim` | Dimension of the transformer model (DDLP) | Integer value, best: [256, 512] | 29 | 30 | 31 | Full list: 32 | 33 | | Hyperparameter | Description | Recommended/Legal Values | 34 | |------------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------| 35 | | `ds` | Dataset name | "obj3d128", "traffic", "phyre", etc. | 36 | | `root` | Root directory for dataset | Example: `/mnt/data/tal/obj3d/` | 37 | | `device` | Device to run the model on | "cuda", "cpu" | 38 | | `batch_size` | Number of samples in each mini-batch | Integer value | 39 | | `lr` | Learning rate | Floating-point value, `0.0002` | 40 | | `kp_activation` | Activation function for keypoints | **"tanh"** (use that), "sigmoid" | 41 | | `pad_mode` | Padding mode for the CNNs | **"replicate"** (best), "zeros " | 42 | | `load_model` | Flag to load pre-trained model | true, false | 43 | | `pretrained_path` | Path to the pre-trained model | String or null | 44 | | `num_epochs` | Number of training epochs | Integer value | 45 | | `n_kp` | Number of keypoints to extract from each patch | Integer value, recommended: 1 | 46 | | `recon_loss_type` | Type of reconstruction loss | "vgg", "mse" | 47 | | `sigma` | Prior standard deviation for the keypoints in Chamfer-KL | Floating-point value, unused (leave as is) | 48 | | `beta_kl` | Weight for the KL divergence loss | Floating-point value, recommended: 0.15 for "mse", 40.0 for "vgg" | 49 | | `beta_rec` | Weight for the reconstruction loss | Floating-point value, recommended: 1.0 | 50 | | `patch_size` | Patch size for the prior keypoint proposals network | Integer value, recommended: [8 ,16, 32] | 51 | | `topk` | Top-k value for plotting keypoints (used only for keypoints ploting) | Integer value, default: 10 | 52 | | `n_kp_enc` | Number of posterior keypoints to be learned | Integer value, we used: [10, 12, 25] | 53 | | `eval_epoch_freq` | Frequency of evaluation during training | Integer value, default: 1 | 54 | | `learned_feature_dim` | Dimension of latent visual features extracted from glimpses | Integer value, best: [4, 10] | 55 | | `bg_learned_feature_dim` | Dimension of latent visual features extracted from the background | Integer value, default: same as `learned_feature_dim` | 56 | | `n_kp_prior` | Number of keypoints to filter from the set of prior keypoints | Integer value, in practice, we don't filter (=number of prior patches) | 57 | | `weight_decay` | Weight decay for the optimizer | Floating-point value, default: 0.0 | 58 | | `kp_range` | Range of keypoints | **[-1, 1]** (use that), [0, 1] | 59 | | `warmup_epoch` | Number of warm-up epochs for DLP, where only the patch encoder/decoder is trained | Integer value, default: 1 | 60 | | `dropout` | Dropout rate for the CNNs | Floating-point value, default: 0.0 | 61 | | `iou_thresh` | IoU threshold for object bounding boxes (only for plotting) | Floating-point value, default: 0.2 | 62 | | `anchor_s` | Glimpse size defined as a ratio of image size, effectively the posterior patch size (e.g., 0.25 for image_size=128 -> glimpse_size=32) | Floating-point value, best: 0.125 for `phyre`, 0.25 for all others | 63 | | `kl_balance` | Balance parameter between attributes and visual appearance features for the KL divergence loss | Floating-point value, best: 0.001 | 64 | | `image_size` | Size of the input image | Integer value, e.g., 128 | 65 | | `ch` | Number of channels in the input image | Integer value (=3) | 66 | | `enc_channels` | Number of channels for the posterior CNN | List of integer values, (32, 32, 64, 64, 128, 128) | 67 | | `prior_channels` | Number of channels for the prior CNN | List of integer values (32, 32, 64) | 68 | | `timestep_horizon` | Number of timesteps to train DDLP on (DDLP) | Integer value, typical values: [10, 15, 20] | 69 | | `predict_delta` | Flag to predict the delta between consecutive frames instead of absolute coordinates (DDLP) | true (use that), false 70 | | `beta_dyn` | Weight for the KL dynamics loss (DDLP) | Floating-point value, recommended: =`beta_kl` | 71 | | `scale_std` | Prior standard deviation for scale | Floating-point value, recommended: [0.3, 1.0] | 72 | | `offset_std` | Prior standard deviation for offset | Floating-point value, recommended: [0.2, 1.0] | 73 | | `obj_on_alpha` | Prior alpha (Beta distribution) for obj_on (transparency) | Floating-point value, recommended: 0.1 | 74 | | `obj_on_beta` | Prior beta (Beta distribution) for obj_on (transparency) | Floating-point value, recommended: 0.1 | 75 | | `beta_dyn_rec` | Weight for the dynamics reconstruction loss (DDLP) | Floating-point value, recommended: 1.0 | 76 | | `num_static_frames` | Number of static frames (="burn-in frames") in the sequence that for which their KL is optimized w.r.t constant prior (DDLP) | Integer value, best: 4 | 77 | | `pint_layers` | Number of transformer layers in the dynamics module (DDLP) | Integer value, best: 6 | 78 | | `pint_heads` | Number of transformer heads in the dynamics module (DDLP) | Integer value, best: 8 | 79 | | `pint_dim` | Dimension of the transformer model (DDLP) | Integer value, best: [256, 512] | 80 | | `run_prefix` | Prefix for the run directory name | String | 81 | | `animation_horizon` | Number of frames to animate into the future ,only at inference time (DDLP) | Integer value, default: 50 | 82 | | `eval_im_metrics` | Flag to enable evaluation of image metrics | true, false | 83 | | `use_resblock` | Flag to use residual blocks in the CNNs | true, **false** (use that) | 84 | | `scheduler_gamma` | Learning rate scheduler gamma parameter | Floating-point value, default: 0.95 | 85 | | `adam_betas` | Beta values for the Adam optimizer | List of floating-point values, default: (0.9, 0.999) | 86 | | `adam_eps` | Epsilon value for the Adam optimizer | Floating-point value, default: 0.0001 | 87 | | `train_enc_prior` | Flag to train the encoder prior model | **true**, false | 88 | | `start_dyn_epoch` | Epoch at which to start training the dynamics model (DDLP) | Integer value, default: 0 | 89 | | `cond_steps` | Number of consecutive steps to condition on in the dynamics model at infernece time (DDLP) | Integer value | 90 | | `animation_fps` | Frames per second for generating animations (DDLP) | Floating-point value, default: 0.06 | 91 | | `use_correlation_heatmaps` | Flag to use correlation heatmaps for tracking | **true**, false | 92 | | `enable_enc_attn` | Flag to enable attention between patches in the particle encoder | true, **false** | 93 | | `filtering_heuristic` | Filtering heuristic to filter posterior anchors from prior keypoints | "distance", **"variance"**, "random", "none" | 94 | | `use_tracking` | Flag to enable object tracking | true (DDLP), false (DLP) | 95 | 96 | -------------------------------------------------------------------------------- /configs/generate_config_file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def save_config(src_dir, fname, hparams): 6 | path_to_conf = os.path.join(src_dir, fname) 7 | with open(path_to_conf, "w") as outfile: 8 | json.dump(hparams, outfile, indent=2) 9 | 10 | 11 | def gen_conf_file(ds, fname='default.json'): 12 | device = 'cuda' 13 | lr = 2e-4 14 | batch_size = 32 15 | num_epochs = 150 16 | load_model = False 17 | pretrained_path = None 18 | eval_epoch_freq = 1 19 | n_kp = 1 # num kp per patch 20 | iou_thresh = 0.2 21 | kp_range = (-1, 1) 22 | weight_decay = 0.0 23 | run_prefix = "" 24 | pad_mode = 'replicate' 25 | sigma = 1.0 # default sigma for the gaussian maps 26 | dropout = 0.0 27 | kp_activation = "tanh" 28 | ch = 3 # image channels 29 | topk = 5 # top-k particles to plot 30 | use_resblock = False 31 | adam_betas = (0.9, 0.999) 32 | adam_eps = 1e-4 33 | # filtering heuristic to filter prior keypoints 34 | filtering_heuristic = 'variance' # ['distance', 'variance', 'random', 'none'] 35 | 36 | timestep_horizon = 10 37 | predict_delta = True 38 | use_tracking = True 39 | use_correlation_heatmaps = True # use correlation heatmap between patches for tracking 40 | enable_enc_attn = False # rnable attention between patches in the particle encoder 41 | 42 | beta_kl = 0.1 43 | beta_dyn = 0.1 44 | beta_rec = 1.0 45 | beta_dyn_rec = 1.0 46 | kl_balance = 0.001 47 | 48 | num_static_frames = 4 49 | 50 | pint_layers = 6 # transformer layers in the dynamics module 51 | pint_dim = 256 52 | pint_heads = 8 53 | 54 | # priors 55 | scale_std = 0.3 56 | offset_std = 0.2 57 | obj_on_alpha = 0.1 58 | obj_on_beta = 0.1 59 | 60 | animation_horizon = 100 61 | eval_im_metrics = True 62 | scheduler_gamma = 0.95 63 | train_enc_prior = True # train the SSM prior or leave random 64 | start_dyn_epoch = 0 # epoch from which to start training the dynamics module 65 | cond_steps = 10 # conditional steps for the dynamics module during inference 66 | animation_fps = 3 / 50 67 | 68 | if ds == 'traffic': 69 | beta_kl = 40.0 70 | beta_dyn = 40.0 71 | beta_rec = 1.0 72 | beta_dyn_rec = 1.0 73 | # n_kp_enc = 16 # total kp to output from the encoder / filter from prior 74 | n_kp_enc = 25 75 | # n_kp_prior = 20 76 | n_kp_prior = 30 77 | patch_size = 16 78 | learned_feature_dim = 8 # additional features than x,y for each kp 79 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 80 | recon_loss_type = "vgg" 81 | # warmup_epoch = 1 82 | warmup_epoch = 3 83 | anchor_s = 0.25 84 | kl_balance = 0.001 85 | exclusive_patches = False 86 | # batch_size = 2 87 | batch_size = 8 # a100 88 | eval_epoch_freq = 1 89 | # timestep_horizon = 20 90 | timestep_horizon = 10 91 | lr = 2e-4 92 | # lr = 5e-4 93 | image_size = 128 94 | enc_channels = [32, 64, 128, 256] 95 | prior_channels = (16, 32, 64) 96 | # root = '/mnt/data/tal/traffic_dataset/img128np_fs3.npy' 97 | root = '/home/tal/data/traffic/img128np.npy' 98 | elif ds == 'clevrer': 99 | beta_kl = 100.0 100 | beta_dyn = 100.0 101 | beta_rec = 1.0 102 | beta_dyn_rec = 1.0 103 | # n_kp_enc = 10 # total kp to output from the encoder / filter from prior 104 | n_kp_enc = 12 # total kp to output from the encoder / filter from prior 105 | # n_kp_prior = 20 106 | n_kp_prior = 16 # orig 107 | # n_kp_prior = 10 108 | patch_size = 16 109 | learned_feature_dim = 8 # additional features than x,y for each kp 110 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 111 | recon_loss_type = "vgg" 112 | # recon_loss_type = "mse" 113 | warmup_epoch = 1 114 | # warmup_epoch = 0 115 | anchor_s = 0.25 116 | kl_balance = 0.001 117 | # batch_size = 4 118 | batch_size = 8 119 | eval_epoch_freq = 1 120 | timestep_horizon = 20 121 | lr = 2e-4 122 | image_size = 128 123 | # image_size = 64 124 | enc_channels = [32, 64, 128, 256] # 128x128 125 | # enc_channels = [32, 64, 128] # 64x64 126 | prior_channels = (16, 32, 64) 127 | # root = '/mnt/data/tal/clevrer_ep/' 128 | root = '/datadrive/clevrer/' 129 | elif ds == 'balls': 130 | beta_kl = 0.1 # original 0.05 131 | beta_dyn = 0.1 # original 0.1 132 | beta_rec = 1.0 133 | beta_dyn_rec = 1.0 134 | # beta_rec = 1 / 11 135 | # n_kp_enc = 3 # total kp to output from the encoder / filter from prior 136 | n_kp_enc = 6 137 | n_kp_prior = 12 # original 138 | # n_kp_prior = 15 139 | patch_size = 8 # original 140 | # patch_size = 16 141 | learned_feature_dim = 3 # additional features than x,y for each kp 142 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 143 | recon_loss_type = "mse" 144 | warmup_epoch = 1 145 | # warmup_epoch = -2 146 | anchor_s = 0.25 147 | kl_balance = 0.001 148 | # override manually 149 | # lr = 2e-4 150 | batch_size = 32 151 | # batch_size = 20 152 | eval_epoch_freq = 1 153 | predict_delta = True 154 | image_size = 64 155 | ch = 3 156 | enc_channels = (32, 64, 128) 157 | prior_channels = (16, 32, 64) 158 | root = '/mnt/data/tal/gswm_balls/BALLS_INTERACTION' 159 | elif ds == 'bair': 160 | beta_kl = 40.0 161 | beta_dyn = 800.0 162 | beta_rec = 1.0 163 | beta_dyn_rec = 1.0 164 | # n_kp_enc = 3 # total kp to output from the encoder / filter from prior 165 | n_kp_enc = 15 166 | n_kp_prior = 20 # original 167 | # n_kp_prior = 15 168 | patch_size = 8 # original 169 | # patch_size = 16 170 | learned_feature_dim = 10 # additional features than x,y for each kp 171 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 172 | recon_loss_type = "vgg" 173 | warmup_epoch = 1 174 | # warmup_epoch = 0 175 | anchor_s = 0.25 176 | kl_balance = 0.001 177 | # override manually 178 | lr = 2e-4 179 | batch_size = 8 180 | eval_epoch_freq = 1 181 | timestep_horizon = 15 182 | image_size = 64 183 | enc_channels = (32, 64, 128) 184 | prior_channels = (16, 32, 64) 185 | root = '/mnt/data/tal/bair/processed/' 186 | # root = '/media/newhd/data/bair/processed/' 187 | # dataset = BAIRDataset(root=root, train=True, horizon=timestep_horizon + 1) 188 | animation_horizon = 16 189 | cond_steps = 1 190 | elif ds == 'obj3d': 191 | # mse: 192 | # beta_kl = 0.01 # original: 0.05, worked good: 0.01 193 | # beta_dyn = 0.01 194 | # beta_rec = 1.0 195 | # vgg: 196 | beta_kl = 30.0 197 | beta_dyn = 30.0 198 | beta_rec = 1.0 199 | beta_dyn_rec = 1.0 200 | # n_kp_enc = 3 # total kp to output from the encoder / filter from prior 201 | n_kp_enc = 8 202 | n_kp_prior = 16 # original 203 | # n_kp_prior = 15 204 | patch_size = 8 # original 205 | # patch_size = 16 206 | learned_feature_dim = 10 # additional features than x,y for each kp 207 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 208 | # recon_loss_type = "mse" 209 | recon_loss_type = "vgg" 210 | warmup_epoch = 1 211 | # warmup_epoch = 0 212 | anchor_s = 0.25 213 | kl_balance = 0.001 214 | # override manually 215 | lr = 2e-4 216 | # batch_size = 20 # mse 217 | batch_size = 10 # vgg 218 | eval_epoch_freq = 1 219 | image_size = 64 220 | enc_channels = (32, 64, 128) 221 | prior_channels = (16, 32, 64) 222 | root = '/mnt/data/tal/obj3d/' 223 | elif ds == 'obj3d128': 224 | # mse: 225 | # beta_kl = 0.1 # original: 0.05, worked good: 0.01 226 | # beta_dyn = 0.1 227 | # beta_rec = 1.0 228 | # vgg: 229 | beta_kl = 100.0 230 | beta_dyn = 100.0 231 | beta_rec = 1.0 232 | beta_dyn_rec = 1.0 233 | 234 | # recon_loss_type = "mse" 235 | recon_loss_type = "vgg" 236 | 237 | # beta_rec = 1 / 11 238 | # n_kp_enc = 3 # total kp to output from the encoder / filter from prior 239 | n_kp_enc = 12 240 | n_kp_prior = 16 # original 241 | # n_kp_prior = 64 242 | # patch_size = 8 # original 243 | patch_size = 16 244 | # learned_feature_dim = 10 # additional features than x,y for each kp 245 | learned_feature_dim = 8 246 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 247 | 248 | warmup_epoch = 1 249 | # warmup_epoch = 0 250 | anchor_s = 0.25 251 | kl_balance = 0.001 252 | # override manually 253 | lr = 2e-4 254 | # batch_size = 6 # mse 255 | batch_size = 4 # vgg 256 | eval_epoch_freq = 1 257 | timestep_horizon = 10 258 | # sigma = 1.0 # deterministic chamfer 259 | image_size = 128 260 | enc_channels = [32, 64, 128, 256] # 128x128 261 | prior_channels = (16, 32, 64) 262 | root = '/mnt/data/tal/obj3d/' 263 | elif ds == 'sketchy': 264 | # mse: 265 | # beta_kl = 0.01 # original: 0.05, worked good: 0.01 266 | # beta_dyn = 0.01 267 | # beta_rec = 1.0 268 | # vgg: 269 | beta_kl = 40.0 270 | beta_dyn = 40.0 271 | beta_rec = 1.0 272 | beta_dyn_rec = 1.0 273 | # n_kp_enc = 3 # total kp to output from the encoder / filter from prior 274 | n_kp_enc = 8 275 | n_kp_prior = 16 # original 276 | # n_kp_prior = 15 277 | # patch_size = 8 # original 278 | patch_size = 16 279 | learned_feature_dim = 10 # additional features than x,y for each kp 280 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 281 | # recon_loss_type = "mse" 282 | recon_loss_type = "vgg" 283 | warmup_epoch = 1 284 | # warmup_epoch = 0 285 | anchor_s = 0.25 286 | kl_balance = 0.001 287 | # override manually 288 | lr = 2e-4 289 | # batch_size = 20 # mse 290 | batch_size = 4 # vgg 291 | eval_epoch_freq = 1 292 | sigma = 1.0 # deterministic chamfer 293 | image_size = 128 294 | enc_channels = [32, 64, 128, 256] # 128x128 295 | prior_channels = (16, 32, 64) 296 | root = '/mnt/data/tal/sketchy/' 297 | elif ds == 'phyre': 298 | # mse: 299 | beta_kl = 0.15 # orig: 0.1 300 | beta_dyn = 0.15 # prig: 0.1 301 | beta_rec = 1.0 302 | beta_dyn_rec = 1.0 303 | # vgg: 304 | # beta_kl = 20.0 305 | # beta_dyn = 40.0 306 | # beta_rec = 1.0 307 | # beta_dyn_rec = 0.1 # 0.1 308 | n_kp_enc = 25 # anchor_s:0.125 309 | n_kp_prior = 30 # anchor_s:0.125 310 | # n_kp_enc = 15 # anchor_s:0.25 311 | # n_kp_prior = 20 # anchor_s:0.25 312 | # n_kp_prior = 16 # original 313 | # patch_size = 8 # original 314 | patch_size = 16 # TODO: 8? 315 | learned_feature_dim = 4 # additional features than x,y for each kp 316 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 317 | recon_loss_type = "mse" 318 | # recon_loss_type = "vgg" 319 | warmup_epoch = 1 320 | # warmup_epoch = -2 321 | anchor_s = 0.125 322 | # anchor_s = 0.25 323 | kl_balance = 0.001 324 | # override manually 325 | # lr = 2e-4 326 | # batch_size = 6 # mse, t10 327 | batch_size = 3 # mse, t20 328 | # batch_size = 4 # vgg 64 329 | # batch_size = 2 # vgg 128 330 | eval_epoch_freq = 1 331 | sigma = 1.0 # deterministic chamfer 332 | predict_delta = True 333 | # timestep_horizon = 10 334 | timestep_horizon = 15 335 | pint_dim = 512 336 | image_size = 128 337 | # image_size = 64 338 | enc_channels = [32, 64, 128, 256] # 128x128 339 | # enc_channels = [32, 64, 128] # 64x64 340 | prior_channels = (16, 32, 64) 341 | root = '/mnt/data/tal/phyre/' 342 | # root = '/media/newhd/data/phyre/' 343 | animation_fps = 2.5 / 50 344 | elif ds == 'mario': 345 | # mse: 346 | # beta_kl = 0.01 # original: 0.05, worked good: 0.01 347 | # beta_dyn = 0.01 348 | # beta_rec = 1.0 349 | # vgg: 350 | beta_kl = 80.0 351 | beta_dyn = 80.0 352 | beta_rec = 1.0 353 | beta_dyn_rec = 1.0 354 | # n_kp_enc = 3 # total kp to output from the encoder / filter from prior 355 | n_kp_enc = 15 356 | n_kp_prior = 20 # original 357 | # n_kp_prior = 15 358 | # patch_size = 8 # original 359 | patch_size = 16 360 | learned_feature_dim = 10 # additional features than x,y for each kp 361 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 362 | # recon_loss_type = "mse" 363 | recon_loss_type = "vgg" 364 | warmup_epoch = 1 365 | # warmup_epoch = 0 366 | anchor_s = 0.25 367 | kl_balance = 0.001 368 | # override manually 369 | lr = 2e-4 370 | # batch_size = 20 # mse 371 | batch_size = 4 # vgg 372 | eval_epoch_freq = 1 373 | image_size = 128 374 | enc_channels = [32, 64, 128, 256] 375 | prior_channels = (16, 32, 64) 376 | root = '/media/newhd/data/mario/' 377 | elif ds == 'shapes': 378 | beta_kl = 0.1 # original 379 | beta_rec = 1.0 380 | n_kp_enc = 10 # total kp to output from the encoder / filter from prior 381 | n_kp_prior = 64 382 | patch_size = 8 383 | learned_feature_dim = 6 # additional features than x,y for each kp 384 | topk = min(10, n_kp_enc) # display top-10 kp with smallest variance 385 | recon_loss_type = "mse" 386 | warmup_epoch = 1 387 | eval_epoch_freq = 1 388 | anchor_s = 0.25 389 | kl_balance = 0.001 390 | lr = 1e-3 391 | batch_size = 64 392 | image_size = 64 393 | enc_channels = (32, 64, 128) 394 | prior_channels = (16, 32, 64) 395 | root = None 396 | filtering_heuristic = 'none' 397 | use_correlation_heatmaps = False 398 | use_tracking = False 399 | else: 400 | raise NotImplementedError("unrecognized dataset, please implement it and add it to the train script") 401 | 402 | hparams = {'ds': ds, 'root': root, 'device': device, 'batch_size': batch_size, 'lr': lr, 403 | 'kp_activation': kp_activation, 404 | 'pad_mode': pad_mode, 'load_model': load_model, 'pretrained_path': pretrained_path, 405 | 'num_epochs': num_epochs, 'n_kp': n_kp, 'recon_loss_type': recon_loss_type, 406 | 'sigma': sigma, 'beta_kl': beta_kl, 'beta_rec': beta_rec, 407 | 'patch_size': patch_size, 'topk': topk, 'n_kp_enc': n_kp_enc, 408 | 'eval_epoch_freq': eval_epoch_freq, 'learned_feature_dim': learned_feature_dim, 409 | 'n_kp_prior': n_kp_prior, 'weight_decay': weight_decay, 'kp_range': kp_range, 410 | 'warmup_epoch': warmup_epoch, 'dropout': dropout, 411 | 'iou_thresh': iou_thresh, 'anchor_s': anchor_s, 'kl_balance': kl_balance, 412 | 'image_size': image_size, 'ch': ch, 'enc_channels': enc_channels, 413 | 'prior_channels': prior_channels, 414 | 'timestep_horizon': timestep_horizon, 'predict_delta': predict_delta, 'beta_dyn': beta_dyn, 415 | 'scale_std': scale_std, 'offset_std': offset_std, 'obj_on_alpha': obj_on_alpha, 416 | 'obj_on_beta': obj_on_beta, 'beta_dyn_rec': beta_dyn_rec, 'num_static_frames': num_static_frames, 417 | 'pint_layers': pint_layers, 'pint_heads': pint_heads, 'pint_dim': pint_dim, 'run_prefix': run_prefix, 418 | 'animation_horizon': animation_horizon, 'eval_im_metrics': eval_im_metrics, 'use_resblock': use_resblock, 419 | 'scheduler_gamma': scheduler_gamma, 'adam_betas': adam_betas, 'adam_eps': adam_eps, 420 | 'train_enc_prior': train_enc_prior, 'start_dyn_epoch': start_dyn_epoch, 'cond_steps': cond_steps, 421 | 'animation_fps': animation_fps, 'use_correlation_heatmaps': use_correlation_heatmaps, 422 | 'enable_enc_attn': enable_enc_attn, 'filtering_heuristic': filtering_heuristic, 423 | 'use_tracking': use_tracking} 424 | 425 | save_config('./', fname, hparams) 426 | 427 | 428 | if __name__ == '__main__': 429 | # dss = ['traffic', 'balls', 'clevrer', 'phyre', 'obj3d', 'obj3d128', 'mario', 'sketchy', 'bair'] 430 | dss = ['shapes'] 431 | for ds in dss: 432 | gen_conf_file(ds=ds, fname=f'{ds}.json') 433 | conf_path = os.path.join('', f'{ds}.json') 434 | with open(conf_path, 'r') as f: 435 | config = json.load(f) 436 | print(config) 437 | -------------------------------------------------------------------------------- /datasets/clevrer_ds.py: -------------------------------------------------------------------------------- 1 | """ 2 | functions and classes to process the CLEVRER dataset 3 | """ 4 | 5 | import os 6 | import os.path as osp 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import cv2 10 | # import utils.tps as tps 11 | import glob 12 | 13 | import torch 14 | from PIL import Image, ImageFile 15 | from torch.utils.data import Dataset, DataLoader 16 | import torchvision.transforms as transforms 17 | 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True 19 | 20 | 21 | # --- old preprocessing functions for the single image setting --- # 22 | def list_images_in_dir(path): 23 | valid_images = [".jpg", ".gif", ".png"] 24 | img_list = [] 25 | for f in os.listdir(path): 26 | ext = os.path.splitext(f)[1] 27 | if ext.lower() not in valid_images: 28 | continue 29 | img_list.append(os.path.join(path, f)) 30 | return img_list 31 | 32 | 33 | def prepare_numpy_file(path_to_image_dir, image_size=128, frameskip=1, start_frame=1): 34 | # path_to_image_dir = '/media/newhd/data/traffic_data/rimon_frames/' 35 | img_list = list_images_in_dir(path_to_image_dir) 36 | img_list = sorted(img_list, key=lambda x: int(x.split('/')[-1].split('_')[-1].split('.')[0])) 37 | img_list = [img_list[i] for i in range(len(img_list)) if 38 | abs(int(img_list[i].split('/')[-1].split('_')[-1].split('.')[0])) % 1000 > start_frame] 39 | print(f'img_list: {len(img_list)}, 0: {img_list[0]}, -1: {img_list[-1]}') 40 | img_np_list = [] 41 | for i in tqdm(range(len(img_list))): 42 | if i % frameskip != 0: 43 | continue 44 | img = Image.open(img_list[i]) 45 | img = img.convert('RGB') 46 | # img = img.crop((60, 0, 480, 420)) 47 | img = img.resize((image_size, image_size), Image.BICUBIC) 48 | img_np = np.asarray(img) 49 | img_np_list.append(img_np) 50 | img_np_array = np.stack(img_np_list, axis=0) 51 | print(f'img_np_array: {img_np_array.shape}') 52 | save_path = os.path.join(path_to_image_dir, f'clevrer_img{image_size}np_fs{frameskip}.npy') 53 | np.save(save_path, img_np_array) 54 | print(f'file save at @ {save_path}') 55 | 56 | 57 | # --- end old preprocessing functions for the single image setting --- # 58 | 59 | # --- new preprocessing functions for the episodic setting --- # 60 | """ 61 | Instructions: 62 | 1. Download the CLEVRER dataset from here: http://clevrer.csail.mit.edu/ 63 | 2. Extract the directories 'train' and 'valid', they should contain directories named like: 'video_10000-11000' 64 | 3. In the directory containing the 'train' and 'valid', run `preprocess_clevrer(mode='train', ep_len=100, start_frame=18)` 65 | """ 66 | 67 | 68 | def extract_frames(path_to_video, path_to_save_frames, start_frame, end_frame, image_size=128): 69 | vidcap = cv2.VideoCapture(path_to_video) 70 | success, image = vidcap.read() 71 | count = 0 72 | curr_frame = 0 73 | while success: 74 | if count >= start_frame: 75 | path = os.path.join(path_to_save_frames, "%d.png" % curr_frame) 76 | resized = cv2.resize(image, (image_size, image_size), interpolation=cv2.INTER_AREA) 77 | cv2.imwrite(path, resized) 78 | curr_frame += 1 79 | success, image = vidcap.read() 80 | # print('Read a new frame: ', success) 81 | count += 1 82 | if count == end_frame: 83 | break 84 | vidcap.release() 85 | 86 | 87 | def preprocess_clevrer(mode='train', ep_len=100, start_frame=18): 88 | assert mode in ['train', 'valid'] 89 | end_frame = start_frame + ep_len 90 | path_to_dir = f'./ep_{mode}' 91 | os.makedirs(path_to_dir, exist_ok=True) 92 | path_to_video_dir = f'./{mode}' 93 | video_dirs = [d for d in os.listdir(path_to_video_dir) 94 | if os.path.isdir(os.path.join(path_to_video_dir, d)) and 'video' in d] 95 | video_dirs = sorted(video_dirs) 96 | episode = 0 97 | for i in range(len(video_dirs)): 98 | curr_dir = os.path.join(path_to_video_dir, video_dirs[i]) 99 | print(f'current dir: {curr_dir}') 100 | videos_curr_dir = sorted([v for v in os.listdir(curr_dir) if 'video' in v]) 101 | for j in range(len(videos_curr_dir)): 102 | curr_video = os.path.join(curr_dir, videos_curr_dir[j]) 103 | target_dir = os.path.join(path_to_dir, f'{episode}') 104 | os.makedirs(target_dir, exist_ok=True) 105 | # extract frames 106 | extract_frames(curr_video, target_dir, start_frame, end_frame) 107 | episode += 1 108 | 109 | 110 | # --- end new preprocessing functions for the episodic setting --- # 111 | 112 | 113 | class CLEVREREpDataset(Dataset): 114 | def __init__(self, root, mode, ep_len=100, sample_length=20, image_size=128): 115 | # path = os.path.join(root, mode) 116 | assert mode in ['train', 'val', 'valid', 'test'] 117 | if mode == 'val': 118 | mode = 'valid' 119 | self.root = os.path.join(root, mode) 120 | self.image_size = image_size 121 | 122 | self.mode = mode 123 | self.sample_length = sample_length 124 | 125 | # Get all numbers 126 | self.folders = [] 127 | for file in os.listdir(self.root): 128 | try: 129 | self.folders.append(int(file)) 130 | except ValueError: 131 | continue 132 | self.folders.sort() 133 | 134 | self.epsisodes = [] 135 | self.EP_LEN = ep_len 136 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1 137 | 138 | for f in self.folders: 139 | dir_name = os.path.join(self.root, str(f)) 140 | paths = list(glob.glob(osp.join(dir_name, '*.png'))) 141 | # if len(paths) != self.EP_LEN: 142 | # continue 143 | # assert len(paths) == self.EP_LEN, 'len(paths): {}'.format(len(paths)) 144 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0]) 145 | paths.sort(key=get_num) 146 | self.epsisodes.append(paths) 147 | 148 | def __getitem__(self, index): 149 | 150 | imgs = [] 151 | if self.mode == 'train': 152 | # Implement continuous indexing 153 | ep = index // self.seq_per_episode 154 | offset = index % self.seq_per_episode 155 | end = offset + self.sample_length 156 | 157 | e = self.epsisodes[ep] 158 | for image_index in range(offset, end): 159 | img = Image.open(osp.join(e[image_index])) 160 | img = img.resize((self.image_size, self.image_size)) 161 | img = transforms.ToTensor()(img)[:3] 162 | imgs.append(img) 163 | else: 164 | for path in self.epsisodes[index]: 165 | img = Image.open(path) 166 | img = img.resize((self.image_size, self.image_size)) 167 | img = transforms.ToTensor()(img)[:3] 168 | imgs.append(img) 169 | 170 | img = torch.stack(imgs, dim=0).float() 171 | pos = torch.zeros(0) 172 | size = torch.zeros(0) 173 | id = torch.zeros(0) 174 | in_camera = torch.zeros(0) 175 | 176 | return img, pos, size, id, in_camera 177 | 178 | def __len__(self): 179 | length = len(self.epsisodes) 180 | if self.mode == 'train': 181 | return length * self.seq_per_episode 182 | else: 183 | return length 184 | 185 | 186 | class CLEVREREpDatasetImage(Dataset): 187 | def __init__(self, root, mode, ep_len=100, sample_length=20, image_size=128): 188 | # path = os.path.join(root, mode) 189 | assert mode in ['train', 'val', 'valid', 'test'] 190 | if mode == 'val': 191 | mode = 'valid' 192 | self.root = os.path.join(root, mode) 193 | self.image_size = image_size 194 | 195 | self.mode = mode 196 | self.sample_length = sample_length 197 | 198 | # Get all numbers 199 | self.folders = [] 200 | for file in os.listdir(self.root): 201 | try: 202 | self.folders.append(int(file)) 203 | except ValueError: 204 | continue 205 | self.folders.sort() 206 | 207 | self.epsisodes = [] 208 | self.EP_LEN = ep_len 209 | self.seq_per_episode = self.EP_LEN - self.sample_length + 1 210 | 211 | for f in self.folders: 212 | dir_name = os.path.join(self.root, str(f)) 213 | paths = list(glob.glob(osp.join(dir_name, '*.png'))) 214 | # if len(paths) != self.EP_LEN: 215 | # continue 216 | # assert len(paths) == self.EP_LEN, 'len(paths): {}'.format(len(paths)) 217 | get_num = lambda x: int(osp.splitext(osp.basename(x))[0]) 218 | paths.sort(key=get_num) 219 | self.epsisodes.append(paths) 220 | 221 | def __getitem__(self, index): 222 | 223 | imgs = [] 224 | # Implement continuous indexing 225 | ep = index // self.seq_per_episode 226 | offset = index % self.seq_per_episode 227 | end = offset + self.sample_length 228 | 229 | e = self.epsisodes[ep] 230 | for image_index in range(offset, end): 231 | img = Image.open(osp.join(e[image_index])) 232 | img = img.resize((self.image_size, self.image_size)) 233 | img = transforms.ToTensor()(img)[:3] 234 | imgs.append(img) 235 | 236 | img = torch.stack(imgs, dim=0).float() 237 | pos = torch.zeros(0) 238 | size = torch.zeros(0) 239 | id = torch.zeros(0) 240 | in_camera = torch.zeros(0) 241 | 242 | return img, pos, size, id, in_camera 243 | 244 | def __len__(self): 245 | length = len(self.epsisodes) 246 | return length * self.seq_per_episode 247 | 248 | 249 | class CLEVRERDataset(Dataset): 250 | def __init__(self, path_to_npy, image_size=128, transform=None, mode='single', train=True, horizon=3, 251 | frames_per_video=34, video_as_index=False): 252 | super(CLEVRERDataset, self).__init__() 253 | assert mode in ['single', 'frames', 'tps', 'horizon'] 254 | self.mode = mode 255 | self.frames_per_video = frames_per_video 256 | self.horizon = horizon if (horizon > 0 and self.mode == 'horizon') else self.frames_per_video 257 | self.train_mode = train 258 | if train: 259 | print(f'clevrer dataset mode: {self.mode}') 260 | if self.mode == 'horizon': 261 | print(f'time steps horizon: {self.horizon}') 262 | if self.mode == 'tps': 263 | # self.warper = tps.Warper(H=image_size, W=image_size, warpsd_all=0.00001, 264 | # warpsd_subset=0.001, transsd=0.1, scalesd=0.1, 265 | # rotsd=2, im1_multiplier=0.1, im1_multiplier_aff=0.1) 266 | pass 267 | else: 268 | self.warper = None 269 | data = np.load(path_to_npy) 270 | # train_size = int(0.9 * data.shape[0]) 271 | # valid_size = data.shape[0] - train_size 272 | if train: 273 | self.data = data 274 | # self.data = data[:self.frames_per_video * 200] 275 | print(f'loaded data with shape: {self.data.shape}, size: {self.data.shape[0]}') 276 | else: 277 | self.data = data[:5000] 278 | self.image_size = image_size 279 | self.num_videos = self.data.shape[0] // self.frames_per_video 280 | self.video_as_index = video_as_index 281 | if transform is None: 282 | self.input_transform = transforms.Compose([ 283 | transforms.ToPILImage(), 284 | transforms.Resize(image_size), 285 | transforms.ToTensor() 286 | ]) 287 | else: 288 | self.input_transform = transform 289 | 290 | def __getitem__(self, index): 291 | if not self.video_as_index: 292 | video_num = int(index / self.frames_per_video) 293 | video_start_idx = video_num * self.frames_per_video 294 | curr_idx = index % self.frames_per_video 295 | max_idx = min(video_start_idx + self.frames_per_video - 1, self.data.shape[0] - 1) 296 | global_idx = video_start_idx + curr_idx 297 | if self.mode == 'single': 298 | return self.input_transform(self.data[index]) 299 | elif self.mode == 'frames': 300 | min_idx = min(video_start_idx, index - 1) 301 | if min_idx == video_start_idx: 302 | im1 = self.input_transform(self.data[min_idx + 1]) 303 | im2 = self.input_transform(self.data[min_idx]) 304 | else: 305 | im1 = self.input_transform(self.data[min_idx]) 306 | im2 = self.input_transform(self.data[min_idx - 1]) 307 | return im1, im2 308 | elif self.mode == 'horizon': 309 | images = [] 310 | length = max_idx 311 | if (index + self.horizon) >= length: 312 | slack = index + self.horizon - length 313 | index = index - slack 314 | for i in range(self.horizon): 315 | t = index + i 316 | images.append(self.input_transform(self.data[t])) 317 | images = torch.stack(images, dim=0) 318 | return images 319 | elif self.mode == 'tps': 320 | im = self.input_transform(self.data[index]) 321 | im = im * 255 322 | im2, im1, _, _, _, _ = self.warper(im) 323 | return im1 / 255, im2 / 255 324 | else: 325 | raise NotImplementedError 326 | else: 327 | video_num = index 328 | video_start_idx = video_num * self.frames_per_video 329 | max_idx = video_start_idx + self.frames_per_video - 1 330 | images = [] 331 | length = max_idx 332 | frame_idx = video_start_idx 333 | actual_horizon = self.frames_per_video if ((frame_idx + self.horizon) >= length) else self.horizon 334 | for i in range(actual_horizon): 335 | t = frame_idx + i 336 | images.append(self.input_transform(self.data[t])) 337 | images = torch.stack(images, dim=0) 338 | return images 339 | 340 | def __len__(self): 341 | if not self.video_as_index: 342 | return self.data.shape[0] 343 | else: 344 | return self.num_videos 345 | 346 | 347 | if __name__ == '__main__': 348 | # -- single image setting --- # 349 | path_to_img = '/media/newhd/data/clevrer/train/frames/' 350 | # prepare_numpy_file(path_to_img, image_size=128, frameskip=3, start_frame=26) 351 | test_epochs = True 352 | # load data 353 | # path_to_npy = '/media/newhd/data/clevrer/valid/clevrer_img128np_fs3_valid.npy' 354 | # mode = 'frames' 355 | # horizon = 4 356 | # train = True 357 | # clevrer_ds = CLEVRERDataset(path_to_npy, mode=mode, train=train, horizon=horizon) 358 | # clevrer_dl = DataLoader(clevrer_ds, shuffle=True, pin_memory=True, batch_size=5) 359 | # batch = next(iter(clevrer_dl)) 360 | # if mode == 'single': 361 | # im1 = batch[0] 362 | # elif mode == 'frames' or mode == 'tps': 363 | # im1 = batch[0][0] 364 | # im2 = batch[1][0] 365 | # 366 | # if mode == 'single': 367 | # print(im1.shape) 368 | # img_np = im1.permute(1, 2, 0).data.cpu().numpy() 369 | # fig = plt.figure(figsize=(5, 5)) 370 | # ax = fig.add_subplot(111) 371 | # ax.imshow(img_np) 372 | # elif mode == 'horizon': 373 | # print(f'batch shape: {batch.shape}') 374 | # images = batch[0] 375 | # print(f'images shape: {images.shape}') 376 | # fig = plt.figure(figsize=(8, 8)) 377 | # for i in range(images.shape[0]): 378 | # ax = fig.add_subplot(1, horizon, i + 1) 379 | # im = images[i] 380 | # im_np = im.permute(1, 2, 0).data.cpu().numpy() 381 | # ax.imshow(im_np) 382 | # ax.set_title(f'im {i + 1}') 383 | # else: 384 | # print(f'im1: {im1.shape}, im2: {im2.shape}') 385 | # im1_np = im1.permute(1, 2, 0).data.cpu().numpy() 386 | # im2_np = im2.permute(1, 2, 0).data.cpu().numpy() 387 | # fig = plt.figure(figsize=(8, 8)) 388 | # ax = fig.add_subplot(1, 2, 1) 389 | # ax.imshow(im1_np) 390 | # ax.set_title('im1') 391 | # 392 | # ax = fig.add_subplot(1, 2, 2) 393 | # ax.imshow(im2_np) 394 | # ax.set_title('im2 [t-1] or [tps]') 395 | # plt.show() 396 | # --- end single image --- # 397 | 398 | # --- episodic setting --- # 399 | root = '/media/newhd/data/clevrer/episodes' 400 | clevrer_ds = CLEVREREpDataset(root=root, ep_len=100, sample_length=30, mode='train') 401 | clevrer_dl = DataLoader(clevrer_ds, shuffle=True, pin_memory=True, batch_size=5) 402 | batch = next(iter(clevrer_dl)) 403 | im = batch[0][0][0] 404 | print(im.shape) 405 | img_np = im.permute(1, 2, 0).data.cpu().numpy() 406 | fig = plt.figure(figsize=(5, 5)) 407 | ax = fig.add_subplot(111) 408 | ax.imshow(img_np) 409 | plt.show() 410 | 411 | if test_epochs: 412 | from tqdm import tqdm 413 | 414 | pbar = tqdm(iterable=clevrer_dl) 415 | for batch in pbar: 416 | pass 417 | pbar.close() 418 | -------------------------------------------------------------------------------- /train_dlp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Single-GPU training of DLPv2 3 | """ 4 | # imports 5 | import numpy as np 6 | import os 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | import matplotlib 10 | import argparse 11 | # torch 12 | import torch 13 | import torch.nn.functional as F 14 | from utils.loss_functions import calc_reconstruction_loss, VGGDistance 15 | from torch.utils.data import DataLoader 16 | import torchvision.utils as vutils 17 | import torch.optim as optim 18 | # modules 19 | from models import ObjectDLP 20 | # datasets 21 | from datasets.get_dataset import get_image_dataset 22 | # util functions 23 | from utils.util_func import plot_keypoints_on_image_batch, prepare_logdir, save_config, log_line, \ 24 | plot_bb_on_image_batch_from_z_scale_nms, plot_bb_on_image_batch_from_masks_nms, get_config 25 | from eval.eval_model import evaluate_validation_elbo 26 | from eval.eval_gen_metrics import eval_dlp_im_metric 27 | 28 | matplotlib.use("Agg") 29 | torch.backends.cudnn.benchmark = False 30 | torch.backends.cudnn.deterministic = True 31 | 32 | 33 | def train_dlp(config_path='./configs/shapes.json'): 34 | # load config 35 | try: 36 | config = get_config(config_path) 37 | except FileNotFoundError: 38 | raise SystemExit("config file not found") 39 | hparams = config # to save a copy of the hyper-parameters 40 | # data and general 41 | ds = config['ds'] 42 | ch = config['ch'] # image channels 43 | image_size = config['image_size'] 44 | root = config['root'] # dataset root 45 | batch_size = config['batch_size'] 46 | lr = config['lr'] 47 | num_epochs = config['num_epochs'] 48 | topk = min(config['topk'], config['n_kp_enc']) # top-k particles to plot 49 | eval_epoch_freq = config['eval_epoch_freq'] 50 | weight_decay = config['weight_decay'] 51 | iou_thresh = config['iou_thresh'] # threshold for NMS for plotting bounding boxes 52 | run_prefix = config['run_prefix'] 53 | load_model = config['load_model'] 54 | pretrained_path = config['pretrained_path'] # path of pretrained model to load, if None, train from scratch 55 | adam_betas = config['adam_betas'] 56 | adam_eps = config['adam_eps'] 57 | scheduler_gamma = config['scheduler_gamma'] 58 | eval_im_metrics = config['eval_im_metrics'] 59 | device = config['device'] 60 | if 'cuda' in device: 61 | device = torch.device(f'{device}' if torch.cuda.is_available() else 'cpu') 62 | else: 63 | device = torch.device('cpu') 64 | # model 65 | kp_range = config['kp_range'] 66 | kp_activation = config['kp_activation'] 67 | enc_channels = config['enc_channels'] 68 | prior_channels = config['prior_channels'] 69 | pad_mode = config['pad_mode'] 70 | n_kp = config['n_kp'] # kp per patch in prior, best to leave at 1 71 | n_kp_prior = config['n_kp_prior'] # number of prior kp to filter for the kl 72 | n_kp_enc = config['n_kp_enc'] # total posterior kp 73 | patch_size = config['patch_size'] # prior patch size 74 | anchor_s = config['anchor_s'] # posterior patch/glimpse ratio of image size 75 | learned_feature_dim = config['learned_feature_dim'] 76 | dropout = config['dropout'] 77 | use_resblock = config['use_resblock'] 78 | use_correlation_heatmaps = config['use_correlation_heatmaps'] # use heatmaps for tracking 79 | use_tracking = config['use_tracking'] 80 | enable_enc_attn = config['enable_enc_attn'] # enable attention between patches in the particle encoder 81 | filtering_heuristic = config["filtering_heuristic"] # filtering heuristic to filter prior keypoints 82 | 83 | # optimization 84 | warmup_epoch = config['warmup_epoch'] 85 | recon_loss_type = config['recon_loss_type'] 86 | beta_kl = config['beta_kl'] 87 | beta_rec = config['beta_rec'] 88 | kl_balance = config['kl_balance'] # balance between visual features and the other particle attributes 89 | train_enc_prior = config['train_enc_prior'] 90 | 91 | # priors 92 | sigma = config['sigma'] # std for constant kp prior, leave at 1 for deterministic chamfer-kl 93 | scale_std = config['scale_std'] 94 | offset_std = config['offset_std'] 95 | obj_on_alpha = config['obj_on_alpha'] # transparency beta distribution "a" 96 | obj_on_beta = config['obj_on_beta'] # transparency beta distribution "b" 97 | 98 | # load data 99 | dataset = get_image_dataset(ds, root, mode='train', image_size=image_size) 100 | dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4, pin_memory=True, 101 | drop_last=True) 102 | # model 103 | model = ObjectDLP(cdim=ch, enc_channels=enc_channels, prior_channels=prior_channels, 104 | image_size=image_size, n_kp=n_kp, learned_feature_dim=learned_feature_dim, 105 | pad_mode=pad_mode, sigma=sigma, 106 | dropout=dropout, patch_size=patch_size, n_kp_enc=n_kp_enc, 107 | n_kp_prior=n_kp_prior, kp_range=kp_range, kp_activation=kp_activation, 108 | anchor_s=anchor_s, use_resblock=use_resblock, 109 | scale_std=scale_std, offset_std=offset_std, obj_on_alpha=obj_on_alpha, 110 | obj_on_beta=obj_on_beta, 111 | use_correlation_heatmaps=use_correlation_heatmaps, use_tracking=use_tracking, 112 | enable_enc_attn=enable_enc_attn, filtering_heuristic=filtering_heuristic).to(device) 113 | print(model.info()) 114 | # prepare saving location 115 | run_name = f'{ds}_dlp' + run_prefix 116 | log_dir = prepare_logdir(runname=run_name, src_dir='./') 117 | fig_dir = os.path.join(log_dir, 'figures') 118 | save_dir = os.path.join(log_dir, 'saves') 119 | save_config(log_dir, hparams) 120 | 121 | # prepare loss functions 122 | if recon_loss_type == "vgg": 123 | recon_loss_func = VGGDistance(device=device) 124 | else: 125 | recon_loss_func = calc_reconstruction_loss 126 | 127 | # optimizer and scheduler 128 | optimizer = optim.Adam(model.get_parameters(), lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay) 129 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=scheduler_gamma, verbose=True) 130 | 131 | if load_model and pretrained_path is not None: 132 | try: 133 | model.load_state_dict(torch.load(pretrained_path, map_location=device)) 134 | print("loaded model from checkpoint") 135 | except: 136 | print("model checkpoint not found") 137 | 138 | # log statistics 139 | losses = [] 140 | losses_rec = [] 141 | losses_kl = [] 142 | losses_kl_kp = [] 143 | losses_kl_feat = [] 144 | losses_kl_scale = [] 145 | losses_kl_depth = [] 146 | losses_kl_obj_on = [] 147 | 148 | # initialize validation statistics 149 | valid_loss = best_valid_loss = 1e8 150 | valid_losses = [] 151 | best_valid_epoch = 0 152 | 153 | # save PSNR values of the reconstruction 154 | psnrs = [] 155 | 156 | # image metrics 157 | if eval_im_metrics: 158 | val_lpipss = [] 159 | best_val_lpips_epoch = 0 160 | val_lpips = best_val_lpips = 1e8 161 | 162 | for epoch in range(num_epochs): 163 | model.train() 164 | batch_losses = [] 165 | batch_losses_rec = [] 166 | batch_losses_kl = [] 167 | batch_losses_kl_kp = [] 168 | batch_losses_kl_feat = [] 169 | batch_losses_kl_scale = [] 170 | batch_losses_kl_depth = [] 171 | batch_losses_kl_obj_on = [] 172 | batch_psnrs = [] 173 | 174 | pbar = tqdm(iterable=dataloader) 175 | for batch in pbar: 176 | x = batch[0].to(device) 177 | if len(x.shape) == 5 and not use_tracking: 178 | # [bs, T, ch, h, w] 179 | x = x.view(-1, *x.shape[2:]) 180 | elif len(x.shape) == 4 and use_tracking: 181 | # [bs, ch, h, w] 182 | x = x.unsqueeze(1) 183 | x_prior = x # the input image to the prior is the same as the posterior 184 | noisy = (epoch < (warmup_epoch + 1)) 185 | # forward pass 186 | model_output = model(x, x_prior=x_prior, warmup=(epoch < warmup_epoch), noisy=noisy, bg_masks_from_fg=False, 187 | train_enc_prior=train_enc_prior) 188 | # calculate loss 189 | all_losses = model.calc_elbo(x, model_output, warmup=(epoch < warmup_epoch), beta_kl=beta_kl, 190 | beta_rec=beta_rec, kl_balance=kl_balance, 191 | recon_loss_type=recon_loss_type, 192 | recon_loss_func=recon_loss_func, noisy=noisy) 193 | loss = all_losses['loss'] 194 | optimizer.zero_grad() 195 | loss.backward() 196 | optimizer.step() 197 | 198 | # output for logging and plotting 199 | mu_p = model_output['kp_p'] 200 | z_base = model_output['z_base'] 201 | mu_offset = model_output['mu_offset'] 202 | logvar_offset = model_output['logvar_offset'] 203 | rec_x = model_output['rec'] 204 | mu_scale = model_output['mu_scale'] 205 | mu_depth = model_output['mu_depth'] 206 | # object stuff 207 | dec_objects_original = model_output['dec_objects_original'] 208 | cropped_objects_original = model_output['cropped_objects_original'] 209 | obj_on = model_output['obj_on'] # [batch_size, n_kp] 210 | alpha_masks = model_output['alpha_masks'] # [batch_size, n_kp, 1, h, w] 211 | 212 | psnr = all_losses['psnr'] 213 | obj_on_l1 = all_losses['obj_on_l1'] 214 | loss_kl = all_losses['kl'] 215 | loss_rec = all_losses['loss_rec'] 216 | loss_kl_kp = all_losses['loss_kl_kp'] 217 | loss_kl_feat = all_losses['loss_kl_feat'] 218 | loss_kl_scale = all_losses['loss_kl_scale'] 219 | loss_kl_depth = all_losses['loss_kl_depth'] 220 | loss_kl_obj_on = all_losses['loss_kl_obj_on'] 221 | 222 | if use_tracking: 223 | x = x.view(-1, *x.shape[2:]) 224 | x_prior = x_prior.view(-1, *x_prior.shape[2:]) 225 | # for plotting, confidence calculation 226 | mu_tot = z_base + mu_offset 227 | logvar_tot = logvar_offset 228 | # log 229 | batch_psnrs.append(psnr.data.cpu().item()) 230 | batch_losses.append(loss.data.cpu().item()) 231 | batch_losses_rec.append(loss_rec.data.cpu().item()) 232 | batch_losses_kl.append(loss_kl.data.cpu().item()) 233 | batch_losses_kl_kp.append(loss_kl_kp.data.cpu().item()) 234 | batch_losses_kl_feat.append(loss_kl_feat.data.cpu().item()) 235 | batch_losses_kl_scale.append(loss_kl_scale.data.cpu().item()) 236 | batch_losses_kl_depth.append(loss_kl_depth.data.cpu().item()) 237 | batch_losses_kl_obj_on.append(loss_kl_obj_on.data.cpu().item()) 238 | # progress bar 239 | if epoch < warmup_epoch: 240 | pbar.set_description_str(f'epoch #{epoch} (warmup)') 241 | elif noisy: 242 | pbar.set_description_str(f'epoch #{epoch} (noisy)') 243 | else: 244 | pbar.set_description_str(f'epoch #{epoch}') 245 | pbar.set_postfix(loss=loss.data.cpu().item(), rec=loss_rec.data.cpu().item(), 246 | kl=loss_kl.data.cpu().item(), on_l1=obj_on_l1.cpu().item()) 247 | # break # for debug 248 | pbar.close() 249 | losses.append(np.mean(batch_losses)) 250 | losses_rec.append(np.mean(batch_losses_rec)) 251 | losses_kl.append(np.mean(batch_losses_kl)) 252 | losses_kl_kp.append(np.mean(batch_losses_kl_kp)) 253 | losses_kl_feat.append(np.mean(batch_losses_kl_feat)) 254 | losses_kl_scale.append(np.mean(batch_losses_kl_scale)) 255 | losses_kl_depth.append(np.mean(batch_losses_kl_depth)) 256 | losses_kl_obj_on.append(np.mean(batch_losses_kl_obj_on)) 257 | if len(batch_psnrs) > 0: 258 | psnrs.append(np.mean(batch_psnrs)) 259 | # scheduler 260 | scheduler.step() 261 | 262 | # epoch summary 263 | log_str = f'epoch {epoch} summary\n' 264 | log_str += f'loss: {losses[-1]:.3f}, rec: {losses_rec[-1]:.3f}, kl: {losses_kl[-1]:.3f}\n' 265 | log_str += f'kl_balance: {kl_balance:.3f}, kl_kp: {losses_kl_kp[-1]:.3f}, kl_feat: {losses_kl_feat[-1]:.3f}\n' 266 | log_str += f'kl_scale: {losses_kl_scale[-1]:.3f}, kl_depth: {losses_kl_depth[-1]:.3f}, kl_obj_on: {losses_kl_obj_on[-1]:.3f}\n' 267 | 268 | # log_str += f'mu max: {mu.max()}, mu min: {mu.min()}\n' 269 | log_str += f'mu max: {mu_tot.max()}, mu min: {mu_tot.min()}\n' 270 | log_str += f'mu offset max: {mu_offset.max()}, mu offset min: {mu_offset.min()}\n' 271 | log_str += f'val loss (freq: {eval_epoch_freq}): {valid_loss:.3f},' \ 272 | f' best: {best_valid_loss:.3f} @ epoch: {best_valid_epoch}\n' 273 | if obj_on is not None: 274 | log_str += f'obj_on max: {obj_on.max()}, obj_on min: {obj_on.min()}\n' 275 | log_str += f'scale max: {mu_scale.sigmoid().max()}, scale min: {mu_scale.sigmoid().min()}\n' 276 | log_str += f'depth max: {mu_depth.max()}, depth min: {mu_depth.min()}\n' 277 | if eval_im_metrics: 278 | log_str += f'val lpips (freq: {eval_epoch_freq}): {val_lpips:.3f},' \ 279 | f' best: {best_val_lpips:.3f} @ epoch: {best_val_lpips_epoch}\n' 280 | print(log_str) 281 | log_line(log_dir, log_str) 282 | 283 | if epoch % eval_epoch_freq == 0 or epoch == num_epochs - 1: 284 | # for plotting purposes 285 | mu_plot = mu_tot.clamp(min=kp_range[0], max=kp_range[1]) 286 | max_imgs = 8 287 | img_with_kp = plot_keypoints_on_image_batch(mu_plot, x, radius=3, 288 | thickness=1, max_imgs=max_imgs, kp_range=kp_range) 289 | img_with_kp_p = plot_keypoints_on_image_batch(mu_p, x_prior, radius=3, thickness=1, max_imgs=max_imgs, 290 | kp_range=kp_range) 291 | # top-k 292 | with torch.no_grad(): 293 | logvar_sum = logvar_tot.sum(-1) * obj_on # [bs, n_kp] 294 | logvar_topk = torch.topk(logvar_sum, k=topk, dim=-1, largest=False) 295 | indices = logvar_topk[1] # [batch_size, topk] 296 | batch_indices = torch.arange(mu_tot.shape[0]).view(-1, 1).to(mu_tot.device) 297 | topk_kp = mu_tot[batch_indices, indices] 298 | # bounding boxes 299 | bb_scores = -1 * logvar_sum 300 | hard_threshold = None 301 | 302 | kp_batch = mu_plot 303 | scale_batch = mu_scale 304 | img_with_masks_nms, nms_ind = plot_bb_on_image_batch_from_z_scale_nms(kp_batch, scale_batch, x, 305 | scores=bb_scores, 306 | iou_thresh=iou_thresh, 307 | thickness=1, max_imgs=max_imgs, 308 | hard_thresh=hard_threshold) 309 | alpha_masks = torch.where(alpha_masks < 0.05, 0.0, 1.0) 310 | img_with_masks_alpha_nms, _ = plot_bb_on_image_batch_from_masks_nms(alpha_masks, x, scores=bb_scores, 311 | iou_thresh=iou_thresh, thickness=1, 312 | max_imgs=max_imgs, 313 | hard_thresh=hard_threshold) 314 | # hard_thresh: a general threshold for bb scores (set None to not use it) 315 | bb_str = f'bb scores: max: {bb_scores.max():.2f}, min: {bb_scores.min():.2f},' \ 316 | f' mean: {bb_scores.mean():.2f}\n' 317 | print(bb_str) 318 | log_line(log_dir, bb_str) 319 | img_with_kp_topk = plot_keypoints_on_image_batch(topk_kp.clamp(min=kp_range[0], max=kp_range[1]), x, 320 | radius=3, thickness=1, max_imgs=max_imgs, 321 | kp_range=kp_range) 322 | dec_objects = model_output['dec_objects'] 323 | bg = model_output['bg'] 324 | vutils.save_image(torch.cat([x[:max_imgs, -3:], img_with_kp[:max_imgs, -3:].to(device), 325 | rec_x[:max_imgs, -3:], img_with_kp_p[:max_imgs, -3:].to(device), 326 | img_with_kp_topk[:max_imgs, -3:].to(device), 327 | dec_objects[:max_imgs, -3:], 328 | img_with_masks_nms[:max_imgs, -3:].to(device), 329 | img_with_masks_alpha_nms[:max_imgs, -3:].to(device), 330 | bg[:max_imgs, -3:]], 331 | dim=0).data.cpu(), '{}/image_{}.jpg'.format(fig_dir, epoch), 332 | nrow=8, pad_value=1) 333 | with torch.no_grad(): 334 | _, dec_objects_rgb = torch.split(dec_objects_original, [1, 3], dim=2) 335 | dec_objects_rgb = dec_objects_rgb.reshape(-1, *dec_objects_rgb.shape[2:]) 336 | cropped_objects_original = cropped_objects_original.clone().reshape(-1, 3, 337 | cropped_objects_original.shape[ 338 | -1], 339 | cropped_objects_original.shape[ 340 | -1]) 341 | if cropped_objects_original.shape[-1] != dec_objects_rgb.shape[-1]: 342 | cropped_objects_original = F.interpolate(cropped_objects_original, 343 | size=dec_objects_rgb.shape[-1], 344 | align_corners=False, mode='bilinear') 345 | vutils.save_image( 346 | torch.cat([cropped_objects_original[:max_imgs * 2, -3:], dec_objects_rgb[:max_imgs * 2, -3:]], 347 | dim=0).data.cpu(), '{}/image_obj_{}.jpg'.format(fig_dir, epoch), 348 | nrow=8, pad_value=1) 349 | 350 | torch.save(model.state_dict(), os.path.join(save_dir, f'{ds}_dlp{run_prefix}.pth')) 351 | print("validation step...") 352 | valid_loss = evaluate_validation_elbo(model, config, epoch, batch_size=batch_size, 353 | recon_loss_type=recon_loss_type, device=device, 354 | save_image=True, fig_dir=fig_dir, topk=topk, 355 | recon_loss_func=recon_loss_func, beta_rec=beta_rec, 356 | iou_thresh=iou_thresh, 357 | beta_kl=beta_kl, kl_balance=kl_balance) 358 | log_str = f'validation loss: {valid_loss:.3f}\n' 359 | print(log_str) 360 | log_line(log_dir, log_str) 361 | if best_valid_loss > valid_loss: 362 | log_str = f'validation loss updated: {best_valid_loss:.3f} -> {valid_loss:.3f}\n' 363 | print(log_str) 364 | log_line(log_dir, log_str) 365 | best_valid_loss = valid_loss 366 | best_valid_epoch = epoch 367 | torch.save(model.state_dict(), 368 | os.path.join(save_dir, 369 | f'{ds}_dlp{run_prefix}_best.pth')) 370 | torch.cuda.empty_cache() 371 | if eval_im_metrics and epoch > 0: 372 | valid_imm_results = eval_dlp_im_metric(model, device, config, 373 | val_mode='val', 374 | eval_dir=log_dir, 375 | batch_size=batch_size) 376 | log_str = f'validation: lpips: {valid_imm_results["lpips"]:.3f}, ' 377 | log_str += f'psnr: {valid_imm_results["psnr"]:.3f}, ssim: {valid_imm_results["ssim"]:.3f}\n' 378 | val_lpips = valid_imm_results['lpips'] 379 | print(log_str) 380 | log_line(log_dir, log_str) 381 | if (not torch.isinf(torch.tensor(val_lpips))) and (best_val_lpips > val_lpips): 382 | log_str = f'validation lpips updated: {best_val_lpips:.3f} -> {val_lpips:.3f}\n' 383 | print(log_str) 384 | log_line(log_dir, log_str) 385 | best_val_lpips = val_lpips 386 | best_val_lpips_epoch = epoch 387 | torch.save(model.state_dict(), 388 | os.path.join(save_dir, f'{ds}_dlp{run_prefix}_best_lpips.pth')) 389 | torch.cuda.empty_cache() 390 | valid_losses.append(valid_loss) 391 | if eval_im_metrics: 392 | val_lpipss.append(val_lpips) 393 | # plot graphs 394 | if epoch > 0: 395 | num_plots = 4 396 | fig = plt.figure() 397 | ax = fig.add_subplot(num_plots, 1, 1) 398 | ax.plot(np.arange(len(losses[1:])), losses[1:], label="loss") 399 | ax.set_title(run_name) 400 | ax.legend() 401 | 402 | ax = fig.add_subplot(num_plots, 1, 2) 403 | ax.plot(np.arange(len(losses_kl[1:])), losses_kl[1:], label="kl", color='red') 404 | if learned_feature_dim > 0: 405 | ax.plot(np.arange(len(losses_kl_kp[1:])), losses_kl_kp[1:], label="kl_kp", color='cyan') 406 | ax.plot(np.arange(len(losses_kl_feat[1:])), losses_kl_feat[1:], label="kl_feat", color='green') 407 | ax.legend() 408 | 409 | ax = fig.add_subplot(num_plots, 1, 3) 410 | ax.plot(np.arange(len(losses_rec[1:])), losses_rec[1:], label="rec", color='green') 411 | ax.legend() 412 | 413 | ax = fig.add_subplot(num_plots, 1, 4) 414 | ax.plot(np.arange(len(valid_losses[1:])), valid_losses[1:], label="valid_loss", color='magenta') 415 | ax.legend() 416 | plt.tight_layout() 417 | plt.savefig(f'{fig_dir}/{run_name}_graph.jpg') 418 | plt.close('all') 419 | return model 420 | 421 | 422 | if __name__ == "__main__": 423 | parser = argparse.ArgumentParser(description="DLP Single-GPU Training") 424 | parser.add_argument("-d", "--dataset", type=str, default='shapes', 425 | help="dataset of to train the model on: ['traffic', 'clevrer', 'obj3d128', 'phyre']") 426 | args = parser.parse_args() 427 | ds = args.dataset 428 | if ds.endswith('json'): 429 | conf_path = ds 430 | else: 431 | conf_path = os.path.join('./configs', f'{ds}.json') 432 | 433 | train_dlp(conf_path) 434 | --------------------------------------------------------------------------------