├── .gitignore ├── .gitmodules ├── README.md ├── analysis_1d_evo.ipynb ├── analysis_2d_evo.ipynb ├── analysis_2d_full.py ├── analysis_2d_rl.ipynb ├── argparser.py ├── assets ├── gif-MeshGraphNets+gt remeshing.gif ├── gif-MeshGraphNets+heuristic remeshing.gif ├── gif-ground_truth.gif ├── gif-lamp.gif ├── gif-lamp_no_remeshing.gif ├── lamp_architecture.png └── lamp_poster.pdf ├── data ├── arcsimmesh_data │ ├── README.md │ └── __init__.py └── mppde1d_data │ └── __init__.py ├── datasets ├── README.md ├── arcsimmesh_dataset.py ├── datagen_square.ipynb ├── load_dataset.py └── mppde1d_dataset.py ├── gnns.py ├── license ├── models.py ├── requirements.txt ├── results └── README.md ├── train.py ├── utils.py └── utils_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Added by author 7 | data/* 8 | results/* 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pytorch_net"] 2 | path = pytorch_net 3 | url = git@github.com:tailintalent/pytorch_net.git 4 | [submodule "MP_Neural_PDE_Solvers"] 5 | path = MP_Neural_PDE_Solvers 6 | url = git@github.com:tailintalent/MP_Neural_PDE_Solvers.git 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LAMP: Learning Controllable Adaptive Simulation for Multi-resolution Physics (ICLR 2023 Notable-Top-25%) 2 | 3 | [Paper](https://openreview.net/forum?id=PbfgkZ2HdbE) | [arXiv](https://arxiv.org/abs/2305.01122) | [Poster](https://github.com/snap-stanford/lamp/blob/master/assets/lamp_poster.pdf) | [Slides](https://docs.google.com/presentation/d/1cMRGe2qNIrzSNRTUtbsVUod_PvyhDHcHzEa8wfxiQsw/edit?usp=sharing) | [Tweet](https://twitter.com/tailin_wu/status/1653253117671272448) | [Project Page](https://snap.stanford.edu/lamp/) 4 | 5 | Official repo for the paper [Learning Controllable Adaptive Simulation for Multi-resolution Physics](https://openreview.net/forum?id=PbfgkZ2HdbE)
6 | [Tailin Wu*](https://tailin.org/), [Takashi Maruyama*](https://sites.google.com/view/tmaruyama/home), [Qingqing Zhao*](https://cyanzhao42.github.io/), [Gordon Wetzstein](https://stanford.edu/~gordonwz/), [Jure Leskovec](https://cs.stanford.edu/people/jure/)
7 | ICLR 2023 **Notable-Top-25%**. 8 | 9 | It is the first fully DL-based surrogate model that jointly learns the evolution model, and optimizes spatial resolutions to reduce computational cost, learned via reinforcement learning. 10 | 11 | We demonstrate that LAMP is able to adaptively trade-off computation to improve long-term prediction error, by performing spatial refinement and coarsening of the mesh. LAMP outperforms state-of-the-art (SOTA) deep learning surrogate models, with an average of 33.7% error reduction for 1D nonlinear PDEs, and outperforms SOTA MeshGraphNets + Adaptive Mesh Refinement in 2D mesh-based simulations. 12 | 13 | 14 | 15 | Learned remeshing & evolution by LAMP: 16 | 17 | 18 | ## Installation 19 | 20 | 1. First clone the directory. Then run the following command to initialize the submodules: 21 | 22 | ```code 23 | git submodule init; git submodule update 24 | ``` 25 | 26 | (If showing error of no permission, need to first [add a new SSH key to your GitHub account](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/adding-a-new-ssh-key-to-your-github-account).) 27 | 28 | 2. Install dependencies. 29 | 30 | First, create a new environment using [conda](https://docs.conda.io/en/latest/miniconda.html) (with python >= 3.7). Then install pytorch, torch-geometric and other dependencies as follows (the repository is run with the following dependencies. Other version of torch-geometric or deepsnap may work but there is no guarentee.) 31 | 32 | Install pytorch (replace "cu113" with appropriate cuda version. For example, cuda11.1 will use "cu111"): 33 | ```code 34 | pip install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/torch_stable.html 35 | ``` 36 | 37 | Install torch-geometric. Run the following command: 38 | ```code 39 | pip install torch-scatter==2.0.9 -f https://data.pyg.org/whl/torch-1.10.2+cu113.html 40 | pip install torch-sparse==0.6.12 -f https://data.pyg.org/whl/torch-1.10.2+cu113.html 41 | pip install torch-geometric==1.7.2 42 | pip install torch-cluster==1.5.9 -f https://data.pyg.org/whl/torch-1.10.2+cu113.html 43 | ``` 44 | 45 | Install other dependencies: 46 | ```code 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | If wanting to use wandb (--wandb=True), need to set up wandb, following [this link](https://docs.wandb.ai/quickstart). 51 | 52 | If wanting to run 2d mesh-based simulation, FEniCS needs to be installed: 53 | 54 | ```code 55 | conda install -c conda-forge fenics 56 | ``` 57 | 58 | 59 | ## Dataset 60 | 61 | The dataset files can be downloaded via [this link](https://drive.google.com/drive/folders/1ld5I86mPC7wWTxPhbCtG2AcH0vLW3o25?usp=share_link). 62 | * To run 1D experiment, download the files under "mppde1d_data/" in the link into the "data/mppde1d_data/" folder in the local repo. 63 | * To run 2D mesh-based experiment, download the files under "arcsimmesh_data/" in the link into the "data/arcsimmesh_data/" folder in the local repo. Script for data generation is also provided ("datasets/datagen_square.ipynb".) To run the script, compile [ARCSim v0.2.1](http://graphics.berkeley.edu/resources/ARCSim/#:~:text=ArcSim%20is%20a%20simulation,detail%20of%20the%20simulated%20objects) and place the script in ARCSim folder. The detailed explanation for the attributes for the mesh-based dataset are provided under [datasets/README.md](https://github.com/snap-stanford/lamp/tree/master/datasets). 64 | 65 | 66 | ## Training 67 | 68 | Below we provide example commands for training LAMP. 69 | 70 | ### 1D nonlinear PDE: 71 | 72 | First, **pre-train** the evolution model for 1D: 73 | 74 | ```code 75 | python train.py --exp_id=evo-1d --date_time=2023-01-01 --dataset=mppde1df-E2-100-nt-250-nx-200 --time_interval=1 --data_dropout=node:0-0.3:0.1 --latent_size=64 --n_train=-1 --save_interval=5 --test_interval=5 --algo=gnnremesher --rl_coefs=None --input_steps=1 --act_name=silu --multi_step=1^2:0.1^3:0.1^4:0.1 --temporal_bundle_steps=25 --use_grads=False --is_y_diff=False --loss_type=mse --batch_size=16 --val_batch_size=16 --epochs=50 --opt=adam --weight_decay=0 --seed=0 --id=0 --verbose=1 --n_workers=0 --gpuid=0 76 | ``` 77 | 78 | The learned model will be saved under `./results/{--exp_id}_{--date_time}/`, where the `{--exp_id}` and `{--date_time}` are specified in the above command. The filename has the format of `*{hash}_{machine_name}.p`, e.g. "mppde1df-E2-100-nt-250-nx-200_train_-1_algo_gnnremesher_..._Hash_mhkVkAaz_ampere3.p", then the `{hash}` is `mhkVkAaz` and `{machine_name}` is `ampere3`, where the `{hash}` is uniquely determined by **all** the argument settings in the [argparser.py](https://github.com/snap-stanford/lamp/blob/master/argparser.py) (therefore, as long as any argument setting is different, the filename will be different and will not overwrite each other). 79 | 80 | Then, **jointly train** the remeshing model via reinforcement learning (RL) and the evolution model. The `--load_dirname` below should use folder name `{exp_id}_{date_time}` where the evolution model is located (as specified above), and the `--load_filename` should use part of the filename that can uniquely identify this model file, and should include the `{hash}` of this model. 81 | 82 | ```code 83 | python train.py --load_dirname=evo-1d_2023-01-01 --load_filename=Q66bz42y --exp_id=rl-1d --date_time=2023-01-02 --wandb_project_name=rl-1d_2023-01-02 --wandb=True --dataset=mppde1df-E2-100-nt-250-nx-200 --time_interval=1 --data_dropout=None --latent_size=64 --n_train=-1 --input_steps=1 --act_name=elu --multi_step=1^2:0.1^3:0.1^4:0.1 --temporal_bundle_steps=25 --use_grads=False --is_y_diff=False --loss_type=mse --batch_size=128 --val_batch_size=128 --epochs=30 --opt=adam --weight_decay=0 --seed=0 --verbose=1 --n_workers=0 --gpuid=7 --algo=srlgnnremesher --reward_mode=lossdiff+statediff --reward_beta=0-0.5 --rl_data_dropout=uniform:2 --min_edge_size=0.0014 --rl_horizon=4 --reward_loss_coef=5 --rl_eta=1e-2 --actor_lr=5e-4 --value_lr=1e-4 --value_num_pool=1 --value_pooling_type=global_mean_pool --value_latent_size=32 --value_batch_norm=False --actor_batch_norm=True --rescale=10 --edge_attr=True --rl_gamma=0.9 --value_loss_coef=0.5 --max_grad_norm=2 --is_single_action=False --value_target_mode=vanilla --wandb_step_plot=100 --wandb_step=20 --save_iteration=1000 --save_interval=1 --test_interval=1 --gpuid=3 --lr=1e-4 --actor_critic_step=200 --evolution_steps=200 --reward_condition=True --max_action=20 --rl_is_finetune_evolution=True --rl_finetune_evalution_mode=policy:fine --id=0 84 | ``` 85 | 86 | ### 2D mesh-based simulation: 87 | 88 | **Pre-train** the evolution model for 2D (need to have FEniCS installed, see "Installation" section: 89 | 90 | ```code 91 | export OMP_NUM_THREADS=6; python train.py --exp_id=evo-2d --date_time=2023-01-01 --dataset=arcsimmesh_square_annotated --time_interval=2 --data_dropout=None --n_train=-1 --save_interval=5 --algo=gnnremesher-evolution --rl_coefs=None --input_steps=2 --act_name=silu --multi_step=1 --temporal_bundle_steps=1 --edge_attr=True --use_grads=False --is_y_diff=False --loss_type=l2 --batch_size=10 --val_batch_size=10 --latent_size=56 --n_layers=8 --noise_amp=1e-2 --correction_rate=0.9 --epochs=100 --opt=adam --weight_decay=0  --is_mesh=True --seed=0 --id=0 --verbose=2 --test_interval=2 --n_workers=20 --gpuid=0 92 | ``` 93 | 94 | Then, **jointly train** the remeshing model via RL and the evolution model: 95 | 96 | ```code 97 | export OMP_NUM_THREADS=6; python train.py --exp_id=2d_rl --wandb_project_name=2d_rerun --wandb=True --date_time=2023-02-26 --dataset=arcsimmesh_square_annotated_coarse_minlen008_interp_500 --time_interval=2 --n_train=-1 --latent_size=64 --load_dirname=evo-2d_2023_02_18 --load_filename=9UQLIKKc_ampere1 --input_steps=2 --act_name=elu --temporal_bundle_steps=1 --use_grads=False --is_y_diff=True --loss_type=l2 --epochs=300 --opt=adam --weight_decay=0 --verbose=1 --algo=srlgnnremesher --reward_mode=lossdiff+statediff --rl_data_dropout=None --min_edge_size=0.04 --actor_lr=5e-4 --value_lr=1e-4 --value_num_pool=1 --value_pooling_type=global_mean_pool --value_latent_size=64 --value_batch_norm=False --actor_batch_norm=True --rescale=10 --edge_attr=True --rl_gamma=0.9 --value_loss_coef=0.5 --max_grad_norm=20 --is_single_action=False --value_target_mode=vanilla --wandb_step_plot=50 --wandb_step=2 --id=0 --save_iteration=500 --save_interval=1 --test_interval=1 --is_mesh=True --is_unittest=False --rl_horizon=6 --multi_step=6 --rl_eta=2e-2 --reward_beta=0 --reward_condition=True --max_action=20 --rl_is_finetune_evolution=True --lr=1e-4 --actor_critic_step=200 --evolution_steps=100 --rl_finetune_evalution_mode=policy:fine --wandb=True --batch_size=64 --val_batch_size=64 --n_workers=6 --reward_loss_coef=1000 --evl_stop_gradient=True --noise_amp=0.01 --gpuid=5 --is_eval_sample=True --seed=256 --n_train=:-1 --soft_update=False --fine_tune_gt_input=True --policy_input_feature=coords --skip_coarse=False --skip_flip=True --processor_aggr=mean --fix_alt_evolution_model=True 98 | ``` 99 | 100 | For commands for **baseline** models in 1D, see the README in [./MP_Neural_PDE_Solvers/](https://github.com/tailintalent/MP_Neural_PDE_Solvers). 101 | 102 | We also provide pre-trained evolution models directly for RL training [here](https://drive.google.com/drive/folders/1ioR5gjYeQaNQMvqrdMYMZI-0n5k5UM8f). Put the folders in the Google doc (e.g., "evo-1d_2023-01-01" for pre-trained evolution model for 1d, "evo-2d_2023-02-18" for pre-trained evolution model for 2d) under the ./results/ folder, and can then use the RL commands above to perform joint training. 103 | 104 | ## Analysis 105 | 106 | * For 1D experiments, to analyze the pretrained evolution model for LAMP, use [analysis_1D_evo.ipynb](https://github.com/snap-stanford/lamp/blob/master/analysis_1d_evo.ipynb). 107 | 108 | * For 1D experiments, to analyze the full model for LAMP and the baselines, use [analysis_1D_full.py](https://github.com/snap-stanford/lamp/blob/master/analysis_1d_full.py). 109 | 110 | * For 1D experiments, to analyze the baseline models (MP-PDE, FNO, CNN), use [./MP_Neural_PDE_Solvers/analysis.ipynb](https://github.com/tailintalent/MP_Neural_PDE_Solvers/blob/master/analysis.ipynb). 111 | 112 | * For 2D experiments, to analyze the pretrained evolution model for LAMP, use [analysis_2D_evo.ipynb](https://github.com/snap-stanford/lamp/blob/master/analysis_2d_evo.ipynb). 113 | 114 | * For 2D experiments, to analyze the full model for LAMP and the baselines, use [analysis_2D_full.py](https://github.com/snap-stanford/lamp/blob/master/analysis_2d_full.py) and [analysis_2d_rl.ipynb](https://github.com/snap-stanford/lamp/blob/master/analysis_2d_rl.ipynb). 115 | 116 | ## Visualization: 117 | 118 | Example visualization of learned remeshing & evolution: 119 | 120 | LAMP: 121 | 122 | 123 | 124 | LAMP (no remeshing): 125 | 126 | 127 | 128 | MeshGraphNets + ground-truth remeshing: 129 | 130 | 131 | 132 | 133 | MeshGraphNets + heuristic remeshing: 134 | 135 | 136 | 137 | 138 | Ground-truth (fine-grained): 139 | 140 | 141 | 142 | ## Related Projects 143 | 144 | * [LE-PDE](https://github.com/snap-stanford/le_pde) (NeurIPS 2022): Accelerate the simulation and inverse optimization of PDEs. Compared to state-of-the-art deep learning-based surrogate models (e.g., FNO, MP-PDE), it is up to 15x improvement in speed, while achieving competitive accuracy. 145 | 146 | * [CinDM](https://github.com/AI4Science-WestlakeU/cindm) (ICLR 2024 spotlight): We introduce a method that uses compositional generative models to design boundaries and initial states significantly more complex than the ones seen in training for physical simulations. 147 | 148 | * [BENO](https://github.com/AI4Science-WestlakeU/beno) (ICLR 2024): We introduce a boundary-embedded neural operator that incorporates complex boundary shape and inhomogeneous boundary values into the solving of Elliptic PDEs. 149 | 150 | ## Citation 151 | If you find our work and/or our code useful, please cite us via: 152 | 153 | ```bibtex 154 | @inproceedings{wu2023learning, 155 | title={Learning Controllable Adaptive Simulation for Multi-resolution Physics}, 156 | author={Tailin Wu and Takashi Maruyama and Qingqing Zhao and Gordon Wetzstein and Jure Leskovec}, 157 | booktitle={The Eleventh International Conference on Learning Representations}, 158 | year={2023}, 159 | url={https://openreview.net/forum?id=PbfgkZ2HdbE} 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /analysis_1d_evo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ca791c51-0ee5-40b0-b207-75ac8e2a491d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "\n", 13 | "import argparse\n", 14 | "from collections import OrderedDict\n", 15 | "import datetime\n", 16 | "import gc\n", 17 | "get_ipython().run_line_magic('matplotlib', 'inline')\n", 18 | "import matplotlib\n", 19 | "import matplotlib.pylab as plt\n", 20 | "from numbers import Number\n", 21 | "import numpy as np\n", 22 | "import pandas as pd\n", 23 | "pd.options.display.max_rows = 1500\n", 24 | "pd.options.display.max_columns = 200\n", 25 | "pd.options.display.width = 1000\n", 26 | "pd.set_option('max_colwidth', 400)\n", 27 | "import pdb\n", 28 | "import pickle\n", 29 | "import pprint as pp\n", 30 | "import time\n", 31 | "import torch\n", 32 | "import torch.nn as nn\n", 33 | "import torch.nn.functional as F\n", 34 | "from torch import optim\n", 35 | "from torch.utils.data import DataLoader\n", 36 | "from deepsnap.batch import Batch as deepsnap_Batch\n", 37 | "\n", 38 | "import sys, os\n", 39 | "sys.path.append(os.path.join(os.path.dirname(\"__file__\"), '..'))\n", 40 | "sys.path.append(os.path.join(os.path.dirname(\"__file__\"), '..', '..'))\n", 41 | "from lamp.argparser import arg_parse\n", 42 | "from lamp.datasets.load_dataset import load_data\n", 43 | "from lamp.gnns import get_data_dropout\n", 44 | "from lamp.models import load_model\n", 45 | "from lamp.pytorch_net.util import Interp1d_torch, groupby_add_keys, filter_df, get_unique_keys_df, Attr_Dict, Printer, get_num_params, get_machine_name, pload, pdump, to_np_array, get_pdict, reshape_weight_to_matrix, ddeepcopy as deepcopy, plot_vectors, record_data, filter_filename, Early_Stopping, str2bool, get_filename_short, print_banner, plot_matrices, get_num_params, init_args, filter_kwargs, to_string, COLOR_LIST\n", 46 | "from lamp.utils import p, update_legacy_default_hyperparam, EXP_PATH, deepsnap_to_pyg, LpLoss, to_cpu, to_tuple_shape, parse_multi_step, loss_op, get_cholesky_inverse, get_device, get_data_comb\n", 47 | "\n", 48 | "device = torch.device(\"cuda:5\")" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "fd583f18-3ce5-469a-a89f-a19887163e00", 54 | "metadata": {}, 55 | "source": [ 56 | "## Functions:" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "df5ff29f-9493-4e3b-a5e7-9545e7c0a143", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# Analysis:\n", 67 | "def get_results_1d(\n", 68 | " all_hash,\n", 69 | " mode=\"best\",\n", 70 | " exclude_idx=(None,),\n", 71 | " dropout_mode=\"None\",\n", 72 | " n_rollout_steps=-1,\n", 73 | " is_full_eval=True,\n", 74 | " dirname=None,\n", 75 | " suffix=\"\",\n", 76 | "):\n", 77 | " \"\"\"\n", 78 | " Perform analysis on the 1D Burgers' benchmark.\n", 79 | "\n", 80 | " Args:\n", 81 | " all_hash: a list of hashes which indicates the experiments to load for analysis\n", 82 | " mode: choose from \"best\" (load the best model with lowest validation loss) or an integer, \n", 83 | " e.g. -1 (last saved model), -2 (second last saved model)\n", 84 | " is_full_eval: if True, evalute on all grount-truth points. E.g., assuming is_full_eval is True, \n", 85 | " if the prediction is on 50 vertices, it will be first interpolated to the vertices \n", 86 | " of ground-truth, then compute the loss w.r.t. ground-truth.\n", 87 | " dirname: if not None, will use the dirnaem provided. E.g. tailin-1d_2022-7-27\n", 88 | " suffix: suffix for saving the analysis result.\n", 89 | " \"\"\"\n", 90 | " \n", 91 | " isplot = True\n", 92 | " df_dict_list = []\n", 93 | " dirname_start = dirname\n", 94 | " for hash_str in all_hash:\n", 95 | " df_dict = {}\n", 96 | " df_dict[\"hash\"] = hash_str\n", 97 | " # Load model:\n", 98 | " is_found = False\n", 99 | " for dirname_core in [\n", 100 | " dirname_start,\n", 101 | " ]:\n", 102 | " filename = filter_filename(EXP_PATH + dirname_core, include=hash_str)\n", 103 | " if len(filename) == 1:\n", 104 | " is_found = True\n", 105 | " break\n", 106 | " if not is_found:\n", 107 | " print(f\"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.\")\n", 108 | " continue\n", 109 | " dirname = EXP_PATH + dirname_core\n", 110 | " if not dirname.endswith(\"/\"):\n", 111 | " dirname += \"/\"\n", 112 | "\n", 113 | " try:\n", 114 | " data_record = pload(dirname + filename[0])\n", 115 | " except Exception as e:\n", 116 | " # p.print(f\"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:\", banner_size=100)\n", 117 | " print(f\"error {e} in hash_str {hash_str}\")\n", 118 | " continue\n", 119 | " p.print(f\"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:\", banner_size=160)\n", 120 | " if isplot:\n", 121 | " plot_learning_curve(data_record)\n", 122 | " args = init_args(update_legacy_default_hyperparam(data_record[\"args\"]))\n", 123 | " args.filename = filename\n", 124 | " if mode == \"best\":\n", 125 | " model = load_model(data_record[\"best_model_dict\"], device=device)\n", 126 | " print(\"Load the model with best validation loss.\")\n", 127 | " else:\n", 128 | " assert isinstance(mode, int)\n", 129 | " print(f'Load the model at epoch {data_record[\"epoch\"][mode]}')\n", 130 | " model = load_model(data_record[\"model_dict\"][mode], device=device)\n", 131 | " model.eval()\n", 132 | " # pp.pprint(args.__dict__)\n", 133 | " kwargs = {}\n", 134 | " if data_record[\"best_model_dict\"][\"type\"].startswith(\"GNNPolicy\"):\n", 135 | " kwargs[\"is_deepsnap\"] = True\n", 136 | "\n", 137 | " # Load test dataset:\n", 138 | " args_test = deepcopy(args)\n", 139 | " multi_step = (250 - 50) // args_test.temporal_bundle_steps\n", 140 | " args_test.multi_step = f\"1^{multi_step}\"\n", 141 | " args_test.is_test_only = True\n", 142 | " args_test.n_train = \"-1\"\n", 143 | " n_test_traj = 128\n", 144 | " (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)\n", 145 | " nx = int(args.dataset.split(\"-\")[2])\n", 146 | " time_stamps_effective = len(dataset_test) // n_test_traj\n", 147 | " for exclude_idx_ele in exclude_idx:\n", 148 | " loss_list = []\n", 149 | " pred_list = []\n", 150 | " y_list = []\n", 151 | " for i in range(n_test_traj):\n", 152 | " idx = i * time_stamps_effective + args_test.temporal_bundle_steps\n", 153 | " data = deepcopy(dataset_test[idx])\n", 154 | " data_ori = deepcopy(dataset_test[idx])\n", 155 | " if dropout_mode == \"None\":\n", 156 | " if exclude_idx_ele is not None:\n", 157 | " data = get_data_dropout(data, dropout_mode=\"node:0\", exclude_idx=exclude_idx_ele)\n", 158 | " else:\n", 159 | " data = get_data_dropout(data, dropout_mode=dropout_mode)\n", 160 | " data = data.to(device)\n", 161 | " preds, info = model(\n", 162 | " data,\n", 163 | " pred_steps=np.arange(1,n_rollout_steps+1) if n_rollout_steps != -1 else np.arange(1, max(parse_multi_step(args_test.multi_step).keys())+1),\n", 164 | " latent_pred_steps=None,\n", 165 | " is_recons=False,\n", 166 | " use_grads=False,\n", 167 | " use_pos=args.use_pos,\n", 168 | " is_y_diff=False,\n", 169 | " is_rollout=False,\n", 170 | " **kwargs\n", 171 | " )\n", 172 | " y = data.node_label[\"n0\"]\n", 173 | " if n_rollout_steps != -1:\n", 174 | " y = y[:,:25*n_rollout_steps]\n", 175 | " pred = preds[\"n0\"].reshape(y.shape)\n", 176 | " pred_list.append(pred.detach())\n", 177 | " y_list.append(y.detach())\n", 178 | " if is_full_eval:\n", 179 | " \"\"\"\n", 180 | " y_ori: [100, 175, 1]\n", 181 | " x_pos_ori: [100, 1]\n", 182 | " x_pos: [50, 1] for uniform:2\n", 183 | " pred: [50, 175, 1]\n", 184 | " Interp1d_torch(x, y, x_new):\n", 185 | " x: (B, N)\n", 186 | " y: (B, N)\n", 187 | " x_new: (B, P)\n", 188 | " \"\"\"\n", 189 | " y_ori = data_ori.node_label[\"n0\"]\n", 190 | " if n_rollout_steps != -1:\n", 191 | " y_ori = y_ori[:,:25*n_rollout_steps]\n", 192 | " x_pos_ori = data_ori.node_pos[\"n0\"] # [100, 1]\n", 193 | " x_pos = data.node_pos[\"n0\"] # [50, 1]\n", 194 | " fnc = Interp1d_torch()\n", 195 | " n_steps = y_ori.shape[1]\n", 196 | " x_pos_format = x_pos.transpose(0,1).expand(n_steps, x_pos.shape[0]) # [175, 50]\n", 197 | " pred_format = pred.squeeze(2).transpose(0,1)\n", 198 | " x_pos_ori_format = x_pos_ori.transpose(0,1).expand(n_steps, x_pos_ori.shape[0]).to(device) # [175, 100]\n", 199 | " pred_interp_format = fnc(x_pos_format, pred_format, x_pos_ori_format) # [175, 100]\n", 200 | " y_ori_format = y_ori.squeeze(2).transpose(0,1).to(device) # [175, 100]\n", 201 | " loss_ele = nn.MSELoss(reduction=\"sum\")(pred_interp_format, y_ori_format) / nx\n", 202 | " else:\n", 203 | " loss_ele = nn.MSELoss(reduction=\"sum\")(pred, y) / nx\n", 204 | " loss_list.append(loss_ele.item())\n", 205 | "\n", 206 | " loss_mean = np.mean(loss_list)\n", 207 | " pred_list = torch.stack(pred_list).squeeze(-1)\n", 208 | " y_list = torch.stack(y_list).squeeze(-1)\n", 209 | " df_dict[f\"loss_cumu_{exclude_idx_ele}\"] = loss_mean \n", 210 | " print(\"\\nTest for {} for exclude_idx={} is: {:.9f} at epoch {}, for {}/{} epochs\".format(hash_str, exclude_idx_ele, loss_mean, data_record['best_epoch'], len(data_record[\"train_loss\"]), args.epochs))\n", 211 | "\n", 212 | " mse_full = nn.MSELoss(reduction=\"none\")(pred_list, y_list)\n", 213 | " mse_time = to_np_array(mse_full.mean((0,1)))\n", 214 | " p.print(\"Learning curve:\", is_datetime=False, banner_size=100)\n", 215 | " plt.figure(figsize=(12,5))\n", 216 | " plt.subplot(1,2,1)\n", 217 | " plt.plot(mse_time)\n", 218 | " plt.xlabel(\"rollout step\")\n", 219 | " plt.ylabel(\"MSE\")\n", 220 | " plt.title(\"MSE vs. rollout step (linear scale)\")\n", 221 | " plt.subplot(1,2,2)\n", 222 | " plt.semilogy(mse_time)\n", 223 | " plt.xlabel(\"rollout step\")\n", 224 | " plt.ylabel(\"MSE\")\n", 225 | " plt.title(\"MSE vs. rollout step (log scale)\")\n", 226 | " plt.show()\n", 227 | " plt.figure(figsize=(6,5))\n", 228 | " plt.plot(mse_time.cumsum())\n", 229 | " plt.title(\"cumulative MSE vs. rollout step\")\n", 230 | " plt.xlabel(\"rollout step\")\n", 231 | " plt.ylabel(\"cumulative MSE\")\n", 232 | " plt.show()\n", 233 | "\n", 234 | " # Visualization:\n", 235 | " for idx in range(6,8):\n", 236 | " p.print(f\"Example {idx*128}:\", banner_size=100, is_datetime=False)\n", 237 | " data = deepcopy(dataset_test[idx*128]).to(device)\n", 238 | " if exclude_idx_ele is not None:\n", 239 | " data = get_data_dropout(data, dropout_mode=\"node:0\", exclude_idx=exclude_idx_ele)\n", 240 | " preds, info = model(\n", 241 | " data,\n", 242 | " pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),\n", 243 | " latent_pred_steps=None,\n", 244 | " is_recons=False,\n", 245 | " use_grads=False,\n", 246 | " use_pos=args.use_pos,\n", 247 | " is_y_diff=False,\n", 248 | " is_rollout=False,\n", 249 | " **kwargs\n", 250 | " )\n", 251 | " y = data.node_label[\"n0\"]\n", 252 | " pred = preds[\"n0\"].reshape(y.shape)\n", 253 | " visualize(pred, y)\n", 254 | " visualize_paper(pred, y)\n", 255 | "\n", 256 | " p.print(f\"Individual prediction at rollout step {y.shape[1]}:\", banner_size=100, is_datetime=False)\n", 257 | " time_step = -1\n", 258 | " for idx in range(0, 20, 5):\n", 259 | " plt.figure(figsize=(6,4))\n", 260 | " plt.plot(to_np_array(pred_list[idx,:,time_step]), label=\"pred\")\n", 261 | " plt.plot(to_np_array(y_list[idx,:,time_step]), \"--\", label=\"y\")\n", 262 | " plt.legend()\n", 263 | " plt.show()\n", 264 | " df_dict[\"best_epoch\"] = data_record['best_epoch']\n", 265 | " df_dict[\"epoch\"] = len(data_record[\"train_loss\"])\n", 266 | " df_dict.update(args.__dict__)\n", 267 | " df_dict_list.append(df_dict)\n", 268 | " df = pd.DataFrame(df_dict_list)\n", 269 | " pdump(df, f\"df_1d{suffix}.p\")\n", 270 | " return df\n", 271 | "\n", 272 | "# Plotting:\n", 273 | "def plot_learning_curve(data_record):\n", 274 | " plt.figure(figsize=(16,6))\n", 275 | " plt.subplot(1,2,1)\n", 276 | " plt.plot(data_record[\"epoch\"], data_record[\"train_loss\"], label=\"train\")\n", 277 | " plt.plot(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"val_loss\"], label=\"val\")\n", 278 | " plt.plot(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"test_loss\"], label=\"test\")\n", 279 | " plt.title(\"Learning curve, linear scale\")\n", 280 | " plt.legend()\n", 281 | " plt.subplot(1,2,2)\n", 282 | " plt.semilogy(data_record[\"epoch\"], data_record[\"train_loss\"], label=\"train\")\n", 283 | " plt.semilogy(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"val_loss\"], label=\"val\")\n", 284 | " plt.semilogy(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"test_loss\"], label=\"test\")\n", 285 | " plt.title(\"Learning curve, log scale\")\n", 286 | " plt.legend()\n", 287 | " plt.show()\n", 288 | "\n", 289 | "\n", 290 | "def plot_colorbar(matrix, vmax=None, vmin=None, cmap=\"seismic\", label=None):\n", 291 | " if vmax==None:\n", 292 | " vmax = matrix.max()\n", 293 | " vmin = matrix.min()\n", 294 | " im = plt.imshow(matrix,vmax=vmax,vmin=vmin,cmap=cmap)\n", 295 | " plt.title(label)\n", 296 | " im_ratio = matrix.shape[0]/matrix.shape[1]\n", 297 | " plt.colorbar(im,fraction=0.046*im_ratio,pad=0.04)\n", 298 | "\n", 299 | "\n", 300 | "def visualize(pred, gt, animate=False):\n", 301 | " if torch.is_tensor(gt):\n", 302 | " gt = to_np_array(gt)\n", 303 | " pred = to_np_array(pred)\n", 304 | " mse_over_t = ((gt-pred)**2).mean(axis=0).mean(axis=-1)\n", 305 | " \n", 306 | " if not animate:\n", 307 | " vmax = gt.max()\n", 308 | " vmin = gt.min()\n", 309 | " plt.figure(figsize=[15,5])\n", 310 | " plt.subplot(1,4,1)\n", 311 | " plot_colorbar(gt[:,:,0].T,label=\"gt\")\n", 312 | " plt.subplot(1,4,2)\n", 313 | " plot_colorbar(pred[:,:,0].T,label=\"pred\")\n", 314 | " plt.subplot(1,4,3)\n", 315 | " plot_colorbar((pred-gt)[:,:,0].T,vmax=np.abs(pred-gt).max(),vmin=(-1*np.abs(pred-gt).max()),label=\"diff\")\n", 316 | " plt.subplot(1,4,4)\n", 317 | " plt.plot(mse_over_t);plt.title(\"mse over t\");plt.yscale('log');\n", 318 | " plt.tight_layout()\n", 319 | " plt.show()\n", 320 | "\n", 321 | "def visualize_paper(pred, gt, is_save=False):\n", 322 | " idx = 6\n", 323 | " nx = pred.shape[0]\n", 324 | "\n", 325 | " fontsize = 14\n", 326 | " idx_list = np.arange(0, 200, 15)\n", 327 | " color_list = np.linspace(0.01, 0.9, len(idx_list))\n", 328 | " x_axis = np.linspace(0,16,nx)\n", 329 | " cmap = matplotlib.cm.get_cmap('jet')\n", 330 | "\n", 331 | " plt.figure(figsize=(16,5))\n", 332 | " plt.subplot(1,2,1)\n", 333 | " for i, idx in enumerate(idx_list):\n", 334 | " pred_i = to_np_array(pred[...,idx,:].squeeze())\n", 335 | " rgb = cmap(color_list[i])[:3]\n", 336 | " plt.plot(x_axis, pred_i, color=rgb, label=f\"t={np.round(i*0.3, 1)}s\")\n", 337 | " plt.ylabel(\"u(t,x)\", fontsize=fontsize)\n", 338 | " plt.xlabel(\"x\", fontsize=fontsize)\n", 339 | " plt.tick_params(labelsize=fontsize)\n", 340 | " # plt.legend(fontsize=10, bbox_to_anchor=[1,1])\n", 341 | " plt.xticks([0,8,16], [0,8,16])\n", 342 | " plt.ylim([-2.5,2.5])\n", 343 | " plt.title(\"Prediction\")\n", 344 | " if is_save:\n", 345 | " plt.savefig(f\"1D_E2-{nx}.pdf\", bbox_inches='tight')\n", 346 | "\n", 347 | " plt.subplot(1,2,2)\n", 348 | " for i, idx in enumerate(idx_list):\n", 349 | " y_i = to_np_array(gt[...,idx,:])\n", 350 | " rgb = cmap(color_list[i])[:3]\n", 351 | " plt.plot(x_axis, y_i, color=rgb, label=f\"t={np.round(i*0.3, 1)}s\")\n", 352 | " plt.ylabel(\"u(t,x)\", fontsize=fontsize)\n", 353 | " plt.xlabel(\"x\", fontsize=fontsize)\n", 354 | " plt.tick_params(labelsize=fontsize)\n", 355 | " plt.legend(fontsize=10, bbox_to_anchor=[1,1])\n", 356 | " plt.xticks([0,8,16], [0,8,16])\n", 357 | " plt.ylim([-2.5,2.5])\n", 358 | " plt.title(\"Ground-truth\")\n", 359 | " if is_save:\n", 360 | " plt.savefig(f\"1D_gt-{nx}.pdf\", bbox_inches='tight')\n", 361 | " plt.show()" 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "id": "bc711711-1ca8-4944-92d3-848313c275e9", 367 | "metadata": {}, 368 | "source": [ 369 | "## Analysis:" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "id": "84b58196-d00d-4519-8e3b-a0789dbf9444", 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "\"\"\"\n", 380 | "all_hash is a list of hashes, each of which corresponds to one experiment.\n", 381 | " For example, if one experiment is saved under ./results/evo-1d_2023-01-01/mppde1d-E2-50_train_-1_algo_contrast_ebm_False_ebmt_cd_enc_cnn-s_evo_cnn_act_elu_hid_128_lo_rmse_recef_1.0_conef_1.0_nconv_4_nlat_1_clat_3_lf_True_reg_None_id_0_Hash_qvQry9QJ_ampere3.p\n", 382 | " Then, the \"qvQry9QJ\" (located at the end of the filename) is the {hash} of this file.\n", 383 | " The \"evo-1d_2023-01-01\" is the \"{--exp_id}_{--date_time}\" of the training command.\n", 384 | " all_hash can contain multiple hashes, and analyze them sequentially.\n", 385 | "\"\"\"\n", 386 | "all_hash = [\n", 387 | " \"Q66bz42y_ampere3\",\n", 388 | "]\n", 389 | "df9 = get_results_1d(\n", 390 | " all_hash,\n", 391 | " dirname=\"evo-1d_2023-01-01\",\n", 392 | " n_rollout_steps=7,\n", 393 | " dropout_mode=\"uniform:2\",\n", 394 | " is_full_eval=True,\n", 395 | " suffix=\"_0\")" 396 | ] 397 | } 398 | ], 399 | "metadata": { 400 | "kernelspec": { 401 | "display_name": "Python 3", 402 | "language": "python", 403 | "name": "python3" 404 | }, 405 | "language_info": { 406 | "codemirror_mode": { 407 | "name": "ipython", 408 | "version": 3 409 | }, 410 | "file_extension": ".py", 411 | "mimetype": "text/x-python", 412 | "name": "python", 413 | "nbconvert_exporter": "python", 414 | "pygments_lexer": "ipython3", 415 | "version": "3.9.5" 416 | } 417 | }, 418 | "nbformat": 4, 419 | "nbformat_minor": 5 420 | } 421 | -------------------------------------------------------------------------------- /analysis_2d_full.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import OrderedDict 3 | import datetime 4 | import gc 5 | import matplotlib 6 | import matplotlib.pylab as plt 7 | from numbers import Number 8 | import numpy as np 9 | import pandas as pd 10 | pd.options.display.max_rows = 1500 11 | pd.options.display.max_columns = 200 12 | pd.options.display.width = 1000 13 | pd.set_option('max_colwidth', 400) 14 | import pdb 15 | import pickle 16 | import pprint as pp 17 | import time 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch import optim 22 | from torch.utils.data import DataLoader 23 | from deepsnap.batch import Batch as deepsnap_Batch 24 | import xarray as xr 25 | 26 | import sys, os 27 | from argparser import arg_parse 28 | from lamp.datasets.load_dataset import load_data 29 | from lamp.gnns import GNNRemesher, Value_Model_Summation, Value_Model, GNNRemesherPolicy, GNNPolicySizing, GNNPolicyAgent, get_reward_batch, get_data_dropout, GNNPolicyAgent_Sampling 30 | from lamp.models import get_model, load_model, unittest_model, build_optimizer, test 31 | 32 | from lamp.pytorch_net.util import Attr_Dict, Batch, filter_filename, pload, pdump, Printer, get_time, init_args, update_args, clip_grad, set_seed, update_dict, filter_kwargs, plot_vectors, plot_matrices, make_dir, get_pdict, to_np_array, record_data, make_dir, str2bool, get_filename_short, print_banner, get_num_params, ddeepcopy as deepcopy, write_to_config 33 | from lamp.utils import p, update_legacy_default_hyperparam, EXP_PATH, seed_everything 34 | 35 | def plot_fig(info, data, index=20): 36 | vers_gt = data.reind_yfeatures["n0"][index].detach().cpu().numpy() 37 | faces_gt = data.yface_list["n0"][index] 38 | 39 | faces_gt = np.stack(faces_gt) 40 | batch_0_idx = np.where(vers_gt[faces_gt[:,0],0]<1.5) 41 | faces_gt = np.stack(faces_gt)[batch_0_idx] 42 | 43 | batch_0_idx = np.where(vers_gt[:,0]<1.5) 44 | vers_gt = vers_gt[batch_0_idx][:,:] 45 | 46 | plot0 = info['state_preds'][0] 47 | plot20 = info['state_preds'][index] 48 | plot20alt_gt_evl = info['state_preds_alt_gt_evl'][index] 49 | plot20alt_gt_mesh_gt_evl = info['state_preds_alt_gt_mesh_gt_evl'][index] 50 | plot20alt_008_gt_evl = info['state_preds_alt_008_gt_evl'][index] 51 | plot20heuristic_gt_evl = info['state_preds_heuristic_gt_evl'][index] 52 | 53 | 54 | from matplotlib.backends.backend_pdf import PdfPages 55 | fig = plt.figure(figsize=(30,6)) 56 | n = 2 57 | ax0= fig.add_subplot(n,7,1,projection='3d') 58 | ax0.set_axis_off() 59 | ax0.set_xlim([-0.6, 0.6]) 60 | ax0.set_ylim([-0.6, 0.6]) 61 | ax0.set_zlim([-0.5, 0.1]) 62 | ax0.view_init(30, 10) 63 | ax0.plot_trisurf(plot0['history'][-1][:,3].detach().cpu().numpy(),plot0['history'][-1][:,4].detach().cpu().numpy(), plot0['xfaces'].detach().cpu().numpy().T, plot0['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1)) 64 | plt.title("plot0") 65 | 66 | ax0= fig.add_subplot(n,7,2,projection='3d') 67 | ax0.set_axis_off() 68 | ax0.set_xlim([-0.6, 0.6]) 69 | ax0.set_ylim([-0.6, 0.6]) 70 | ax0.set_zlim([-0.5, 0.1]) 71 | ax0.view_init(30, 10) 72 | ax0.plot_trisurf(plot20alt_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20alt_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20alt_gt_evl['xfaces'].detach().cpu().numpy().T, plot20alt_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1)) 73 | plt.title("plot20alt_gt_evl") 74 | 75 | ax0= fig.add_subplot(n,7,3,projection='3d') 76 | ax0.set_axis_off() 77 | ax0.set_xlim([-0.6, 0.6]) 78 | ax0.set_ylim([-0.6, 0.6]) 79 | ax0.set_zlim([-0.5, 0.1]) 80 | ax0.view_init(30, 10) 81 | ax0.plot_trisurf(plot20alt_gt_mesh_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20alt_gt_mesh_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20alt_gt_mesh_gt_evl['xfaces'].detach().cpu().numpy().T, plot20alt_gt_mesh_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1)) 82 | plt.title("plot20alt_gt_mesh_gt_evl") 83 | 84 | ax0= fig.add_subplot(n,7,4,projection='3d') 85 | ax0.set_axis_off() 86 | ax0.set_xlim([-0.6, 0.6]) 87 | ax0.set_ylim([-0.6, 0.6]) 88 | ax0.set_zlim([-0.5, 0.1]) 89 | ax0.view_init(30, 10) 90 | ax0.plot_trisurf(plot20alt_008_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20alt_008_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20alt_008_gt_evl['xfaces'].detach().cpu().numpy().T, plot20alt_008_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1)) 91 | plt.title("plot20alt_008_gt_evl") 92 | 93 | ax0= fig.add_subplot(n,7,5,projection='3d') 94 | ax0.set_axis_off() 95 | ax0.set_xlim([-0.6, 0.6]) 96 | ax0.set_ylim([-0.6, 0.6]) 97 | ax0.set_zlim([-0.5, 0.1]) 98 | ax0.view_init(30, 10) 99 | ax0.plot_trisurf(plot20heuristic_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20heuristic_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20heuristic_gt_evl['xfaces'].detach().cpu().numpy().T, plot20heuristic_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1)) 100 | plt.title("plot20heuristic_gt_evl") 101 | 102 | 103 | ax0= fig.add_subplot(n,7,6,projection='3d') 104 | ax0.set_axis_off() 105 | ax0.set_xlim([-0.6, 0.6]) 106 | ax0.set_ylim([-0.6, 0.6]) 107 | ax0.set_zlim([-0.5, 0.1]) 108 | ax0.view_init(30, 10) 109 | ax0.plot_trisurf(plot20['history'][-1][:,3].detach().cpu().numpy(),plot20['history'][-1][:,4].detach().cpu().numpy(), plot20['xfaces'].detach().cpu().numpy().T, plot20['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1)) 110 | plt.title("plot20_policy") 111 | 112 | 113 | ax0= fig.add_subplot(n,7,7,projection='3d') 114 | ax0.set_axis_off() 115 | ax0.set_xlim([-0.6, 0.6]) 116 | ax0.set_ylim([-0.6, 0.6]) 117 | ax0.set_zlim([-0.5, 0.1]) 118 | ax0.view_init(30, 10) 119 | ax0.plot_trisurf(vers_gt[:,3],vers_gt[:,4],faces_gt, vers_gt[:,5], shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1)) 120 | plt.title("plot20gt") 121 | 122 | 123 | 124 | ax0= fig.add_subplot(n,7,7+1) 125 | ax0.set_axis_off() 126 | ax0.triplot(plot0['history'][-1][:,0].detach().cpu().numpy(),plot0['history'][-1][:,1].detach().cpu().numpy(), plot0['xfaces'].detach().cpu().numpy().T) 127 | plt.title("plot0") 128 | 129 | ax0= fig.add_subplot(n,7,7+2) 130 | ax0.set_axis_off() 131 | ax0.triplot(plot20alt_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20alt_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20alt_gt_evl['xfaces'].detach().cpu().numpy().T,) 132 | plt.title("plot20alt_gt_evl") 133 | 134 | 135 | ax0= fig.add_subplot(n,7,7+3) 136 | ax0.set_axis_off() 137 | ax0.triplot(plot20alt_gt_mesh_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20alt_gt_mesh_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20alt_gt_mesh_gt_evl['xfaces'].detach().cpu().numpy().T,) 138 | plt.title("plot20alt_gt_mesh_gt_evl") 139 | 140 | ax0= fig.add_subplot(n,7,7+4) 141 | ax0.set_axis_off() 142 | ax0.triplot(plot20alt_008_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20alt_008_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20alt_008_gt_evl['xfaces'].detach().cpu().numpy().T,) 143 | plt.title("plot20alt_008_gt_evl") 144 | 145 | ax0= fig.add_subplot(n,7,7+5) 146 | ax0.set_axis_off() 147 | ax0.triplot(plot20heuristic_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20heuristic_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20heuristic_gt_evl['xfaces'].detach().cpu().numpy().T,) 148 | plt.title("plot20heuristic_gt_evl") 149 | 150 | 151 | ax0= fig.add_subplot(n,7,7+6) 152 | ax0.set_axis_off() 153 | ax0.triplot(plot20['history'][-1][:,0].detach().cpu().numpy(),plot20['history'][-1][:,1].detach().cpu().numpy(), plot20['xfaces'].detach().cpu().numpy().T,) 154 | plt.title("plot20_policy") 155 | 156 | 157 | ax0= fig.add_subplot(n,7,7+7) 158 | ax0.set_axis_off() 159 | ax0.triplot(vers_gt[:,0],vers_gt[:,1],faces_gt) 160 | plt.title("plot20gt") 161 | 162 | 163 | 164 | # Analysis: 165 | def get_results_2d( 166 | all_hash, 167 | mode="best", 168 | exclude_idx=(None,), 169 | dirname=None, 170 | suffix="",device=None 171 | ): 172 | """ 173 | Perform analysis on the 2D cloth' benchmark. 174 | 175 | Args: 176 | all_hash: a list of hashes which indicates the experiments to load for analysis 177 | mode: choose from "best" (load the best model with lowest validation loss) or an integer, 178 | e.g. -1 (last saved model), -2 (second last saved model) 179 | dirname: if not None, will use the dirnaem provided. E.g. tailin-1d_2022-7-27 180 | suffix: suffix for saving the analysis result. 181 | """ 182 | 183 | isplot = True 184 | df_dict_list = [] 185 | dirname_start = "2d_rl_reproduce_2023-02-26" if dirname is None else dirname 186 | for hash_str in all_hash: 187 | df_dict = {} 188 | df_dict["hash"] = hash_str 189 | is_found = False 190 | for dirname_core in [ 191 | dirname_start, 192 | ]: 193 | filename = filter_filename(EXP_PATH + dirname_core, include=hash_str) 194 | if len(filename) == 1: 195 | is_found = True 196 | break 197 | if not is_found: 198 | print(f"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.") 199 | continue 200 | dirname = EXP_PATH + dirname_core 201 | if not dirname.endswith("/"): 202 | dirname += "/" 203 | 204 | try: 205 | data_record = pload(dirname + filename[0]) 206 | except Exception as e: 207 | print(f"error {e} in hash_str {hash_str}") 208 | continue 209 | 210 | args = init_args(update_legacy_default_hyperparam(data_record["args"])) 211 | args.filename = filename 212 | if not("processor_aggr" in data_record["model_dict"][-1]["actor_model_dict"].keys()): 213 | data_record["model_dict"][-1]["actor_model_dict"]["processor_aggr"] = "max" 214 | if "best_model_dict" in data_record.keys(): 215 | data_record["best_model_dict"]["actor_model_dict"]["processor_aggr"] = "max" 216 | model = load_model(data_record["best_model_dict"], device=device) 217 | evolution = load_model(data_record["best_evolution_model_dict"], device=device) 218 | print("Load the model with best validation loss.") 219 | model.eval() 220 | evolution.eval() 221 | 222 | # Load test dataset: 223 | args_test = deepcopy(args) 224 | args_test.dataset="arcsimmesh_square_annotated_coarse_minlen008_interp_500" 225 | args_test.multi_step = "1^20" 226 | args_test.n_train = "-1" 227 | args_test.is_train=False 228 | args_test.use_fineres_data=False 229 | args_test.show_missing_files = False 230 | args_test.input_steps=2 231 | args_test.time_interval=2 232 | args_test.val_batch_size = 1 233 | args_test.device = device 234 | args.device = device 235 | (train_dataset, test_dataset), (trian_loader, _, test_loader) = load_data(args_test) 236 | print(len(test_loader)) 237 | 238 | args_test_gt_008 = deepcopy(args) 239 | args_test_gt_008.dataset="arcsimmesh_square_annotated_coarse_minlen008_interp_500_gt_500" 240 | args_test_gt_008.n_train = "-1" 241 | args_test_gt_008.multi_step = "1^20" 242 | args_test_gt_008.is_train=False 243 | args_test_gt_008.use_fineres_data=False 244 | args_test_gt_008.show_missing_files = False 245 | args_test_gt_008.input_steps=2 246 | args_test_gt_008.time_interval=2 247 | args_test_gt_008.val_batch_size = 1 248 | args_test_gt_008.device = device 249 | args.device = device 250 | (_, _), (_, _, test_loader_008) = load_data(args_test_gt_008) 251 | print(len(test_loader_008)) 252 | 253 | args.noise_amp = 0 254 | return model, evolution, test_loader, test_loader_008, args 255 | 256 | def get_eval(model, test_loader, test_loader_008,evolution_model,beta=0,break_index=2,hashva=None,name=None,best_evolution_model=None,args=None): 257 | dicts = {} 258 | min_evolution_loss_rmse_average = 1000 259 | if True: 260 | evolution_loss_mse = 0 261 | evolution_loss_alt_gt_evl_mse = 0 262 | evolution_loss_alt_gt_mesh_gt_evl_mse = 0 263 | evolution_loss_alt_008_gt_evl_mse = 0 264 | evolution_loss_heuristic_gt_evl_mse = 0 265 | 266 | 267 | r_statediff = 0 268 | v_statediff = 0 269 | rewards = 0 270 | count = 0 271 | actual_count = 0 272 | state_size = 0 273 | state_size_remeshed = 0 274 | args.pred_steps=20 275 | 276 | list_evolution_loss_mse = [] 277 | list_evolution_loss_mse_alt_gt_evl = [] 278 | list_evolution_loss_mse_alt_gt_mesh_gt_evl = [] 279 | list_evolution_loss_mse_alt_008_gt_evl = [] 280 | list_evolution_loss_heuristic_gt_evl = [] 281 | 282 | list_state_size_remeshed = [] 283 | list_all_nodes = [] 284 | list_all_nodes_remesh = [] 285 | kwargs = {} 286 | kwargs["evolution_model_alt"] = best_evolution_model 287 | 288 | if not(os.path.exists("./results/LAMP_2d")): 289 | os.mkdir("./results/LAMP_2d") 290 | with torch.no_grad(): 291 | for j, (data,data008) in enumerate(zip(test_loader,test_loader_008)): 292 | # if (j%91==0 and j>0) or j==0: 293 | if data.time_step in [10,30,50]: 294 | count += 1 295 | actual_count +=1 296 | if data.__class__.__name__ == "Attr_Dict": 297 | data = data.to(args.device) 298 | data008 = data008.to(args.device) 299 | else: 300 | data.to(args.device) 301 | data008.to(args.device) 302 | info, data_clone = model.get_loss( 303 | data, 304 | args, 305 | wandb=None, 306 | opt_evl=False, 307 | step_num=0, 308 | mode="test_gt_remesh_heursitc", 309 | beta=beta, 310 | is_gc_collect=False, 311 | evolution_model=evolution_model, 312 | data008 = data008, 313 | **kwargs 314 | ) 315 | state_size_elm = 0 316 | state_size_elm_list = [] 317 | for elem in data_clone.reind_yfeatures["n0"]: 318 | state_size_elm += elem.shape[0] 319 | state_size_elm_list.append(elem.shape[0]) 320 | evolution_loss_mse += (info['evolution/loss_mse'].sum().item()) 321 | state_size += state_size_elm 322 | evolution_loss_alt_gt_evl_mse += (info['evolution/loss_alt_gt_evl_mse'].sum().item()) 323 | evolution_loss_alt_gt_mesh_gt_evl_mse += (info['evolution/loss_alt_gt_mesh_gt_evl_mse'].sum().item()) 324 | evolution_loss_alt_008_gt_evl_mse += (info['evolution/loss_alt_008_gt_evl_mse'].sum().item()) 325 | evolution_loss_heuristic_gt_evl_mse += (info['evolution/loss_heuristic_gt_evl_mse'].sum().item()) 326 | 327 | 328 | state_size_remeshed += info['v/state_size'].sum().item() 329 | 330 | list_evolution_loss_mse.append(info['evolution/loss_mse']) 331 | list_evolution_loss_mse_alt_gt_evl.append(info['evolution/loss_alt_gt_evl_mse']) 332 | list_evolution_loss_mse_alt_gt_mesh_gt_evl.append(info['evolution/loss_alt_gt_mesh_gt_evl_mse']) 333 | list_evolution_loss_mse_alt_008_gt_evl.append(info['evolution/loss_alt_008_gt_evl_mse']) 334 | list_evolution_loss_heuristic_gt_evl.append(info['evolution/loss_heuristic_gt_evl_mse']) 335 | 336 | list_all_nodes.append(state_size_elm_list) 337 | list_all_nodes_remesh.append(info['v/state_size']) 338 | 339 | print(info['evolution/loss_mse'].shape[0]) 340 | print("index",actual_count,"loss", (info['evolution/loss_mse'].sum().item()/state_size_elm),"Gt state_size_elm",state_size_elm,"remeshed_state_size",info['v/state_size'].sum().item()) 341 | print((info['v/state_size'].reshape(-1))[::4]) 342 | print("current running rms alt gt evl:",np.sqrt(evolution_loss_alt_gt_evl_mse/state_size)) 343 | print("current running rms alt gt mesh gt evl :",np.sqrt(evolution_loss_alt_gt_mesh_gt_evl_mse/state_size)) 344 | print("current running rms alt 008 gt evl:",np.sqrt(evolution_loss_alt_008_gt_evl_mse/state_size)) 345 | print("current running rms heursitc gt evl:",np.sqrt(evolution_loss_heuristic_gt_evl_mse/state_size)) 346 | print("current running rms:",np.sqrt(evolution_loss_mse/state_size)) 347 | 348 | 349 | 350 | if not(os.path.exists("./results/LAMP_2d/{}".format(name))): 351 | os.mkdir("./results/LAMP_2d/{}".format(name)) 352 | if not(os.path.exists("./results/LAMP_2d/{}/{}".format(name,count))): 353 | os.mkdir("./results/LAMP_2d/{}/{}".format(name,count)) 354 | if not(os.path.exists("./results/LAMP_2d/{}/{}/{}".format(name,count,data.time_step))): 355 | os.mkdir("./results/LAMP_2d/{}/{}/{}".format(name,count,data.time_step)) 356 | np.save("./results/LAMP_2d/{}/{}/{}/{}_{}.npy".format(name,count,data.time_step,name,actual_count),{"loss_mse":info['evolution/loss_mse'].detach().cpu().numpy(),"loss_alt_gt_evl_mse":info['evolution/loss_alt_gt_evl_mse'].detach().cpu().numpy,"loss_alt_gt_mesh_gt_evl_mse":info['evolution/loss_alt_gt_mesh_gt_evl_mse'].detach().cpu().numpy(),"loss_alt_008_gt_evl_mse":info['evolution/loss_alt_008_gt_evl_mse'].detach().cpu().numpy(),"loss_heuristic_gt_evl_mse":info['evolution/loss_heuristic_gt_evl_mse'].detach().cpu().numpy(),"state_size_elm_list":state_size_elm_list,"state_size":info['v/state_size'].detach().cpu().numpy()}) 357 | np.save("./results/LAMP_2d/{}/{}/{}/{}_{}_all.npy".format(name,count,data.time_step,name,actual_count),{"info":info}) 358 | for index in range(5,20,2): 359 | plot_fig(info, data_clone,index=index) 360 | plt.savefig("./results/LAMP_2d/{}/{}/{}/{}_{}_{}.png".format(name,count,data.time_step,name,actual_count,index)) 361 | plt.close() 362 | torch.cuda.empty_cache() 363 | with open('./results/LAMP_2d/{}/summary.txt'.format(name), 'a') as f: 364 | f.write("current running rms alt gt evl: {}\n".format(np.sqrt(evolution_loss_alt_gt_evl_mse/state_size))) 365 | f.write("current running rms alt gt mesh gt evl : {}\n".format(np.sqrt(evolution_loss_alt_gt_mesh_gt_evl_mse/state_size))) 366 | f.write("current running rms alt 008 gt evl: {}\n".format(np.sqrt(evolution_loss_alt_008_gt_evl_mse/state_size))) 367 | f.write("current running rms heursitc gt evl: {}\n".format(np.sqrt(evolution_loss_heuristic_gt_evl_mse/state_size))) 368 | f.write("current running rms: {}\n".format(np.sqrt(evolution_loss_mse/state_size))) 369 | 370 | if count>=break_index: 371 | break 372 | # break 373 | evolution_loss_rmse_average = np.sqrt(evolution_loss_mse/state_size) 374 | evolution_loss_alt_gt_mesh_gt_evl_mse_average = np.sqrt(evolution_loss_alt_gt_mesh_gt_evl_mse/state_size) 375 | evolution_loss_alt_gt_evl_mse_average = np.sqrt(evolution_loss_alt_gt_evl_mse/state_size) 376 | evolution_loss_alt_008_gt_evl_mse_average = np.sqrt(evolution_loss_alt_008_gt_evl_mse/state_size) 377 | evolution_loss_heuristic_gt_evl_mse_average = np.sqrt(evolution_loss_heuristic_gt_evl_mse/state_size) 378 | 379 | 380 | state_size_remeshed = state_size_remeshed/actual_count 381 | print("count",count) 382 | print("{}, {}, evolution_loss_alt_gt_mesh_gt_evl_mse_average".format(hashva,beta),evolution_loss_alt_gt_mesh_gt_evl_mse_average) 383 | print("{}, {}, evolution_loss_alt_gt_evl_mse_average".format(hashva,beta),evolution_loss_alt_gt_evl_mse_average) 384 | print("{}, {}, evolution_loss_alt_008_gt_evl_mse_average".format(hashva,beta),evolution_loss_alt_008_gt_evl_mse_average) 385 | print("{}, {}, evolution_loss_heuristic_gt_evl_mse_average".format(hashva,beta),evolution_loss_heuristic_gt_evl_mse_average) 386 | print("{}, {}, evolution_loss_rmse_average".format(hashva,beta),evolution_loss_rmse_average, "state_size_remeshed", state_size_remeshed) 387 | with open('./results/LAMP_2d/{}/summary.txt'.format(name), 'a') as f: 388 | f.write("count {}".format(count)) 389 | f.write("{}, {}, evolution_loss_alt_gt_mesh_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_alt_gt_mesh_gt_evl_mse_average)) 390 | f.write("{}, {}, evolution_loss_alt_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_alt_gt_evl_mse_average)) 391 | f.write("{}, {}, evolution_loss_alt_008_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_alt_008_gt_evl_mse_average)) 392 | f.write("{}, {}, evolution_loss_heuristic_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_heuristic_gt_evl_mse_average)) 393 | f.write("{}, {}, evolution_loss_rmse_average {}\n".format(hashva,beta,evolution_loss_rmse_average)) 394 | 395 | dicts= [info, data_clone, evolution_loss_rmse_average,state_size_remeshed,list_evolution_loss_mse,list_evolution_loss_mse_alt,list_all_nodes,list_all_nodes_remesh] 396 | torch.cuda.empty_cache() 397 | return info, data_clone, dicts, min_evolution_loss_rmse_average 398 | 399 | 400 | def run_hash(all_hashes,dirname=None,evo_dirname="evo-2d_2023_02_18",evo_hash="9UQLIKKc_ampere1",gpu=0): 401 | 402 | p = Printer() 403 | 404 | all_hashes = [all_hashes] 405 | print(all_hashes) 406 | load_best_model = False 407 | device = "cuda:{}".format(gpu) 408 | beta = 0 409 | constrain_edge_size = True 410 | 411 | mode = "best" 412 | evo_dirname = os.path.join(EXP_PATH, evo_dirname) 413 | if not dirname.endswith("/"): 414 | evo_dirname += "/" 415 | all_dict = {} 416 | isplot = False 417 | seed_everything(42) 418 | filename = filter_filename(evo_dirname, include=evo_hash) 419 | print(filename) 420 | try: 421 | data_record = pload(evo_dirname + filename[0]) 422 | except Exception as e: 423 | print(f"error {e} in evo_hash {evo_hash}") 424 | p.print(f"Hash {evo_hash}, best model at epoch {data_record['best_epoch']}:", banner_size=160) 425 | args = init_args(update_legacy_default_hyperparam(data_record["args"])) 426 | args.filename = filename 427 | data_record['best_model_dict']['type'] = 'GNNRemesherPolicy' 428 | data_record['best_model_dict']['noise_amp'] = 0. 429 | data_record['best_model_dict']["correction_rate"] = 0. 430 | data_record['best_model_dict']["batch_size"] = 16 431 | best_gnn_evl_model = load_model(data_record["best_model_dict"], device=device) 432 | print("Load the model with best validation loss.") 433 | best_gnn_evl_model.eval() 434 | 435 | value = {} 436 | min_evolution_loss_rmse_average = 1000 437 | min_hash = None 438 | 439 | for hashva in all_hashes: 440 | seed_everything(42) 441 | model, evolution_model, test_loader, test_loader_008, args = get_results_2d([hashva], 442 | mode=-1, 443 | exclude_idx=(None,), 444 | dirname=dirname, 445 | suffix="",device=device) 446 | model.to(args.device) 447 | evolution_model.to(args.device) 448 | if load_best_model: 449 | evolution_model = best_gnn_evl_model.eval().to(args.device) 450 | if constrain_edge_size: 451 | model.actor.min_edge_size=max(0.04,model.actor.min_edge_size) 452 | model.eval() 453 | evolution_model.eval() 454 | 455 | name = "{}_{}_{}_{}_004".format(hashva,load_best_model,constrain_edge_size,beta) 456 | seed_everything(42) 457 | info, data, dicts, minval = get_eval(model, test_loader,test_loader_008, evolution_model,break_index=50*3,hashva=hashva,beta=beta,name=name,best_evolution_model=best_gnn_evl_model.eval(),args=args) 458 | 459 | if minval 0: 127 | p.print(i) 128 | p.print(i + 1) 129 | dataset = GraphDataset(graph_list, task='node', minimum_node_per_graph=1) 130 | dataset.n_simu = pyg_dataset.n_simu 131 | dataset.time_stamps = pyg_dataset.time_stamps 132 | return dataset 133 | 134 | train_val_fraction = 0.9 135 | train_fraction = args.train_fraction 136 | multi_step_dict = parse_multi_step(args.multi_step) 137 | max_pred_steps = max(list(multi_step_dict.keys()) + [1]) * args.temporal_bundle_steps 138 | filename_train_val = os.path.join(PDE_PATH, "deepsnap", "{}_train_val_in_{}_out_{}{}{}{}.p".format( 139 | args.dataset, args.input_steps * args.temporal_bundle_steps, max_pred_steps, 140 | "_itv_{}".format(args.time_interval) if args.time_interval > 1 else "", 141 | "_yvar_{}".format(args.is_y_variable_length) if args.is_y_variable_length is True else "", 142 | "_noise_{}".format(args.data_noise_amp) if args.data_noise_amp > 0 else "", 143 | )) 144 | filename_test = os.path.join(PDE_PATH, "deepsnap", "{}_test_in_{}_out_{}{}{}{}.p".format( 145 | args.dataset, args.input_steps * args.temporal_bundle_steps, max_pred_steps, 146 | "_itv_{}".format(args.time_interval) if args.time_interval > 1 else "", 147 | "_yvar_{}".format(args.is_y_variable_length) if args.is_y_variable_length is True else "", 148 | "_noise_{}".format(args.data_noise_amp) if args.data_noise_amp > 0 else "", 149 | )) 150 | make_dir(filename_train_val) 151 | is_to_deepsnap = True 152 | if (os.path.isfile(filename_train_val) or args.is_test_only) and os.path.isfile(filename_test) and args.n_train == "-1" and not args.dataset.startswith("arcsimmesh"): 153 | if not args.is_test_only: 154 | p.print(f"Loading {filename_train_val}") 155 | loaded = pickle.load(open(filename_train_val, "rb")) 156 | if isinstance(loaded, tuple): 157 | dataset_train, dataset_val = loaded 158 | else: 159 | dataset_train_val = loaded 160 | p.print(f"Loading {filename_test}") 161 | dataset_test = pickle.load(open(filename_test, "rb")) 162 | p.print("Loaded pre-saved deepsnap file at {}.".format(filename_test)) 163 | 164 | else: 165 | p.print("{} does not exist. Generating...".format(filename_test)) 166 | is_save = True # If True, will save generated deepsnap dataset. 167 | 168 | if args.dataset.startswith("mppde1d"): 169 | if not args.is_test_only: 170 | pyg_dataset_train = MPPDE1D( 171 | dataset=args.dataset, 172 | input_steps=args.input_steps * args.temporal_bundle_steps, 173 | output_steps=max_pred_steps, 174 | time_interval=args.time_interval, 175 | is_y_diff=args.is_y_diff, 176 | split="train", 177 | ) 178 | pyg_dataset_val = MPPDE1D( 179 | dataset=args.dataset, 180 | input_steps=args.input_steps * args.temporal_bundle_steps, 181 | output_steps=max_pred_steps, 182 | time_interval=args.time_interval, 183 | is_y_diff=args.is_y_diff, 184 | split="valid", 185 | ) 186 | pyg_dataset_test = MPPDE1D( 187 | dataset=args.dataset, 188 | input_steps=args.input_steps * args.temporal_bundle_steps, 189 | output_steps=max_pred_steps, 190 | time_interval=args.time_interval, 191 | is_y_diff=args.is_y_diff, 192 | split="test", 193 | ) 194 | elif args.dataset.startswith("arcsimmesh"): 195 | if args.algo.startswith("rlgnnremesher"): 196 | max_pred_steps = args.rl_horizon 197 | if not args.is_test_only: 198 | pyg_dataset_train_val = ArcsimMesh( 199 | dataset=args.dataset, 200 | input_steps=args.input_steps * args.temporal_bundle_steps, 201 | output_steps=max_pred_steps, 202 | time_interval=args.time_interval, 203 | use_fineres_data=args.use_fineres_data, 204 | is_train=True 205 | #is_y_diff=args.is_y_diff, 206 | #split="train", 207 | ) 208 | pyg_dataset_test = ArcsimMesh( 209 | dataset=args.dataset, 210 | input_steps=args.input_steps * args.temporal_bundle_steps, 211 | output_steps=max_pred_steps, 212 | time_interval=args.time_interval, 213 | is_shifted_data=args.is_shifted_data, 214 | use_fineres_data=args.use_fineres_data, 215 | is_train=False, 216 | #is_y_diff=args.is_y_diff, 217 | #split="test", 218 | ) 219 | is_to_deepsnap = False 220 | else: 221 | raise 222 | 223 | if args.n_train != "-1": 224 | # Test overfitting: 225 | is_save = False 226 | if not args.is_test_only: 227 | if "pyg_dataset_train_val" in locals(): 228 | pyg_dataset_train_val = get_elements(pyg_dataset_train_val, args.n_train) 229 | else: 230 | pyg_dataset_train_val = get_elements(pyg_dataset_train, args.n_train) 231 | pyg_dataset_test = pyg_dataset_train_val 232 | p.print(", using the following elements {}.".format(args.n_train)) 233 | else: 234 | p.print(":") 235 | 236 | # Transform to deepsnap format: 237 | if is_to_deepsnap: 238 | if not args.is_test_only: 239 | if "pyg_dataset_train_val" in locals(): 240 | dataset_train_val = to_deepsnap(pyg_dataset_train_val, args) 241 | else: 242 | dataset_train = to_deepsnap(pyg_dataset_train, args) 243 | dataset_val = to_deepsnap(pyg_dataset_val, args) 244 | dataset_test = to_deepsnap(pyg_dataset_test, args) 245 | else: 246 | if not args.is_test_only: 247 | dataset_train_val = pyg_dataset_train_val 248 | dataset_test = pyg_dataset_test 249 | 250 | # Save pre-processed dataset into file: 251 | if is_save: 252 | if not args.is_test_only: 253 | if "pyg_dataset_train_val" in locals(): 254 | if not os.path.isfile(filename_train_val): 255 | pickle.dump(dataset_train_val, open(filename_train_val, "wb")) 256 | else: 257 | pickle.dump((dataset_train, dataset_val), open(filename_train_val, "wb")) 258 | try: 259 | pickle.dump(dataset_test, open(filename_test, "wb")) 260 | p.print("saved generated deepsnap dataset to {}".format(filename_test)) 261 | except Exception as e: 262 | p.print(f"Cannot save dataset object. Reason: {e}") 263 | 264 | 265 | # Split into train, val and test: 266 | collate_fn = deepsnap_Batch.collate() if is_to_deepsnap else MeshBatch(is_absorb_batch=True, is_collate_tuple=True).collate() 267 | if not args.is_test_only: 268 | if args.n_train == "-1": 269 | if "dataset_train" in locals() and "dataset_val" in locals(): 270 | dataset_train_val = (dataset_train, dataset_val) 271 | elif args.dataset_split_type == "standard": 272 | if args.dataset.startswith("VL") or args.dataset.startswith("PL") or args.dataset.startswith("PIL"): 273 | train_idx, val_idx = get_train_val_idx(len(dataset_train_val), chunk_size=200) 274 | dataset_train = dataset_train_val[train_idx] 275 | dataset_val = dataset_train_val[val_idx] 276 | else: 277 | num_train = int(len(dataset_train_val) * train_fraction) 278 | dataset_train, dataset_val = dataset_train_val[:num_train], dataset_train_val[num_train:] 279 | elif args.dataset_split_type == "random": 280 | train_idx, val_idx = get_train_val_idx_random(len(dataset_train_val), train_fraction=train_fraction) 281 | dataset_train, dataset_val = dataset_train_val[train_idx], dataset_train_val[val_idx] 282 | elif args.dataset_split_type == "order": 283 | n_train = int(len(dataset_train_val) * train_fraction) 284 | dataset_train, dataset_val = dataset_train_val[:n_train], dataset_train_val[n_train:] 285 | else: 286 | raise Exception("dataset_split_type '{}' is not valid!".format(args.dataset_split_type)) 287 | else: 288 | # train, val, test are all the same as designated by args.n_train: 289 | dataset_train = deepcopy(dataset_train_val) if is_to_deepsnap else dataset_train_val 290 | dataset_val = deepcopy(dataset_train_val) if is_to_deepsnap else dataset_train_val 291 | train_loader = DataLoader(dataset_train, num_workers=args.n_workers, collate_fn=collate_fn, 292 | batch_size=args.batch_size, shuffle=True if args.dataset_split_type!="order" else False, drop_last=True) 293 | val_loader = DataLoader(dataset_val, num_workers=args.n_workers, collate_fn=collate_fn, 294 | batch_size=args.val_batch_size if not args.algo.startswith("supn") else 1, shuffle=False, drop_last=False) 295 | else: 296 | dataset_train_val = None 297 | train_loader, val_loader = None, None 298 | test_loader = DataLoader(dataset_test, num_workers=args.n_workers, collate_fn=collate_fn, 299 | batch_size=args.val_batch_size if not args.algo.startswith("supn") else 1, shuffle=False, drop_last=False) 300 | return (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) 301 | 302 | class MeshBatch(object): 303 | def __init__(self, is_absorb_batch=False, is_collate_tuple=False): 304 | """ 305 | 306 | Args: 307 | is_collate_tuple: if True, will collate inside the tuple. 308 | """ 309 | self.is_absorb_batch = is_absorb_batch 310 | self.is_collate_tuple = is_collate_tuple 311 | 312 | def collate(self): 313 | import re 314 | if torch.__version__.startswith("1.9") or torch.__version__.startswith("1.10") or torch.__version__.startswith("1.11"): 315 | from torch._six import string_classes 316 | from collections import abc as container_abcs 317 | else: 318 | from torch._six import container_abcs, string_classes, int_classes 319 | from pstar import pdict, plist 320 | default_collate_err_msg_format = ( 321 | "collate_fn: batch must contain tensors, numpy arrays, numbers, " 322 | "dicts or lists; found {}") 323 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 324 | def default_convert(data): 325 | r"""Converts each NumPy array data field into a tensor""" 326 | elem_type = type(data) 327 | if isinstance(data, torch.Tensor): 328 | return data 329 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 330 | and elem_type.__name__ != 'string_': 331 | # array of string classes and object 332 | if elem_type.__name__ == 'ndarray' \ 333 | and np_str_obj_array_pattern.search(data.dtype.str) is not None: 334 | return data 335 | return torch.as_tensor(data) 336 | elif isinstance(data, container_abcs.Mapping): 337 | return {key: default_convert(data[key]) for key in data} 338 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple 339 | return elem_type(*(default_convert(d) for d in data)) 340 | elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes): 341 | return [default_convert(d) for d in data] 342 | else: 343 | return data 344 | 345 | def collate_fn(batch): 346 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate.""" 347 | # pdb.set_trace() 348 | elem = batch[0] 349 | elem_type = type(elem) 350 | if isinstance(elem, torch.Tensor): 351 | out = None 352 | if torch.utils.data.get_worker_info() is not None: 353 | # If we're in a background process, concatenate directly into a 354 | # shared memory tensor to avoid an extra copy 355 | numel = sum([x.numel() for x in batch]) 356 | storage = elem.storage()._new_shared(numel) 357 | out = elem.new(storage) 358 | tensor = torch.cat(batch, 0, out=out) 359 | if self.is_absorb_batch: 360 | # pdb.set_trace() 361 | if tensor.shape[1] == 0: 362 | tensor = tensor.view(tensor.shape[0], 0) 363 | else: 364 | tensor = tensor.view(-1, *tensor.shape[2:]) 365 | return tensor 366 | elif elem is None: 367 | return None 368 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 369 | and elem_type.__name__ != 'string_': 370 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 371 | # array of string classes and object 372 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 373 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 374 | return collate_fn([torch.as_tensor(b) for b in batch]) 375 | elif elem.shape == (): # scalars 376 | return torch.as_tensor(batch) 377 | elif isinstance(elem, float): 378 | return torch.tensor(batch, dtype=torch.float64) 379 | elif isinstance(elem, int): 380 | return torch.tensor(batch) 381 | elif isinstance(elem, string_classes): 382 | return batch 383 | elif isinstance(elem, container_abcs.Mapping): 384 | Dict = {} 385 | for key in elem: 386 | if key == "node_feature": 387 | Dict["vers"] = collate_trans_fn([d[key] for d in batch]) 388 | Dict[key] = collate_fn([d[key] for d in batch]) 389 | Dict["batch"] = {"n0": []} 390 | batch_nodes = [d[key]["n0"] for d in batch] 391 | for i in range(len(batch_nodes)): 392 | item = torch.full((batch_nodes[i].shape[0],), i, dtype=torch.long) 393 | Dict["batch"]["n0"].append(item) 394 | Dict["batch"]["n0"] = torch.cat(Dict["batch"]["n0"]) 395 | elif key in ["y_tar", "reind_yfeatures"]: 396 | # pdb.set_trace() 397 | Dict[key] = collate_ytar_trans_fn([d[key] for d in batch]) 398 | elif key in ["history"]: 399 | Dict[key] = collate_fn([d[key] for d in batch]) 400 | Dict["batch_history"] = {"n0": []} 401 | batch_nodes = [d[key]["n0"] for d in batch] 402 | for i in range(len(batch_nodes)): 403 | item = torch.full((batch_nodes[i][0].shape[0],), i, dtype=torch.long) 404 | Dict["batch_history"]["n0"].append(item) 405 | Dict["batch_history"]["n0"] = torch.cat(Dict["batch_history"]["n0"]) 406 | elif key == "edge_index": 407 | Dict[key] = collate_edgeidshift_fn([d[key] for d in batch]) 408 | elif key == "yedge_index": 409 | # pdb.set_trace() 410 | Dict[key] = collate_y_edgeidshift_fn([d[key] for d in batch]) 411 | elif key == "xfaces": 412 | Dict[key] = collate_xfaceshift_fn([d[key] for d in batch]) 413 | elif key in ["bary_indices", "hist_indices"]: 414 | #pdb.set_trace() 415 | Dict[key] = collate_bary_indices_fn([d[key] for d in batch]) 416 | elif key in ["yface_list", "xface_list"]: 417 | Dict[key] = collate_fn([d[key] for d in batch]) 418 | else: 419 | Dict[key] = collate_fn([d[key] for d in batch]) 420 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict): 421 | Dict = elem.__class__(**Dict) 422 | return Dict 423 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple: 424 | return elem_type(*(collate_fn(samples) for samples in zip(*batch))) 425 | elif isinstance(elem, My_Tuple): 426 | it = iter(batch) 427 | elem_size = len(next(it)) 428 | if not all(len(elem) == elem_size for elem in it): 429 | raise RuntimeError('each element in list of batch should be of equal size') 430 | transposed = zip(*batch) 431 | return elem.__class__([collate_fn(samples) for samples in transposed]) 432 | elif isinstance(elem, tuple): 433 | # pdb.set_trace() 434 | if self.is_collate_tuple: 435 | #pdb.set_trace() 436 | if len(elem) == 0: 437 | return batch[0] 438 | elif isinstance(elem[0], torch.Tensor): 439 | newbatch = () 440 | for i in range(len(elem)): 441 | newbatch = newbatch + tuple([torch.cat([tup[i] for tup in batch], dim=0)]) 442 | return newbatch 443 | elif type(elem[0]) == list: 444 | newbatch = () 445 | for i in range(len(elem)): 446 | cumsum = 0 447 | templist = [] 448 | for k in range(len(batch)): 449 | shiftbatch = np.array(batch[k][i]) + cumsum 450 | cumsum = shiftbatch.max() + 1 451 | templist.extend(shiftbatch.tolist()) 452 | newbatch = newbatch + (templist,) 453 | elif type(elem[0]).__module__ == np.__name__: 454 | newbatch = () 455 | for i in range(len(elem)): 456 | cumsum = 0 457 | templist = [] 458 | for k in range(len(batch)): 459 | shiftbatch = batch[k][i] + cumsum 460 | cumsum = shiftbatch.max() + 1 461 | templist.extend(shiftbatch.tolist()) 462 | newbatch = newbatch + (templist,) 463 | else: 464 | newbatch = batch[0] 465 | return newbatch 466 | else: 467 | return batch 468 | elif isinstance(elem, container_abcs.Sequence): 469 | # check to make sure that the elements in batch have consistent size 470 | it = iter(batch) 471 | elem_size = len(next(it)) 472 | if not all(len(elem) == elem_size for elem in it): 473 | raise RuntimeError('each element in list of batch should be of equal size') 474 | transposed = zip(*batch) 475 | return [collate_fn(samples) for samples in transposed] 476 | elif elem.__class__.__name__ == 'Dictionary': 477 | return batch 478 | elif elem.__class__.__name__ == 'DGLHeteroGraph': 479 | import dgl 480 | return dgl.batch(batch) 481 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 482 | 483 | def collate_bary_indices_fn(batch): 484 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate.""" 485 | # pdb.set_trace() 486 | elem = batch[0] 487 | elem_type = type(elem) 488 | if isinstance(elem, container_abcs.Mapping): 489 | Dict = {key: collate_bary_indices_fn([d[key] for d in batch]) for key in elem} 490 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict): 491 | Dict = elem.__class__(**Dict) 492 | return Dict 493 | elif isinstance(elem, tuple): 494 | # pdb.set_trace() 495 | if self.is_collate_tuple: 496 | #pdb.set_trace() 497 | if type(elem[0]).__module__ == np.__name__: 498 | newbatch = () 499 | for i in range(len(elem)): 500 | cumsum = 0 501 | templist = [] 502 | for k in range(len(batch)): 503 | shiftbatch = batch[k][i] + cumsum 504 | cumsum = shiftbatch.max() + 1 505 | templist.extend(shiftbatch) 506 | newbatch = newbatch + (templist,) 507 | return newbatch 508 | else: 509 | return batch 510 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 511 | 512 | def collate_y_edgeidshift_fn(batch): 513 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate.""" 514 | # pdb.set_trace() 515 | elem = batch[0] 516 | elem_type = type(elem) 517 | if isinstance(elem, torch.Tensor): 518 | out = None 519 | if torch.utils.data.get_worker_info() is not None: 520 | # If we're in a background process, concatenate directly into a 521 | # shared memory tensor to avoid an extra copy 522 | numel = sum([x.numel() for x in batch]) 523 | storage = elem.storage()._new_shared(numel) 524 | out = elem.new(storage) 525 | cumsum = 0 526 | newbatch = [] 527 | for i in range(len(batch)): 528 | shiftbatch = batch[i] + cumsum 529 | newbatch.append(shiftbatch) 530 | cumsum = shiftbatch.max().item() + 1 531 | tensor = torch.cat(newbatch, dim=1) 532 | return tensor 533 | elif isinstance(elem, container_abcs.Mapping): 534 | Dict = {key: collate_y_edgeidshift_fn([d[key] for d in batch]) for key in elem} 535 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict): 536 | Dict = elem.__class__(**Dict) 537 | return Dict 538 | elif isinstance(elem, tuple): 539 | if self.is_collate_tuple: 540 | if isinstance(elem[0], torch.Tensor): 541 | newbatch = () 542 | for i in range(len(elem)): 543 | cumsum = 0 544 | tempbatch = [] 545 | for tup in batch: 546 | shiftbatch = tup[i] + cumsum 547 | tempbatch.append(shiftbatch) 548 | cumsum = shiftbatch.max().item() + 1 549 | newbatch = newbatch + tuple([torch.cat(tempbatch, dim=-1)]) 550 | return newbatch 551 | else: 552 | return batch 553 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 554 | 555 | def collate_edgeidshift_fn(batch): 556 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate.""" 557 | # pdb.set_trace() 558 | elem = batch[0] 559 | elem_type = type(elem) 560 | if isinstance(elem, torch.Tensor): 561 | out = None 562 | if torch.utils.data.get_worker_info() is not None: 563 | # If we're in a background process, concatenate directly into a 564 | # shared memory tensor to avoid an extra copy 565 | numel = sum([x.numel() for x in batch]) 566 | storage = elem.storage()._new_shared(numel) 567 | out = elem.new(storage) 568 | cumsum = 0 569 | newbatch = [] 570 | for i in range(len(batch)): 571 | shiftbatch = batch[i] + cumsum 572 | newbatch.append(shiftbatch) 573 | cumsum = shiftbatch.max().item() + 1 574 | tensor = torch.cat(newbatch, dim=1) 575 | return tensor 576 | elif isinstance(elem, container_abcs.Mapping): 577 | Dict = {key: collate_edgeidshift_fn([d[key] for d in batch]) for key in elem} 578 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict): 579 | Dict = elem.__class__(**Dict) 580 | return Dict 581 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 582 | 583 | def collate_xfaceshift_fn(batch): 584 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate.""" 585 | # pdb.set_trace() 586 | elem = batch[0] 587 | elem_type = type(elem) 588 | if isinstance(elem, torch.Tensor): 589 | out = None 590 | if torch.utils.data.get_worker_info() is not None: 591 | # If we're in a background process, concatenate directly into a 592 | # shared memory tensor to avoid an extra copy 593 | numel = sum([x.numel() for x in batch]) 594 | storage = elem.storage()._new_shared(numel) 595 | out = elem.new(storage) 596 | cumsum = 0 597 | newbatch = [] 598 | for i in range(len(batch)): 599 | shiftbatch = batch[i] + cumsum 600 | newbatch.append(shiftbatch) 601 | cumsum = shiftbatch.max().item() + 1 602 | tensor = torch.cat(newbatch, dim=-1) 603 | return tensor 604 | elif isinstance(elem, container_abcs.Mapping): 605 | Dict = {key: collate_xfaceshift_fn([d[key] for d in batch]) for key in elem} 606 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict): 607 | Dict = elem.__class__(**Dict) 608 | return Dict 609 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 610 | 611 | def collate_trans_fn(batch): 612 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate.""" 613 | # pdb.set_trace() 614 | elem = batch[0] 615 | elem_type = type(elem) 616 | if isinstance(elem, torch.Tensor): 617 | out = None 618 | if torch.utils.data.get_worker_info() is not None: 619 | # If we're in a background process, concatenate directly into a 620 | # shared memory tensor to avoid an extra copy 621 | numel = sum([x.numel() for x in batch]) 622 | storage = elem.storage()._new_shared(numel) 623 | out = elem.new(storage) 624 | batch = [batch[i] + 2*i for i in range(len(batch))] 625 | try: 626 | tensor = torch.cat(batch, 0, out=out) 627 | except: 628 | pdb.set_trace() 629 | if self.is_absorb_batch: 630 | # pdb.set_trace() 631 | # if tensor.shape[1] == 0: 632 | # tensor = tensor.view(tensor.shape[0]*tensor.shape[1], 0) 633 | # else: 634 | tensor = tensor.view(-1, *tensor.shape[2:]) 635 | return tensor 636 | elif isinstance(elem, container_abcs.Mapping): 637 | Dict = {key: collate_trans_fn([d[key] for d in batch]) for key in elem} 638 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict): 639 | Dict = elem.__class__(**Dict) 640 | return Dict 641 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 642 | 643 | def collate_ytar_trans_fn(batch): 644 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate.""" 645 | # pdb.set_trace() 646 | elem = batch[0] 647 | elem_type = type(elem) 648 | if isinstance(elem, tuple): 649 | # pdb.set_trace() 650 | if self.is_collate_tuple: 651 | #pdb.set_trace() 652 | if len(elem) == 0: 653 | return batch[0] 654 | elif isinstance(elem[0], torch.Tensor): 655 | newbatch = () 656 | for i in range(len(elem)): 657 | templist = [batch[j][i] + 2*j for j in range(len(batch))] 658 | newbatch = newbatch + (torch.cat(templist, dim=0),) 659 | return newbatch 660 | elif isinstance(elem, container_abcs.Mapping): 661 | Dict = {key: collate_ytar_trans_fn([d[key] for d in batch]) for key in elem} 662 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict): 663 | Dict = elem.__class__(**Dict) 664 | return Dict 665 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 666 | return collate_fn 667 | -------------------------------------------------------------------------------- /datasets/mppde1d_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import scipy.io 8 | import numpy as np 9 | import h5py 10 | import pickle 11 | import torch 12 | from torch_geometric.data import Dataset, Data 13 | import pdb 14 | import sys, os 15 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..')) 16 | from lamp.utils import MPPDE1D_PATH, PDE_PATH 17 | from lamp.pytorch_net.util import Attr_Dict, plot_matrices 18 | from lamp.utils import get_root_dir, to_tuple_shape 19 | from MP_Neural_PDE_Solvers.common.utils import HDF5Dataset 20 | from MP_Neural_PDE_Solvers.equations.PDEs import * 21 | 22 | 23 | # In[ ]: 24 | 25 | 26 | class MPPDE1D(Dataset): 27 | def __init__( 28 | self, 29 | dataset="mppde1d-E1-100", 30 | input_steps=1, 31 | output_steps=1, 32 | time_interval=1, 33 | is_y_diff=False, 34 | split="train", 35 | transform=None, 36 | pre_transform=None, 37 | verbose=False, 38 | ): 39 | assert dataset.startswith("mppde1d") 40 | self.dataset = dataset 41 | self.dirname = MPPDE1D_PATH 42 | self.root = PDE_PATH 43 | 44 | if len(dataset.split("-")) == 3: 45 | _, self.mode, self.nx = dataset.split("-") 46 | self.nt_total, self.nx_total = 250, 200 47 | else: 48 | assert len(dataset.split("-")) == 7 49 | _, self.mode, self.nx, _, self.nt_total, _, self.nx_total = dataset.split("-") 50 | self.nt_total, self.nx_total = int(self.nt_total), int(self.nx_total) 51 | self.nx = int(self.nx) 52 | self.input_steps = input_steps 53 | self.output_steps = output_steps 54 | self.time_interval = time_interval 55 | self.is_y_diff = is_y_diff 56 | self.split = split 57 | assert self.split in ["train", "valid", "test"] 58 | self.verbose = verbose 59 | 60 | self.t_cushion_input = self.input_steps * self.time_interval if self.input_steps * self.time_interval > 1 else 1 61 | self.t_cushion_output = self.output_steps * self.time_interval if self.output_steps * self.time_interval > 1 else 1 62 | 63 | self.original_shape = (self.nx,) 64 | self.dyn_dims = 1 # density 65 | 66 | pde=CE(device="cpu") 67 | if (self.nt_total, self.nx_total) == (250, 200): 68 | path = os.path.join(PDE_PATH, MPPDE1D_PATH) + f'{pde}_{self.split}_{self.mode}.h5' 69 | else: 70 | path = os.path.join(PDE_PATH, MPPDE1D_PATH) + f'{pde}_{self.split}_{self.mode}_nt_{self.nt_total}_nx_{self.nx_total}.h5' 71 | print(f"Load dataset {path}") 72 | self.time_stamps = self.nt_total 73 | base_resolution=[self.nt_total, self.nx] 74 | super_resolution=[self.nt_total, self.nx_total] 75 | 76 | self.dataset_cache = HDF5Dataset(path, pde=pde, 77 | mode=self.split, base_resolution=base_resolution, 78 | super_resolution=super_resolution) 79 | self.n_simu = len(self.dataset_cache) 80 | self.time_stamps_effective = (self.time_stamps - self.t_cushion_input - self.t_cushion_output + self.time_interval) // self.time_interval 81 | super(MPPDE1D, self).__init__(self.root, transform, pre_transform) 82 | 83 | @property 84 | def raw_file_names(self): 85 | return [] 86 | 87 | @property 88 | def processed_dir(self): 89 | return os.path.join(self.root, self.dirname) 90 | 91 | @property 92 | def processed_file_names(self): 93 | return [] 94 | 95 | def download(self): 96 | # Download to `self.raw_dir`. 97 | pass 98 | 99 | def _process(self): 100 | import warnings 101 | from typing import Any, List 102 | from torch_geometric.data.makedirs import makedirs 103 | def _repr(obj: Any) -> str: 104 | if obj is None: 105 | return 'None' 106 | return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__()) 107 | 108 | def files_exist(files: List[str]) -> bool: 109 | # NOTE: We return `False` in case `files` is empty, leading to a 110 | # re-processing of files on every instantiation. 111 | return len(files) != 0 and all([os.path.exists(f) for f in files]) 112 | 113 | f = os.path.join(self.processed_dir, 'pre_transform.pt') 114 | if os.path.exists(f) and torch.load(f) != _repr(self.pre_transform): 115 | warnings.warn( 116 | f"The `pre_transform` argument differs from the one used in " 117 | f"the pre-processed version of this dataset. If you want to " 118 | f"make use of another pre-processing technique, make sure to " 119 | f"sure to delete '{self.processed_dir}' first") 120 | 121 | f = os.path.join(self.processed_dir, 'pre_filter.pt') 122 | if os.path.exists(f) and torch.load(f) != _repr(self.pre_filter): 123 | warnings.warn( 124 | "The `pre_filter` argument differs from the one used in the " 125 | "pre-processed version of this dataset. If you want to make " 126 | "use of another pre-fitering technique, make sure to delete " 127 | "'{self.processed_dir}' first") 128 | 129 | if files_exist(self.processed_paths): # pragma: no cover 130 | return 131 | 132 | makedirs(self.processed_dir) 133 | self.process() 134 | 135 | path = os.path.join(self.processed_dir, 'pre_transform.pt') 136 | if not os.path.isfile(path): 137 | torch.save(_repr(self.pre_transform), path) 138 | path = os.path.join(self.processed_dir, 'pre_filter.pt') 139 | if not os.path.isfile(path): 140 | torch.save(_repr(self.pre_filter), path) 141 | 142 | def get_edge_index(self): 143 | edge_index_filename = os.path.join(self.processed_dir, f"{self.dataset}_edge_index.p") 144 | mask_valid_filename = os.path.join(self.root, self.dirname, f"{self.dataset}_mask_index.p") 145 | if os.path.isfile(edge_index_filename) and os.path.isfile(mask_valid_filename): 146 | edge_index = pickle.load(open(edge_index_filename, "rb")) 147 | mask_valid = pickle.load(open(mask_valid_filename, "rb")) 148 | return edge_index, mask_valid 149 | mask_valid = torch.ones(self.original_shape).bool() 150 | #velo_invalid_ids = np.where(velo_invalid_mask.flatten())[0] 151 | rows, cols = (*self.original_shape, 1) 152 | cube = np.arange(rows * cols).reshape(rows, cols) 153 | edge_list = [] 154 | for i in range(rows): 155 | for j in range(cols): 156 | if i + 1 < rows: #and cube[i, j] not in velo_invalid_ids and cube[i+1, j] not in velo_invalid_ids: 157 | edge_list.append([cube[i, j], cube[i+1, j]]) 158 | edge_list.append([cube[i+1, j], cube[i, j]]) 159 | if j + 1 < cols: #and cube[i, j]: #not in velo_invalid_ids and cube[i, j+1] not in velo_invalid_ids: 160 | edge_list.append([cube[i, j], cube[i, j+1]]) 161 | edge_list.append([cube[i, j+1], cube[i, j]]) 162 | edge_index = torch.LongTensor(edge_list).T 163 | pickle.dump(edge_index, open(edge_index_filename, "wb")) 164 | pickle.dump(mask_valid, open(mask_valid_filename, "wb")) 165 | return edge_index, mask_valid 166 | 167 | def process(self): 168 | pass 169 | 170 | def len(self): 171 | return self.time_stamps_effective * self.n_simu 172 | 173 | def get(self, idx): 174 | # assert self.time_interval == 1 175 | sim_id, time_id = divmod(idx, self.time_stamps_effective) 176 | _, data_traj, x_pos, param = self.dataset_cache[sim_id] 177 | if self.verbose: 178 | print(f"sim_id: {sim_id} time_id: {time_id} input: ({time_id * self.time_interval + self.t_cushion_input -self.input_steps * self.time_interval}, {time_id * self.time_interval + self.t_cushion_input}) output: ({time_id * self.time_interval + self.t_cushion_input}, {time_id * self.time_interval + self.t_cushion_input + self.output_steps * self.time_interval})") 179 | x_dens = torch.FloatTensor(np.stack([data_traj[time_id * self.time_interval + self.t_cushion_input + j] for j in range(-self.input_steps * self.time_interval, 0, self.time_interval)], -1)) 180 | y_dens = torch.FloatTensor(np.stack([data_traj[time_id * self.time_interval + self.t_cushion_input + j] for j in range(0, self.output_steps * self.time_interval, self.time_interval)], -1)) # [1, rows, cols, output_steps, 1] 181 | edge_index, mask_valid = self.get_edge_index() 182 | param = torch.cat([torch.FloatTensor([ele]) for key, ele in param.items()]) 183 | x_bdd = torch.ones(x_dens.shape[0]) 184 | x_bdd[0] = 0 185 | x_bdd[-1] = 0 186 | x_pos = torch.FloatTensor(x_pos)[...,None] 187 | for dim in range(len(self.original_shape)): 188 | x_pos[..., dim] /= self.original_shape[dim] 189 | 190 | data = Data( 191 | x=x_dens.reshape(-1, *x_dens.shape[-1:], 1).clone(), # [number_nodes: 64 * 64, input_steps, 1] 192 | x_pos=x_pos, # [number_nodes: 128 * 128, 2] 193 | x_bdd=x_bdd[...,None], 194 | xfaces=torch.tensor([]), 195 | y=y_dens.reshape(-1, *y_dens.shape[-1:], 1).clone(), # [number_nodes: 64 * 64, input_steps, 1] 196 | edge_index=edge_index, 197 | mask=mask_valid, 198 | param=param, 199 | original_shape=self.original_shape, 200 | dyn_dims=self.dyn_dims, 201 | compute_func=(0, None), 202 | dataset=self.dataset, 203 | ) 204 | # data = Attr_Dict( 205 | # node_feature={"n0": x_dens.reshape(-1, *x_dens.shape[-1:], 1).clone()}, 206 | # node_label={"n0": y_dens.reshape(-1, *y_dens.shape[-1:], 1).clone()}, 207 | # node_pos={"n0": x_pos}, 208 | # x_bdd={"n0": x_bdd[...,None]}, 209 | # xfaces={"n0": torch.tensor([])}, 210 | # edge_index={("n0","0","n0"): edge_index}, 211 | # mask={"n0": mask_valid}, 212 | # param=My_Freeze_Tuple((("n0", param),)), 213 | # original_shape=My_Freeze_Tuple((("n0", self.original_shape),)), 214 | # dyn_dims=My_Freeze_Tuple((("n0", self.dyn_dims),)), 215 | # compute_func=My_Freeze_Tuple((("n0", (0, None)),)), 216 | # grid_keys=("n0",), 217 | # part_keys=(), 218 | # time_step=time_id, 219 | # sim_id=sim_id, 220 | # dataset=self.dataset, 221 | # ) 222 | update_edge_attr_1d(data) 223 | return data 224 | 225 | 226 | def update_edge_attr_1d(data): 227 | dataset_str = to_tuple_shape(data.dataset) 228 | if dataset_str.split("-")[0] != "mppde1d": 229 | if hasattr(data, "node_feature"): 230 | edge_attr = data.node_pos["n0"][data.edge_index[("n0","0","n0")][0]] - data.node_pos["n0"][data.edge_index[("n0","0","n0")][1]] 231 | if dataset_str.split("-")[0] in ["mppde1de"]: 232 | data.edge_attr = {("n0","0","n0"): edge_attr} 233 | elif dataset_str.split("-")[0] in ["mppde1df", "mppde1dg", "mppde1dh"]: 234 | data.edge_attr = {("n0","0","n0"): torch.cat([edge_attr, edge_attr.abs()], -1)} 235 | else: 236 | raise 237 | if dataset_str.split("-")[0] in ["mppde1dg"]: 238 | data.x_bdd = {"n0": 1 - data.x_bdd["n0"]} 239 | else: 240 | edge_attr = data.x_pos[data.edge_index[0]] - data.x_pos[data.edge_index[1]] 241 | if dataset_str.split("-")[0] in ["mppde1de"]: 242 | data.edge_attr = edge_attr 243 | elif dataset_str.split("-")[0] in ["mppde1df", "mppde1dg", "mppde1dh"]: 244 | data.edge_attr = torch.cat([edge_attr, edge_attr.abs()], -1) 245 | else: 246 | raise 247 | if dataset_str.split("-")[0] in ["mppde1dg"]: 248 | data.x_bdd = 1 - data.x_bdd 249 | return data 250 | 251 | 252 | def get_data_pred(state_preds, step, data, **kwargs): 253 | """Get a new mppde1d Data from the state_preds at step "step" (e.g. 0,1,2....). 254 | Here we assume that the mesh does not change. 255 | 256 | Args: 257 | state_preds: has shape of [n_nodes, n_steps, feature_size:temporal_bundle_steps] 258 | data: a Deepsnap Data object 259 | **kwargs: keys for which need to use the value instead of data's. 260 | """ 261 | is_deepsnap = hasattr(data, "node_feature") 262 | is_list = isinstance(state_preds, list) 263 | if is_list: 264 | assert len(state_preds[0].shape) == 3 265 | state_pred = state_preds[step].reshape(state_preds[step].shape[0], -1, data.node_feature["n0"].shape[-1]) 266 | else: 267 | assert isinstance(state_preds, torch.Tensor) 268 | state_pred = state_preds[...,step,:].reshape(state_preds.shape[0], state_preds.shape[-1], 1) 269 | data_pred = Attr_Dict( 270 | node_feature={"n0": state_pred}, 271 | node_label={"n0": None}, 272 | node_pos={"n0": kwargs["x_pos"] if "x_pos" in kwargs else data.node_pos["n0"] if is_deepsnap else data.x_pos}, 273 | x_bdd={"n0": kwargs["x_bdd"] if "x_bdd" in kwargs else data.x_bdd["n0"] if is_deepsnap else data.x_bdd}, 274 | xfaces={"n0": torch.tensor([])}, 275 | edge_index={("n0","0","n0"): kwargs["edge_index"] if "edge_index" in kwargs else data.edge_index[("n0","0","n0")] if is_deepsnap else data.edge_index}, 276 | mask={"n0": data.mask["n0"] if is_deepsnap else data.mask}, 277 | param={"n0": data.param["n0"][:1] if is_deepsnap else data.param[:1]}, 278 | original_shape=to_tuple_shape(data.original_shape), 279 | dyn_dims=to_tuple_shape(data.dyn_dims), 280 | compute_func=to_tuple_shape(data.compute_func), 281 | grid_keys=("n0",), 282 | part_keys=(), 283 | dataset=to_tuple_shape(data.dataset), 284 | batch=kwargs["batch"] if "batch" in kwargs else data.batch, 285 | ) 286 | update_edge_attr_1d(data_pred) 287 | return data_pred 288 | 289 | 290 | # In[ ]: 291 | 292 | 293 | if __name__ == "__main__": 294 | dataset = MPPDE1D( 295 | dataset="mppde1dh-E2-100", 296 | input_steps=1, 297 | output_steps=1, 298 | time_interval=1, 299 | is_y_diff=False, 300 | split="valid", 301 | transform=None, 302 | pre_transform=None, 303 | verbose=False, 304 | ) 305 | # dataset2 = MPPDE1D( 306 | # dataset="mppde1de-E2-400-nt-1000-nx-400", 307 | # input_steps=1, 308 | # output_steps=1, 309 | # time_interval=5, 310 | # is_y_diff=False, 311 | # split="valid", 312 | # transform=None, 313 | # pre_transform=None, 314 | # verbose=True, 315 | # ) 316 | import matplotlib.pylab as plt 317 | plt.plot(dataset[198].x.squeeze()) 318 | 319 | -------------------------------------------------------------------------------- /license: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 snap-stanford 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | deepsnap==0.2.1 3 | matplotlib==3.4.3 4 | numpy==1.23.5 5 | plotly 6 | tqdm 7 | scipy 8 | h5py 9 | pandas 10 | pyyaml 11 | numba 12 | pstar 13 | xarray 14 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | This results/ folder stores the files containing the experiment results and model checkpoints. The files (with suffix of ".p") are stored under "results/{--exp_id}_{--date_time}/", where --exp_id and --date_time are specified in the command, indicating the experiment id and date for this batch of experiments. The ".p" file is a dictionary containing all necessary information. It can be loaded via: `data_record = pickle.load(open({FILENAME.p}, "rb"))`. The `data_record` dictionary contains the keys of (e.g.): 4 | 5 | * "model_dict": a list of model checkpoints, saved every --inspect_interval epochs. The model can be loaded via the command: `model = load_model(data_record["model_dict"][id])`, where id indicates which checkpoint you want to load from. 6 | * "epoch": a list of integers indicating the corresponding epoch number when the model_dict is saved. 7 | * "train_loss" (or "val_loss", "test_loss"): a list containing the loss for training (validation/test) at the corresponding epoch. 8 | * "last_model_dict": the last model_dict. 9 | * "last_optimizer_dict": optimizer state which can be used for resuming the training. 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | from deepsnap.batch import Batch as deepsnap_Batch 4 | import gc 5 | import numpy as np 6 | import pdb 7 | import pickle 8 | import pprint as pp 9 | import scipy 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | from torch import optim 14 | import torch.multiprocessing 15 | torch.multiprocessing.set_sharing_strategy('file_system') 16 | import time 17 | 18 | import sys, os 19 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..')) 20 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..')) 21 | from lamp.argparser import arg_parse 22 | from lamp.models import get_model, load_model, unittest_model, build_optimizer, test 23 | from lamp.gnns import GNNRemesher 24 | from lamp.datasets.load_dataset import load_data 25 | from lamp.pytorch_net.util import Attr_Dict, Batch, filter_filename, pload, pdump, Printer, get_time, init_args, update_args, clip_grad, set_seed, update_dict, filter_kwargs, plot_vectors, plot_matrices, make_dir, get_pdict, to_np_array, record_data, make_dir, Early_Stopping, str2bool, get_filename_short, print_banner, get_num_params, ddeepcopy as deepcopy, write_to_config 26 | from lamp.utils import EXP_PATH, MeshBatch 27 | from lamp.utils import p, update_legacy_default_hyperparam, get_grad_norm, loss_op_core, get_model_dict, get_elements, is_diagnose, get_keys_values, loss_op, to_tuple_shape, parse_multi_step, get_device, seed_everything 28 | 29 | 30 | # In[ ]: 31 | 32 | def find_hash_and_load(all_hash,mode=-1,exclude_idx=(None,),dirname=None,suffix=""): 33 | isplot = True 34 | df_dict_list = [] 35 | dirname_start = "tailin-rl_2022-9-22/" if dirname is None else dirname 36 | for hash_str in all_hash: 37 | df_dict = {} 38 | df_dict["hash"] = hash_str 39 | # Load model: 40 | is_found = False 41 | for dirname_core in [ 42 | dirname_start, 43 | "tailin-rl_2022-9-22/", 44 | "qq-rl_2022-11-5/", 45 | "qq-rl_2022-9-26/", 46 | "multiscale_cloth_2022-9-26/", 47 | ]: 48 | filename = filter_filename(EXP_PATH + dirname_core, include=hash_str) 49 | if len(filename) == 1: 50 | is_found = True 51 | break 52 | if not is_found: 53 | print(f"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.") 54 | continue 55 | dirname = EXP_PATH + dirname_core 56 | if not dirname.endswith("/"): 57 | dirname += "/" 58 | 59 | try: 60 | data_record = pload(dirname + filename[0]) 61 | except Exception as e: 62 | # p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100) 63 | print(f"error {e} in hash_str {hash_str}") 64 | continue 65 | return data_record 66 | 67 | args = arg_parse() 68 | try: 69 | get_ipython().run_line_magic('matplotlib', 'inline') 70 | get_ipython().run_line_magic('load_ext', 'autoreload') 71 | get_ipython().run_line_magic('autoreload', '2') 72 | is_jupyter = True 73 | args.exp_id = "tailin-test" 74 | args.date_time = "8-27" 75 | # args.date_time = "{}-{}".format(datetime.datetime.now().month, datetime.datetime.now().day) 76 | 77 | # Train: 78 | args.epochs = 200 79 | args.contrastive_rel_coef = 0 80 | args.n_conv_blocks = 6 81 | args.latent_noise_amp = 1e-5 82 | args.multi_step = "1" 83 | args.latent_multi_step = "1^2^3^4" 84 | args.latent_loss_normalize_mode = "targetindi" 85 | args.channel_mode = "exp-16" 86 | args.batch_size = 20 87 | args.val_batch_size = 20 88 | args.reg_type = "None" 89 | args.reg_coef = 1e-4 90 | args.is_reg_anneal = True 91 | args.lr_scheduler_type = "cos" 92 | args.id = "test2" 93 | args.n_workers = 0 94 | args.plot_interval = 50 95 | args.temporal_bundle_steps = 1 96 | 97 | ################################## 98 | # RL algorithm: 99 | ################################## 100 | # args.algo chooses from "gnnremesher-evolution(+reward:32)", "rlgnnremesher^sizing", "rlgnnremesher^agent" 101 | args.algo = "rlgnnremesher^agent" 102 | if args.algo.startswith("rlgnnremesher"): 103 | args.rl_coefs = "None" 104 | args.rl_horizon = 4 105 | args.reward_mode = "lossdiff+statediff" 106 | args.reward_beta = "1" 107 | args.reward_src = "env" 108 | args.rl_lambda = 0.95 109 | args.rl_gamma = 0.99 110 | args.rl_rho = 1. 111 | args.rl_eta = 1e-4 112 | args.rl_critic_update_iterations = 10 113 | args.rl_data_dropout = "node:0-0.3:0.5" 114 | 115 | args.value_latent_size = 32 116 | args.value_num_pool = 1 117 | args.value_act_name = "elu" 118 | args.value_act_name_final = "linear" 119 | args.value_layer_norm = False 120 | args.value_batch_norm = False 121 | args.value_num_steps = 3 122 | args.value_pooling_type = "global_mean_pool" 123 | args.value_target_mode = "value-lambda" 124 | 125 | args.load_dirname = "tailin-multi_2022-8-27" 126 | args.load_filename = "IHvBKQ8K_ampere3" 127 | 128 | ################################## 129 | # Dataset and model: 130 | ################################## 131 | # args.dataset = "mppde1df-E2-100" 132 | # args.dataset = "arcsimmesh_square" 133 | args.dataset = "arcsimmesh_square_annotated" 134 | if args.dataset.startswith("mppde1d"): 135 | args.latent_size = 64 136 | args.act_name = "elu" 137 | args.use_grads = False 138 | args.n_train = "-1" 139 | args.epochs = 2000 140 | args.use_pos = False 141 | args.latent_size = 64 142 | args.contrastive_rel_coef = 0 143 | args.is_prioritized_dropout = False 144 | args.input_steps = 1 145 | args.multi_step = "1^2:0.1^3:0.1^4:0.1" 146 | args.temporal_bundle_steps = 25 147 | args.n_train = ":100" 148 | args.epochs = 2000 149 | args.test_interval = 100 150 | args.save_interval = 100 151 | 152 | args.data_dropout = "node:0-0.1:0.1" 153 | args.use_pos = False 154 | args.rl_coefs = "reward:0.1" 155 | 156 | # Data: 157 | args.time_interval = 1 158 | args.dataset_split_type = "random" 159 | args.train_fraction = 1 160 | 161 | # Model: 162 | args.evolution_type = "mlp-3-elu-2" 163 | args.forward_type = "Euler" 164 | args.act_name = "elu" 165 | 166 | args.gpuid = "7" 167 | args.is_unittest = True 168 | 169 | elif args.dataset.startswith("arcsimmesh"): 170 | args.exp_id = "takashi-2dtest" 171 | args.date_time = "{}-{}".format(datetime.datetime.now().month, datetime.datetime.now().day) 172 | args.algo = "gnnremesher-evolution" 173 | args.encoder_type = "cnn-s" 174 | args.evo_conv_type = "cnn" 175 | args.decoder_type = "cnn-tr" 176 | args.padding_mode = "zeros" 177 | args.n_conv_layers_latent = 3 178 | args.n_conv_blocks = 4 179 | args.n_latent_levs = 1 180 | args.is_latent_flatten = True 181 | args.latent_size = 16 182 | args.act_name = "elu" 183 | args.decoder_act_name = "rational" 184 | args.use_grads = False 185 | args.n_train = "-1" 186 | args.use_pos = False 187 | args.contrastive_rel_coef = 0 188 | args.is_prioritized_dropout = False 189 | args.input_steps = 2 190 | args.multi_step = "1" 191 | args.latent_multi_step = "1" 192 | args.temporal_bundle_steps = 1 193 | # args.static_encoder_type = "param-2-elu" 194 | args.static_latent_size = 16 195 | args.n_train = ":100" 196 | args.epochs = 20 197 | args.test_interval = 10 198 | args.save_interval = 10 199 | 200 | #args.data_dropout = "node:0-0.4" 201 | args.use_pos = False 202 | args.load_filename = "IHvBKQ8K_ampere3" 203 | args.rl_coefs = "reward:0.1" 204 | 205 | # Data: 206 | args.time_interval = 1 207 | args.dataset_split_type = "random" 208 | args.train_fraction = 1 209 | 210 | # Model: 211 | args.evolution_type = "mlp-3-elu-2" 212 | args.forward_type = "Euler" 213 | args.act_name = "elu" 214 | args.is_mesh = True 215 | args.edge_attr=True 216 | # args.edge_threshold=0.000001 217 | args.edge_threshold=0. 218 | 219 | args.gpuid = "3" 220 | args.is_unittest = True 221 | 222 | except: 223 | is_jupyter = False 224 | 225 | if args.dataset.startswith("mppde1d"): 226 | if args.dataset.endswith("-40"): 227 | args.output_padding_str = "0-0-0-0" 228 | elif args.dataset.endswith("-50"): 229 | args.output_padding_str = "1-0-1-0" 230 | elif args.dataset.endswith("-100"): 231 | args.output_padding_str = "1-1-0-0" 232 | 233 | 234 | # # 2. Load data and model: 235 | 236 | # In[ ]: 237 | 238 | 239 | set_seed(args.seed) 240 | (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args) 241 | p.print(f"Minibatches for train: {len(train_loader)}") 242 | p.print(f"Minibatches for val: {len(val_loader)}") 243 | p.print(f"Minibatches for test: {len(test_loader)}") 244 | args_test8 = deepcopy(args) 245 | if args.dataset.startswith("m"): 246 | args_test8.multi_step = "1^8" 247 | args_test8.pred_steps = 8 248 | args_test8.is_train = False 249 | else: 250 | args_test8.multi_step = "1^75" 251 | args_test8.pred_steps = 75 252 | args_test8.is_train = False 253 | args_test8.batch_size = 1 254 | args_test8.val_batch_size = 1 255 | # seed_everything(42) 256 | (_, dataset_test8), (_, val_loader8, test_loader8) = load_data(args_test8) 257 | 258 | p.print(f"Minibatches for val8: {len(val_loader8)}") 259 | p.print(f"Minibatches for test8: {len(test_loader8)}") 260 | val_loader8 = None 261 | device = get_device(args) 262 | args.device = device 263 | args_test8.device = device 264 | 265 | data = deepcopy(dataset_test[0]) 266 | epoch = 0 267 | 268 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"): 269 | args_testfine = deepcopy(args) 270 | args_testfine.use_fineres_data = True 271 | (dataset_train_val_fine,_), (train_loader_fine, _, _) = load_data(args_testfine) 272 | p.print(f"Minibatches for trainfine: {len(train_loader_fine)}") 273 | else: 274 | args_testfine=None 275 | 276 | model = get_model(args, data, device) 277 | if args.algo.startswith("rlgnnremesher") or args.algo.startswith("srlgnnremesher"): 278 | device = get_device(args) 279 | loaded_dirname = EXP_PATH + args.load_dirname 280 | filenames = filter_filename(loaded_dirname, include=[args.load_filename]) 281 | assert len(filenames) == 1, f"There are {len(filenames)} files under ./results/{args.load_dirname} that contain the str {args.load_filename}. Re-check the argument of --load_dirname and --load_filename." 282 | loaded_filename = os.path.join(loaded_dirname, filenames[0]) 283 | data_record_load = pload(loaded_filename) 284 | 285 | args_load = init_args(update_legacy_default_hyperparam(data_record_load["args"])) 286 | args_load.multi_step = args.multi_step 287 | evolution_model = load_model(data_record_load["model_dict"][-1], device) 288 | if args.fix_alt_evolution_model: 289 | evolution_model_alt = load_model(data_record_load["model_dict"][-1], device) 290 | evolution_model_alt.to(device) 291 | evolution_model_alt.eval() 292 | else: 293 | evolution_model_alt = None 294 | if not args.rl_is_finetune_evolution: 295 | evolution_model.eval() 296 | 297 | 298 | if args.load_hash!="None": 299 | data_record_load = find_hash_and_load([args.load_hash]) 300 | model.actor.load_state_dict(data_record_load["model_dict"][-1]["actor_model_dict"]["state_dict"]) 301 | model.critic.load_state_dict(data_record_load["model_dict"][-1]["critic_model_dict"]["state_dict"]) 302 | model.critic_target.load_state_dict(data_record_load["model_dict"][-1]["critic_model_dict"]["state_dict"]) 303 | # model = load_model(data_record_load["model_dict"][-1],device=device) 304 | evolution_model = load_model(data_record_load["evolution_model_dict"][-1],device=device) 305 | if not args.rl_is_finetune_evolution: 306 | evolution_model.eval() 307 | 308 | # # 3. Training: 309 | 310 | # In[ ]: 311 | 312 | 313 | if args.algo.startswith("rlgnnremesher"): 314 | separate_params = [ 315 | {'params': model.actor.parameters(), 'lr': args.actor_lr}, 316 | {'params': model.critic.parameters(), 'lr': args.value_lr}, 317 | ] 318 | opt, scheduler = build_optimizer( 319 | args, params=None, 320 | separate_params=separate_params, 321 | ) 322 | elif args.algo.startswith("srlgnnremesher"): 323 | separate_params = [ {'params': model.parameters(), 'lr': args.actor_lr},] 324 | opt, scheduler = build_optimizer(args, params=None,separate_params=separate_params,) 325 | else: 326 | opt, scheduler = build_optimizer(args, model.parameters()) 327 | 328 | if args.rl_is_finetune_evolution: 329 | opt_params = [{'params': evolution_model.parameters(), 'lr': args.lr}] 330 | opt_evolution, opt_scheduler = build_optimizer( 331 | args, params=None, 332 | separate_params=opt_params, 333 | ) 334 | 335 | n_params_model = get_num_params(model) 336 | p.print("n_params_model: {}".format(n_params_model), end="") 337 | machine_name = os.uname()[1].split('.')[0] 338 | data_record = {"n_params_model": n_params_model, "args": update_dict(args.__dict__, "machine_name", machine_name), 339 | "best_train_model_dict": [], "best_train_loss": [], "best_train_loss_history":[]} 340 | early_stopping = Early_Stopping(patience=args.early_stopping_patience) 341 | 342 | short_str_dict = { 343 | "dataset": "", 344 | "n_train": "train", 345 | "algo": "algo", 346 | "act_name": "act", 347 | "latent_size": "hid", 348 | "multi_step": "mt", 349 | "temporal_bundle_steps": "tb", 350 | "loss_type": "lo", 351 | "gpuid": "gpu", 352 | "id": "id", 353 | } 354 | 355 | filename_short = get_filename_short( 356 | short_str_dict.keys(), 357 | short_str_dict, 358 | args_dict=args.__dict__, 359 | ) 360 | filename = EXP_PATH + "{}_{}/".format(args.exp_id, args.date_time) + filename_short[:-2] + "_{}.p".format(machine_name) 361 | write_to_config(args, filename) 362 | args.filename = filename 363 | kwargs = {} 364 | if args.algo.startswith("rlgnnremesher") or args.algo.startswith("srlgnnremesher"): 365 | kwargs["evolution_model"] = evolution_model 366 | kwargs["evolution_model_alt"] = evolution_model_alt 367 | p.print(filename, banner_size=100) 368 | # print(model) 369 | make_dir(filename) 370 | best_val_loss = np.Inf 371 | if args.load_filename != "None": 372 | val_loss = np.Inf 373 | collate_fn = deepsnap_Batch.collate() if data.__class__.__name__ == "HeteroGraph" else MeshBatch( 374 | is_absorb_batch=True, is_collate_tuple=True).collate() if args.dataset.startswith("arcsimmesh") else Batch( 375 | is_absorb_batch=True, is_collate_tuple=True).collate() 376 | if args.is_unittest: 377 | unittest_model(model, 378 | collate_fn([data, data]), args, device, use_grads=args.use_grads, use_pos=args.use_pos, is_mesh=args.is_mesh, 379 | test_cases="all" if not (args.dataset.startswith("PIL") or args.dataset.startswith("PHIL")) else "model_dict", algo=args.algo, 380 | **kwargs 381 | ) 382 | if args.is_tensorboard: 383 | from tensorboardX import SummaryWriter 384 | writer = SummaryWriter(EXP_PATH + '/log/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 385 | pp.pprint(args.__dict__) 386 | 387 | if args.wandb: 388 | import wandb 389 | args.wandb_project_name = name="_".join([args.wandb_project_name,args.algo]) 390 | wandb.init(project=args.wandb_project_name, entity="multi-scale", name="_".join(["beta_",args.reward_beta]+filename.split("_")[-2:])[:-2] + f'_ntrain_{args.n_train}{"_" + args.id if args.id != "0" else ""}', 391 | config={"name": args.exp_id}) 392 | wandb.watch(model, log_freq=1000, log="all") 393 | wandb.watch(evolution_model, log_freq=1000, log="all") 394 | wandb.config=vars(args) 395 | else: 396 | wandb = None 397 | step_num = 0 398 | opt_actor = None 399 | opt_evl = None 400 | 401 | # if (args.algo.startswith("rlgnnremesher") or args.algo.startswith("srlgnnremesher")) and args.wandb: 402 | # model.get_tested( 403 | # test_loader8, 404 | # args_test8, 405 | # current_epoch=0, 406 | # current_minibatch=0, 407 | # wandb=wandb, 408 | # step_num=step_num, 409 | # **kwargs 410 | # ) 411 | 412 | if args.load_hash!="None": 413 | if "last_optimizer_dict" in data_record_load.keys(): 414 | opt.load_state_dict(data_record_load["last_optimizer_dict"]) 415 | if "last_evolution_optimizer_dict" in data_record_load.keys(): 416 | opt_evolution.load_state_dict(data_record_load["last_evolution_optimizer_dict"]) 417 | if "last_scheduler_dict" in data_record_load.keys(): 418 | opt_scheduler.load_state_dict(data_record_load["last_scheduler_dict"]) 419 | 420 | 421 | 422 | while epoch < args.epochs: 423 | total_loss = 0 424 | count = 0 425 | 426 | model.train() 427 | train_info = {} 428 | best_train_loss = np.Inf 429 | last_few_losses = [] 430 | num_losses = 20 431 | t_start = time.time() 432 | 433 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"): 434 | train_loader_fine_iterator = iter(train_loader_fine) 435 | for j, data in enumerate(train_loader): 436 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"): 437 | try: 438 | data_fine = next(train_loader_fine_iterator) 439 | except StopIteration: 440 | train_loader_fine_iterator = iter(train_loader_fine) 441 | data_fine = next(dataloader_iterator) 442 | else: 443 | data_fine=None 444 | 445 | t_end = time.time() 446 | if args.verbose >= 2 and j % 100 == 0: 447 | p.print(f"Data loading time: {t_end - t_start:.6f}") 448 | if data.__class__.__name__ == "Attr_Dict": 449 | data = data.to(device) 450 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"): data_fine = data_fine.to(device) 451 | else: 452 | data.to(device) 453 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"): data_fine.to(device) 454 | opt.zero_grad() 455 | if args.rl_is_finetune_evolution: 456 | opt_evolution.zero_grad() 457 | if args.actor_critic_step==None: 458 | opt_actor=True 459 | opt_evl = True 460 | else: 461 | if step_num%(args.actor_critic_step+args.evolution_steps)= 0: 625 | p.print(filename) 626 | record_data(data_record, [epoch, get_model_dict(model)], ["save_epoch", "model_dict"]) 627 | if "evolution_model" in locals(): 628 | record_data(data_record, [evolution_model.model_dict], ["evolution_model_dict"]) 629 | with open(filename, "wb") as f: 630 | pickle.dump(data_record, f) 631 | if val_loss < best_val_loss: 632 | best_val_loss = val_loss 633 | data_record["best_model_dict"] = get_model_dict(model) 634 | data_record["best_optimizer_dict"] = opt.state_dict() 635 | data_record["best_scheduler_dict"] = scheduler.state_dict() if scheduler is not None else None 636 | if "evolution_model" in locals(): 637 | data_record["best_evolution_model_dict"] = evolution_model.model_dict 638 | if args.rl_is_finetune_evolution: 639 | data_record["best_evolution_optimizer_dict"] = opt_evolution.state_dict() 640 | data_record["best_epoch"] = epoch 641 | data_record["last_model_dict"] = get_model_dict(model) 642 | data_record["last_optimizer_dict"] = opt.state_dict() 643 | data_record["last_scheduler_dict"] = scheduler.state_dict() if scheduler is not None else None 644 | data_record["last_epoch"] = epoch 645 | if "evolution_model" in locals(): 646 | data_record["last_evolution_model_dict"] = evolution_model.model_dict 647 | if args.rl_is_finetune_evolution: 648 | data_record["last_evolution_optimizer_dict"] = opt_evolution.state_dict() 649 | p.print("12", precision="millisecond", is_silent=args.is_timing<1, avg_window=1) 650 | 651 | pdump(data_record, filename) 652 | if "to_stop" in locals() and to_stop: 653 | p.print("Early-stop at epoch {}.".format(epoch)) 654 | break 655 | epoch += 1 656 | record_data(data_record, [epoch, get_model_dict(model)], ["save_epoch", "model_dict"]) 657 | if "evolution_model" in locals(): 658 | record_data(data_record, [evolution_model.model_dict], ["evolution_model_dict"]) 659 | pdump(data_record, filename) 660 | 661 | -------------------------------------------------------------------------------- /utils_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[ ]: 5 | 6 | 7 | import torch 8 | import math 9 | import numpy as np 10 | from torch import nn 11 | # from torchmeta.modules import MetaModule 12 | from collections import OrderedDict 13 | import copy 14 | import torch.nn.functional as F 15 | 16 | 17 | def get_conv_func(pos_dim, *args, **kwargs): 18 | if "reg_type_list" in kwargs: 19 | reg_type_list = kwargs.pop("reg_type_list") 20 | else: 21 | reg_type_list = None 22 | if pos_dim == 1: 23 | conv = nn.Conv1d(*args, **kwargs) 24 | elif pos_dim == 2: 25 | conv = nn.Conv2d(*args, **kwargs) 26 | elif pos_dim == 3: 27 | conv = nn.Conv3d(*args, **kwargs) 28 | else: 29 | raise Exception("The pos_dim can only be 1, 2 or 3!") 30 | if reg_type_list is not None: 31 | if "snn" in reg_type_list: 32 | conv = SpectralNorm(conv) 33 | elif "snr" in reg_type_list: 34 | conv = SpectralNormReg(conv) 35 | return conv 36 | 37 | 38 | def get_conv_trans_func(pos_dim, *args, **kwargs): 39 | if "reg_type_list" in kwargs: 40 | reg_type_list = kwargs.pop("reg_type_list") 41 | else: 42 | reg_type_list = None 43 | if pos_dim == 1: 44 | conv_trans = nn.ConvTranspose1d(*args, **kwargs) 45 | elif pos_dim == 2: 46 | conv_trans = nn.ConvTranspose2d(*args, **kwargs) 47 | elif pos_dim == 3: 48 | conv_trans = nn.ConvTranspose3d(*args, **kwargs) 49 | else: 50 | raise Exception("The pos_dim can only be 1, 2 or 3!") 51 | # The weight's output dim=1 for ConvTranspose 52 | if reg_type_list is not None: 53 | if "snn" in reg_type_list: 54 | conv_trans = SpectralNorm(conv_trans, dim=1) 55 | elif "snr" in reg_type_list: 56 | conv_trans = SpectralNormReg(conv_trans, dim=1) 57 | return conv_trans 58 | 59 | 60 | # ### Spectral Norm: 61 | 62 | # In[ ]: 63 | 64 | 65 | def l2normalize(v, eps=1e-12): 66 | return v / (v.norm() + eps) 67 | 68 | 69 | class SpectralNorm(nn.Module): 70 | def __init__(self, module, name='weight', power_iterations=1, dim=0): 71 | super(SpectralNorm, self).__init__() 72 | self.module = module 73 | self.name = name 74 | self.power_iterations = power_iterations 75 | self.dim = dim 76 | if not self._made_params(): 77 | self._make_params() 78 | 79 | def reshape_weight_to_matrix(self, weight): 80 | weight_mat = weight 81 | if self.dim != 0: 82 | # permute dim to front 83 | weight_mat = weight_mat.permute(self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]) 84 | height = weight_mat.size(0) 85 | return weight_mat.reshape(height, -1) 86 | 87 | def _update_u_v(self): 88 | u = getattr(self.module, self.name + "_u") 89 | v = getattr(self.module, self.name + "_v") 90 | w = getattr(self.module, self.name + "_bar") 91 | w_mat = self.reshape_weight_to_matrix(w) 92 | 93 | height = w_mat.shape[0] 94 | for _ in range(self.power_iterations): 95 | v.data = l2normalize(torch.mv(torch.t(w_mat.data), u.data)) 96 | u.data = l2normalize(torch.mv(w_mat.data, v.data)) 97 | 98 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 99 | sigma = u.dot(w_mat.mv(v)) 100 | setattr(self.module, self.name, w / sigma.expand_as(w)) 101 | 102 | def _made_params(self): 103 | try: 104 | u = getattr(self.module, self.name + "_u") 105 | v = getattr(self.module, self.name + "_v") 106 | w = getattr(self.module, self.name + "_bar") 107 | return True 108 | except AttributeError: 109 | return False 110 | 111 | def _make_params(self): 112 | w = getattr(self.module, self.name) 113 | w_mat = self.reshape_weight_to_matrix(w) 114 | 115 | height = w_mat.shape[0] 116 | width = w_mat.shape[1] 117 | 118 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 119 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 120 | u.data = l2normalize(u.data) 121 | v.data = l2normalize(v.data) 122 | w_bar = nn.Parameter(w.data) 123 | 124 | del self.module._parameters[self.name] 125 | 126 | self.module.register_parameter(self.name + "_u", u) 127 | self.module.register_parameter(self.name + "_v", v) 128 | self.module.register_parameter(self.name + "_bar", w_bar) 129 | 130 | def forward(self, *args): 131 | if self.training: 132 | self._update_u_v() 133 | else: 134 | setattr(self.module, self.name, getattr(self.module, self.name + "_bar") / 1) 135 | return self.module.forward(*args) 136 | 137 | 138 | # ### SpectralNormReg: 139 | 140 | # In[ ]: 141 | 142 | 143 | class SpectralNormReg(nn.Module): 144 | def __init__(self, module, name='weight', power_iterations=1, dim=0): 145 | super(SpectralNormReg, self).__init__() 146 | self.module = module 147 | self.name = name 148 | self.power_iterations = power_iterations 149 | self.dim = dim 150 | if not self._made_params(): 151 | self._make_params() 152 | 153 | def reshape_weight_to_matrix(self, weight): 154 | weight_mat = weight 155 | if self.dim != 0: 156 | # permute dim to front 157 | weight_mat = weight_mat.permute(self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]) 158 | height = weight_mat.size(0) 159 | return weight_mat.reshape(height, -1) 160 | 161 | def compute_snreg(self): 162 | u = getattr(self.module, self.name + "_u") 163 | v = getattr(self.module, self.name + "_v") 164 | w = getattr(self.module, self.name + "_bar") 165 | w_mat = self.reshape_weight_to_matrix(w) 166 | 167 | height = w_mat.shape[0] 168 | for _ in range(self.power_iterations): 169 | v.data = l2normalize(torch.mv(torch.t(w_mat.data), u.data)) 170 | u.data = l2normalize(torch.mv(w_mat.data, v.data)) 171 | 172 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 173 | sigma = u.dot(w_mat.mv(v)) 174 | self.snreg = sigma.square() / 2 175 | setattr(self.module, self.name, w / 1) # Here the " / 1" is to prevent state_dict() to record self.module.weight 176 | 177 | def _made_params(self): 178 | try: 179 | u = getattr(self.module, self.name + "_u") 180 | v = getattr(self.module, self.name + "_v") 181 | w = getattr(self.module, self.name + "_bar") 182 | return True 183 | except AttributeError: 184 | return False 185 | 186 | def _make_params(self): 187 | w = getattr(self.module, self.name) 188 | w_mat = self.reshape_weight_to_matrix(w) 189 | 190 | height = w_mat.shape[0] 191 | width = w_mat.shape[1] 192 | 193 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 194 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 195 | u.data = l2normalize(u.data) 196 | v.data = l2normalize(v.data) 197 | w_bar = nn.Parameter(w.data) 198 | 199 | del self.module._parameters[self.name] 200 | 201 | self.module.register_parameter(self.name + "_u", u) 202 | self.module.register_parameter(self.name + "_v", v) 203 | self.module.register_parameter(self.name + "_bar", w_bar) 204 | 205 | def forward(self, *args): 206 | self.compute_snreg() 207 | return self.module.forward(*args) 208 | 209 | 210 | # ### Hessian regularization: 211 | 212 | # In[ ]: 213 | 214 | 215 | def get_Hessian_penalty( 216 | G, 217 | z, 218 | mode, 219 | k=2, 220 | epsilon=0.1, 221 | reduction=torch.max, 222 | return_separately=False, 223 | G_z=None, 224 | is_nondimensionalize=False, 225 | **G_kwargs 226 | ): 227 | """ 228 | Adapted from https://github.com/wpeebles/hessian_penalty/ (Peebles et al. 2020). 229 | Note: If you want to regularize multiple network activations simultaneously, you need to 230 | make sure the function G you pass to hessian_penalty returns a list of those activations when it's called with 231 | G(z, **G_kwargs). Otherwise, if G returns a tensor the Hessian Penalty will only be computed for the final 232 | output of G. 233 | 234 | Args: 235 | G: Function that maps input z to either a tensor or a list of tensors (activations) 236 | z: Input to G that the Hessian Penalty will be computed with respect to 237 | mode: choose from "Hdiag", "Hoff" or "Hall", specifying the scope of Hessian values to perform sum square on. 238 | "Hall" will be the sum of "Hdiag" (for diagonal elements) and "Hoff" (for off-diagonal elements). 239 | k: Number of Hessian directions to sample (must be >= 2) 240 | epsilon: Amount to blur G before estimating Hessian (must be > 0) 241 | reduction: Many-to-one function to reduce each pixel/neuron's individual hessian penalty into a final loss 242 | return_separately: If False, hessian penalties for each activation output by G are automatically summed into 243 | a final loss. If True, the hessian penalties for each layer will be returned in a list 244 | instead. If G outputs a single tensor, setting this to True will produce a length-1 245 | list. 246 | :param G_z: [Optional small speed-up] If you have already computed G(z, **G_kwargs) for the current training 247 | iteration, then you can provide it here to reduce the number of forward passes of this method by 1 248 | :param G_kwargs: Additional inputs to G besides the z vector. For example, in BigGAN you 249 | would pass the class label into this function via y= 250 | :return: A differentiable scalar (the hessian penalty), or a list of hessian penalties if return_separately is True 251 | """ 252 | if G_z is None: 253 | G_z = G(z, **G_kwargs) 254 | rademacher_size = torch.Size((k, *z.size())) # (k, N, z.size()) 255 | if mode == "Hall": 256 | loss_diag = get_Hessian_penalty(G=G, z=z, mode="Hdiag", k=k, epsilon=epsilon, reduction=reduction, return_separately=return_separately, G_z=G_z, **G_kwargs) 257 | loss_offdiag = get_Hessian_penalty(G=G, z=z, mode="Hoff", k=k, epsilon=epsilon, reduction=reduction, return_separately=return_separately, G_z=G_z, **G_kwargs) 258 | if return_separately: 259 | loss = [] 260 | for loss_i_diag, loss_i_offdiag in zip(loss_diag, loss_offdiag): 261 | loss.append(loss_i_diag + loss_i_offdiag) 262 | else: 263 | loss = loss_diag + loss_offdiag 264 | return loss 265 | elif mode == "Hdiag": 266 | xs = epsilon * complex_rademacher(rademacher_size, device=z.device) 267 | elif mode == "Hoff": 268 | xs = epsilon * rademacher(rademacher_size, device=z.device) 269 | else: 270 | raise 271 | second_orders = [] 272 | 273 | if mode == "Hdiag" and isinstance(G, nn.Module): 274 | # Use the complex64 dtype: 275 | dtype_ori = next(iter(G.parameters())).dtype 276 | G.type(torch.complex64) 277 | if isinstance(G, nn.Module): 278 | G_wrapper = get_listified_fun(G) 279 | G_z = listity_tensor(G_z) 280 | else: 281 | G_wrapper = G 282 | 283 | for x in xs: # Iterate over each (N, z.size()) tensor in xs 284 | central_second_order = multi_layer_second_directional_derivative(G_wrapper, z, x, G_z, epsilon, **G_kwargs) 285 | second_orders.append(central_second_order) # Appends a tensor with shape equal to G(z).size() 286 | loss = multi_stack_metric_and_reduce(second_orders, mode, reduction, return_separately) # (k, G(z).size()) --> scalar 287 | 288 | if mode == "Hdiag" and isinstance(G, nn.Module): 289 | # Revert back to original dtype: 290 | G.type(dtype_ori) 291 | 292 | if is_nondimensionalize: 293 | # Multiply a factor ||z||_2^2 so that the result is dimensionless: 294 | factor = z.square().mean() 295 | if return_separately: 296 | loss = [ele * factor for ele in loss] 297 | else: 298 | loss = loss * factor 299 | return loss 300 | 301 | 302 | def listity_tensor(tensor): 303 | """Turn the output features (except for the first batch dimension) of a function into a list 304 | 305 | Args: 306 | tensor: has shape [B, d1, d2, ...] 307 | """ 308 | batch_size = tensor.shape[0] 309 | shape = tensor.shape[1:] 310 | tensor_reshape = tensor.reshape(batch_size, -1) 311 | tensor_listify = [tensor_reshape[:, i] for i in range(tensor_reshape.shape[1])] 312 | return tensor_listify 313 | 314 | 315 | def get_listified_fun(G): 316 | def fun(z, **Gkwargs): 317 | G_out = G(z, **Gkwargs) 318 | return listity_tensor(G_out) 319 | return fun 320 | 321 | 322 | def rademacher(shape, device='cpu'): 323 | """Creates a random tensor of size [shape] under the Rademacher distribution (P(x=1) == P(x=-1) == 0.5)""" 324 | return torch.randint(2, size=shape, device=device).float() * 2 - 1 325 | 326 | 327 | def complex_rademacher(shape, device='cpu'): 328 | """Creates a random tensor of size [shape] with (P(x=1) == P(x=-1) == P(x=1j) == P(x=-1j) == 0.25)""" 329 | collection = torch.from_numpy(np.array([1., -1, 1j, -1j])).type(torch.complex64).to(device) 330 | x = x.randint(4, size=shape, device=device) # Creates random tensor of 0, 1, 2, 3 331 | return collection[x] # Map tensor of 0, 1, 2, 3 to 1, -1, 1j, -1j 332 | 333 | 334 | def multi_layer_second_directional_derivative(G, z, x, G_z, epsilon, **G_kwargs): 335 | """Estimates the second directional derivative of G w.r.t. its input at z in the direction x""" 336 | G_to_x = G(z + x, **G_kwargs) 337 | G_from_x = G(z - x, **G_kwargs) 338 | 339 | G_to_x = listify(G_to_x) 340 | G_from_x = listify(G_from_x) 341 | G_z = listify(G_z) 342 | 343 | eps_sqr = epsilon ** 2 344 | sdd = [(G2x - 2 * G_z_base + Gfx) / eps_sqr for G2x, G_z_base, Gfx in zip(G_to_x, G_z, G_from_x)] 345 | return sdd 346 | 347 | 348 | def stack_metric_and_reduce(list_of_activations, mode, reduction=torch.max): 349 | """Equation (5) from the paper.""" 350 | second_orders = torch.stack(list_of_activations) # (k, N, C, H, W) 351 | if mode == "Hoff": 352 | tensor = torch.var(second_orders, dim=0, unbiased=True) / 2 # (N, C, H, W) 353 | elif mode == "Hdiag": 354 | tensor = torch.mean((second_orders ** 2).real, dim=0) 355 | else: 356 | raise 357 | penalty = reduction(tensor) # (1,) (scalar) 358 | return penalty 359 | 360 | 361 | def multi_stack_metric_and_reduce(sdds, mode, reduction=torch.max, return_separately=False): 362 | """Iterate over all activations to be regularized, then apply Equation (5) to each.""" 363 | sum_of_penalties = 0 if not return_separately else [] 364 | for activ_n in zip(*sdds): 365 | penalty = stack_metric_and_reduce(activ_n, mode, reduction) 366 | sum_of_penalties += penalty if not return_separately else [penalty] 367 | return sum_of_penalties 368 | 369 | 370 | def listify(x): 371 | """If x is already a list, do nothing. Otherwise, wrap x in a list.""" 372 | if isinstance(x, list): 373 | return x 374 | else: 375 | return [x] 376 | 377 | 378 | def _test_hessian_penalty(mode, k=100): 379 | """ 380 | A simple multi-layer test to verify the implementation. 381 | Function: G(z) = [z_0 * z_1 + z0 ** 2, z_0**2 * z_1 + 2 * z1 ** 2] 382 | The Hessian for the first value is [ 383 | [2, 1], 384 | [1, 0], 385 | ] 386 | so the offdiagonal sum square is 2 387 | the diagonal sum square is 4. 388 | 389 | The Hessian for the second function value is: [ 390 | [2 * z_1, 2 * z_0], 391 | [2 * z_0, 4], 392 | ] 393 | so the offdiagonal sum square is 8 * z_0**2 394 | the diagonal sum square is 16 + 4 * z_1**2. 395 | Ground Truth Hessian Penalty: [2, 8 * z_0**2] 396 | """ 397 | batch_size = 10 398 | nz = 2 399 | z = torch.randn(batch_size, nz) 400 | def reduction(x): return x.abs().mean() 401 | def G(z): return [z[:, 0] * z[:, 1] + z[:, 0] ** 2, (z[:, 0] ** 2) * z[:, 1] + 2 * z[:, 1] ** 2] 402 | if mode == "Hdiag": 403 | ground_truth = [4, 16 + 4 * (z[:, 1] ** 2).mean().item()] 404 | elif mode == "Hoff": 405 | ground_truth = [2, reduction(8 * z[:, 0] ** 2).item()] 406 | elif mode == "Hall": 407 | ground_truth = [4+2, 16 + 4 * (z[:, 1] ** 2).mean().item() + reduction(8 * z[:, 0] ** 2).item()] 408 | else: 409 | raise 410 | # In this simple example, we use k=100 to reduce variance, but when applied to neural networks 411 | # you will probably want to use a small k (e.g., k=2) due to memory considerations. 412 | predicted = get_Hessian_penalty(G, z, mode=mode, G_z=None, k=k, reduction=reduction, return_separately=True) 413 | predicted = [p.item() for p in predicted] 414 | print('Ground Truth: %s' % ground_truth) 415 | print('Approximation: %s' % predicted) # This should be close to ground_truth, but not exactly correct 416 | print('Difference: %s' % [str(100 * abs(p - gt) / gt) + '%' for p, gt in zip(predicted, ground_truth)]) 417 | 418 | 419 | ## >>> functions for the MeshGraphNets: 420 | 421 | def init_weights_requ(m): 422 | if type(m) == BatchLinear or type(m) == nn.Linear: 423 | if hasattr(m, 'weight'): 424 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') 425 | 426 | 427 | def init_weights_normal(m): 428 | if type(m) == BatchLinear or type(m) == nn.Linear: 429 | if hasattr(m, 'weight'): 430 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_out') 431 | 432 | 433 | def init_weights_selu(m): 434 | if type(m) == BatchLinear or type(m) == nn.Linear: 435 | if hasattr(m, 'weight'): 436 | num_input = m.weight.size(-1) 437 | nn.init.normal_(m.weight, std=1/math.sqrt(num_input)) 438 | 439 | 440 | def init_weights_elu(m): 441 | if type(m) == BatchLinear or type(m) == nn.Linear: 442 | if hasattr(m, 'weight'): 443 | num_input = m.weight.size(-1) 444 | nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277)/math.sqrt(num_input)) 445 | 446 | 447 | def init_weights_xavier(m): 448 | if type(m) == BatchLinear or type(m) == nn.Linear: 449 | if hasattr(m, 'weight'): 450 | nn.init.xavier_normal_(m.weight) 451 | 452 | 453 | def init_weights_uniform(m): 454 | if type(m) == BatchLinear or type(m) == nn.Linear: 455 | if hasattr(m, 'weight'): 456 | torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') 457 | 458 | 459 | def sine_init(m, w0=60): 460 | with torch.no_grad(): 461 | if hasattr(m, 'weight'): 462 | num_input = m.weight.size(-1) 463 | m.weight.uniform_(-np.sqrt(6/num_input)/w0, np.sqrt(6/num_input)/w0) 464 | 465 | 466 | def first_layer_sine_init(m): 467 | with torch.no_grad(): 468 | if hasattr(m, 'weight'): 469 | num_input = m.weight.size(-1) 470 | m.weight.uniform_(-1/num_input, 1/num_input) 471 | 472 | 473 | class MetaModule(nn.Module): 474 | """ 475 | Base class for PyTorch meta-learning modules. These modules accept an 476 | additional argument `params` in their `forward` method. 477 | Notes 478 | ----- 479 | Objects inherited from `MetaModule` are fully compatible with PyTorch 480 | modules from `torch.nn.Module`. The argument `params` is a dictionary of 481 | tensors, with full support of the computation graph (for differentiation). 482 | """ 483 | def __init__(self): 484 | super(MetaModule, self).__init__() 485 | self._children_modules_parameters_cache = dict() 486 | 487 | def meta_named_parameters(self, prefix='', recurse=True): 488 | gen = self._named_members( 489 | lambda module: module._parameters.items() 490 | if isinstance(module, MetaModule) else [], 491 | prefix=prefix, recurse=recurse) 492 | for elem in gen: 493 | yield elem 494 | 495 | def meta_parameters(self, recurse=True): 496 | for name, param in self.meta_named_parameters(recurse=recurse): 497 | yield param 498 | 499 | def get_subdict(self, params, key=None): 500 | if params is None: 501 | return None 502 | 503 | all_names = tuple(params.keys()) 504 | if (key, all_names) not in self._children_modules_parameters_cache: 505 | if key is None: 506 | self._children_modules_parameters_cache[(key, all_names)] = all_names 507 | 508 | else: 509 | key_escape = re.escape(key) 510 | key_re = re.compile(r'^{0}\.(.+)'.format(key_escape)) 511 | 512 | self._children_modules_parameters_cache[(key, all_names)] = [ 513 | key_re.sub(r'\1', k) for k in all_names if key_re.match(k) is not None] 514 | 515 | names = self._children_modules_parameters_cache[(key, all_names)] 516 | if not names: 517 | warnings.warn('Module `{0}` has no parameter corresponding to the ' 518 | 'submodule named `{1}` in the dictionary `params` ' 519 | 'provided as an argument to `forward()`. Using the ' 520 | 'default parameters for this submodule. The list of ' 521 | 'the parameters in `params`: [{2}].'.format( 522 | self.__class__.__name__, key, ', '.join(all_names)), 523 | stacklevel=2) 524 | return None 525 | 526 | return OrderedDict([(name, params[f'{key}.{name}']) for name in names]) 527 | 528 | 529 | class BatchLinear(nn.Linear, MetaModule): 530 | '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a 531 | hypernetwork.''' 532 | __doc__ = nn.Linear.__doc__ 533 | 534 | def forward(self, input, params=None): 535 | if params is None: 536 | params = OrderedDict(self.named_parameters()) 537 | 538 | bias = params.get('bias', None) 539 | weight = params['weight'] 540 | 541 | output = input.matmul(weight.permute(*[i for i in range(len(weight.shape)-2)], -1, -2)) 542 | output += bias.unsqueeze(-2) 543 | return output 544 | 545 | 546 | class FirstSine(nn.Module): 547 | def __init__(self, w0=60): 548 | super().__init__() 549 | self.w0 = torch.tensor(w0) 550 | 551 | def forward(self, input): 552 | return torch.sin(self.w0*input) 553 | 554 | 555 | class Sine(nn.Module): 556 | def __init__(self, w0=60): 557 | super().__init__() 558 | self.w0 = torch.tensor(w0) 559 | 560 | def forward(self, input): 561 | return torch.sin(self.w0*input) 562 | 563 | 564 | class ReQU(nn.Module): 565 | def __init__(self, inplace=True): 566 | super().__init__() 567 | self.relu = nn.ReLU(inplace) 568 | 569 | def forward(self, input): 570 | # return torch.sin(np.sqrt(256)*input) 571 | return .5*self.relu(input)**2 572 | 573 | 574 | class MSoftplus(nn.Module): 575 | def __init__(self): 576 | super().__init__() 577 | self.softplus = nn.Softplus() 578 | self.cst = torch.log(torch.tensor(2.)) 579 | 580 | def forward(self, input): 581 | return self.softplus(input)-self.cst 582 | 583 | 584 | class Swish(nn.Module): 585 | def __init__(self): 586 | super().__init__() 587 | 588 | def forward(self, input): 589 | return input*torch.sigmoid(input) 590 | 591 | 592 | def layer_factory(layer_type): 593 | layer_dict = { 594 | 'relu': (nn.ReLU(inplace=True), init_weights_normal), 595 | 'leakyrelu': (nn.LeakyReLU(inplace=True), init_weights_normal), 596 | 'requ': (ReQU(inplace=False), init_weights_requ), 597 | 'sigmoid': (nn.Sigmoid(), None), 598 | 'fsine': (Sine(), first_layer_sine_init), 599 | 'sine': (Sine(), sine_init), 600 | 'tanh': (nn.Tanh(), init_weights_xavier), 601 | 'selu': (nn.SELU(inplace=True), init_weights_selu), 602 | 'gelu': (nn.GELU(), init_weights_selu), 603 | 'swish': (Swish(), init_weights_selu), 604 | 'softplus': (nn.Softplus(), init_weights_normal), 605 | 'msoftplus': (MSoftplus(), init_weights_normal), 606 | 'elu': (nn.ELU(), init_weights_elu), 607 | 'silu': (nn.SiLU(), init_weights_selu), 608 | } 609 | return layer_dict[layer_type] 610 | 611 | 612 | class PositionalEncoding(nn.Module): 613 | def __init__(self, num_encoding_functions=6, include_input=True, log_sampling=True, normalize=False, 614 | input_dim=2, gaussian_pe=False, gaussian_variance=0.1): 615 | super().__init__() 616 | self.num_encoding_functions = num_encoding_functions 617 | self.include_input = include_input 618 | self.log_sampling = log_sampling 619 | self.normalize = normalize 620 | self.gaussian_pe = gaussian_pe 621 | self.normalization = None 622 | 623 | if self.gaussian_pe: 624 | # this needs to be registered as a parameter so that it is saved in the model state dict 625 | # and so that it is converted using .cuda(). Doesn't need to be trained though 626 | self.gaussian_weights = nn.Parameter(2*np.pi*gaussian_variance * torch.randn((num_encoding_functions*2), input_dim), 627 | requires_grad=False) 628 | 629 | else: 630 | self.frequency_bands = None 631 | if self.log_sampling: 632 | self.frequency_bands = 2.0 ** torch.linspace( 633 | 0.0, 634 | self.num_encoding_functions - 1, 635 | self.num_encoding_functions) 636 | else: 637 | self.frequency_bands = torch.linspace( 638 | 2.0 ** 0.0, 639 | 2.0 ** (self.num_encoding_functions - 1), 640 | self.num_encoding_functions) 641 | 642 | if normalize: 643 | self.normalization = torch.tensor(1/self.frequency_bands) 644 | 645 | def forward(self, tensor) -> torch.Tensor: 646 | r"""Apply positional encoding to the input. 647 | Args: 648 | tensor (torch.Tensor): Input tensor to be positionally encoded. 649 | encoding_size (optional, int): Number of encoding functions used to compute 650 | a positional encoding (default: 6). 651 | include_input (optional, bool): Whether or not to include the input in the 652 | positional encoding (default: True). 653 | Returns: 654 | (torch.Tensor): Positional encoding of the input tensor. 655 | """ 656 | 657 | encoding = [tensor] if self.include_input else [] 658 | if self.gaussian_pe: 659 | for func in [torch.sin, torch.cos]: 660 | encoding.append(func(torch.matmul(tensor, self.gaussian_weights.T))) 661 | else: 662 | for idx, freq in enumerate(self.frequency_bands): 663 | for func in [torch.sin, torch.cos]: 664 | if self.normalization is not None: 665 | encoding.append(self.normalization[idx]*func(tensor * freq)) 666 | else: 667 | encoding.append(func(tensor * freq)) 668 | 669 | # Special case, for no positional encoding 670 | if len(encoding) == 1: 671 | return encoding[0] 672 | else: 673 | return torch.cat(encoding, dim=-1) 674 | 675 | 676 | 677 | class FCBlock(nn.Module): 678 | '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork. 679 | Can be used just as a normal neural network though, as well. 680 | ''' 681 | def __init__(self, in_features, out_features, 682 | num_hidden_layers, hidden_features, 683 | outermost_linear=False, outmost_nonlinearity=None, nonlinearity='relu', 684 | weight_init=None, w0=60, set_bias=None, 685 | dropout=0.0, layer_norm=False,latent_dim=64,skip_connect=None): 686 | super().__init__() 687 | 688 | self.skip_connect = skip_connect 689 | self.latent_dim = latent_dim 690 | self.first_layer_init = None 691 | self.dropout = dropout 692 | 693 | if outmost_nonlinearity==None: 694 | outmost_nonlinearity = nonlinearity 695 | 696 | # Create hidden features list 697 | if not isinstance(hidden_features, list): 698 | num_hidden_features = hidden_features 699 | hidden_features = [] 700 | for i in range(num_hidden_layers+1): 701 | hidden_features.append(num_hidden_features) 702 | else: 703 | num_hidden_layers = len(hidden_features)-1 704 | #print(f"net_size={hidden_features}") 705 | 706 | # Create the net 707 | #print(f"num_layers={len(hidden_features)}") 708 | if isinstance(nonlinearity, list): 709 | print(f"num_non_lin={len(nonlinearity)}") 710 | assert len(hidden_features) == len(nonlinearity), "Num hidden layers needs to " "match the length of the list of non-linearities" 711 | 712 | self.net = [] 713 | self.net.append(nn.Sequential( 714 | nn.Linear(in_features, hidden_features[0]), 715 | layer_factory(nonlinearity[0])[0] 716 | )) 717 | for i in range(num_hidden_layers): 718 | if self.skip_connect==None: 719 | self.net.append(nn.Sequential( 720 | nn.Linear(hidden_features[i], hidden_features[i+1]), 721 | layer_factory(nonlinearity[i+1])[0] 722 | )) 723 | else: 724 | if i+1 in self.skip_connect: 725 | self.net.append(nn.Sequential( 726 | nn.Linear(hidden_features[i]+self.latent_dim, hidden_features[i+1]), 727 | layer_factory(nonlinearity[i+1])[0] 728 | )) 729 | else: 730 | self.net.append(nn.Sequential( 731 | nn.Linear(hidden_features[i], hidden_features[i+1]), 732 | layer_factory(nonlinearity[i+1])[0] 733 | )) 734 | 735 | if outermost_linear: 736 | self.net.append(nn.Sequential( 737 | nn.Linear(hidden_features[-1], out_features), 738 | )) 739 | else: 740 | self.net.append(nn.Sequential( 741 | nn.Linear(hidden_features[-1], out_features), 742 | layer_factory(nonlinearity[-1])[0] 743 | )) 744 | elif isinstance(nonlinearity, str): 745 | nl, weight_init = layer_factory(nonlinearity) 746 | outmost_nl, _ = layer_factory(outmost_nonlinearity) 747 | if(nonlinearity == 'sine'): 748 | first_nl = FirstSine() 749 | self.first_layer_init = first_layer_sine_init 750 | else: 751 | first_nl = nl 752 | 753 | if weight_init is not None: 754 | self.weight_init = weight_init 755 | 756 | self.net = [] 757 | self.net.append(nn.Sequential( 758 | nn.Linear(in_features, hidden_features[0]), 759 | first_nl 760 | )) 761 | 762 | for i in range(num_hidden_layers): 763 | if(self.dropout > 0): 764 | self.net.append(nn.Dropout(self.dropout)) 765 | if self.skip_connect == None: 766 | self.net.append(nn.Sequential( 767 | nn.Linear(hidden_features[i], hidden_features[i+1]), 768 | copy.deepcopy(nl) 769 | )) 770 | else: 771 | if i+1 in self.skip_connect: 772 | self.net.append(nn.Sequential( 773 | nn.Linear(hidden_features[i]+self.latent_dim, hidden_features[i+1]), 774 | copy.deepcopy(nl) 775 | )) 776 | else: 777 | self.net.append(nn.Sequential( 778 | nn.Linear(hidden_features[i], hidden_features[i+1]), 779 | copy.deepcopy(nl) 780 | )) 781 | 782 | if (self.dropout > 0): 783 | self.net.append(nn.Dropout(self.dropout)) 784 | if outermost_linear: 785 | self.net.append(nn.Sequential( 786 | nn.Linear(hidden_features[-1], out_features), 787 | )) 788 | else: 789 | self.net.append(nn.Sequential( 790 | nn.Linear(hidden_features[-1], out_features), 791 | copy.deepcopy(outmost_nl) 792 | )) 793 | if layer_norm: 794 | self.net.append(nn.LayerNorm([out_features])) 795 | 796 | self.net = nn.Sequential(*self.net) 797 | 798 | if isinstance(nonlinearity, list): 799 | for layer_num, layer_name in enumerate(nonlinearity): 800 | self.net[layer_num].apply(layer_factory(layer_name)[1]) 801 | elif isinstance(nonlinearity, str): 802 | if self.weight_init is not None: 803 | self.net.apply(self.weight_init) 804 | 805 | if self.first_layer_init is not None: 806 | self.net[0].apply(self.first_layer_init) 807 | 808 | if set_bias is not None: 809 | self.net[-1][0].bias.data = set_bias * torch.ones_like(self.net[-1][0].bias.data) 810 | 811 | def forward(self, coords, batch_vecs=None): 812 | if self.skip_connect == None: 813 | output = self.net(coords) 814 | else: 815 | input = coords 816 | for i in range(len(self.net)): 817 | output = self.net[i](input) 818 | if i+1 in self.skip_connect: 819 | input = torch.cat([batch_vecs, output], dim=-1) 820 | else: 821 | input = output 822 | return output 823 | 824 | 825 | class CoordinateNet_autodecoder(nn.Module): 826 | '''A autodecoder network''' 827 | def __init__(self, latent_size=64, out_features=1, nl='sine', in_features=64+2, 828 | hidden_features=256, num_hidden_layers=3, num_pe_fns=6, 829 | w0=60,use_pe=False,skip_connect=None,dataset_size=100, 830 | outmost_nonlinearity=None,outermost_linear=True): 831 | super().__init__() 832 | 833 | self.nl = nl 834 | self.use_pe = use_pe 835 | self.latent_size = latent_size 836 | self.lat_vecs = torch.nn.Embedding(dataset_size, self.latent_size) 837 | torch.nn.init.normal_(self.lat_vecs.weight.data, 0.0, 1/ math.sqrt(self.latent_size)) 838 | 839 | if self.nl != 'sine' and use_pe: 840 | in_features = 2 * (2*num_pe_fns + 1)+latent_size 841 | 842 | if self.use_pe: 843 | self.pe = PositionalEncoding(num_encoding_functions=num_pe_fns) 844 | self.decoder = FCBlock(in_features=in_features, 845 | out_features=out_features, 846 | num_hidden_layers=num_hidden_layers, 847 | hidden_features=hidden_features, 848 | outermost_linear=outermost_linear, 849 | nonlinearity=nl, 850 | w0=w0,skip_connect=skip_connect,latent_dim=latent_size,outmost_nonlinearity=outmost_nonlinearity) 851 | self.mean = torch.mean(torch.mean(self.lat_vecs.weight.data.detach(), dim=1)).cuda() 852 | self.varience = torch.mean(torch.var(self.lat_vecs.weight.data.detach(), dim=1)).cuda() 853 | 854 | 855 | def forward(self, model_input,latent=None): 856 | coords = model_input['coords'].clone().detach().requires_grad_(True) 857 | if latent==None: 858 | batch_vecs = self.lat_vecs(model_input['idx']).unsqueeze(1).repeat(1,coords.shape[1],1) 859 | else: 860 | batch_vecs = latent.unsqueeze(1).repeat(1,coords.shape[1],1) 861 | 862 | if self.nl != 'sine' and self.use_pe: 863 | coords_pe = self.pe(coords) 864 | input = torch.cat([batch_vecs, coords_pe], dim=-1) 865 | output = self.decoder(input,batch_vecs) 866 | else: 867 | input = torch.cat([batch_vecs, coords], dim=-1) 868 | output = self.decoder(input,batch_vecs) 869 | 870 | return {'model_in': coords, 'model_out': output,'batch_vecs': batch_vecs, 'meta': model_input} --------------------------------------------------------------------------------