├── scripts ├── seeds.txt ├── new_seeds.txt ├── magnet_gnn_b1.sh ├── magnet_gnn_b2.sh ├── mpnn_test_uniform.sh ├── mpnn_test_condensed.sh ├── magnet_gnn │ ├── magnet_gnn_2d_b1_64_irregular.sh │ ├── magnet_gnn_2d_b1_128_irregular.sh │ ├── magnet_gnn_2d_b1_256_irregular.sh │ ├── magnet_gnn_2d_b1_512_irregular.sh │ ├── magnet_gnn_2d_b1_64_irregular_concentrated.sh │ ├── magnet_gnn_2d_b1_128_irregular_concentrated.sh │ ├── magnet_gnn_2d_b1_256_irregular_concentrated.sh │ ├── magnet_gnn_2d_b1_512_irregular_concentrated.sh │ ├── magnet_gnn_2d_b1_64_regular.sh │ └── magnet_gnn_2d_b2_64_regular.sh ├── fno_2d │ ├── fno_2d_b1_64_regular.sh │ └── fno_2d_b2_64_regular.sh ├── mpnn_2d │ ├── mpnn_2d_b1_64_regular.sh │ ├── mpnn_2d_b2_64_regular.sh │ ├── mpnn_2d_b1_64_irregular.sh │ ├── mpnn_2d_b1_128_irregular.sh │ ├── mpnn_2d_b1_256_irregular.sh │ ├── mpnn_2d_b1_512_irregular.sh │ ├── new_seeds │ │ ├── mpnn_2d_b1_64_irregular.sh │ │ ├── mpnn_2d_b1_128_irregular.sh │ │ ├── mpnn_2d_b1_256_irregular.sh │ │ ├── mpnn_2d_b1_512_irregular.sh │ │ ├── mpnn_2d_b1_64_irregular_concentrated.sh │ │ ├── mpnn_2d_b1_128_irregular_concentrated.sh │ │ ├── mpnn_2d_b1_256_irregular_concentrated.sh │ │ └── mpnn_2d_b1_512_irregular_concentrated.sh │ ├── mpnn_2d_b1_64_irregular_concentrated.sh │ ├── mpnn_2d_b1_128_irregular_concentrated.sh │ ├── mpnn_2d_b1_256_irregular_concentrated.sh │ └── mpnn_2d_b1_512_irregular_concentrated.sh └── magnet_cnn_2d │ ├── magnet_cnn_2d_b1_64_regular.sh │ └── magnet_cnn_2d_b2_64_regular.sh ├── .gitignore ├── assets ├── magnet.jpg └── predictions.JPG ├── configs ├── trainer │ └── default.yaml ├── model │ ├── mpnn.yaml │ ├── mpnn_2d.yaml │ ├── fno_1d.yaml │ ├── fno_2d.yaml │ ├── magnet_gnn.yaml │ ├── magnet_cnn.yaml │ ├── magnet_cnn_2d.yaml │ └── magnet_cnn_no_interaction.yaml ├── datamodule │ ├── h5_datamodule.yaml │ ├── h5_datamodule_graph.yaml │ ├── h5_datamodule_2d.yaml │ ├── h5_datamodule_implicit.yaml │ ├── h5_datamodule_implicit_gnn.yaml │ ├── h5_datamodule_implicit_2d.yaml │ ├── h5_datamodule_graph_2d.yaml │ └── h5_datamodule_implicit_gnn_2d.yaml ├── config.yaml └── callbacks │ └── default.yaml ├── requirements.txt ├── models ├── factory.py ├── backbones │ ├── mlp.py │ └── edsr.py ├── fno_2d.py ├── fno_1d.py ├── magnet_cnn_no_interaction.py ├── mpnn.py ├── mpnn_2d.py ├── magnet_cnn.py └── magnet_gnn.py ├── tune.py ├── run.py ├── utils.py ├── README.md └── datamodule ├── dataset.py ├── dataset_2d.py ├── h5_datamodule.py └── h5_datamodule_2d.py /scripts/seeds.txt: -------------------------------------------------------------------------------- 1 | 42 2 | 21 3 | 10 4 | 5 5 | 2022 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | *.ipynb 4 | *.out 5 | *.svg -------------------------------------------------------------------------------- /scripts/new_seeds.txt: -------------------------------------------------------------------------------- 1 | 23564 2 | 65978945 3 | 313165 4 | 8796 5 | 325987 -------------------------------------------------------------------------------- /assets/magnet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaggbow/magnet/HEAD/assets/magnet.jpg -------------------------------------------------------------------------------- /assets/predictions.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaggbow/magnet/HEAD/assets/predictions.JPG -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | track_grad_norm: -1 3 | max_epochs: 100 4 | gpus: 1 5 | precision: 32 6 | strategy: null -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | comet_ml==3.25.0 2 | h5py==3.6.0 3 | hydra-core==1.2.0 4 | numpy==1.22.2 5 | omegaconf==2.1.1 6 | pytorch_lightning==1.5.9 7 | torch==1.10.2 8 | torch_geometric==2.0.3 9 | rich==11.1.0 -------------------------------------------------------------------------------- /configs/model/mpnn.yaml: -------------------------------------------------------------------------------- 1 | name: mpnn 2 | params: 3 | # Model hyperparameters 4 | hidden_features: 128 5 | hidden_layer: 5 6 | time_window: 16 7 | teacher_forcing: False 8 | neighbors: 3 9 | # Optimization hyperparameters 10 | factor: 0.3 11 | step_size: 50 12 | loss: l1 13 | lr: 0.001 14 | weight_decay: 0 -------------------------------------------------------------------------------- /configs/model/mpnn_2d.yaml: -------------------------------------------------------------------------------- 1 | name: mpnn_2d 2 | params: 3 | # Model hyperparameters 4 | hidden_features: 128 5 | hidden_layer: 5 6 | time_window: 10 7 | teacher_forcing: False 8 | neighbors: 4 9 | # Optimization hyperparameters 10 | factor: 0.3 11 | step_size: 50 12 | loss: l1 13 | lr: 0.001 14 | weight_decay: 0 -------------------------------------------------------------------------------- /configs/model/fno_1d.yaml: -------------------------------------------------------------------------------- 1 | name: fno_1d 2 | params: 3 | # Model hyperparameters 4 | modes: 12 5 | width: 256 6 | num_layers: 5 7 | time_history: 25 8 | time_future: 25 9 | teacher_forcing: True 10 | # Optimization hyperparameters 11 | factor: 0.3 12 | step_size: 50 13 | loss: l1 14 | lr: 0.001 15 | weight_decay: 0 -------------------------------------------------------------------------------- /configs/model/fno_2d.yaml: -------------------------------------------------------------------------------- 1 | name: fno_2d 2 | params: 3 | # Model hyperparameters 4 | modes_1: 12 5 | modes_2: 12 6 | width: 256 7 | num_layers: 5 8 | time_history: 10 9 | time_future: 10 10 | teacher_forcing: True 11 | # Optimization hyperparameters 12 | factor: 0.3 13 | step_size: 50 14 | loss: l1 15 | lr: 0.001 16 | weight_decay: 0 -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule.HDF5Datamodule 2 | name: h5_datamodule 3 | train_path: /home/mila/o/oussama.boussif/pde_oned/data/CE_train_E3.h5 4 | val_path: /home/mila/o/oussama.boussif/pde_oned/data/CE_valid_E3.h5 5 | test_path: /home/mila/o/oussama.boussif/pde_oned/data/CE_test_E3.h5 6 | nt_train: 250 7 | nx_train: 50 8 | nt_val: 250 9 | nx_val: 50 10 | nt_test: 250 11 | nx_test: 50 12 | 13 | num_workers: 0 14 | batch_size: 32 -------------------------------------------------------------------------------- /configs/model/magnet_gnn.yaml: -------------------------------------------------------------------------------- 1 | name: magnet_gnn 2 | params: 3 | # Model hyperparameters 4 | time_slice: 25 5 | latent_dim: 128 6 | num_message_passing_steps: 5 7 | mlp_layers: 4 8 | mlp_hidden: 128 9 | radius: 0.08 10 | n_chan: 128 11 | teacher_forcing: True 12 | codec_neighbors: 4 13 | noise: 0 14 | interpolation: area 15 | # Optimization hyperparameters 16 | factor: 0.3 17 | step_size: 50 18 | loss: l1 19 | lr: 0.001 20 | weight_decay: 0 -------------------------------------------------------------------------------- /scripts/magnet_gnn_b1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_b1 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:00:00 9 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_b1-slurm-%j.out 10 | 11 | # 1. Load the required modules 12 | module --quiet load anaconda/3 13 | conda activate dedalus 14 | 15 | python test_reg_b1.py -------------------------------------------------------------------------------- /scripts/magnet_gnn_b2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_b2 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:00:00 9 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_b2-slurm-%j.out 10 | 11 | # 1. Load the required modules 12 | module --quiet load anaconda/3 13 | conda activate dedalus 14 | 15 | python test_reg_b2.py -------------------------------------------------------------------------------- /configs/model/magnet_cnn.yaml: -------------------------------------------------------------------------------- 1 | name: magnet_cnn 2 | params: 3 | # Model hyperparameters 4 | time_slice: 16 5 | latent_dim: 32 6 | num_message_passing_steps: 10 7 | mlp_layers: 4 8 | mlp_hidden: 64 9 | radius: 0.08 10 | scales: 1 11 | n_chan: 128 12 | kernel_size: 3 13 | res_scale: 1 14 | res_layers: 4 15 | teacher_forcing: True 16 | interpolation: area 17 | # Optimization hyperparameters 18 | factor: 0.3 19 | step_size: 40 20 | loss: l1 21 | lr: 0.001 22 | weight_decay: 0.0000001 -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule_graph.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule.HDF5DatamoduleGraph 2 | name: h5_datamodule_graph 3 | train_path: /home/mila/o/oussama.boussif/pde_oned/data/CE_train_E3.h5 4 | val_path: /home/mila/o/oussama.boussif/pde_oned/data/CE_valid_E3.h5 5 | test_path: /home/mila/o/oussama.boussif/pde_oned/data/CE_test_E3.h5 6 | nt_train: 250 7 | nx_train: 50 8 | nt_val: 250 9 | nx_val: 50 10 | nt_test: 250 11 | nx_test: 50 12 | radius: 1 13 | in_timesteps: 25 14 | 15 | num_workers: 0 16 | batch_size: 32 -------------------------------------------------------------------------------- /configs/model/magnet_cnn_2d.yaml: -------------------------------------------------------------------------------- 1 | name: magnet_cnn_2d 2 | params: 3 | # Model hyperparameters 4 | time_slice: 16 5 | latent_dim: 32 6 | num_message_passing_steps: 10 7 | mlp_layers: 4 8 | mlp_hidden: 64 9 | radius: 0.1 10 | scales: 1 11 | n_chan: 128 12 | kernel_size: 3 13 | res_scale: 1 14 | res_layers: 16 15 | teacher_forcing: True 16 | interpolation: area 17 | # Optimization hyperparameters 18 | factor: 0.3 19 | step_size: 40 20 | loss: l1 21 | lr: 0.001 22 | weight_decay: 0.0000001 -------------------------------------------------------------------------------- /scripts/mpnn_test_uniform.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_test_uniform 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=4:30:00 9 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_test_uniform-slurm-%j.out 10 | 11 | # 1. Load the required modules 12 | module --quiet load anaconda/3 13 | conda activate dedalus 14 | 15 | python test_irr_uniform.py -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule_2d.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule_2d.HDF5Datamodule_2d 2 | name: h5_datamodule_2d 3 | train_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5 4 | val_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5 5 | test_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5 6 | nt_train: 50 7 | res_train: 64 8 | nt_val: 50 9 | res_val: 64 10 | nt_test: 50 11 | res_test: 64 12 | 13 | num_workers: 3 14 | batch_size: 32 -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule_implicit.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule.HDF5DatamoduleImplicit 2 | name: h5_datamodule_implicit 3 | train_path: /home/mila/o/oussama.boussif/pde_oned/data/KS_train.h5 4 | val_path: /home/mila/o/oussama.boussif/pde_oned/data/KS_valid.h5 5 | test_path: /home/mila/o/oussama.boussif/pde_oned/data/Heat_test.h5 6 | nt_train: 128 7 | nx_train: 256 8 | nt_val: 128 9 | nx_val: 256 10 | nt_test: 256 11 | nx_test: 256 12 | sampling: uniform 13 | samples: 32 14 | 15 | num_workers: 0 16 | batch_size: 32 -------------------------------------------------------------------------------- /scripts/mpnn_test_condensed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_test_condensed 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=4:30:00 9 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_test_condensed-slurm-%j.out 10 | 11 | # 1. Load the required modules 12 | module --quiet load anaconda/3 13 | conda activate dedalus 14 | 15 | python test_irr_condensed.py -------------------------------------------------------------------------------- /configs/model/magnet_cnn_no_interaction.yaml: -------------------------------------------------------------------------------- 1 | name: magnet_cnn_no_interaction 2 | params: 3 | # Model hyperparameters 4 | time_slice: 16 5 | use_lstm: True 6 | lstm_hidden: 256 7 | lstm_layers: 4 8 | mlp_layers: 1 9 | mlp_hidden: 32 10 | scales: 1 11 | n_chan: 128 12 | kernel_size: 3 13 | teacher_forcing: False 14 | res_scale: 1 15 | res_layers: 16 16 | interpolation: area 17 | # Optimization hyperparameters 18 | factor: 0.6 19 | step_size: 50 20 | loss: l1 21 | lr: 0.0005 22 | weight_decay: 0.0001 -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule_implicit_gnn.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule.HDF5DatamoduleImplicitGNN 2 | name: h5_datamodule_implicit_gnn 3 | train_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/KS_train.h5 4 | val_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/KS_valid.h5 5 | test_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/Heat_test.h5 6 | nt_train: 128 7 | nx_train: 256 8 | nt_val: 128 9 | nx_val: 256 10 | nt_test: 256 11 | nx_test: 256 12 | sampling: uniform 13 | samples: 32 14 | 15 | num_workers: 0 16 | batch_size: 32 -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule_implicit_2d.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule_2d.HDF5DatamoduleImplicit_2d 2 | name: h5_datamodule_implicit_2d 3 | train_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5 4 | val_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5 5 | test_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5 6 | nt_train: 50 7 | res_train: 64 8 | nt_val: 50 9 | res_val: 64 10 | nt_test: 50 11 | res_test: 64 12 | samples: 32 13 | 14 | num_workers: 3 15 | batch_size: 32 -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule_graph_2d.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule_2d.HDF5DatamoduleGraph_2d 2 | name: h5_datamodule_graph_2d 3 | train_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5 4 | val_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5 5 | test_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5 6 | nt_train: 50 7 | res_train: 64 8 | nt_val: 50 9 | res_val: 64 10 | nt_test: 50 11 | res_test: 64 12 | train_regular: True 13 | val_regular: True 14 | test_regular: True 15 | 16 | num_workers: 3 17 | batch_size: 4 -------------------------------------------------------------------------------- /models/factory.py: -------------------------------------------------------------------------------- 1 | from .fno_1d import FNO1d 2 | from .fno_2d import FNO2d 3 | from .magnet_cnn import MAgNetCNN 4 | from .magnet_cnn_2d import MAgNetCNN_2d 5 | from .magnet_cnn_no_interaction import MAgNetCNN_no_interaction 6 | from .magnet_gnn import MAgNetGNN 7 | from .mpnn import MPNN 8 | from .mpnn_2d import MPNN_2d 9 | 10 | FACTORY = { 11 | 'fno_1d': FNO1d, 12 | 'fno_2d': FNO2d, 13 | 'mpnn': MPNN, 14 | 'mpnn_2d': MPNN_2d, 15 | 'magnet_cnn_no_interaction': MAgNetCNN_no_interaction, 16 | 'magnet_cnn': MAgNetCNN, 17 | 'magnet_cnn_2d': MAgNetCNN_2d, 18 | 'magnet_gnn': MAgNetGNN 19 | } -------------------------------------------------------------------------------- /configs/datamodule/h5_datamodule_implicit_gnn_2d.yaml: -------------------------------------------------------------------------------- 1 | _target_: datamodule.h5_datamodule_2d.HDF5DatamoduleImplicitGNN_2d 2 | name: h5_datamodule_implicit_gnn_2d 3 | train_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_512.h5 4 | val_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5 5 | test_path: /home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5 6 | nt_train: 50 7 | res_train: 64 8 | nt_val: 50 9 | res_val: 32 10 | nt_test: 50 11 | res_test: 32 12 | samples: 32 13 | train_regular: False 14 | val_regular: True 15 | test_regular: True 16 | 17 | num_workers: 3 18 | batch_size: 32 -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: fno_1d 3 | - datamodule: h5_datamodule 4 | - trainer: default 5 | - callbacks: default.yaml 6 | - override hydra/sweeper: optuna 7 | - override hydra/sweeper/sampler: tpe 8 | - override hydra/hydra_logging: colorlog 9 | - override hydra/job_logging: colorlog 10 | seed: 42 11 | name: fno_1d 12 | ckpt_path: null 13 | hydra: 14 | sweep: 15 | dir: /network/scratch/o/oussama.boussif/pdeone/logs/multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} 16 | subdir: ${hydra.job.num} 17 | sweeper: 18 | sampler: 19 | seed: 42 20 | direction: minimize 21 | study_name: fno_1d 22 | storage: null 23 | n_trials: 15 24 | n_jobs: 2 25 | run: 26 | dir: /network/scratch/o/oussama.boussif/pdeone/logs/experiments/${name}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val_mae_loss" # name of the logged metric which determines when model is improving 4 | mode: "min" 5 | save_top_k: 1 # save k best models (determined by above metric) 6 | save_last: True # additionaly always save model from last epoch 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "epoch_{epoch:03d}" 10 | auto_insert_metric_name: False 11 | 12 | early_stopping: 13 | _target_: pytorch_lightning.callbacks.EarlyStopping 14 | monitor: "val_mae_loss" # name of the logged metric which determines when model is improving 15 | mode: "min" 16 | patience: 35 # how many validation epochs of not improving until training stops 17 | min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement -------------------------------------------------------------------------------- /models/backbones/mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | activations = { 5 | 'relu': nn.ReLU(), 6 | 'tanh': nn.Tanh(), 7 | 'gelu': nn.GELU() 8 | } 9 | class MLP(nn.Module): 10 | def __init__(self, in_dim, hidden_list, out_dim, activation='relu'): 11 | 12 | super().__init__() 13 | assert activation in ['relu', 'tanh', 'gelu'] 14 | 15 | self.layers = nn.ModuleList() 16 | self.layers.append(nn.Linear(in_dim, hidden_list[0])) 17 | self.layers.append(activations[activation]) 18 | 19 | for i in range(len(hidden_list)-1): 20 | self.layers.append(nn.Linear(hidden_list[i], hidden_list[i+1])) 21 | self.layers.append(activations[activation]) 22 | self.layers.append(nn.Linear(hidden_list[-1],out_dim)) 23 | 24 | def forward(self, x): 25 | out = x 26 | for layer in self.layers: 27 | out = layer(out) 28 | return out -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_64_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_64_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_64_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=32 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_128_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_128_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_128_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_128.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=128 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=64 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_256_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_256_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_256_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_256.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=256 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=128 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_512_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_512_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_512_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_512.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=512 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=256 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/fno_2d/fno_2d_b1_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=fno_2d_b1_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/fno_2d_b1_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=fno_2d \ 21 | name=fno_2d \ 22 | datamodule=h5_datamodule_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | model.params.time_history=10 \ 33 | model.params.time_future=10 \ 34 | model.params.teacher_forcing=False \ 35 | model.params.modes_1=12 \ 36 | model.params.modes_2=12 \ 37 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/fno_2d/fno_2d_b2_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=fno_2d_b2_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/fno_2d_b2_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=fno_2d \ 21 | name=fno_2d \ 22 | datamodule=h5_datamodule_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_train_B2_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | model.params.time_history=10 \ 33 | model.params.time_future=10 \ 34 | model.params.teacher_forcing=False \ 35 | model.params.modes_1=12 \ 36 | model.params.modes_2=12 \ 37 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_64_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_64_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_64_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=32 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:2 7 | #SBATCH --mem=60G 8 | #SBATCH --time=13:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | datamodule.batch_size=4 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.gpus=2 \ 37 | trainer.strategy='ddp' \ 38 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b2_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b2_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:2 7 | #SBATCH --mem=60G 8 | #SBATCH --time=13:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b2_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_train_B2_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | datamodule.batch_size=4 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.gpus=2 \ 37 | trainer.strategy='ddp' \ 38 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_cnn_2d/magnet_cnn_2d_b1_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_cnn_2d_b1_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:2 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_cnn_2d_b1_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_cnn_2d \ 21 | name=magnet_cnn_2d \ 22 | datamodule=h5_datamodule_implicit_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | datamodule.samples=256 \ 33 | model.params.time_slice=10 \ 34 | model.params.teacher_forcing=True \ 35 | trainer.gpus=2 \ 36 | trainer.strategy='ddp' \ 37 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_cnn_2d/magnet_cnn_2d_b2_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_cnn_2d_b2_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:2 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_cnn_2d_b2_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_cnn_2d \ 21 | name=magnet_cnn_2d \ 22 | datamodule=h5_datamodule_implicit_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_train_B2_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | datamodule.samples=256 \ 33 | model.params.time_slice=10 \ 34 | model.params.teacher_forcing=True \ 35 | trainer.gpus=2 \ 36 | trainer.strategy='ddp' \ 37 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_128_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_128_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_128_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_128.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=128 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=64 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_256_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_256_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_256_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_256.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=256 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=128 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_512_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_512_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_512_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_512.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=512 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | datamodule.samples=256 \ 34 | model.params.time_slice=10 \ 35 | trainer.max_epochs=250 -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_64_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_64_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_64_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_128_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_128_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_128_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_128.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=128 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_256_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_256_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_256_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_256.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=256 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_512_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_512_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_512_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_512.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=512 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b1_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b1_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:2 7 | #SBATCH --mem=60G 8 | #SBATCH --time=9:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b1_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_train_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | datamodule.batch_size=8 \ 33 | datamodule.samples=256 \ 34 | datamodule.train_regular=True \ 35 | model.params.time_slice=10 \ 36 | trainer.max_epochs=250 \ 37 | trainer.gpus=2 \ 38 | trainer.strategy='ddp' -------------------------------------------------------------------------------- /scripts/magnet_gnn/magnet_gnn_2d_b2_64_regular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=magnet_gnn_2d_b2_64_regular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:2 7 | #SBATCH --mem=60G 8 | #SBATCH --time=9:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/magnet_gnn_2d_b2_64_regular-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=magnet_gnn \ 21 | name=magnet_gnn \ 22 | datamodule=h5_datamodule_implicit_gnn_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_train_B2_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B2/burgers_test_B2_64.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=64 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=64 \ 32 | datamodule.batch_size=8 \ 33 | datamodule.samples=256 \ 34 | datamodule.train_regular=True \ 35 | model.params.time_slice=10 \ 36 | trainer.max_epochs=250 \ 37 | trainer.gpus=2 \ 38 | trainer.strategy='ddp' -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_64_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_64_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_64_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_128_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_128_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_128_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_128.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=128 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_256_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_256_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_256_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_256.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=256 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_512_irregular.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_512_irregular 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_512_irregular-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/uniform/burgers_train_irregular_B1_512.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=512 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_64_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_64_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_64_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_128_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_128_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_128_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_128.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=128 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_256_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_256_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_256_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_256.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=256 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/mpnn_2d_b1_512_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_512_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_512_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_512.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=512 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_64_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_64_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_64_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_64.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=64 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_128_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_128_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=1:30:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_128_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_128.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=128 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_256_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_256_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=2:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_256_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_256.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=256 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /scripts/mpnn_2d/new_seeds/mpnn_2d_b1_512_irregular_concentrated.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=mpnn_2d_b1_512_irregular_concentrated 4 | #SBATCH --partition=long 5 | #SBATCH --cpus-per-task=6 6 | #SBATCH --gres=gpu:rtx8000:1 7 | #SBATCH --mem=60G 8 | #SBATCH --time=3:00:00 9 | #SBATCH --array=1-5 10 | #SBATCH -o /network/scratch/o/oussama.boussif/slurms/mpnn_2d_b1_512_irregular_concentrated-slurm-%A_%a.out 11 | 12 | param_store=scripts/new_seeds.txt 13 | seed=$(cat $param_store | awk -v var=$SLURM_ARRAY_TASK_ID 'NR==var {print $1}') 14 | # 1. Load the required modules 15 | module --quiet load anaconda/3 16 | conda activate dedalus 17 | 18 | python run.py \ 19 | seed=$seed \ 20 | model=mpnn_2d \ 21 | name=mpnn_2d \ 22 | datamodule=h5_datamodule_graph_2d \ 23 | datamodule.train_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/concentrated/burgers_train_irregular_B1_512.h5' \ 24 | datamodule.val_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 25 | datamodule.test_path='/home/mila/o/oussama.boussif/scratch/pdeone/data/B1/burgers_test_B1_32.h5' \ 26 | datamodule.nt_train=50 \ 27 | datamodule.res_train=512 \ 28 | datamodule.nt_val=50 \ 29 | datamodule.res_val=32 \ 30 | datamodule.nt_test=50 \ 31 | datamodule.res_test=32 \ 32 | datamodule.batch_size=32 \ 33 | model.params.time_window=10 \ 34 | model.params.neighbors=4 \ 35 | model.params.teacher_forcing=False \ 36 | trainer.max_epochs=250 \ 37 | datamodule.train_regular=False -------------------------------------------------------------------------------- /tune.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import List, Optional 3 | warnings.filterwarnings("ignore") 4 | 5 | from comet_ml import Experiment 6 | 7 | from pytorch_lightning.loggers import CometLogger 8 | import pytorch_lightning as pl 9 | from pytorch_lightning.utilities.seed import seed_everything 10 | 11 | from pytorch_lightning import Callback, Trainer, LightningDataModule 12 | 13 | 14 | import hydra 15 | from omegaconf import DictConfig 16 | 17 | from models.factory import FACTORY 18 | import utils 19 | 20 | log = utils.get_logger(__name__) 21 | 22 | 23 | @hydra.main(config_path="configs", config_name="config.yaml") 24 | def main(cfg: DictConfig): 25 | 26 | dataset = cfg.datamodule.name 27 | model = cfg.model.name 28 | 29 | print("This run will tune the model", model, "on the", dataset, "dataset") 30 | seed_everything(cfg.seed, workers=True) 31 | 32 | # Initialize Logger 33 | comet_logger = CometLogger( 34 | project_name=f"{model}-tune-{dataset}", 35 | experiment_name=f"{model}_seed_{cfg.seed}_{dataset}") 36 | 37 | # Initialize the datamodule 38 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 39 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 40 | 41 | # Init lightning callbacks 42 | callbacks: List[Callback] = [] 43 | if "callbacks" in cfg: 44 | for _, cb_conf in cfg.callbacks.items(): 45 | if "_target_" in cb_conf: 46 | log.info(f"Instantiating callback <{cb_conf._target_}>") 47 | callbacks.append(hydra.utils.instantiate(cb_conf)) 48 | 49 | trainer: Trainer = hydra.utils.instantiate( 50 | cfg.trainer, callbacks=callbacks, logger=comet_logger, _convert_="partial" 51 | ) 52 | 53 | model = FACTORY[model] 54 | model = model(cfg.model.params) 55 | 56 | trainer.fit(model, datamodule) 57 | out = trainer.callback_metrics['val_mae_loss'].item() 58 | return out 59 | 60 | if __name__ == "__main__": 61 | 62 | main() 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import List, Optional 3 | warnings.filterwarnings("ignore") 4 | 5 | from comet_ml import Experiment 6 | 7 | from pytorch_lightning.loggers import CometLogger 8 | from pytorch_lightning.utilities.seed import seed_everything 9 | from pytorch_lightning import Callback, Trainer, LightningDataModule 10 | 11 | import hydra 12 | from omegaconf import DictConfig 13 | 14 | from models.factory import FACTORY 15 | import utils 16 | 17 | log = utils.get_logger(__name__) 18 | 19 | @hydra.main(config_path="configs/", config_name="config.yaml") 20 | def main(cfg: DictConfig): 21 | 22 | dataset = cfg.datamodule.name 23 | model = cfg.model.name 24 | 25 | print("This run trains and tests the model", model, "on the", dataset, "dataset") 26 | seed_everything(cfg.seed, workers=True) 27 | 28 | # Initialize Logger 29 | comet_logger = CometLogger( 30 | project_name=f"pdeone-{dataset.replace('_','-')}", 31 | experiment_name=f"{model}_seed_{cfg.seed}_{dataset}") 32 | 33 | # Initialize the datamodule 34 | log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>") 35 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule) 36 | 37 | # Init lightning callbacks 38 | callbacks: List[Callback] = [] 39 | if "callbacks" in cfg: 40 | for _, cb_conf in cfg.callbacks.items(): 41 | if "_target_" in cb_conf: 42 | log.info(f"Instantiating callback <{cb_conf._target_}>") 43 | callbacks.append(hydra.utils.instantiate(cb_conf)) 44 | 45 | trainer: Trainer = hydra.utils.instantiate( 46 | cfg.trainer, callbacks=callbacks, logger=comet_logger, _convert_="partial" 47 | ) 48 | 49 | model = FACTORY[model] 50 | model = model(cfg.model.params) 51 | 52 | trainer.fit(model, datamodule) # Train the model 53 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") # print path to best checkpoint 54 | # trainer.test(model, datamodule, ckpt_path='best', verbose=True) # Test the model 55 | 56 | if __name__ == "__main__": 57 | 58 | main() -------------------------------------------------------------------------------- /models/backbones/edsr.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class ResBlock(nn.Module): 4 | def __init__(self, n_chan, kernel_size, 5 | bias=True, act=nn.ReLU(True), res_scale=1, mode='1d'): 6 | 7 | super().__init__() 8 | 9 | assert mode in ["1d", "2d"] 10 | self.res_scale = res_scale 11 | self.mode = mode 12 | 13 | if mode == '2d': 14 | self.conv_1 = nn.Conv2d(n_chan, n_chan, kernel_size, padding=(kernel_size//2)) 15 | self.act = act 16 | self.conv_2 = nn.Conv2d(n_chan, n_chan, kernel_size, padding=(kernel_size//2)) 17 | else: 18 | self.conv_1 = nn.Conv1d(n_chan, n_chan, kernel_size, padding=(kernel_size//2)) 19 | self.act = act 20 | self.conv_2 = nn.Conv1d(n_chan, n_chan, kernel_size, padding=(kernel_size//2)) 21 | 22 | def forward(self, x): 23 | out = self.conv_1(x) 24 | out = self.act(out) 25 | out = self.conv_2(out) 26 | 27 | out += x 28 | 29 | out = out.mul(self.res_scale) 30 | return out 31 | 32 | class EDSR(nn.Module): 33 | def __init__(self, in_chan, n_chan=64, res_layers=16, kernel_size=3, res_scale=1, mode='1d'): 34 | ''' 35 | EDSR model without upsampling 36 | ''' 37 | super().__init__() 38 | assert mode in ["1d", "2d"] 39 | self.mode = mode 40 | 41 | if mode == '2d': 42 | self.head_conv = nn.Conv2d(in_chan, n_chan, kernel_size, padding=(kernel_size//2)) 43 | self.res_layers = nn.Sequential(*[ResBlock(n_chan, kernel_size, res_scale, mode=mode) for _ in range(res_layers)]) 44 | self.tail_conv = nn.Conv2d(n_chan, n_chan, kernel_size, padding=(kernel_size//2)) 45 | else: 46 | self.head_conv = nn.Conv1d(in_chan, n_chan, kernel_size, padding=(kernel_size//2)) 47 | self.res_layers = nn.Sequential(*[ResBlock(n_chan, kernel_size, res_scale, mode=mode) for _ in range(res_layers)]) 48 | self.tail_conv = nn.Conv1d(n_chan, n_chan, kernel_size, padding=(kernel_size//2)) 49 | 50 | self.out_dim = n_chan 51 | 52 | def forward(self, x): 53 | 54 | 55 | x = self.head_conv(x) 56 | res = self.res_layers(x) 57 | res = self.tail_conv(res) 58 | res += x 59 | 60 | return res -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from pytorch_lightning.utilities import rank_zero_only 4 | 5 | def to_coords(x: torch.Tensor, t: torch.Tensor): 6 | """ 7 | Transforms the coordinates to a tensor X of shape [time, space, 2]. 8 | Args: 9 | x: spatial coordinates 10 | t: temporal coordinates 11 | Returns: 12 | torch.Tensor: X[..., 0] is the space coordinate (in 2D) 13 | X[..., 1] is the time coordinate (in 2D) 14 | """ 15 | x_, t_ = torch.meshgrid(x, t) 16 | x_, t_ = x_.T, t_.T 17 | return torch.stack((x_, t_), -1) 18 | 19 | def make_coord(shape, ranges=None, flatten=True): 20 | """ 21 | Make coordinates at grid centers. 22 | """ 23 | coord_seqs = [] 24 | for i, n in enumerate(shape): 25 | if ranges is None: 26 | v0, v1 = -1, 1 27 | else: 28 | v0, v1 = ranges[i] 29 | r = (v1 - v0) / (2 * n) 30 | seq = v0 + r + (2 * r) * torch.arange(n).float() 31 | coord_seqs.append(seq) 32 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 33 | if flatten: 34 | ret = ret.view(-1, ret.shape[-1]) 35 | return ret 36 | 37 | def get_logger(name=__name__): 38 | """ 39 | Initializes multi-GPU-friendly python command line logger. 40 | https://github.com/ashleve/lightning-hydra-template/blob/8b62eef9d0d9c863e88c0992595688d6289d954f/src/utils/utils.py#L12 41 | """ 42 | 43 | logger = logging.getLogger(name) 44 | 45 | # this ensures all logging levels get marked with the rank zero decorator 46 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 47 | for level in ( 48 | "debug", 49 | "info", 50 | "warning", 51 | "error", 52 | "exception", 53 | "fatal", 54 | "critical", 55 | ): 56 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 57 | 58 | return logger 59 | 60 | def to_pixel_samples(img): 61 | """ Convert the image to coord-RGB pairs. 62 | img: Tensor, (C, L) or (C, H, W) 63 | """ 64 | if len(img.shape) == 2: 65 | coord = make_coord(img.shape[-1:]) 66 | elif len(img.shape) == 3: 67 | coord = make_coord(img.shape[-1:]) 68 | else: 69 | NotImplementedError 70 | rgb = img.view(img.shape[0], -1).permute(1, 0) 71 | return coord, rgb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAgNet: Mesh-Agnostic Neural PDE Solver (Neurips 2022) 2 | This is the official repository to the paper ["MAgNet: Mesh-Agnostic Neural PDE Solver"](https://arxiv.org/abs/2210.05495) by [Oussama Boussif](https://jaggbow.github.io), [Dan Assouline](https://github.com/danassou), and professors [Loubna Benabbou](https://www.uqar.ca/universite/a-propos-de-l-uqar/departements/unites-departementales-des-sciences-de-la-gestion/benabbou-lobna) and [Yoshua Bengio](https://yoshuabengio.org/). 3 | 4 | In this paper, we aim to address the problem of learning solutions to Partial Differential Equations (PDE) while also generalizing to any mesh or resolution at test-time. This effectively enables us to generate predictions at any point of the PDE domain. 5 | 6 | ![MAgNet](assets/magnet.jpg "MAgNet: Mesh-Agnostic Neural PDE Solver") 7 | 8 | ![Predictions](assets/predictions.JPG "Predictions vs Ground-Truth for different resolutions") 9 | # Citation 10 | To cite our work, please use the following bibtex: 11 | ``` 12 | @inproceedings{magnet_neurips_2022, 13 | author = {Boussif, Oussama and Bengio, Yoshua and Benabbou, Loubna and Assouline, Dan}, 14 | booktitle = {Advances in Neural Information Processing Systems}, 15 | editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh}, 16 | pages = {31972--31985}, 17 | publisher = {Curran Associates, Inc.}, 18 | title = {MAgNet: Mesh Agnostic Neural PDE Solver}, 19 | url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/cf4c7ee0734cdfe09a099cf6cd7b117a-Paper-Conference.pdf}, 20 | volume = {35}, 21 | year = {2022} 22 | } 23 | ``` 24 | # Requirements 25 | 26 | Start by installing the required modules: 27 | ``` 28 | pip install -r requirements.txt 29 | ``` 30 | # Dataset 31 | The dataset is available for download at the following link: [magnet dataset](https://drive.google.com/drive/folders/1hZ67IOFr8XwErpXYZnDC9WRcFmb-BeBb?usp=sharing) and contains two folders: ``1d`` and ``2d`` for the 1D and 2D PDE datasets respectively. 32 | 33 | The structure of the 1D dataset is as follows: 34 | ``` 35 | ├───E1 36 | │ ├───irregular 37 | │ │ CE_test_E1_graph_100.h5 38 | │ │ CE_test_E1_graph_200.h5 39 | │ │ CE_test_E1_graph_40.h5 40 | │ │ CE_test_E1_graph_50.h5 41 | │ │ CE_train_E1_graph_30.h5 42 | │ │ CE_train_E1_graph_50.h5 43 | │ │ CE_train_E1_graph_70.h5 44 | │ │ 45 | │ └───regular 46 | │ CE_test_E1_100.h5 47 | │ CE_test_E1_200.h5 48 | │ CE_test_E1_40.h5 49 | │ CE_test_E1_50.h5 50 | │ CE_train_E1_50.h5 51 | │ 52 | ├───E2 53 | │ └───regular 54 | │ CE_train_E2_50.h5 55 | │ CE_test_E2_100.h5 56 | │ CE_test_E2_200.h5 57 | │ CE_test_E2_40.h5 58 | │ CE_test_E2_50.h5 59 | │ 60 | └───E3 61 | └───regular 62 | CE_test_E3_100.h5 63 | CE_test_E3_200.h5 64 | CE_test_E3_40.h5 65 | CE_test_E3_50.h5 66 | CE_train_E3_50.h5 67 | ``` 68 | 69 | Each file is formatted as follows: `CE_{mode}_{dataset}_{resolution}.h5` where `mode` can be `train` or `test` and `dataset` can be `E1`, `E2` or `E3` and `resolution` denotes the resolution of the dataset. The folder `regular` contains simulations on a regular grid and `irregular` contains simulations on an irregular grid. 70 | 71 | --------- 72 | 73 | For the 2D dataset, it is structured as follows: 74 | ``` 75 | ├── B1 76 | │ ├── burgers_test_B1_128.h5 77 | │ ├── burgers_test_B1_256.h5 78 | │ ├── burgers_test_B1_32.h5 79 | │ ├── burgers_test_B1_64.h5 80 | │ ├── burgers_train_B1_128.h5 81 | │ ├── burgers_train_B1_256.h5 82 | │ ├── burgers_train_B1_32.h5 83 | │ ├── burgers_train_B1_64.h5 84 | │ ├── concentrated 85 | │ │ ├── burgers_train_irregular_B1_128.h5 86 | │ │ ├── burgers_train_irregular_B1_256.h5 87 | │ │ ├── burgers_train_irregular_B1_512.h5 88 | │ │ └── burgers_train_irregular_B1_64.h5 89 | │ └── uniform 90 | │ ├── burgers_train_irregular_B1_128.h5 91 | │ ├── burgers_train_irregular_B1_256.h5 92 | │ ├── burgers_train_irregular_B1_512.h5 93 | │ └── burgers_train_irregular_B1_64.h5 94 | └── B2 95 | ├── burgers_test_B2_128.h5 96 | ├── burgers_test_B2_256.h5 97 | ├── burgers_test_B2_32.h5 98 | ├── burgers_test_B2_64.h5 99 | ├── burgers_train_B2_128.h5 100 | ├── burgers_train_B2_256.h5 101 | ├── burgers_train_B2_32.h5 102 | └── burgers_train_B2_64.h5 103 | ``` 104 | Each file is formatted as follows: `burgers_{mode}_{dataset}_{resolution}.h5` where `mode` can be `train` or `test` and `dataset` can be `B1` or `B2` and `resolution` is the resolution of the dataset. The folder `concentrated` contains simulations on an irregular grid where points are sampled around a region in the grid while `uniform` contains simulations on a uniform irregular grid. 105 | # Experiments 106 | We use `hydra` for config management and command line parsing so it's straightforward to run experiments using our code-base. Below is an example command for training the **MAgNet[CNN]** model on the **E1** dataset for 250 epochs on four GPUs: 107 | ``` 108 | python run.py \ 109 | model=magnet_cnn \ 110 | name=magnet_cnn \ 111 | datamodule=h5_datamodule_implicit \ 112 | datamodule.train_path={train_path} \ 113 | datamodule.val_path={val_path}' \ 114 | datamodule.test_path={test_path} \ 115 | datamodule.nt_train=250 \ 116 | datamodule.nx_train={train_resolution} \ 117 | datamodule.nt_val=250 \ 118 | datamodule.nx_val={val_resolution} \ 119 | datamodule.nt_test=250 \ 120 | datamodule.nx_test={test_resolution} \ 121 | datamodule.samples=16 \ 122 | model.params.time_slice=25 \ 123 | trainer.max_epochs=250 \ 124 | trainer.gpus=4 \ 125 | trainer.strategy='ddp' 126 | ``` 127 | You can find the relevant scripts that were used to run experiments under the ``scripts`` folder. 128 | -------------------------------------------------------------------------------- /models/fno_2d.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from: https://github.com/zongyi-li/fourier_neural_operator/blob/master/fourier_2d_time.py 3 | ''' 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | import pytorch_lightning as pl 9 | 10 | 11 | class SpectralConv2d(nn.Module): 12 | def __init__(self, in_channels, out_channels, modes1, modes2): 13 | super(SpectralConv2d, self).__init__() 14 | 15 | """ 16 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 17 | """ 18 | 19 | self.in_channels = in_channels 20 | self.out_channels = out_channels 21 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 22 | self.modes2 = modes2 23 | 24 | self.scale = (1 / (in_channels * out_channels)) 25 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 26 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 27 | 28 | # Complex multiplication 29 | def compl_mul2d(self, input, weights): 30 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 31 | return torch.einsum("bixy,ioxy->boxy", input, weights) 32 | 33 | def forward(self, x): 34 | batchsize = x.shape[0] 35 | #Compute Fourier coeffcients up to factor of e^(- something constant) 36 | x_ft = torch.fft.rfft2(x) 37 | 38 | # Multiply relevant Fourier modes 39 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1)//2 + 1, dtype=torch.cfloat, device=x.device) 40 | out_ft[:, :, :self.modes1, :self.modes2] = \ 41 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 42 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 43 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 44 | 45 | #Return to physical space 46 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 47 | return x 48 | 49 | 50 | class FNO2d(pl.LightningModule): 51 | def __init__(self,hparams): 52 | 53 | super().__init__() 54 | 55 | self.save_hyperparameters() 56 | 57 | # Training parameters 58 | self.lr = hparams.lr 59 | self.weight_decay = hparams.weight_decay 60 | self.factor = hparams.factor 61 | self.step_size = hparams.step_size 62 | self.loss = hparams.loss 63 | # Model parameters 64 | self.modes_1 = hparams.modes_1 65 | self.modes_2 = hparams.modes_2 66 | self.width = hparams.width 67 | self.time_history = hparams.time_history 68 | self.time_future = hparams.time_future 69 | self.num_layers = hparams.num_layers 70 | self.teacher_forcing = hparams.teacher_forcing 71 | 72 | if self.loss == 'l1': 73 | self.criterion = nn.L1Loss() 74 | elif self.loss == 'l2': 75 | self.criterion = nn.MSELoss() 76 | elif self.loss == 'smooth_l1': 77 | self.criterion = nn.SmoothL1Loss() 78 | 79 | self.mse_criterion = nn.MSELoss() 80 | self.mae_criterion = nn.L1Loss() 81 | 82 | self.fc0 = nn.Linear(self.time_history + 3, self.width) 83 | self.fc1 = nn.Linear(self.width, 128) 84 | self.fc2 = nn.Linear(128, self.time_future) 85 | 86 | fourier_layers = [] 87 | conv_layers = [] 88 | for i in range(self.num_layers): 89 | fourier_layers.append(SpectralConv2d(self.width, self.width, self.modes_1, self.modes_2)) 90 | conv_layers.append(nn.Conv2d(self.width, self.width, 1)) 91 | self.fourier_layers = nn.ModuleList(fourier_layers) 92 | self.conv_layers = nn.ModuleList(conv_layers) 93 | 94 | 95 | def forward(self, u: torch.Tensor, dx: torch.Tensor, dy: torch.Tensor, dt: torch.Tensor): 96 | """ 97 | Forward pass of FNO network. 98 | The input to the forward pass has the shape [batch, time_history, H, W]. 99 | 1. Add dx, dy and dt as channel dimension to the time_history 100 | 2. Lift the input to the desired channel dimension by self.fc0 101 | 3. 5 (default) FNO layers 102 | 4. Project from the channel space to the output space by self.fc1 and self.fc2. 103 | The output has the shape [batch, time_future, H, W]. 104 | Args: 105 | u (torch.Tensor): input tensor of shape [batch, time_history, H, W] 106 | dx (torch.Tensor): spatial distances 107 | dy (torch.Tensor): spatial distances 108 | dt (torch.Tensor): temporal distances 109 | Returns: torch.Tensor: output has the shape [batch, time_future, x] 110 | """ 111 | B, T, H, W = u.shape 112 | x = torch.cat(( 113 | u, 114 | dx[:, None, None, None].to(u.device).repeat(1, 1, H, W), 115 | dy[:, None, None, None].to(u.device).repeat(1, 1, H, W), 116 | dt[:, None, None, None].repeat(1, 1, H, W).to(u.device)), 1) 117 | 118 | x = x.permute(0,2,3,1) 119 | x = self.fc0(x) # B, H, W, C 120 | x = x.permute(0, 3, 1, 2) # B, C, H, W 121 | 122 | for fourier, conv in zip(self.fourier_layers, self.conv_layers): 123 | x1 = fourier(x) 124 | x2 = conv(x) 125 | x = x1 + x2 126 | x = F.gelu(x) 127 | 128 | x = x.permute(0,2,3,1) 129 | x = self.fc1(x) 130 | x = F.gelu(x) 131 | x = self.fc2(x) 132 | x = x.permute(0,3,1,2) 133 | return x 134 | 135 | def configure_optimizers(self): 136 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 137 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor) 138 | return { 139 | "optimizer": optimizer, 140 | "lr_scheduler": { 141 | "scheduler": scheduler 142 | }, 143 | } 144 | 145 | def training_step(self, train_batch, batch_idx): 146 | u, dx, dy, dt = train_batch 147 | u = u.float() 148 | dx = dx.float() 149 | dy = dy.float() 150 | dt = dt.float() 151 | 152 | u_history = u[:,:self.time_history] # B, T_history, H, W 153 | u_future = u[:,self.time_history:] # B, T_future, H, W 154 | T_future = u_future.shape[1] 155 | 156 | u_hat = [] 157 | inp = u_history 158 | for t in range(T_future//self.time_future): 159 | y_hat = self.forward(inp, dx, dy, dt) 160 | u_hat.append(y_hat) 161 | if self.teacher_forcing: 162 | inp = u_future[:,t*self.time_future:(t+1)*self.time_future] 163 | else: 164 | inp = y_hat 165 | 166 | u_hat = torch.cat(u_hat, dim=1) 167 | 168 | loss = self.criterion(u_hat, u_future) 169 | mae_loss = self.mae_criterion(u_hat, u_future) 170 | 171 | self.log('train_loss', loss, prog_bar=True) 172 | self.log('train_mae_loss', mae_loss, prog_bar=True) 173 | 174 | return loss 175 | 176 | def validation_step(self, val_batch, batch_idx): 177 | u, dx, dy, dt = val_batch 178 | u = u.float() 179 | dx = dx.float() 180 | dy = dy.float() 181 | dt = dt.float() 182 | 183 | u_history = u[:,:self.time_history] # B, T_history, H, W 184 | u_future = u[:,self.time_history:] # B, T_future, H, W 185 | T_future = u_future.shape[1] 186 | 187 | u_hat = [] 188 | inp = u_history 189 | for t in range(T_future//self.time_future): 190 | y_hat = self.forward(inp, dx, dy, dt) 191 | u_hat.append(y_hat) 192 | inp = y_hat 193 | 194 | u_hat = torch.cat(u_hat, dim=1) 195 | 196 | loss = self.criterion(u_hat, u_future) 197 | mae_loss = self.mae_criterion(u_hat, u_future) 198 | 199 | self.log('val_loss', loss, prog_bar=True) 200 | self.log('val_mae_loss', mae_loss, prog_bar=True) 201 | return loss -------------------------------------------------------------------------------- /models/fno_1d.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from: https://github.com/zongyi-li/fourier_neural_operator/blob/master/fourier_2d_time.py 3 | ''' 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | import pytorch_lightning as pl 10 | from utils import * 11 | 12 | 13 | class SpectralConv1d(nn.Module): 14 | def __init__(self, in_channels: int, out_channels: int, modes: int): 15 | super(SpectralConv1d, self).__init__() 16 | """ 17 | Initializes the 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 18 | Args: 19 | in_channels (int): input channels to the FNO layer 20 | out_channels (int): output channels of the FNO layer 21 | modes (int): number of Fourier modes to multiply, at most floor(N/2) + 1 22 | """ 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.modes = modes 26 | self.scale = (1 / (in_channels*out_channels)) 27 | self.weights = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes, dtype=torch.cfloat)) 28 | 29 | def compl_mul1d(self, input: torch.Tensor, weights: torch.Tensor): 30 | """ 31 | Complex multiplication of the Fourier modes. 32 | [batch, in_channels, x], [in_channel, out_channels, x] -> [batch, out_channels, x] 33 | Args: 34 | input (torch.Tensor): input tensor of size [batch, in_channels, x] 35 | weights (torch.Tensor): weight tensor of size [in_channels, out_channels, x] 36 | Returns: 37 | torch.Tensor: output tensor with shape [batch, out_channels, x] 38 | """ 39 | return torch.einsum("bix,iox->box", input, weights) 40 | 41 | def forward(self, x: torch.Tensor): 42 | """ 43 | Fourier transformation, multiplication of relevant Fourier modes, backtransformation 44 | Args: 45 | x (torch.Tensor): input to forward pass os shape [batch, in_channels, x] 46 | Returns: 47 | torch.Tensor: output of size [batch, out_channels, x] 48 | """ 49 | batchsize = x.shape[0] 50 | # Fourier transformation 51 | x_ft = torch.fft.rfft(x) 52 | 53 | # Multiply relevant Fourier modes 54 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat) 55 | out_ft[:, :, :self.modes] = self.compl_mul1d(x_ft[:, :, :self.modes], self.weights) 56 | 57 | #Return to physical space 58 | x = torch.fft.irfft(out_ft, n=x.size(-1)) 59 | return x 60 | 61 | 62 | class FNO1d(pl.LightningModule): 63 | def __init__(self,hparams): 64 | 65 | super().__init__() 66 | 67 | self.save_hyperparameters() 68 | 69 | # Training parameters 70 | self.lr = hparams.lr 71 | self.weight_decay = hparams.weight_decay 72 | self.factor = hparams.factor 73 | self.step_size = hparams.step_size 74 | self.loss = hparams.loss 75 | # Model parameters 76 | self.modes = hparams.modes 77 | self.width = hparams.width 78 | self.time_history = hparams.time_history 79 | self.time_future = hparams.time_future 80 | self.num_layers = hparams.num_layers 81 | self.teacher_forcing = hparams.teacher_forcing 82 | 83 | if self.loss == 'l1': 84 | self.criterion = nn.L1Loss() 85 | elif self.loss == 'l2': 86 | self.criterion = nn.MSELoss() 87 | elif self.loss == 'smooth_l1': 88 | self.criterion = nn.SmoothL1Loss() 89 | 90 | self.mse_criterion = nn.MSELoss() 91 | self.mae_criterion = nn.L1Loss() 92 | 93 | self.fc0 = nn.Linear(self.time_history + 2, self.width) 94 | self.fc1 = nn.Linear(self.width, 128) 95 | self.fc2 = nn.Linear(128, self.time_future) 96 | 97 | fourier_layers = [] 98 | conv_layers = [] 99 | for i in range(self.num_layers): 100 | fourier_layers.append(SpectralConv1d(self.width, self.width, self.modes)) 101 | conv_layers.append(nn.Conv1d(self.width, self.width, 1)) 102 | self.fourier_layers = nn.ModuleList(fourier_layers) 103 | self.conv_layers = nn.ModuleList(conv_layers) 104 | 105 | 106 | def forward(self, u: torch.Tensor, dx: torch.Tensor, dt: torch.Tensor): 107 | """ 108 | Forward pass of FNO network. 109 | The input to the forward pass has the shape [batch, time_history, x]. 110 | 1. Add dx and dt as channel dimension to the time_history, repeat for every x 111 | 2. Lift the input to the desired channel dimension by self.fc0 112 | 3. 5 (default) FNO layers 113 | 4. Project from the channel space to the output space by self.fc1 and self.fc2. 114 | The output has the shape [batch, time_future, x]. 115 | Args: 116 | u (torch.Tensor): input tensor of shape [batch, time_history, x] 117 | dx (torch.Tensor): spatial distances 118 | dt (torch.Tensor): temporal distances 119 | Returns: torch.Tensor: output has the shape [batch, time_future, x] 120 | """ 121 | #TODO: rewrite training method and forward pass without permutation 122 | # [b, x, c] = [b, x, t+2] 123 | nx = u.shape[1] 124 | x = torch.cat((u, dx[:, None, None].to(u.device).repeat(1, nx, 1), 125 | dt[:, None, None].repeat(1, nx, 1).to(u.device)), -1) 126 | 127 | x = self.fc0(x) 128 | # [b, x, c] -> [b, c, x] 129 | x = x.permute(0, 2, 1) 130 | 131 | for fourier, conv in zip(self.fourier_layers, self.conv_layers): 132 | x1 = fourier(x) 133 | x2 = conv(x) 134 | x = x1 + x2 135 | x = F.gelu(x) 136 | 137 | # [b, c, x] -> [b, x, c] 138 | x = x.permute(0, 2, 1) 139 | x = self.fc1(x) 140 | x = F.gelu(x) 141 | x = self.fc2(x) 142 | return x 143 | 144 | def configure_optimizers(self): 145 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 146 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor) 147 | return { 148 | "optimizer": optimizer, 149 | "lr_scheduler": { 150 | "scheduler": scheduler 151 | }, 152 | } 153 | 154 | def training_step(self, train_batch, batch_idx): 155 | u, dx, dt = train_batch 156 | u = u.float() 157 | dx = dx.float() 158 | dt = dt.float() 159 | 160 | L = u.shape[-1] 161 | u_history = u[:,:self.time_history] # B, T_history, L 162 | u_future = u[:,self.time_history:] # B, T_future, L 163 | T_future = u_future.shape[1] 164 | 165 | u_hat = [] 166 | inp = u_history 167 | for t in range(T_future//self.time_future): 168 | y_hat = self.forward(inp.permute(0,2,1), dx, dt).permute(0,2,1) 169 | u_hat.append(y_hat) 170 | if self.teacher_forcing: 171 | inp = u_future[:,t*self.time_future:(t+1)*self.time_future] 172 | else: 173 | inp = y_hat 174 | 175 | u_hat = torch.cat(u_hat, dim=1) 176 | 177 | loss = self.criterion(u_hat, u_future) 178 | mae_loss = self.mae_criterion(u_hat, u_future) 179 | 180 | self.log('train_loss', loss, prog_bar=True) 181 | self.log('train_mae_loss', mae_loss, prog_bar=True) 182 | 183 | return loss 184 | 185 | def validation_step(self, val_batch, batch_idx): 186 | u, dx, dt = val_batch 187 | u = u.float() 188 | dx = dx.float() 189 | dt = dt.float() 190 | 191 | L = u.shape[-1] 192 | 193 | u_history = u[:,:self.time_history] # B, T_history, L 194 | u_future = u[:,self.time_history:] # B, T_future, L 195 | T_future = u_future.shape[1] 196 | 197 | u_hat = [] 198 | inp = u_history 199 | for t in range(T_future//self.time_future): 200 | y_hat = self.forward(inp.permute(0,2,1), dx, dt).permute(0,2,1) 201 | u_hat.append(y_hat) 202 | inp = y_hat 203 | 204 | u_hat = torch.cat(u_hat, dim=1) 205 | 206 | loss = self.criterion(u_hat, u_future) 207 | mae_loss = self.mae_criterion(u_hat, u_future) 208 | 209 | self.log('val_loss', loss, prog_bar=True) 210 | self.log('val_mae_loss', mae_loss, prog_bar=True) 211 | return loss -------------------------------------------------------------------------------- /datamodule/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | 8 | from torch_geometric.nn import radius_graph 9 | 10 | from utils import * 11 | 12 | class HDF5DatasetGraph(Dataset): 13 | 14 | def __init__(self, 15 | path, 16 | nt, 17 | nx, 18 | mode='train', 19 | load_all=False, 20 | in_timesteps=16, 21 | radius=2): 22 | 23 | assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']" 24 | 25 | f = h5py.File(path, 'r') 26 | self.mode = mode 27 | self.data = f[self.mode] 28 | self.dataset = f'pde_{nt}-{nx}' 29 | self.radius = radius 30 | self.in_timesteps = in_timesteps 31 | 32 | if load_all: 33 | data = {self.dataset: self.data[self.dataset][:]} 34 | f.close() 35 | self.data = data 36 | 37 | def __len__(self): 38 | return self.data[self.dataset].shape[0] 39 | 40 | def __getitem__(self, idx): 41 | 42 | x = torch.from_numpy(self.data['x'][idx]).unsqueeze(-1) # N, 1 43 | t = torch.from_numpy(self.data['t'][idx]) # T 44 | u = torch.from_numpy(self.data[self.dataset][idx]).permute(1,0) # N, T 45 | 46 | return_tensors = { 47 | 'u': u, 48 | 'x': x, 49 | 't': t 50 | } 51 | return return_tensors 52 | 53 | class HDF5DatasetImplicitGNN(Dataset): 54 | 55 | def __init__(self, 56 | path, 57 | nt, 58 | nx, 59 | sampling='uniform', 60 | mode='train', 61 | load_all=False, 62 | samples = 256): 63 | 64 | assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']" 65 | 66 | f = h5py.File(path, 'r') 67 | self.mode = mode 68 | self.data = f[self.mode] 69 | self.dataset = f'pde_{nt}-{nx}' 70 | self.samples = samples 71 | self.sampling = sampling 72 | 73 | if load_all: 74 | data = {self.dataset: self.data[self.dataset][:]} 75 | f.close() 76 | self.data = data 77 | 78 | def __len__(self): 79 | return self.data[self.dataset].shape[0] 80 | 81 | def __getitem__(self, idx): 82 | 83 | x = self.data['x'][idx] 84 | # Normalize time coordinates 85 | x = 2*(x-x.min())/(x.max()-x.min())-1 86 | 87 | t = self.data['t'][idx] 88 | u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, L 89 | T, _, L = u_hr.shape 90 | u_lr = u_hr[:,:,::2] # T, 1, L//2 91 | lr_coord = x[::2] 92 | 93 | if self.mode in ['train']: 94 | indices_left = np.setdiff1d(np.arange(0,L), np.arange(0,L)[::2]) 95 | sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False))) 96 | hr_coord = x[sample_lst] 97 | 98 | hr_points = u_hr[:,:,sample_lst].permute(0,2,1) 99 | 100 | return_tensors = { 101 | 't': t, 102 | 'sample_idx': sample_lst, 103 | 'lr_frames': u_lr, 104 | 'hr_frames': u_hr, 105 | 'hr_points': hr_points, 106 | 'coords_hr': hr_coord, 107 | 'coords_lr': lr_coord 108 | } 109 | else: 110 | indices_left = np.setdiff1d(np.arange(0,L), np.arange(0,L)[::2]) 111 | hr_coord = x[indices_left] 112 | 113 | hr_points = u_hr[:,:,indices_left].permute(0,2,1) 114 | 115 | return_tensors = { 116 | 't': t, 117 | 'lr_frames': u_lr, 118 | 'hr_frames': u_hr, 119 | 'hr_points': hr_points, 120 | 'coords_hr': hr_coord, 121 | 'coords_lr': lr_coord 122 | } 123 | 124 | return return_tensors 125 | 126 | class HDF5DatasetImplicit(Dataset): 127 | 128 | def __init__(self, 129 | path, 130 | nt, 131 | nx, 132 | sampling='uniform', 133 | mode='train', 134 | load_all=False, 135 | samples = 256): 136 | 137 | assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']" 138 | 139 | f = h5py.File(path, 'r') 140 | self.mode = mode 141 | self.data = f[self.mode] 142 | self.dataset = f'pde_{nt}-{nx}' 143 | self.samples = samples 144 | self.sampling = sampling 145 | 146 | if load_all: 147 | data = {self.dataset: self.data[self.dataset][:]} 148 | f.close() 149 | self.data = data 150 | 151 | def __len__(self): 152 | return self.data[self.dataset].shape[0] 153 | 154 | def __getitem__(self, idx): 155 | 156 | x = self.data['x'][idx] 157 | t = self.data['t'][idx] 158 | u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, L 159 | 160 | T, _, L = u_hr.shape 161 | u_lr = F.interpolate(u_hr, size=(L // 2), mode='linear', align_corners=False) # T, 1, L//2 162 | 163 | if self.mode in ['train']: 164 | if self.sampling == 'uniform': 165 | sample_lst = torch.tensor(sorted(np.random.choice(L, self.samples, replace=False))) 166 | elif self.sampling == 'boundary': 167 | p = torch.softmax(torch.pow(torch.abs(torch.arange(L)-L//2)/L, 2)/0.1, dim=0).numpy() 168 | sample_lst = torch.tensor(sorted(np.random.choice(L, self.samples, p=p, replace=False))) 169 | hr_coord = make_coord([L])[sample_lst] 170 | 171 | cell = torch.ones_like(hr_coord) 172 | cell *= 2 / L 173 | hr_points = torch.stack([to_pixel_samples(u_hr[i])[1][sample_lst] for i in range(T)], dim=0) 174 | 175 | return_tensors = { 176 | 't': t, 177 | 'sample_idx': sample_lst, 178 | 'lr_frames': u_lr, 179 | 'hr_frames': u_hr, 180 | 'hr_points': hr_points, 181 | 'coords': hr_coord, 182 | 'cells': cell 183 | } 184 | else: 185 | hr_coord = make_coord([L]) 186 | 187 | cell = torch.ones_like(hr_coord) 188 | cell *= 2 / L 189 | hr_points = torch.stack([to_pixel_samples(u_hr[i])[1] for i in range(T)], dim=0) 190 | 191 | return_tensors = { 192 | 't': t, 193 | 'lr_frames': u_lr, 194 | 'hr_frames': u_hr, 195 | 'hr_points': hr_points, 196 | 'coords': hr_coord, 197 | 'cells': cell 198 | } 199 | 200 | return return_tensors 201 | 202 | 203 | 204 | class HDF5Dataset(Dataset): 205 | """ 206 | Load samples of an PDE Dataset, get items according to PDE. 207 | """ 208 | def __init__(self, 209 | path: str, 210 | mode: str, 211 | nt: int, 212 | nx: int, 213 | dtype=torch.float32, 214 | load_all: bool=False): 215 | """Initialize the dataset object. 216 | Args: 217 | path: path to dataset 218 | mode: [train, valid, test] 219 | nt: temporal resolution 220 | nx: spatial resolution 221 | shift: [fourier, linear] 222 | dtype: floating precision of data 223 | load_all: load all the data into memory 224 | """ 225 | super().__init__() 226 | f = h5py.File(path, 'r') 227 | self.mode = mode 228 | self.dtype = dtype 229 | self.data = f[self.mode] 230 | self.dataset = f'pde_{nt}-{nx}' 231 | 232 | if load_all: 233 | data = {self.dataset: self.data[self.dataset][:]} 234 | f.close() 235 | self.data = data 236 | 237 | def __len__(self): 238 | return self.data[self.dataset].shape[0] 239 | 240 | def __getitem__(self, idx: int): 241 | """ 242 | Returns data items for batched training/validation/testing. 243 | Args: 244 | idx: data index 245 | Returns: 246 | torch.Tensor: data trajectory used for training/validation/testing 247 | torch.Tensor: dx 248 | torch.Tensor: dt 249 | """ 250 | u = torch.from_numpy(self.data[self.dataset][idx]) 251 | x = torch.from_numpy(self.data['x'][idx]) 252 | t = torch.from_numpy(self.data['t'][idx]) 253 | 254 | dx = torch.diff(x)[0] 255 | dt = torch.diff(t)[0] 256 | 257 | return u.float(), dx.float(), dt.float() -------------------------------------------------------------------------------- /datamodule/dataset_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | import torch.nn.functional as F 7 | 8 | from torch_geometric.nn import radius_graph 9 | 10 | from utils import * 11 | 12 | class HDF5DatasetGraph_2d(Dataset): 13 | 14 | def __init__(self, 15 | path, 16 | nt, 17 | res, 18 | mode='train', 19 | regular=True, 20 | load_all=False): 21 | 22 | assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']" 23 | 24 | f = h5py.File(path, 'r') 25 | self.mode = mode 26 | self.data = f[self.mode] 27 | self.regular = regular 28 | self.dataset = f'pde_{nt}-{res}' 29 | 30 | if load_all: 31 | data = {self.dataset: self.data[self.dataset][:]} 32 | f.close() 33 | self.data = data 34 | 35 | def __len__(self): 36 | return self.data[self.dataset].shape[0] 37 | 38 | def __getitem__(self, idx): 39 | 40 | u = torch.from_numpy(self.data[self.dataset][idx]) # T, W, W 41 | W = u.shape[-1] 42 | u = u.reshape(u.shape[0], -1) # T, WW 43 | u = u.permute(1,0) # WW, T 44 | 45 | if self.regular: 46 | x = torch.from_numpy(self.data['x'][idx]) 47 | y = torch.from_numpy(self.data['y'][idx]) 48 | coords = torch.stack(torch.meshgrid(x,y), dim=-1).reshape(-1,2) 49 | else: 50 | coords = torch.from_numpy(self.data['coords'][idx]) 51 | t = torch.from_numpy(self.data['t'][idx]) # T 52 | 53 | 54 | return_tensors = { 55 | 'u': u, 56 | 'x': coords, 57 | 't': t 58 | } 59 | return return_tensors 60 | 61 | class HDF5DatasetImplicitGNN_2d(Dataset): 62 | 63 | def __init__(self, 64 | path, 65 | nt, 66 | res, 67 | mode='train', 68 | regular=True, 69 | load_all=False, 70 | samples = 256): 71 | 72 | assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']" 73 | 74 | f = h5py.File(path, 'r') 75 | self.mode = mode 76 | self.data = f[self.mode] 77 | self.dataset = f'pde_{nt}-{res}' 78 | self.samples = samples 79 | self.regular = regular 80 | 81 | if load_all: 82 | data = {self.dataset: self.data[self.dataset][:]} 83 | f.close() 84 | self.data = data 85 | 86 | def __len__(self): 87 | return self.data[self.dataset].shape[0] 88 | 89 | def __getitem__(self, idx): 90 | 91 | if self.regular: 92 | x = self.data['x'][idx] 93 | y = self.data['y'][idx] 94 | coords = np.stack(np.meshgrid(x,y), axis=-1) 95 | coords = coords.reshape(-1, coords.shape[-1]) 96 | u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, W, W 97 | u_hr = u_hr.reshape(u_hr.shape[0], 1, -1) 98 | else: 99 | coords = self.data['coords'][idx] # N, 2 100 | u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, N 101 | coords = 2*(coords-coords.min(0))/(coords.max(0)-coords.min(0))-1 102 | 103 | t = self.data['t'][idx] 104 | 105 | T, _, N = u_hr.shape 106 | u_lr = u_hr[:,:,::2] # T, 1, N//2 107 | lr_coord = coords[::2] 108 | 109 | if self.mode in ['train']: 110 | indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2]) 111 | sample_lst = torch.tensor(sorted(np.random.choice(indices_left, self.samples, replace=False))) 112 | hr_coord = coords[sample_lst] 113 | 114 | hr_points = u_hr[:,:,sample_lst].permute(0,2,1) 115 | 116 | return_tensors = { 117 | 't': t, 118 | 'sample_idx': sample_lst, 119 | 'lr_frames': u_lr, 120 | 'hr_frames': u_hr, 121 | 'hr_points': hr_points, 122 | 'coords_hr': hr_coord, 123 | 'coords_lr': lr_coord 124 | } 125 | else: 126 | indices_left = np.setdiff1d(np.arange(0,N), np.arange(0,N)[::2]) 127 | hr_coord = coords[indices_left] 128 | 129 | hr_points = u_hr[:,:,indices_left].permute(0,2,1) 130 | 131 | return_tensors = { 132 | 't': t, 133 | 'lr_frames': u_lr, 134 | 'hr_frames': u_hr, 135 | 'hr_points': hr_points, 136 | 'coords_hr': hr_coord, 137 | 'coords_lr': lr_coord 138 | } 139 | 140 | return return_tensors 141 | 142 | class HDF5DatasetImplicit_2d(Dataset): 143 | 144 | def __init__(self, 145 | path, 146 | nt, 147 | res, 148 | mode='train', 149 | load_all=False, 150 | samples = 256): 151 | 152 | assert mode in ['train', 'valid', 'test'], "mode must belong to one of these ['train', 'val', 'test']" 153 | 154 | f = h5py.File(path, 'r') 155 | self.mode = mode 156 | self.data = f[self.mode] 157 | self.dataset = f'pde_{nt}-{res}' 158 | self.samples = samples 159 | 160 | if load_all: 161 | data = {self.dataset: self.data[self.dataset][:]} 162 | f.close() 163 | self.data = data 164 | 165 | def __len__(self): 166 | return self.data[self.dataset].shape[0] 167 | 168 | def __getitem__(self, idx): 169 | 170 | t = self.data['t'][idx] 171 | u_hr = torch.from_numpy(self.data[self.dataset][idx]).unsqueeze(1) # T, 1, W, W 172 | 173 | T, _, W, W = u_hr.shape 174 | u_lr = F.interpolate(u_hr, size=(W // 2), mode='bilinear', align_corners=False) # T, 1, W//2, W//2 175 | 176 | if self.mode in ['train']: 177 | sample_lst = torch.tensor(sorted(np.random.choice(W*W, self.samples, replace=False))) 178 | 179 | hr_coord = make_coord([W, W])[sample_lst] 180 | 181 | cell = torch.ones_like(hr_coord) 182 | cell *= 2 / W 183 | hr_points = torch.stack([to_pixel_samples(u_hr[i])[1][sample_lst] for i in range(T)], dim=0) 184 | 185 | return_tensors = { 186 | 't': t, 187 | 'sample_idx': sample_lst, 188 | 'lr_frames': u_lr, 189 | 'hr_frames': u_hr, 190 | 'hr_points': hr_points, 191 | 'coords': hr_coord, 192 | 'cells': cell 193 | } 194 | else: 195 | hr_coord = make_coord([W, W]) 196 | 197 | cell = torch.ones_like(hr_coord) 198 | cell *= 2 / W 199 | hr_points = torch.stack([to_pixel_samples(u_hr[i])[1] for i in range(T)], dim=0) 200 | 201 | return_tensors = { 202 | 't': t, 203 | 'lr_frames': u_lr, 204 | 'hr_frames': u_hr, 205 | 'hr_points': hr_points, 206 | 'coords': hr_coord, 207 | 'cells': cell 208 | } 209 | 210 | return return_tensors 211 | 212 | 213 | class HDF5Dataset(Dataset): 214 | """ 215 | Load samples of an PDE Dataset, get items according to PDE. 216 | """ 217 | def __init__(self, 218 | path: str, 219 | mode: str, 220 | nt: int, 221 | res: int, 222 | dtype=torch.float32, 223 | load_all: bool=False): 224 | """Initialize the dataset object. 225 | Args: 226 | path: path to dataset 227 | mode: [train, valid, test] 228 | nt: temporal resolution 229 | res: spatial resolution 230 | shift: [fourier, linear] 231 | dtype: floating precision of data 232 | load_all: load all the data into memory 233 | """ 234 | super().__init__() 235 | f = h5py.File(path, 'r') 236 | self.mode = mode 237 | self.dtype = dtype 238 | self.data = f[self.mode] 239 | self.dataset = f'pde_{nt}-{res}' 240 | 241 | if load_all: 242 | data = {self.dataset: self.data[self.dataset][:]} 243 | f.close() 244 | self.data = data 245 | 246 | def __len__(self): 247 | return self.data[self.dataset].shape[0] 248 | 249 | def __getitem__(self, idx: int): 250 | """ 251 | Returns data items for batched training/validation/testing. 252 | Args: 253 | idx: data index 254 | Returns: 255 | torch.Tensor: data trajectory used for training/validation/testing 256 | torch.Tensor: dx 257 | torch.Tensor: dt 258 | """ 259 | u = torch.from_numpy(self.data[self.dataset][idx]) 260 | dx = torch.from_numpy(self.data['dx'][idx])[0] 261 | dy = torch.from_numpy(self.data['dy'][idx])[0] 262 | dt = torch.from_numpy(self.data['dt'][idx])[0] 263 | 264 | return u, dx, dy, dt -------------------------------------------------------------------------------- /datamodule/h5_datamodule.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | import pytorch_lightning as pl 4 | 5 | from .dataset import * 6 | 7 | 8 | class HDF5Datamodule(pl.LightningDataModule): 9 | def __init__( 10 | self, 11 | name='h5_datamodule', 12 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 13 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 14 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 15 | nt_train=128, 16 | nx_train=256, 17 | nt_val=128, 18 | nx_val=256, 19 | nt_test=256, 20 | nx_test=256, 21 | num_workers=2, 22 | batch_size = 32): 23 | super().__init__() 24 | 25 | self.save_hyperparameters() 26 | 27 | self.name = name 28 | self.train_path = train_path 29 | self.val_path = val_path 30 | self.test_path = test_path 31 | self.nt_train = nt_train 32 | self.nx_train = nx_train 33 | self.nt_val = nt_val 34 | self.nx_val = nx_val 35 | self.nt_test = nt_test 36 | self.nx_test = nx_test 37 | 38 | 39 | self.batch_size = batch_size 40 | self.num_workers = num_workers 41 | 42 | def setup(self, stage = None): 43 | 44 | self.train_dataset = HDF5Dataset( 45 | path=self.train_path, 46 | mode='train', 47 | nt=self.nt_train, 48 | nx=self.nx_train, 49 | dtype=torch.float32) 50 | 51 | self.val_dataset = HDF5Dataset( 52 | path=self.val_path, 53 | mode='valid', 54 | nt=self.nt_val, 55 | nx=self.nx_val, 56 | dtype=torch.float32) 57 | 58 | self.test_dataset = HDF5Dataset( 59 | path=self.test_path, 60 | mode='test', 61 | nt=self.nt_test, 62 | nx=self.nx_test, 63 | dtype=torch.float32) 64 | 65 | def train_dataloader(self): 66 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 67 | 68 | def val_dataloader(self): 69 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 70 | 71 | def test_dataloader(self): 72 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 73 | 74 | 75 | class HDF5DatamoduleImplicit(pl.LightningDataModule): 76 | def __init__( 77 | self, 78 | name='h5_datamodule_implicit', 79 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 80 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 81 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 82 | nt_train=128, 83 | nx_train=256, 84 | nt_val=128, 85 | nx_val=256, 86 | nt_test=256, 87 | nx_test=256, 88 | samples=32, 89 | sampling='uniform', 90 | num_workers=2, 91 | batch_size = 32): 92 | super().__init__() 93 | 94 | self.save_hyperparameters() 95 | 96 | self.name = name 97 | self.train_path = train_path 98 | self.val_path = val_path 99 | self.test_path = test_path 100 | self.nt_train = nt_train 101 | self.nx_train = nx_train 102 | self.sampling = sampling 103 | self.nt_val = nt_val 104 | self.nx_val = nx_val 105 | self.nt_test = nt_test 106 | self.nx_test = nx_test 107 | self.samples = samples 108 | 109 | 110 | self.batch_size = batch_size 111 | self.num_workers = num_workers 112 | 113 | def setup(self, stage = None): 114 | 115 | self.train_dataset = HDF5DatasetImplicit( 116 | path=self.train_path, 117 | mode='train', 118 | sampling=self.sampling, 119 | nt=self.nt_train, 120 | nx=self.nx_train, 121 | samples=self.samples) 122 | 123 | self.val_dataset = HDF5DatasetImplicit( 124 | path=self.val_path, 125 | mode='valid', 126 | sampling=self.sampling, 127 | nt=self.nt_val, 128 | nx=self.nx_val, 129 | samples=self.samples) 130 | 131 | self.test_dataset = HDF5DatasetImplicit( 132 | path=self.test_path, 133 | mode='test', 134 | sampling=self.sampling, 135 | nt=self.nt_test, 136 | nx=self.nx_test, 137 | samples=self.samples) 138 | 139 | def train_dataloader(self): 140 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 141 | 142 | def val_dataloader(self): 143 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 144 | 145 | def test_dataloader(self): 146 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 147 | 148 | 149 | class HDF5DatamoduleGraph(pl.LightningDataModule): 150 | def __init__( 151 | self, 152 | name='h5_datamodule_implicit', 153 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 154 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 155 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 156 | nt_train=128, 157 | nx_train=256, 158 | nt_val=128, 159 | nx_val=256, 160 | nt_test=256, 161 | nx_test=256, 162 | in_timesteps=16, 163 | radius=2, 164 | num_workers=2, 165 | batch_size = 32): 166 | super().__init__() 167 | 168 | self.save_hyperparameters() 169 | 170 | self.name = name 171 | self.train_path = train_path 172 | self.val_path = val_path 173 | self.test_path = test_path 174 | self.nt_train = nt_train 175 | self.nx_train = nx_train 176 | self.nt_val = nt_val 177 | self.nx_val = nx_val 178 | self.nt_test = nt_test 179 | self.nx_test = nx_test 180 | self.in_timesteps = in_timesteps 181 | self.radius = radius 182 | 183 | 184 | self.batch_size = batch_size 185 | self.num_workers = num_workers 186 | 187 | def setup(self, stage = None): 188 | 189 | self.train_dataset = HDF5DatasetGraph( 190 | path=self.train_path, 191 | mode='train', 192 | nt=self.nt_train, 193 | nx=self.nx_train, 194 | in_timesteps=self.in_timesteps, 195 | radius=self.radius) 196 | 197 | self.val_dataset = HDF5DatasetGraph( 198 | path=self.val_path, 199 | mode='valid', 200 | nt=self.nt_val, 201 | nx=self.nx_val, 202 | in_timesteps=self.in_timesteps, 203 | radius=self.radius) 204 | 205 | self.test_dataset = HDF5DatasetGraph( 206 | path=self.test_path, 207 | mode='test', 208 | nt=self.nt_test, 209 | nx=self.nx_test, 210 | in_timesteps=self.in_timesteps, 211 | radius=self.radius) 212 | 213 | def train_dataloader(self): 214 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 215 | 216 | def val_dataloader(self): 217 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 218 | 219 | def test_dataloader(self): 220 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 221 | 222 | 223 | class HDF5DatamoduleImplicitGNN(pl.LightningDataModule): 224 | def __init__( 225 | self, 226 | name='h5_datamodule_implicit_gnn', 227 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 228 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 229 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 230 | nt_train=128, 231 | nx_train=256, 232 | nt_val=128, 233 | nx_val=256, 234 | nt_test=256, 235 | nx_test=256, 236 | samples=32, 237 | sampling='uniform', 238 | num_workers=2, 239 | batch_size = 32): 240 | super().__init__() 241 | 242 | self.save_hyperparameters() 243 | 244 | self.name = name 245 | self.train_path = train_path 246 | self.val_path = val_path 247 | self.test_path = test_path 248 | self.nt_train = nt_train 249 | self.nx_train = nx_train 250 | self.nt_val = nt_val 251 | self.nx_val = nx_val 252 | self.nt_test = nt_test 253 | self.nx_test = nx_test 254 | self.samples = samples 255 | self.sampling = sampling 256 | 257 | 258 | self.batch_size = batch_size 259 | self.num_workers = num_workers 260 | 261 | def setup(self, stage = None): 262 | 263 | self.train_dataset = HDF5DatasetImplicitGNN( 264 | path=self.train_path, 265 | nt=self.nt_train, 266 | nx=self.nx_train, 267 | sampling=self.sampling, 268 | mode='train', 269 | samples=self.samples) 270 | 271 | self.val_dataset = HDF5DatasetImplicitGNN( 272 | path=self.val_path, 273 | nt=self.nt_val, 274 | nx=self.nx_val, 275 | sampling=self.sampling, 276 | mode='valid', 277 | samples=self.samples) 278 | 279 | self.test_dataset = HDF5DatasetImplicitGNN( 280 | path=self.test_path, 281 | nt=self.nt_test, 282 | nx=self.nx_test, 283 | sampling=self.sampling, 284 | mode='test', 285 | samples=self.samples) 286 | 287 | def train_dataloader(self): 288 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 289 | 290 | def val_dataloader(self): 291 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 292 | 293 | def test_dataloader(self): 294 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) -------------------------------------------------------------------------------- /datamodule/h5_datamodule_2d.py: -------------------------------------------------------------------------------- 1 | from cgitb import reset 2 | from torch.utils.data import DataLoader 3 | 4 | import pytorch_lightning as pl 5 | 6 | from .dataset_2d import * 7 | 8 | 9 | class HDF5Datamodule_2d(pl.LightningDataModule): 10 | def __init__( 11 | self, 12 | name='h5_datamodule_2d', 13 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 14 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 15 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 16 | nt_train=128, 17 | res_train=256, 18 | nt_val=128, 19 | res_val=256, 20 | nt_test=256, 21 | res_test=256, 22 | num_workers=2, 23 | batch_size=32): 24 | super().__init__() 25 | 26 | self.save_hyperparameters() 27 | 28 | self.name = name 29 | self.train_path = train_path 30 | self.val_path = val_path 31 | self.test_path = test_path 32 | self.nt_train = nt_train 33 | self.res_train = res_train 34 | self.nt_val = nt_val 35 | self.res_val = res_val 36 | self.nt_test = nt_test 37 | self.res_test = res_test 38 | 39 | 40 | self.batch_size = batch_size 41 | self.num_workers = num_workers 42 | 43 | def setup(self, stage = None): 44 | 45 | self.train_dataset = HDF5Dataset( 46 | path=self.train_path, 47 | mode='train', 48 | nt=self.nt_train, 49 | res=self.res_train, 50 | dtype=torch.float32) 51 | 52 | self.val_dataset = HDF5Dataset( 53 | path=self.val_path, 54 | mode='test', 55 | nt=self.nt_val, 56 | res=self.res_val, 57 | dtype=torch.float32) 58 | 59 | self.test_dataset = HDF5Dataset( 60 | path=self.test_path, 61 | mode='test', 62 | nt=self.nt_test, 63 | res=self.res_test, 64 | dtype=torch.float32) 65 | 66 | def train_dataloader(self): 67 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 68 | 69 | def val_dataloader(self): 70 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 71 | 72 | def test_dataloader(self): 73 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 74 | 75 | 76 | class HDF5DatamoduleImplicit_2d(pl.LightningDataModule): 77 | def __init__( 78 | self, 79 | name='h5_datamodule_implicit_2d', 80 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 81 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 82 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 83 | nt_train=128, 84 | res_train=256, 85 | nt_val=128, 86 | res_val=256, 87 | nt_test=256, 88 | res_test=256, 89 | samples=32, 90 | num_workers=2, 91 | batch_size = 32): 92 | super().__init__() 93 | 94 | self.save_hyperparameters() 95 | 96 | self.name = name 97 | self.train_path = train_path 98 | self.val_path = val_path 99 | self.test_path = test_path 100 | self.nt_train = nt_train 101 | self.res_train = res_train 102 | self.nt_val = nt_val 103 | self.res_val = res_val 104 | self.nt_test = nt_test 105 | self.res_test = res_test 106 | self.samples = samples 107 | 108 | 109 | self.batch_size = batch_size 110 | self.num_workers = num_workers 111 | 112 | def setup(self, stage = None): 113 | 114 | self.train_dataset = HDF5DatasetImplicit_2d( 115 | path=self.train_path, 116 | mode='train', 117 | nt=self.nt_train, 118 | res=self.res_train, 119 | samples=self.samples) 120 | 121 | self.val_dataset = HDF5DatasetImplicit_2d( 122 | path=self.val_path, 123 | mode='test', 124 | nt=self.nt_val, 125 | res=self.res_val, 126 | samples=self.samples) 127 | 128 | self.test_dataset = HDF5DatasetImplicit_2d( 129 | path=self.test_path, 130 | mode='test', 131 | nt=self.nt_test, 132 | res=self.res_test, 133 | samples=self.samples) 134 | 135 | def train_dataloader(self): 136 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 137 | 138 | def val_dataloader(self): 139 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 140 | 141 | def test_dataloader(self): 142 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 143 | 144 | 145 | class HDF5DatamoduleGraph_2d(pl.LightningDataModule): 146 | def __init__( 147 | self, 148 | name='h5_datamodule_graph_2d', 149 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 150 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 151 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 152 | nt_train=128, 153 | res_train=256, 154 | nt_val=128, 155 | res_val=256, 156 | nt_test=256, 157 | res_test=256, 158 | train_regular=True, 159 | val_regular=True, 160 | test_regular=True, 161 | num_workers=2, 162 | batch_size=32): 163 | super().__init__() 164 | 165 | self.save_hyperparameters() 166 | 167 | self.name = name 168 | self.train_path = train_path 169 | self.val_path = val_path 170 | self.test_path = test_path 171 | self.nt_train = nt_train 172 | self.res_train = res_train 173 | self.nt_val = nt_val 174 | self.res_val = res_val 175 | self.nt_test = nt_test 176 | self.res_test = res_test 177 | self.train_regular = train_regular 178 | self.val_regular = val_regular 179 | self.test_regular = test_regular 180 | 181 | 182 | self.batch_size = batch_size 183 | self.num_workers = num_workers 184 | 185 | def setup(self, stage = None): 186 | 187 | self.train_dataset = HDF5DatasetGraph_2d( 188 | path=self.train_path, 189 | mode='train', 190 | regular=self.train_regular, 191 | nt=self.nt_train, 192 | res=self.res_train) 193 | 194 | self.val_dataset = HDF5DatasetGraph_2d( 195 | path=self.val_path, 196 | mode='test', 197 | regular=self.val_regular, 198 | nt=self.nt_val, 199 | res=self.res_val) 200 | 201 | self.test_dataset = HDF5DatasetGraph_2d( 202 | path=self.test_path, 203 | mode='test', 204 | regular=self.test_regular, 205 | nt=self.nt_test, 206 | res=self.res_test) 207 | 208 | def train_dataloader(self): 209 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 210 | 211 | def val_dataloader(self): 212 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 213 | 214 | def test_dataloader(self): 215 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 216 | 217 | 218 | class HDF5DatamoduleImplicitGNN_2d(pl.LightningDataModule): 219 | def __init__( 220 | self, 221 | name='h5_datamodule_implicit_gnn', 222 | train_path="/content/drive/MyDrive/MILA/snapshots.h5", 223 | val_path="/content/drive/MyDrive/MILA/snapshots.h5", 224 | test_path="/content/drive/MyDrive/MILA/snapshots.h5", 225 | nt_train=128, 226 | res_train=256, 227 | nt_val=128, 228 | res_val=256, 229 | nt_test=256, 230 | res_test=256, 231 | train_regular=False, 232 | val_regular=True, 233 | test_regular=True, 234 | samples=32, 235 | num_workers=2, 236 | batch_size = 32): 237 | super().__init__() 238 | 239 | self.save_hyperparameters() 240 | 241 | self.name = name 242 | self.train_path = train_path 243 | self.val_path = val_path 244 | self.test_path = test_path 245 | self.nt_train = nt_train 246 | self.res_train = res_train 247 | self.nt_val = nt_val 248 | self.res_val = res_val 249 | self.nt_test = nt_test 250 | self.res_test = res_test 251 | self.samples = samples 252 | self.train_regular = train_regular 253 | self.val_regular = val_regular 254 | self.test_regular = test_regular 255 | 256 | self.batch_size = batch_size 257 | self.num_workers = num_workers 258 | 259 | def setup(self, stage = None): 260 | 261 | self.train_dataset = HDF5DatasetImplicitGNN_2d( 262 | path=self.train_path, 263 | nt=self.nt_train, 264 | res=self.res_train, 265 | mode='train', 266 | regular=self.train_regular, 267 | samples=self.samples) 268 | 269 | self.val_dataset = HDF5DatasetImplicitGNN_2d( 270 | path=self.val_path, 271 | nt=self.nt_val, 272 | res=self.res_val, 273 | mode='test', 274 | regular=self.val_regular, 275 | samples=self.samples) 276 | 277 | self.test_dataset = HDF5DatasetImplicitGNN_2d( 278 | path=self.test_path, 279 | nt=self.nt_test, 280 | res=self.res_test, 281 | mode='test', 282 | regular=self.test_regular, 283 | samples=self.samples) 284 | 285 | def train_dataloader(self): 286 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 287 | 288 | def val_dataloader(self): 289 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) 290 | 291 | def test_dataloader(self): 292 | return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, pin_memory=True) -------------------------------------------------------------------------------- /models/magnet_cnn_no_interaction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | import pytorch_lightning as pl 8 | 9 | from models.backbones.edsr import EDSR 10 | from models.backbones.mlp import MLP 11 | from utils import * 12 | 13 | 14 | class MAgNetCNN_no_interaction(pl.LightningModule): 15 | def __init__(self,hparams): 16 | 17 | super().__init__() 18 | 19 | self.save_hyperparameters() 20 | 21 | # Training parameters 22 | self.lr = hparams.lr 23 | self.weight_decay = hparams.weight_decay 24 | self.factor = hparams.factor 25 | self.step_size = hparams.step_size 26 | self.loss = hparams.loss 27 | # Model parameters 28 | self.time_slice = hparams.time_slice 29 | self.use_lstm = hparams.use_lstm 30 | self.lstm_hidden = hparams.lstm_hidden 31 | self.lstm_layers = hparams.lstm_layers 32 | self.mlp_layers = hparams.mlp_layers 33 | self.mlp_hidden = hparams.mlp_hidden 34 | self.scales = hparams.scales 35 | self.teacher_forcing = hparams.teacher_forcing 36 | self.res_layers = hparams.res_layers 37 | self.n_chan = hparams.n_chan 38 | self.kernel_size = hparams.kernel_size 39 | self.res_scale = hparams.res_scale 40 | self.interpolation = hparams.interpolation 41 | 42 | if self.loss == 'l1': 43 | self.criterion = nn.L1Loss() 44 | elif self.loss == 'l2': 45 | self.criterion = nn.MSELoss() 46 | elif self.loss == 'smooth_l1': 47 | self.criterion = nn.SmoothL1Loss() 48 | 49 | self.mse_criterion = nn.MSELoss() 50 | self.mae_criterion = nn.L1Loss() 51 | 52 | self.encoder = EDSR(**{ 53 | "res_layers": self.res_layers, 54 | "n_chan": self.n_chan, 55 | "kernel_size": self.kernel_size, 56 | "res_scale": self.res_scale, 57 | "mode": "1d", 58 | "in_chan": self.time_slice}) 59 | 60 | self.proj_head = nn.Linear(self.encoder.out_dim+3+1+self.lstm_hidden, self.lstm_hidden) 61 | 62 | if self.use_lstm: 63 | self.lstm_encoder = nn.LSTM(2+self.lstm_hidden, self.lstm_hidden, self.lstm_layers, batch_first=True) 64 | self.lstm_decoder = nn.LSTM(2*self.lstm_hidden, self.lstm_hidden, self.lstm_layers, batch_first=True) 65 | self.attn = nn.Sequential( 66 | nn.Linear(3*self.lstm_hidden, self.lstm_hidden), 67 | nn.Tanh(), 68 | nn.Linear(self.lstm_hidden, 1, bias=False)) 69 | 70 | self.layernorm = nn.LayerNorm(self.lstm_hidden) 71 | 72 | self.decoder = MLP( 73 | in_dim=self.lstm_hidden, 74 | hidden_list=[self.mlp_hidden]*self.mlp_layers, 75 | out_dim=1 76 | ) 77 | else: 78 | self.decoder = MLP( 79 | in_dim=self.lstm_hidden, 80 | hidden_list=[self.mlp_hidden]*self.mlp_layers, 81 | out_dim=1 82 | ) 83 | 84 | def att_decoder(self, inp, hidden, encoder_states): 85 | 86 | seq_len = encoder_states.shape[1] 87 | hidden_ = torch.cat([hidden[0][-1:], hidden[1][-1:]], dim=-1).permute(1,0,2) 88 | hidden_ = hidden_.repeat(1,seq_len,1) 89 | alignment_scores = self.attn(torch.cat((hidden_, encoder_states), dim = -1)).squeeze(-1) # shape: [batch, time_in, feat_dim] 90 | alignment_weights = F.softmax(alignment_scores, dim=1).unsqueeze(1) # shape: [batch, 1, time_in] 91 | context = torch.bmm(alignment_weights, encoder_states) # shape: [batch, 1, feat_dim] 92 | 93 | inp_decoder = torch.cat([inp, context], dim=-1) 94 | output, hidden = self.lstm_decoder(inp_decoder, hidden) 95 | 96 | return output, hidden 97 | 98 | def seq2seq_attention(self, x, future_step=1, hidden=None): 99 | ''' 100 | Args: 101 | x, [batch,time_in,in_dim]: Input sequence 102 | ''' 103 | # Encode the input sequence 104 | encoder_states, hidden = self.lstm_encoder(x, hidden) 105 | 106 | inp = encoder_states[:,-1:] # [batch, 1, feat_dim] 107 | outputs = [] 108 | for _ in range(future_step): 109 | output, hidden = self.att_decoder(inp, hidden, encoder_states) 110 | outputs.append(output) 111 | inp = output 112 | outputs = torch.cat(outputs, dim=1) 113 | return outputs, hidden 114 | 115 | def pos_encoding(self, coords): 116 | ''' 117 | Args: 118 | coords, [batch, N_coords, 2]: 2D coordinates 119 | enc_dim, int: Encoding dimension 120 | ''' 121 | x_proj = (2.*np.pi*coords) @ (torch.eye(1).to(coords.device)) 122 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 123 | 124 | def continuous_decoder(self, x_t, feat, cell, coord_hr, t): 125 | ''' 126 | Args: 127 | x_t, [B, T, C, L] 128 | feat, [B, C, L]: Feature maps 129 | cell, [B, N, 1] 130 | coord_hr, [B, N, 1] 131 | t, [B, T_in+T_out] 132 | ''' 133 | B, C, L = feat.shape 134 | T = x_t.shape[1] 135 | N_coords = coord_hr.shape[1] 136 | 137 | # Coordinates in the feature map 138 | feat_coord = make_coord([L], flatten=False).to(feat.device).permute(1,0).unsqueeze(0).expand(B, 1, L) 139 | feat_coord = torch.cat([feat_coord, torch.zeros_like(feat_coord)], dim=1) 140 | feat_coord = feat_coord.unsqueeze(2) # B, 2, 1, L 141 | 142 | dx = 1/L 143 | 144 | tt = t.unsqueeze(1).repeat(1,cell.shape[1],1) # B, N, T 145 | 146 | 147 | output = [] 148 | latent = torch.randn((B, N_coords, self.lstm_hidden)).to(x_t.device) 149 | for i in range(T): 150 | pred_signals = [] 151 | areas = [] 152 | for vx in [-1,1]: 153 | 154 | coord = coord_hr.clone().unsqueeze(1) 155 | coord = torch.cat([coord, torch.zeros_like(coord)], dim=-1) # B, 1, N, 2 156 | coord[:, :, :, 0] += vx * dx + 1e-6 157 | coord.clamp_(-1 + 1e-6, 1 - 1e-6) 158 | 159 | # latent code z* 160 | q_feat = F.grid_sample(feat.unsqueeze(2), coord, mode='nearest', padding_mode="border", align_corners=False)[:,:,0].permute(0,2,1) # B, N, C 161 | # coordinates 162 | q_coord = F.grid_sample(feat_coord, coord, mode='nearest', padding_mode="border", align_corners=False)[:,0].permute(0,2,1) # B, N, 1 163 | # final coordinates 164 | final_coord = coord_hr-q_coord 165 | final_coord *= L 166 | # cell decoding 167 | final_cell = cell.clone() 168 | final_cell *= L 169 | 170 | areas.append(torch.abs(final_coord).reshape(-1,1)) # B*N, 1 171 | 172 | # true solution 173 | q_inp = F.grid_sample(x_t[:,i].unsqueeze(2), coord, mode='nearest', padding_mode="border", align_corners=False)[:,:,0].permute(0,2,1) # B, N, C 174 | # putting all inputs together 175 | final_input = torch.cat([q_feat, q_inp, final_coord, final_cell, latent, tt[:,:,i:i+1]], dim=-1) 176 | final_input = final_input.view(B*N_coords, -1) 177 | 178 | latent = self.proj_head(final_input) # B*N, C 179 | pred_signals.append(latent) 180 | latent = latent.reshape(B, N_coords, -1) 181 | 182 | # Area Interpolation 183 | if self.interpolation == 'area': 184 | ret = (pred_signals[0]*areas[1]+pred_signals[1]*areas[0])/(areas[1]+areas[0]) 185 | else: 186 | ret = (pred_signals[0]*areas[1]+pred_signals[1]*areas[0])/(areas[1]+areas[0]) 187 | output.append(ret) 188 | output = torch.stack(output, dim=1) 189 | 190 | return output 191 | 192 | def feature_encoding(self, x_t, scale=1): 193 | B, T, C, L = x_t.shape 194 | # Encoding x_lr and getting feature maps 195 | x_t = x_t.reshape(B, T*C, L) 196 | x_lr = F.interpolate(x_t, size=(L // 2**scale), mode='linear', align_corners=False) 197 | 198 | feat = self.encoder(x_lr) 199 | 200 | x = x_lr.reshape(B, T, C, -1) 201 | return feat, x 202 | 203 | def forward( 204 | self, 205 | x_t, 206 | coords, 207 | cell, 208 | t, 209 | hr_last, 210 | hiddens=None): 211 | ''' 212 | Args: 213 | x_lr: tensor of shape [B, T, C, L] that represents the low-resolution frames 214 | coord_hr: tensor of shape [B, N, 1] that represents the N coordinates for sequence in the batch 215 | t: tensor of shape [B, T] represents the time-coordinates for each sequence in the batch 216 | hiddens: list of four hidden states to be fed to the LSTM 217 | ''' 218 | B, T = x_t.shape[:2] 219 | N_coords = coords.shape[1] 220 | T_out = t.shape[-1] - T 221 | 222 | z = 0 223 | for s in range(1,self.scales+1): 224 | feat, x_lr = self.feature_encoding(x_t, scale=s) 225 | z +=self.continuous_decoder(x_lr, feat, cell, coords, t) 226 | z = torch.cat([z, self.pos_encoding(coords).reshape(B*N_coords, -1).unsqueeze(1).repeat(1,T,1)], dim=-1) 227 | 228 | if self.use_lstm: 229 | out, hc = self.seq2seq_attention(z, future_step=T_out) 230 | ret = self.layernorm(out) 231 | ret = self.decoder(ret) 232 | 233 | outputs = [] 234 | tt = t.unsqueeze(1).repeat(1,cell.shape[1],1) 235 | 236 | for i in range(T_out): 237 | delta_t = tt[:,:,T+i:T+i+1]-tt[:,:,T-1:T] 238 | op = ret[:,i].view(B, N_coords, -1) 239 | outputs.append(hr_last+delta_t*op) 240 | outputs = torch.stack(outputs, dim=1) 241 | return outputs, None 242 | 243 | def configure_optimizers(self): 244 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 245 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor) 246 | return { 247 | "optimizer": optimizer, 248 | "lr_scheduler": { 249 | "scheduler": scheduler 250 | }, 251 | } 252 | 253 | def training_step(self, train_batch, batch_idx): 254 | t = train_batch['t'].float() 255 | u = train_batch['hr_frames'].float() 256 | u_values = train_batch['hr_points'].float() 257 | coords = train_batch['coords'].float() 258 | cells = train_batch['cells'].float() 259 | sample_idx = train_batch['sample_idx'] 260 | 261 | u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1 262 | B, T_future = u_values_future.shape[:2] 263 | 264 | u_values_hat = [] 265 | inp = u[:,:self.time_slice] 266 | hr_last = u_values[:,self.time_slice-1] 267 | 268 | for i in range(T_future//self.time_slice): 269 | y_hat, _ = self.forward(inp, coords, cells, t[:,i*self.time_slice:(i+2)*self.time_slice], hr_last) 270 | u_values_hat.append(y_hat) 271 | 272 | inp = u[:,(i+1)*self.time_slice:(i+2)*self.time_slice] 273 | 274 | if self.teacher_forcing: 275 | hr_last = u_values[:,(i+2)*self.time_slice-1] 276 | else: 277 | reshape_y_hat = y_hat.permute(0,1,3,2) 278 | 279 | for b in range(B): 280 | inp[b,:,:,sample_idx[b]] = reshape_y_hat[b] 281 | hr_last = y_hat[:,-1] 282 | 283 | u_values_hat = torch.cat(u_values_hat, dim=1) 284 | 285 | loss = self.criterion(u_values_hat, u_values_future) 286 | mae_loss = self.mae_criterion(u_values_hat, u_values_future) 287 | 288 | self.log('train_loss', loss, prog_bar=True) 289 | self.log('train_mae_loss', mae_loss, prog_bar=True) 290 | 291 | return loss 292 | 293 | def validation_step(self, val_batch, batch_idx): 294 | t = val_batch['t'].float() 295 | u = val_batch['hr_frames'].float() # B, T, 1, L 296 | u_values = val_batch['hr_points'].float() 297 | coords = val_batch['coords'].float() 298 | cells = val_batch['cells'].float() 299 | 300 | u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1 301 | T_future = u_values_future.shape[1] 302 | 303 | u_values_hat = [] 304 | inp = u[:,:self.time_slice] 305 | hr_last = u_values[:,self.time_slice-1] 306 | 307 | for i in range(T_future//self.time_slice): 308 | y_hat, _ = self.forward(inp, coords, cells, t[:,i*self.time_slice:(i+2)*self.time_slice], hr_last) 309 | u_values_hat.append(y_hat) 310 | 311 | inp = y_hat.permute(0,1,3,2) 312 | hr_last = y_hat[:,-1] 313 | 314 | u_values_hat = torch.cat(u_values_hat, dim=1) 315 | loss = self.criterion(u_values_hat, u_values_future) 316 | mae_loss = self.mae_criterion(u_values_hat, u_values_future) 317 | 318 | self.log('val_loss', loss, prog_bar=True) 319 | self.log('val_mae_loss', mae_loss, prog_bar=True) -------------------------------------------------------------------------------- /models/mpnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | import pytorch_lightning as pl 8 | 9 | from torch_geometric.nn import MessagePassing, radius_graph, InstanceNorm 10 | from torch_geometric.data import Data 11 | 12 | from models.backbones.mlp import MLP 13 | from utils import * 14 | 15 | class Swish(nn.Module): 16 | """ 17 | Swish activation function 18 | """ 19 | def __init__(self, beta=1): 20 | super(Swish, self).__init__() 21 | self.beta = beta 22 | 23 | def forward(self, x): 24 | return x * torch.sigmoid(self.beta*x) 25 | 26 | 27 | class GNN_Layer(MessagePassing): 28 | """ 29 | Message passing layer 30 | """ 31 | def __init__(self, 32 | in_features: int, 33 | out_features: int, 34 | hidden_features: int, 35 | time_window: int, 36 | n_variables: int): 37 | """ 38 | Initialize message passing layers 39 | Args: 40 | in_features (int): number of node input features 41 | out_features (int): number of node output features 42 | hidden_features (int): number of hidden features 43 | time_window (int): number of input/output timesteps (temporal bundling) 44 | n_variables (int): number of equation specific parameters used in the solver 45 | """ 46 | super(GNN_Layer, self).__init__(node_dim=-2, aggr='mean') 47 | self.in_features = in_features 48 | self.out_features = out_features 49 | self.hidden_features = hidden_features 50 | 51 | self.message_net_1 = nn.Sequential(nn.Linear(2 * in_features + time_window + 1 + n_variables, hidden_features), 52 | Swish() 53 | ) 54 | self.message_net_2 = nn.Sequential(nn.Linear(hidden_features, hidden_features), 55 | Swish() 56 | ) 57 | self.update_net_1 = nn.Sequential(nn.Linear(in_features + hidden_features + n_variables, hidden_features), 58 | Swish() 59 | ) 60 | self.update_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features), 61 | Swish() 62 | ) 63 | self.norm = InstanceNorm(hidden_features) 64 | 65 | def forward(self, x, u, pos, variables, edge_index, batch): 66 | """ 67 | Propagate messages along edges 68 | """ 69 | x = self.propagate(edge_index, x=x, u=u, pos=pos, variables=variables) 70 | x = self.norm(x, batch) 71 | return x 72 | 73 | def message(self, x_i, x_j, u_i, u_j, pos_i, pos_j, variables_i): 74 | """ 75 | Message update following formula 8 of the paper 76 | """ 77 | message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_i - pos_j, variables_i), dim=-1)) 78 | message = self.message_net_2(message) 79 | return message 80 | 81 | def update(self, message, x, variables): 82 | """ 83 | Node update following formula 9 of the paper 84 | """ 85 | update = self.update_net_1(torch.cat((x, message, variables), dim=-1)) 86 | update = self.update_net_2(update) 87 | if self.in_features == self.out_features: 88 | return x + update 89 | else: 90 | return update 91 | 92 | 93 | class MPNN(pl.LightningModule): 94 | def __init__(self,hparams): 95 | 96 | super().__init__() 97 | 98 | self.save_hyperparameters() 99 | 100 | # Training parameters 101 | self.lr = hparams.lr 102 | self.weight_decay = hparams.weight_decay 103 | self.factor = hparams.factor 104 | self.step_size = hparams.step_size 105 | self.loss = hparams.loss 106 | # Model parameters 107 | self.out_features = hparams.time_window 108 | self.hidden_features = hparams.hidden_features 109 | self.hidden_layer = hparams.hidden_layer 110 | self.time_window = hparams.time_window 111 | self.teacher_forcing = hparams.teacher_forcing 112 | self.n = hparams.neighbors 113 | 114 | self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer( 115 | in_features=self.hidden_features, 116 | hidden_features=self.hidden_features, 117 | out_features=self.hidden_features, 118 | time_window=self.time_window, 119 | n_variables=1 # variables = eq_variables + time 120 | ) for _ in range(self.hidden_layer - 1))) 121 | 122 | # The last message passing last layer has a fixed output size to make the use of the decoder 1D-CNN easier 123 | self.gnn_layers.append(GNN_Layer(in_features=self.hidden_features, 124 | hidden_features=self.hidden_features, 125 | out_features=self.hidden_features, 126 | time_window=self.time_window, 127 | n_variables=1 128 | ) 129 | ) 130 | 131 | self.embedding_mlp = nn.Sequential( 132 | nn.Linear(self.time_window + 2, self.hidden_features), 133 | Swish(), 134 | nn.Linear(self.hidden_features, self.hidden_features), 135 | Swish() 136 | ) 137 | 138 | # Decoder CNN, maps to different outputs (temporal bundling) 139 | 140 | if(self.time_window==10): 141 | self.output_mlp = nn.Sequential( 142 | nn.Conv1d(1, 8, 16, stride=6), 143 | nn.Conv1d(8, 1, 10, stride=1)) 144 | if(self.time_window==16): 145 | self.output_mlp = nn.Sequential( 146 | nn.Conv1d(1, 8, 16, stride=5), 147 | Swish(), 148 | nn.Conv1d(8, 1, 8, stride=1)) 149 | if(self.time_window==20): 150 | self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 15, stride=4), 151 | Swish(), 152 | nn.Conv1d(8, 1, 10, stride=1) 153 | ) 154 | if (self.time_window == 25): 155 | self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 16, stride=3), 156 | Swish(), 157 | nn.Conv1d(8, 1, 14, stride=1) 158 | ) 159 | if(self.time_window==50): 160 | self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 12, stride=2), 161 | Swish(), 162 | nn.Conv1d(8, 1, 10, stride=1) 163 | ) 164 | 165 | if self.loss == 'l1': 166 | self.criterion = nn.L1Loss() 167 | elif self.loss == 'l2': 168 | self.criterion = nn.MSELoss() 169 | elif self.loss == 'smooth_l1': 170 | self.criterion = nn.SmoothL1Loss() 171 | 172 | self.mse_criterion = nn.MSELoss() 173 | self.mae_criterion = nn.L1Loss() 174 | 175 | def forward(self, data, L, tmax, dt): 176 | 177 | u = data.x 178 | # Encode and normalize coordinate information 179 | pos = data.pos 180 | pos_x = pos[:, 1][:, None] / L 181 | pos_t = pos[:, 0][:, None] / tmax 182 | edge_index = data.edge_index 183 | batch = data.batch 184 | 185 | # Encode equation specific parameters 186 | # alpha, beta, gamma are used in E1,E2,E3 experiments 187 | # bc_left, bc_right, c are used in WE1, WE2, WE3 experiments 188 | variables = pos_t # time is treated as equation variable 189 | 190 | # Encoder and processor (message passing) 191 | node_input = torch.cat((u, pos_x, variables), -1) 192 | h = self.embedding_mlp(node_input) 193 | for i in range(self.hidden_layer): 194 | h = self.gnn_layers[i](h, u, pos_x, variables, edge_index, batch) 195 | 196 | # Decoder (formula 10 in the paper) 197 | dt = (torch.ones(1, self.time_window).to(dt.device) * dt).to(dt.device) 198 | dt = torch.cumsum(dt, dim=1) 199 | # [batch*n_nodes, hidden_dim] -> 1DCNN([batch*n_nodes, 1, hidden_dim]) -> [batch*n_nodes, time_window] 200 | diff = self.output_mlp(h[:, None]).squeeze(1) 201 | out = u[:, -1].repeat(self.time_window, 1).transpose(0, 1) + dt * diff 202 | 203 | return out 204 | 205 | 206 | def configure_optimizers(self): 207 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 208 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor) 209 | return { 210 | "optimizer": optimizer, 211 | "lr_scheduler": { 212 | "scheduler": scheduler 213 | }, 214 | } 215 | 216 | def _build_graph(self, 217 | data: torch.Tensor, 218 | t: torch.Tensor, 219 | x: torch.Tensor, 220 | steps: list): 221 | """ 222 | data, [B, T, N] 223 | labels, [B, T, N] 224 | t, [B] 225 | x, [B, N] 226 | steps, [B] 227 | """ 228 | nx = data.shape[-1] 229 | 230 | u = torch.Tensor().to(data.device) 231 | x_pos = torch.Tensor().to(data.device) 232 | t_pos = torch.Tensor().to(data.device) 233 | batch = torch.Tensor().to(data.device) 234 | 235 | for b, (data_batch, step) in enumerate(zip(data, steps)): 236 | u = torch.cat((u, torch.transpose(torch.cat([d[None, :] for d in data_batch]), 0, 1)), ) 237 | x_pos = torch.cat((x_pos, x[0]), ) 238 | t_pos = torch.cat((t_pos, torch.ones(nx, device=t.device) * t[b, step]), ) 239 | batch = torch.cat((batch, torch.ones(nx, device=batch.device) * b), ) 240 | 241 | # Calculate the edge_index 242 | 243 | dx = x[0][1] - x[0][0] 244 | radius = self.n * dx + 0.0001 245 | edge_index = radius_graph(x_pos, r=radius, batch=batch.long(), loop=False) 246 | 247 | graph = Data(x=u, edge_index=edge_index) 248 | graph.pos = torch.cat((t_pos[:, None], x_pos[:, None]), 1) 249 | graph.batch = batch.long() 250 | 251 | return graph 252 | 253 | 254 | def training_step(self, train_batch, batch_idx): 255 | u = train_batch['u'].float().permute(0,2,1) 256 | x = train_batch['x'].float().squeeze(-1) 257 | B, _, N = u.shape 258 | t = train_batch['t'].float() # B, T 259 | dt = t[0][1] - t[0][0] 260 | 261 | graph = self._build_graph( 262 | u[:,:self.time_window,:], 263 | t, 264 | x, 265 | steps=[0]*B) 266 | 267 | target = u[:,self.time_window:,:] 268 | T_out = target.shape[1] 269 | 270 | u_hat = [] 271 | for i in range(T_out//self.time_window): 272 | y_hat = self.forward(graph, x[0,-1], t[0,-1], dt) 273 | y_hat = y_hat.reshape(B, N, -1).permute(0,2,1) 274 | u_hat.append(y_hat) 275 | 276 | if self.teacher_forcing: 277 | graph = self._build_graph( 278 | u[:,(i+1)*self.time_window:(i+2)*self.time_window,:], 279 | t, 280 | x, 281 | steps=[0]*B) 282 | else: 283 | graph = self._build_graph( 284 | y_hat, 285 | t, 286 | x, 287 | steps=[0]*B) 288 | 289 | u_hat = torch.cat(u_hat, dim=1) 290 | 291 | loss = self.criterion(u_hat, target) 292 | mae_loss = self.mae_criterion(u_hat, target) 293 | 294 | self.log('train_loss', loss, prog_bar=True) 295 | self.log('train_mae_loss', mae_loss, prog_bar=True) 296 | 297 | return loss 298 | 299 | def validation_step(self, val_batch, batch_idx): 300 | u = val_batch['u'].float().permute(0,2,1) 301 | x = val_batch['x'].float().squeeze(-1) 302 | B, _, N = u.shape 303 | t = val_batch['t'].float() # B, T 304 | dt = t[0][1] - t[0][0] 305 | 306 | graph = self._build_graph( 307 | u[:,:self.time_window,:], 308 | t, 309 | x, 310 | steps=[0]*B) 311 | 312 | target = u[:,self.time_window:,:] 313 | T_out = target.shape[1] 314 | 315 | u_hat = [] 316 | for i in range(T_out//self.time_window): 317 | y_hat = self.forward(graph, x[0,-1], t[0,-1], dt) 318 | y_hat = y_hat.reshape(B, N, -1).permute(0,2,1) 319 | u_hat.append(y_hat) 320 | 321 | graph = self._build_graph( 322 | y_hat, 323 | t, 324 | x, 325 | steps=[0]*B) 326 | 327 | u_hat = torch.cat(u_hat, dim=1) 328 | 329 | loss = self.criterion(u_hat, target) 330 | mae_loss = self.mae_criterion(u_hat, target) 331 | 332 | self.log('val_loss', loss, prog_bar=True) 333 | self.log('val_mae_loss', mae_loss, prog_bar=True) -------------------------------------------------------------------------------- /models/mpnn_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | import pytorch_lightning as pl 8 | 9 | from torch_geometric.nn import MessagePassing, radius_graph, InstanceNorm 10 | from torch_geometric.data import Data 11 | 12 | from models.backbones.mlp import MLP 13 | from utils import * 14 | 15 | class Swish(nn.Module): 16 | """ 17 | Swish activation function 18 | """ 19 | def __init__(self, beta=1): 20 | super(Swish, self).__init__() 21 | self.beta = beta 22 | 23 | def forward(self, x): 24 | return x * torch.sigmoid(self.beta*x) 25 | 26 | 27 | class GNN_Layer(MessagePassing): 28 | """ 29 | Message passing layer 30 | """ 31 | def __init__(self, 32 | in_features: int, 33 | out_features: int, 34 | hidden_features: int, 35 | time_window: int, 36 | n_variables: int): 37 | """ 38 | Initialize message passing layers 39 | Args: 40 | in_features (int): number of node input features 41 | out_features (int): number of node output features 42 | hidden_features (int): number of hidden features 43 | time_window (int): number of input/output timesteps (temporal bundling) 44 | n_variables (int): number of equation specific parameters used in the solver 45 | """ 46 | super(GNN_Layer, self).__init__(node_dim=-2, aggr='mean') 47 | self.in_features = in_features 48 | self.out_features = out_features 49 | self.hidden_features = hidden_features 50 | 51 | self.message_net_1 = nn.Sequential(nn.Linear(2 * in_features + time_window + 2 + n_variables, hidden_features), 52 | Swish() 53 | ) 54 | self.message_net_2 = nn.Sequential(nn.Linear(hidden_features, hidden_features), 55 | Swish() 56 | ) 57 | self.update_net_1 = nn.Sequential(nn.Linear(in_features + hidden_features + n_variables, hidden_features), 58 | Swish() 59 | ) 60 | self.update_net_2 = nn.Sequential(nn.Linear(hidden_features, out_features), 61 | Swish() 62 | ) 63 | self.norm = InstanceNorm(hidden_features) 64 | 65 | def forward(self, x, u, pos, variables, edge_index, batch): 66 | """ 67 | Propagate messages along edges 68 | """ 69 | x = self.propagate(edge_index, x=x, u=u, pos=pos, variables=variables) 70 | x = self.norm(x, batch) 71 | return x 72 | 73 | def message(self, x_i, x_j, u_i, u_j, pos_i, pos_j, variables_i): 74 | """ 75 | Message update following formula 8 of the paper 76 | """ 77 | message = self.message_net_1(torch.cat((x_i, x_j, u_i - u_j, pos_i - pos_j, variables_i), dim=-1)) 78 | message = self.message_net_2(message) 79 | return message 80 | 81 | def update(self, message, x, variables): 82 | """ 83 | Node update following formula 9 of the paper 84 | """ 85 | update = self.update_net_1(torch.cat((x, message, variables), dim=-1)) 86 | update = self.update_net_2(update) 87 | if self.in_features == self.out_features: 88 | return x + update 89 | else: 90 | return update 91 | 92 | 93 | class MPNN_2d(pl.LightningModule): 94 | def __init__(self,hparams): 95 | 96 | super().__init__() 97 | 98 | self.save_hyperparameters() 99 | 100 | # Training parameters 101 | self.lr = hparams.lr 102 | self.weight_decay = hparams.weight_decay 103 | self.factor = hparams.factor 104 | self.step_size = hparams.step_size 105 | self.loss = hparams.loss 106 | # Model parameters 107 | self.out_features = hparams.time_window 108 | self.hidden_features = hparams.hidden_features 109 | self.hidden_layer = hparams.hidden_layer 110 | self.time_window = hparams.time_window 111 | self.teacher_forcing = hparams.teacher_forcing 112 | self.n = hparams.neighbors 113 | 114 | self.gnn_layers = torch.nn.ModuleList(modules=(GNN_Layer( 115 | in_features=self.hidden_features, 116 | hidden_features=self.hidden_features, 117 | out_features=self.hidden_features, 118 | time_window=self.time_window, 119 | n_variables=1 # variables = eq_variables + time 120 | ) for _ in range(self.hidden_layer - 1))) 121 | 122 | self.gnn_layers.append(GNN_Layer(in_features=self.hidden_features, 123 | hidden_features=self.hidden_features, 124 | out_features=self.hidden_features, 125 | time_window=self.time_window, 126 | n_variables=1 127 | ) 128 | ) 129 | 130 | self.embedding_mlp = nn.Sequential( 131 | nn.Linear(self.time_window + 3, self.hidden_features), 132 | Swish(), 133 | nn.Linear(self.hidden_features, self.hidden_features), 134 | Swish() 135 | ) 136 | 137 | # Decoder CNN, maps to different outputs (temporal bundling) 138 | if(self.time_window==10): 139 | self.output_mlp = nn.Sequential( 140 | nn.Conv1d(1, 8, 16, stride=6), 141 | Swish(), 142 | nn.Conv1d(8, 1, 10, stride=1)) 143 | if(self.time_window==16): 144 | self.output_mlp = nn.Sequential( 145 | nn.Conv1d(1, 8, 16, stride=5), 146 | Swish(), 147 | nn.Conv1d(8, 1, 8, stride=1)) 148 | if(self.time_window==20): 149 | self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 15, stride=4), 150 | Swish(), 151 | nn.Conv1d(8, 1, 10, stride=1) 152 | ) 153 | if (self.time_window == 25): 154 | self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 16, stride=3), 155 | Swish(), 156 | nn.Conv1d(8, 1, 14, stride=1) 157 | ) 158 | if(self.time_window==50): 159 | self.output_mlp = nn.Sequential(nn.Conv1d(1, 8, 12, stride=2), 160 | Swish(), 161 | nn.Conv1d(8, 1, 10, stride=1) 162 | ) 163 | 164 | if self.loss == 'l1': 165 | self.criterion = nn.L1Loss() 166 | elif self.loss == 'l2': 167 | self.criterion = nn.MSELoss() 168 | elif self.loss == 'smooth_l1': 169 | self.criterion = nn.SmoothL1Loss() 170 | 171 | self.mse_criterion = nn.MSELoss() 172 | self.mae_criterion = nn.L1Loss() 173 | 174 | def forward(self, data, L, tmax, dt): 175 | 176 | u = data.x 177 | # Encode and normalize coordinate information 178 | pos = data.pos 179 | pos_x = pos[:, 1][:, None] / L 180 | pos_t = pos[:, 0][:, None] / tmax 181 | edge_index = data.edge_index 182 | batch = data.batch 183 | 184 | # Encode equation specific parameters 185 | # alpha, beta, gamma are used in E1,E2,E3 experiments 186 | # bc_left, bc_right, c are used in WE1, WE2, WE3 experiments 187 | variables = pos_t # time is treated as equation variable 188 | 189 | # Encoder and processor (message passing) 190 | node_input = torch.cat((u, pos_x, variables), -1) 191 | h = self.embedding_mlp(node_input) 192 | for i in range(self.hidden_layer): 193 | h = self.gnn_layers[i](h, u, pos_x, variables, edge_index, batch) 194 | 195 | # Decoder (formula 10 in the paper) 196 | dt = (torch.ones(1, self.time_window).to(dt.device) * dt).to(dt.device) 197 | dt = torch.cumsum(dt, dim=1) 198 | # [batch*n_nodes, hidden_dim] -> 1DCNN([batch*n_nodes, 1, hidden_dim]) -> [batch*n_nodes, time_window] 199 | diff = self.output_mlp(h[:, None]).squeeze(1) 200 | out = u[:, -1].repeat(self.time_window, 1).transpose(0, 1) + dt * diff 201 | 202 | return out 203 | 204 | 205 | def configure_optimizers(self): 206 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 207 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor) 208 | return { 209 | "optimizer": optimizer, 210 | "lr_scheduler": { 211 | "scheduler": scheduler 212 | }, 213 | } 214 | 215 | def _build_graph(self, 216 | data: torch.Tensor, 217 | t: torch.Tensor, 218 | x: torch.Tensor, 219 | steps: list): 220 | """ 221 | data, [B, T, N] 222 | t, [B] 223 | x, [B, N] 224 | steps, [B] 225 | """ 226 | nx = data.shape[-1] 227 | 228 | u = torch.Tensor().to(data.device) 229 | x_pos = torch.Tensor().to(data.device) 230 | t_pos = torch.Tensor().to(data.device) 231 | batch = torch.Tensor().to(data.device) 232 | 233 | for b, (data_batch, step) in enumerate(zip(data, steps)): 234 | u = torch.cat((u, torch.transpose(torch.cat([d[None, :] for d in data_batch]), 0, 1)), ) 235 | x_pos = torch.cat((x_pos, x[0]), ) 236 | t_pos = torch.cat((t_pos, torch.ones(nx, device=t.device) * t[b, step]), ) 237 | batch = torch.cat((batch, torch.ones(nx, device=batch.device) * b), ) 238 | 239 | # Calculate the edge_index 240 | dx = x[0][1] - x[0][0] 241 | dy = x[0][int(nx**0.5)] - x[0][0] 242 | dr = torch.norm(dx-dy, p=2) 243 | radius = self.n * dr + 0.0001 244 | 245 | edge_index = radius_graph(x_pos, r=radius, batch=batch.long(), loop=False) 246 | 247 | graph = Data(x=u, edge_index=edge_index) 248 | graph.pos = torch.cat((t_pos[:, None], x_pos), 1) 249 | graph.batch = batch.long() 250 | 251 | return graph 252 | 253 | 254 | def training_step(self, train_batch, batch_idx): 255 | u = train_batch['u'].float().permute(0,2,1) 256 | x = train_batch['x'].float().squeeze(-1) 257 | B, _, N = u.shape 258 | t = train_batch['t'].float() # B, T 259 | dt = t[0][1] - t[0][0] 260 | 261 | graph = self._build_graph( 262 | u[:,:self.time_window,:], 263 | t, 264 | x, 265 | steps=[self.time_window-1]*B) 266 | 267 | target = u[:,self.time_window:,:] 268 | T_out = target.shape[1] 269 | 270 | u_hat = [] 271 | for i in range(T_out//self.time_window): 272 | y_hat = self.forward(graph, x[0,-1], t[0,-1], dt) 273 | y_hat = y_hat.reshape(B, N, -1).permute(0,2,1) 274 | u_hat.append(y_hat) 275 | 276 | if self.teacher_forcing: 277 | graph = self._build_graph( 278 | u[:,(i+1)*self.time_window:(i+2)*self.time_window,:], 279 | t, 280 | x, 281 | steps=[(i+2)*self.time_window-1]*B) 282 | else: 283 | graph = self._build_graph( 284 | y_hat, 285 | t, 286 | x, 287 | steps=[(i+2)*self.time_window-1]*B) 288 | 289 | u_hat = torch.cat(u_hat, dim=1) 290 | 291 | loss = self.criterion(u_hat, target) 292 | mae_loss = self.mae_criterion(u_hat, target) 293 | 294 | self.log('train_loss', loss, prog_bar=True) 295 | self.log('train_mae_loss', mae_loss, prog_bar=True) 296 | 297 | return loss 298 | 299 | def validation_step(self, val_batch, batch_idx): 300 | u = val_batch['u'].float().permute(0,2,1) 301 | x = val_batch['x'].float().squeeze(-1) 302 | B, T_in, N = u.shape 303 | t = val_batch['t'].float() # B, T 304 | dt = t[0][1] - t[0][0] 305 | 306 | graph = self._build_graph( 307 | u[:,:self.time_window,:], 308 | t, 309 | x, 310 | steps=[self.time_window-1]*B) 311 | 312 | target = u[:,self.time_window:,:] 313 | T_out = target.shape[1] 314 | 315 | u_hat = [] 316 | for i in range(T_out//self.time_window): 317 | y_hat = self.forward(graph, x[0,-1], t[0,-1], dt) 318 | y_hat = y_hat.reshape(B, N, -1).permute(0,2,1) 319 | u_hat.append(y_hat) 320 | 321 | graph = self._build_graph( 322 | y_hat, 323 | t, 324 | x, 325 | steps=[(i+2)*self.time_window-1]*B) 326 | 327 | u_hat = torch.cat(u_hat, dim=1) 328 | 329 | loss = self.criterion(u_hat, target) 330 | mae_loss = self.mae_criterion(u_hat, target) 331 | 332 | self.log('val_loss', loss, prog_bar=True) 333 | self.log('val_mae_loss', mae_loss, prog_bar=True) 334 | -------------------------------------------------------------------------------- /models/magnet_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | import pytorch_lightning as pl 6 | 7 | from torch_geometric.nn import MessagePassing, radius_graph 8 | 9 | from models.backbones.edsr import EDSR 10 | from models.backbones.mlp import MLP 11 | from utils import * 12 | 13 | class Encoder(nn.Module): 14 | def __init__( 15 | self, 16 | node_in, 17 | node_out, 18 | edge_in, 19 | edge_out, 20 | mlp_layers, 21 | mlp_hidden, 22 | ): 23 | super(Encoder, self).__init__() 24 | self.node_fn = nn.Sequential( 25 | MLP( 26 | in_dim=node_in, 27 | hidden_list=[mlp_hidden]*mlp_layers, 28 | out_dim=node_out), 29 | nn.LayerNorm(node_out) 30 | ) 31 | self.edge_fn = nn.Sequential( 32 | MLP( 33 | in_dim=edge_in, 34 | hidden_list=[mlp_hidden]*mlp_layers, 35 | out_dim=edge_out 36 | ), 37 | nn.LayerNorm(edge_out) 38 | ) 39 | 40 | def forward(self, x, edge_index, e_features): # global_features 41 | # x: (E, node_in) 42 | # edge_index: (2, E) 43 | # e_features: (E, edge_in) 44 | return self.node_fn(x), self.edge_fn(e_features) 45 | 46 | class InteractionNetwork(MessagePassing): 47 | def __init__( 48 | self, 49 | node_in, 50 | node_out, 51 | edge_in, 52 | edge_out, 53 | mlp_layers, 54 | mlp_hidden, 55 | ): 56 | super(InteractionNetwork, self).__init__(aggr='mean') 57 | self.node_fn = nn.Sequential( 58 | MLP( 59 | in_dim=node_in+edge_out, 60 | hidden_list=[mlp_hidden]*mlp_layers, 61 | out_dim=node_out), 62 | nn.LayerNorm(node_out) 63 | ) 64 | self.edge_fn = nn.Sequential( 65 | MLP( 66 | in_dim=node_in+node_in+edge_in, 67 | hidden_list=[mlp_hidden]*mlp_layers, 68 | out_dim=edge_out 69 | ), 70 | nn.LayerNorm(edge_out) 71 | ) 72 | 73 | def forward(self, x, edge_index, e_features): 74 | # x: (E, node_in) 75 | # edge_index: (2, E) 76 | # e_features: (E, edge_in) 77 | x_residual = x 78 | e_features_residual = e_features 79 | x, e_features = self.propagate(edge_index=edge_index, x=x, e_features=e_features) 80 | return x+x_residual, e_features+e_features_residual 81 | 82 | def message(self, edge_index, x_i, x_j, e_features): 83 | 84 | e_features = torch.cat([x_i, x_j, e_features], dim=-1) 85 | e_features = self.edge_fn(e_features) 86 | return e_features 87 | 88 | def update(self, x_updated, x, e_features): 89 | # x_updated: (E, edge_out) 90 | # x: (E, node_in) 91 | x_updated = torch.cat([x_updated, x], dim=-1) 92 | x_updated = self.node_fn(x_updated) 93 | return x_updated, e_features 94 | 95 | class Processor(MessagePassing): 96 | def __init__( 97 | self, 98 | node_in, 99 | node_out, 100 | edge_in, 101 | edge_out, 102 | num_message_passing_steps, 103 | mlp_num_layers, 104 | mlp_hidden_dim, 105 | ): 106 | super(Processor, self).__init__(aggr='max') 107 | self.gnn_stacks = nn.ModuleList([ 108 | InteractionNetwork( 109 | node_in=node_in, 110 | node_out=node_out, 111 | edge_in=edge_in, 112 | edge_out=edge_out, 113 | mlp_layers=mlp_num_layers, 114 | mlp_hidden=mlp_hidden_dim, 115 | ) for _ in range(num_message_passing_steps)]) 116 | 117 | def forward(self, x, edge_index, e_features): 118 | for gnn in self.gnn_stacks: 119 | x, e_features = gnn(x, edge_index, e_features) 120 | return x, e_features 121 | 122 | class Decoder(nn.Module): 123 | def __init__( 124 | self, 125 | node_in, 126 | node_out, 127 | mlp_layers, 128 | mlp_hidden, 129 | ): 130 | super(Decoder, self).__init__() 131 | 132 | self.node_fn = MLP( 133 | in_dim=node_in, 134 | hidden_list=[mlp_hidden]*mlp_layers, 135 | out_dim=node_out) 136 | 137 | 138 | def forward(self, x): 139 | # x: (E, node_in) 140 | return self.node_fn(x) 141 | 142 | class MAgNetCNN(pl.LightningModule): 143 | def __init__(self,hparams): 144 | 145 | super().__init__() 146 | 147 | self.save_hyperparameters() 148 | 149 | # Training parameters 150 | self.lr = hparams.lr 151 | self.weight_decay = hparams.weight_decay 152 | self.factor = hparams.factor 153 | self.step_size = hparams.step_size 154 | self.loss = hparams.loss 155 | # Model parameters 156 | self.time_slice = hparams.time_slice 157 | self.num_message_passing_steps = hparams.num_message_passing_steps 158 | self.latent_dim = hparams.latent_dim 159 | self.mlp_layers = hparams.mlp_layers 160 | self.mlp_hidden = hparams.mlp_hidden 161 | self.scales = hparams.scales 162 | self.res_layers = hparams.res_layers 163 | self.n_chan = hparams.n_chan 164 | self.kernel_size = hparams.kernel_size 165 | self.res_scale = hparams.res_scale 166 | self.interpolation = hparams.interpolation 167 | self.radius = hparams.radius 168 | self.teacher_forcing = hparams.teacher_forcing 169 | 170 | if self.loss == 'l1': 171 | self.criterion = nn.L1Loss() 172 | elif self.loss == 'l2': 173 | self.criterion = nn.MSELoss() 174 | elif self.loss == 'smooth_l1': 175 | self.criterion = nn.SmoothL1Loss() 176 | 177 | self.mse_criterion = nn.MSELoss() 178 | self.mae_criterion = nn.L1Loss() 179 | 180 | self.encoder = EDSR(**{ 181 | "res_layers": self.res_layers, 182 | "n_chan": self.n_chan, 183 | "kernel_size": self.kernel_size, 184 | "res_scale": self.res_scale, 185 | "mode": "1d", 186 | "in_chan": self.time_slice}) 187 | 188 | self.proj_head = nn.Sequential( 189 | MLP( 190 | in_dim=self.encoder.out_dim+3+1, 191 | hidden_list=[self.mlp_hidden]*self.mlp_layers, 192 | out_dim=self.n_chan), 193 | nn.LayerNorm(self.n_chan) 194 | ) 195 | self.projector = MLP( 196 | in_dim=self.n_chan, 197 | hidden_list=[self.mlp_hidden]*self.mlp_layers, 198 | out_dim=1) 199 | 200 | 201 | self._encoder = Encoder( 202 | node_in=self.time_slice+2, 203 | node_out=self.latent_dim, 204 | edge_in=self.time_slice+1, 205 | edge_out=self.latent_dim, 206 | mlp_layers=self.mlp_layers, 207 | mlp_hidden=self.mlp_hidden, 208 | ) 209 | self._processor = Processor( 210 | node_in=self.latent_dim, 211 | node_out=self.latent_dim, 212 | edge_in=self.latent_dim, 213 | edge_out=self.latent_dim, 214 | num_message_passing_steps=self.num_message_passing_steps, 215 | mlp_num_layers=self.mlp_layers, 216 | mlp_hidden_dim=self.mlp_hidden, 217 | ) 218 | self._decoder = Decoder( 219 | node_in=self.latent_dim, 220 | node_out=self.time_slice, 221 | mlp_layers=self.mlp_layers, 222 | mlp_hidden=self.mlp_hidden, 223 | ) 224 | 225 | def continuous_decoder(self, x_t, feat, cell, coord_hr, t): 226 | ''' 227 | Args: 228 | x_t, [B, T, C, L] 229 | feat, [B, C, L]: Feature maps 230 | cell, [B, N, 1] 231 | coord_hr, [B, N, 1] 232 | t, [B, T_in+T_out] 233 | ''' 234 | B, C, L = feat.shape 235 | T = x_t.shape[1] 236 | N_coords = coord_hr.shape[1] 237 | 238 | # Coordinates in the feature map 239 | feat_coord = make_coord([L], flatten=False).to(feat.device).permute(1,0).unsqueeze(0).expand(B, 1, L) 240 | feat_coord = torch.cat([feat_coord, torch.zeros_like(feat_coord)], dim=1) 241 | feat_coord = feat_coord.unsqueeze(2) # B, 2, 1, L 242 | 243 | dx = 1/L 244 | 245 | tt = t.unsqueeze(1).repeat(1,cell.shape[1],1) # B, N, T 246 | pred_signals = [] 247 | areas = [] 248 | 249 | for vx in [-1,1]: 250 | seq_input = [] 251 | coord = coord_hr.clone().unsqueeze(1) 252 | coord = torch.cat([coord, torch.zeros_like(coord)], dim=-1) # B, 1, N, 2 253 | coord[:, :, :, 0] += vx * dx + 1e-6 254 | coord.clamp_(-1 + 1e-6, 1 - 1e-6) 255 | 256 | # latent code z* 257 | q_feat = F.grid_sample(feat.unsqueeze(2), coord, mode='nearest', padding_mode="border", align_corners=False)[:,:,0].permute(0,2,1) # B, N, C 258 | # coordinates 259 | q_coord = F.grid_sample(feat_coord, coord, mode='nearest', padding_mode="border", align_corners=False)[:,0].permute(0,2,1) # B, N, 1 260 | # final coordinates 261 | final_coord = coord_hr-q_coord 262 | final_coord *= L 263 | # cell decoding 264 | final_cell = cell.clone() 265 | final_cell *= L 266 | 267 | areas.append(torch.abs(final_coord).reshape(-1,1).unsqueeze(1)) # B*N, 1, 1 268 | for i in range(T): 269 | 270 | # true solution 271 | q_inp = F.grid_sample(x_t[:,i].unsqueeze(2), coord, mode='nearest', padding_mode="border", align_corners=False)[:,:,0].permute(0,2,1) # B, N, C 272 | # putting all inputs together (z, [x,c], t) 273 | final_input = torch.cat([q_feat, q_inp, final_coord, final_cell, tt[:,:,i:i+1]], dim=-1) 274 | final_input = final_input.view(B*N_coords, -1) 275 | seq_input.append(final_input) 276 | 277 | seq_input = torch.stack(seq_input, dim=1) # B*N, T, C 278 | pred_signals.append(self.proj_head(seq_input)) 279 | 280 | # Area Interpolation 281 | if self.interpolation == 'area': 282 | ret = (pred_signals[0]*areas[1]+pred_signals[1]*areas[0])/(areas[1]+areas[0]) 283 | else: 284 | ret = (pred_signals[0]*areas[1]+pred_signals[1]*areas[0])/(areas[1]+areas[0]) 285 | return ret 286 | 287 | def feature_encoding(self, x_t, scale=1): 288 | B, T, C, L = x_t.shape 289 | # Encoding x_lr and getting feature maps 290 | x_t = x_t.reshape(B, T*C, L) 291 | 292 | feat = self.encoder(x_t) 293 | 294 | return feat 295 | 296 | def _build_graph(self, u, x, t): 297 | B, N, _ = u.shape 298 | 299 | u_ = u.reshape(B*N, -1) 300 | x_ = x.reshape(B*N, -1) 301 | 302 | batch_ids = torch.cat([torch.LongTensor([i for _ in range(n)]) for i, n in enumerate(B*[N])]).to(self.device) 303 | edges = radius_graph(x_, batch=batch_ids, r=self.radius, loop=True) # (2, n_edges) 304 | receivers = edges[0, :] 305 | senders = edges[1, :] 306 | edge_index = torch.stack([senders, receivers]) 307 | 308 | node_features = [] 309 | node_features.append(u_) 310 | node_features.append(x_) 311 | node_features.append(t[:,-1:].repeat(N, 1)) 312 | node_features = torch.cat(node_features, dim=-1) 313 | 314 | edge_features = [] 315 | 316 | edge_features.append((u_[senders]-u_[receivers])) 317 | edge_features.append((x_[senders]-x_[receivers])) 318 | edge_features = torch.cat(edge_features, dim=-1) 319 | 320 | return node_features, edge_index, edge_features 321 | 322 | def forward( 323 | self, 324 | x_t, 325 | coords, 326 | cell, 327 | t, 328 | hr_last, 329 | hiddens=None): 330 | ''' 331 | Args: 332 | x_lr: tensor of shape [B, T, C, L] that represents the low-resolution frames 333 | coord_hr: tensor of shape [B, N, 1] that represents the N coordinates for sequence in the batch 334 | t: tensor of shape [B, T] represents the time-coordinates for each sequence in the batch 335 | ''' 336 | B, T = x_t.shape[:2] 337 | N_coords = coords.shape[1] 338 | T_out = t.shape[-1] - T 339 | 340 | feat = self.feature_encoding(x_t, scale=1) 341 | L = feat.shape[-1] 342 | z = self.continuous_decoder(x_t, feat, cell, coords, t) # B*N, T, C 343 | hr_points = self.projector(z) # B*N, T, 1 344 | 345 | # Build Graph 346 | hr_points = hr_points.reshape(B, N_coords, T, -1) # B, N, T, C 347 | hr_points = hr_points.reshape(B, N_coords, -1) # B, N, C 348 | lr_points = x_t.permute(0,3,1,2) # B, L, T, C 349 | lr_points = lr_points.reshape(B, L, -1) # B, L, C 350 | 351 | lr_coords = make_coord([L]).to(feat.device).unsqueeze(0).repeat(B, 1, 1) # B, L, 1 352 | all_coords = torch.cat([lr_coords, coords], dim=1) # B, (L+N), 1 353 | 354 | all_feats = torch.cat([lr_points, hr_points], dim=1) # B, (L+N), C 355 | 356 | node_features, edge_index, edge_features = self._build_graph(all_feats, all_coords, t[:,:T]) 357 | 358 | 359 | node_features, edge_features = self._encoder(node_features, edge_index, edge_features) 360 | node_features, _ = self._processor(node_features, edge_index, edge_features) 361 | node_features = self._decoder(node_features) # B*(L+N), T_out 362 | ret = node_features.reshape(B, -1, node_features.shape[-1]) # B, (L+N), T_out 363 | 364 | outputs = [] 365 | tt = t.unsqueeze(1).repeat(1,L+cell.shape[1],1) 366 | 367 | last_values = torch.cat([x_t[:,-1].permute(0,2,1), hr_last], dim=1) # B, (L+N), 1 368 | 369 | for i in range(T_out): 370 | delta_t = tt[:,:,T+i:T+i+1]-tt[:,:,T-1:T] 371 | op = ret[...,i].unsqueeze(-1) # B, (L+N), 1 372 | outputs.append(last_values+delta_t*op) 373 | 374 | outputs = torch.stack(outputs, dim=1) # B, T, (L+N), 1 375 | 376 | out_lr = outputs[:,:,:L] 377 | out_hr = outputs[:,:,L:] 378 | hr_points = hr_points.reshape(B, N_coords, T, -1) 379 | hr_points = hr_points.permute(0,2,1,3) 380 | 381 | return out_hr, out_lr, hr_points 382 | 383 | def configure_optimizers(self): 384 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 385 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor) 386 | return { 387 | "optimizer": optimizer, 388 | "lr_scheduler": { 389 | "scheduler": scheduler 390 | }, 391 | } 392 | 393 | def training_step(self, train_batch, batch_idx): 394 | t = train_batch['t'].float() 395 | u = train_batch['lr_frames'].float() 396 | B, T, C, L = u.shape 397 | u_values = train_batch['hr_points'].float() 398 | coords = train_batch['coords'].float() 399 | cells = train_batch['cells'].float() 400 | sample_idx = train_batch['sample_idx'] 401 | 402 | u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1 403 | B, T_future = u_values_future.shape[:2] 404 | 405 | u_values_hat = [] 406 | hr_values_hat = [] 407 | 408 | inp = u[:,:self.time_slice] 409 | hr_last = u_values[:,self.time_slice-1] 410 | 411 | for i in range(T_future//self.time_slice): 412 | out_hr, out_lr, hr_points = self.forward(inp, coords, cells, t[:,i*self.time_slice:(i+2)*self.time_slice], hr_last) 413 | y_hat = torch.cat([out_hr, out_lr], dim=2) 414 | u_values_hat.append(y_hat) 415 | hr_values_hat.append(hr_points) 416 | 417 | if self.teacher_forcing: 418 | inp = u[:,(i+1)*self.time_slice:(i+2)*self.time_slice] # B, T, C, L 419 | hr_last = u_values[:,(i+2)*self.time_slice-1] 420 | else: 421 | inp = out_lr.permute(0,1,3,2) 422 | hr_last = out_hr[:,-1] 423 | 424 | u_values_hat = torch.cat(u_values_hat, dim=1) # B, T_out, (N+L), 1 425 | hr_values_hat = torch.cat(hr_values_hat, dim=1) # B, T_in, N, 1 426 | 427 | target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2) 428 | loss = self.criterion(u_values_hat, target)+self.criterion(hr_values_hat, u_values[:,:-self.time_slice]) 429 | mae_loss = self.mae_criterion(u_values_hat, target) 430 | interp_loss = self.mae_criterion(hr_values_hat, u_values[:,:-self.time_slice]) 431 | 432 | self.log('train_loss', loss, prog_bar=True) 433 | self.log('train_mae_loss', mae_loss, prog_bar=True) 434 | self.log('train_interp_loss', interp_loss, prog_bar=True) 435 | 436 | return loss 437 | 438 | def validation_step(self, val_batch, batch_idx): 439 | t = val_batch['t'].float() 440 | u = val_batch['lr_frames'].float() # B, T, 1, L 441 | B, T, _, L = u.shape 442 | u_values = val_batch['hr_points'].float() 443 | coords = val_batch['coords'].float() 444 | cells = val_batch['cells'].float() 445 | 446 | u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1 447 | T_future = u_values_future.shape[1] 448 | 449 | u_values_hat = [] 450 | inp = u[:,:self.time_slice] 451 | hr_last = u_values[:,self.time_slice-1] 452 | 453 | for i in range(T_future//self.time_slice): 454 | y_hat, _, _ = self.forward(inp, coords, cells, t[:,i*self.time_slice:(i+2)*self.time_slice], hr_last) 455 | 456 | u_values_hat.append(y_hat) 457 | 458 | inp = y_hat.permute(0,1,3,2) 459 | inp = F.interpolate(inp.reshape(-1,inp.shape[-2], inp.shape[-1]), size=L, mode='linear', align_corners=False).reshape(B, -1, inp.shape[-2], L) 460 | hr_last = y_hat[:,-1] 461 | 462 | u_values_hat = torch.cat(u_values_hat, dim=1) 463 | loss = self.criterion(u_values_hat, u_values_future) 464 | mae_loss = self.mae_criterion(u_values_hat, u_values_future) 465 | 466 | self.log('val_loss', loss, prog_bar=True) 467 | self.log('val_mae_loss', mae_loss, prog_bar=True) -------------------------------------------------------------------------------- /models/magnet_gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import pytorch_lightning as pl 5 | 6 | from torch_geometric.nn import MessagePassing, radius_graph, knn 7 | 8 | from models.backbones.mlp import MLP 9 | from utils import * 10 | 11 | class Encoder(nn.Module): 12 | def __init__( 13 | self, 14 | node_in, 15 | node_out, 16 | edge_in, 17 | edge_out, 18 | mlp_layers, 19 | mlp_hidden, 20 | ): 21 | super(Encoder, self).__init__() 22 | self.node_fn = nn.Sequential( 23 | MLP( 24 | in_dim=node_in, 25 | hidden_list=[mlp_hidden]*mlp_layers, 26 | out_dim=node_out), 27 | nn.LayerNorm(node_out) 28 | ) 29 | self.edge_fn = nn.Sequential( 30 | MLP( 31 | in_dim=edge_in, 32 | hidden_list=[mlp_hidden]*mlp_layers, 33 | out_dim=edge_out, 34 | ), 35 | nn.LayerNorm(edge_out) 36 | ) 37 | 38 | def forward(self, x, edge_index, e_features): # global_features 39 | # x: (E, node_in) 40 | # edge_index: (2, E) 41 | # e_features: (E, edge_in) 42 | return self.node_fn(x), self.edge_fn(e_features) 43 | 44 | class InteractionNetwork(MessagePassing): 45 | def __init__( 46 | self, 47 | node_in, 48 | node_out, 49 | edge_in, 50 | edge_out, 51 | mlp_layers, 52 | mlp_hidden, 53 | ): 54 | super(InteractionNetwork, self).__init__(aggr='mean') 55 | self.node_fn = nn.Sequential( 56 | MLP( 57 | in_dim=node_in+edge_out, 58 | hidden_list=[mlp_hidden]*mlp_layers, 59 | out_dim=node_out), 60 | nn.LayerNorm(node_out)) 61 | self.edge_fn = nn.Sequential( 62 | MLP( 63 | in_dim=node_in+node_in+edge_in, 64 | hidden_list=[mlp_hidden]*mlp_layers, 65 | out_dim=edge_out 66 | ), 67 | nn.LayerNorm(edge_out) 68 | ) 69 | 70 | def forward(self, x, edge_index, e_features): 71 | # x: (E, node_in) 72 | # edge_index: (2, E) 73 | # e_features: (E, edge_in) 74 | x_residual = x 75 | e_features_residual = e_features 76 | x, e_features = self.propagate(edge_index=edge_index, x=x, e_features=e_features) 77 | return x+x_residual, e_features+e_features_residual 78 | 79 | def message(self, edge_index, x_i, x_j, e_features): 80 | 81 | e_features = torch.cat([x_i, x_j, e_features], dim=-1) 82 | e_features = self.edge_fn(e_features) 83 | return e_features 84 | 85 | def update(self, x_updated, x, e_features): 86 | # x_updated: (E, edge_out) 87 | # x: (E, node_in) 88 | x_updated = torch.cat([x_updated, x], dim=-1) 89 | x_updated = self.node_fn(x_updated) 90 | return x_updated, e_features 91 | 92 | class Processor(MessagePassing): 93 | def __init__( 94 | self, 95 | node_in, 96 | node_out, 97 | edge_in, 98 | edge_out, 99 | num_message_passing_steps, 100 | mlp_num_layers, 101 | mlp_hidden_dim, 102 | ): 103 | super(Processor, self).__init__(aggr='max') 104 | self.gnn_stacks = nn.ModuleList([ 105 | InteractionNetwork( 106 | node_in=node_in, 107 | node_out=node_out, 108 | edge_in=edge_in, 109 | edge_out=edge_out, 110 | mlp_layers=mlp_num_layers, 111 | mlp_hidden=mlp_hidden_dim, 112 | ) for _ in range(num_message_passing_steps)]) 113 | 114 | def forward(self, x, edge_index, e_features): 115 | for gnn in self.gnn_stacks: 116 | x, e_features = gnn(x, edge_index, e_features) 117 | return x, e_features 118 | 119 | class Decoder(nn.Module): 120 | def __init__( 121 | self, 122 | node_in, 123 | node_out, 124 | mlp_layers, 125 | mlp_hidden, 126 | ): 127 | super(Decoder, self).__init__() 128 | 129 | self.node_fn = MLP( 130 | in_dim=node_in, 131 | hidden_list=[mlp_hidden]*mlp_layers, 132 | out_dim=node_out) 133 | 134 | 135 | def forward(self, x): 136 | # x: (E, node_in) 137 | return self.node_fn(x) 138 | 139 | class MAgNetGNN(pl.LightningModule): 140 | def __init__(self,hparams): 141 | 142 | super().__init__() 143 | 144 | self.save_hyperparameters() 145 | 146 | # Training parameters 147 | self.lr = hparams.lr 148 | self.weight_decay = hparams.weight_decay 149 | self.factor = hparams.factor 150 | self.step_size = hparams.step_size 151 | self.loss = hparams.loss 152 | # Model parameters 153 | self.time_slice = hparams.time_slice 154 | self.num_message_passing_steps = hparams.num_message_passing_steps 155 | self.latent_dim = hparams.latent_dim 156 | self.mlp_layers = hparams.mlp_layers 157 | self.mlp_hidden = hparams.mlp_hidden 158 | self.n_chan = hparams.n_chan 159 | self.radius = hparams.radius 160 | self.codec_neighbors = hparams.codec_neighbors 161 | self.teacher_forcing = hparams.teacher_forcing 162 | self.noise = hparams.noise 163 | self.interpolation = hparams.interpolation 164 | 165 | if self.loss == 'l1': 166 | self.criterion = nn.L1Loss() 167 | elif self.loss == 'l2': 168 | self.criterion = nn.MSELoss() 169 | elif self.loss == 'smooth_l1': 170 | self.criterion = nn.SmoothL1Loss() 171 | 172 | self.mse_criterion = nn.MSELoss() 173 | self.mae_criterion = nn.L1Loss() 174 | 175 | self.encoder = Encoder( 176 | node_in=self.time_slice+3, 177 | node_out=self.latent_dim, 178 | edge_in=self.time_slice+2, 179 | edge_out=self.latent_dim, 180 | mlp_layers=self.mlp_layers, 181 | mlp_hidden=self.mlp_hidden, 182 | ) 183 | self.processor = Processor( 184 | node_in=self.latent_dim, 185 | node_out=self.latent_dim, 186 | edge_in=self.latent_dim, 187 | edge_out=self.latent_dim, 188 | num_message_passing_steps=self.num_message_passing_steps, 189 | mlp_num_layers=self.mlp_layers, 190 | mlp_hidden_dim=self.mlp_hidden, 191 | ) 192 | 193 | self.proj_head = nn.Linear(self.latent_dim+4, self.n_chan) 194 | self.projector = MLP( 195 | in_dim=self.n_chan, 196 | hidden_list=[self.mlp_hidden]*self.mlp_layers, 197 | out_dim=1) 198 | 199 | 200 | self._encoder = Encoder( 201 | node_in=self.time_slice+3, 202 | node_out=self.latent_dim, 203 | edge_in=self.time_slice+2, 204 | edge_out=self.latent_dim, 205 | mlp_layers=self.mlp_layers, 206 | mlp_hidden=self.mlp_hidden, 207 | ) 208 | self._processor = Processor( 209 | node_in=self.latent_dim, 210 | node_out=self.latent_dim, 211 | edge_in=self.latent_dim, 212 | edge_out=self.latent_dim, 213 | num_message_passing_steps=self.num_message_passing_steps, 214 | mlp_num_layers=self.mlp_layers, 215 | mlp_hidden_dim=self.mlp_hidden, 216 | ) 217 | self._decoder = Decoder( 218 | node_in=self.latent_dim, 219 | node_out=self.time_slice, 220 | mlp_layers=self.mlp_layers, 221 | mlp_hidden=self.mlp_hidden, 222 | ) 223 | 224 | def continuous_decoder( 225 | self, 226 | x_lr, 227 | lr_encoded, 228 | lr_coords, 229 | hr_coords, 230 | t): 231 | ''' 232 | Args: 233 | x_lr, [B, T, C, L] 234 | lr_encoded, [B, L, C]: 235 | lr_coords, [B, L, 1] 236 | hr_coords, [B, N, 1] 237 | t, [B, T] 238 | ''' 239 | B, T, _, L = x_lr.shape 240 | N = hr_coords.shape[1] 241 | 242 | # Find nearest k low-res neighbors for each high-res coordinate (k=2 by default) 243 | flat_lr_coords = lr_coords.reshape(B*L, -1) 244 | batch_lr = torch.cat([torch.LongTensor([i]*L) for i in range(B)]).to(flat_lr_coords.device) 245 | flat_hr_coords = hr_coords.reshape(B*N, -1) 246 | batch_hr = torch.cat([torch.LongTensor([i]*N) for i in range(B)]).to(flat_hr_coords.device) 247 | assign_index = knn(flat_lr_coords, flat_hr_coords, self.codec_neighbors, batch_lr, batch_hr) 248 | 249 | lr_encoded_flat = lr_encoded.reshape(B*L, -1) 250 | timesteps = t.unsqueeze(1).repeat(1,N,1) # B, N, T 251 | timesteps = timesteps.reshape(B*N, -1) # B*N, T 252 | 253 | out = [] 254 | for i in range(T): 255 | weights = [] 256 | latents = [] 257 | x_lr_flat = x_lr[:,i].permute(0,2,1).reshape(B*L, -1) 258 | timestep = timesteps[:,i:i+1] 259 | for j in range(self.codec_neighbors): 260 | q_feat = lr_encoded_flat[assign_index[1,j::self.codec_neighbors]] 261 | q_inp = x_lr_flat[assign_index[1,j::self.codec_neighbors]] 262 | q_coord = flat_lr_coords[assign_index[1,j::self.codec_neighbors]] 263 | final_coord = q_coord-flat_hr_coords 264 | 265 | final_input = torch.cat([q_feat, q_inp, final_coord, timestep], dim=-1) 266 | if self.interpolation == 'area': 267 | weight = torch.norm(final_coord, 2, dim=-1)**2 # B*N, 1 268 | weight = weight.unsqueeze(-1) 269 | elif self.interpolation == 'knn': 270 | weight = (1/(torch.norm(final_coord, 2, dim=-1)**2)).unsqueeze(-1) 271 | elif self.interpolation == 'sph': 272 | weight = torch.pow(1 - (L*torch.norm(final_coord, 2, dim=-1)**2), 3).unsqueeze(-1) 273 | latents.append(self.proj_head(final_input)) # B*N, C 274 | weights.append(weight) 275 | 276 | if self.interpolation == 'area': 277 | latent = (latents[0]*weights[1]+latents[1]*weights[0])/(weights[1]+weights[0]) 278 | else: 279 | latent = (latents[0]*weights[0]+latents[1]*weights[1])/(weights[1]+weights[0]) 280 | out.append(latent) 281 | 282 | out = torch.stack(out, dim=1) # B*N, T, C 283 | return out 284 | 285 | 286 | def _build_graph(self, u, x, t): 287 | B, N, _ = u.shape 288 | 289 | u_ = u.reshape(B*N, -1) 290 | x_ = x.reshape(B*N, -1) 291 | 292 | batch_ids = torch.cat([torch.LongTensor([i for _ in range(n)]) for i, n in enumerate(B*[N])]).to(self.device) 293 | edges = radius_graph(x_, batch=batch_ids, r=self.radius, loop=True) # (2, n_edges) 294 | receivers = edges[0, :] 295 | senders = edges[1, :] 296 | edge_index = torch.stack([senders, receivers]) 297 | 298 | node_features = [] 299 | node_features.append(u_) 300 | node_features.append(x_) 301 | node_features.append(t[:,-1:].repeat(N, 1)) 302 | node_features = torch.cat(node_features, dim=-1) 303 | 304 | edge_features = [] 305 | 306 | edge_features.append((u_[senders]-u_[receivers])) 307 | edge_features.append((x_[senders]-x_[receivers])) 308 | edge_features = torch.cat(edge_features, dim=-1) 309 | 310 | return node_features, edge_index, edge_features 311 | 312 | def forward( 313 | self, 314 | x_lr, 315 | lr_coords, 316 | hr_coords, 317 | t, 318 | hr_last): 319 | ''' 320 | Args: 321 | x_lr: tensor of shape [B, T, C, L] that represents the low-resolution frames 322 | lr_coords: tensor of shape [B, L, 1] that represents the L coordinates for sequence of low frames in the batch 323 | hr_coords: tensor of shape [B, N, 1] that represents the N coordinates for sequence of points in the batch 324 | t: tensor of shape [B, T] represents the time-coordinates for each sequence in the batch 325 | ''' 326 | B, T, C, L = x_lr.shape 327 | N = hr_coords.shape[1] 328 | T_out = t.shape[1] - T 329 | 330 | # Build graph and encode it 331 | u = x_lr.permute(0,3,1,2) # B, L, T, C 332 | u = u.reshape(B, L, -1) # B, L, C 333 | node_features, edge_index, edge_features = self._build_graph(u, lr_coords, t[:,:T]) 334 | node_features, edge_features = self.encoder(node_features, edge_index, edge_features) 335 | lr_encoded, _ = self.processor(node_features, edge_index, edge_features) 336 | 337 | # Get interpolated features from low-res points 338 | z = self.continuous_decoder(x_lr, lr_encoded, lr_coords, hr_coords, t) # B*N, T, C 339 | hr_points = self.projector(z) # B*N, T, 1 340 | 341 | # Build Graph 342 | hr_points = hr_points.reshape(B, N, T, -1) # B, N, T, C 343 | hr_points = hr_points.reshape(B, N, -1) # B, N, C 344 | lr_points = x_lr.permute(0,3,1,2) # B, L, T, C 345 | lr_points = lr_points.reshape(B, L, -1) # B, L, C 346 | 347 | all_coords = torch.cat([lr_coords, hr_coords], dim=1) # B, (L+N), 1 348 | 349 | all_feats = torch.cat([lr_points, hr_points], dim=1) # B, (L+N), C 350 | 351 | node_features, edge_index, edge_features = self._build_graph(all_feats, all_coords, t[:,:T]) 352 | 353 | 354 | node_features, edge_features = self._encoder(node_features, edge_index, edge_features) 355 | node_features, _ = self._processor(node_features, edge_index, edge_features) 356 | node_features = self._decoder(node_features) # B*(L+N), T_out 357 | ret = node_features.reshape(B, -1, node_features.shape[-1]) # B, (L+N), T_out 358 | 359 | outputs = [] 360 | tt = t.unsqueeze(1).repeat(1,L+N,1) 361 | 362 | last_values = torch.cat([x_lr[:,-1].permute(0,2,1), hr_last], dim=1) # B, (L+N), 1 363 | 364 | for i in range(T_out): 365 | delta_t = tt[:,:,T+i:T+i+1]-tt[:,:,T-1:T] 366 | op = ret[...,i].unsqueeze(-1) # B, (L+N), 1 367 | outputs.append(last_values+delta_t*op) 368 | 369 | outputs = torch.stack(outputs, dim=1) # B, T, (L+N), 1 370 | 371 | out_lr = outputs[:,:,:L] 372 | out_hr = outputs[:,:,L:] 373 | hr_points = hr_points.reshape(B, N, T, -1) 374 | hr_points = hr_points.permute(0,2,1,3) 375 | 376 | return out_hr, out_lr, hr_points 377 | 378 | def configure_optimizers(self): 379 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) 380 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=self.step_size, gamma=self.factor) 381 | return { 382 | "optimizer": optimizer, 383 | "lr_scheduler": { 384 | "scheduler": scheduler 385 | }, 386 | } 387 | 388 | def training_step(self, train_batch, batch_idx): 389 | t = train_batch['t'].float() 390 | u = train_batch['lr_frames'].float() 391 | u_values = train_batch['hr_points'].float() 392 | coords = train_batch['coords_hr'].float() 393 | lr_coords = train_batch['coords_lr'].float() 394 | 395 | u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1 396 | B, T_future = u_values_future.shape[:2] 397 | 398 | u_values_hat = [] 399 | hr_values_hat = [] 400 | 401 | inp = u[:,:self.time_slice] 402 | noise = self.noise*torch.randn(inp.shape).to(inp.device) 403 | inp = inp+noise 404 | 405 | hr_last = u_values[:,self.time_slice-1] 406 | noise = self.noise*torch.randn(hr_last.shape).to(hr_last.device) 407 | hr_last = hr_last+noise 408 | 409 | for i in range(T_future//self.time_slice): 410 | out_hr, out_lr, hr_points = self.forward(inp, lr_coords, coords, t[:,i*self.time_slice:(i+2)*self.time_slice], hr_last) 411 | y_hat = torch.cat([out_hr, out_lr], dim=2) 412 | u_values_hat.append(y_hat) 413 | hr_values_hat.append(hr_points) 414 | 415 | if self.teacher_forcing: 416 | inp = u[:,(i+1)*self.time_slice:(i+2)*self.time_slice] # B, T, C, L 417 | hr_last = u_values[:,(i+2)*self.time_slice-1] 418 | else: 419 | inp = out_lr.permute(0,1,3,2) 420 | hr_last = out_hr[:,-1] 421 | 422 | noise = self.noise*torch.randn(inp.shape).to(inp.device) 423 | inp = inp+noise 424 | 425 | noise = self.noise*torch.randn(hr_last.shape).to(hr_last.device) 426 | hr_last = hr_last+noise 427 | 428 | u_values_hat = torch.cat(u_values_hat, dim=1) # B, T_out, (N+L), 1 429 | hr_values_hat = torch.cat(hr_values_hat, dim=1) # B, T_in, N, 1 430 | 431 | target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2) 432 | loss = self.criterion(u_values_hat, target)+self.criterion(hr_values_hat, u_values[:,:-self.time_slice]) 433 | mae_loss = self.mae_criterion(u_values_hat, target) 434 | interp_loss = self.mae_criterion(hr_values_hat, u_values[:,:-self.time_slice]) 435 | 436 | self.log('train_loss', loss, prog_bar=True) 437 | self.log('train_mae_loss', mae_loss, prog_bar=True) 438 | self.log('train_interp_loss', interp_loss, prog_bar=True) 439 | 440 | return loss 441 | 442 | def validation_step(self, val_batch, batch_idx): 443 | t = val_batch['t'].float() 444 | u = val_batch['lr_frames'].float() # B, T, 1, L 445 | u_values = val_batch['hr_points'].float() 446 | coords = val_batch['coords_hr'].float() 447 | lr_coords = val_batch['coords_lr'].float() 448 | 449 | u_values_future = u_values[:,self.time_slice:] # B, T_future, N, 1 450 | T_future = u_values_future.shape[1] 451 | 452 | u_values_hat = [] 453 | inp = u[:,:self.time_slice] 454 | hr_last = u_values[:,self.time_slice-1] 455 | 456 | for i in range(T_future//self.time_slice): 457 | out_hr, out_lr, _ = self.forward( 458 | inp, 459 | lr_coords, 460 | coords, 461 | t[:,i*self.time_slice:(i+2)*self.time_slice], 462 | hr_last) 463 | y_hat = torch.cat([out_hr, out_lr], dim=2) 464 | u_values_hat.append(y_hat) 465 | 466 | inp = out_lr.permute(0,1,3,2) 467 | hr_last = out_hr[:,-1] 468 | 469 | u_values_hat = torch.cat(u_values_hat, dim=1) 470 | target = torch.cat([u_values_future, u[:,self.time_slice:].permute(0,1,3,2)], dim=2) 471 | loss = self.criterion(u_values_hat, target) 472 | mae_loss = self.mae_criterion(u_values_hat, target) 473 | 474 | self.log('val_loss', loss, prog_bar=True) 475 | self.log('val_mae_loss', mae_loss, prog_bar=True) --------------------------------------------------------------------------------