├── .gitignore ├── .gitmodules ├── README.md ├── activate_data_proc ├── colab.yml ├── config ├── baseline_vae.yaml ├── data_preparation │ ├── human36m.yaml │ ├── iper.yaml │ ├── plants.yaml │ └── taichi.yaml ├── first_stage.yaml ├── img_encoder.yaml ├── model_names.txt ├── poke_encoder.yaml ├── posenet.yaml ├── pretrained_models │ ├── human36m_128.yaml │ ├── human36m_64.yaml │ ├── iper_128.yaml │ ├── iper_64.yaml │ ├── plants_128.yaml │ ├── plants_64.yaml │ ├── taichi_128.yaml │ └── taichi_64.yaml ├── second_stage.yaml └── test_config.yaml ├── data ├── __init__.py ├── base_dataset.py ├── checksums.txt ├── config.ini ├── datamodule.py ├── flow_dataset.py ├── helper_functions.py ├── human36m_preprocess.py ├── prepare_dataset.py └── samplers.py ├── data_proc.yml ├── experiments ├── __init__.py ├── experiment.py ├── first_stage_image.py ├── first_stage_video.py ├── poke_encoder.py └── second_stage_video.py ├── images ├── control_sensitivity.gif ├── fpp_final.png ├── gui_demo1.gif ├── gui_demo2.gif ├── gui_examples │ ├── iper_exmpl_1.png │ ├── iper_exmpl_2.png │ └── iper_exmpl_3.png ├── iper_exmpl_1.png ├── kinematics_transfer.gif ├── overview.gif └── paper.bib ├── ipoke.yml ├── main.py ├── models ├── conv_poke_encoder.py ├── first_stage_image_conv.py ├── first_stage_motion_model.py ├── modules │ ├── INN │ │ ├── INN.py │ │ ├── coupling_flow_alternative.py │ │ ├── flow_blocks.py │ │ ├── loss.py │ │ ├── macow.py │ │ ├── macow2.py │ │ ├── macow_utils.py │ │ └── modules.py │ ├── autoencoders │ │ ├── LPIPS.py │ │ ├── baseline_fc_models.py │ │ ├── big_ae.py │ │ ├── biggan.py │ │ ├── ckpt_util.py │ │ ├── distributions.py │ │ ├── fully_conv_models.py │ │ ├── util.py │ │ └── vgg16.py │ ├── discriminators │ │ ├── disc_utils.py │ │ ├── patchgan.py │ │ └── patchgan_3d.py │ └── motion_models │ │ ├── motion_encoder.py │ │ ├── motion_generator.py │ │ └── rnn.py ├── poke_vae.py ├── pretrained_models.py └── second_stage_video.py ├── testing ├── eval_models.py ├── evaluate_diversity.py ├── frechet_video_distance.py └── gui.py └── utils ├── callbacks.py ├── flownet_loader.py ├── general.py ├── logging.py ├── logging.yaml ├── losses.py ├── metrics.py └── posenet_wrapper.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | data/test_data/ 4 | train_f.py 5 | **/.ipynb_checkpoints 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "models/flownet2"] 2 | path = models/flownet2 3 | url = https://github.com/NVIDIA/flownet2-pytorch 4 | [submodule "models/pose_estimator"] 5 | path = models/pose_estimator 6 | url = git@github.com:ablattmann/pose_estimation_hrnet.git 7 | -------------------------------------------------------------------------------- /activate_data_proc: -------------------------------------------------------------------------------- 1 | # deactivates current environment and sets everything up for applying flownet2 2 | conda deactivate 3 | export PATH="${PATH}:/usr/local/cuda-10.0/bin/" 4 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda-10.0/lib64:/usr/local/cuda-10.0/lib" 5 | 6 | conda activate data_proc 7 | -------------------------------------------------------------------------------- /colab.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.0 8 | - imageio 9 | - imagesize 10 | - numpy 11 | - pandas 12 | - pathlib 13 | - pillow 14 | - pip=20.0.2 15 | - python=3.7.9 16 | - pytorch=1.7.1 17 | - torchvision 18 | - tqdm 19 | - yaml 20 | - pip: 21 | - coloredlogs 22 | - imageio-ffmpeg 23 | - ipykernel 24 | - ipython 25 | - ipython-genutils 26 | - jupyter-client 27 | - jupyter-core 28 | - kornia 29 | - lpips 30 | - moviepy==1.0.0 31 | - natsort 32 | - opencv-python==4.2.0.34 33 | - opt-einsum 34 | - pytorch-lightning==1.1.7 35 | - pyyaml 36 | - tensorflow-gan==2.0.0 37 | - tensorflow==2.5.0 38 | - tensorflow-hub==0.9.0 39 | - tensorflow-probability==0.11.1 40 | - silence-tensorflow==1.1.1 41 | - scikit-image 42 | - scikit-learn 43 | - scipy 44 | - tensorboard 45 | - umap-learn 46 | - wandb 47 | - seaborn 48 | - dotmap 49 | - streamlit 50 | - streamlit-drawable-canvas -------------------------------------------------------------------------------- /config/baseline_vae.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | experiment: first_stage 3 | #model_name: taichi-pokevae-ss64-bs16-mf10-mfdt12-bn32-wkl0.1-np1-stackpokemotion 4 | # h36m-pokevae-ss64-bs16-mf10-mfdt12-bn64-wkl0.1-np1-stackpokemotion 5 | # taichi-pokevae-ss64-bs16-mf10-mfdt12-bn32-wkl0.1-np1-stackpokemotion 6 | # plants-pokevae-ss64-bs16-mf10-mfdt12-bn64-wkl0.1-np1-stackpokemotion 7 | # iper-pokevae-ss64-bs16-mf10-mfdt12-bn32-wkl0.1-np5-stackpokemotion 8 | #taichi-pokevae-ss128-bs20-mf10-mfdt12-bn32-wkl0.1-np5 -stackpokemotion #h36m-motion_model-ss128-bs14-mf16-mfdt12-lessgp-spade_model-bn64 #plants-motion_model-ss128-bs14-mf16-mfdt12-lessgp-spade_model-bn64 9 | profiler: False 10 | debug: False 11 | base_dir: "logs" 12 | seed: 42 13 | 14 | data: 15 | dataset: TaichiDataset 16 | poke_size: 5 17 | # valid_lags: 1 18 | #subsample_step: 2 19 | max_frames: 10 20 | batch_size: 20 21 | n_workers: 20 22 | yield_videos: True 23 | spatial_size: !!python/tuple [128,128] 24 | p_col: .8 25 | p_geom: .8 26 | augment_b: 0.4 27 | augment_c: 0.5 28 | augment_h: 0.15 29 | augment_s: 0.4 30 | aug_deg: 15 # for iper use 0, for plants use 30°] 31 | # translation is (vertical, horizontal) 32 | aug_trans: !!python/tuple [0.1,0.1] 33 | split: official 34 | flow_weights: False 35 | filter_flow: False 36 | augment : True 37 | n_pokes: 5 38 | normalize_flows: False 39 | zero_poke: True 40 | zero_poke_amount: 12 41 | scale_poke_to_res: True 42 | 43 | testing: 44 | n_samples_fvd: 1000 45 | # for diversity measure 46 | n_samples_per_data_point: 5 47 | test_batch_size: 25 48 | n_samples_vis: 200 49 | n_samples_metrics: 1000 50 | verbose: True 51 | debug: False 52 | div_kp: False 53 | summarize_n_pokes: True 54 | 55 | 56 | 57 | training: 58 | lr: 0.0002 59 | weight_decay: 0.00001 60 | min_acc_batch_size: 3 61 | max_batches_per_epoch: 2000 62 | max_val_batches: 200 63 | profiler: False 64 | n_epochs: 1000 65 | w_kl: .1 #0.0000001 66 | w_l1: 10 67 | w_vgg: 10 68 | val_every: 1000 69 | gamma: 0.98 70 | vgg_1: False 71 | full_sequence: True 72 | kl_annealing: 5 73 | 74 | 75 | architecture: 76 | baseline: True 77 | ENC_M_channels: [64, 128, 256, 256, 256] 78 | decoder_factor: 64 79 | z_dim: 32 80 | norm: 'group' 81 | CN_content: 'spade' 82 | CN_motion: 'ADAIN' 83 | spectral_norm: True 84 | running_stats: False 85 | n_gru_layers: 4 86 | dec_channels: [256,256,256,128,64] 87 | min_spatial_size: 8 88 | stack_motion_and_poke: True 89 | 90 | d_t: 91 | use: True 92 | patch_temp_disc: False 93 | gp_weight: 1. 94 | fmap_weight: 1. 95 | gen_weight: 1. 96 | bce_loss: False 97 | num_classes: 1 98 | pretrain: 0 99 | max_frames: 12 100 | gp_1: False 101 | 102 | d_s: 103 | use: True 104 | bce_loss: False 105 | gp_weight: 0. 106 | fmap_weight: 1. 107 | pretrain: 0 108 | n_examples: 16 109 | gen_weight: 1. 110 | gp_1: False 111 | 112 | logging: 113 | n_val_img_batches: 4 114 | log_train_prog_at: 300 115 | n_saved_ckpt: 5 116 | n_samples_fvd: 1000 117 | bs_i3d: 8 118 | n_logged_img: 8 -------------------------------------------------------------------------------- /config/data_preparation/human36m.yaml: -------------------------------------------------------------------------------- 1 | # note: raw_dir not necessited for the human36m dataset 2 | #raw_dir: '/export/scratch/ablattma/datasets/human36m/videos/*/*' 3 | processed_dir: '/export/scratch/ablattma/datasets/human36m/processed' 4 | rgb_max: 1.0 5 | fp16_scale: 1024.0 6 | flow_delta: 5 7 | flow_max: 10 8 | mode: prepare # should be in [all, extract, prepare] 9 | video_format: mp4 10 | spatial_size: 256 11 | input_size: 1024 12 | frames_discr: 8 13 | target_gpus: [0] 14 | num_workers: 12 15 | 16 | data: 17 | dataset: Human36mDataset 18 | spatial_size: !!python/tuple [64,64] 19 | poke_size: 3 20 | n_pokes: 1 21 | # whether to split after videos or randomly (in a reproducible way) 22 | split: official #official 23 | num_workers: 20 24 | excluded_objects: [] # for iper leave empty, for plants, use [8] 25 | # list containing the lags, curren6l, its indices are allowed to be within the intervall [0,5], otherwise, an error is throw 26 | # can be in [image, video] 27 | yield_videos: True 28 | # augmentation parameters for color transformations 29 | p_col: 0.8 30 | p_geom: 0.8 31 | augment_b: 0.4 32 | augment_c: 0.5 33 | augment_h: 0.15 34 | augment_s: 0.4 35 | # augmentation parameters for geometric transformations 36 | aug_deg: 15 # for iper use 0, for plants use 30° 37 | # translation is (vertical, horizontal) 38 | aug_trans: !!python/tuple [0.1,0.1] # for iper use [0.1,0.6], for plants use [0.1,0.1] 39 | # whether to get flow weight 40 | foreground_value: 10. 41 | background_weight: 1. 42 | # filter 43 | filter: all 44 | fancy_aug: False 45 | # whether to load flow/images in ram 46 | flow_in_ram: False 47 | imgs_in_ram: False 48 | # sequence parameters 49 | max_frames: 10 50 | augment_wo_dis: True 51 | equal_poke_val: True 52 | subsample_step: 1 53 | flow_weights: False 54 | n_ref_frames: 10 -------------------------------------------------------------------------------- /config/data_preparation/iper.yaml: -------------------------------------------------------------------------------- 1 | raw_dir: '/export/scratch/compvis/dataset/iPER/raw_videos' 2 | processed_dir: '/export/scratch/ablattma/datasets/iPER' 3 | rgb_max: 1.0 4 | fp16_scale: 1024.0 5 | flow_delta: 5 6 | flow_max: 10 7 | mode: pose_estimation # should be in [all, extract, prepare, pose_estimation] 8 | video_format: mp4 9 | spatial_size: 256 10 | input_size: 1024 11 | frames_discr: 1 12 | target_gpus: [0] 13 | num_workers: 1 14 | 15 | 16 | data: 17 | dataset: IperDataset 18 | spatial_size: !!python/tuple [64,64] 19 | poke_size: 3 20 | n_pokes: 1 21 | # whether to split after videos or randomly (in a reproducible way) 22 | split: official #official 23 | num_workers: 20 24 | excluded_objects: [] # for iper leave empty, for plants, use [8] 25 | # list containing the lags, curren6l, its indices are allowed to be within the intervall [0,5], otherwise, an error is throw 26 | # can be in [image, video] 27 | yield_videos: True 28 | # augmentation parameters for color transformations 29 | p_col: 0.8 30 | p_geom: 0.8 31 | augment_b: 0.4 32 | augment_c: 0.5 33 | augment_h: 0.15 34 | augment_s: 0.4 35 | # augmentation parameters for geometric transformations 36 | aug_deg: 15 # for iper use 0, for plants use 30° 37 | # translation is (vertical, horizontal) 38 | aug_trans: !!python/tuple [0.1,0.1] # for iper use [0.1,0.6], for plants use [0.1,0.1] 39 | # whether to get flow weight 40 | foreground_value: 10. 41 | background_weight: 1. 42 | # filter 43 | filter: all 44 | fancy_aug: False 45 | # whether to load flow/images in ram 46 | flow_in_ram: False 47 | imgs_in_ram: False 48 | # sequence parameters 49 | max_frames: 10 50 | augment_wo_dis: True 51 | equal_poke_val: True 52 | subsample_step: 1 53 | flow_weights: False 54 | n_ref_frames: 10 -------------------------------------------------------------------------------- /config/data_preparation/plants.yaml: -------------------------------------------------------------------------------- 1 | raw_dir: '/export/scratch/compvis/datasets/plants/poking_plants' 2 | processed_dir: '/export/scratch/ablattma/datasets/plants' 3 | rgb_max: 1.0 # flownet2 parameter --> leave as is 4 | fp16_scale: 1024.0 # flownet2 parameter --> leave as is 5 | flow_max: 30 # maximum number of frames between which optical flow will be estimated 6 | flow_delta: 10 # discretization step for optical flow estimates so for this example, we would have 3 flow maps per image pair where for frame 0, it would be flow_0-->10, flow_0-->20 and flow_0-->30 7 | mode: prepare # should be in [all, extract, prepare] 8 | video_format: mkv 9 | spatial_size: 256 # the output size of the images and flow maps 10 | input_size: 1024 # the input size for the flow estimator i.e. the spatial resolution of the processed videos 11 | frames_discr: 1 12 | target_gpus: [9] # ids of the gpus, among which the processes will be devided 13 | num_workers: 1 # number of parallel-working optical flow estimators to process the data 14 | 15 | 16 | data: 17 | dataset: PlantDataset 18 | spatial_size: !!python/tuple [64,64] 19 | poke_size: 3 20 | n_pokes: 1 21 | # whether to split after videos or randomly (in a reproducible way) 22 | split: official #official 23 | num_workers: 20 24 | excluded_objects: [] # for iper leave empty, for plants, use [8] 25 | # list containing the lags, curren6l, its indices are allowed to be within the intervall [0,5], otherwise, an error is throw 26 | # can be in [image, video] 27 | yield_videos: True 28 | # augmentation parameters for color transformations 29 | p_col: 0.8 30 | p_geom: 0.8 31 | augment_b: 0.4 32 | augment_c: 0.5 33 | augment_h: 0.15 34 | augment_s: 0.4 35 | # augmentation parameters for geometric transformations 36 | aug_deg: 15 # for iper use 0, for plants use 30° 37 | # translation is (vertical, horizontal) 38 | aug_trans: !!python/tuple [0.1,0.1] # for iper use [0.1,0.6], for plants use [0.1,0.1] 39 | # whether to get flow weight 40 | foreground_value: 10. 41 | background_weight: 1. 42 | # filter 43 | filter: all 44 | fancy_aug: False 45 | # whether to load flow/images in ram 46 | flow_in_ram: False 47 | imgs_in_ram: False 48 | # sequence parameters 49 | max_frames: 10 50 | augment_wo_dis: True 51 | equal_poke_val: True 52 | subsample_step: 2 53 | flow_weights: False 54 | n_ref_frames: 10 -------------------------------------------------------------------------------- /config/data_preparation/taichi.yaml: -------------------------------------------------------------------------------- 1 | raw_dir: 2 | processed_dir: /export/scratch/compvis/datasets/taichi/taichi 3 | rgb_max: 1.0 4 | fp16_scale: 1024.0 5 | flow_delta: 10 6 | flow_max: 20 7 | mode: prepare # should be in [all, extract, prepare] 8 | video_format: mp4 9 | spatial_size: 256 10 | input_size: 256 11 | frames_discr: 1 12 | target_gpus: [0] 13 | num_workers: 1 14 | 15 | 16 | data: 17 | dataset: TaichiDataset 18 | spatial_size: !!python/tuple [64,64] 19 | poke_size: 3 20 | n_pokes: 1 21 | # whether to split after videos or randomly (in a reproducible way) 22 | split: official #official 23 | num_workers: 20 24 | excluded_objects: [] # for iper leave empty, for plants, use [8] 25 | # list containing the lags, curren6l, its indices are allowed to be within the intervall [0,5], otherwise, an error is throw 26 | # can be in [image, video] 27 | yield_videos: True 28 | # augmentation parameters for color transformations 29 | p_col: 0.8 30 | p_geom: 0.8 31 | augment_b: 0.4 32 | augment_c: 0.5 33 | augment_h: 0.15 34 | augment_s: 0.4 35 | # augmentation parameters for geometric transformations 36 | aug_deg: 15 # for iper use 0, for plants use 30° 37 | # translation is (vertical, horizontal) 38 | aug_trans: !!python/tuple [0.1,0.1] # for iper use [0.1,0.6], for plants use [0.1,0.1] 39 | # whether to get flow weight 40 | foreground_value: 10. 41 | background_weight: 1. 42 | # filter 43 | filter: all 44 | fancy_aug: False 45 | # whether to load flow/images in ram 46 | flow_in_ram: False 47 | imgs_in_ram: False 48 | # sequence parameters 49 | max_frames: 10 50 | augment_wo_dis: True 51 | equal_poke_val: True 52 | subsample_step: 2 53 | flow_weights: False 54 | n_ref_frames: 10 -------------------------------------------------------------------------------- /config/first_stage.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | experiment: first_stage 3 | profiler: False 4 | debug: False 5 | base_dir: "logs" 6 | seed: 42 7 | 8 | data: 9 | dataset: IperDataset # supported datasets are IperDataset, PlantDataset, Human3.6mDataset and TaichiDataset 10 | poke_size: 1 # not important for the first stage 11 | max_frames: 10 # the number of frames to predict 12 | batch_size: 20 13 | n_workers: 20 # data loading workers 14 | yield_videos: True # leave as is 15 | spatial_size: !!python/tuple [128,128] # spatial video size 16 | # data augmentation 17 | augment: True 18 | p_col: .8 19 | p_geom: .8 20 | augment_b: 0.4 21 | augment_c: 0.5 22 | augment_h: 0.15 23 | augment_s: 0.4 24 | aug_deg: 15 25 | aug_trans: !!python/tuple [0.1,0.1] 26 | split: official # see in the file data/flow_dataset.py 27 | filter_flow: False 28 | n_pokes: 1 # no important for first stage 29 | zero_poke: True # whether or not to train with simulated zero pokes to force the model to learn foreground background separation 30 | zero_poke_amount: 12 #frequency, when zero pokes occur in the training (the amount of zeropokes per epoch is 1 / zero_poke_amount 31 | filter: action # some datasets have special filter procedures, see data/flow_dataset.py 32 | 33 | training: 34 | lr: 0.0002 35 | weight_decay: 0.00001 36 | min_acc_batch_size: 3 37 | max_batches_per_epoch: 2000 38 | max_val_batches: 200 39 | profiler: Falshe 40 | n_epochs: 1000 41 | w_kl: 0.0000001 42 | w_l1: 10 43 | w_vgg: 10 44 | val_every: 1000 45 | gamma: 0.98 46 | vgg_1: False 47 | full_sequence: True 48 | 49 | 50 | architecture: 51 | ENC_M_channels: [64, 128, 256, 256, 256] # for models with for spatial video size 64x64 remove last entry 52 | decoder_factor: 32 53 | z_dim: 32 # number of channels for the video representation on which the invertible model will be trained later on 54 | norm: 'group' 55 | CN_content: 'spade' 56 | CN_motion: 'ADAIN' 57 | spectral_norm: True 58 | running_stats: False 59 | n_gru_layers: 4 # number of hidden layers in the latent GRU 60 | dec_channels: [256,256,256,128,64] # for models with for spatial video size 64x64 remove first entry 61 | min_spatial_size: 8 62 | motion_bias: True 63 | deterministic: False 64 | 65 | d_t: 66 | use: True 67 | patch_temp_disc: False 68 | gp_weight: 1. 69 | fmap_weight: 1. 70 | gen_weight: 1. 71 | bce_loss: False 72 | num_classes: 1 73 | pretrain: 0 74 | max_frames: 12 75 | gp_1: False 76 | 77 | d_s: 78 | use: True 79 | bce_loss: False 80 | gp_weight: 0. 81 | fmap_weight: 1. 82 | pretrain: 0 83 | n_examples: 16 84 | gen_weight: 1. 85 | gp_1: False 86 | 87 | logging: 88 | n_val_img_batches: 4 89 | log_train_prog_at: 300 90 | n_saved_ckpt: 5 91 | n_samples_fvd: 1000 92 | bs_i3d: 8 93 | n_logged_img: 8 94 | 95 | testing: 96 | n_samples_fvd: 1000 97 | # for diversity measure 98 | n_samples_per_data_point: 50 99 | test_batch_size: 16 -------------------------------------------------------------------------------- /config/img_encoder.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | experiment: img_encoder 3 | #model_name: taichi-conv-64-bs128-n_cn64-gp1 #sharedmodel-aug-test-bigae-256-dis-logvar-bs96-moreaug 4 | profiler: False 5 | debug: False 6 | base_dir: "logs" 7 | seed: 42 8 | 9 | data: 10 | dataset: TaichiDataset 11 | poke_size: 1 12 | subsample_step: 10 13 | max_frames: 1 14 | batch_size: 128 15 | n_workers: 20 16 | yield_videos: False 17 | spatial_size: !!python/tuple [64,64] 18 | p_col: .8 19 | p_geom: .8 20 | augment_b: 0.4 21 | augment_c: 0.5 22 | augment_h: 0.15 23 | augment_s: 0.4 24 | aug_deg: 15 # for iper use 0, for plants use 30° 25 | # translation is (vertical, horizontal) 26 | aug_trans: !!python/tuple [0.1,0.1] 27 | split: official 28 | flow_weights: False 29 | filter_flow: False 30 | augment: True 31 | n_pokes: 1 32 | # only for faster data loading 33 | normalize_flows: True 34 | 35 | 36 | training: 37 | lr: 2.0e-4 38 | weight_decay: 0 39 | min_acc_batch_size: 3 40 | # max_batches_per_epoch: 2000 41 | max_val_batches: 200 42 | profiler: True 43 | n_epochs: 20 44 | pretrain: 2 45 | w_kl: 1.0e-6 46 | val_every: 1. 47 | forward_sample: True 48 | gp_weight: 1. 49 | 50 | 51 | architecture: 52 | conv: True 53 | nf_in: 3 54 | nf_max: 64 55 | min_spatial_size: 8 56 | deterministic: True 57 | 58 | 59 | logging: 60 | n_val_img_batches: 4 61 | log_train_prog_at: 300 62 | n_saved_ckpt: 5 63 | n_log_images: 8 64 | 65 | testing: 66 | seed: 42 67 | -------------------------------------------------------------------------------- /config/model_names.txt: -------------------------------------------------------------------------------- 1 | iper_128 2 | plants_128 3 | h36m_128 4 | taichi_128 5 | iper_64 6 | plants_64 7 | h36m_64 8 | taichi_64 -------------------------------------------------------------------------------- /config/poke_encoder.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | experiment: poke_encoder 3 | #model_name: h36m-conv-64-bs128-n_cn8x8x64-endpoint10frames-flow_ae #sharedmodel-aug-test-bigae-256-dis-logvar-bs96-moreaug 4 | profiler: False 5 | debug: False 6 | base_dir: "logs" 7 | seed: 42 8 | 9 | data: 10 | dataset: Human36mDataset 11 | poke_size: 5 12 | #subsample_step: 2 13 | max_frames: 10 14 | batch_size: 64 15 | n_workers: 20 16 | yield_videos: False 17 | spatial_size: !!python/tuple [128,128] 18 | p_col: 0.8 19 | p_geom: 0.8 20 | augment_b: 0.4 21 | augment_c: 0.5 22 | augment_h: 0.15 23 | augment_s: 0.4 24 | aug_deg: 15 # for iper use 0, for plants use 30° 25 | # translation is (vertical, horizontal) 26 | aug_trans: !!python/tuple [0.1,0.1] 27 | split: official 28 | flow_weights: False 29 | augment: True 30 | n_pokes: 5 31 | scale_poke_to_res: True 32 | zero_poke_amount: 12 33 | zero_poke: True 34 | normalize_flows: False 35 | #valid_lags: 0 36 | 37 | 38 | training: 39 | lr: 0.001 40 | weight_decay: 0 41 | min_acc_batch_size: 3 42 | max_batches_per_epoch: 2000 43 | max_val_batches: 200 44 | profiler: False 45 | n_epochs: 20 46 | w_kl: 1 47 | val_every: 1. 48 | forward_sample: True 49 | 50 | testing: 51 | n_samples_fvd: 1000 52 | # for diversity measure 53 | n_samples_per_data_point: 50 54 | test_batch_size: 16 55 | 56 | 57 | architecture: 58 | conv: True 59 | nf_in: 2 60 | nf_max: 64 61 | min_spatial_size: 8 62 | deterministic: True 63 | flow_ae: True 64 | poke_and_image: False 65 | 66 | 67 | logging: 68 | n_val_img_batches: 4 69 | log_train_prog_at: 300 70 | n_saved_ckpt: 5 71 | n_log_images: 8 72 | 73 | -------------------------------------------------------------------------------- /config/posenet.yaml: -------------------------------------------------------------------------------- 1 | 2 | 3 | MODEL: 4 | NAME: 'pose_resnet' 5 | PRETRAINED: 'models/pytorch/imagenet/resnet152-b121ed2d.pth' 6 | IMAGE_SIZE: 7 | - 256 8 | - 256 9 | HEATMAP_SIZE: 10 | - 64 11 | - 64 12 | SIGMA: 2 13 | NUM_JOINTS: 16 14 | TARGET_TYPE: 'gaussian' 15 | EXTRA: 16 | FINAL_CONV_KERNEL: 1 17 | DECONV_WITH_BIAS: false 18 | NUM_DECONV_LAYERS: 3 19 | NUM_DECONV_FILTERS: 20 | - 256 21 | - 256 22 | - 256 23 | NUM_DECONV_KERNELS: 24 | - 4 25 | - 4 26 | - 4 27 | NUM_LAYERS: 152 28 | 29 | TEST: 30 | BATCH_SIZE_PER_GPU: 128 31 | COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json' 32 | BBOX_THRE: 1.0 33 | IMAGE_THRE: 0.0 34 | IN_VIS_THRE: 0.2 35 | MODEL_FILE: '' 36 | NMS_THRE: 1.0 37 | OKS_THRE: 0.9 38 | FLIP_TEST: true 39 | POST_PROCESS: true 40 | SHIFT_HEATMAP: true 41 | USE_GT_BBOX: true 42 | -------------------------------------------------------------------------------- /config/pretrained_models/human36m_128.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 32 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: h36m-ss128-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 20 70 | dataset: Human36mDataset 71 | flow_weights: false 72 | max_frames: 10 73 | n_pokes: 1 74 | n_workers: 26 75 | normalize_flows: false 76 | object_weighting: false 77 | p_col: 0.8 78 | p_geom: 0.8 79 | poke_size: 5 80 | scale_poke_to_res: true 81 | spatial_size: !!python/tuple 82 | - 128 83 | - 128 84 | split: official 85 | val_obj_weighting: false 86 | yield_videos: true 87 | zero_poke: true 88 | zero_poke_amount: 12 89 | first_stage: 90 | name: h36m-ss128-bn64-mf10 91 | general: 92 | base_dir: logs 93 | debug: false 94 | experiment: second_stage 95 | profiler: false 96 | seed: 42 97 | logging: 98 | log_train_prog_at: 200 99 | n_fvd_samples: 1000 100 | n_log_images: 8 101 | n_samples: 4 102 | n_samples_umap: 1000 103 | n_saved_ckpt: 5 104 | n_val_img_batches: 3 105 | poke_embedder: 106 | name: h36m-ss128-bn64-endpoint10f-np5 107 | use: true 108 | testing: 109 | n_samples_fvd: 1000 110 | n_samples_metrics: 2000 111 | n_samples_per_data_point: 60 112 | n_samples_vis: 200 113 | test_batch_size: 25 114 | verbose: true 115 | training: 116 | clip_grad_norm: 0.0 117 | custom_lr_decrease: true 118 | full_seq: true 119 | lr: 0.001 120 | lr_scaling: true 121 | lr_scaling_max_it: 500 122 | max_batches_per_epoch: 2000 123 | max_val_batches: 100 124 | min_acc_batch_size: 3 125 | mixed_prec: false 126 | n_epochs: 100 127 | spatial_mean: false 128 | use_adabelief: false 129 | use_logp_loss: false 130 | val_every: 0.5 131 | weight_decay: 1.0e-05 132 | ui: 133 | debug: false 134 | display_size: 128 135 | fixed_length: true 136 | fixed_seed: false 137 | fps: 5 138 | ids: [] 139 | interactive: false 140 | seq_length_to_generate: 10 -------------------------------------------------------------------------------- /config/pretrained_models/human36m_64.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 32 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: h36m-ss64-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 20 70 | dataset: Human36mDataset 71 | filter: all 72 | flow_weights: false 73 | max_frames: 10 74 | n_pokes: 1 75 | n_workers: 26 76 | normalize_flows: false 77 | object_weighting: true 78 | p_col: 0.8 79 | p_geom: 0.8 80 | poke_size: 5 81 | scale_poke_to_res: true 82 | spatial_size: !!python/tuple 83 | - 64 84 | - 64 85 | split: official 86 | val_obj_weighting: false 87 | yield_videos: true 88 | zero_poke: true 89 | zero_poke_amount: 12 90 | first_stage: 91 | name: h36m-ss64-bn64-mf10 92 | general: 93 | base_dir: logs 94 | debug: false 95 | experiment: second_stage 96 | profiler: false 97 | model_name: h36m_64 98 | seed: 42 99 | logging: 100 | log_train_prog_at: 200 101 | n_fvd_samples: 1000 102 | n_log_images: 8 103 | n_samples: 4 104 | n_samples_umap: 1000 105 | n_saved_ckpt: 5 106 | n_val_img_batches: 3 107 | poke_embedder: 108 | name: h36m-ss64-bn8x8x64-endpoint10f-np5 109 | use: true 110 | testing: 111 | debug: false 112 | n_samples_fvd: 1000 113 | n_samples_metrics: 1000 114 | n_samples_per_data_point: 50 115 | n_samples_vis: 200 116 | test_batch_size: 25 117 | verbose: true 118 | training: 119 | clip_grad_norm: 0.0 120 | custom_lr_decrease: true 121 | full_seq: true 122 | lr: 0.001 123 | lr_scaling: true 124 | lr_scaling_max_it: 500 125 | max_batches_per_epoch: 2000 126 | max_val_batches: 100 127 | min_acc_batch_size: 3 128 | mixed_prec: false 129 | n_epochs: 100 130 | spatial_mean: false 131 | use_adabelief: false 132 | use_logp_loss: false 133 | val_every: 0.5 134 | weight_decay: 1.0e-05 135 | ui: 136 | debug: false 137 | display_size: 256 138 | fixed_length: true 139 | fixed_seed: false 140 | fps: 5 141 | ids: [] 142 | interactive: false 143 | n_gt_pokes: 5 144 | model_name: iper-16_10d1-bs96-lr1e-3-bn128-fullseq-ss128-mf10-endpoint-np5-mweight 145 | seq_length_to_generate: 10 -------------------------------------------------------------------------------- /config/pretrained_models/iper_128.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 64 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: iper-ss128-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 40 70 | dataset: IperDataset 71 | filter: all 72 | flow_weights: false 73 | max_frames: 10 74 | n_pokes: 5 75 | n_workers: 26 76 | normalize_flows: false 77 | object_weighting: false 78 | p_col: 0.8 79 | p_geom: 0.8 80 | poke_size: 5 81 | scale_poke_to_res: true 82 | spatial_size: !!python/tuple 83 | - 128 84 | - 128 85 | split: official 86 | val_obj_weighting: false 87 | yield_videos: true 88 | zero_poke: true 89 | zero_poke_amount: 12 90 | first_stage: 91 | name: iper-ss128-bn32-mf10-complex 92 | general: 93 | base_dir: logs 94 | debug: false 95 | experiment: second_stage 96 | profiler: false 97 | seed: 42 98 | logging: 99 | log_train_prog_at: 200 100 | n_fvd_samples: 1000 101 | n_log_images: 8 102 | n_samples: 4 103 | n_samples_umap: 1000 104 | n_saved_ckpt: 5 105 | n_val_img_batches: 3 106 | poke_embedder: 107 | name: iper-ss128-bn64-endpoint10f-np5 108 | use: true 109 | testing: 110 | debug: false 111 | div_kp: false 112 | n_samples_fvd: 1000 113 | n_samples_metrics: 1000 114 | n_samples_per_data_point: 100 115 | n_samples_vis: 100 116 | n_test_pokes: 1 117 | seed: 31 118 | summarize_n_pokes: true 119 | test_batch_size: 10 120 | verbose: true 121 | training: 122 | clip_grad_norm: 0.0 123 | custom_lr_decrease: true 124 | full_seq: true 125 | lr: 0.001 126 | lr_scaling: true 127 | lr_scaling_max_it: 500 128 | max_batches_per_epoch: 2000 129 | max_val_batches: 100 130 | min_acc_batch_size: 3 131 | mixed_prec: false 132 | n_epochs: 100 133 | spatial_mean: false 134 | use_adabelief: false 135 | use_logp_loss: false 136 | val_every: 0.5 137 | weight_decay: 1.0e-05 138 | ui: 139 | debug: false 140 | display_size: 256 141 | fixed_length: true 142 | fixed_seed: false 143 | fps: 5 144 | ids: [] 145 | interactive: false 146 | n_gt_pokes: 5 147 | model_name: iper-16_10d1-bs96-lr1e-3-bn32-fmcf64-fullseq-ss128-mf10-endpoint-np5-complex 148 | save_fps: 3 -------------------------------------------------------------------------------- /config/pretrained_models/iper_64.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 64 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: iper-ss64-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 40 70 | dataset: IperDataset 71 | flow_weights: false 72 | max_frames: 10 73 | n_pokes: 5 74 | n_workers: 26 75 | normalize_flows: false 76 | object_weighting: true 77 | p_col: 0.8 78 | p_geom: 0.8 79 | poke_size: 5 80 | scale_poke_to_res: true 81 | spatial_size: !!python/tuple 82 | - 64 83 | - 64 84 | split: official 85 | val_obj_weighting: false 86 | yield_videos: true 87 | zero_poke: true 88 | zero_poke_amount: 12 89 | first_stage: 90 | name: iper-ss64-bn32-mf10 91 | general: 92 | base_dir: logs 93 | debug: false 94 | experiment: second_stage 95 | profiler: false 96 | seed: 42 97 | logging: 98 | log_train_prog_at: 200 99 | n_fvd_samples: 1000 100 | n_log_images: 8 101 | n_samples: 4 102 | n_samples_umap: 1000 103 | n_saved_ckpt: 5 104 | n_val_img_batches: 3 105 | poke_embedder: 106 | name: iper-ss64-bn8x8x64-endpoint10f-np5 107 | use: true 108 | testing: 109 | debug: false 110 | n_samples_fvd: 1000 111 | n_samples_metrics: 100 112 | n_samples_per_data_point: 5 113 | n_samples_vis: 200 114 | test_batch_size: 25 115 | verbose: true 116 | training: 117 | clip_grad_norm: 0.0 118 | custom_lr_decrease: true 119 | full_seq: true 120 | lr: 0.001 121 | lr_scaling: true 122 | lr_scaling_max_it: 500 123 | max_batches_per_epoch: 2000 124 | max_val_batches: 100 125 | min_acc_batch_size: 3 126 | mixed_prec: false 127 | n_epochs: 100 128 | spatial_mean: false 129 | use_adabelief: false 130 | use_logp_loss: false 131 | val_every: 0.5 132 | weight_decay: 1.0e-05 133 | ui: 134 | debug: false 135 | display_size: 256 136 | fixed_length: true 137 | fixed_seed: false 138 | fps: 5 139 | ids: [] 140 | interactive: false 141 | model_name: plants-16_10d1-bs20-lr1e-3-bn64-fullseq-mfc32-ss128-mf10-endpoint-np5 142 | seq_length_to_generate: 10 -------------------------------------------------------------------------------- /config/pretrained_models/plants_128.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 32 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: plants-ss128-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 20 70 | dataset: PlantDataset 71 | flow_weights: false 72 | max_frames: 10 73 | n_pokes: 5 74 | n_workers: 26 75 | normalize_flows: false 76 | object_weighting: true 77 | p_col: 0.8 78 | p_geom: 0.8 79 | poke_size: 5 80 | scale_poke_to_res: true 81 | spatial_size: !!python/tuple 82 | - 128 83 | - 128 84 | split: official 85 | val_obj_weighting: false 86 | yield_videos: true 87 | zero_poke: true 88 | zero_poke_amount: 12 89 | first_stage: 90 | name: plants-ss128-bn64-mf10 91 | general: 92 | base_dir: logs 93 | debug: false 94 | experiment: second_stage 95 | profiler: false 96 | seed: 42 97 | logging: 98 | log_train_prog_at: 200 99 | n_fvd_samples: 1000 100 | n_log_images: 8 101 | n_samples: 4 102 | n_samples_umap: 1000 103 | n_saved_ckpt: 5 104 | n_val_img_batches: 3 105 | poke_embedder: 106 | name: plants-ss128-bn64-endpoint10f-np5 107 | use: true 108 | testing: 109 | n_samples_fvd: 1000 110 | n_samples_metrics: 2000 111 | n_samples_per_data_point: 60 112 | n_samples_vis: 200 113 | test_batch_size: 25 114 | verbose: true 115 | training: 116 | clip_grad_norm: 0.0 117 | custom_lr_decrease: true 118 | full_seq: true 119 | lr: 0.001 120 | lr_scaling: true 121 | lr_scaling_max_it: 500 122 | max_batches_per_epoch: 2000 123 | max_val_batches: 100 124 | min_acc_batch_size: 3 125 | mixed_prec: false 126 | n_epochs: 100 127 | spatial_mean: false 128 | use_adabelief: false 129 | use_logp_loss: false 130 | val_every: 0.5 131 | weight_decay: 1.0e-05 132 | ui: 133 | debug: false 134 | display_size: 256 135 | fixed_length: true 136 | fixed_seed: false 137 | fps: 5 138 | ids: [] 139 | interactive: false 140 | model_name: iper-16_10d1-bs96-lr1e-3-bn128-fullseq-ss128-mf10-endpoint-np5-mweight 141 | seq_length_to_generate: 10 -------------------------------------------------------------------------------- /config/pretrained_models/plants_64.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 64 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: plants-ss64-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 40 70 | dataset: PlantDataset 71 | flow_weights: false 72 | max_frames: 10 73 | n_pokes: 5 74 | n_workers: 26 75 | normalize_flows: false 76 | object_weighting: false 77 | p_col: 0.8 78 | p_geom: 0.8 79 | poke_size: 5 80 | scale_poke_to_res: true 81 | spatial_size: !!python/tuple 82 | - 64 83 | - 64 84 | split: official 85 | val_obj_weighting: false 86 | yield_videos: true 87 | zero_poke: true 88 | zero_poke_amount: 12 89 | first_stage: 90 | name: plants-ss64-bn32-mf10 91 | general: 92 | base_dir: logs 93 | debug: false 94 | experiment: second_stage 95 | profiler: false 96 | seed: 42 97 | logging: 98 | log_train_prog_at: 200 99 | n_fvd_samples: 1000 100 | n_log_images: 8 101 | n_samples: 4 102 | n_samples_umap: 1000 103 | n_saved_ckpt: 5 104 | n_val_img_batches: 3 105 | poke_embedder: 106 | name: plants-ss64-bn8x8x64-endpoint10f-np5 107 | use: true 108 | testing: 109 | debug: false 110 | n_samples_fvd: 1000 111 | n_samples_metrics: 100 112 | n_samples_per_data_point: 5 113 | n_samples_vis: 200 114 | test_batch_size: 25 115 | verbose: true 116 | training: 117 | clip_grad_norm: 0.0 118 | custom_lr_decrease: true 119 | full_seq: true 120 | lr: 0.001 121 | lr_scaling: true 122 | lr_scaling_max_it: 500 123 | max_batches_per_epoch: 2000 124 | max_val_batches: 100 125 | min_acc_batch_size: 3 126 | mixed_prec: false 127 | n_epochs: 100 128 | spatial_mean: false 129 | use_adabelief: false 130 | use_logp_loss: false 131 | val_every: 0.5 132 | weight_decay: 1.0e-05 133 | ui: 134 | debug: false 135 | display_size: 256 136 | fixed_length: true 137 | fixed_seed: false 138 | fps: 5 139 | ids: [] 140 | interactive: false 141 | model_name: plants-16_10d1-bs20-lr1e-3-bn64-fullseq-mfc32-ss128-mf10-endpoint-np5 142 | seq_length_to_generate: 10 -------------------------------------------------------------------------------- /config/pretrained_models/taichi_128.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 64 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: taichi-ss128-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 40 70 | dataset: TaichiDataset 71 | flow_weights: false 72 | max_frames: 10 73 | n_pokes: 5 74 | n_workers: 26 75 | normalize_flows: false 76 | object_weighting: false 77 | p_col: 0.8 78 | p_geom: 0.8 79 | poke_size: 5 80 | scale_poke_to_res: true 81 | spatial_size: !!python/tuple 82 | - 128 83 | - 128 84 | split: official 85 | val_obj_weighting: false 86 | yield_videos: true 87 | zero_poke: true 88 | zero_poke_amount: 12 89 | first_stage: 90 | name: taichi-ss128-bn32-mf10 91 | general: 92 | base_dir: logs 93 | debug: false 94 | experiment: second_stage 95 | profiler: false 96 | seed: 42 97 | logging: 98 | log_train_prog_at: 200 99 | n_fvd_samples: 1000 100 | n_log_images: 8 101 | n_samples: 4 102 | n_samples_umap: 1000 103 | n_saved_ckpt: 5 104 | n_val_img_batches: 3 105 | poke_embedder: 106 | name: taichi-ss128-bn8x8x64-endpoint10f-np5 107 | use: true 108 | testing: 109 | n_samples_fvd: 1000 110 | n_samples_metrics: 1000 111 | n_samples_per_data_point: 50 112 | n_samples_vis: 200 113 | test_batch_size: 25 114 | verbose: true 115 | training: 116 | clip_grad_norm: 0.0 117 | custom_lr_decrease: true 118 | full_seq: true 119 | lr: 0.001 120 | lr_scaling: true 121 | lr_scaling_max_it: 500 122 | max_batches_per_epoch: 2000 123 | max_val_batches: 100 124 | min_acc_batch_size: 3 125 | mixed_prec: false 126 | n_epochs: 100 127 | spatial_mean: false 128 | use_adabelief: false 129 | use_logp_loss: false 130 | val_every: 0.5 131 | weight_decay: 1.0e-05 132 | ui: 133 | debug: false 134 | display_size: 256 135 | fixed_length: true 136 | fixed_seed: false 137 | fps: 5 138 | ids: [] 139 | interactive: false 140 | model_name: plants-16_10d1-f64-bs40lr1e-3-bn64-fullseq-ss128-mf10-endpoint-np5-poke_scale 141 | seq_length_to_generate: 10 -------------------------------------------------------------------------------- /config/pretrained_models/taichi_64.yaml: -------------------------------------------------------------------------------- 1 | architecture: 2 | activation: elu 3 | attention: false 4 | augment_channels: 32 5 | augmented_input: false 6 | cond_conv: false 7 | cond_conv_hidden_channels: 256 8 | condition_nice: false 9 | coupling_type: conv 10 | factor: 16 11 | factors: 12 | - 4 13 | - 8 14 | flow_attn_heads: 4 15 | flow_mid_channels_factor: 64 16 | kernel_size: 17 | - 2 18 | - 3 19 | levels: 20 | - - 15 21 | - 10 22 | - 5 23 | - - 15 24 | - 12 25 | - 10 26 | - 8 27 | - 6 28 | - 4 29 | - 2 30 | multistack: false 31 | n_blocks: 2 32 | n_flows: 20 33 | num_steps: 34 | - 10 35 | - 5 36 | - 5 37 | - 4 38 | - 4 39 | - 4 40 | - 3 41 | - 3 42 | - 3 43 | - 2 44 | - 2 45 | - 2 46 | - 1 47 | - 1 48 | - 1 49 | p_dropout: 0.0 50 | prior_transform: affine 51 | reshape: none 52 | scale: false 53 | scale_augmentation: true 54 | shift_augmentation: true 55 | transform: affine 56 | conditioner: 57 | name: taichi-ss64-bn64 58 | use: true 59 | data: 60 | aug_deg: 15 61 | aug_trans: !!python/tuple 62 | - 0.1 63 | - 0.1 64 | augment_b: 0.4 65 | augment_c: 0.5 66 | augment_h: 0.15 67 | augment_s: 0.4 68 | augment_wo_dis: true 69 | batch_size: 40 70 | dataset: TaichiDataset 71 | flow_weights: false 72 | max_frames: 10 73 | n_pokes: 5 74 | n_workers: 26 75 | normalize_flows: false 76 | object_weighting: false 77 | p_col: 0.8 78 | p_geom: 0.8 79 | poke_size: 5 80 | scale_poke_to_res: true 81 | spatial_size: !!python/tuple 82 | - 64 83 | - 64 84 | split: official 85 | val_obj_weighting: false 86 | yield_videos: true 87 | zero_poke: true 88 | zero_poke_amount: 12 89 | first_stage: 90 | name: taichi-ss64-bn32-mf10 91 | general: 92 | base_dir: logs 93 | debug: false 94 | experiment: second_stage 95 | profiler: false 96 | seed: 42 97 | logging: 98 | log_train_prog_at: 200 99 | n_fvd_samples: 1000 100 | n_log_images: 8 101 | n_samples: 4 102 | n_samples_umap: 1000 103 | n_saved_ckpt: 5 104 | n_val_img_batches: 3 105 | poke_embedder: 106 | name: taichi-ss64-bn8x8x64-endpoint10f-np5 107 | use: true 108 | testing: 109 | debug: true 110 | n_samples_fvd: 1000 111 | n_samples_metrics: 50 112 | n_samples_per_data_point: 4 113 | n_samples_vis: 200 114 | test_batch_size: 25 115 | verbose: true 116 | training: 117 | clip_grad_norm: 0.0 118 | custom_lr_decrease: true 119 | full_seq: true 120 | lr: 0.001 121 | lr_scaling: true 122 | lr_scaling_max_it: 500 123 | max_batches_per_epoch: 2000 124 | max_val_batches: 100 125 | min_acc_batch_size: 3 126 | mixed_prec: false 127 | n_epochs: 100 128 | spatial_mean: false 129 | use_adabelief: false 130 | use_logp_loss: false 131 | val_every: 0.5 132 | weight_decay: 1.0e-05 133 | ui: 134 | debug: false 135 | display_size: 256 136 | fixed_length: true 137 | fixed_seed: false 138 | fps: 5 139 | ids: [] 140 | interactive: false 141 | model_name: iper-16_10d1-bs96-lr1e-3-bn128-fullseq-ss128-mf10-endpoint-np5-mweight 142 | seq_length_to_generate: 10 -------------------------------------------------------------------------------- /config/second_stage.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | experiment: second_stage 3 | profiler: False 4 | debug: False 5 | base_dir: "logs" 6 | seed: 42 7 | 8 | # name of the pretrained first stage model (video autoencoding framework), the pretrained models are listed in the respective dictionary in models/pretrained_models.py . 9 | # When having trained your own model, you can add it under a new key and specifiy its model name and the path to the checkpoint which shall be used 10 | first_stage: 11 | name: iper-ss128-bn32-mf10-complex 12 | 13 | # name of the pretrained image encoder model (\phi_{x_0} in the paper), the pretrained models are listed in the respective dictionary in models/pretrained_models.py . 14 | # When having trained your own model, you can add it under a new key and specifiy its model name and the path to the checkpoint which shall be used 15 | conditioner: 16 | use: True 17 | name: iper-ss128-bn64 18 | 19 | # name of the pretrained poke embedder model (\phi_{c} in the paper), the pretrained models are listed in the respective dictionary in models/pretrained_models.py. 20 | # When having trained your own model, you can add it under a new key and specifiy its model name and the path to the checkpoint which shall be used 21 | poke_embedder: 22 | use: True 23 | name: iper-ss128-bn64-endpoint10f-np5 24 | # 25 | 26 | data: 27 | dataset: IperDataset # supported datasets are IperDataset, PlantDataset, Human3.6mDataset and TaichiDataset 28 | # a window will be used for the poke with from (x-poke_size/2:x+poke_size/2,y-poke_size/2:y+poke_size/2) where (x,y) is the poked location. 29 | #All the values therein will be initialized with the respective pixel shift value 30 | poke_size: 5 31 | max_frames: 10 # number of predicted frames 32 | batch_size: 40 33 | n_workers: 26 34 | yield_videos: True # leave as is 35 | spatial_size: !!python/tuple [128,128] # spatial video resolution 36 | # data augmentation 37 | augment: True 38 | p_col: 0.8 39 | p_geom: 0.8 40 | augment_b: 0.4 41 | augment_c: 0.5 42 | augment_h: 0.15 43 | augment_s: 0.4 44 | aug_deg: 15 45 | aug_trans: !!python/tuple [0.1,0.1] 46 | split: official # see in the file data/flow_dataset.py 47 | n_pokes: 5 # the maximum number of pokes for a given training example. The actual number will be randomly chosen from within [1,n_pokes] 48 | zero_poke: True # whether or not to train with simulated zero pokes to force the model to learn foreground background separation 49 | zero_poke_amount: 12 #frequency, when zero pokes occur in the training (the amount of zeropokes per epoch is 1 / zero_poke_amount 50 | scale_poke_to_res: True # whether or not to scale the flow magnitudes according to the spatial downsampling of the videos 51 | filter: all # some datasets have special filter procedures, see data/flow_dataset.py 52 | 53 | architecture: 54 | attention: False 55 | n_blocks: 2 56 | flow_mid_channels_factor: 64 57 | flow_attn_heads: 4 58 | kernel_size: [2,3] 59 | coupling_type: "conv" 60 | scale: False 61 | n_flows: 20 62 | num_steps: [10,5,5,4,4,4,3,3,3,2,2,2,1,1,1] 63 | factor: 16 64 | levels: [[15,10, 5], [15, 12, 10, 8, 6, 4, 2]] 65 | factors: [4,8] 66 | activation: "elu" 67 | transform: "affine" 68 | prior_transform: "affine" 69 | condition_nice: False 70 | augmented_input: False 71 | augment_channels: 32 72 | scale_augmentation: True 73 | shift_augmentation: True 74 | multistack: False 75 | cond_conv: False 76 | cond_conv_hidden_channels: 256 77 | reshape: none 78 | p_dropout: 0. 79 | 80 | testing: 81 | n_samples_fvd: 1000 82 | # for diversity measure 83 | n_samples_per_data_point: 5 84 | test_batch_size: 20 85 | n_samples_vis: 100 86 | n_samples_metrics: 1000 87 | verbose: True 88 | debug: False 89 | div_kp: False 90 | summarize_n_pokes: False 91 | n_test_pokes: 1 92 | seed: 42 93 | n_control_sensitivity_pokes: 32 94 | 95 | 96 | training: 97 | lr: 1.0e-3 98 | weight_decay: 1.0e-5 99 | min_acc_batch_size: 3 100 | max_batches_per_epoch: 2000 101 | max_val_batches: 100 102 | use_logp_loss: False 103 | n_epochs: 100 104 | val_every: .5 105 | clip_grad_norm: 0. 106 | lr_scaling: True 107 | lr_scaling_max_it: 500 108 | custom_lr_decrease: True 109 | mixed_prec: False 110 | full_seq: True 111 | spatial_mean: False 112 | use_adabelief: False 113 | # logdet_weight: .5 114 | 115 | 116 | logging: 117 | n_val_img_batches: 3 118 | log_train_prog_at: 200 119 | n_saved_ckpt: 5 120 | n_log_images: 8 121 | n_samples: 4 122 | n_samples_umap: 1000 123 | n_fvd_samples: 1000 124 | 125 | 126 | ui: 127 | display_size: 256 128 | debug: False 129 | fixed_length: True 130 | #seq_length_to_generate: 10 131 | fps: 5 132 | save_fps: 3 133 | fixed_seed: False 134 | interactive: False 135 | ids: [] 136 | n_gt_pokes: 5 137 | -------------------------------------------------------------------------------- /config/test_config.yaml: -------------------------------------------------------------------------------- 1 | general: 2 | experiment: dummy 3 | model_name: test 4 | 5 | 6 | poke_coords: False 7 | max_samples: 100 8 | fix_seed: True 9 | n_logged: 6 10 | gpu: 5 11 | n_exmpls_pose_metric: 15 12 | nn_computation: False 13 | 14 | training: 15 | lr: 0.0001 16 | 17 | data: 18 | dataset: IperDataset 19 | poke_size: 5 20 | # subsample_step: 1 21 | max_frames: 10 22 | batch_size: 6 23 | n_workers: 20 24 | yield_videos: True 25 | spatial_size: !!python/tuple [128,128] 26 | p_col: 0.8 27 | p_geom: 0.8 28 | augment_b: 0.4 29 | augment_c: 0.5 30 | augment_h: 0.15 31 | augment_s: 0.4 32 | aug_deg: 15 # for iper use 0, for plants use 30° 33 | # translation is (vertical, horizontal) 34 | aug_trans: !!python/tuple [0.1,0.1] 35 | split: official 36 | flow_weights: False 37 | augment_wo_dis: True 38 | n_pokes: 1 39 | zero_poke: False 40 | zero_poke_amount: 8 41 | normalize_flows: False 42 | object_weighting: False 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset 2 | from torchvision import transforms as tt 3 | from data.flow_dataset import PlantDataset, IperDataset,Human36mDataset, TaichiDataset 4 | from data.samplers import SequenceSampler,FixedLengthSampler,SequenceLengthSampler 5 | 6 | # add key value pair for datasets here, all datasets should inherit from base_dataset 7 | __datasets__ = {"IperDataset": IperDataset, 8 | "PlantDataset": PlantDataset, 9 | "Human36mDataset": Human36mDataset, 10 | "TaichiDataset": TaichiDataset,} 11 | 12 | __samplers__ = {"fixed_length": FixedLengthSampler, 13 | } 14 | 15 | 16 | # returns only the class, not yet an instance 17 | def get_transforms(config): 18 | return { 19 | "PlantDataset": tt.Compose( 20 | [ 21 | tt.ToTensor(), 22 | tt.Lambda(lambda x: (x * 2.0) - 1.0), 23 | ] 24 | ), 25 | "IperDataset": tt.Compose( 26 | [ 27 | tt.ToTensor(), 28 | tt.Lambda(lambda x: (x * 2.0) - 1.0), 29 | ] 30 | ), 31 | "Human36mDataset": tt.Compose( 32 | [ 33 | tt.ToTensor(), 34 | tt.Lambda(lambda x: (x * 2.0) - 1.0), 35 | ] 36 | ), 37 | "TaichiDataset": tt.Compose( 38 | [ 39 | tt.ToTensor(), 40 | tt.Lambda(lambda x: (x * 2.0) - 1.0), 41 | ] 42 | ), 43 | } 44 | 45 | 46 | def get_dataset(config, custom_transforms=None): 47 | dataset = __datasets__[config["dataset"]] 48 | if custom_transforms is not None: 49 | print("Returning dataset with custom transform") 50 | transforms = custom_transforms 51 | else: 52 | transforms = get_transforms(config)[config["dataset"]] 53 | return dataset, transforms 54 | 55 | -------------------------------------------------------------------------------- /data/checksums.txt: -------------------------------------------------------------------------------- 1 | d517e6c0b1112427b2a39fcbd732281c archives/Videos_S1.tgz 2 | 02ef041813c3a37b137f86df24419e5a archives/Videos_S5.tgz 3 | a4b8690e5320c5854f99f60bf31cbabc archives/Videos_S6.tgz 4 | 79caf93c6ec31b1c14cd1d31d5f292e0 archives/Videos_S7.tgz 5 | 18818148e68fcd80fce1efa82f98126d archives/Videos_S8.tgz 6 | 3e7d923d5c573ac833334a31b5f8a797 archives/Videos_S9.tgz 7 | 25c25250be8f75d7991dcbf74bb9d339 archives/Videos_S11.tgz -------------------------------------------------------------------------------- /data/config.ini: -------------------------------------------------------------------------------- 1 | [General] 2 | 3 | # Get your PHPSESSID by logging into http://vision.imar.ro/human3.6m/ and inspecting the cookies 4 | # with your web browser. 5 | PHPSESSID=fv2r3aas6blltpibdm9v3er9r5 6 | TARGETDIR=/export/scratch/ablattma/datasets/human36m -------------------------------------------------------------------------------- /data/datamodule.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from torch.utils.data import DataLoader,WeightedRandomSampler 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | from data import get_dataset 7 | from data.samplers import FixedLengthSampler 8 | 9 | class StaticDataModule(LightningDataModule): 10 | 11 | def __init__(self,config, datakeys,debug=False): 12 | from data.flow_dataset import IperDataset 13 | super().__init__() 14 | self.config = config 15 | self.datakeys = datakeys 16 | self.batch_size = self.config["batch_size"] 17 | self.num_workers = self.config["n_workers"] 18 | self.zero_poke = "zero_poke" in self.config and self.config["zero_poke"] 19 | self.dset, self.transforms = get_dataset(self.config) 20 | self.dset_train = self.dset(self.transforms,self.datakeys,self.config, train=True, debug=debug) 21 | if isinstance(self.dset_train,IperDataset) and self.dset_train.yield_videos: 22 | self.test_datakeys = self.datakeys + ['keypoints_rel','keypoints_abs','keypoint_poke' ,'nn'] # 23 | self.val_datakeys = self.datakeys + ['keypoints_rel', 'keypoints_abs', 'keypoint_poke'] 24 | else: 25 | self.test_datakeys = self.val_datakeys = self.datakeys 26 | 27 | if self.config['filter'] != 'all': 28 | self.test_config = deepcopy(self.config) 29 | else: 30 | self.test_config = self.config 31 | self.dset_val = self.dset(self.transforms, self.val_datakeys, self.test_config, train=False,debug=debug) 32 | 33 | self.dset_test = self.dset_val if not isinstance(self,IperDataset) else self.dset(self.transforms, self.test_datakeys, self.test_config, train=False,debug=debug) 34 | self.val_obj_weighting = self.config['object_weighting'] if 'object_weighting' in self.config else self.dset_val.obj_weighting 35 | def w_init_fn(worker_id): 36 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 37 | 38 | self.winit_fn = w_init_fn 39 | 40 | 41 | def train_dataloader(self): 42 | if self.zero_poke: 43 | sampler = FixedLengthSampler(self.dset_train,self.batch_size,shuffle=True, 44 | drop_last=True,weighting=self.dset_train.obj_weighting, 45 | zero_poke=self.zero_poke,zero_poke_amount=self.config["zero_poke_amount"]) 46 | return DataLoader(self.dset_train,batch_sampler=sampler,num_workers=self.num_workers,worker_init_fn=self.winit_fn) 47 | else: 48 | if self.dset_train.obj_weighting: 49 | sampler = WeightedRandomSampler(weights=self.dset_train.datadict["weights"], num_samples=self.dset_train.datadict["img_path"].shape[0]) 50 | return DataLoader(self.dset_train, batch_size=self.batch_size, num_workers=self.num_workers, sampler=sampler) 51 | else: 52 | return DataLoader(self.dset_train,batch_size=self.batch_size, num_workers=self.num_workers,shuffle=True) 53 | 54 | def val_dataloader(self): 55 | if self.zero_poke: 56 | sampler = FixedLengthSampler(self.dset_val, self.batch_size, shuffle=True, 57 | drop_last=True, weighting=self.val_obj_weighting, 58 | zero_poke=self.zero_poke, zero_poke_amount=self.config["zero_poke_amount"]) 59 | return DataLoader(self.dset_val, batch_sampler=sampler, num_workers=self.num_workers, worker_init_fn=self.winit_fn) 60 | else: 61 | if self.val_obj_weighting: 62 | sampler = WeightedRandomSampler(weights=self.dset_val.datadict["weights"],num_samples=self.dset_val.datadict["img_path"].shape[0]) 63 | return DataLoader(self.dset_val,batch_size=self.batch_size,num_workers=self.num_workers,sampler=sampler) 64 | else: 65 | return DataLoader(self.dset_val,batch_size=self.batch_size, num_workers=self.num_workers,shuffle=True) 66 | 67 | def test_dataloader(self): 68 | return DataLoader(self.dset_test,batch_size=self.config['test_batch_size'], num_workers=self.num_workers,shuffle=True) 69 | -------------------------------------------------------------------------------- /data/helper_functions.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def preprocess_image(img,swap_channels=False): 5 | """ 6 | 7 | :param img: numpy array of shape (H,W,3) 8 | :param swap_channels: True, if channelorder is BGR 9 | :return: 10 | """ 11 | if swap_channels: 12 | img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 13 | 14 | # this seems to be possible as flownet2 outputs only images which can be divided by 64 15 | shape = img.shape 16 | img = img[:int(shape[0] / 64) * 64,:int(shape[1] / 64) * 64] 17 | 18 | return img -------------------------------------------------------------------------------- /data/human36m_preprocess.py: -------------------------------------------------------------------------------- 1 | #code heaviliy borrowed from https://github.com/anibali/h36m-fetch 2 | 3 | from subprocess import call 4 | from os import path, makedirs 5 | import hashlib 6 | from tqdm import tqdm 7 | import configparser 8 | import requests 9 | import tarfile 10 | from glob import glob 11 | 12 | 13 | BASE_URL = 'http://vision.imar.ro/human3.6m/filebrowser.php' 14 | 15 | subjects = [ 16 | ('S1', 1), 17 | ('S5', 6), 18 | ('S6', 7), 19 | ('S7', 2), 20 | ('S8', 3), 21 | ('S9', 4), 22 | ('S11', 5), 23 | ] 24 | 25 | 26 | def md5(filename): 27 | hash_md5 = hashlib.md5() 28 | with open(filename, 'rb') as f: 29 | for chunk in iter(lambda: f.read(4096), b''): 30 | hash_md5.update(chunk) 31 | return hash_md5.hexdigest() 32 | 33 | 34 | def download_file(url, dest_file, phpsessid): 35 | call(['axel', 36 | '-a', 37 | '-n', '24', 38 | '-H', 'COOKIE: PHPSESSID=' + phpsessid, 39 | '-o', dest_file, 40 | url]) 41 | 42 | def get_config(): 43 | dirpath = path.dirname(path.realpath(__file__)) 44 | config = configparser.ConfigParser() 45 | config.read(path.join(dirpath,'config.ini')) 46 | return config 47 | 48 | 49 | def get_phpsessid(config): 50 | 51 | try: 52 | phpsessid = config['General']['PHPSESSID'] 53 | except (KeyError, configparser.NoSectionError): 54 | print('Could not read PHPSESSID from `config.ini`.') 55 | phpsessid = input('Enter PHPSESSID: ') 56 | return phpsessid 57 | 58 | 59 | def verify_phpsessid(phpsessid): 60 | requests.packages.urllib3.disable_warnings() 61 | test_url = 'http://vision.imar.ro/human3.6m/filebrowser.php' 62 | resp = requests.get(test_url, verify=False, cookies=dict(PHPSESSID=phpsessid)) 63 | fail_message = 'Failed to verify your PHPSESSID. Please ensure that you ' \ 64 | 'are currently logged in at http://vision.imar.ro/human3.6m/ ' \ 65 | 'and that you have copied the PHPSESSID cookie correctly.' 66 | assert resp.url == test_url, fail_message 67 | 68 | 69 | def download_all(phpsessid, out_dir): 70 | checksums = {} 71 | dirpath = path.dirname(path.realpath(__file__)) 72 | with open(path.join(dirpath,'checksums.txt'), 'r') as f: 73 | for line in f.read().splitlines(keepends=False): 74 | v, k = line.split(' ') 75 | checksums[k] = v 76 | 77 | files = [] 78 | for subject_id, id in subjects: 79 | files += [ 80 | ('Videos_{}.tgz'.format(subject_id), 81 | 'download=1&filepath=Videos&filename=SubjectSpecific_{}.tgz'.format(id)), 82 | ] 83 | 84 | # out_dir = 'video_download' 85 | # makedirs(out_dir, exist_ok=True) 86 | 87 | for filename, query in tqdm(files, ascii=True): 88 | out_file = path.join(out_dir, filename) 89 | 90 | if path.isfile(out_file): 91 | continue 92 | 93 | if path.isfile(out_file): 94 | checksum = md5(out_file) 95 | if checksums.get(out_file, None) == checksum: 96 | continue 97 | 98 | download_file(BASE_URL + '?' + query, out_file, phpsessid) 99 | 100 | # https://stackoverflow.com/a/6718435 101 | def commonprefix(m): 102 | s1 = min(m) 103 | s2 = max(m) 104 | for i, c in enumerate(s1): 105 | if c != s2[i]: 106 | return s1[:i] 107 | return s1 108 | 109 | def extract_tgz(tgz_file, dest): 110 | # if path.exists(dest): 111 | # return 112 | with tarfile.open(tgz_file, 'r:gz') as tar: 113 | members = [m for m in tar.getmembers() if m.isreg()] 114 | member_dirs = [path.dirname(m.name).split(path.sep) for m in members] 115 | base_path = path.sep.join(commonprefix(member_dirs)) 116 | for m in members: 117 | m.name = path.relpath(m.name, base_path) 118 | tar.extractall(dest) 119 | 120 | def extract(out_dir,tgzs): 121 | out_dir = path.join(out_dir,'videos') 122 | 123 | for tgz in tqdm(tgzs,desc='Extracting tgz archives'): 124 | subject_id = tgz.split('_')[-1].split('.')[0] 125 | videodir = path.join(out_dir,subject_id) 126 | makedirs(videodir,exist_ok=True) 127 | 128 | extract_tgz(tgz,videodir) 129 | 130 | 131 | if __name__ == '__main__': 132 | config = get_config() 133 | phpsessid = get_phpsessid(config) 134 | verify_phpsessid(phpsessid) 135 | out_dir = config['General']['TARGETDIR'] 136 | download_dir = path.join(out_dir,'video_download') 137 | makedirs(download_dir,exist_ok=True) 138 | download_all(phpsessid,out_dir=download_dir) 139 | tgzs = glob(path.join(download_dir,'*.tgz')) 140 | extract(out_dir,tgzs) 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import BatchSampler,RandomSampler,SequentialSampler, WeightedRandomSampler 3 | 4 | from data.base_dataset import BaseDataset 5 | from data.flow_dataset import PlantDataset 6 | 7 | class SequenceSampler(BatchSampler): 8 | def __init__(self, dataset:BaseDataset, batch_size, shuffle, drop_last): 9 | assert isinstance(dataset, BaseDataset), "The used dataset in Sequence Sampler must inherit from BaseDataset" 10 | if shuffle: 11 | sampler = RandomSampler(dataset) 12 | else: 13 | sampler = SequentialSampler(dataset) 14 | super().__init__(sampler, batch_size, drop_last) 15 | 16 | 17 | self.dataset = dataset 18 | #self.max_lag = self.dataset.datadict["flow_paths"].shape[1] 19 | 20 | 21 | def __iter__(self): 22 | batch = [] 23 | 24 | # sample sequence length 25 | lag = int(np.random.choice(self.dataset.valid_lags, 1)) 26 | 27 | for idx in self.sampler: 28 | batch.append((idx, lag)) 29 | if len(batch) == self.batch_size: 30 | yield batch 31 | batch = [] 32 | 33 | # sample sequence length 34 | lag = int(np.random.choice(self.dataset.valid_lags, 1)) 35 | 36 | if len(batch) > 0 and not self.drop_last: 37 | yield batch 38 | 39 | 40 | class FixedLengthSampler(BatchSampler): 41 | 42 | def __init__(self, dataset:PlantDataset,batch_size,shuffle,drop_last, weighting, zero_poke,zero_poke_amount=None): 43 | if shuffle: 44 | if weighting: 45 | sampler = WeightedRandomSampler(weights=dataset.datadict["weights"],num_samples=len(dataset)) 46 | else: 47 | sampler = RandomSampler(dataset) 48 | else: 49 | sampler = SequentialSampler(dataset) 50 | super().__init__(sampler, batch_size, drop_last) 51 | self.shuffle = shuffle 52 | self.dataset = dataset 53 | self.zero_poke = zero_poke 54 | self.zero_poke_amount = zero_poke_amount 55 | if self.zero_poke: 56 | assert self.zero_poke_amount is not None 57 | 58 | 59 | def __iter__(self): 60 | batch = [] 61 | if self.zero_poke: 62 | # sample a certain proportion to be zero pokes 63 | zero_poke_ids = np.random.choice(np.arange(self.dataset.__len__()),size=int(self.dataset.__len__()/ self.zero_poke_amount),replace=False).tolist() 64 | self.dataset.logger.info(f"Sampling {len(zero_poke_ids)} zeropokes for next epoch") 65 | else: 66 | zero_poke_ids = [] 67 | 68 | for idx in self.sampler: 69 | if idx in zero_poke_ids: 70 | batch.append(-1) 71 | else: 72 | batch.append(idx) 73 | if len(batch) == self.batch_size: 74 | yield batch 75 | batch = [] 76 | 77 | 78 | if len(batch) > 0 and not self.drop_last: 79 | yield batch 80 | 81 | 82 | 83 | class SequenceLengthSampler(BatchSampler): 84 | def __init__(self, dataset:BaseDataset, batch_size, shuffle, drop_last, n_frames=None, zero_poke = False,): 85 | assert isinstance(dataset, BaseDataset), "The used dataset in Sequence Sampler must inherit from BaseDataset" 86 | assert dataset.var_sequence_length and dataset.yield_videos, "The dataset has to be run in sequence mode and has to output variable sequence lengths" 87 | sampler = SequentialSampler(dataset) 88 | super().__init__(sampler, batch_size, drop_last) 89 | self.dataset = dataset 90 | self.shuffle = shuffle 91 | if n_frames is not None: 92 | assert n_frames >= self.dataset.min_frames and n_frames <=(self.dataset.min_frames + self.dataset.max_frames) 93 | self.n_frames = (n_frames-self.dataset.min_frames) 94 | else: 95 | self.n_frames = n_frames 96 | self.start_n_frames = -1 if zero_poke else 0 97 | if zero_poke: 98 | if self.dataset.train: 99 | self.len_p = np.asarray([self.dataset.zeropoke_weight] + [1.] * self.dataset.max_frames) 100 | else: 101 | self.len_p = np.asarray([1.] * (self.dataset.max_frames + 1)) 102 | else: 103 | self.len_p = np.asarray([1.] * self.dataset.max_frames) 104 | 105 | if self.dataset.longest_seq_weight != None and self.dataset.train: 106 | self.len_p[-1] = self.dataset.longest_seq_weight 107 | if zero_poke: 108 | # to keep sufficient outside pokes for the model to learn foreground and background 109 | self.len_p[0] = self.dataset.longest_seq_weight / 2 110 | self.len_p = self.len_p /self.len_p.sum() 111 | 112 | def __iter__(self): 113 | batch = [] 114 | 115 | # sample sequence length 116 | if self.shuffle: 117 | # -1 corresponds to 118 | n_frames = int(np.random.choice(np.arange(self.start_n_frames,self.dataset.max_frames), 1, p=self.len_p)) 119 | 120 | else: 121 | last_n = self.start_n_frames 122 | n_frames = last_n 123 | 124 | if n_frames == -1: 125 | n_frames_actual = int(np.random.choice(np.arange(self.dataset.max_frames), 1)) 126 | appended = (n_frames, n_frames_actual) 127 | else: 128 | appended = (n_frames, None) 129 | for idx in self.sampler: 130 | appended = (appended[0] if self.n_frames is None else self.n_frames,appended[1]) 131 | batch.append(appended) 132 | if len(batch) == self.batch_size: 133 | yield batch 134 | batch = [] 135 | 136 | # sample sequence length 137 | if self.shuffle: 138 | n_frames = int(np.random.choice(np.arange(self.start_n_frames,self.dataset.max_frames), 1,p=self.len_p)) 139 | else: 140 | n_frames = last_n+1 if last_n 0 and not self.drop_last: 150 | yield batch -------------------------------------------------------------------------------- /data_proc.yml: -------------------------------------------------------------------------------- 1 | name: data_proc 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=10.0.130 8 | - imagesize=1.2.0 9 | - numpy 10 | - pillow=7.1.2 11 | - pip=20.0.2 12 | - python=3.7.7 13 | - pytorch=1.2.0 14 | - torchvision=0.4.0 15 | - tqdm=4.48.0 16 | - pip: 17 | - coloredlogs==14.0 18 | - matplotlib==3.2.1 19 | - opencv-python==4.2.0.34 20 | - pyyaml==5.3.1 21 | - scipy 22 | - imageio-ffmpeg 23 | - yacs 24 | - Cython 25 | - json_tricks 26 | - typing_extensions 27 | 28 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from experiments.first_stage_image import FirstStageImageModel 2 | from experiments.poke_encoder import PokeEncoderModel 3 | from experiments.first_stage_video import FirstStageSequenceModel 4 | from experiments.second_stage_video import SecondStageVideoModel 5 | 6 | 7 | 8 | __experiments__ = { 9 | "img_encoder": FirstStageImageModel, 10 | "poke_encoder": PokeEncoderModel, 11 | "first_stage": FirstStageSequenceModel, 12 | "second_stage": SecondStageVideoModel, 13 | } 14 | 15 | 16 | def select_experiment(config,dirs, devices): 17 | experiment = config["general"]["experiment"] 18 | model_name = config["general"]["model_name"] 19 | if experiment not in __experiments__: 20 | raise NotImplementedError(f"No such experiment! {experiment}") 21 | if config["general"]["restart"]: 22 | print(f"Restarting run \"{model_name}\" of type \"{experiment}\". Device: {devices}") 23 | else: 24 | print(f"New run \"{model_name}\" of type \"{experiment}\". Device: {devices}") 25 | return __experiments__[experiment](config, dirs, devices) 26 | -------------------------------------------------------------------------------- /experiments/experiment.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.loggers import WandbLogger 5 | from pytorch_lightning.profiler import PassThroughProfiler,AdvancedProfiler, SimpleProfiler 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from functools import partial 8 | import yaml 9 | import math 10 | from glob import glob 11 | from os import path 12 | import numpy as np 13 | 14 | from utils.general import get_logger 15 | from utils.callbacks import BestCkptsToYaml 16 | 17 | WANDB_DISABLE_CODE = True 18 | 19 | class Experiment: 20 | def __init__(self, config:dict, dirs: dict, devices): 21 | super().__init__() 22 | self.config = config 23 | self.dirs = dirs 24 | self.devices = devices 25 | #self.logger = get_logger() 26 | 27 | ########## seed setting ########## 28 | 29 | seed = self.config['testing']['seed'] if 'seed' in self.config['testing'] and self.config['general']['test'] != 'none' else self.config['general']['seed'] 30 | torch.manual_seed(seed) 31 | torch.cuda.manual_seed(seed) 32 | np.random.seed(seed) 33 | # random.seed(opt.seed) 34 | torch.backends.cudnn.deterministic = True 35 | torch.manual_seed(seed) 36 | rng = np.random.RandomState(seed) 37 | 38 | 39 | 40 | self.is_debug = self.config["general"]["debug"] 41 | if self.is_debug: 42 | self.config["data"]["n_workers"] = 0 43 | self.config["logging"]["n_samples_umap"] = 20 44 | self.config["logging"]["log_train_prog_at"] = 10 45 | self.config["data"]["batch_size"] = 2 46 | 47 | is_target_version = "target_version" in self.config["general"] and path.isdir(path.join(self.dirs["ckpt"],str(self.config["general"]["target_version"]))) 48 | 49 | # wandb logging 50 | self.current_version = 0 51 | if path.isdir(path.join(self.dirs["ckpt"])): 52 | runs = glob(path.join(self.dirs["ckpt"],"*")) 53 | if len(runs) > 0: 54 | self.current_version = max([int(r.split("/")[-1]) for r in runs]) 55 | if self.config["general"]["test"] == 'none': 56 | self.current_version+=1 57 | 58 | 59 | if self.config["general"]["test"] != 'none' and is_target_version: 60 | self.current_version = self.config["general"]["target_version"] 61 | 62 | if self.config['general']['test'] == 'none': 63 | logger = WandbLogger(name=self.config["general"]["model_name"], save_dir=self.dirs["log"], 64 | project="poking_inn",group=self.config["general"]["experiment"],tags=[self.config["data"]["dataset"]], 65 | version=self.config["general"]["experiment"]+ "-" +self.config["general"]["model_name"]+ "-" + str(self.current_version), 66 | save_code=True,entity='inn_poking') 67 | else: 68 | logger = False 69 | 70 | self.config["general"].update({"version":self.current_version}) 71 | if self.config["general"]["restart"] or self.config["general"]["test"] != 'none': 72 | if is_target_version: 73 | self.ckpt_load_dir = path.join(self.dirs["ckpt"],str(self.config["general"]["target_version"])) 74 | else: 75 | if self.config["general"]["test"] != 'none': 76 | self.ckpt_load_dir = path.join(self.dirs["ckpt"],str(self.current_version)) 77 | else: 78 | self.ckpt_load_dir = self.__get_latest_ckpt_dir() 79 | 80 | 81 | acc_batches = int(math.ceil(self.config["training"]["min_acc_batch_size"] / self.config["data"]["batch_size"])) \ 82 | if self.config["training"]["min_acc_batch_size"] > self.config["data"]["batch_size"] else 1 83 | 84 | prof_file = path.join(self.dirs['log'],'profile.log') 85 | profiler =AdvancedProfiler(output_filename=prof_file) if self.config["general"]["profiler"] else None 86 | self.basic_trainer = partial(pl.Trainer, deterministic=True, gpus=devices, logger=logger, 87 | progress_bar_refresh_rate=1,profiler=profiler, 88 | accumulate_grad_batches=acc_batches, 89 | max_epochs=self.config["training"]["n_epochs"]) 90 | if self.is_debug: 91 | self.basic_trainer = partial(self.basic_trainer,limit_train_batches=10,limit_val_batches=2, 92 | limit_test_batches=5,weights_summary="top",log_every_n_steps=2,num_sanity_val_steps=2) 93 | 94 | else: 95 | self.basic_trainer = partial(self.basic_trainer,val_check_interval=self.config["training"]["val_every"],num_sanity_val_steps=0) 96 | 97 | self.ckpt_callback = partial(ModelCheckpoint,dirpath=path.join(self.dirs["ckpt"],str(self.current_version)), period= 1,save_last=True) 98 | 99 | if 'test_batch_size' not in self.config['data']: 100 | self.config['data']['test_batch_size'] = 16 101 | 102 | if self.config['general']['test'] == 'none': 103 | logger.log_hyperparams(self.config) 104 | 105 | # signal.signal(signal.SIGINT,self.ckpt_to_yaml) 106 | 107 | def _get_checkpoint(self): 108 | 109 | 110 | ckpt_name = glob(path.join(self.ckpt_load_dir, "*.yaml")) 111 | last_ckpt = path.join(self.ckpt_load_dir, "last.ckpt") 112 | if self.config["general"]["last_ckpt"] and path.isfile(last_ckpt): 113 | print('Using last ckpt...') 114 | ckpt_name = [last_ckpt] 115 | elif self.config["general"]["last_ckpt"] and not path.isfile(last_ckpt): 116 | raise ValueError("intending to load last ckpt, but no last ckpt found. Aborting....") 117 | 118 | if len(ckpt_name) == 1: 119 | ckpt_name = ckpt_name[0] 120 | else: 121 | msg = "Not enough" if len(ckpt_name) < 1 else "Too many" 122 | raise ValueError(msg + f" checkpoint files found! Aborting...") 123 | 124 | if ckpt_name.endswith(".yaml"): 125 | with open(ckpt_name, "r") as f: 126 | ckpts = yaml.load(f, Loader=yaml.FullLoader) 127 | 128 | has_files = len(ckpts) > 0 129 | while has_files: 130 | best_val = min([ckpts[key] for key in ckpts]) 131 | ckpt_name = {ckpts[key]:key for key in ckpts}[best_val] 132 | if path.isfile(ckpt_name): 133 | break 134 | else: 135 | del ckpts[ckpt_name] 136 | has_files = len(ckpts) > 0 137 | 138 | if not has_files: 139 | raise ValueError(f'No valid files contained in ckpt-name-holding file "{ckpt_name}"') 140 | 141 | ckpt_file_name = ckpt_name.split('/')[-1] 142 | print(f'********************* Loading checkpoint for run_version #{self.current_version} with name "{ckpt_file_name}" *******************************') 143 | return ckpt_name 144 | 145 | def add_ckpt_file(self): 146 | assert isinstance(self.ckpt_callback,ModelCheckpoint) 147 | return BestCkptsToYaml(self.ckpt_callback) 148 | 149 | def __get_latest_ckpt_dir(self): 150 | start_version = self.current_version -1 151 | ckpt_dir = None 152 | for v in range(start_version,-1,-1): 153 | act_dir = path.join(self.dirs["ckpt"],str(v)) 154 | if not path.isdir(act_dir): 155 | continue 156 | if self.config["general"]["last_ckpt"]: 157 | ckpt_dir = act_dir if path.isfile(path.join(act_dir,"last.ckpt")) else None 158 | print('Using last ckpt...') 159 | else: 160 | ckpt_dir = act_dir if (path.isfile(path.join(act_dir,"best_k_models.yaml")) and len(glob(path.join(act_dir,"*.ckpt"))) > 0) else None 161 | 162 | if ckpt_dir is not None: 163 | break 164 | 165 | if ckpt_dir is None: 166 | raise NotADirectoryError("NO valid checkpoint dir found but model shall be restarted! Aborting....") 167 | 168 | # self.logger.info(f'load checkpoint from file: "{ckpt_dir}"') 169 | 170 | return ckpt_dir 171 | 172 | @abstractmethod 173 | def train(self): 174 | """ 175 | Here, the experiment shall be run 176 | :return: 177 | """ 178 | pass 179 | 180 | @abstractmethod 181 | def test(self): 182 | """ 183 | Here the prediction shall be run 184 | :param ckpt_path: The path where the checkpoint file to load can be found 185 | :return: 186 | """ 187 | pass 188 | -------------------------------------------------------------------------------- /experiments/first_stage_image.py: -------------------------------------------------------------------------------- 1 | from experiments.experiment import Experiment# 2 | from functools import partial 3 | 4 | # from models.first_stage_image_fc import AEModel 5 | from models.first_stage_image_conv import ConvAEModel 6 | from data.datamodule import StaticDataModule 7 | 8 | 9 | class FirstStageImageModel(Experiment): 10 | 11 | def __init__(self,config,dirs,devices): 12 | super().__init__(config,dirs,devices) 13 | 14 | # intiliaze models 15 | self.datakeys = ["images"] 16 | 17 | 18 | 19 | self.config["architecture"].update({"in_size": self.config["data"]["spatial_size"][0]}) 20 | 21 | model = ConvAEModel 22 | 23 | if self.config["general"]["restart"]: 24 | ckpt_path = self._get_checkpoint() 25 | self.ae = model.load_from_checkpoint(ckpt_path,map_location="cpu",config=self.config) 26 | else: 27 | self.ae = model(self.config) 28 | # basic trainer is initialized in parent class 29 | # self.logger.info( 30 | # f"Number of trainable parameters in model is {sum(p.numel() for p in self.ae.parameters())}" 31 | # ) 32 | 33 | self.ckpt_callback = self.ckpt_callback(filename='{epoch}-{lpips-val:.3f}',monitor='lpips-val', 34 | save_top_k=self.config["logging"]["n_saved_ckpt"], mode='min') 35 | 36 | to_yaml_cb = self.add_ckpt_file() 37 | callbacks = [self.ckpt_callback,to_yaml_cb] 38 | if self.config["general"]["restart"] and ckpt_path is not None: 39 | self.basic_trainer = partial(self.basic_trainer,resume_from_checkpoint=ckpt_path,callbacks=callbacks) 40 | else: 41 | self.basic_trainer = partial(self.basic_trainer,callbacks=callbacks) 42 | 43 | self.basic_trainer = partial(self.basic_trainer,automatic_optimization=False) 44 | 45 | 46 | 47 | 48 | def train(self): 49 | # prepare data 50 | datamod = StaticDataModule(self.config["data"],datakeys=self.datakeys) 51 | datamod.setup() 52 | n_batches_complete_train = len(datamod.train_dataloader()) 53 | n_batches_complete_val = len(datamod.val_dataloader()) 54 | #n_train_batches = self.config["training"]["max_batches_per_epoch"] if n_batches_complete_train > self.config["training"]["max_batches_per_epoch"] else n_batches_complete_train 55 | n_val_batches = self.config["training"]["max_val_batches"] if n_batches_complete_val > self.config["training"]["max_val_batches"] else n_batches_complete_val 56 | 57 | if not self.is_debug: 58 | trainer = self.basic_trainer(limit_val_batches=n_val_batches, 59 | limit_test_batches=n_val_batches,replace_sampler_ddp=datamod.dset_train.obj_weighting) 60 | else: 61 | trainer = self.basic_trainer() 62 | 63 | trainer.fit(self.ae,datamodule=datamod) 64 | 65 | 66 | 67 | 68 | 69 | def test(self): 70 | pass 71 | -------------------------------------------------------------------------------- /experiments/first_stage_video.py: -------------------------------------------------------------------------------- 1 | from experiments.experiment import Experiment 2 | from functools import partial 3 | 4 | from models.first_stage_motion_model import SpadeCondMotionModel, FCBaseline 5 | from models.poke_vae import PokeVAE 6 | from data.datamodule import StaticDataModule 7 | 8 | 9 | class FirstStageSequenceModel(Experiment): 10 | 11 | def __init__(self,config,dirs,devices): 12 | super().__init__(config,dirs,devices) 13 | 14 | # intiliaze models 15 | self.datakeys = ["images"] 16 | self.is_baseline = 'baseline' in self.config['architecture'] and self.config['architecture']['baseline'] 17 | self.fc_baseline = 'fc_baseline' in self.config['architecture'] and self.config['architecture']['fc_baseline'] 18 | if self.is_baseline: 19 | self.datakeys.extend(['poke','flow','sample_ids']) 20 | 21 | model = PokeVAE if self.is_baseline else SpadeCondMotionModel 22 | if self.fc_baseline: 23 | model = FCBaseline 24 | 25 | if self.config["general"]["restart"] or self.config['general']['test'] != 'none': 26 | ckpt_path = self._get_checkpoint() 27 | self.ae = model.load_from_checkpoint(ckpt_path,map_location="cpu", strict=False,config=self.config,dirs= self.dirs) 28 | else: 29 | self.ae = model(self.config,dirs=self.dirs) 30 | # basic trainer is initialized in parent class 31 | # self.logger.info( 32 | # f"Number of trainable parameters in model is {sum(p.numel() for p in self.ae.parameters())}" 33 | # ) 34 | 35 | self.ckpt_callback = self.ckpt_callback(filename='{epoch}-{FVD-val:.3f}',monitor='FVD-val', 36 | save_top_k=self.config["logging"]["n_saved_ckpt"], mode='min') 37 | to_yaml_cb = self.add_ckpt_file() 38 | 39 | callbacks = [self.ckpt_callback,to_yaml_cb] 40 | if self.config["general"]["restart"] and ckpt_path is not None: 41 | self.basic_trainer = partial(self.basic_trainer,resume_from_checkpoint=ckpt_path,callbacks=callbacks) 42 | else: 43 | self.basic_trainer = partial(self.basic_trainer,callbacks=callbacks) 44 | 45 | self.basic_trainer = partial(self.basic_trainer,automatic_optimization=False,terminate_on_nan=True) 46 | 47 | def train(self): 48 | # prepare data 49 | datamod = StaticDataModule(self.config["data"], datakeys=self.datakeys) 50 | datamod.setup() 51 | n_batches_complete_train = len(datamod.train_dataloader()) 52 | n_batches_complete_val = len(datamod.val_dataloader()) 53 | n_train_batches = self.config["training"]["max_batches_per_epoch"] if n_batches_complete_train > self.config["training"]["max_batches_per_epoch"] else n_batches_complete_train 54 | n_val_batches = self.config["training"]["max_val_batches"] if n_batches_complete_val > self.config["training"]["max_val_batches"] else n_batches_complete_val 55 | 56 | if not self.is_debug: 57 | trainer = self.basic_trainer(limit_val_batches=n_val_batches,limit_train_batches=n_train_batches, 58 | limit_test_batches=n_val_batches, replace_sampler_ddp=datamod.dset_train.obj_weighting or datamod.zero_poke) 59 | else: 60 | trainer = self.basic_trainer() 61 | 62 | trainer.fit(self.ae, datamodule=datamod) 63 | 64 | 65 | def test(self): 66 | import math 67 | import torch 68 | from tqdm import tqdm 69 | from copy import deepcopy 70 | from utils.logging import make_errorbar_plot 71 | import pandas as pd 72 | import os 73 | 74 | assert self.is_baseline or self.fc_baseline 75 | 76 | # test without zero poke 77 | self.config['data']['zero_poke'] = False 78 | self.config['data']['test_batch_size'] = self.config['testing']['test_batch_size'] 79 | if self.config['general']['test']=='motion_transfer': 80 | assert self.config['data']['dataset'] == 'IperDataset' 81 | self.config['data']['get_kp_nn'] = True 82 | if self.config['general']['test'] == 'diversity': 83 | n_test_batches = int( 84 | math.ceil(self.config['testing']['n_samples_metrics'] / self.config['data']['test_batch_size'])) 85 | # max_n_pokes = deepcopy(self.config['data']['n_pokes']) 86 | 87 | # if self.config['testing']['summarize_n_pokes']: 88 | self.ae.console_logger.info('***************************COMPUTING METRICS OVER SUMMARIZED POKES*******************************************************') 89 | datamod = StaticDataModule(self.config["data"], datakeys=self.datakeys, debug=self.config['testing']['debug']) 90 | datamod.setup() 91 | trainer = self.basic_trainer(limit_test_batches=n_test_batches) 92 | trainer.test(self.ae, datamodule=datamod) 93 | # else: 94 | # self.ae.console_logger.info('***************************COMPUTING METRICS FOR EACH INDIVIDUAL NUMBER OF POKES******************************************') 95 | # self.config['data']['fix_n_pokes'] = True 96 | # for count,n_pokes in enumerate(tqdm(reversed(range(max_n_pokes)),desc='Conducting metrics experiment....')): 97 | # self.config['data']['n_pokes'] = n_pokes +1 98 | # self.ae.console_logger.info(f'Instantiating {n_pokes + 1} pokes...') 99 | # 100 | # datamod = StaticDataModule(self.config["data"], datakeys=self.datakeys, debug=self.config['testing']['debug']) 101 | # datamod.setup() 102 | # trainer = self.basic_trainer(limit_test_batches=n_test_batches) 103 | # trainer.test(self.ae,datamodule=datamod) 104 | # 105 | # if self.config['general']['test'] == 'metrics': 106 | # 107 | # kps_dict =self.ae.metrics_dict['KPS'] 108 | # ssim_dict = self.ae.metrics_dict['SSIM'] 109 | # lpips_dict = self.ae.metrics_dict['LPIPS'] 110 | # # construcr dataframes 111 | # df_ssim = pd.DataFrame.from_dict(ssim_dict) 112 | # # df_psnr = pd.DataFrame.from_dict(psnr_dict) 113 | # df_lpips = pd.DataFrame.from_dict(lpips_dict) 114 | # n_samples_per_poke = self.config['testing']['n_samples_per_data_point'] 115 | # 116 | # # save data and plot stats 117 | # postfix = 'aggregated' if self.config['testing']['summarize_n_pokes'] else 'unique_pokes' 118 | # # metrics only reported when gt keypoints are available 119 | # if len(kps_dict) > 0: 120 | # df_kps = pd.DataFrame.from_dict(kps_dict) 121 | # fig_savename = os.path.join(self.ae.metrics_dir, f'keypoint_err_plot_{n_samples_per_poke}samples-{postfix}.pdf') 122 | # df_kps.to_csv(os.path.join(self.ae.metrics_dir,f'plot_data_{n_samples_per_poke}pokes_kps-{postfix}.csv')) 123 | # make_errorbar_plot(fig_savename,df_kps,xid='Time',yid='Mean MSE per Frame', 124 | # hueid='Number of Pokes',varid='Std per Frame') 125 | # df_kps_group = df_kps.groupby('Time', as_index=False).mean() 126 | # df_kps_group.to_csv(os.path.join(self.ae.metrics_dir, f'plot_data_kps_group.csv')) 127 | # 128 | # # image based metrics which can be reported for all datasets 129 | # fig_savename = os.path.join(self.ae.metrics_dir, f'ssim_plot_{n_samples_per_poke}samples-{postfix}.pdf') 130 | # df_ssim.to_csv(os.path.join(self.ae.metrics_dir, f'plot_data_{n_samples_per_poke}pokes_ssim-{postfix}.csv')) 131 | # make_errorbar_plot(fig_savename, df_ssim, xid='Time', yid='Mean SSIM per Frame', 132 | # hueid='Number of Pokes',varid='Std per Frame') 133 | # 134 | # fig_savename = os.path.join(self.ae.metrics_dir, f'lpips_plot_{n_samples_per_poke}samples-{postfix}.pdf') 135 | # df_lpips.to_csv(os.path.join(self.ae.metrics_dir, f'plot_data_{n_samples_per_poke}pokes_lpips-{postfix}.csv')) 136 | # make_errorbar_plot(fig_savename, df_lpips, xid='Time', yid='Mean LPIPS per Frame', 137 | # hueid='Number of Pokes', varid='Std per Frame') 138 | # 139 | # # fig_savename = os.path.join(self.ae.metrics_dir, f'psnr_plot_{n_samples_per_poke}samples.pdf') 140 | # # df_psnr.to_csv(os.path.join(self.ae.metrics_dir, f'plot_data_{n_samples_per_poke}pokes_psnr.csv')) 141 | # # make_errorbar_plot(fig_savename, df_psnr, xid='Time', yid='Mean PSNR per Frame', 142 | # # hueid='Number of Pokes', varid='Std per Frame') 143 | # #aggregate for all pokes 144 | # 145 | # df_ssim_group = df_ssim.groupby('Time', as_index=False).mean() 146 | # df_ssim_group.to_csv(os.path.join(self.ae.metrics_dir, f'plot_data_ssim_group.csv')) 147 | # self.ae.console_logger.info(f'Mean ssim value from sample metric is {df_ssim_group["Mean SSIM per Frame"]}') 148 | # # df_psnr_group=df_psnr.groupby('Number of Pokes', as_index=False).mean() 149 | # # df_psnr_group.to_csv(os.path.join(self.ae.metrics_dir, f'plot_data_psnr_group.csv')) 150 | # df_lpips_group = df_lpips.groupby('Time', as_index=False).mean() 151 | # df_lpips_group.to_csv(os.path.join(self.ae.metrics_dir, f'plot_data_lpips_group.csv')) 152 | # self.ae.console_logger.info(f'Mean lpips value from sample metric is {df_lpips_group["Mean LPIPS per Frame"]}') 153 | # else: 154 | mean_divscore = torch.mean(torch.tensor(self.ae.div_scores)) 155 | self.ae.console_logger.info(f'Diversity score averaged over all pokes is {mean_divscore}') 156 | 157 | else: 158 | 159 | datamod = StaticDataModule(self.config["data"], datakeys=self.datakeys) 160 | datamod.setup() 161 | if self.config['general']['test'] == 'fvd': 162 | self.config['data']['test_batch_size'] = 16 163 | n_test_batches = int(math.ceil(self.config['testing']['n_samples_fvd'] / self.config['data']['test_batch_size'])) 164 | elif self.config['general']['test'] == 'samples': 165 | n_test_batches = self.config['testing']['n_samples_vis'] // self.config['data']['test_batch_size'] 166 | else: 167 | n_test_batches = int(math.ceil(self.config['testing']['n_samples_metrics'] / self.config['data']['test_batch_size'])) 168 | 169 | 170 | 171 | trainer = self.basic_trainer(limit_test_batches=n_test_batches) 172 | 173 | trainer.test(self.ae, datamodule=datamod) -------------------------------------------------------------------------------- /experiments/poke_encoder.py: -------------------------------------------------------------------------------- 1 | from experiments.experiment import Experiment# 2 | from functools import partial 3 | 4 | # from models.baselines.poke_encoder_fc import PokeAE 5 | from models.conv_poke_encoder import ConvPokeAE 6 | from data.datamodule import StaticDataModule 7 | 8 | 9 | class PokeEncoderModel(Experiment): 10 | 11 | 12 | def __init__(self,config,dirs,devices): 13 | super().__init__(config,dirs,devices) 14 | 15 | # intiliaze models 16 | self.datakeys = ["poke","flow","images","original_flow"] 17 | 18 | 19 | 20 | self.config["architecture"].update({"in_size": self.config["data"]["spatial_size"][0]}) 21 | 22 | model = ConvPokeAE 23 | 24 | if self.config["general"]["restart"]: 25 | ckpt_path = self._get_checkpoint() 26 | self.ae = model.load_from_checkpoint(ckpt_path,map_location="cpu",config=self.config) 27 | else: 28 | self.ae = model(self.config) 29 | # basic trainer is initialized in parent class 30 | # self.logger.info( 31 | # f"Number of trainable parameters in model is {sum(p.numel() for p in self.ae.parameters())}" 32 | # ) 33 | 34 | self.ckpt_callback = self.ckpt_callback(filename='{epoch}-{lpips-val:.3f}',monitor='lpips-val', 35 | save_top_k=self.config["logging"]["n_saved_ckpt"], mode='min') 36 | to_yaml_cb = self.add_ckpt_file() 37 | 38 | callbacks = [self.ckpt_callback,to_yaml_cb] 39 | if self.config["general"]["restart"] and ckpt_path is not None: 40 | self.basic_trainer = partial(self.basic_trainer,resume_from_checkpoint=ckpt_path,callbacks=callbacks) 41 | else: 42 | self.basic_trainer = partial(self.basic_trainer,callbacks=callbacks) 43 | 44 | 45 | 46 | 47 | def train(self): 48 | # prepare data 49 | datamod = StaticDataModule(self.config["data"],datakeys=self.datakeys) 50 | datamod.setup() 51 | n_batches_complete_train = len(datamod.train_dataloader()) 52 | n_batches_complete_val = len(datamod.val_dataloader()) 53 | n_train_batches = self.config["training"]["max_batches_per_epoch"] if n_batches_complete_train > self.config["training"]["max_batches_per_epoch"] else n_batches_complete_train 54 | n_val_batches = self.config["training"]["max_val_batches"] if n_batches_complete_val > self.config["training"]["max_val_batches"] else n_batches_complete_val 55 | 56 | if not self.is_debug: 57 | trainer = self.basic_trainer(limit_train_batches=n_train_batches, limit_val_batches=n_val_batches, limit_test_batches=n_val_batches) 58 | else: 59 | trainer = self.basic_trainer() 60 | 61 | trainer.fit(self.ae,datamodule=datamod) 62 | 63 | 64 | 65 | 66 | 67 | def test(self): 68 | pass 69 | -------------------------------------------------------------------------------- /images/control_sensitivity.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/control_sensitivity.gif -------------------------------------------------------------------------------- /images/fpp_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/fpp_final.png -------------------------------------------------------------------------------- /images/gui_demo1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/gui_demo1.gif -------------------------------------------------------------------------------- /images/gui_demo2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/gui_demo2.gif -------------------------------------------------------------------------------- /images/gui_examples/iper_exmpl_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/gui_examples/iper_exmpl_1.png -------------------------------------------------------------------------------- /images/gui_examples/iper_exmpl_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/gui_examples/iper_exmpl_2.png -------------------------------------------------------------------------------- /images/gui_examples/iper_exmpl_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/gui_examples/iper_exmpl_3.png -------------------------------------------------------------------------------- /images/iper_exmpl_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/iper_exmpl_1.png -------------------------------------------------------------------------------- /images/kinematics_transfer.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/kinematics_transfer.gif -------------------------------------------------------------------------------- /images/overview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/ipoke/84ceecfefd7bf0769cd35334f2eb824c44c4d56b/images/overview.gif -------------------------------------------------------------------------------- /images/paper.bib: -------------------------------------------------------------------------------- 1 | @misc{blattmann2021ipoke, 2 | title={iPOKE: Poking a Still Image for Controlled Stochastic Video Synthesis}, 3 | author={Andreas Blattmann and Timo Milbich and Michael Dorkenwald and Björn Ommer}, 4 | year={2021}, 5 | eprint={2107.02790}, 6 | archivePrefix={arXiv}, 7 | primaryClass={cs.CV} 8 | } -------------------------------------------------------------------------------- /ipoke.yml: -------------------------------------------------------------------------------- 1 | name: ipoke 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - cudatoolkit=11.0 8 | - gpustat 9 | - imageio 10 | - imagesize 11 | - matplotlib 12 | - matplotlib-base 13 | - numpy 14 | - numpy-base 15 | - pandas 16 | - pathlib 17 | - pillow 18 | - pip=20.0.2 19 | - pyqt 20 | - python=3.7.9 21 | - pytorch=1.7.1 22 | - qt 23 | - torchvision 24 | - tqdm 25 | - yaml 26 | - pip: 27 | - coloredlogs 28 | - imageio-ffmpeg 29 | - ipykernel 30 | - ipython 31 | - ipython-genutils 32 | - jupyter-client 33 | - jupyter-core 34 | - kornia 35 | - lpips 36 | - moviepy==1.0.0 37 | - natsort 38 | - opencv-python==4.2.0.34 39 | - opt-einsum 40 | - pytorch-lightning==1.1.7 41 | - pyyaml 42 | - tensorflow-gan==2.0.0 43 | - tensorflow==2.5.0 44 | - tensorflow-hub==0.9.0 45 | - tensorflow-probability==0.11.1 46 | - silence-tensorflow==1.1.1 47 | - scikit-image 48 | - scikit-learn 49 | - scipy 50 | - tensorboard 51 | - umap-learn 52 | - wandb 53 | - seaborn 54 | - dotmap 55 | - yacs 56 | - Cython 57 | - json_tricks -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path, makedirs 3 | from experiments import select_experiment 4 | import torch 5 | import yaml 6 | import os 7 | import silence_tensorflow.auto 8 | 9 | 10 | def create_dir_structure(config, model_name): 11 | subdirs = ["ckpt", "config", "generated", "log"] 12 | 13 | 14 | # model_name = config['model_name'] if model_name is None else model_name 15 | structure = {subdir: path.join(config["base_dir"],config["experiment"],subdir,model_name) for subdir in subdirs} 16 | return structure 17 | 18 | def load_parameters(config_name, restart, model_name): 19 | with open(config_name,"r") as f: 20 | cdict_old = yaml.load(f,Loader=yaml.FullLoader) 21 | cdict_old['general']['model_name'] = model_name 22 | # if we just want to test if it runs 23 | dir_structure = create_dir_structure(cdict_old["general"], model_name) 24 | saved_config = path.join(dir_structure["config"], "config.yaml") 25 | if restart: 26 | if path.isfile(saved_config): 27 | with open(saved_config,"r") as f: 28 | cdict = yaml.load(f, Loader=yaml.FullLoader) 29 | else: 30 | raise FileNotFoundError("No saved config file found but model is intended to be restarted. Aborting....") 31 | 32 | [makedirs(dir_structure[d]) for d in dir_structure if not path.isdir(dir_structure[d])] 33 | 34 | cdict['testing'] = cdict_old['testing'] 35 | cdict['general']['model_name'] = model_name 36 | 37 | else: 38 | [makedirs(dir_structure[d],exist_ok=True) for d in dir_structure] 39 | if path.isfile(saved_config) and not cdict_old["general"]["debug"]: 40 | print(f"\033[93m" + "WARNING: Model has been started somewhen earlier: Resume training (y/n)?" + "\033[0m") 41 | while True: 42 | answer = input() 43 | if answer == "y" or answer == "yes": 44 | with open(saved_config,"r") as f: 45 | cdict = yaml.load(f, Loader=yaml.FullLoader) 46 | cdict['testing'] = cdict_old['testing'] 47 | restart = True 48 | break 49 | elif answer == "n" or answer == "no": 50 | with open(saved_config, "w") as f: 51 | yaml.dump(cdict_old, f, default_flow_style=False) 52 | cdict = cdict_old 53 | break 54 | else: 55 | print(f"\033[93m" + "Invalid answer! Try again!(y/n)" + "\033[0m") 56 | else: 57 | with open(saved_config, "w") as f: 58 | yaml.dump(cdict_old,f,default_flow_style=False) 59 | 60 | cdict = cdict_old 61 | 62 | 63 | return cdict, dir_structure, restart 64 | 65 | def check_ckpt_paths(config): 66 | if "DATAPATH" not in os.environ: 67 | return config 68 | 69 | for key in config: 70 | for k in config[key]: 71 | if k == "ckpt": 72 | config[key][k] = path.join(os.environ["DATAPATH"],config[key][k][1:]) 73 | 74 | 75 | return config 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("-c", "--config", type=str, 81 | default="config/latent_flow_net.yaml", 82 | help="Define config file") 83 | parser.add_argument("-m","--model_name", type=str, required=True,help="Run name for the project that shall be resumed for training or testing.") 84 | parser.add_argument("-r","--resume", default=False ,action="store_true", help='Whether or not to resume the training.') 85 | parser.add_argument("-g","--gpus",default=[0], type=int, 86 | nargs="+",help="GPU to use.") 87 | parser.add_argument("--test",default='none', type=str, choices=['none','fvd','accuracy','samples','diversity', 'kps_acc', 'transfer', 'control_sensitivity'],help="Whether to start in infer mode?") 88 | parser.add_argument("--last_ckpt",default=False,action="store_true",help="Whether to load the last checkpoint if resuming training.") 89 | parser.add_argument("--target_version",default=None,type=int,help="The target version for loading checkpoints from.") 90 | 91 | args = parser.parse_args() 92 | 93 | config, structure, restart = load_parameters(args.config, args.resume or args.test !='none', args.model_name) 94 | config["general"]["restart"] = restart 95 | config["general"]["last_ckpt"] = args.last_ckpt 96 | config["general"]["test"] = args.test 97 | if args.target_version is not None: 98 | config["general"]["target_version"] = args.target_version 99 | 100 | 101 | config = check_ckpt_paths(config) 102 | 103 | devices = ",".join([str(g) for g in args.gpus]) if isinstance(args.gpus,list) else str(args.gpus) 104 | os.environ["CUDA_VISIBLE_DEVICES"] = devices 105 | args.gpus = [i for i,_ in enumerate(args.gpus)] 106 | 107 | 108 | 109 | # if len(args.gpus) == 1: 110 | # gpus = int(args.gpus[0]) 111 | # else: 112 | # gpus = args.gpus 113 | 114 | experiment = select_experiment(config, structure, args.gpus) 115 | 116 | # start selected experiment 117 | 118 | if args.test != 'none': 119 | experiment.test() 120 | else: 121 | experiment.train() -------------------------------------------------------------------------------- /models/conv_poke_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.loggers import WandbLogger 5 | from pytorch_lightning.metrics.regression import SSIM, PSNR 6 | from torch.optim import Adam, lr_scheduler 7 | import wandb 8 | 9 | from models.modules.autoencoders.fully_conv_models import FirstStageWrapper 10 | from models.modules.autoencoders.LPIPS import LPIPS as PerceptualLoss 11 | from utils.metrics import LPIPS 12 | from lpips import LPIPS as lpips_net 13 | from utils.logging import batches2flow_grid 14 | 15 | 16 | class ConvPokeAE(pl.LightningModule): 17 | 18 | def __init__(self, config): 19 | super().__init__() 20 | # self.automatic_optimization=False 21 | self.config = config 22 | self.be_deterministic = self.config["architecture"]['deterministic'] 23 | self.kl_weight = self.config["training"]['w_kl'] 24 | self.register_buffer("disc_factor", torch.tensor(1.), persistent=True) 25 | self.register_buffer("disc_weight", torch.tensor(1.), persistent=True) 26 | self.register_buffer("perc_weight", torch.tensor(1.), persistent=True) 27 | self.logvar = nn.Parameter(torch.ones(size=()) * 0.0) 28 | self.n_logged_imgs = self.config["logging"]["n_log_images"] 29 | self.flow_ae = "flow_ae" in self.config["architecture"] and self.config["architecture"]["flow_ae"] 30 | self.poke_and_image = "poke_and_image" in self.config["architecture"] and self.config["architecture"]["poke_and_image"] 31 | 32 | 33 | self.vgg_loss = PerceptualLoss() 34 | 35 | # ae 36 | self.model = FirstStageWrapper(self.config) 37 | 38 | 39 | # metrics 40 | # self.ssim = SSIM() 41 | # self.psnr = PSNR() 42 | self.lpips_net = lpips_net() 43 | for param in self.lpips_net.parameters(): 44 | param.requires_grad = False 45 | 46 | self.lpips_metric = LPIPS() 47 | 48 | def setup(self, stage: str): 49 | assert isinstance(self.logger, WandbLogger) 50 | self.logger.experiment.watch(self, log="all") 51 | 52 | def training_step(self, batch, batch_idx): 53 | if isinstance(batch['poke'], list): 54 | poke = batch["poke"][0] 55 | poke_coords = batch["poke"][1] 56 | else: 57 | poke = batch['poke'] 58 | poke_coords = None 59 | flow = batch["flow"] 60 | 61 | if self.poke_and_image: 62 | img = batch["images"][:, 0] 63 | poke = torch.cat([poke, img], dim=1) 64 | 65 | poke_in = flow if self.flow_ae else poke 66 | 67 | rec = self.model(poke_in) 68 | rec_loss = torch.abs(flow.contiguous() - rec.contiguous()) 69 | 70 | zeros = torch.zeros((flow.size(0), 1, *flow.shape[-2:]), device=self.device) 71 | p_loss = self.vgg_loss(torch.cat([flow, zeros], 1).contiguous(), torch.cat([rec, zeros], 1).contiguous()) 72 | # equal weighting of l1 and perceptual loss 73 | rec_loss = rec_loss + self.perc_weight * p_loss 74 | 75 | 76 | 77 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 78 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 79 | 80 | loss = nll_loss 81 | 82 | loss_dict = {"train/loss": loss, "train/logvar": self.logvar.detach(), "train/nll_loss": nll_loss, } 83 | self.log_dict(loss_dict, logger=True, on_epoch=True, on_step=True) 84 | self.log("global step", self.global_step) 85 | self.log("learning rate", self.optimizers().param_groups[0]["lr"], on_step=True, logger=True) 86 | 87 | self.log("overall_loss", loss, prog_bar=True, logger=False) 88 | self.log("nll_loss", nll_loss, prog_bar=True, logger=False) 89 | 90 | if self.global_step % self.config["logging"]["log_train_prog_at"] == 0: 91 | flow_orig = batch["original_flow"].detach() 92 | img = batch["images"][:, 0].detach() 93 | poke = poke[:, :2].detach() 94 | flows = [poke.detach(), rec.detach(), flow.detach(), flow_orig.detach()] 95 | captions = ["Poke", "Flow-rec", "Flow-target", "Flow-orig"] 96 | if self.flow_ae: 97 | poke_coords = None 98 | train_grid_cmap = batches2flow_grid(flows, captions, n_logged=self.n_logged_imgs, img=img, poke=flows[0], 99 | poke_coords=poke_coords,poke_normalized=False) 100 | train_grid_quiver = batches2flow_grid(flows, captions, n_logged=self.n_logged_imgs, quiver=True, img=img, poke=flows[0], 101 | poke_coords=poke_coords,poke_normalized=False) 102 | self.logger.experiment.log({f"Train Batch Cmap": wandb.Image(train_grid_cmap, 103 | caption=f"Training Images @ it #{self.global_step}"), 104 | f"Train Batch Quiver": wandb.Image(train_grid_quiver, 105 | caption=f"Training Images @ it #{self.global_step}"), 106 | }, step=self.global_step) 107 | 108 | return loss 109 | 110 | def training_epoch_end(self, outputs): 111 | self.log("epoch", self.current_epoch) 112 | 113 | def validation_step(self, batch, batch_id): 114 | #poke = batch["poke"][0] if isinstance(batch["poke"], list) else batch["poke"] 115 | if isinstance(batch['poke'],list): 116 | poke = batch["poke"][0] 117 | poke_coords = batch["poke"][1] 118 | else: 119 | poke = batch['poke'] 120 | poke_coords=None 121 | 122 | flow = batch["flow"] 123 | 124 | if self.poke_and_image: 125 | img = batch["images"][:, 0] 126 | poke = torch.cat([poke, img], dim=1) 127 | 128 | poke_in = flow if self.flow_ae else poke 129 | with torch.no_grad(): 130 | rec = self.model(poke_in) 131 | rec_loss = torch.abs(flow.contiguous() - rec.contiguous()) 132 | 133 | zeros = torch.zeros((flow.size(0), 1, *flow.shape[-2:]), device=self.device) 134 | f3 = torch.cat([flow, zeros], 1).contiguous() 135 | r3 = torch.cat([rec, zeros], 1).contiguous() 136 | p_loss = self.vgg_loss(f3, r3) 137 | # equal weighting of l1 and perceptual loss 138 | rec_loss = rec_loss + self.perc_weight * p_loss 139 | 140 | 141 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 142 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 143 | loss = nll_loss 144 | 145 | log_dict = {"val/loss": loss, "val/logvar": self.logvar.detach(), 146 | "val/nll_loss": nll_loss, "val/rec_loss": rec_loss} 147 | 148 | self.log_dict(log_dict, logger=True, prog_bar=False, on_epoch=True) 149 | 150 | # self.log("ssim-val", self.ssim(rec, flow).cpu(), on_step=False, on_epoch=True, logger=True) 151 | # self.log("psnr-val", self.psnr(rec, flow).cpu(), on_step=False, on_epoch=True, logger=True) 152 | self.log("lpips-val", self.lpips_metric(self.lpips_net, r3, f3).cpu(), on_step=False, on_epoch=True, logger=True) 153 | 154 | if batch_id < self.config["logging"]["n_val_img_batches"]: 155 | flow_orig = batch["original_flow"].detach() 156 | img = batch["images"][:, 0].detach() 157 | poke = poke[:, :2].detach() 158 | flows = [poke, rec, flow, flow_orig] 159 | captions = ["Poke", "Flow-rec", "Flow-target", "Flow-orig"] 160 | if self.flow_ae: 161 | poke_coords = None 162 | val_grid_cmap = batches2flow_grid(flows, captions, n_logged=self.n_logged_imgs, img=img, poke=flows[0], 163 | poke_coords=poke_coords,poke_normalized=False) 164 | val_grid_quiver = batches2flow_grid(flows, captions, n_logged=self.n_logged_imgs, quiver=True, img=img, poke=flows[0], 165 | poke_coords=poke_coords,poke_normalized=False) 166 | self.logger.experiment.log({f"Validation Batch #{batch_id} Cmap Plot": wandb.Image(val_grid_cmap, 167 | caption=f"Validation Images @ it {self.global_step}"), 168 | f"Validation Batch #{batch_id} Quiver Plot": wandb.Image(val_grid_quiver, 169 | caption=f"Validation Images @ it {self.global_step}") 170 | }, step=self.global_step 171 | ) 172 | 173 | return log_dict, batch_id 174 | 175 | def configure_optimizers(self): 176 | # optimizers 177 | opt_g = Adam(self.parameters(), lr=self.config["training"]["lr"], weight_decay=self.config["training"]["weight_decay"]) 178 | # schedulers 179 | sched_g = lr_scheduler.ReduceLROnPlateau(opt_g, mode="min", factor=.5, patience=1, min_lr=1e-8, 180 | threshold=0.0001, threshold_mode='abs') 181 | return [opt_g], [{'scheduler': sched_g, 'monitor': "loss-val", "interval": 1, 'reduce_on_plateau': True, 'strict': True}, ] 182 | # return ({'optimizer': opt_g,'lr_scheduler':sched_g,'monitor':"loss-val","interval":1,'reduce_on_plateau':True,'strict':True}, 183 | # {'optimizer': opt_d,'lr_scheduler':sched_d,'monitor':"loss-val","interval":1,'reduce_on_plateau':True,'strict':True}) 184 | -------------------------------------------------------------------------------- /models/modules/INN/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Normal 4 | 5 | 6 | class FlowLoss(nn.Module): 7 | def __init__(self,spatial_mean=False, logdet_weight=1.): 8 | super().__init__() 9 | # self.config = config 10 | self.spatial_mean = spatial_mean 11 | self.logdet_weight = logdet_weight 12 | 13 | def forward(self, sample, logdet): 14 | nll_loss = torch.mean(nll(sample, spatial_mean=self.spatial_mean)) 15 | assert len(logdet.shape) == 1 16 | if self.spatial_mean: 17 | h,w = sample.shape[-2:] 18 | nlogdet_loss = -torch.mean(logdet) / (h*w) 19 | else: 20 | nlogdet_loss = -torch.mean(logdet) 21 | 22 | loss = nll_loss + self.logdet_weight*nlogdet_loss 23 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample),spatial_mean=self.spatial_mean)) 24 | log = { 25 | "flow_loss": loss, 26 | "reference_nll_loss": reference_nll_loss, 27 | "nlogdet_loss": nlogdet_loss, 28 | "nll_loss": nll_loss, 29 | 'logdet_weight': self.logdet_weight 30 | } 31 | return loss, log 32 | 33 | class FlowLossAlternative(nn.Module): 34 | def __init__(self): 35 | super().__init__() 36 | # self.config = config 37 | 38 | def forward(self, sample, logdet): 39 | nll_loss = torch.mean(torch.sum(0.5*torch.pow(sample, 2), dim=1)) 40 | nlogdet_loss = - logdet.mean() 41 | 42 | 43 | loss = nll_loss + nlogdet_loss 44 | reference_sample = torch.randn_like(sample) 45 | reference_nll_loss = torch.mean(torch.sum(0.5*torch.pow(reference_sample, 2), dim=1)) 46 | log = { 47 | "flow_loss": loss, 48 | "reference_nll_loss": reference_nll_loss, 49 | "nlogdet_loss": nlogdet_loss, 50 | "nll_loss": nll_loss 51 | } 52 | return loss, log 53 | 54 | class ExtendedFlowLoss(nn.Module): 55 | def __init__(self,): 56 | super().__init__() 57 | # self.config = config 58 | 59 | def forward(self, sample_x, sample_v, logdet): 60 | nll_loss_x = torch.mean(nll(sample_x)) 61 | nll_loss_v = torch.mean(nll(sample_v)) 62 | assert len(logdet.shape) == 1 63 | nlogdet_loss = -torch.mean(logdet) 64 | loss = nll_loss_x + nll_loss_v + nlogdet_loss 65 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample_x))) 66 | log = { 67 | "flow_loss": loss, 68 | "reference_nll_loss": reference_nll_loss, 69 | "nlogdet_loss": nlogdet_loss, 70 | "nll_loss_x": nll_loss_x, 71 | "nll_loss_v": nll_loss_v 72 | } 73 | return loss, log 74 | 75 | def nll(sample, spatial_mean= False): 76 | if spatial_mean: 77 | return 0.5 * torch.sum(torch.mean(torch.pow(sample, 2),dim=[2,3]), dim=1) 78 | else: 79 | return 0.5 * torch.sum(torch.pow(sample, 2), dim=[1, 2, 3]) 80 | 81 | 82 | class GaussianLogP(nn.Module): 83 | 84 | def __init__(self,mu=0.,sigma=1.): 85 | super().__init__() 86 | self.dist = Normal(loc=mu,scale=sigma) 87 | 88 | def forward(self,sample,logdet): 89 | nll_log_loss = torch.sum(self.dist.log_prob(sample)) / sample.size(0) 90 | nlogdet_loss = torch.mean(logdet) 91 | reference_nll_loss = torch.mean(nll(torch.randn_like(sample))) 92 | nll_loss = torch.mean(nll(sample)) 93 | loss = - (nll_log_loss + nlogdet_loss) 94 | log = {"flow_loss":loss, 95 | "reference_nll_loss":reference_nll_loss, 96 | "nlogdet_loss":-nlogdet_loss, 97 | "nll_loss": nll_loss, 98 | "nll_log_loss":-nll_log_loss} 99 | 100 | return loss, log -------------------------------------------------------------------------------- /models/modules/autoencoders/LPIPS.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.modules.autoencoders.vgg16 import vgg16, normalize_tensor, spatial_average 7 | from models.modules.autoencoders.ckpt_util import get_ckpt_path 8 | 9 | 10 | class LPIPS(nn.Module): 11 | # Learned perceptual metric 12 | # Be careful about requires_grad and eval-mode 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name) 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | # if __name__ == "__main__": 77 | # # DataParallel test 78 | # ngpu = torch.cuda.device_count() 79 | # device_ids = [n for n in range(ngpu)] 80 | # lpips = LPIPS() 81 | # # lpips = torch.nn.DataParallel(lpips, device_ids=device_ids) 82 | # lpips.cuda() 83 | # x = torch.randn(16, 3, 128, 128).cuda() 84 | # y = torch.randn(16, 3, 128, 128).cuda() 85 | # loss = lpips(x, y) 86 | # print("test loss:", loss.mean()) 87 | # print("done.") -------------------------------------------------------------------------------- /models/modules/autoencoders/baseline_fc_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import spectral_norm 4 | import numpy as np 5 | 6 | from models.modules.autoencoders.util import Conv2dBlock, ResBlock, Spade, NormConv2d 7 | from models.modules.autoencoders.fully_conv_models import ConvEncoder 8 | 9 | 10 | class FirstStageFCWrapper(nn.Module): 11 | 12 | def __init__(self,config): 13 | super().__init__() 14 | self.config = config 15 | self.encoder = BaselineFCEncoder(config=self.config) 16 | # no spade layer as these are only for encoder training 17 | self.config['architecture']['dec_channels'] = [self.config['architecture']['nf_max']] + self.encoder.depths 18 | self.config['architecture']['spectral_norm'] = True 19 | self.config['architecture'].update({'z_dim': self.config['architecture']['nf_max']}) 20 | self.decoder = BaselineFCGenerator(config=self.config['architecture'],use_spade=False) 21 | self.be_deterministic = True 22 | 23 | def forward(self,x): 24 | enc = self.encoder(x) 25 | return self.decoder([enc],None) 26 | 27 | class BaselineFCEncoder(ConvEncoder): 28 | 29 | def __init__(self,config): 30 | self.config = config 31 | n_stages = int( 32 | np.log2(self.config["data"]["spatial_size"][0] // 4)) 33 | nf_max = self.config['architecture']['nf_max'] 34 | nf_in = self.config['architecture']['nf_in'] 35 | #always determinstic 36 | self.deterministic = True 37 | super().__init__(nf_in,nf_max,n_stages,variational=not self.deterministic) 38 | 39 | self.make_fc = NormConv2d(nf_max,nf_max,4,padding=0) 40 | 41 | def forward(self, x): 42 | # onky use output as model is not varaitional 43 | out, *_ =super().forward(x,sample_prior=False) 44 | out = self.make_fc(out).squeeze(dim=-1).squeeze(dim=-1) 45 | return out 46 | 47 | 48 | 49 | 50 | class BaselineFCGenerator(nn.Module): 51 | 52 | def __init__(self,config, use_spade=True): 53 | super().__init__() 54 | channels = config['dec_channels'] 55 | snorm = config['spectral_norm'] 56 | latent_dim = config['z_dim'] 57 | nc_out = config['nc_out'] if 'nc_out' in config else 3 58 | self.use_spade = use_spade 59 | 60 | self.blocks = nn.ModuleList() 61 | self.spade_blocks = nn.ModuleList() 62 | self.first_conv_nf = channels[0] 63 | self.n_stages = len(channels)-1 64 | if snorm: 65 | self.start_block = spectral_norm(nn.Linear(in_features=latent_dim,out_features=self.first_conv_nf * 16,)) 66 | else: 67 | self.start_block = nn.Linear(in_features=latent_dim,out_features=channels[0] * 16,) 68 | nf = 0 69 | for i, nf in enumerate(channels[1:]): 70 | n_out = nf 71 | 72 | nf_in_dec = channels[i] 73 | self.blocks.append(ResBlock(nf_in_dec, n_out, norm='none' if self.use_spade else 'group', upsampling=True, snorm=config['spectral_norm'])) 74 | if self.use_spade: 75 | self.spade_blocks.append(Spade(n_out, config)) 76 | 77 | self.out_conv = Conv2dBlock(nf, nc_out, 3, 1, 1, norm="none", 78 | activation="tanh") 79 | 80 | 81 | def forward(self,actual_frame,start_frame,del_shape=True): 82 | x = self.start_block(actual_frame.pop() if del_shape else actual_frame[-1]) 83 | x = x.reshape(x.size(0),self.first_conv_nf,4,4) 84 | for n in range(self.n_stages): 85 | x = self.blocks[n](x) 86 | if self.use_spade: 87 | x = self.spade_blocks[n](x, start_frame) 88 | 89 | if del_shape: 90 | assert not actual_frame 91 | out = self.out_conv(x) 92 | return out 93 | 94 | # class BaseLineFCEncoder(nn.Module): 95 | 96 | -------------------------------------------------------------------------------- /models/modules/autoencoders/big_ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torchvision 4 | from torchvision import models as tvmodels 5 | 6 | from models.modules.autoencoders.util import ActNorm 7 | from models.modules.autoencoders.distributions import DiagonalGaussianDistribution 8 | from models.modules.autoencoders.biggan import load_variable_latsize_generator 9 | 10 | 11 | 12 | class BigAE(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | import torch.backends.cudnn as cudnn 16 | cudnn.benchmark = True 17 | self.be_deterministic = config['deterministic'] 18 | if "n_out_channels" not in config: 19 | config.update({"n_out_channels": 3}) 20 | self.encoder = ResnetEncoder(config) 21 | self.decoder = BigGANDecoderWrapper(config=config) 22 | 23 | def encode(self, input): 24 | h = input 25 | h = self.encoder(h) 26 | return DiagonalGaussianDistribution(h, deterministic=self.be_deterministic) 27 | 28 | def decode(self, input): 29 | h = input 30 | h = self.decoder(h.squeeze(-1).squeeze(-1)) 31 | return h 32 | 33 | def forward(self, input): 34 | p = self.encode(input) 35 | img = self.decode(p.mode()) 36 | return img, p.mode(), p 37 | 38 | def get_last_layer(self): 39 | return getattr(self.decoder.decoder.colorize.module, 'weight_bar') 40 | 41 | 42 | class ClassUp(nn.Module): 43 | def __init__(self, dim, depth, hidden_dim=256, use_sigmoid=False, out_dim=None): 44 | super().__init__() 45 | layers = [] 46 | layers.append(nn.Linear(dim, hidden_dim)) 47 | layers.append(nn.LeakyReLU()) 48 | for d in range(depth): 49 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 50 | layers.append(nn.LeakyReLU()) 51 | layers.append(nn.Linear(hidden_dim, dim if out_dim is None else out_dim)) 52 | if use_sigmoid: 53 | layers.append(nn.Sigmoid()) 54 | self.main = nn.Sequential(*layers) 55 | 56 | def forward(self, x): 57 | x = self.main(x.squeeze(-1).squeeze(-1)) 58 | x = torch.nn.functional.softmax(x, dim=1) 59 | return x 60 | 61 | 62 | class BigGANDecoderWrapper(nn.Module): 63 | """Wraps a BigGAN into our autoencoding framework""" 64 | def __init__(self, config): 65 | super().__init__() 66 | z_dim = config['z_dim'] 67 | self.do_pre_processing = config['pre_process'] 68 | image_size = config['in_size'] 69 | use_actnorm = config['use_actnorm_in_dec'] 70 | pretrained = config['pretrained'] 71 | class_embedding_dim = 1000 72 | use_adain = config["use_adain"] if "use_adain" in config else False 73 | 74 | self.map_to_class_embedding = ClassUp(z_dim, depth=2, hidden_dim=2*class_embedding_dim, 75 | use_sigmoid=False, out_dim=class_embedding_dim) 76 | self.decoder = load_variable_latsize_generator(image_size, z_dim, 77 | pretrained=pretrained, 78 | use_actnorm=use_actnorm, 79 | n_class=class_embedding_dim, 80 | n_channels=config["n_out_channels"], 81 | adain=use_adain) 82 | 83 | 84 | def forward(self, x, labels=None): 85 | emb = self.map_to_class_embedding(x) 86 | x = self.decoder(x, emb) 87 | return x 88 | 89 | class DenseEncoderLayer(nn.Module): 90 | def __init__(self, scale, spatial_size, out_size, in_channels=None, 91 | width_multiplier=1): 92 | super().__init__() 93 | self.scale = scale 94 | self.wm = width_multiplier 95 | self.in_channels = int(self.wm*64*min(2**(self.scale-1), 16)) 96 | if in_channels is not None: 97 | self.in_channels = in_channels 98 | self.out_channels = out_size 99 | self.kernel_size = spatial_size 100 | self.build() 101 | 102 | def forward(self, input): 103 | x = input 104 | for layer in self.sub_layers: 105 | x = layer(x) 106 | return x 107 | 108 | def build(self): 109 | self.sub_layers = nn.ModuleList([ 110 | nn.Conv2d( 111 | in_channels=self.in_channels, 112 | out_channels=self.out_channels, 113 | kernel_size=self.kernel_size, 114 | stride=1, 115 | padding=0, 116 | bias=True)]) 117 | 118 | 119 | _norm_options = { 120 | "in": nn.InstanceNorm2d, 121 | "bn": nn.BatchNorm2d, 122 | "an": ActNorm} 123 | 124 | rescale = lambda x: 0.5*(x+1) 125 | 126 | class ResnetEncoder(nn.Module): 127 | def __init__(self, config): 128 | super().__init__() 129 | __possible_resnets = { 130 | 'resnet18': tvmodels.resnet18, 131 | 'resnet34': tvmodels.resnet34, 132 | 'resnet50': tvmodels.resnet50, 133 | 'resnet101': tvmodels.resnet101 134 | } 135 | self.config = config 136 | self.do_pre_processing = config['pre_process'] 137 | self.in_channels = self.config["n_in_channels"] if "n_in_channels" in self.config else 3 138 | self.use_inconv = self.in_channels != 3 139 | if self.use_inconv: 140 | # input map 141 | self.in_conv = nn.Conv2d(in_channels=self.config["n_in_channels"],out_channels=3,kernel_size=1) 142 | assert not self.do_pre_processing 143 | 144 | z_dim = config['z_dim'] 145 | ipt_size =config['in_size'] 146 | type_ = config['type'] 147 | load_pretrained = config['pretrained'] 148 | norm_layer = _norm_options[config['norm']] 149 | 150 | self.type = type_ 151 | self.z_dim = z_dim 152 | self.model = __possible_resnets[type_](pretrained=load_pretrained, norm_layer=norm_layer) 153 | 154 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 155 | self.image_transform = torchvision.transforms.Compose( 156 | [torchvision.transforms.Lambda(lambda image: torch.stack([normalize(rescale(x)) for x in image]))] 157 | ) 158 | 159 | size_pre_fc = self._get_spatial_size(ipt_size) 160 | assert size_pre_fc[2]==size_pre_fc[3], 'Output spatial size is not quadratic' 161 | spatial_size = size_pre_fc[2] 162 | num_channels_pre_fc = size_pre_fc[1] 163 | # replace last fc 164 | self.model.fc = DenseEncoderLayer(0, 165 | spatial_size=spatial_size, 166 | out_size=2*z_dim, 167 | in_channels=num_channels_pre_fc) 168 | 169 | def forward(self, x): 170 | if self.do_pre_processing: 171 | x = self._pre_process(x) 172 | features = self.features(x) 173 | encoding = self.model.fc(features) 174 | return encoding 175 | 176 | def features(self, x): 177 | if self.do_pre_processing: 178 | x = self._pre_process(x) 179 | if self.use_inconv: 180 | x = self.in_conv(x) 181 | x = self.model.conv1(x) 182 | x = self.model.bn1(x) 183 | x = self.model.relu(x) 184 | x = self.model.maxpool(x) 185 | x = self.model.layer1(x) 186 | x = self.model.layer2(x) 187 | x = self.model.layer3(x) 188 | x = self.model.layer4(x) 189 | x = self.model.avgpool(x) 190 | return x 191 | 192 | def post_features(self, x): 193 | x = self.model.fc(x) 194 | return x 195 | 196 | def _pre_process(self, x): 197 | x = self.image_transform(x) 198 | return x 199 | 200 | def _get_spatial_size(self, ipt_size): 201 | x = torch.randn(1,self.in_channels , ipt_size, ipt_size) 202 | return self.features(x).size() 203 | 204 | @property 205 | def mean(self): 206 | return [0.485, 0.456, 0.406] 207 | 208 | @property 209 | def std(self): 210 | return [0.229, 0.224, 0.225] 211 | 212 | @property 213 | def input_size(self): 214 | return [3, 224, 224] -------------------------------------------------------------------------------- /models/modules/autoencoders/ckpt_util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "bigae_animals": "https://heibox.uni-heidelberg.de/f/f0adb4d509ea4132b9ea/?dl=1", 7 | "bigae_animalfaces": "https://heibox.uni-heidelberg.de/f/3c0bf40a85a84e2a986e/?dl=1", 8 | "biggan_128": "https://heibox.uni-heidelberg.de/f/56ed256209fd40968864/?dl=1", 9 | "biggan_256": "https://heibox.uni-heidelberg.de/f/437b501944874bcc92a4/?dl=1", 10 | "dequant_vae": "https://heibox.uni-heidelberg.de/f/e7c8959b50a64f40826e/?dl=1", 11 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 12 | } 13 | 14 | CKPT_MAP = { 15 | "bigae_animals": "autoencoders/bigae/animals-1672855.ckpt", 16 | "bigae_animalfaces": "autoencoders/bigae/animalfaces-631606.ckpt", 17 | "biggan_128": "autoencoders/biggan/biggan-128.pth", 18 | "biggan_256": "autoencoders/biggan/biggan-256.pth", 19 | "dequant_vae": "autoencoders/dequant/dequantvae-20000.ckpt", 20 | "vgg_lpips": "autoencoders/lpips/vgg.pth" 21 | } 22 | 23 | MD5_MAP = { 24 | "bigae_animals": "6213882571854935226a041b8dcaecdd", 25 | "bigae_animalfaces": "7f379d6ebcbc03a710ef0605806f0b51", 26 | "biggan_128": "a2148cf64807444113fac5eede060d28", 27 | "biggan_256": "e23db3caa34ac4c4ae922a75258dcb8d", 28 | "dequant_vae": "5c2a6fe765142cbdd9f10f15d65a68b6", 29 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 30 | } 31 | 32 | 33 | def download(url, local_path, chunk_size=1024): 34 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 35 | with requests.get(url, stream=True) as r: 36 | total_size = int(r.headers.get("content-length", 0)) 37 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 38 | with open(local_path, "wb") as f: 39 | for data in r.iter_content(chunk_size=chunk_size): 40 | if data: 41 | f.write(data) 42 | pbar.update(chunk_size) 43 | 44 | 45 | def md5_hash(path): 46 | with open(path, "rb") as f: 47 | content = f.read() 48 | return hashlib.md5(content).hexdigest() 49 | 50 | 51 | def get_ckpt_path(name, root=None, check=False): 52 | assert name in URL_MAP 53 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 54 | root = root if root is not None else os.path.join(cachedir, "autoencoders") 55 | path = os.path.join(root, CKPT_MAP[name]) 56 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 57 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 58 | download(URL_MAP[name], path) 59 | md5 = md5_hash(path) 60 | assert md5 == MD5_MAP[name], md5 61 | return path 62 | -------------------------------------------------------------------------------- /models/modules/autoencoders/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class AbstractDistribution: 7 | def sample(self): 8 | raise NotImplementedError() 9 | 10 | def mode(self): 11 | raise NotImplementedError() 12 | 13 | 14 | class DiracDistribution(AbstractDistribution): 15 | def __init__(self, value): 16 | self.value = value 17 | 18 | def sample(self): 19 | return self.value 20 | 21 | def mode(self): 22 | return self.value 23 | 24 | 25 | class DiagonalGaussianDistribution(object): 26 | def __init__(self, parameters, deterministic=False): 27 | self.parameters = parameters 28 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 29 | self.logvar = torch.clamp(self.logvar, -30.0, 10.0) 30 | self.deterministic = deterministic 31 | self.std = torch.exp(0.5 * self.logvar) 32 | self.var = torch.exp(self.logvar) 33 | if self.deterministic: 34 | self.var = self.std = torch.zeros_like(self.mean) 35 | 36 | def sample(self): 37 | x = self.mean + self.std * torch.randn_like(self.mean) 38 | return x 39 | 40 | def kl(self, other=None): 41 | if self.deterministic: 42 | return torch.Tensor([0.]) 43 | else: 44 | if other is None: 45 | return torch.mean(0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])) 46 | else: 47 | return 0.5 * torch.sum( 48 | torch.pow(self.mean - other.mean, 2) / other.var 49 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 50 | dim=[1, 2, 3]) 51 | 52 | def nll(self, sample): 53 | if self.deterministic: 54 | return torch.Tensor([0.]) 55 | logtwopi = np.log(2.0 * np.pi) 56 | return 0.5 * torch.sum( 57 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 58 | dim=[1, 2, 3]) 59 | 60 | def mode(self): 61 | return self.mean 62 | -------------------------------------------------------------------------------- /models/modules/autoencoders/fully_conv_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from models.modules.autoencoders.util import Conv2dBlock, ResBlock, AdaINLinear, NormConv2d, Spade 7 | 8 | 9 | class FirstStageWrapper(nn.Module): 10 | 11 | def __init__(self,config): 12 | super().__init__() 13 | self.config = config 14 | self.be_deterministic = self.config["architecture"]["deterministic"] 15 | n_stages = int(np.log2(self.config["data"]["spatial_size"][0] // self.config["architecture"]["min_spatial_size"])) 16 | nf_in_enc = self.config["architecture"]["nf_in"] 17 | if "poke_and_image" in self.config["architecture"] and self.config["architecture"]["poke_and_image"]: 18 | nf_in_enc+=3 19 | self.encoder = ConvEncoder(nf_in=nf_in_enc, nf_max=self.config["architecture"]["nf_max"], 20 | n_stages=n_stages, variational=not self.be_deterministic) 21 | decoder_channels = [self.config["architecture"]["nf_max"]] + self.encoder.depths 22 | self.decoder = ConvDecoder(self.config["architecture"]["nf_max"], decoder_channels, out_channels=self.config["architecture"]["nf_in"]) 23 | 24 | def forward(self,x): 25 | enc, *_ = self.encoder(x) 26 | return self.decoder([enc],del_shape=False) 27 | 28 | class ConvEncoder(nn.Module): 29 | def __init__(self, nf_in, nf_max, n_stages, variational=False, norm_layer = "group", layers=None, spectral_norm=True): 30 | super().__init__() 31 | 32 | self.variational = variational 33 | self.depths = [] 34 | 35 | act = "elu" #if self.variational else "relu" 36 | 37 | blocks = [] 38 | bottleneck = [] 39 | nf = 32 if layers is None else layers[0] 40 | blocks.append( 41 | Conv2dBlock( 42 | nf_in, nf, 3, 2, norm=norm_layer, activation=act, padding=1,snorm=spectral_norm 43 | ) 44 | ) 45 | self.depths.append(nf) 46 | n_stages = n_stages if layers is None else len(layers) 47 | for n in range(n_stages - 1): 48 | blocks.append( 49 | ResBlock( 50 | nf, 51 | min(nf * 2, nf_max) if layers is None else layers[n+1], 52 | stride = 2, 53 | norm=norm_layer, 54 | activation=act, 55 | snorm=spectral_norm 56 | ) 57 | ) 58 | nf = min(nf * 2, nf_max) if layers is None else layers[n+1] 59 | self.depths.insert(0,nf) 60 | 61 | self.nf_in_bn = nf 62 | bottleneck.append(ResBlock(nf, nf_max,activation=act, norm=norm_layer)) 63 | # if layers is None: 64 | # bottleneck.append(ResBlock(nf_max, nf_max,activation=act, norm=norm_layer)) 65 | 66 | 67 | if self.variational: 68 | self.make_mu = NormConv2d(nf_max,nf_max,3, padding=1) 69 | self.make_sigma = NormConv2d(nf_max,nf_max,3, padding=1) 70 | self.squash = nn.Sigmoid() 71 | 72 | self.model = nn.Sequential(*blocks) 73 | self.bottleneck = nn.Sequential(*bottleneck) 74 | 75 | def forward(self, input, sample_prior=False): 76 | out = self.model(input) 77 | mean = out 78 | out = self.bottleneck(out) 79 | logstd = None 80 | if self.variational: 81 | mean = self.make_mu(out) 82 | # normalize sigma in between 83 | logstd = self.squash(self.make_sigma(out)) 84 | if sample_prior: 85 | out = torch.randn_like(mean) 86 | else: 87 | out = self.reparametrize(mean,logstd) 88 | 89 | return out, mean, logstd 90 | 91 | def reparametrize(self,mean,logstd): 92 | std = torch.exp(logstd) 93 | eps = torch.randn_like(std) 94 | return eps.mul(std) + mean 95 | 96 | class ConvDecoder(nn.Module): 97 | """ 98 | Fully convolutional decoder consisting of resnet blocks, with optional skip connections (default no skip connections; if these 99 | shall be used, set n_skip_stages > 0 100 | """ 101 | def __init__(self,nf_in, in_channels, n_skip_stages=0, spectral_norm=True, norm_layer="group",layers=None,out_channels=3): 102 | super().__init__() 103 | self.n_stages = len(in_channels)-1 104 | self.n_skip_stages = n_skip_stages 105 | 106 | self.blocks = nn.ModuleList() 107 | 108 | nf = nf_in 109 | self.in_block = ResBlock(nf,in_channels[0], snorm=spectral_norm, norm=norm_layer) 110 | 111 | for i,nf in enumerate(in_channels[1:]): 112 | if layers is None: 113 | n_out = nf 114 | 115 | nf_in_dec = 2 * nf if i < self.n_skip_stages else in_channels[i] 116 | # if layers is not None: 117 | # nf_in_dec = 2 * nf 118 | # n_out = in_channels[i+1] if i < len(in_channels) -1 else nf 119 | self.blocks.append(ResBlock(nf_in_dec, n_out , norm=norm_layer, upsampling=True,snorm=spectral_norm)) 120 | 121 | self.out_conv = Conv2dBlock(nf,out_channels,3,1,1,norm="none",activation="tanh" if out_channels==3 else "none") 122 | 123 | def forward(self,shape, del_shape=True): 124 | x = self.in_block(shape.pop() if del_shape else shape[-1]) 125 | for n in range(self.n_stages): 126 | if n < self.n_skip_stages: 127 | x = torch.cat([x,shape.pop() if del_shape else shape[self.n_skip_stages-1-n]],1) 128 | x = self.blocks[n](x) 129 | 130 | if del_shape: 131 | assert not shape 132 | out = self.out_conv(x) 133 | return out 134 | 135 | class SpadeCondConvDecoder(nn.Module): 136 | 137 | def __init__(self,config,stacked_input=False): 138 | super().__init__() 139 | 140 | in_channels = config['dec_channels'] 141 | 142 | self.n_stages = len(in_channels) - 1 143 | self.n_skip_stages = config['n_skip_stages'] if 'n_skip_stages' in config else 0 144 | out_channels = config['out_channels'] if 'out_channels' in config else 3 145 | 146 | self.blocks = nn.ModuleList() 147 | self.spade_blocks = nn.ModuleList() 148 | 149 | 150 | nf = 2*config['z_dim'] if stacked_input else config['z_dim'] 151 | self.in_block = ResBlock(nf, in_channels[0], snorm=config['spectral_norm'], norm=config['norm']) 152 | 153 | for i, nf in enumerate(in_channels[1:]): 154 | n_out = nf 155 | 156 | nf_in_dec = 2 * nf if i < self.n_skip_stages else in_channels[i] 157 | # if layers is not None: 158 | # nf_in_dec = 2 * nf 159 | # n_out = in_channels[i+1] if i < len(in_channels) -1 else nf 160 | self.blocks.append(ResBlock(nf_in_dec, n_out, norm='none', upsampling=True, snorm=config['spectral_norm'])) 161 | self.spade_blocks.append(Spade(n_out,config)) 162 | 163 | self.out_conv = Conv2dBlock(nf, out_channels, 3, 1, 1, norm="none", 164 | activation="tanh" if out_channels == 3 else "none") 165 | 166 | def forward(self, actual_frame ,start_frame, del_shape=True): 167 | x = self.in_block(actual_frame.pop() if del_shape else actual_frame[-1]) 168 | for n in range(self.n_stages): 169 | if n < self.n_skip_stages: 170 | x = torch.cat([x, actual_frame.pop() if del_shape else actual_frame[self.n_skip_stages - 1 - n]], 1) 171 | x = self.blocks[n](x) 172 | x = self.spade_blocks[n](x,start_frame) 173 | 174 | if del_shape: 175 | assert not actual_frame 176 | out = self.out_conv(x) 177 | return out -------------------------------------------------------------------------------- /models/modules/autoencoders/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | from collections import namedtuple 4 | 5 | 6 | class vgg16(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(vgg16, self).__init__() 9 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.N_slices = 5 16 | for x in range(4): 17 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 18 | for x in range(4, 9): 19 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(9, 16): 21 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(16, 23): 23 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(23, 30): 25 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 26 | if not requires_grad: 27 | for param in self.parameters(): 28 | param.requires_grad = False 29 | 30 | def forward(self, X): 31 | h = self.slice1(X) 32 | h_relu1_2 = h 33 | h = self.slice2(h) 34 | h_relu2_2 = h 35 | h = self.slice3(h) 36 | h_relu3_3 = h 37 | h = self.slice4(h) 38 | h_relu4_3 = h 39 | h = self.slice5(h) 40 | h_relu5_3 = h 41 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 42 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 43 | return out 44 | 45 | 46 | def normalize_tensor(x,eps=1e-10): 47 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 48 | return x/(norm_factor+eps) 49 | 50 | 51 | def spatial_average(x, keepdim=True): 52 | return x.mean([2,3],keepdim=keepdim) -------------------------------------------------------------------------------- /models/modules/discriminators/disc_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def KLDLoss(mu, logvar): 6 | return -0.5 * torch.mean(torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=[1, 2, 3])) 7 | 8 | 9 | def calculate_adaptive_weight(nll_loss, g_loss, discriminator_weight, last_layer=None): 10 | if last_layer is not None: 11 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 12 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 13 | else: 14 | nll_grads = torch.autograd.grad(nll_loss, last_layer[0], retain_graph=True)[0] 15 | g_grads = torch.autograd.grad(g_loss, last_layer[0], retain_graph=True)[0] 16 | 17 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 18 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 19 | d_weight = d_weight * discriminator_weight 20 | return d_weight 21 | 22 | 23 | def hinge_d_loss(logits_real, logits_fake): 24 | loss_real = torch.mean(F.relu(1. - logits_real)) 25 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 26 | d_loss = 0.5 * (loss_real + loss_fake) 27 | return d_loss 28 | 29 | 30 | def adopt_weight(weight, epoch, threshold=0, value=0.): 31 | if epoch < threshold: 32 | weight = value 33 | return weight 34 | -------------------------------------------------------------------------------- /models/modules/motion_models/motion_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn, torch 2 | import numpy as np 3 | 4 | 5 | ###################################################################################################### 6 | ###3D-ConvNet Implementation from https://github.com/tomrunia/PyTorchConv3D ########################## 7 | 8 | def resnet10(**kwargs): 9 | """Constructs a ResNet-10 model. 10 | """ 11 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 12 | return model 13 | 14 | 15 | def resnet18(**kwargs): 16 | """Constructs a ResNet-18 model. 17 | """ 18 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 19 | return model 20 | 21 | def resnet18_alternative(**kwargs): 22 | model = ResNetMotionEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 23 | return model 24 | 25 | 26 | 27 | def resnet34(**kwargs): 28 | """Constructs a ResNet-34 model. 29 | """ 30 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 31 | return model 32 | 33 | 34 | def conv3x3x3(in_planes, out_planes, stride=1): 35 | # 3x3x3 convolution with padding 36 | return nn.Conv3d( 37 | in_planes, 38 | out_planes, 39 | kernel_size=3, 40 | stride=stride, 41 | padding=1, 42 | bias=False) 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | expansion = 1 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(BasicBlock, self).__init__() 50 | self.conv1 = conv3x3x3(inplanes, planes, stride) 51 | self.bn1 = nn.GroupNorm(num_groups=16, num_channels=planes) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.conv2 = conv3x3x3(planes, planes) 54 | self.bn2 = nn.GroupNorm(num_groups=16, num_channels=planes) 55 | self.downsample = downsample 56 | self.stride = stride 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class ResNet(nn.Module): 78 | 79 | def __init__(self, block, layers, dic): 80 | self.inplanes = 64 81 | super(ResNet, self).__init__() 82 | self.spatial_size = dic['img_size'] 83 | channels = dic['ENC_M_channels'] 84 | # currently no deterministic motion encoder required 85 | self.be_determinstic = False 86 | self.conv1 = nn.Conv3d(3, channels[0], kernel_size=(3, 7, 7), stride=(2, 2, 2), padding=(1, 3, 3), bias=False) 87 | self.bn1 = nn.GroupNorm(num_groups=16, num_channels=channels[0]) 88 | self.relu = nn.ReLU(inplace=True) 89 | # self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 90 | self.layer1 = self._make_layer(block, channels[1], layers[0]) 91 | self.layer2 = self._make_layer(block, channels[2], layers[1], stride=2) 92 | self.layer3 = self._make_layer(block, channels[3], layers[2], stride=2) 93 | last_channels = channels[3] 94 | if self.spatial_size // 2**3 > 4: 95 | self.layer4 = self._make_layer(block, channels[4], layers[3], stride=2) 96 | last_channels = channels[4] 97 | if self.spatial_size // 2**4 > 4: 98 | self.layer5 = self._make_layer(block, channels[5], layers[3], stride=2) 99 | last_channels = channels[5] 100 | 101 | self.conv_mu = nn.Conv2d(last_channels, dic['z_dim'], 4, 1, 0) 102 | self.conv_var = nn.Conv2d(last_channels, dic['z_dim'], 4, 1, 0) 103 | 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv3d): 106 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 107 | 108 | def _make_layer(self, block, planes, blocks, stride=1): 109 | downsample = None 110 | if stride != 1 or self.inplanes != planes * block.expansion: 111 | downsample = nn.Sequential( 112 | nn.Conv3d( 113 | self.inplanes, 114 | planes * block.expansion, 115 | kernel_size=1, 116 | stride=stride, 117 | bias=False), 118 | nn.GroupNorm(num_channels=planes * block.expansion, num_groups=16)) 119 | 120 | layers = [block(self.inplanes, planes, stride, downsample)] 121 | self.inplanes = planes * block.expansion 122 | for _ in range(1, blocks): 123 | layers.append(block(self.inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def reparameterize(self, emb): 128 | mu, logvar = self.conv_mu(emb).reshape(emb.size(0), -1), self.conv_var(emb).reshape(emb.size(0), -1) 129 | eps = torch.FloatTensor(logvar.size()).normal_().cuda() 130 | std = logvar.mul(0.5).exp_() 131 | return eps.mul(std).add_(mu), mu, logvar 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.bn1(x) 136 | x = self.relu(x) 137 | # x = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | if self.spatial_size // 2 ** 3 > 4: 143 | x = self.layer4(x) 144 | if self.spatial_size // 2 ** 4 > 4: 145 | x = self.layer5(x) 146 | return self.reparameterize(x.squeeze(2)) 147 | 148 | 149 | 150 | class ResNetMotionEncoder(nn.Module): 151 | 152 | def __init__(self, block, layers, dic): 153 | super().__init__() 154 | self.be_determinstic = 'deterministic' in dic and dic['deterministic'] 155 | channels = dic['ENC_M_channels'] 156 | self.inplanes = channels[0] 157 | self.spatial_size = dic['img_size'] 158 | max_frames = dic['max_frames'] 159 | self.min_ssize = dic['min_spatial_size'] if 'min_spatial_size' in dic else 8 160 | 161 | self.conv1 = nn.Conv3d(3, channels[0], kernel_size=(3, 7, 7), stride=(2, 2, 2), padding=(1, 3, 3), bias=False) 162 | self.bn1 = nn.GroupNorm(num_groups=16, num_channels=channels[0]) 163 | self.relu = nn.ReLU(inplace=True) 164 | # self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 165 | 166 | test = np.log2(max_frames) 167 | first_block_down = len(channels)-1 < int(np.ceil(test)) or dic['full_seq'] 168 | stride1 = (2,1,1) if first_block_down else 1 169 | self.layer1 = self._make_layer(block, channels[1], layers[0],stride=stride1) 170 | self.layer2 = self._make_layer(block, channels[2], layers[1], stride=2) 171 | self.layer3 = self._make_layer(block, channels[3], layers[2], stride=2) 172 | last_channels = channels[3] 173 | 174 | self.stride4 = (2,1,1) if dic['full_seq'] and max_frames >= 16 else None 175 | 176 | if self.spatial_size // 2**3 > self.min_ssize: 177 | self.stride4 = 2 178 | 179 | 180 | if self.stride4 is not None: 181 | if len(channels)<5: 182 | channels.append(channels[-1]) 183 | print(f"Warning: adding one additional layer to motion encoder with channels={channels[-1]}") 184 | self.layer4 = self._make_layer(block, channels[4], layers[3], stride=self.stride4) 185 | last_channels = channels[4] 186 | if self.spatial_size // 2**4 > self.min_ssize: 187 | self.layer5 = self._make_layer(block, channels[5], layers[3], stride=2) 188 | last_channels = channels[5] 189 | 190 | 191 | self.conv_mu = nn.Conv2d(last_channels, dic['z_dim'], 3, 1, 1) 192 | self.conv_var = nn.Conv2d(last_channels, dic['z_dim'], 3, 1, 1) 193 | 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv3d): 196 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 197 | 198 | def _make_layer(self, block, planes, blocks, stride=1): 199 | downsample = None 200 | 201 | if stride != 1 or self.inplanes != planes * block.expansion: 202 | downsample = nn.Sequential( 203 | nn.Conv3d( 204 | self.inplanes, 205 | planes * block.expansion, 206 | kernel_size=1, 207 | stride=stride, 208 | bias=False), 209 | nn.GroupNorm(num_channels=planes * block.expansion, num_groups=16)) 210 | 211 | layers = [block(self.inplanes, planes, stride, downsample)] 212 | self.inplanes = planes * block.expansion 213 | for _ in range(1, blocks): 214 | layers.append(block(self.inplanes, planes)) 215 | 216 | return nn.Sequential(*layers) 217 | 218 | def reparameterize(self, emb): 219 | mu, logvar = self.conv_mu(emb), self.conv_var(emb) 220 | eps = torch.FloatTensor(logvar.size()).normal_().cuda() 221 | std = logvar.mul(0.5).exp_() 222 | return eps.mul(std).add_(mu), mu, logvar 223 | 224 | def forward(self, x): 225 | x = self.conv1(x) 226 | x = self.bn1(x) 227 | x = self.relu(x) 228 | # x = self.maxpool(x) 229 | 230 | x = self.layer1(x) 231 | x = self.layer2(x) 232 | x = self.layer3(x) 233 | if self.stride4 is not None: 234 | x = self.layer4(x) 235 | if self.spatial_size // 2 ** 4 > self.min_ssize: 236 | x = self.layer5(x) 237 | if self.be_determinstic: 238 | _, out, _ = self.reparameterize(x.squeeze(2)) 239 | return out, out , out 240 | else: 241 | return self.reparameterize(x.squeeze(2)) 242 | 243 | 244 | 245 | 246 | if __name__ == '__main__': 247 | ## Test 3dconvnet with dummy input 248 | import os 249 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 250 | 251 | config = {'ENC_M_channels': [32,64,128,128], 'z_dim': 64, 'img_size': 64, 'max_frames': 7} 252 | 253 | model = resnet18_alternative(dic=config ).cuda() 254 | 255 | print(f'model has {sum(p.numel() for p in model.parameters())} parameters') 256 | 257 | dummy = torch.rand((2, 3, config['max_frames'], 64, 64)).cuda() 258 | out, *_= model(dummy) 259 | print(out.shape) 260 | 261 | 262 | -------------------------------------------------------------------------------- /models/modules/motion_models/motion_generator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn, torch 2 | import torch.nn.functional as F 3 | from models.modules.autoencoders.util import Spade, ADAIN, Norm3D 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | class generator_block(nn.Module): 8 | 9 | def __init__(self, n_in, n_out, pars): 10 | super().__init__() 11 | self.learned_shortcut = (n_in != n_out) 12 | n_middle = min(n_in, n_out) 13 | 14 | self.conv_0 = nn.Conv3d(n_in, n_middle, 3, 1, 1) 15 | self.conv_1 = nn.Conv3d(n_middle, n_out, 3, 1, 1) 16 | 17 | if self.learned_shortcut: 18 | self.conv_s = nn.Conv3d(n_in, n_out, 1, bias=False) 19 | 20 | if pars['spectral_norm']: 21 | self.conv_0 = spectral_norm(self.conv_0) 22 | self.conv_1 = spectral_norm(self.conv_1) 23 | 24 | if self.learned_shortcut: 25 | self.conv_s = spectral_norm(self.conv_s) 26 | 27 | self.norm_0 = Spade(n_in, pars) if pars['CN_content'] == 'spade' else Norm3D(n_in, pars) 28 | self.norm_1 = ADAIN(n_middle, pars) if pars['CN_motion'] == 'ADAIN' else Norm3D(n_middle, pars) 29 | 30 | if self.learned_shortcut: 31 | self.norm_s = Norm3D(n_in, pars) 32 | 33 | def forward(self, x, cond1, cond2): 34 | 35 | x_s = self.shortcut(x) 36 | 37 | dx = self.conv_0(self.actvn(self.norm_0(x, cond2))) 38 | dx = self.conv_1(self.actvn(self.norm_1(dx, cond1))) 39 | 40 | out = x_s + dx 41 | 42 | return out 43 | 44 | def shortcut(self, x): 45 | if self.learned_shortcut: 46 | x_s = self.conv_s(self.norm_s(x)) 47 | else: 48 | x_s = x 49 | return x_s 50 | 51 | def actvn(self, x): 52 | return F.leaky_relu(x, 2e-1) 53 | 54 | 55 | class Generator(nn.Module): 56 | def __init__(self, dic): 57 | super().__init__() 58 | 59 | self.img_size = dic["img_size"] 60 | nf = dic['decoder_factor'] 61 | self.z_dim = dic['z_dim'] 62 | self.fmap_start = 16*nf 63 | 64 | self.fc = nn.Linear(dic['z_dim'], 4*4*16*nf) 65 | self.head_0 = generator_block(16*nf, 16*nf, dic) 66 | 67 | self.g_0 = generator_block(16*nf, 16*nf, dic) 68 | self.g_1 = generator_block(16*nf, 8*nf, dic) 69 | self.g_2 = generator_block(8*nf, 4*nf, dic) 70 | self.g_3 = generator_block(4*nf, 2*nf, dic) 71 | self.g_4 = generator_block(2*nf, 1*nf, dic) 72 | 73 | self.conv_img = nn.Conv3d(nf, 3, 3, padding=1) 74 | 75 | self.reset_params() 76 | 77 | @staticmethod 78 | def weight_init(m): 79 | if isinstance(m, nn.Conv2d): 80 | nn.init.xavier_uniform_(m.weight.data, gain=0.02) 81 | # nn.init.orthogonal_(m.weight.data, gain=0.02) 82 | if not isinstance(m.bias, type(None)): 83 | nn.init.constant_(m.bias.data, 0) 84 | 85 | def reset_params(self): 86 | for _, m in enumerate(self.modules()): 87 | self.weight_init(m) 88 | 89 | def forward(self, img, motion): 90 | 91 | x = self.fc(motion).reshape(img.size(0), -1, 1, 4, 4) 92 | # x = torch.ones(img.size(0), self.fmap_start, 1, 4, 4).cuda() 93 | 94 | x = self.head_0(x, motion, img) 95 | 96 | x = F.interpolate(x, scale_factor=2) 97 | x = self.g_0(x, motion, img) 98 | 99 | x = F.interpolate(x, scale_factor=2) 100 | x = self.g_1(x, motion, img) 101 | 102 | x = F.interpolate(x, scale_factor=2) 103 | x = self.g_2(x, motion, img) 104 | 105 | x = F.interpolate(x, scale_factor=(2, 2, 2)) 106 | x = self.g_3(x, motion, img) 107 | 108 | if self.img_size > 64: 109 | x = F.interpolate(x, scale_factor=(1, 2, 2)) 110 | x = self.g_4(x, motion, img) 111 | 112 | x = self.conv_img(F.leaky_relu(x, 2e-1)) 113 | x = torch.tanh(x) 114 | 115 | return x.transpose(1, 2) 116 | 117 | -------------------------------------------------------------------------------- /models/modules/motion_models/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class ConvGRUCell(nn.Module): 5 | """ 6 | Generate a convolutional GRU cell 7 | """ 8 | 9 | def __init__(self, input_size, hidden_size, kernel_size,upsample=False): 10 | super().__init__() 11 | padding = kernel_size // 2 12 | self.input_size = input_size 13 | self.upsample = upsample 14 | self.hidden_size = hidden_size 15 | self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 16 | self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 17 | self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) 18 | if self.upsample: 19 | self.up_gate = nn.ConvTranspose2d(input_size,input_size,kernel_size,2,padding=padding, output_padding=padding) 20 | 21 | 22 | nn.init.orthogonal_(self.reset_gate.weight) 23 | nn.init.orthogonal_(self.update_gate.weight) 24 | nn.init.orthogonal_(self.out_gate.weight) 25 | nn.init.constant_(self.reset_gate.bias, 0.) 26 | nn.init.constant_(self.update_gate.bias, 0.) 27 | nn.init.constant_(self.out_gate.bias, 0.) 28 | if self.upsample: 29 | nn.init.orthogonal_(self.up_gate.weight) 30 | nn.init.constant_(self.up_gate.bias, 0.) 31 | 32 | def forward(self, input_, prev_state): 33 | 34 | if self.upsample: 35 | input_ = self.up_gate(input_) 36 | 37 | # get batch and spatial sizes 38 | batch_size = input_.data.size()[0] 39 | spatial_size = input_.data.size()[2:] 40 | 41 | # generate empty prev_state, if None is provided 42 | if prev_state is None: 43 | state_size = [batch_size, self.hidden_size] + list(spatial_size) 44 | if torch.cuda.is_available(): 45 | prev_state = torch.zeros(state_size).cuda() 46 | else: 47 | prev_state = torch.zeros(state_size) 48 | 49 | # data size is [batch, channel, height, width] 50 | stacked_inputs = torch.cat([input_, prev_state], dim=1) 51 | update = torch.sigmoid(self.update_gate(stacked_inputs)) 52 | reset = torch.sigmoid(self.reset_gate(stacked_inputs)) 53 | out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) 54 | new_state = prev_state * (1 - update) + out_inputs * update 55 | 56 | return new_state 57 | 58 | 59 | class ConvGRU(nn.Module): 60 | 61 | def __init__(self, input_size, hidden_sizes, kernel_sizes, n_layers, upsampling:list=None): 62 | ''' 63 | Generates a multi-layer convolutional GRU. 64 | Preserves spatial dimensions across cells, only altering depth. 65 | Parameters 66 | ---------- 67 | input_size : integer. depth dimension of input tensors. 68 | hidden_sizes : integer or list. depth dimensions of hidden state. 69 | if integer, the same hidden size is used for all cells. 70 | kernel_sizes : integer or list. sizes of Conv2d gate kernels. 71 | if integer, the same kernel size is used for all cells. 72 | n_layers : integer. number of chained `ConvGRUCell`. 73 | ''' 74 | 75 | super(ConvGRU, self).__init__() 76 | if upsampling is None: 77 | upsampling = [False]*n_layers 78 | 79 | self.input_size = input_size 80 | if type(hidden_sizes) != list: 81 | self.hidden_sizes = [hidden_sizes]*n_layers 82 | else: 83 | assert len(hidden_sizes) == n_layers, '`hidden_sizes` must have the same length as n_layers' 84 | self.hidden_sizes = hidden_sizes 85 | if type(kernel_sizes) != list: 86 | self.kernel_sizes = [kernel_sizes]*n_layers 87 | else: 88 | assert len(kernel_sizes) == n_layers, '`kernel_sizes` must have the same length as n_layers' 89 | self.kernel_sizes = kernel_sizes 90 | 91 | self.n_layers = n_layers 92 | 93 | self.cells = [] 94 | for i in range(self.n_layers): 95 | if i == 0: 96 | input_dim = self.input_size 97 | else: 98 | input_dim = self.hidden_sizes[i - 1] 99 | 100 | self.cells.append(ConvGRUCell(input_dim, self.hidden_sizes[i], self.kernel_sizes[i],upsample=upsampling[i])) 101 | 102 | self.cells = nn.Sequential(*self.cells) 103 | 104 | def forward(self, x, hidden=None): 105 | ''' 106 | Parameters 107 | ---------- 108 | x : 4D input tensor. (batch, channels, height, width). 109 | hidden : list of 4D hidden state representations. (layer, batch, channels, height, width). 110 | Returns 111 | ------- 112 | upd_hidden : 5D hidden representation. (layer, batch, channels, height, width). 113 | ''' 114 | 115 | if hidden is None: 116 | hidden = [None]*self.n_layers 117 | 118 | input_ = x 119 | 120 | upd_hidden = [] 121 | 122 | for layer_idx in range(self.n_layers): 123 | cell = self.cells[layer_idx] 124 | cell_hidden = hidden[layer_idx] 125 | 126 | # pass through layer 127 | upd_cell_hidden = cell(input_, cell_hidden) 128 | upd_hidden.append(upd_cell_hidden) 129 | # update input_ to the last updated hidden layer for next pass 130 | input_ = upd_cell_hidden 131 | 132 | # retain tensors in list to allow different hidden sizes 133 | return upd_hidden 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /models/pretrained_models.py: -------------------------------------------------------------------------------- 1 | poke_embedder_models ={ 2 | 'iper-ss128-bn64-endpoint10f-np5': { 3 | 'ckpt': 'logs/poke_encoder/ckpt/iper_128/0/epoch=17-lpips-val=0.298.ckpt', 4 | 'model_name': 'iper_128', 5 | 'tgt_name': 'iper_128' 6 | }, 7 | 'h36m-ss128-bn64-endpoint10f-np5': { 8 | 'ckpt': 'logs/poke_encoder/ckpt/h36m_128/0/epoch=19-lpips-val=0.109.ckpt', 9 | 'model_name': 'h36m_128', 10 | 'tgt_name' : 'h36m_128' 11 | }, 12 | 'plants-ss128-bn64-endpoint10f-np5': { 13 | 'ckpt': 'logs/poke_encoder/ckpt/plants_128/0/epoch=79-lpips-val=0.301.ckpt', 14 | 'model_name': 'plants_128', 15 | 'tgt_name': 'plants_128' 16 | }, 17 | 'iper-ss64-bn8x8x64-endpoint10f-np5': { 18 | 'ckpt': 'logs/poke_encoder/ckpt/iper_64/0/epoch=16-lpips-val=0.172.ckpt', 19 | 'model_name': 'iper_64', 20 | 'tgt_name': 'iper_64' 21 | }, 22 | 'taichi-ss128-bn8x8x64-endpoint10f-np5': { 23 | 'ckpt': 'logs/poke_encoder/ckpt/taichi_128/0/epoch=31-lpips-val=0.314.ckpt', 24 | 'model_name': 'taichi_128', 25 | 'tgt_name': 'taichi_128' 26 | }, 27 | 'taichi-ss64-bn8x8x64-endpoint10f-np5': { 28 | 'ckpt': 'logs/poke_encoder/ckpt/taichi_64/0/epoch=14-lpips-val=0.229.ckpt', 29 | 'model_name': 'taichi_64', 30 | 'tgt_name': 'taichi_64' 31 | }, 32 | 'plants-ss64-bn8x8x64-endpoint10f-np5': { 33 | 'ckpt': 'logs/poke_encoder/ckpt/plants_64/0/epoch=60-lpips-val=0.183.ckpt', 34 | 'model_name': 'plants_64', 35 | 'tgt_name': 'plants_64' 36 | }, 37 | 'h36m-ss64-bn8x8x64-endpoint10f-np5': { 38 | 'ckpt': 'logs/poke_encoder/ckpt/h36m_64/0/epoch=16-lpips-val=0.073.ckpt', 39 | 'model_name': 'h36m_64' 40 | }, 41 | } 42 | first_stage_models = { 43 | 'plants-ss128-bn64-mf10' : { 44 | 'ckpt': 'logs/first_stage/ckpt/plants_128/0/epoch=17-FVD-val=65.191.ckpt', 45 | 'model_name': 'plants_128', 46 | 'tgt_name':'plants_128' 47 | }, 48 | 'h36m-ss128-bn64-mf10' : { 49 | 'ckpt': 'logs/first_stage/ckpt/h36m_128/0/epoch=13-FVD-val=109.079.ckpt', 50 | 'model_name': 'h36m_128', 51 | 'tgt_name':'h36m_128' 52 | }, 53 | 'taichi-ss128-bn32-mf10': { 54 | 'ckpt': 'logs/first_stage/ckpt/taichi_128/0/epoch=10-FVD-val=157.258.ckpt', 55 | 'model_name': 'taichi_128', 56 | 'tgt_name': 'taichi_128' 57 | }, 58 | 'plants-ss64-bn32-mf10': { 59 | 'ckpt': 'logs/first_stage/ckpt/plants_64/0/epoch=18-FVD-val=61.761.ckpt', 60 | 'model_name': 'plants_64', 61 | 'tgt_name': 'plants_64' 62 | }, 63 | 'h36m-ss64-bn64-mf10': { 64 | 'ckpt': 'logs/first_stage/ckpt/h36m_64/0/epoch=18-FVD-val=108.995.ckpt', 65 | 'model_name': 'h36m_64' 66 | }, 67 | 'iper-ss64-bn32-mf10': { 68 | # run name is false here, model was trained with z_dim = 32, as indicated in the dict key 69 | 'ckpt': 'logs/first_stage/ckpt/iper_64/0/epoch=28-FVD-val=67.734.ckpt', 70 | 'model_name': 'iper_64', 71 | 'tgt_name': 'iper_64' 72 | }, 73 | 'taichi-ss64-bn32-mf10': { 74 | # run name is false here, model was trained with z_dim = 32, as indicated in the dict key 75 | 'ckpt': 'logs/first_stage/ckpt/taichi_64/0/epoch=20-FVD-val=113.079.ckpt', 76 | 'model_name': 'taichi_64', 77 | 'tgt_name': 'taichi_64' 78 | }, 79 | 'iper-ss128-bn32-mf10-complex': { 80 | # run name is false here, model was trained with z_dim = 32, as indicated in the dict key 81 | 'ckpt': 'logs/first_stage/ckpt/iper_128/0/epoch=17-FVD-val=61.491.ckpt', 82 | 'model_name': 'iper_128', 83 | 'tgt_name': 'iper_128' 84 | }, 85 | } 86 | conditioner_models = { 87 | 'plants-ss128-bn64': { 88 | 'ckpt': 'logs/img_encoder/ckpt/plants_128/0/epoch=71-lpips-val=0.051.ckpt', 89 | 'model_name': 'plants_128', 90 | 'tgt_name': 'plants_128' 91 | }, 92 | 'iper-ss128-bn64': { 93 | 'ckpt': 'logs/img_encoder/ckpt/iper_128/0/epoch=12-lpips-val=0.026.ckpt', 94 | 'model_name': 'iper_128', 95 | 'tgt_name': 'iper_128' 96 | }, 97 | 'h36m-ss128-bn64': { 98 | 'ckpt': 'logs/img_encoder/ckpt/h36m_128/0/epoch=12-lpips-val=0.067.ckpt', 99 | 'model_name': 'h36m_128', 100 | 'tgt_name': 'h36m_128' 101 | }, 102 | 'plants-ss64-bn64': { 103 | 'ckpt': 'logs/img_encoder/ckpt/plants_64/0/last.ckpt', 104 | 'model_name': 'plants_64', 105 | 'tgt_name': 'plants_64' 106 | }, 107 | 'iper-ss64-bn64': { 108 | 'ckpt': 'logs/img_encoder/ckpt/iper_64/0/last.ckpt', 109 | 'model_name': 'iper_64', 110 | 'tgt_name': 'iper_64' 111 | }, 112 | 'h36m-ss64-bn64': { 113 | 'ckpt': 'logs/img_encoder/ckpt/h36m_64/0/last.ckpt', 114 | 'model_name': 'h36m_64' 115 | }, 116 | 'taichi-ss128-bn64': { 117 | 'ckpt': 'logs/img_encoder/ckpt/taichi_128/0/epoch=8-lpips-val=0.110.ckpt', 118 | 'model_name': 'taichi_128', 119 | 'tgt_name': 'taichi_128' 120 | }, 121 | 'taichi-ss64-bn64': { 122 | 'ckpt': 'logs/img_encoder/ckpt/taichi_64/0/epoch=14-lpips-val=0.006.ckpt', 123 | 'model_name': 'taichi_64', 124 | 'tgt_name': 'taichi_64' 125 | }, 126 | } 127 | 128 | flow_conditioner_models ={} -------------------------------------------------------------------------------- /testing/eval_models.py: -------------------------------------------------------------------------------- 1 | from utils.general import get_logger_old 2 | import os 3 | from os import path 4 | import argparse 5 | 6 | 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--gpu", type=int, required=True, help="The target device.") 13 | parser.add_argument('-t','--test',required=True, type=str, choices=['fvd','accuracy','diversity', 'kps_acc'],help="Which test to conduct.") 14 | args = parser.parse_args() 15 | 16 | 17 | 18 | with open("config/model_names.txt", "r") as f: 19 | model_names = f.readlines() 20 | model_names = [m for m in model_names if not m.startswith("#")] 21 | file = path.basename(__file__) 22 | logger = get_logger_old(file) 23 | 24 | 25 | gpu = args.gpu 26 | 27 | for n in model_names: 28 | n = n.rstrip() 29 | logger.info(f'Conducting experiment "{args.test}" for model {n}') 30 | 31 | try: 32 | test_cmd = f"python -W ignore main.py --config config/second_stage.yaml --gpus {gpu} --model_name {n} --test {args.test}" 33 | if args.test == 'fvd' and "LD_LIBRARY_PATH" in os.environ: 34 | test_cmd = f'LD_LIBRARY_PATH={os.environ["LD_LIBRARY_PATH"]} ' + test_cmd 35 | os.system(test_cmd) 36 | except Exception as e: 37 | logger.error(e) 38 | logger.info("next model") 39 | continue 40 | 41 | logger.info("finished") -------------------------------------------------------------------------------- /testing/evaluate_diversity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import argparse 4 | from os import path 5 | from tqdm import tqdm 6 | import os 7 | 8 | from utils.metrics import metric_vgg16, compute_div_score, compute_div_score_mse, compute_div_score_lpips 9 | from utils.posenet_wrapper import PoseNetWrapper 10 | from utils.general import get_logger_old 11 | 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument("-p", "--path", type=str, 18 | default="/export/data/ablattma/visual_poking/savp/h36m_kth_params/samples_divscore/fake_samples.npy", 19 | help="PAth to the samples file.") 20 | parser.add_argument("--gpu", type=int, required=True, help="The target device.") 21 | parser.add_argument('-r','--repr',type=str,default='vgg_features',choices=['keypoints','vgg_features'],help='The representation which shall be used for diversity calculation.') 22 | 23 | args = parser.parse_args() 24 | 25 | # name='' plans_test 26 | # path = '/export/data/ablattma/visual_poking/savp/h36m_kth_params/samples_divscore/fake_samples.npy' 27 | #'/export/scratch/mdorkenw/results/ICCV/diversity_srvp_iPER.npy' 28 | #'/export/data/ablattma/visual_poking/savp/plans_test /samples_divscore/fake_samples.npy' 29 | # device = 5 30 | if'DATAPATH' in os.environ: 31 | args.path = path.join(os.environ['DATAPATH'],args.path[1:]) 32 | 33 | file = path.basename(__file__) 34 | logger = get_logger_old(file) 35 | 36 | 37 | videos = np.load(args.path) 38 | print(f'Range check before possible normalization! max: {videos.max()}; min: {videos.min()}') 39 | # videos shape is assumed to be (n_examples,n_samples_per_exmpl,sequence_length,channels, h,w) 40 | if videos.shape[0] < videos.shape[1]: 41 | videos = np.swapaxes(videos,0,1) 42 | 43 | if videos.max()>1.: 44 | videos = (videos.astype(float) / 127.5) - 1. 45 | 46 | if videos.shape[-1] == 3: 47 | videos = np.moveaxis(videos,(0,1,2,3,4,5),(0,1,2,4,5,3)) 48 | 49 | assert videos.ndim == 6 50 | print(f'Range check after possible normalization! max: {videos.max()}; min: {videos.min()}') 51 | 52 | dev = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') 53 | 54 | videos = torch.from_numpy(videos).to(torch.float32) 55 | if args.repr == 'vgg_features': 56 | 57 | 58 | 59 | logger.info('Using vgg features as similarity representation') 60 | vgg = metric_vgg16().to(dev) 61 | divl = compute_div_score(videos,vgg,device=dev) 62 | 63 | divl = np.asarray(divl).mean() 64 | logger.info(f'Average cosine distance in vgg features space {divl}') 65 | 66 | else: 67 | config = {'data':{'spatial_size': (videos.shape[-2],videos.shape[-1])}} 68 | posenet = PoseNetWrapper(config) 69 | posenet.eval() 70 | posenet.to(dev) 71 | logger.info('Using keypoints as similarity representation') 72 | 73 | n_ex, n_samples, seq_length, c, h, w = videos.shape 74 | 75 | divl = [] 76 | with torch.no_grad(): 77 | for video in tqdm(videos, f'Computing diversity score for {n_ex} examples with {n_samples} samples.'): 78 | 79 | video = video.to(dev).reshape(-1,*video.shape[2:]) 80 | kps_raw = posenet(video) 81 | kps_abs, kps_rel = posenet.postprocess(kps_raw) 82 | 83 | for j in range(n_samples): 84 | for k in range(n_samples): 85 | if j != k: 86 | f = kps_rel.reshape(n_samples, seq_length, *kps_rel.shape[1:]) 87 | divl.append(np.linalg.norm((f[j]-f[k]).reshape(-1,2)).mean()) 88 | 89 | divl = np.asarray(divl).mean() 90 | logger.info(f'Average euclidean distance in keypoint space {divl}') 91 | 92 | div_score_mse = compute_div_score_mse(videos, device=dev) 93 | div_score_lpips = compute_div_score_lpips(videos, device=dev) 94 | 95 | text = f'Similarity measure_vgg: {divl}; similarity measure mse: {div_score_mse}; similarity measure lpips: {div_score_lpips}\n' 96 | 97 | print(text) 98 | 99 | -------------------------------------------------------------------------------- /testing/frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Minimal Reference implementation for the Frechet Video Distance (FVD). 18 | 19 | FVD is a metric for the quality of video generation models. It is inspired by 20 | the FID (Frechet Inception Distance) used for images, but uses a different 21 | embedding to be better suitable for videos. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | 27 | from __future__ import print_function 28 | 29 | 30 | import six 31 | import tensorflow.compat.v1 as tf 32 | import tensorflow_gan as tfgan 33 | import tensorflow_hub as hub 34 | 35 | 36 | def preprocess(videos, target_resolution): 37 | """Runs some preprocessing on the videos for I3D model. 38 | 39 | Args: 40 | videos: [batch_size, num_frames, height, width, depth] The videos to be 41 | preprocessed. We don't care about the specific dtype of the videos, it can 42 | be anything that tf.image.resize_bilinear accepts. Values are expected to 43 | be in the range 0-255. 44 | target_resolution: (width, height): target video resolution 45 | 46 | Returns: 47 | videos: [batch_size, num_frames, height, width, depth] 48 | """ 49 | videos_shape = videos.shape.as_list() 50 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 51 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) 52 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 53 | output_videos = tf.reshape(resized_videos, target_shape) 54 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 55 | return scaled_videos 56 | 57 | 58 | def _is_in_graph(tensor_name): 59 | """Checks whether a given tensor does exists in the graph.""" 60 | try: 61 | tf.get_default_graph().get_tensor_by_name(tensor_name) 62 | except KeyError: 63 | return False 64 | return True 65 | 66 | class Embedder: 67 | 68 | def __init__(self, videos): 69 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 70 | videos.name).replace(":", "_") 71 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 72 | self.model = hub.Module(module_spec, name=module_name) 73 | 74 | def create_id3_embedding(self,videos): 75 | """Embeds the given videos using the Inflated 3D Convolution network. 76 | 77 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the 78 | first call. 79 | 80 | Args: 81 | videos: [batch_size, num_frames, height=224, width=224, depth=3]. 82 | Expected range is [-1, 1]. 83 | 84 | Returns: 85 | embedding: [batch_size, embedding_size]. embedding_size depends 86 | on the model used. 87 | 88 | Raises: 89 | ValueError: when a provided embedding_layer is not supported. 90 | """ 91 | 92 | batch_size = 16 93 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 94 | 95 | 96 | # Making sure that we import the graph separately for 97 | # each different input video tensor. 98 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 99 | videos.name).replace(":", "_") 100 | 101 | assert_ops = [ 102 | tf.Assert( 103 | tf.reduce_max(videos) <= 1.001, 104 | ["max value in frame is > 1", videos]), 105 | tf.Assert( 106 | tf.reduce_min(videos) >= -1.001, 107 | ["min value in frame is < -1", videos]), 108 | tf.assert_equal( 109 | tf.shape(videos)[0], 110 | batch_size, ["invalid frame batch size: ", 111 | tf.shape(videos)], 112 | summarize=6), 113 | ] 114 | with tf.control_dependencies(assert_ops): 115 | videos = tf.identity(videos) 116 | 117 | module_scope = "%s_apply_default/" % module_name 118 | 119 | # To check whether the module has already been loaded into the graph, we look 120 | # for a given tensor name. If this tensor name exists, we assume the function 121 | # has been called before and the graph was imported. Otherwise we import it. 122 | # Note: in theory, the tensor could exist, but have wrong shapes. 123 | # This will happen if create_id3_embedding is called with a frames_placehoder 124 | # of wrong size/batch size, because even though that will throw a tf.Assert 125 | # on graph-execution time, it will insert the tensor (with wrong shape) into 126 | # the graph. This is why we need the following assert. 127 | video_batch_size = int(videos.shape[0]) 128 | assert video_batch_size in [batch_size, -1, None], "Invalid batch size" 129 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 130 | if not _is_in_graph(tensor_name): 131 | # i3d_model = hub.Module(module_spec, name=module_name) 132 | self.model(videos) 133 | 134 | # gets the kinetics-i3d-400-logits layer 135 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 136 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) 137 | 138 | return tensor 139 | 140 | 141 | def calculate_fvd(real_activations, 142 | generated_activations): 143 | """Returns a list of ops that compute metrics as funcs of activations. 144 | 145 | Args: 146 | real_activations: [num_samples, embedding_size] 147 | generated_activations: [num_samples, embedding_size] 148 | 149 | Returns: 150 | A scalar that contains the requested FVD. 151 | """ 152 | return tfgan.eval.frechet_classifier_distance_from_activations( 153 | real_activations, generated_activations) 154 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import Callback 3 | from pytorch_lightning.callbacks import ModelCheckpoint 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | import wandb 7 | import umap 8 | from os import path 9 | 10 | class BestCkptsToYaml(Callback): 11 | def __init__(self,ckpt_callback:ModelCheckpoint): 12 | #super.__init__() 13 | assert isinstance(ckpt_callback,ModelCheckpoint) 14 | self.ckpt_cb = ckpt_callback 15 | 16 | 17 | def on_train_end(self,trainer, pl_module): 18 | if path.isdir(self.ckpt_cb.dirpath): 19 | self.ckpt_cb.to_yaml() 20 | 21 | def on_validation_epoch_end(self, trainer, pl_module): 22 | if path.isdir(self.ckpt_cb.dirpath): 23 | self.ckpt_cb.to_yaml() 24 | 25 | 26 | 27 | class UMAP(Callback): 28 | def __init__(self, batch_frequency, n_samples): 29 | super().__init__() 30 | self.batch_freq = batch_frequency 31 | self.n_samples = n_samples 32 | 33 | def log_umap(self, pl_module, first_stage_data, split="train"): 34 | 35 | is_train = pl_module.training 36 | if is_train: 37 | pl_module.eval() 38 | 39 | n_samples = self.n_samples // first_stage_data.size(0) 40 | dloader = pl_module.train_dataloader() if split == 'train' else pl_module.val_dataloader() 41 | z, z_m, z_p = [], [], [] 42 | while len(z) < n_samples: 43 | for batch_idx, batch in enumerate(dloader): 44 | if len(z) > n_samples: 45 | break 46 | with torch.no_grad(): 47 | seq = batch['seq'].to(pl_module.device) 48 | ## Create embeddings from first stage model 49 | posterior = pl_module.first_stage_model.encode(seq) 50 | z_m.append(posterior.mode().squeeze(-1).squeeze(-1).detach().cpu().numpy()) 51 | z_p.append(posterior.sample().squeeze(-1).squeeze(-1).detach().cpu().numpy()) 52 | ## Create embeddings from flow by reversing direction 53 | gaussian = torch.randn_like(z_m) 54 | embed = pl_module.flow(gaussian, seq[:, 0], reverse=True).squeeze(-1).squeeze(-1) 55 | z.append(embed.detach().cpu().numpy()) 56 | z = np.concatenate(z) 57 | z_m = np.concatenate(z_m) 58 | z_p = np.concatenate(z_p) 59 | umap_transform = umap.UMAP() 60 | transformation = umap_transform.fit(z_m) 61 | transformed_z = transformation.transform(z_m) 62 | plt.scatter(transformed_z[:, 0], transformed_z[:, 1], c='blue', s=1, marker='o', label="mean", alpha=.3, rasterized=True) 63 | plt.scatter(np.mean(transformed_z[:, 0]), np.mean(transformed_z[:, 1]), c='blue', s=20, marker='o', label="mean mean", alpha=.3) 64 | transformed_z = transformation.transform(z) 65 | plt.scatter(transformed_z[:, 0], transformed_z[:, 1], c='red', s=1, marker='v', label="INN samples", alpha=.3, rasterized=True) 66 | plt.scatter(np.mean(transformed_z[:, 0]), np.mean(transformed_z[:, 1]), c='red', s=20, marker='o', label="INN samples mean", alpha=.3) 67 | transformed_z = transformation.transform(z_p) 68 | plt.scatter(transformed_z[:, 0], transformed_z[:, 1], c='green', s=1, marker='s', label="posterior", alpha=.3, rasterized=True) 69 | plt.scatter(np.mean(transformed_z[:, 0]), np.mean(transformed_z[:, 1]), c='green', s=20, marker='o', label="posterior mean", alpha=.3) 70 | plt.legend() 71 | plt.axis('off') 72 | plt.ioff() 73 | pl_module.logger.experiment.log({"Umap plot " + split: wandb.Image(plt, caption="Umap plot")}) 74 | plt.close() 75 | if is_train: 76 | pl_module.train() 77 | 78 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 79 | self.log_umap(pl_module, batch, batch_idx, split="train") 80 | 81 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 82 | self.log_umap(pl_module, batch, batch_idx, split="val") -------------------------------------------------------------------------------- /utils/flownet_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from PIL import Image 4 | from models.flownet2.models import * 5 | 6 | from torchvision import transforms 7 | import matplotlib.pyplot as plt 8 | import argparse 9 | 10 | from utils.general import get_gpu_id_with_lowest_memory 11 | 12 | 13 | class FlownetPipeline: 14 | def __init__(self): 15 | super(FlownetPipeline, self).__init__() 16 | 17 | def load_flownet(self, args, device): 18 | """ 19 | 20 | :param args: args from argparser 21 | :return: The flownet pytorch model 22 | """ 23 | # load model savefile 24 | save = torch.load( 25 | "/export/scratch/compvis/datasets/plants/pretrained_models/FlowNet2_checkpoint.pth.tar") 26 | model = FlowNet2(args, batchNorm=False) 27 | 28 | untrained_statedict = model.state_dict() 29 | 30 | # load it into proper clean model 31 | model.load_state_dict(save["state_dict"]) 32 | model.eval() 33 | return model.to(device) 34 | 35 | def preprocess_image(self, img, img2, channelOrder="RGB",spatial_size= None): 36 | """ This preprocesses the images for FlowNet input. Preserves the height and width order! 37 | 38 | :param channelOrder: RGB(A) or BGR 39 | :param img: The first image in form of (W x H x RGBA) or (H x W x RGBA) 40 | :param img2: The first image in form of (W x H x RGBA) or (H x W x RGBA) 41 | :return: The preprocessed input for the prediction (BGR x Img# x W x H) or (BGR x Img# x H x W) 42 | """ 43 | # ToTensor transforms from (H x W x C) => (C x H x W) 44 | # also automatically casts into range [0, 1] 45 | if spatial_size is None: 46 | img, img2 = transforms.ToTensor()(img)[:3], transforms.ToTensor()(img2)[:3] 47 | else: 48 | ts = transforms.Compose([transforms.ToPILImage(), transforms.Resize(size=spatial_size,interpolation=Image.BILINEAR),transforms.ToTensor()]) 49 | img, img2 = ts(img)[:3],ts(img2)[:3] 50 | if channelOrder == "RGB": 51 | img, img2 = img[[2, 1, 0]], img2[[2, 1, 0]] 52 | 53 | # Cast to proper shape (Batch x BGR x #Img x H x W) 54 | s = img.shape 55 | img, img2 = img[:, :int(s[1] / 64) * 64, :int(s[2] / 64) * 64], \ 56 | img2[:, :int(s[1] / 64) * 64,:int(s[2] / 64) * 64] 57 | stacked = torch.cat([img[:, None], img2[:, None]], dim=1) 58 | return stacked 59 | 60 | def predict(self, model, stacked, spatial_size=None): 61 | """ 62 | 63 | :param stacked: The two input images. (Batch x BGR x Img# x H x W) 64 | :return: The flow result (2 x W x H) 65 | """ 66 | # predict 67 | model.eval() 68 | prediction = model(stacked) 69 | out_size = float(prediction.shape[-1]) 70 | if spatial_size is not None: 71 | prediction = F.interpolate( 72 | prediction.cpu(), size=(spatial_size,spatial_size), mode="bilinear" 73 | ) 74 | # rescale to make it fit to new shape (not grave, if this is skipped as flow is normalized anyways later) 75 | prediction = prediction / (out_size / spatial_size) 76 | flow = prediction[0] 77 | return flow 78 | 79 | def show_results(self, prediction, with_ampl=False): 80 | """ 81 | 82 | prediction (Tensor): The predicted flow (2 x W x H) 83 | :return: plots 84 | """ 85 | 86 | zeros = torch.zeros((1, prediction.shape[1], prediction.shape[2])) 87 | if with_ampl: 88 | ampl = torch.sum(prediction * prediction, dim=0) 89 | ampl = ampl.squeeze() 90 | else: 91 | ampl = torch.cat([prediction, zeros], dim=0) 92 | ampl -= ampl.min() 93 | ampl /= ampl.max() 94 | 95 | # show image 96 | im = transforms.ToPILImage()(ampl) 97 | if with_ampl: 98 | plt.imshow(im, cmap='gray') 99 | else: 100 | plt.imshow(im) 101 | 102 | 103 | if __name__ == "__main__": 104 | # parse args 105 | parser = argparse.ArgumentParser(description='Process some integers.') 106 | # always 1.0, because pytorch toTensor automatically converts into range [0.0, 1.0] 107 | parser.add_argument("--rgb_max", type=float, default=1.) 108 | parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).') 109 | parser.add_argument('--fp16_scale', type=float, default=1024., 110 | help='Loss scaling, positive power of 2 values can improve fp16 convergence.') 111 | args = parser.parse_args() 112 | 113 | # load test images in BGR mode 114 | img, img2 = np.asarray(Image.open(f"/export/data/ablattma/Datasets/plants/processed/hoch_misc1/frame_0.png")), \ 115 | np.asarray(Image.open(f"/export/data/ablattma/Datasets/plants/processed/hoch_misc1/frame_100.png")) 116 | 117 | # load Flownet 118 | pipeline = FlownetPipeline() 119 | flownet_device = get_gpu_id_with_lowest_memory() 120 | flownet = pipeline.load_flownet(args, flownet_device) 121 | 122 | # process to show flow 123 | stacked = pipeline.preprocess_image(img, img2).to(flownet_device) 124 | prediction = pipeline.predict(flownet, stacked[None]).cpu() 125 | pipeline.show_results(prediction) 126 | plt.show() -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import subprocess 4 | import logging 5 | import yaml 6 | import logging.config 7 | import inspect 8 | from os import walk 9 | import numpy as np 10 | import coloredlogs 11 | import multiprocessing as mp 12 | from threading import Thread 13 | from queue import Queue 14 | from collections import abc 15 | import cv2 16 | from torch import nn 17 | # import kornia 18 | 19 | 20 | def get_member(model, name): 21 | if isinstance(model, nn.DataParallel): 22 | module = model.module 23 | else: 24 | module = model 25 | 26 | return getattr(module, name) 27 | 28 | def preprocess_image(img,swap_channels=False): 29 | """ 30 | 31 | :param img: numpy array of shape (H,W,3) 32 | :param swap_channels: True, if channelorder is BGR 33 | :return: 34 | """ 35 | if swap_channels: 36 | img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 37 | 38 | # this seems to be possible as flownet2 outputs only images which can be divided by 64 39 | shape = img.shape 40 | img = img[:int(shape[0] / 64) * 64,:int(shape[1] / 64) * 64] 41 | 42 | return img 43 | 44 | class LoggingParent: 45 | def __init__(self): 46 | super().__init__() 47 | # find project root 48 | mypath = inspect.getfile(self.__class__) 49 | mypath = "/".join(mypath.split("/")[:-1]) 50 | found = False 51 | while mypath!="" and not found: 52 | f = [] 53 | for (dirpath, dirnames, filenames) in walk(mypath): 54 | f.extend(filenames) 55 | break 56 | if ".gitignore" in f: 57 | found = True 58 | continue 59 | mypath = "/".join(mypath.split("/")[:-1]) 60 | project_root = mypath+"/" 61 | # Put it together 62 | file = inspect.getfile(self.__class__).replace(project_root, "").replace("/", ".").split(".py")[0] 63 | cls = str(self.__class__)[8:-2] 64 | cls = str(cls).replace("__main__.", "").split(".")[-1] 65 | self.logger = get_logger() 66 | 67 | def get_gpu_id_with_lowest_memory(index=0, target_gpus:list=None): 68 | # get info from nvidia-smi 69 | result = subprocess.check_output( 70 | [ 71 | 'nvidia-smi', '--query-gpu=memory.free', 72 | '--format=csv,nounits,noheader' 73 | ], encoding='utf-8') 74 | gpu_memory = [int(x) for x in result.strip().split('\n')] 75 | 76 | # get the one with the lowest usage 77 | if target_gpus is None: 78 | indices = np.argsort(gpu_memory) 79 | else: 80 | indices = [i for i in np.argsort(gpu_memory) if i in target_gpus] 81 | return torch.device(f"cuda:{indices[-index-1]}") 82 | 83 | 84 | iuhihfie_logger_loaded = False 85 | def get_logger_old(name): 86 | # setup logging 87 | global iuhihfie_logger_loaded 88 | if not iuhihfie_logger_loaded: 89 | with open(f'{os.path.dirname(os.path.abspath(__file__))}/logging.yaml', 'r') as f: 90 | log_cfg = yaml.load(f.read(), Loader=yaml.FullLoader) 91 | logging.config.dictConfig(log_cfg) 92 | iuhihfie_logger_loaded = True 93 | logger = logging.getLogger(name) 94 | coloredlogs.install(logger=logger, level="DEBUG") 95 | return logger 96 | 97 | 98 | def get_logger(): 99 | logger = logging.getLogger("pytorch_lightning.core") 100 | logger.setLevel('DEBUG') 101 | # coloredlogs.install(logger=logger, level='DEBUG') 102 | return logger 103 | 104 | 105 | def save_model_to_disk(path, models, epoch): 106 | for i, model in enumerate(models): 107 | tmp_path = path 108 | if not os.path.exists(path): 109 | os.makedirs(path) 110 | tmp_path = tmp_path + f"model_{i}-epoch{epoch}" 111 | torch.save(model.state_dict(), tmp_path) 112 | 113 | 114 | def _do_parallel_data_prefetch(func, Q, data, idx): 115 | # create dummy dataset instance 116 | 117 | # run prefetching 118 | 119 | res = func(*data) 120 | Q.put([idx, res]) 121 | Q.put("Done") 122 | 123 | 124 | def parallel_data_prefetch( 125 | func: callable, data, n_proc, target_data_type="ndarray",cpu_intensive=True 126 | ): 127 | static_args = None 128 | if target_data_type not in ["ndarray", "list"]: 129 | raise ValueError( 130 | "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 131 | ) 132 | if isinstance(data, np.ndarray) and target_data_type == "list": 133 | raise ValueError("list expected but function got ndarray.") 134 | elif isinstance(data, abc.Iterable) and not isinstance(data,tuple): 135 | if isinstance(data, dict): 136 | print( 137 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 138 | ) 139 | data = list(data.values()) 140 | if target_data_type == "ndarray": 141 | data = np.asarray(data) 142 | else: 143 | data = list(data) 144 | elif isinstance(data,tuple): 145 | static_args = data[1:] 146 | data = data[0] 147 | print('Using static args.') 148 | else: 149 | raise TypeError( 150 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 151 | ) 152 | 153 | if cpu_intensive: 154 | Q = mp.Queue(1000) 155 | proc = mp.Process 156 | else: 157 | Q = Queue(1000) 158 | proc = Thread 159 | # spawn processes 160 | if target_data_type == "ndarray": 161 | arguments = [ 162 | [func, Q, (part,) if static_args is None else (part,*static_args), i] 163 | for i, part in enumerate(np.array_split(data, n_proc)) 164 | ] 165 | else: 166 | step = ( 167 | int(len(data) / n_proc + 1) 168 | if len(data) % n_proc != 0 169 | else int(len(data) / n_proc) 170 | ) 171 | arguments = [ 172 | [func, Q, (part,) if static_args is None else (part,*static_args), i] 173 | for i, part in enumerate( 174 | [data[i : i + step] for i in range(0, len(data), step)] 175 | ) 176 | ] 177 | 178 | 179 | processes = [] 180 | for i in range(n_proc): 181 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 182 | processes += [p] 183 | 184 | # start processes 185 | print(f"Start prefetching...") 186 | import time 187 | 188 | start = time.time() 189 | gather_res = [[] for _ in range(n_proc)] 190 | try: 191 | for p in processes: 192 | p.start() 193 | 194 | 195 | k = 0 196 | while k < n_proc: 197 | # get result 198 | res = Q.get() 199 | if res == "Done": 200 | k += 1 201 | else: 202 | gather_res[res[0]] = res[1] 203 | 204 | except Exception as e: 205 | print("Exception: ", e) 206 | for p in processes: 207 | p.terminate() 208 | 209 | raise e 210 | finally: 211 | for p in processes: 212 | p.join() 213 | print(f"Prefetching complete. [{time.time() - start} sec.]") 214 | 215 | if not isinstance(gather_res[0], np.ndarray): 216 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 217 | 218 | # order outputs 219 | return np.concatenate(gather_res, axis=0) 220 | 221 | def linear_var( 222 | act_it, start_it, end_it, start_val, end_val, clip_min, clip_max 223 | ): 224 | act_val = ( 225 | float(end_val - start_val) / (end_it - start_it) * (act_it - start_it) 226 | + start_val 227 | ) 228 | return np.clip(act_val, a_min=clip_min, a_max=clip_max) 229 | 230 | def get_patches(seq_batch,weights,config,fg_value, logger = None): 231 | """ 232 | 233 | :param seq_batch: Batch of videos 234 | :param weights: batch of flow weights for the videos 235 | :param config: config, containing spatial_size 236 | :param fg_value: foreground value of the weight map 237 | :return: 238 | """ 239 | import kornia 240 | weights_as_bool = torch.eq(weights,fg_value) 241 | cropped = [] 242 | for vid,weight in zip(seq_batch,weights_as_bool): 243 | vid_old = vid 244 | weight_ids = torch.nonzero(weight,as_tuple=True) 245 | try: 246 | min_y = weight_ids[0].min() 247 | max_y = weight_ids[0].max() 248 | min_x = weight_ids[1].min() 249 | max_x = weight_ids[1].max() 250 | vid = vid[...,min_y:max_y,min_x:max_x] 251 | if len(vid.shape) < 4: 252 | data_4d = vid[None,...] 253 | vid = kornia.transform.resize(data_4d, config["spatial_size"]) 254 | cropped.append(vid.squeeze(0)) 255 | else: 256 | vid = kornia.transform.resize(vid,config["spatial_size"]) 257 | cropped.append(vid) 258 | except Exception as e: 259 | if logger is None: 260 | print(e) 261 | else: 262 | logger.warn(f'Catched the following exception in "get_patches": {e.__class__.__name__}: {e}. Skip patching this sample...') 263 | cropped.append(vid_old) 264 | 265 | 266 | return torch.stack(cropped,dim=0) 267 | 268 | 269 | if __name__ == "__main__": 270 | print(get_gpu_id_with_lowest_memory()) 271 | -------------------------------------------------------------------------------- /utils/logging.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | 3 | formatters: 4 | standard: 5 | format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 6 | error: 7 | format: "%(levelname)s %(name)s.%(funcName)s(): %(message)s" 8 | 9 | handlers: 10 | console: 11 | class: logging.StreamHandler 12 | level: DEBUG 13 | formatter: standard 14 | 15 | root: 16 | level: DEBUG 17 | handlers: [console] 18 | propagate: no 19 | 20 | loggers: 21 | 22 | matplotlib.legend: 23 | level: ERROR 24 | handlers: [] 25 | propagate: yes 26 | 27 | pytorch_lightning: 28 | level: WARN 29 | handlers: [] 30 | propagate: yes 31 | 32 | PIL.PngImagePlugin: 33 | level: ERROR 34 | handlers: [] 35 | propagate: yes 36 | 37 | ignite: 38 | level: WARN 39 | handlers: [] 40 | propagate: yes -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import torchvision 5 | 6 | class VGG(torch.nn.Module): 7 | def __init__(self, requires_grad=False): 8 | super().__init__() 9 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 10 | self.mean = [0.485, 0.456, 0.406] 11 | self.std = [0.229, 0.224, 0.225] 12 | self.slice1 = torch.nn.Sequential() 13 | self.slice2 = torch.nn.Sequential() 14 | self.slice3 = torch.nn.Sequential() 15 | self.slice4 = torch.nn.Sequential() 16 | self.slice5 = torch.nn.Sequential() 17 | for x in range(2): 18 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 19 | for x in range(2, 7): 20 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(7, 12): 22 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(12, 21): 24 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(21, 30): 26 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 27 | if not requires_grad: 28 | for param in self.parameters(): 29 | param.requires_grad = False 30 | 31 | def forward(self, X): 32 | # X = self.normalize(X) 33 | h_relu1 = self.slice1(X) 34 | h_relu2 = self.slice2(h_relu1) 35 | h_relu3 = self.slice3(h_relu2) 36 | h_relu4 = self.slice4(h_relu3) 37 | h_relu5 = self.slice5(h_relu4) 38 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 39 | return out 40 | 41 | def normalize(self, x): 42 | x = x.permute(1, 0, 2, 3) 43 | for i in range(3): 44 | x[i] = x[i] * self.std[i] + self.mean[i] 45 | return x.permute(1, 0, 2, 3) 46 | 47 | def KL(mu, logvar): 48 | return -0.5 * torch.mean(torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), axis=1)) 49 | 50 | def kl_conv(mu,logvar): 51 | mu = mu.reshape(mu.size(0),-1) 52 | logvar = logvar.reshape(logvar.size(0),-1) 53 | 54 | var = torch.exp(logvar) 55 | 56 | return torch.mean(0.5 * torch.sum(torch.pow(mu, 2) + var - 1.0 - logvar, dim=-1)) 57 | 58 | def fmap_loss(fmap1, fmap2, loss): 59 | recp_loss = 0 60 | for idx in range(len(fmap1)): 61 | if loss == 'l1': 62 | recp_loss += torch.mean(torch.abs((fmap1[idx] - fmap2[idx]))) 63 | if loss == 'l2': 64 | recp_loss += torch.mean((fmap1[idx] - fmap2[idx]) ** 2) 65 | return recp_loss / len(fmap1) 66 | 67 | class VGGLoss(nn.Module): 68 | def __init__(self, weighted=False): 69 | super(VGGLoss, self).__init__() 70 | self.vgg = VGG().cuda() 71 | self.weighted = weighted 72 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 73 | self.criterion = nn.L1Loss() 74 | 75 | def forward(self, x, y): 76 | fmap1, fmap2 = self.vgg(x), self.vgg(y) 77 | if self.weighted: 78 | recp_loss = 0 79 | for idx in range(len(fmap1)): 80 | recp_loss += self.weights[idx] * self.criterion(fmap2[idx], fmap1[idx]) 81 | return recp_loss 82 | else: 83 | return fmap_loss(fmap1, fmap2, loss='l1') -------------------------------------------------------------------------------- /utils/posenet_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from kornia.geometry.transform import Resize 4 | from kornia.enhance import Normalize 5 | import yaml 6 | from dotmap import DotMap 7 | import os 8 | import numpy as np 9 | 10 | from models.pose_estimator.lib.models.pose_resnet import get_pose_net 11 | from models.pose_estimator.lib.core.inference import get_max_preds 12 | 13 | class PoseNetWrapper(nn.Module): 14 | def __init__(self,config): 15 | super().__init__() 16 | self.model_path = 'logs/pose_estimator/pose_resnet_152_256x256.pth' 17 | fp = os.path.dirname(os.path.realpath(__file__)) 18 | configpath = os.path.abspath(os.path.join(fp, "../config/posenet.yaml")) 19 | with open(configpath,'r') as f: 20 | cfg = yaml.load(f,Loader=yaml.FullLoader) 21 | 22 | self.cfg = DotMap(cfg) 23 | 24 | #self.cfg = '../config/posenet.yaml' 25 | self.config = config 26 | self.input_size = self.config['data']['spatial_size'] 27 | self.resize = Resize((256,256)) 28 | self.normalize = Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) 29 | 30 | self.posenet = get_pose_net(self.cfg,is_train=False) 31 | self.posenet.load_state_dict(torch.load(self.model_path,map_location='cpu'),strict=False) 32 | 33 | def prepocess_image(self,x): 34 | out = self.resize(x) 35 | out = (out + 1.) / 2 36 | out = self.normalize(out) 37 | 38 | return out 39 | 40 | def forward(self,x): 41 | self.posenet.eval() 42 | x = self.prepocess_image(x) 43 | out = self.posenet(x) 44 | if isinstance(out,list): 45 | out = out[-1] 46 | 47 | return out 48 | 49 | def postprocess(self,x): 50 | if not isinstance(x,np.ndarray): 51 | x = x.detach().cpu().numpy() 52 | out, _ = get_max_preds(x) 53 | 54 | #resize abs kps to input spatial size 55 | out_abs = out * (self.input_size[0] / 64) 56 | out_rel = out / 64 57 | 58 | return out_abs, out_rel 59 | 60 | 61 | 62 | 63 | 64 | if __name__ == '__main__': 65 | from torch.utils.data import DataLoader 66 | import numpy as np 67 | from os import path 68 | from tqdm import tqdm 69 | from os import makedirs 70 | import cv2 71 | 72 | from data import get_dataset 73 | from data.samplers import FixedLengthSampler 74 | from models.pose_estimator.tools.infer import save_batch_image_with_joints 75 | from utils.metrics import KPSMetric 76 | from utils.general import get_logger_old 77 | from utils.logging import put_text_to_video_row 78 | 79 | # load config 80 | fpath = path.dirname(path.realpath(__file__)) 81 | logger = get_logger_old(fpath) 82 | configpath = path.abspath(path.join(fpath, "../config/test_config.yaml")) 83 | with open(configpath, "r") as f: 84 | config = yaml.load(f, Loader=yaml.FullLoader) 85 | 86 | if config["fix_seed"]: 87 | seed = 42 88 | torch.manual_seed(42) 89 | torch.cuda.manual_seed(42) 90 | np.random.seed(42) 91 | # random.seed(opt.seed) 92 | torch.backends.cudnn.deterministic = True 93 | torch.manual_seed(42) 94 | rng = np.random.RandomState(42) 95 | 96 | dset, transforms = get_dataset(config["data"]) 97 | 98 | datakeys = ['images','keypoints_rel','keypoints_abs', 'sample_ids'] 99 | 100 | test_dataset = dset(transforms, datakeys, config["data"],train=False) 101 | 102 | def init_fn(worker_id): 103 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 104 | 105 | 106 | sampler = FixedLengthSampler(test_dataset, config['data']['batch_size'], shuffle=True, 107 | drop_last=True, weighting=test_dataset.obj_weighting, 108 | zero_poke=config['data']['zero_poke'], zero_poke_amount=config['data']["zero_poke_amount"]) 109 | loader = DataLoader(test_dataset, batch_sampler=sampler, num_workers=config['data']['n_workers'], worker_init_fn=init_fn) 110 | 111 | dev = torch.device(f'cuda:{config["gpu"]}' if torch.cuda.is_available() else 'cpu') 112 | 113 | model = PoseNetWrapper(config) 114 | model.to(dev) 115 | 116 | save_dir = f"../data/test_data/{test_dataset.__class__.__name__}" 117 | save_dir = path.abspath(path.join(fpath,save_dir)) 118 | 119 | n_expls_metric = config['n_exmpls_pose_metric'] 120 | 121 | kps_metric = KPSMetric(logger,n_samples=n_expls_metric,savedir=save_dir) 122 | kps_metric.to(dev) 123 | 124 | for id, batch in enumerate(tqdm(loader)): 125 | if id > n_expls_metric: 126 | break 127 | imgs = batch['images'].to(dev) 128 | kps_abs = batch['keypoints_abs'] 129 | kps_rel = batch['keypoints_rel'].to(dev) 130 | 131 | original_shape = imgs.shape 132 | if imgs.ndim == 5: 133 | imgs = imgs.reshape(-1,*imgs.shape[2:]) 134 | 135 | 136 | with torch.no_grad(): 137 | out_raw = model(imgs) 138 | pred_abs, pred_rel = model.postprocess(out_raw) 139 | 140 | 141 | imgs_with_gt = save_batch_image_with_joints(imgs,kps_abs.reshape(-1,*kps_abs.shape[2:]),[],None,nrow=1,return_image=True) 142 | imgs_with_pred = save_batch_image_with_joints(imgs,pred_abs,[],None,nrow=1,return_image=True) 143 | 144 | grid = np.concatenate([imgs_with_gt,imgs_with_pred],axis=0) 145 | grid = cv2.cvtColor(grid,cv2.COLOR_RGB2BGR) 146 | cv2.imwrite(path.join(save_dir,f'pose_exmpl_{id}.png'),grid) 147 | 148 | 149 | # restore time axis 150 | pred_abs = pred_abs.reshape(*original_shape[:2],*pred_abs.shape[1:]) 151 | pred_rel = torch.from_numpy(pred_rel.reshape(*original_shape[:2], *pred_rel.shape[1:])).to(dev) 152 | 153 | kps_metric.update(pred_rel[:,None],kps_rel[:,None]) 154 | 155 | 156 | mean_nn_kps = kps_metric.compute() 157 | logger.info(f'mean nn kps is {mean_nn_kps}') 158 | 159 | 160 | 161 | --------------------------------------------------------------------------------