├── .gitignore
├── .gitmodules
├── README.md
├── analysis_1d_evo.ipynb
├── analysis_2d_evo.ipynb
├── analysis_2d_full.py
├── analysis_2d_rl.ipynb
├── argparser.py
├── assets
├── gif-MeshGraphNets+gt remeshing.gif
├── gif-MeshGraphNets+heuristic remeshing.gif
├── gif-ground_truth.gif
├── gif-lamp.gif
├── gif-lamp_no_remeshing.gif
├── lamp_architecture.png
└── lamp_poster.pdf
├── data
├── arcsimmesh_data
│ ├── README.md
│ └── __init__.py
└── mppde1d_data
│ └── __init__.py
├── datasets
├── README.md
├── arcsimmesh_dataset.py
├── datagen_square.ipynb
├── load_dataset.py
└── mppde1d_dataset.py
├── gnns.py
├── license
├── models.py
├── requirements.txt
├── results
└── README.md
├── train.py
├── utils.py
└── utils_model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Added by author
7 | data/*
8 | results/*
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "pytorch_net"]
2 | path = pytorch_net
3 | url = git@github.com:tailintalent/pytorch_net.git
4 | [submodule "MP_Neural_PDE_Solvers"]
5 | path = MP_Neural_PDE_Solvers
6 | url = git@github.com:tailintalent/MP_Neural_PDE_Solvers.git
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LAMP: Learning Controllable Adaptive Simulation for Multi-resolution Physics (ICLR 2023 Notable-Top-25%)
2 |
3 | [Paper](https://openreview.net/forum?id=PbfgkZ2HdbE) | [arXiv](https://arxiv.org/abs/2305.01122) | [Poster](https://github.com/snap-stanford/lamp/blob/master/assets/lamp_poster.pdf) | [Slides](https://docs.google.com/presentation/d/1cMRGe2qNIrzSNRTUtbsVUod_PvyhDHcHzEa8wfxiQsw/edit?usp=sharing) | [Tweet](https://twitter.com/tailin_wu/status/1653253117671272448) | [Project Page](https://snap.stanford.edu/lamp/)
4 |
5 | Official repo for the paper [Learning Controllable Adaptive Simulation for Multi-resolution Physics](https://openreview.net/forum?id=PbfgkZ2HdbE)
6 | [Tailin Wu*](https://tailin.org/), [Takashi Maruyama*](https://sites.google.com/view/tmaruyama/home), [Qingqing Zhao*](https://cyanzhao42.github.io/), [Gordon Wetzstein](https://stanford.edu/~gordonwz/), [Jure Leskovec](https://cs.stanford.edu/people/jure/)
7 | ICLR 2023 **Notable-Top-25%**.
8 |
9 | It is the first fully DL-based surrogate model that jointly learns the evolution model, and optimizes spatial resolutions to reduce computational cost, learned via reinforcement learning.
10 |
11 | We demonstrate that LAMP is able to adaptively trade-off computation to improve long-term prediction error, by performing spatial refinement and coarsening of the mesh. LAMP outperforms state-of-the-art (SOTA) deep learning surrogate models, with an average of 33.7% error reduction for 1D nonlinear PDEs, and outperforms SOTA MeshGraphNets + Adaptive Mesh Refinement in 2D mesh-based simulations.
12 |
13 |
14 |
15 | Learned remeshing & evolution by LAMP:
16 |
17 |
18 | ## Installation
19 |
20 | 1. First clone the directory. Then run the following command to initialize the submodules:
21 |
22 | ```code
23 | git submodule init; git submodule update
24 | ```
25 |
26 | (If showing error of no permission, need to first [add a new SSH key to your GitHub account](https://docs.github.com/en/authentication/connecting-to-github-with-ssh/adding-a-new-ssh-key-to-your-github-account).)
27 |
28 | 2. Install dependencies.
29 |
30 | First, create a new environment using [conda](https://docs.conda.io/en/latest/miniconda.html) (with python >= 3.7). Then install pytorch, torch-geometric and other dependencies as follows (the repository is run with the following dependencies. Other version of torch-geometric or deepsnap may work but there is no guarentee.)
31 |
32 | Install pytorch (replace "cu113" with appropriate cuda version. For example, cuda11.1 will use "cu111"):
33 | ```code
34 | pip install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/torch_stable.html
35 | ```
36 |
37 | Install torch-geometric. Run the following command:
38 | ```code
39 | pip install torch-scatter==2.0.9 -f https://data.pyg.org/whl/torch-1.10.2+cu113.html
40 | pip install torch-sparse==0.6.12 -f https://data.pyg.org/whl/torch-1.10.2+cu113.html
41 | pip install torch-geometric==1.7.2
42 | pip install torch-cluster==1.5.9 -f https://data.pyg.org/whl/torch-1.10.2+cu113.html
43 | ```
44 |
45 | Install other dependencies:
46 | ```code
47 | pip install -r requirements.txt
48 | ```
49 |
50 | If wanting to use wandb (--wandb=True), need to set up wandb, following [this link](https://docs.wandb.ai/quickstart).
51 |
52 | If wanting to run 2d mesh-based simulation, FEniCS needs to be installed:
53 |
54 | ```code
55 | conda install -c conda-forge fenics
56 | ```
57 |
58 |
59 | ## Dataset
60 |
61 | The dataset files can be downloaded via [this link](https://drive.google.com/drive/folders/1ld5I86mPC7wWTxPhbCtG2AcH0vLW3o25?usp=share_link).
62 | * To run 1D experiment, download the files under "mppde1d_data/" in the link into the "data/mppde1d_data/" folder in the local repo.
63 | * To run 2D mesh-based experiment, download the files under "arcsimmesh_data/" in the link into the "data/arcsimmesh_data/" folder in the local repo. Script for data generation is also provided ("datasets/datagen_square.ipynb".) To run the script, compile [ARCSim v0.2.1](http://graphics.berkeley.edu/resources/ARCSim/#:~:text=ArcSim%20is%20a%20simulation,detail%20of%20the%20simulated%20objects) and place the script in ARCSim folder. The detailed explanation for the attributes for the mesh-based dataset are provided under [datasets/README.md](https://github.com/snap-stanford/lamp/tree/master/datasets).
64 |
65 |
66 | ## Training
67 |
68 | Below we provide example commands for training LAMP.
69 |
70 | ### 1D nonlinear PDE:
71 |
72 | First, **pre-train** the evolution model for 1D:
73 |
74 | ```code
75 | python train.py --exp_id=evo-1d --date_time=2023-01-01 --dataset=mppde1df-E2-100-nt-250-nx-200 --time_interval=1 --data_dropout=node:0-0.3:0.1 --latent_size=64 --n_train=-1 --save_interval=5 --test_interval=5 --algo=gnnremesher --rl_coefs=None --input_steps=1 --act_name=silu --multi_step=1^2:0.1^3:0.1^4:0.1 --temporal_bundle_steps=25 --use_grads=False --is_y_diff=False --loss_type=mse --batch_size=16 --val_batch_size=16 --epochs=50 --opt=adam --weight_decay=0 --seed=0 --id=0 --verbose=1 --n_workers=0 --gpuid=0
76 | ```
77 |
78 | The learned model will be saved under `./results/{--exp_id}_{--date_time}/`, where the `{--exp_id}` and `{--date_time}` are specified in the above command. The filename has the format of `*{hash}_{machine_name}.p`, e.g. "mppde1df-E2-100-nt-250-nx-200_train_-1_algo_gnnremesher_..._Hash_mhkVkAaz_ampere3.p", then the `{hash}` is `mhkVkAaz` and `{machine_name}` is `ampere3`, where the `{hash}` is uniquely determined by **all** the argument settings in the [argparser.py](https://github.com/snap-stanford/lamp/blob/master/argparser.py) (therefore, as long as any argument setting is different, the filename will be different and will not overwrite each other).
79 |
80 | Then, **jointly train** the remeshing model via reinforcement learning (RL) and the evolution model. The `--load_dirname` below should use folder name `{exp_id}_{date_time}` where the evolution model is located (as specified above), and the `--load_filename` should use part of the filename that can uniquely identify this model file, and should include the `{hash}` of this model.
81 |
82 | ```code
83 | python train.py --load_dirname=evo-1d_2023-01-01 --load_filename=Q66bz42y --exp_id=rl-1d --date_time=2023-01-02 --wandb_project_name=rl-1d_2023-01-02 --wandb=True --dataset=mppde1df-E2-100-nt-250-nx-200 --time_interval=1 --data_dropout=None --latent_size=64 --n_train=-1 --input_steps=1 --act_name=elu --multi_step=1^2:0.1^3:0.1^4:0.1 --temporal_bundle_steps=25 --use_grads=False --is_y_diff=False --loss_type=mse --batch_size=128 --val_batch_size=128 --epochs=30 --opt=adam --weight_decay=0 --seed=0 --verbose=1 --n_workers=0 --gpuid=7 --algo=srlgnnremesher --reward_mode=lossdiff+statediff --reward_beta=0-0.5 --rl_data_dropout=uniform:2 --min_edge_size=0.0014 --rl_horizon=4 --reward_loss_coef=5 --rl_eta=1e-2 --actor_lr=5e-4 --value_lr=1e-4 --value_num_pool=1 --value_pooling_type=global_mean_pool --value_latent_size=32 --value_batch_norm=False --actor_batch_norm=True --rescale=10 --edge_attr=True --rl_gamma=0.9 --value_loss_coef=0.5 --max_grad_norm=2 --is_single_action=False --value_target_mode=vanilla --wandb_step_plot=100 --wandb_step=20 --save_iteration=1000 --save_interval=1 --test_interval=1 --gpuid=3 --lr=1e-4 --actor_critic_step=200 --evolution_steps=200 --reward_condition=True --max_action=20 --rl_is_finetune_evolution=True --rl_finetune_evalution_mode=policy:fine --id=0
84 | ```
85 |
86 | ### 2D mesh-based simulation:
87 |
88 | **Pre-train** the evolution model for 2D (need to have FEniCS installed, see "Installation" section:
89 |
90 | ```code
91 | export OMP_NUM_THREADS=6; python train.py --exp_id=evo-2d --date_time=2023-01-01 --dataset=arcsimmesh_square_annotated --time_interval=2 --data_dropout=None --n_train=-1 --save_interval=5 --algo=gnnremesher-evolution --rl_coefs=None --input_steps=2 --act_name=silu --multi_step=1 --temporal_bundle_steps=1 --edge_attr=True --use_grads=False --is_y_diff=False --loss_type=l2 --batch_size=10 --val_batch_size=10 --latent_size=56 --n_layers=8 --noise_amp=1e-2 --correction_rate=0.9 --epochs=100 --opt=adam --weight_decay=0 --is_mesh=True --seed=0 --id=0 --verbose=2 --test_interval=2 --n_workers=20 --gpuid=0
92 | ```
93 |
94 | Then, **jointly train** the remeshing model via RL and the evolution model:
95 |
96 | ```code
97 | export OMP_NUM_THREADS=6; python train.py --exp_id=2d_rl --wandb_project_name=2d_rerun --wandb=True --date_time=2023-02-26 --dataset=arcsimmesh_square_annotated_coarse_minlen008_interp_500 --time_interval=2 --n_train=-1 --latent_size=64 --load_dirname=evo-2d_2023_02_18 --load_filename=9UQLIKKc_ampere1 --input_steps=2 --act_name=elu --temporal_bundle_steps=1 --use_grads=False --is_y_diff=True --loss_type=l2 --epochs=300 --opt=adam --weight_decay=0 --verbose=1 --algo=srlgnnremesher --reward_mode=lossdiff+statediff --rl_data_dropout=None --min_edge_size=0.04 --actor_lr=5e-4 --value_lr=1e-4 --value_num_pool=1 --value_pooling_type=global_mean_pool --value_latent_size=64 --value_batch_norm=False --actor_batch_norm=True --rescale=10 --edge_attr=True --rl_gamma=0.9 --value_loss_coef=0.5 --max_grad_norm=20 --is_single_action=False --value_target_mode=vanilla --wandb_step_plot=50 --wandb_step=2 --id=0 --save_iteration=500 --save_interval=1 --test_interval=1 --is_mesh=True --is_unittest=False --rl_horizon=6 --multi_step=6 --rl_eta=2e-2 --reward_beta=0 --reward_condition=True --max_action=20 --rl_is_finetune_evolution=True --lr=1e-4 --actor_critic_step=200 --evolution_steps=100 --rl_finetune_evalution_mode=policy:fine --wandb=True --batch_size=64 --val_batch_size=64 --n_workers=6 --reward_loss_coef=1000 --evl_stop_gradient=True --noise_amp=0.01 --gpuid=5 --is_eval_sample=True --seed=256 --n_train=:-1 --soft_update=False --fine_tune_gt_input=True --policy_input_feature=coords --skip_coarse=False --skip_flip=True --processor_aggr=mean --fix_alt_evolution_model=True
98 | ```
99 |
100 | For commands for **baseline** models in 1D, see the README in [./MP_Neural_PDE_Solvers/](https://github.com/tailintalent/MP_Neural_PDE_Solvers).
101 |
102 | We also provide pre-trained evolution models directly for RL training [here](https://drive.google.com/drive/folders/1ioR5gjYeQaNQMvqrdMYMZI-0n5k5UM8f). Put the folders in the Google doc (e.g., "evo-1d_2023-01-01" for pre-trained evolution model for 1d, "evo-2d_2023-02-18" for pre-trained evolution model for 2d) under the ./results/ folder, and can then use the RL commands above to perform joint training.
103 |
104 | ## Analysis
105 |
106 | * For 1D experiments, to analyze the pretrained evolution model for LAMP, use [analysis_1D_evo.ipynb](https://github.com/snap-stanford/lamp/blob/master/analysis_1d_evo.ipynb).
107 |
108 | * For 1D experiments, to analyze the full model for LAMP and the baselines, use [analysis_1D_full.py](https://github.com/snap-stanford/lamp/blob/master/analysis_1d_full.py).
109 |
110 | * For 1D experiments, to analyze the baseline models (MP-PDE, FNO, CNN), use [./MP_Neural_PDE_Solvers/analysis.ipynb](https://github.com/tailintalent/MP_Neural_PDE_Solvers/blob/master/analysis.ipynb).
111 |
112 | * For 2D experiments, to analyze the pretrained evolution model for LAMP, use [analysis_2D_evo.ipynb](https://github.com/snap-stanford/lamp/blob/master/analysis_2d_evo.ipynb).
113 |
114 | * For 2D experiments, to analyze the full model for LAMP and the baselines, use [analysis_2D_full.py](https://github.com/snap-stanford/lamp/blob/master/analysis_2d_full.py) and [analysis_2d_rl.ipynb](https://github.com/snap-stanford/lamp/blob/master/analysis_2d_rl.ipynb).
115 |
116 | ## Visualization:
117 |
118 | Example visualization of learned remeshing & evolution:
119 |
120 | LAMP:
121 |
122 |
123 |
124 | LAMP (no remeshing):
125 |
126 |
127 |
128 | MeshGraphNets + ground-truth remeshing:
129 |
130 |
131 |
132 |
133 | MeshGraphNets + heuristic remeshing:
134 |
135 |
136 |
137 |
138 | Ground-truth (fine-grained):
139 |
140 |
141 |
142 | ## Related Projects
143 |
144 | * [LE-PDE](https://github.com/snap-stanford/le_pde) (NeurIPS 2022): Accelerate the simulation and inverse optimization of PDEs. Compared to state-of-the-art deep learning-based surrogate models (e.g., FNO, MP-PDE), it is up to 15x improvement in speed, while achieving competitive accuracy.
145 |
146 | * [CinDM](https://github.com/AI4Science-WestlakeU/cindm) (ICLR 2024 spotlight): We introduce a method that uses compositional generative models to design boundaries and initial states significantly more complex than the ones seen in training for physical simulations.
147 |
148 | * [BENO](https://github.com/AI4Science-WestlakeU/beno) (ICLR 2024): We introduce a boundary-embedded neural operator that incorporates complex boundary shape and inhomogeneous boundary values into the solving of Elliptic PDEs.
149 |
150 | ## Citation
151 | If you find our work and/or our code useful, please cite us via:
152 |
153 | ```bibtex
154 | @inproceedings{wu2023learning,
155 | title={Learning Controllable Adaptive Simulation for Multi-resolution Physics},
156 | author={Tailin Wu and Takashi Maruyama and Qingqing Zhao and Gordon Wetzstein and Jure Leskovec},
157 | booktitle={The Eleventh International Conference on Learning Representations},
158 | year={2023},
159 | url={https://openreview.net/forum?id=PbfgkZ2HdbE}
160 | }
161 | ```
162 |
--------------------------------------------------------------------------------
/analysis_1d_evo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "ca791c51-0ee5-40b0-b207-75ac8e2a491d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2\n",
12 | "\n",
13 | "import argparse\n",
14 | "from collections import OrderedDict\n",
15 | "import datetime\n",
16 | "import gc\n",
17 | "get_ipython().run_line_magic('matplotlib', 'inline')\n",
18 | "import matplotlib\n",
19 | "import matplotlib.pylab as plt\n",
20 | "from numbers import Number\n",
21 | "import numpy as np\n",
22 | "import pandas as pd\n",
23 | "pd.options.display.max_rows = 1500\n",
24 | "pd.options.display.max_columns = 200\n",
25 | "pd.options.display.width = 1000\n",
26 | "pd.set_option('max_colwidth', 400)\n",
27 | "import pdb\n",
28 | "import pickle\n",
29 | "import pprint as pp\n",
30 | "import time\n",
31 | "import torch\n",
32 | "import torch.nn as nn\n",
33 | "import torch.nn.functional as F\n",
34 | "from torch import optim\n",
35 | "from torch.utils.data import DataLoader\n",
36 | "from deepsnap.batch import Batch as deepsnap_Batch\n",
37 | "\n",
38 | "import sys, os\n",
39 | "sys.path.append(os.path.join(os.path.dirname(\"__file__\"), '..'))\n",
40 | "sys.path.append(os.path.join(os.path.dirname(\"__file__\"), '..', '..'))\n",
41 | "from lamp.argparser import arg_parse\n",
42 | "from lamp.datasets.load_dataset import load_data\n",
43 | "from lamp.gnns import get_data_dropout\n",
44 | "from lamp.models import load_model\n",
45 | "from lamp.pytorch_net.util import Interp1d_torch, groupby_add_keys, filter_df, get_unique_keys_df, Attr_Dict, Printer, get_num_params, get_machine_name, pload, pdump, to_np_array, get_pdict, reshape_weight_to_matrix, ddeepcopy as deepcopy, plot_vectors, record_data, filter_filename, Early_Stopping, str2bool, get_filename_short, print_banner, plot_matrices, get_num_params, init_args, filter_kwargs, to_string, COLOR_LIST\n",
46 | "from lamp.utils import p, update_legacy_default_hyperparam, EXP_PATH, deepsnap_to_pyg, LpLoss, to_cpu, to_tuple_shape, parse_multi_step, loss_op, get_cholesky_inverse, get_device, get_data_comb\n",
47 | "\n",
48 | "device = torch.device(\"cuda:5\")"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "id": "fd583f18-3ce5-469a-a89f-a19887163e00",
54 | "metadata": {},
55 | "source": [
56 | "## Functions:"
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "id": "df5ff29f-9493-4e3b-a5e7-9545e7c0a143",
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "# Analysis:\n",
67 | "def get_results_1d(\n",
68 | " all_hash,\n",
69 | " mode=\"best\",\n",
70 | " exclude_idx=(None,),\n",
71 | " dropout_mode=\"None\",\n",
72 | " n_rollout_steps=-1,\n",
73 | " is_full_eval=True,\n",
74 | " dirname=None,\n",
75 | " suffix=\"\",\n",
76 | "):\n",
77 | " \"\"\"\n",
78 | " Perform analysis on the 1D Burgers' benchmark.\n",
79 | "\n",
80 | " Args:\n",
81 | " all_hash: a list of hashes which indicates the experiments to load for analysis\n",
82 | " mode: choose from \"best\" (load the best model with lowest validation loss) or an integer, \n",
83 | " e.g. -1 (last saved model), -2 (second last saved model)\n",
84 | " is_full_eval: if True, evalute on all grount-truth points. E.g., assuming is_full_eval is True, \n",
85 | " if the prediction is on 50 vertices, it will be first interpolated to the vertices \n",
86 | " of ground-truth, then compute the loss w.r.t. ground-truth.\n",
87 | " dirname: if not None, will use the dirnaem provided. E.g. tailin-1d_2022-7-27\n",
88 | " suffix: suffix for saving the analysis result.\n",
89 | " \"\"\"\n",
90 | " \n",
91 | " isplot = True\n",
92 | " df_dict_list = []\n",
93 | " dirname_start = dirname\n",
94 | " for hash_str in all_hash:\n",
95 | " df_dict = {}\n",
96 | " df_dict[\"hash\"] = hash_str\n",
97 | " # Load model:\n",
98 | " is_found = False\n",
99 | " for dirname_core in [\n",
100 | " dirname_start,\n",
101 | " ]:\n",
102 | " filename = filter_filename(EXP_PATH + dirname_core, include=hash_str)\n",
103 | " if len(filename) == 1:\n",
104 | " is_found = True\n",
105 | " break\n",
106 | " if not is_found:\n",
107 | " print(f\"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.\")\n",
108 | " continue\n",
109 | " dirname = EXP_PATH + dirname_core\n",
110 | " if not dirname.endswith(\"/\"):\n",
111 | " dirname += \"/\"\n",
112 | "\n",
113 | " try:\n",
114 | " data_record = pload(dirname + filename[0])\n",
115 | " except Exception as e:\n",
116 | " # p.print(f\"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:\", banner_size=100)\n",
117 | " print(f\"error {e} in hash_str {hash_str}\")\n",
118 | " continue\n",
119 | " p.print(f\"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:\", banner_size=160)\n",
120 | " if isplot:\n",
121 | " plot_learning_curve(data_record)\n",
122 | " args = init_args(update_legacy_default_hyperparam(data_record[\"args\"]))\n",
123 | " args.filename = filename\n",
124 | " if mode == \"best\":\n",
125 | " model = load_model(data_record[\"best_model_dict\"], device=device)\n",
126 | " print(\"Load the model with best validation loss.\")\n",
127 | " else:\n",
128 | " assert isinstance(mode, int)\n",
129 | " print(f'Load the model at epoch {data_record[\"epoch\"][mode]}')\n",
130 | " model = load_model(data_record[\"model_dict\"][mode], device=device)\n",
131 | " model.eval()\n",
132 | " # pp.pprint(args.__dict__)\n",
133 | " kwargs = {}\n",
134 | " if data_record[\"best_model_dict\"][\"type\"].startswith(\"GNNPolicy\"):\n",
135 | " kwargs[\"is_deepsnap\"] = True\n",
136 | "\n",
137 | " # Load test dataset:\n",
138 | " args_test = deepcopy(args)\n",
139 | " multi_step = (250 - 50) // args_test.temporal_bundle_steps\n",
140 | " args_test.multi_step = f\"1^{multi_step}\"\n",
141 | " args_test.is_test_only = True\n",
142 | " args_test.n_train = \"-1\"\n",
143 | " n_test_traj = 128\n",
144 | " (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args_test)\n",
145 | " nx = int(args.dataset.split(\"-\")[2])\n",
146 | " time_stamps_effective = len(dataset_test) // n_test_traj\n",
147 | " for exclude_idx_ele in exclude_idx:\n",
148 | " loss_list = []\n",
149 | " pred_list = []\n",
150 | " y_list = []\n",
151 | " for i in range(n_test_traj):\n",
152 | " idx = i * time_stamps_effective + args_test.temporal_bundle_steps\n",
153 | " data = deepcopy(dataset_test[idx])\n",
154 | " data_ori = deepcopy(dataset_test[idx])\n",
155 | " if dropout_mode == \"None\":\n",
156 | " if exclude_idx_ele is not None:\n",
157 | " data = get_data_dropout(data, dropout_mode=\"node:0\", exclude_idx=exclude_idx_ele)\n",
158 | " else:\n",
159 | " data = get_data_dropout(data, dropout_mode=dropout_mode)\n",
160 | " data = data.to(device)\n",
161 | " preds, info = model(\n",
162 | " data,\n",
163 | " pred_steps=np.arange(1,n_rollout_steps+1) if n_rollout_steps != -1 else np.arange(1, max(parse_multi_step(args_test.multi_step).keys())+1),\n",
164 | " latent_pred_steps=None,\n",
165 | " is_recons=False,\n",
166 | " use_grads=False,\n",
167 | " use_pos=args.use_pos,\n",
168 | " is_y_diff=False,\n",
169 | " is_rollout=False,\n",
170 | " **kwargs\n",
171 | " )\n",
172 | " y = data.node_label[\"n0\"]\n",
173 | " if n_rollout_steps != -1:\n",
174 | " y = y[:,:25*n_rollout_steps]\n",
175 | " pred = preds[\"n0\"].reshape(y.shape)\n",
176 | " pred_list.append(pred.detach())\n",
177 | " y_list.append(y.detach())\n",
178 | " if is_full_eval:\n",
179 | " \"\"\"\n",
180 | " y_ori: [100, 175, 1]\n",
181 | " x_pos_ori: [100, 1]\n",
182 | " x_pos: [50, 1] for uniform:2\n",
183 | " pred: [50, 175, 1]\n",
184 | " Interp1d_torch(x, y, x_new):\n",
185 | " x: (B, N)\n",
186 | " y: (B, N)\n",
187 | " x_new: (B, P)\n",
188 | " \"\"\"\n",
189 | " y_ori = data_ori.node_label[\"n0\"]\n",
190 | " if n_rollout_steps != -1:\n",
191 | " y_ori = y_ori[:,:25*n_rollout_steps]\n",
192 | " x_pos_ori = data_ori.node_pos[\"n0\"] # [100, 1]\n",
193 | " x_pos = data.node_pos[\"n0\"] # [50, 1]\n",
194 | " fnc = Interp1d_torch()\n",
195 | " n_steps = y_ori.shape[1]\n",
196 | " x_pos_format = x_pos.transpose(0,1).expand(n_steps, x_pos.shape[0]) # [175, 50]\n",
197 | " pred_format = pred.squeeze(2).transpose(0,1)\n",
198 | " x_pos_ori_format = x_pos_ori.transpose(0,1).expand(n_steps, x_pos_ori.shape[0]).to(device) # [175, 100]\n",
199 | " pred_interp_format = fnc(x_pos_format, pred_format, x_pos_ori_format) # [175, 100]\n",
200 | " y_ori_format = y_ori.squeeze(2).transpose(0,1).to(device) # [175, 100]\n",
201 | " loss_ele = nn.MSELoss(reduction=\"sum\")(pred_interp_format, y_ori_format) / nx\n",
202 | " else:\n",
203 | " loss_ele = nn.MSELoss(reduction=\"sum\")(pred, y) / nx\n",
204 | " loss_list.append(loss_ele.item())\n",
205 | "\n",
206 | " loss_mean = np.mean(loss_list)\n",
207 | " pred_list = torch.stack(pred_list).squeeze(-1)\n",
208 | " y_list = torch.stack(y_list).squeeze(-1)\n",
209 | " df_dict[f\"loss_cumu_{exclude_idx_ele}\"] = loss_mean \n",
210 | " print(\"\\nTest for {} for exclude_idx={} is: {:.9f} at epoch {}, for {}/{} epochs\".format(hash_str, exclude_idx_ele, loss_mean, data_record['best_epoch'], len(data_record[\"train_loss\"]), args.epochs))\n",
211 | "\n",
212 | " mse_full = nn.MSELoss(reduction=\"none\")(pred_list, y_list)\n",
213 | " mse_time = to_np_array(mse_full.mean((0,1)))\n",
214 | " p.print(\"Learning curve:\", is_datetime=False, banner_size=100)\n",
215 | " plt.figure(figsize=(12,5))\n",
216 | " plt.subplot(1,2,1)\n",
217 | " plt.plot(mse_time)\n",
218 | " plt.xlabel(\"rollout step\")\n",
219 | " plt.ylabel(\"MSE\")\n",
220 | " plt.title(\"MSE vs. rollout step (linear scale)\")\n",
221 | " plt.subplot(1,2,2)\n",
222 | " plt.semilogy(mse_time)\n",
223 | " plt.xlabel(\"rollout step\")\n",
224 | " plt.ylabel(\"MSE\")\n",
225 | " plt.title(\"MSE vs. rollout step (log scale)\")\n",
226 | " plt.show()\n",
227 | " plt.figure(figsize=(6,5))\n",
228 | " plt.plot(mse_time.cumsum())\n",
229 | " plt.title(\"cumulative MSE vs. rollout step\")\n",
230 | " plt.xlabel(\"rollout step\")\n",
231 | " plt.ylabel(\"cumulative MSE\")\n",
232 | " plt.show()\n",
233 | "\n",
234 | " # Visualization:\n",
235 | " for idx in range(6,8):\n",
236 | " p.print(f\"Example {idx*128}:\", banner_size=100, is_datetime=False)\n",
237 | " data = deepcopy(dataset_test[idx*128]).to(device)\n",
238 | " if exclude_idx_ele is not None:\n",
239 | " data = get_data_dropout(data, dropout_mode=\"node:0\", exclude_idx=exclude_idx_ele)\n",
240 | " preds, info = model(\n",
241 | " data,\n",
242 | " pred_steps=np.arange(1,max(parse_multi_step(args_test.multi_step).keys())+1),\n",
243 | " latent_pred_steps=None,\n",
244 | " is_recons=False,\n",
245 | " use_grads=False,\n",
246 | " use_pos=args.use_pos,\n",
247 | " is_y_diff=False,\n",
248 | " is_rollout=False,\n",
249 | " **kwargs\n",
250 | " )\n",
251 | " y = data.node_label[\"n0\"]\n",
252 | " pred = preds[\"n0\"].reshape(y.shape)\n",
253 | " visualize(pred, y)\n",
254 | " visualize_paper(pred, y)\n",
255 | "\n",
256 | " p.print(f\"Individual prediction at rollout step {y.shape[1]}:\", banner_size=100, is_datetime=False)\n",
257 | " time_step = -1\n",
258 | " for idx in range(0, 20, 5):\n",
259 | " plt.figure(figsize=(6,4))\n",
260 | " plt.plot(to_np_array(pred_list[idx,:,time_step]), label=\"pred\")\n",
261 | " plt.plot(to_np_array(y_list[idx,:,time_step]), \"--\", label=\"y\")\n",
262 | " plt.legend()\n",
263 | " plt.show()\n",
264 | " df_dict[\"best_epoch\"] = data_record['best_epoch']\n",
265 | " df_dict[\"epoch\"] = len(data_record[\"train_loss\"])\n",
266 | " df_dict.update(args.__dict__)\n",
267 | " df_dict_list.append(df_dict)\n",
268 | " df = pd.DataFrame(df_dict_list)\n",
269 | " pdump(df, f\"df_1d{suffix}.p\")\n",
270 | " return df\n",
271 | "\n",
272 | "# Plotting:\n",
273 | "def plot_learning_curve(data_record):\n",
274 | " plt.figure(figsize=(16,6))\n",
275 | " plt.subplot(1,2,1)\n",
276 | " plt.plot(data_record[\"epoch\"], data_record[\"train_loss\"], label=\"train\")\n",
277 | " plt.plot(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"val_loss\"], label=\"val\")\n",
278 | " plt.plot(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"test_loss\"], label=\"test\")\n",
279 | " plt.title(\"Learning curve, linear scale\")\n",
280 | " plt.legend()\n",
281 | " plt.subplot(1,2,2)\n",
282 | " plt.semilogy(data_record[\"epoch\"], data_record[\"train_loss\"], label=\"train\")\n",
283 | " plt.semilogy(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"val_loss\"], label=\"val\")\n",
284 | " plt.semilogy(data_record[\"test_epoch\"] if \"test_epoch\" in data_record else data_record[\"epoch\"], data_record[\"test_loss\"], label=\"test\")\n",
285 | " plt.title(\"Learning curve, log scale\")\n",
286 | " plt.legend()\n",
287 | " plt.show()\n",
288 | "\n",
289 | "\n",
290 | "def plot_colorbar(matrix, vmax=None, vmin=None, cmap=\"seismic\", label=None):\n",
291 | " if vmax==None:\n",
292 | " vmax = matrix.max()\n",
293 | " vmin = matrix.min()\n",
294 | " im = plt.imshow(matrix,vmax=vmax,vmin=vmin,cmap=cmap)\n",
295 | " plt.title(label)\n",
296 | " im_ratio = matrix.shape[0]/matrix.shape[1]\n",
297 | " plt.colorbar(im,fraction=0.046*im_ratio,pad=0.04)\n",
298 | "\n",
299 | "\n",
300 | "def visualize(pred, gt, animate=False):\n",
301 | " if torch.is_tensor(gt):\n",
302 | " gt = to_np_array(gt)\n",
303 | " pred = to_np_array(pred)\n",
304 | " mse_over_t = ((gt-pred)**2).mean(axis=0).mean(axis=-1)\n",
305 | " \n",
306 | " if not animate:\n",
307 | " vmax = gt.max()\n",
308 | " vmin = gt.min()\n",
309 | " plt.figure(figsize=[15,5])\n",
310 | " plt.subplot(1,4,1)\n",
311 | " plot_colorbar(gt[:,:,0].T,label=\"gt\")\n",
312 | " plt.subplot(1,4,2)\n",
313 | " plot_colorbar(pred[:,:,0].T,label=\"pred\")\n",
314 | " plt.subplot(1,4,3)\n",
315 | " plot_colorbar((pred-gt)[:,:,0].T,vmax=np.abs(pred-gt).max(),vmin=(-1*np.abs(pred-gt).max()),label=\"diff\")\n",
316 | " plt.subplot(1,4,4)\n",
317 | " plt.plot(mse_over_t);plt.title(\"mse over t\");plt.yscale('log');\n",
318 | " plt.tight_layout()\n",
319 | " plt.show()\n",
320 | "\n",
321 | "def visualize_paper(pred, gt, is_save=False):\n",
322 | " idx = 6\n",
323 | " nx = pred.shape[0]\n",
324 | "\n",
325 | " fontsize = 14\n",
326 | " idx_list = np.arange(0, 200, 15)\n",
327 | " color_list = np.linspace(0.01, 0.9, len(idx_list))\n",
328 | " x_axis = np.linspace(0,16,nx)\n",
329 | " cmap = matplotlib.cm.get_cmap('jet')\n",
330 | "\n",
331 | " plt.figure(figsize=(16,5))\n",
332 | " plt.subplot(1,2,1)\n",
333 | " for i, idx in enumerate(idx_list):\n",
334 | " pred_i = to_np_array(pred[...,idx,:].squeeze())\n",
335 | " rgb = cmap(color_list[i])[:3]\n",
336 | " plt.plot(x_axis, pred_i, color=rgb, label=f\"t={np.round(i*0.3, 1)}s\")\n",
337 | " plt.ylabel(\"u(t,x)\", fontsize=fontsize)\n",
338 | " plt.xlabel(\"x\", fontsize=fontsize)\n",
339 | " plt.tick_params(labelsize=fontsize)\n",
340 | " # plt.legend(fontsize=10, bbox_to_anchor=[1,1])\n",
341 | " plt.xticks([0,8,16], [0,8,16])\n",
342 | " plt.ylim([-2.5,2.5])\n",
343 | " plt.title(\"Prediction\")\n",
344 | " if is_save:\n",
345 | " plt.savefig(f\"1D_E2-{nx}.pdf\", bbox_inches='tight')\n",
346 | "\n",
347 | " plt.subplot(1,2,2)\n",
348 | " for i, idx in enumerate(idx_list):\n",
349 | " y_i = to_np_array(gt[...,idx,:])\n",
350 | " rgb = cmap(color_list[i])[:3]\n",
351 | " plt.plot(x_axis, y_i, color=rgb, label=f\"t={np.round(i*0.3, 1)}s\")\n",
352 | " plt.ylabel(\"u(t,x)\", fontsize=fontsize)\n",
353 | " plt.xlabel(\"x\", fontsize=fontsize)\n",
354 | " plt.tick_params(labelsize=fontsize)\n",
355 | " plt.legend(fontsize=10, bbox_to_anchor=[1,1])\n",
356 | " plt.xticks([0,8,16], [0,8,16])\n",
357 | " plt.ylim([-2.5,2.5])\n",
358 | " plt.title(\"Ground-truth\")\n",
359 | " if is_save:\n",
360 | " plt.savefig(f\"1D_gt-{nx}.pdf\", bbox_inches='tight')\n",
361 | " plt.show()"
362 | ]
363 | },
364 | {
365 | "cell_type": "markdown",
366 | "id": "bc711711-1ca8-4944-92d3-848313c275e9",
367 | "metadata": {},
368 | "source": [
369 | "## Analysis:"
370 | ]
371 | },
372 | {
373 | "cell_type": "code",
374 | "execution_count": null,
375 | "id": "84b58196-d00d-4519-8e3b-a0789dbf9444",
376 | "metadata": {},
377 | "outputs": [],
378 | "source": [
379 | "\"\"\"\n",
380 | "all_hash is a list of hashes, each of which corresponds to one experiment.\n",
381 | " For example, if one experiment is saved under ./results/evo-1d_2023-01-01/mppde1d-E2-50_train_-1_algo_contrast_ebm_False_ebmt_cd_enc_cnn-s_evo_cnn_act_elu_hid_128_lo_rmse_recef_1.0_conef_1.0_nconv_4_nlat_1_clat_3_lf_True_reg_None_id_0_Hash_qvQry9QJ_ampere3.p\n",
382 | " Then, the \"qvQry9QJ\" (located at the end of the filename) is the {hash} of this file.\n",
383 | " The \"evo-1d_2023-01-01\" is the \"{--exp_id}_{--date_time}\" of the training command.\n",
384 | " all_hash can contain multiple hashes, and analyze them sequentially.\n",
385 | "\"\"\"\n",
386 | "all_hash = [\n",
387 | " \"Q66bz42y_ampere3\",\n",
388 | "]\n",
389 | "df9 = get_results_1d(\n",
390 | " all_hash,\n",
391 | " dirname=\"evo-1d_2023-01-01\",\n",
392 | " n_rollout_steps=7,\n",
393 | " dropout_mode=\"uniform:2\",\n",
394 | " is_full_eval=True,\n",
395 | " suffix=\"_0\")"
396 | ]
397 | }
398 | ],
399 | "metadata": {
400 | "kernelspec": {
401 | "display_name": "Python 3",
402 | "language": "python",
403 | "name": "python3"
404 | },
405 | "language_info": {
406 | "codemirror_mode": {
407 | "name": "ipython",
408 | "version": 3
409 | },
410 | "file_extension": ".py",
411 | "mimetype": "text/x-python",
412 | "name": "python",
413 | "nbconvert_exporter": "python",
414 | "pygments_lexer": "ipython3",
415 | "version": "3.9.5"
416 | }
417 | },
418 | "nbformat": 4,
419 | "nbformat_minor": 5
420 | }
421 |
--------------------------------------------------------------------------------
/analysis_2d_full.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import OrderedDict
3 | import datetime
4 | import gc
5 | import matplotlib
6 | import matplotlib.pylab as plt
7 | from numbers import Number
8 | import numpy as np
9 | import pandas as pd
10 | pd.options.display.max_rows = 1500
11 | pd.options.display.max_columns = 200
12 | pd.options.display.width = 1000
13 | pd.set_option('max_colwidth', 400)
14 | import pdb
15 | import pickle
16 | import pprint as pp
17 | import time
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch import optim
22 | from torch.utils.data import DataLoader
23 | from deepsnap.batch import Batch as deepsnap_Batch
24 | import xarray as xr
25 |
26 | import sys, os
27 | from argparser import arg_parse
28 | from lamp.datasets.load_dataset import load_data
29 | from lamp.gnns import GNNRemesher, Value_Model_Summation, Value_Model, GNNRemesherPolicy, GNNPolicySizing, GNNPolicyAgent, get_reward_batch, get_data_dropout, GNNPolicyAgent_Sampling
30 | from lamp.models import get_model, load_model, unittest_model, build_optimizer, test
31 |
32 | from lamp.pytorch_net.util import Attr_Dict, Batch, filter_filename, pload, pdump, Printer, get_time, init_args, update_args, clip_grad, set_seed, update_dict, filter_kwargs, plot_vectors, plot_matrices, make_dir, get_pdict, to_np_array, record_data, make_dir, str2bool, get_filename_short, print_banner, get_num_params, ddeepcopy as deepcopy, write_to_config
33 | from lamp.utils import p, update_legacy_default_hyperparam, EXP_PATH, seed_everything
34 |
35 | def plot_fig(info, data, index=20):
36 | vers_gt = data.reind_yfeatures["n0"][index].detach().cpu().numpy()
37 | faces_gt = data.yface_list["n0"][index]
38 |
39 | faces_gt = np.stack(faces_gt)
40 | batch_0_idx = np.where(vers_gt[faces_gt[:,0],0]<1.5)
41 | faces_gt = np.stack(faces_gt)[batch_0_idx]
42 |
43 | batch_0_idx = np.where(vers_gt[:,0]<1.5)
44 | vers_gt = vers_gt[batch_0_idx][:,:]
45 |
46 | plot0 = info['state_preds'][0]
47 | plot20 = info['state_preds'][index]
48 | plot20alt_gt_evl = info['state_preds_alt_gt_evl'][index]
49 | plot20alt_gt_mesh_gt_evl = info['state_preds_alt_gt_mesh_gt_evl'][index]
50 | plot20alt_008_gt_evl = info['state_preds_alt_008_gt_evl'][index]
51 | plot20heuristic_gt_evl = info['state_preds_heuristic_gt_evl'][index]
52 |
53 |
54 | from matplotlib.backends.backend_pdf import PdfPages
55 | fig = plt.figure(figsize=(30,6))
56 | n = 2
57 | ax0= fig.add_subplot(n,7,1,projection='3d')
58 | ax0.set_axis_off()
59 | ax0.set_xlim([-0.6, 0.6])
60 | ax0.set_ylim([-0.6, 0.6])
61 | ax0.set_zlim([-0.5, 0.1])
62 | ax0.view_init(30, 10)
63 | ax0.plot_trisurf(plot0['history'][-1][:,3].detach().cpu().numpy(),plot0['history'][-1][:,4].detach().cpu().numpy(), plot0['xfaces'].detach().cpu().numpy().T, plot0['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1))
64 | plt.title("plot0")
65 |
66 | ax0= fig.add_subplot(n,7,2,projection='3d')
67 | ax0.set_axis_off()
68 | ax0.set_xlim([-0.6, 0.6])
69 | ax0.set_ylim([-0.6, 0.6])
70 | ax0.set_zlim([-0.5, 0.1])
71 | ax0.view_init(30, 10)
72 | ax0.plot_trisurf(plot20alt_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20alt_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20alt_gt_evl['xfaces'].detach().cpu().numpy().T, plot20alt_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1))
73 | plt.title("plot20alt_gt_evl")
74 |
75 | ax0= fig.add_subplot(n,7,3,projection='3d')
76 | ax0.set_axis_off()
77 | ax0.set_xlim([-0.6, 0.6])
78 | ax0.set_ylim([-0.6, 0.6])
79 | ax0.set_zlim([-0.5, 0.1])
80 | ax0.view_init(30, 10)
81 | ax0.plot_trisurf(plot20alt_gt_mesh_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20alt_gt_mesh_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20alt_gt_mesh_gt_evl['xfaces'].detach().cpu().numpy().T, plot20alt_gt_mesh_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1))
82 | plt.title("plot20alt_gt_mesh_gt_evl")
83 |
84 | ax0= fig.add_subplot(n,7,4,projection='3d')
85 | ax0.set_axis_off()
86 | ax0.set_xlim([-0.6, 0.6])
87 | ax0.set_ylim([-0.6, 0.6])
88 | ax0.set_zlim([-0.5, 0.1])
89 | ax0.view_init(30, 10)
90 | ax0.plot_trisurf(plot20alt_008_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20alt_008_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20alt_008_gt_evl['xfaces'].detach().cpu().numpy().T, plot20alt_008_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1))
91 | plt.title("plot20alt_008_gt_evl")
92 |
93 | ax0= fig.add_subplot(n,7,5,projection='3d')
94 | ax0.set_axis_off()
95 | ax0.set_xlim([-0.6, 0.6])
96 | ax0.set_ylim([-0.6, 0.6])
97 | ax0.set_zlim([-0.5, 0.1])
98 | ax0.view_init(30, 10)
99 | ax0.plot_trisurf(plot20heuristic_gt_evl['history'][-1][:,3].detach().cpu().numpy(),plot20heuristic_gt_evl['history'][-1][:,4].detach().cpu().numpy(), plot20heuristic_gt_evl['xfaces'].detach().cpu().numpy().T, plot20heuristic_gt_evl['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1))
100 | plt.title("plot20heuristic_gt_evl")
101 |
102 |
103 | ax0= fig.add_subplot(n,7,6,projection='3d')
104 | ax0.set_axis_off()
105 | ax0.set_xlim([-0.6, 0.6])
106 | ax0.set_ylim([-0.6, 0.6])
107 | ax0.set_zlim([-0.5, 0.1])
108 | ax0.view_init(30, 10)
109 | ax0.plot_trisurf(plot20['history'][-1][:,3].detach().cpu().numpy(),plot20['history'][-1][:,4].detach().cpu().numpy(), plot20['xfaces'].detach().cpu().numpy().T, plot20['history'][-1][:,5].detach().cpu().numpy(), shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1))
110 | plt.title("plot20_policy")
111 |
112 |
113 | ax0= fig.add_subplot(n,7,7,projection='3d')
114 | ax0.set_axis_off()
115 | ax0.set_xlim([-0.6, 0.6])
116 | ax0.set_ylim([-0.6, 0.6])
117 | ax0.set_zlim([-0.5, 0.1])
118 | ax0.view_init(30, 10)
119 | ax0.plot_trisurf(vers_gt[:,3],vers_gt[:,4],faces_gt, vers_gt[:,5], shade=True, linewidth = 1., edgecolor = 'black', color=(9/255,237/255,249/255,1))
120 | plt.title("plot20gt")
121 |
122 |
123 |
124 | ax0= fig.add_subplot(n,7,7+1)
125 | ax0.set_axis_off()
126 | ax0.triplot(plot0['history'][-1][:,0].detach().cpu().numpy(),plot0['history'][-1][:,1].detach().cpu().numpy(), plot0['xfaces'].detach().cpu().numpy().T)
127 | plt.title("plot0")
128 |
129 | ax0= fig.add_subplot(n,7,7+2)
130 | ax0.set_axis_off()
131 | ax0.triplot(plot20alt_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20alt_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20alt_gt_evl['xfaces'].detach().cpu().numpy().T,)
132 | plt.title("plot20alt_gt_evl")
133 |
134 |
135 | ax0= fig.add_subplot(n,7,7+3)
136 | ax0.set_axis_off()
137 | ax0.triplot(plot20alt_gt_mesh_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20alt_gt_mesh_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20alt_gt_mesh_gt_evl['xfaces'].detach().cpu().numpy().T,)
138 | plt.title("plot20alt_gt_mesh_gt_evl")
139 |
140 | ax0= fig.add_subplot(n,7,7+4)
141 | ax0.set_axis_off()
142 | ax0.triplot(plot20alt_008_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20alt_008_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20alt_008_gt_evl['xfaces'].detach().cpu().numpy().T,)
143 | plt.title("plot20alt_008_gt_evl")
144 |
145 | ax0= fig.add_subplot(n,7,7+5)
146 | ax0.set_axis_off()
147 | ax0.triplot(plot20heuristic_gt_evl['history'][-1][:,0].detach().cpu().numpy(),plot20heuristic_gt_evl['history'][-1][:,1].detach().cpu().numpy(), plot20heuristic_gt_evl['xfaces'].detach().cpu().numpy().T,)
148 | plt.title("plot20heuristic_gt_evl")
149 |
150 |
151 | ax0= fig.add_subplot(n,7,7+6)
152 | ax0.set_axis_off()
153 | ax0.triplot(plot20['history'][-1][:,0].detach().cpu().numpy(),plot20['history'][-1][:,1].detach().cpu().numpy(), plot20['xfaces'].detach().cpu().numpy().T,)
154 | plt.title("plot20_policy")
155 |
156 |
157 | ax0= fig.add_subplot(n,7,7+7)
158 | ax0.set_axis_off()
159 | ax0.triplot(vers_gt[:,0],vers_gt[:,1],faces_gt)
160 | plt.title("plot20gt")
161 |
162 |
163 |
164 | # Analysis:
165 | def get_results_2d(
166 | all_hash,
167 | mode="best",
168 | exclude_idx=(None,),
169 | dirname=None,
170 | suffix="",device=None
171 | ):
172 | """
173 | Perform analysis on the 2D cloth' benchmark.
174 |
175 | Args:
176 | all_hash: a list of hashes which indicates the experiments to load for analysis
177 | mode: choose from "best" (load the best model with lowest validation loss) or an integer,
178 | e.g. -1 (last saved model), -2 (second last saved model)
179 | dirname: if not None, will use the dirnaem provided. E.g. tailin-1d_2022-7-27
180 | suffix: suffix for saving the analysis result.
181 | """
182 |
183 | isplot = True
184 | df_dict_list = []
185 | dirname_start = "2d_rl_reproduce_2023-02-26" if dirname is None else dirname
186 | for hash_str in all_hash:
187 | df_dict = {}
188 | df_dict["hash"] = hash_str
189 | is_found = False
190 | for dirname_core in [
191 | dirname_start,
192 | ]:
193 | filename = filter_filename(EXP_PATH + dirname_core, include=hash_str)
194 | if len(filename) == 1:
195 | is_found = True
196 | break
197 | if not is_found:
198 | print(f"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.")
199 | continue
200 | dirname = EXP_PATH + dirname_core
201 | if not dirname.endswith("/"):
202 | dirname += "/"
203 |
204 | try:
205 | data_record = pload(dirname + filename[0])
206 | except Exception as e:
207 | print(f"error {e} in hash_str {hash_str}")
208 | continue
209 |
210 | args = init_args(update_legacy_default_hyperparam(data_record["args"]))
211 | args.filename = filename
212 | if not("processor_aggr" in data_record["model_dict"][-1]["actor_model_dict"].keys()):
213 | data_record["model_dict"][-1]["actor_model_dict"]["processor_aggr"] = "max"
214 | if "best_model_dict" in data_record.keys():
215 | data_record["best_model_dict"]["actor_model_dict"]["processor_aggr"] = "max"
216 | model = load_model(data_record["best_model_dict"], device=device)
217 | evolution = load_model(data_record["best_evolution_model_dict"], device=device)
218 | print("Load the model with best validation loss.")
219 | model.eval()
220 | evolution.eval()
221 |
222 | # Load test dataset:
223 | args_test = deepcopy(args)
224 | args_test.dataset="arcsimmesh_square_annotated_coarse_minlen008_interp_500"
225 | args_test.multi_step = "1^20"
226 | args_test.n_train = "-1"
227 | args_test.is_train=False
228 | args_test.use_fineres_data=False
229 | args_test.show_missing_files = False
230 | args_test.input_steps=2
231 | args_test.time_interval=2
232 | args_test.val_batch_size = 1
233 | args_test.device = device
234 | args.device = device
235 | (train_dataset, test_dataset), (trian_loader, _, test_loader) = load_data(args_test)
236 | print(len(test_loader))
237 |
238 | args_test_gt_008 = deepcopy(args)
239 | args_test_gt_008.dataset="arcsimmesh_square_annotated_coarse_minlen008_interp_500_gt_500"
240 | args_test_gt_008.n_train = "-1"
241 | args_test_gt_008.multi_step = "1^20"
242 | args_test_gt_008.is_train=False
243 | args_test_gt_008.use_fineres_data=False
244 | args_test_gt_008.show_missing_files = False
245 | args_test_gt_008.input_steps=2
246 | args_test_gt_008.time_interval=2
247 | args_test_gt_008.val_batch_size = 1
248 | args_test_gt_008.device = device
249 | args.device = device
250 | (_, _), (_, _, test_loader_008) = load_data(args_test_gt_008)
251 | print(len(test_loader_008))
252 |
253 | args.noise_amp = 0
254 | return model, evolution, test_loader, test_loader_008, args
255 |
256 | def get_eval(model, test_loader, test_loader_008,evolution_model,beta=0,break_index=2,hashva=None,name=None,best_evolution_model=None,args=None):
257 | dicts = {}
258 | min_evolution_loss_rmse_average = 1000
259 | if True:
260 | evolution_loss_mse = 0
261 | evolution_loss_alt_gt_evl_mse = 0
262 | evolution_loss_alt_gt_mesh_gt_evl_mse = 0
263 | evolution_loss_alt_008_gt_evl_mse = 0
264 | evolution_loss_heuristic_gt_evl_mse = 0
265 |
266 |
267 | r_statediff = 0
268 | v_statediff = 0
269 | rewards = 0
270 | count = 0
271 | actual_count = 0
272 | state_size = 0
273 | state_size_remeshed = 0
274 | args.pred_steps=20
275 |
276 | list_evolution_loss_mse = []
277 | list_evolution_loss_mse_alt_gt_evl = []
278 | list_evolution_loss_mse_alt_gt_mesh_gt_evl = []
279 | list_evolution_loss_mse_alt_008_gt_evl = []
280 | list_evolution_loss_heuristic_gt_evl = []
281 |
282 | list_state_size_remeshed = []
283 | list_all_nodes = []
284 | list_all_nodes_remesh = []
285 | kwargs = {}
286 | kwargs["evolution_model_alt"] = best_evolution_model
287 |
288 | if not(os.path.exists("./results/LAMP_2d")):
289 | os.mkdir("./results/LAMP_2d")
290 | with torch.no_grad():
291 | for j, (data,data008) in enumerate(zip(test_loader,test_loader_008)):
292 | # if (j%91==0 and j>0) or j==0:
293 | if data.time_step in [10,30,50]:
294 | count += 1
295 | actual_count +=1
296 | if data.__class__.__name__ == "Attr_Dict":
297 | data = data.to(args.device)
298 | data008 = data008.to(args.device)
299 | else:
300 | data.to(args.device)
301 | data008.to(args.device)
302 | info, data_clone = model.get_loss(
303 | data,
304 | args,
305 | wandb=None,
306 | opt_evl=False,
307 | step_num=0,
308 | mode="test_gt_remesh_heursitc",
309 | beta=beta,
310 | is_gc_collect=False,
311 | evolution_model=evolution_model,
312 | data008 = data008,
313 | **kwargs
314 | )
315 | state_size_elm = 0
316 | state_size_elm_list = []
317 | for elem in data_clone.reind_yfeatures["n0"]:
318 | state_size_elm += elem.shape[0]
319 | state_size_elm_list.append(elem.shape[0])
320 | evolution_loss_mse += (info['evolution/loss_mse'].sum().item())
321 | state_size += state_size_elm
322 | evolution_loss_alt_gt_evl_mse += (info['evolution/loss_alt_gt_evl_mse'].sum().item())
323 | evolution_loss_alt_gt_mesh_gt_evl_mse += (info['evolution/loss_alt_gt_mesh_gt_evl_mse'].sum().item())
324 | evolution_loss_alt_008_gt_evl_mse += (info['evolution/loss_alt_008_gt_evl_mse'].sum().item())
325 | evolution_loss_heuristic_gt_evl_mse += (info['evolution/loss_heuristic_gt_evl_mse'].sum().item())
326 |
327 |
328 | state_size_remeshed += info['v/state_size'].sum().item()
329 |
330 | list_evolution_loss_mse.append(info['evolution/loss_mse'])
331 | list_evolution_loss_mse_alt_gt_evl.append(info['evolution/loss_alt_gt_evl_mse'])
332 | list_evolution_loss_mse_alt_gt_mesh_gt_evl.append(info['evolution/loss_alt_gt_mesh_gt_evl_mse'])
333 | list_evolution_loss_mse_alt_008_gt_evl.append(info['evolution/loss_alt_008_gt_evl_mse'])
334 | list_evolution_loss_heuristic_gt_evl.append(info['evolution/loss_heuristic_gt_evl_mse'])
335 |
336 | list_all_nodes.append(state_size_elm_list)
337 | list_all_nodes_remesh.append(info['v/state_size'])
338 |
339 | print(info['evolution/loss_mse'].shape[0])
340 | print("index",actual_count,"loss", (info['evolution/loss_mse'].sum().item()/state_size_elm),"Gt state_size_elm",state_size_elm,"remeshed_state_size",info['v/state_size'].sum().item())
341 | print((info['v/state_size'].reshape(-1))[::4])
342 | print("current running rms alt gt evl:",np.sqrt(evolution_loss_alt_gt_evl_mse/state_size))
343 | print("current running rms alt gt mesh gt evl :",np.sqrt(evolution_loss_alt_gt_mesh_gt_evl_mse/state_size))
344 | print("current running rms alt 008 gt evl:",np.sqrt(evolution_loss_alt_008_gt_evl_mse/state_size))
345 | print("current running rms heursitc gt evl:",np.sqrt(evolution_loss_heuristic_gt_evl_mse/state_size))
346 | print("current running rms:",np.sqrt(evolution_loss_mse/state_size))
347 |
348 |
349 |
350 | if not(os.path.exists("./results/LAMP_2d/{}".format(name))):
351 | os.mkdir("./results/LAMP_2d/{}".format(name))
352 | if not(os.path.exists("./results/LAMP_2d/{}/{}".format(name,count))):
353 | os.mkdir("./results/LAMP_2d/{}/{}".format(name,count))
354 | if not(os.path.exists("./results/LAMP_2d/{}/{}/{}".format(name,count,data.time_step))):
355 | os.mkdir("./results/LAMP_2d/{}/{}/{}".format(name,count,data.time_step))
356 | np.save("./results/LAMP_2d/{}/{}/{}/{}_{}.npy".format(name,count,data.time_step,name,actual_count),{"loss_mse":info['evolution/loss_mse'].detach().cpu().numpy(),"loss_alt_gt_evl_mse":info['evolution/loss_alt_gt_evl_mse'].detach().cpu().numpy,"loss_alt_gt_mesh_gt_evl_mse":info['evolution/loss_alt_gt_mesh_gt_evl_mse'].detach().cpu().numpy(),"loss_alt_008_gt_evl_mse":info['evolution/loss_alt_008_gt_evl_mse'].detach().cpu().numpy(),"loss_heuristic_gt_evl_mse":info['evolution/loss_heuristic_gt_evl_mse'].detach().cpu().numpy(),"state_size_elm_list":state_size_elm_list,"state_size":info['v/state_size'].detach().cpu().numpy()})
357 | np.save("./results/LAMP_2d/{}/{}/{}/{}_{}_all.npy".format(name,count,data.time_step,name,actual_count),{"info":info})
358 | for index in range(5,20,2):
359 | plot_fig(info, data_clone,index=index)
360 | plt.savefig("./results/LAMP_2d/{}/{}/{}/{}_{}_{}.png".format(name,count,data.time_step,name,actual_count,index))
361 | plt.close()
362 | torch.cuda.empty_cache()
363 | with open('./results/LAMP_2d/{}/summary.txt'.format(name), 'a') as f:
364 | f.write("current running rms alt gt evl: {}\n".format(np.sqrt(evolution_loss_alt_gt_evl_mse/state_size)))
365 | f.write("current running rms alt gt mesh gt evl : {}\n".format(np.sqrt(evolution_loss_alt_gt_mesh_gt_evl_mse/state_size)))
366 | f.write("current running rms alt 008 gt evl: {}\n".format(np.sqrt(evolution_loss_alt_008_gt_evl_mse/state_size)))
367 | f.write("current running rms heursitc gt evl: {}\n".format(np.sqrt(evolution_loss_heuristic_gt_evl_mse/state_size)))
368 | f.write("current running rms: {}\n".format(np.sqrt(evolution_loss_mse/state_size)))
369 |
370 | if count>=break_index:
371 | break
372 | # break
373 | evolution_loss_rmse_average = np.sqrt(evolution_loss_mse/state_size)
374 | evolution_loss_alt_gt_mesh_gt_evl_mse_average = np.sqrt(evolution_loss_alt_gt_mesh_gt_evl_mse/state_size)
375 | evolution_loss_alt_gt_evl_mse_average = np.sqrt(evolution_loss_alt_gt_evl_mse/state_size)
376 | evolution_loss_alt_008_gt_evl_mse_average = np.sqrt(evolution_loss_alt_008_gt_evl_mse/state_size)
377 | evolution_loss_heuristic_gt_evl_mse_average = np.sqrt(evolution_loss_heuristic_gt_evl_mse/state_size)
378 |
379 |
380 | state_size_remeshed = state_size_remeshed/actual_count
381 | print("count",count)
382 | print("{}, {}, evolution_loss_alt_gt_mesh_gt_evl_mse_average".format(hashva,beta),evolution_loss_alt_gt_mesh_gt_evl_mse_average)
383 | print("{}, {}, evolution_loss_alt_gt_evl_mse_average".format(hashva,beta),evolution_loss_alt_gt_evl_mse_average)
384 | print("{}, {}, evolution_loss_alt_008_gt_evl_mse_average".format(hashva,beta),evolution_loss_alt_008_gt_evl_mse_average)
385 | print("{}, {}, evolution_loss_heuristic_gt_evl_mse_average".format(hashva,beta),evolution_loss_heuristic_gt_evl_mse_average)
386 | print("{}, {}, evolution_loss_rmse_average".format(hashva,beta),evolution_loss_rmse_average, "state_size_remeshed", state_size_remeshed)
387 | with open('./results/LAMP_2d/{}/summary.txt'.format(name), 'a') as f:
388 | f.write("count {}".format(count))
389 | f.write("{}, {}, evolution_loss_alt_gt_mesh_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_alt_gt_mesh_gt_evl_mse_average))
390 | f.write("{}, {}, evolution_loss_alt_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_alt_gt_evl_mse_average))
391 | f.write("{}, {}, evolution_loss_alt_008_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_alt_008_gt_evl_mse_average))
392 | f.write("{}, {}, evolution_loss_heuristic_gt_evl_mse_average {}\n".format(hashva,beta,evolution_loss_heuristic_gt_evl_mse_average))
393 | f.write("{}, {}, evolution_loss_rmse_average {}\n".format(hashva,beta,evolution_loss_rmse_average))
394 |
395 | dicts= [info, data_clone, evolution_loss_rmse_average,state_size_remeshed,list_evolution_loss_mse,list_evolution_loss_mse_alt,list_all_nodes,list_all_nodes_remesh]
396 | torch.cuda.empty_cache()
397 | return info, data_clone, dicts, min_evolution_loss_rmse_average
398 |
399 |
400 | def run_hash(all_hashes,dirname=None,evo_dirname="evo-2d_2023_02_18",evo_hash="9UQLIKKc_ampere1",gpu=0):
401 |
402 | p = Printer()
403 |
404 | all_hashes = [all_hashes]
405 | print(all_hashes)
406 | load_best_model = False
407 | device = "cuda:{}".format(gpu)
408 | beta = 0
409 | constrain_edge_size = True
410 |
411 | mode = "best"
412 | evo_dirname = os.path.join(EXP_PATH, evo_dirname)
413 | if not dirname.endswith("/"):
414 | evo_dirname += "/"
415 | all_dict = {}
416 | isplot = False
417 | seed_everything(42)
418 | filename = filter_filename(evo_dirname, include=evo_hash)
419 | print(filename)
420 | try:
421 | data_record = pload(evo_dirname + filename[0])
422 | except Exception as e:
423 | print(f"error {e} in evo_hash {evo_hash}")
424 | p.print(f"Hash {evo_hash}, best model at epoch {data_record['best_epoch']}:", banner_size=160)
425 | args = init_args(update_legacy_default_hyperparam(data_record["args"]))
426 | args.filename = filename
427 | data_record['best_model_dict']['type'] = 'GNNRemesherPolicy'
428 | data_record['best_model_dict']['noise_amp'] = 0.
429 | data_record['best_model_dict']["correction_rate"] = 0.
430 | data_record['best_model_dict']["batch_size"] = 16
431 | best_gnn_evl_model = load_model(data_record["best_model_dict"], device=device)
432 | print("Load the model with best validation loss.")
433 | best_gnn_evl_model.eval()
434 |
435 | value = {}
436 | min_evolution_loss_rmse_average = 1000
437 | min_hash = None
438 |
439 | for hashva in all_hashes:
440 | seed_everything(42)
441 | model, evolution_model, test_loader, test_loader_008, args = get_results_2d([hashva],
442 | mode=-1,
443 | exclude_idx=(None,),
444 | dirname=dirname,
445 | suffix="",device=device)
446 | model.to(args.device)
447 | evolution_model.to(args.device)
448 | if load_best_model:
449 | evolution_model = best_gnn_evl_model.eval().to(args.device)
450 | if constrain_edge_size:
451 | model.actor.min_edge_size=max(0.04,model.actor.min_edge_size)
452 | model.eval()
453 | evolution_model.eval()
454 |
455 | name = "{}_{}_{}_{}_004".format(hashva,load_best_model,constrain_edge_size,beta)
456 | seed_everything(42)
457 | info, data, dicts, minval = get_eval(model, test_loader,test_loader_008, evolution_model,break_index=50*3,hashva=hashva,beta=beta,name=name,best_evolution_model=best_gnn_evl_model.eval(),args=args)
458 |
459 | if minval 0:
127 | p.print(i)
128 | p.print(i + 1)
129 | dataset = GraphDataset(graph_list, task='node', minimum_node_per_graph=1)
130 | dataset.n_simu = pyg_dataset.n_simu
131 | dataset.time_stamps = pyg_dataset.time_stamps
132 | return dataset
133 |
134 | train_val_fraction = 0.9
135 | train_fraction = args.train_fraction
136 | multi_step_dict = parse_multi_step(args.multi_step)
137 | max_pred_steps = max(list(multi_step_dict.keys()) + [1]) * args.temporal_bundle_steps
138 | filename_train_val = os.path.join(PDE_PATH, "deepsnap", "{}_train_val_in_{}_out_{}{}{}{}.p".format(
139 | args.dataset, args.input_steps * args.temporal_bundle_steps, max_pred_steps,
140 | "_itv_{}".format(args.time_interval) if args.time_interval > 1 else "",
141 | "_yvar_{}".format(args.is_y_variable_length) if args.is_y_variable_length is True else "",
142 | "_noise_{}".format(args.data_noise_amp) if args.data_noise_amp > 0 else "",
143 | ))
144 | filename_test = os.path.join(PDE_PATH, "deepsnap", "{}_test_in_{}_out_{}{}{}{}.p".format(
145 | args.dataset, args.input_steps * args.temporal_bundle_steps, max_pred_steps,
146 | "_itv_{}".format(args.time_interval) if args.time_interval > 1 else "",
147 | "_yvar_{}".format(args.is_y_variable_length) if args.is_y_variable_length is True else "",
148 | "_noise_{}".format(args.data_noise_amp) if args.data_noise_amp > 0 else "",
149 | ))
150 | make_dir(filename_train_val)
151 | is_to_deepsnap = True
152 | if (os.path.isfile(filename_train_val) or args.is_test_only) and os.path.isfile(filename_test) and args.n_train == "-1" and not args.dataset.startswith("arcsimmesh"):
153 | if not args.is_test_only:
154 | p.print(f"Loading {filename_train_val}")
155 | loaded = pickle.load(open(filename_train_val, "rb"))
156 | if isinstance(loaded, tuple):
157 | dataset_train, dataset_val = loaded
158 | else:
159 | dataset_train_val = loaded
160 | p.print(f"Loading {filename_test}")
161 | dataset_test = pickle.load(open(filename_test, "rb"))
162 | p.print("Loaded pre-saved deepsnap file at {}.".format(filename_test))
163 |
164 | else:
165 | p.print("{} does not exist. Generating...".format(filename_test))
166 | is_save = True # If True, will save generated deepsnap dataset.
167 |
168 | if args.dataset.startswith("mppde1d"):
169 | if not args.is_test_only:
170 | pyg_dataset_train = MPPDE1D(
171 | dataset=args.dataset,
172 | input_steps=args.input_steps * args.temporal_bundle_steps,
173 | output_steps=max_pred_steps,
174 | time_interval=args.time_interval,
175 | is_y_diff=args.is_y_diff,
176 | split="train",
177 | )
178 | pyg_dataset_val = MPPDE1D(
179 | dataset=args.dataset,
180 | input_steps=args.input_steps * args.temporal_bundle_steps,
181 | output_steps=max_pred_steps,
182 | time_interval=args.time_interval,
183 | is_y_diff=args.is_y_diff,
184 | split="valid",
185 | )
186 | pyg_dataset_test = MPPDE1D(
187 | dataset=args.dataset,
188 | input_steps=args.input_steps * args.temporal_bundle_steps,
189 | output_steps=max_pred_steps,
190 | time_interval=args.time_interval,
191 | is_y_diff=args.is_y_diff,
192 | split="test",
193 | )
194 | elif args.dataset.startswith("arcsimmesh"):
195 | if args.algo.startswith("rlgnnremesher"):
196 | max_pred_steps = args.rl_horizon
197 | if not args.is_test_only:
198 | pyg_dataset_train_val = ArcsimMesh(
199 | dataset=args.dataset,
200 | input_steps=args.input_steps * args.temporal_bundle_steps,
201 | output_steps=max_pred_steps,
202 | time_interval=args.time_interval,
203 | use_fineres_data=args.use_fineres_data,
204 | is_train=True
205 | #is_y_diff=args.is_y_diff,
206 | #split="train",
207 | )
208 | pyg_dataset_test = ArcsimMesh(
209 | dataset=args.dataset,
210 | input_steps=args.input_steps * args.temporal_bundle_steps,
211 | output_steps=max_pred_steps,
212 | time_interval=args.time_interval,
213 | is_shifted_data=args.is_shifted_data,
214 | use_fineres_data=args.use_fineres_data,
215 | is_train=False,
216 | #is_y_diff=args.is_y_diff,
217 | #split="test",
218 | )
219 | is_to_deepsnap = False
220 | else:
221 | raise
222 |
223 | if args.n_train != "-1":
224 | # Test overfitting:
225 | is_save = False
226 | if not args.is_test_only:
227 | if "pyg_dataset_train_val" in locals():
228 | pyg_dataset_train_val = get_elements(pyg_dataset_train_val, args.n_train)
229 | else:
230 | pyg_dataset_train_val = get_elements(pyg_dataset_train, args.n_train)
231 | pyg_dataset_test = pyg_dataset_train_val
232 | p.print(", using the following elements {}.".format(args.n_train))
233 | else:
234 | p.print(":")
235 |
236 | # Transform to deepsnap format:
237 | if is_to_deepsnap:
238 | if not args.is_test_only:
239 | if "pyg_dataset_train_val" in locals():
240 | dataset_train_val = to_deepsnap(pyg_dataset_train_val, args)
241 | else:
242 | dataset_train = to_deepsnap(pyg_dataset_train, args)
243 | dataset_val = to_deepsnap(pyg_dataset_val, args)
244 | dataset_test = to_deepsnap(pyg_dataset_test, args)
245 | else:
246 | if not args.is_test_only:
247 | dataset_train_val = pyg_dataset_train_val
248 | dataset_test = pyg_dataset_test
249 |
250 | # Save pre-processed dataset into file:
251 | if is_save:
252 | if not args.is_test_only:
253 | if "pyg_dataset_train_val" in locals():
254 | if not os.path.isfile(filename_train_val):
255 | pickle.dump(dataset_train_val, open(filename_train_val, "wb"))
256 | else:
257 | pickle.dump((dataset_train, dataset_val), open(filename_train_val, "wb"))
258 | try:
259 | pickle.dump(dataset_test, open(filename_test, "wb"))
260 | p.print("saved generated deepsnap dataset to {}".format(filename_test))
261 | except Exception as e:
262 | p.print(f"Cannot save dataset object. Reason: {e}")
263 |
264 |
265 | # Split into train, val and test:
266 | collate_fn = deepsnap_Batch.collate() if is_to_deepsnap else MeshBatch(is_absorb_batch=True, is_collate_tuple=True).collate()
267 | if not args.is_test_only:
268 | if args.n_train == "-1":
269 | if "dataset_train" in locals() and "dataset_val" in locals():
270 | dataset_train_val = (dataset_train, dataset_val)
271 | elif args.dataset_split_type == "standard":
272 | if args.dataset.startswith("VL") or args.dataset.startswith("PL") or args.dataset.startswith("PIL"):
273 | train_idx, val_idx = get_train_val_idx(len(dataset_train_val), chunk_size=200)
274 | dataset_train = dataset_train_val[train_idx]
275 | dataset_val = dataset_train_val[val_idx]
276 | else:
277 | num_train = int(len(dataset_train_val) * train_fraction)
278 | dataset_train, dataset_val = dataset_train_val[:num_train], dataset_train_val[num_train:]
279 | elif args.dataset_split_type == "random":
280 | train_idx, val_idx = get_train_val_idx_random(len(dataset_train_val), train_fraction=train_fraction)
281 | dataset_train, dataset_val = dataset_train_val[train_idx], dataset_train_val[val_idx]
282 | elif args.dataset_split_type == "order":
283 | n_train = int(len(dataset_train_val) * train_fraction)
284 | dataset_train, dataset_val = dataset_train_val[:n_train], dataset_train_val[n_train:]
285 | else:
286 | raise Exception("dataset_split_type '{}' is not valid!".format(args.dataset_split_type))
287 | else:
288 | # train, val, test are all the same as designated by args.n_train:
289 | dataset_train = deepcopy(dataset_train_val) if is_to_deepsnap else dataset_train_val
290 | dataset_val = deepcopy(dataset_train_val) if is_to_deepsnap else dataset_train_val
291 | train_loader = DataLoader(dataset_train, num_workers=args.n_workers, collate_fn=collate_fn,
292 | batch_size=args.batch_size, shuffle=True if args.dataset_split_type!="order" else False, drop_last=True)
293 | val_loader = DataLoader(dataset_val, num_workers=args.n_workers, collate_fn=collate_fn,
294 | batch_size=args.val_batch_size if not args.algo.startswith("supn") else 1, shuffle=False, drop_last=False)
295 | else:
296 | dataset_train_val = None
297 | train_loader, val_loader = None, None
298 | test_loader = DataLoader(dataset_test, num_workers=args.n_workers, collate_fn=collate_fn,
299 | batch_size=args.val_batch_size if not args.algo.startswith("supn") else 1, shuffle=False, drop_last=False)
300 | return (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader)
301 |
302 | class MeshBatch(object):
303 | def __init__(self, is_absorb_batch=False, is_collate_tuple=False):
304 | """
305 |
306 | Args:
307 | is_collate_tuple: if True, will collate inside the tuple.
308 | """
309 | self.is_absorb_batch = is_absorb_batch
310 | self.is_collate_tuple = is_collate_tuple
311 |
312 | def collate(self):
313 | import re
314 | if torch.__version__.startswith("1.9") or torch.__version__.startswith("1.10") or torch.__version__.startswith("1.11"):
315 | from torch._six import string_classes
316 | from collections import abc as container_abcs
317 | else:
318 | from torch._six import container_abcs, string_classes, int_classes
319 | from pstar import pdict, plist
320 | default_collate_err_msg_format = (
321 | "collate_fn: batch must contain tensors, numpy arrays, numbers, "
322 | "dicts or lists; found {}")
323 | np_str_obj_array_pattern = re.compile(r'[SaUO]')
324 | def default_convert(data):
325 | r"""Converts each NumPy array data field into a tensor"""
326 | elem_type = type(data)
327 | if isinstance(data, torch.Tensor):
328 | return data
329 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
330 | and elem_type.__name__ != 'string_':
331 | # array of string classes and object
332 | if elem_type.__name__ == 'ndarray' \
333 | and np_str_obj_array_pattern.search(data.dtype.str) is not None:
334 | return data
335 | return torch.as_tensor(data)
336 | elif isinstance(data, container_abcs.Mapping):
337 | return {key: default_convert(data[key]) for key in data}
338 | elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
339 | return elem_type(*(default_convert(d) for d in data))
340 | elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes):
341 | return [default_convert(d) for d in data]
342 | else:
343 | return data
344 |
345 | def collate_fn(batch):
346 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate."""
347 | # pdb.set_trace()
348 | elem = batch[0]
349 | elem_type = type(elem)
350 | if isinstance(elem, torch.Tensor):
351 | out = None
352 | if torch.utils.data.get_worker_info() is not None:
353 | # If we're in a background process, concatenate directly into a
354 | # shared memory tensor to avoid an extra copy
355 | numel = sum([x.numel() for x in batch])
356 | storage = elem.storage()._new_shared(numel)
357 | out = elem.new(storage)
358 | tensor = torch.cat(batch, 0, out=out)
359 | if self.is_absorb_batch:
360 | # pdb.set_trace()
361 | if tensor.shape[1] == 0:
362 | tensor = tensor.view(tensor.shape[0], 0)
363 | else:
364 | tensor = tensor.view(-1, *tensor.shape[2:])
365 | return tensor
366 | elif elem is None:
367 | return None
368 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
369 | and elem_type.__name__ != 'string_':
370 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
371 | # array of string classes and object
372 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
373 | raise TypeError(default_collate_err_msg_format.format(elem.dtype))
374 | return collate_fn([torch.as_tensor(b) for b in batch])
375 | elif elem.shape == (): # scalars
376 | return torch.as_tensor(batch)
377 | elif isinstance(elem, float):
378 | return torch.tensor(batch, dtype=torch.float64)
379 | elif isinstance(elem, int):
380 | return torch.tensor(batch)
381 | elif isinstance(elem, string_classes):
382 | return batch
383 | elif isinstance(elem, container_abcs.Mapping):
384 | Dict = {}
385 | for key in elem:
386 | if key == "node_feature":
387 | Dict["vers"] = collate_trans_fn([d[key] for d in batch])
388 | Dict[key] = collate_fn([d[key] for d in batch])
389 | Dict["batch"] = {"n0": []}
390 | batch_nodes = [d[key]["n0"] for d in batch]
391 | for i in range(len(batch_nodes)):
392 | item = torch.full((batch_nodes[i].shape[0],), i, dtype=torch.long)
393 | Dict["batch"]["n0"].append(item)
394 | Dict["batch"]["n0"] = torch.cat(Dict["batch"]["n0"])
395 | elif key in ["y_tar", "reind_yfeatures"]:
396 | # pdb.set_trace()
397 | Dict[key] = collate_ytar_trans_fn([d[key] for d in batch])
398 | elif key in ["history"]:
399 | Dict[key] = collate_fn([d[key] for d in batch])
400 | Dict["batch_history"] = {"n0": []}
401 | batch_nodes = [d[key]["n0"] for d in batch]
402 | for i in range(len(batch_nodes)):
403 | item = torch.full((batch_nodes[i][0].shape[0],), i, dtype=torch.long)
404 | Dict["batch_history"]["n0"].append(item)
405 | Dict["batch_history"]["n0"] = torch.cat(Dict["batch_history"]["n0"])
406 | elif key == "edge_index":
407 | Dict[key] = collate_edgeidshift_fn([d[key] for d in batch])
408 | elif key == "yedge_index":
409 | # pdb.set_trace()
410 | Dict[key] = collate_y_edgeidshift_fn([d[key] for d in batch])
411 | elif key == "xfaces":
412 | Dict[key] = collate_xfaceshift_fn([d[key] for d in batch])
413 | elif key in ["bary_indices", "hist_indices"]:
414 | #pdb.set_trace()
415 | Dict[key] = collate_bary_indices_fn([d[key] for d in batch])
416 | elif key in ["yface_list", "xface_list"]:
417 | Dict[key] = collate_fn([d[key] for d in batch])
418 | else:
419 | Dict[key] = collate_fn([d[key] for d in batch])
420 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict):
421 | Dict = elem.__class__(**Dict)
422 | return Dict
423 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple:
424 | return elem_type(*(collate_fn(samples) for samples in zip(*batch)))
425 | elif isinstance(elem, My_Tuple):
426 | it = iter(batch)
427 | elem_size = len(next(it))
428 | if not all(len(elem) == elem_size for elem in it):
429 | raise RuntimeError('each element in list of batch should be of equal size')
430 | transposed = zip(*batch)
431 | return elem.__class__([collate_fn(samples) for samples in transposed])
432 | elif isinstance(elem, tuple):
433 | # pdb.set_trace()
434 | if self.is_collate_tuple:
435 | #pdb.set_trace()
436 | if len(elem) == 0:
437 | return batch[0]
438 | elif isinstance(elem[0], torch.Tensor):
439 | newbatch = ()
440 | for i in range(len(elem)):
441 | newbatch = newbatch + tuple([torch.cat([tup[i] for tup in batch], dim=0)])
442 | return newbatch
443 | elif type(elem[0]) == list:
444 | newbatch = ()
445 | for i in range(len(elem)):
446 | cumsum = 0
447 | templist = []
448 | for k in range(len(batch)):
449 | shiftbatch = np.array(batch[k][i]) + cumsum
450 | cumsum = shiftbatch.max() + 1
451 | templist.extend(shiftbatch.tolist())
452 | newbatch = newbatch + (templist,)
453 | elif type(elem[0]).__module__ == np.__name__:
454 | newbatch = ()
455 | for i in range(len(elem)):
456 | cumsum = 0
457 | templist = []
458 | for k in range(len(batch)):
459 | shiftbatch = batch[k][i] + cumsum
460 | cumsum = shiftbatch.max() + 1
461 | templist.extend(shiftbatch.tolist())
462 | newbatch = newbatch + (templist,)
463 | else:
464 | newbatch = batch[0]
465 | return newbatch
466 | else:
467 | return batch
468 | elif isinstance(elem, container_abcs.Sequence):
469 | # check to make sure that the elements in batch have consistent size
470 | it = iter(batch)
471 | elem_size = len(next(it))
472 | if not all(len(elem) == elem_size for elem in it):
473 | raise RuntimeError('each element in list of batch should be of equal size')
474 | transposed = zip(*batch)
475 | return [collate_fn(samples) for samples in transposed]
476 | elif elem.__class__.__name__ == 'Dictionary':
477 | return batch
478 | elif elem.__class__.__name__ == 'DGLHeteroGraph':
479 | import dgl
480 | return dgl.batch(batch)
481 | raise TypeError(default_collate_err_msg_format.format(elem_type))
482 |
483 | def collate_bary_indices_fn(batch):
484 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate."""
485 | # pdb.set_trace()
486 | elem = batch[0]
487 | elem_type = type(elem)
488 | if isinstance(elem, container_abcs.Mapping):
489 | Dict = {key: collate_bary_indices_fn([d[key] for d in batch]) for key in elem}
490 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict):
491 | Dict = elem.__class__(**Dict)
492 | return Dict
493 | elif isinstance(elem, tuple):
494 | # pdb.set_trace()
495 | if self.is_collate_tuple:
496 | #pdb.set_trace()
497 | if type(elem[0]).__module__ == np.__name__:
498 | newbatch = ()
499 | for i in range(len(elem)):
500 | cumsum = 0
501 | templist = []
502 | for k in range(len(batch)):
503 | shiftbatch = batch[k][i] + cumsum
504 | cumsum = shiftbatch.max() + 1
505 | templist.extend(shiftbatch)
506 | newbatch = newbatch + (templist,)
507 | return newbatch
508 | else:
509 | return batch
510 | raise TypeError(default_collate_err_msg_format.format(elem_type))
511 |
512 | def collate_y_edgeidshift_fn(batch):
513 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate."""
514 | # pdb.set_trace()
515 | elem = batch[0]
516 | elem_type = type(elem)
517 | if isinstance(elem, torch.Tensor):
518 | out = None
519 | if torch.utils.data.get_worker_info() is not None:
520 | # If we're in a background process, concatenate directly into a
521 | # shared memory tensor to avoid an extra copy
522 | numel = sum([x.numel() for x in batch])
523 | storage = elem.storage()._new_shared(numel)
524 | out = elem.new(storage)
525 | cumsum = 0
526 | newbatch = []
527 | for i in range(len(batch)):
528 | shiftbatch = batch[i] + cumsum
529 | newbatch.append(shiftbatch)
530 | cumsum = shiftbatch.max().item() + 1
531 | tensor = torch.cat(newbatch, dim=1)
532 | return tensor
533 | elif isinstance(elem, container_abcs.Mapping):
534 | Dict = {key: collate_y_edgeidshift_fn([d[key] for d in batch]) for key in elem}
535 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict):
536 | Dict = elem.__class__(**Dict)
537 | return Dict
538 | elif isinstance(elem, tuple):
539 | if self.is_collate_tuple:
540 | if isinstance(elem[0], torch.Tensor):
541 | newbatch = ()
542 | for i in range(len(elem)):
543 | cumsum = 0
544 | tempbatch = []
545 | for tup in batch:
546 | shiftbatch = tup[i] + cumsum
547 | tempbatch.append(shiftbatch)
548 | cumsum = shiftbatch.max().item() + 1
549 | newbatch = newbatch + tuple([torch.cat(tempbatch, dim=-1)])
550 | return newbatch
551 | else:
552 | return batch
553 | raise TypeError(default_collate_err_msg_format.format(elem_type))
554 |
555 | def collate_edgeidshift_fn(batch):
556 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate."""
557 | # pdb.set_trace()
558 | elem = batch[0]
559 | elem_type = type(elem)
560 | if isinstance(elem, torch.Tensor):
561 | out = None
562 | if torch.utils.data.get_worker_info() is not None:
563 | # If we're in a background process, concatenate directly into a
564 | # shared memory tensor to avoid an extra copy
565 | numel = sum([x.numel() for x in batch])
566 | storage = elem.storage()._new_shared(numel)
567 | out = elem.new(storage)
568 | cumsum = 0
569 | newbatch = []
570 | for i in range(len(batch)):
571 | shiftbatch = batch[i] + cumsum
572 | newbatch.append(shiftbatch)
573 | cumsum = shiftbatch.max().item() + 1
574 | tensor = torch.cat(newbatch, dim=1)
575 | return tensor
576 | elif isinstance(elem, container_abcs.Mapping):
577 | Dict = {key: collate_edgeidshift_fn([d[key] for d in batch]) for key in elem}
578 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict):
579 | Dict = elem.__class__(**Dict)
580 | return Dict
581 | raise TypeError(default_collate_err_msg_format.format(elem_type))
582 |
583 | def collate_xfaceshift_fn(batch):
584 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate."""
585 | # pdb.set_trace()
586 | elem = batch[0]
587 | elem_type = type(elem)
588 | if isinstance(elem, torch.Tensor):
589 | out = None
590 | if torch.utils.data.get_worker_info() is not None:
591 | # If we're in a background process, concatenate directly into a
592 | # shared memory tensor to avoid an extra copy
593 | numel = sum([x.numel() for x in batch])
594 | storage = elem.storage()._new_shared(numel)
595 | out = elem.new(storage)
596 | cumsum = 0
597 | newbatch = []
598 | for i in range(len(batch)):
599 | shiftbatch = batch[i] + cumsum
600 | newbatch.append(shiftbatch)
601 | cumsum = shiftbatch.max().item() + 1
602 | tensor = torch.cat(newbatch, dim=-1)
603 | return tensor
604 | elif isinstance(elem, container_abcs.Mapping):
605 | Dict = {key: collate_xfaceshift_fn([d[key] for d in batch]) for key in elem}
606 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict):
607 | Dict = elem.__class__(**Dict)
608 | return Dict
609 | raise TypeError(default_collate_err_msg_format.format(elem_type))
610 |
611 | def collate_trans_fn(batch):
612 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate."""
613 | # pdb.set_trace()
614 | elem = batch[0]
615 | elem_type = type(elem)
616 | if isinstance(elem, torch.Tensor):
617 | out = None
618 | if torch.utils.data.get_worker_info() is not None:
619 | # If we're in a background process, concatenate directly into a
620 | # shared memory tensor to avoid an extra copy
621 | numel = sum([x.numel() for x in batch])
622 | storage = elem.storage()._new_shared(numel)
623 | out = elem.new(storage)
624 | batch = [batch[i] + 2*i for i in range(len(batch))]
625 | try:
626 | tensor = torch.cat(batch, 0, out=out)
627 | except:
628 | pdb.set_trace()
629 | if self.is_absorb_batch:
630 | # pdb.set_trace()
631 | # if tensor.shape[1] == 0:
632 | # tensor = tensor.view(tensor.shape[0]*tensor.shape[1], 0)
633 | # else:
634 | tensor = tensor.view(-1, *tensor.shape[2:])
635 | return tensor
636 | elif isinstance(elem, container_abcs.Mapping):
637 | Dict = {key: collate_trans_fn([d[key] for d in batch]) for key in elem}
638 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict):
639 | Dict = elem.__class__(**Dict)
640 | return Dict
641 | raise TypeError(default_collate_err_msg_format.format(elem_type))
642 |
643 | def collate_ytar_trans_fn(batch):
644 | r"""Puts each data field into a tensor with outer dimension batch size, adapted from PyTorch's default_collate."""
645 | # pdb.set_trace()
646 | elem = batch[0]
647 | elem_type = type(elem)
648 | if isinstance(elem, tuple):
649 | # pdb.set_trace()
650 | if self.is_collate_tuple:
651 | #pdb.set_trace()
652 | if len(elem) == 0:
653 | return batch[0]
654 | elif isinstance(elem[0], torch.Tensor):
655 | newbatch = ()
656 | for i in range(len(elem)):
657 | templist = [batch[j][i] + 2*j for j in range(len(batch))]
658 | newbatch = newbatch + (torch.cat(templist, dim=0),)
659 | return newbatch
660 | elif isinstance(elem, container_abcs.Mapping):
661 | Dict = {key: collate_ytar_trans_fn([d[key] for d in batch]) for key in elem}
662 | if isinstance(elem, pdict) or isinstance(elem, Attr_Dict):
663 | Dict = elem.__class__(**Dict)
664 | return Dict
665 | raise TypeError(default_collate_err_msg_format.format(elem_type))
666 | return collate_fn
667 |
--------------------------------------------------------------------------------
/datasets/mppde1d_dataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[ ]:
5 |
6 |
7 | import scipy.io
8 | import numpy as np
9 | import h5py
10 | import pickle
11 | import torch
12 | from torch_geometric.data import Dataset, Data
13 | import pdb
14 | import sys, os
15 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
16 | from lamp.utils import MPPDE1D_PATH, PDE_PATH
17 | from lamp.pytorch_net.util import Attr_Dict, plot_matrices
18 | from lamp.utils import get_root_dir, to_tuple_shape
19 | from MP_Neural_PDE_Solvers.common.utils import HDF5Dataset
20 | from MP_Neural_PDE_Solvers.equations.PDEs import *
21 |
22 |
23 | # In[ ]:
24 |
25 |
26 | class MPPDE1D(Dataset):
27 | def __init__(
28 | self,
29 | dataset="mppde1d-E1-100",
30 | input_steps=1,
31 | output_steps=1,
32 | time_interval=1,
33 | is_y_diff=False,
34 | split="train",
35 | transform=None,
36 | pre_transform=None,
37 | verbose=False,
38 | ):
39 | assert dataset.startswith("mppde1d")
40 | self.dataset = dataset
41 | self.dirname = MPPDE1D_PATH
42 | self.root = PDE_PATH
43 |
44 | if len(dataset.split("-")) == 3:
45 | _, self.mode, self.nx = dataset.split("-")
46 | self.nt_total, self.nx_total = 250, 200
47 | else:
48 | assert len(dataset.split("-")) == 7
49 | _, self.mode, self.nx, _, self.nt_total, _, self.nx_total = dataset.split("-")
50 | self.nt_total, self.nx_total = int(self.nt_total), int(self.nx_total)
51 | self.nx = int(self.nx)
52 | self.input_steps = input_steps
53 | self.output_steps = output_steps
54 | self.time_interval = time_interval
55 | self.is_y_diff = is_y_diff
56 | self.split = split
57 | assert self.split in ["train", "valid", "test"]
58 | self.verbose = verbose
59 |
60 | self.t_cushion_input = self.input_steps * self.time_interval if self.input_steps * self.time_interval > 1 else 1
61 | self.t_cushion_output = self.output_steps * self.time_interval if self.output_steps * self.time_interval > 1 else 1
62 |
63 | self.original_shape = (self.nx,)
64 | self.dyn_dims = 1 # density
65 |
66 | pde=CE(device="cpu")
67 | if (self.nt_total, self.nx_total) == (250, 200):
68 | path = os.path.join(PDE_PATH, MPPDE1D_PATH) + f'{pde}_{self.split}_{self.mode}.h5'
69 | else:
70 | path = os.path.join(PDE_PATH, MPPDE1D_PATH) + f'{pde}_{self.split}_{self.mode}_nt_{self.nt_total}_nx_{self.nx_total}.h5'
71 | print(f"Load dataset {path}")
72 | self.time_stamps = self.nt_total
73 | base_resolution=[self.nt_total, self.nx]
74 | super_resolution=[self.nt_total, self.nx_total]
75 |
76 | self.dataset_cache = HDF5Dataset(path, pde=pde,
77 | mode=self.split, base_resolution=base_resolution,
78 | super_resolution=super_resolution)
79 | self.n_simu = len(self.dataset_cache)
80 | self.time_stamps_effective = (self.time_stamps - self.t_cushion_input - self.t_cushion_output + self.time_interval) // self.time_interval
81 | super(MPPDE1D, self).__init__(self.root, transform, pre_transform)
82 |
83 | @property
84 | def raw_file_names(self):
85 | return []
86 |
87 | @property
88 | def processed_dir(self):
89 | return os.path.join(self.root, self.dirname)
90 |
91 | @property
92 | def processed_file_names(self):
93 | return []
94 |
95 | def download(self):
96 | # Download to `self.raw_dir`.
97 | pass
98 |
99 | def _process(self):
100 | import warnings
101 | from typing import Any, List
102 | from torch_geometric.data.makedirs import makedirs
103 | def _repr(obj: Any) -> str:
104 | if obj is None:
105 | return 'None'
106 | return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__())
107 |
108 | def files_exist(files: List[str]) -> bool:
109 | # NOTE: We return `False` in case `files` is empty, leading to a
110 | # re-processing of files on every instantiation.
111 | return len(files) != 0 and all([os.path.exists(f) for f in files])
112 |
113 | f = os.path.join(self.processed_dir, 'pre_transform.pt')
114 | if os.path.exists(f) and torch.load(f) != _repr(self.pre_transform):
115 | warnings.warn(
116 | f"The `pre_transform` argument differs from the one used in "
117 | f"the pre-processed version of this dataset. If you want to "
118 | f"make use of another pre-processing technique, make sure to "
119 | f"sure to delete '{self.processed_dir}' first")
120 |
121 | f = os.path.join(self.processed_dir, 'pre_filter.pt')
122 | if os.path.exists(f) and torch.load(f) != _repr(self.pre_filter):
123 | warnings.warn(
124 | "The `pre_filter` argument differs from the one used in the "
125 | "pre-processed version of this dataset. If you want to make "
126 | "use of another pre-fitering technique, make sure to delete "
127 | "'{self.processed_dir}' first")
128 |
129 | if files_exist(self.processed_paths): # pragma: no cover
130 | return
131 |
132 | makedirs(self.processed_dir)
133 | self.process()
134 |
135 | path = os.path.join(self.processed_dir, 'pre_transform.pt')
136 | if not os.path.isfile(path):
137 | torch.save(_repr(self.pre_transform), path)
138 | path = os.path.join(self.processed_dir, 'pre_filter.pt')
139 | if not os.path.isfile(path):
140 | torch.save(_repr(self.pre_filter), path)
141 |
142 | def get_edge_index(self):
143 | edge_index_filename = os.path.join(self.processed_dir, f"{self.dataset}_edge_index.p")
144 | mask_valid_filename = os.path.join(self.root, self.dirname, f"{self.dataset}_mask_index.p")
145 | if os.path.isfile(edge_index_filename) and os.path.isfile(mask_valid_filename):
146 | edge_index = pickle.load(open(edge_index_filename, "rb"))
147 | mask_valid = pickle.load(open(mask_valid_filename, "rb"))
148 | return edge_index, mask_valid
149 | mask_valid = torch.ones(self.original_shape).bool()
150 | #velo_invalid_ids = np.where(velo_invalid_mask.flatten())[0]
151 | rows, cols = (*self.original_shape, 1)
152 | cube = np.arange(rows * cols).reshape(rows, cols)
153 | edge_list = []
154 | for i in range(rows):
155 | for j in range(cols):
156 | if i + 1 < rows: #and cube[i, j] not in velo_invalid_ids and cube[i+1, j] not in velo_invalid_ids:
157 | edge_list.append([cube[i, j], cube[i+1, j]])
158 | edge_list.append([cube[i+1, j], cube[i, j]])
159 | if j + 1 < cols: #and cube[i, j]: #not in velo_invalid_ids and cube[i, j+1] not in velo_invalid_ids:
160 | edge_list.append([cube[i, j], cube[i, j+1]])
161 | edge_list.append([cube[i, j+1], cube[i, j]])
162 | edge_index = torch.LongTensor(edge_list).T
163 | pickle.dump(edge_index, open(edge_index_filename, "wb"))
164 | pickle.dump(mask_valid, open(mask_valid_filename, "wb"))
165 | return edge_index, mask_valid
166 |
167 | def process(self):
168 | pass
169 |
170 | def len(self):
171 | return self.time_stamps_effective * self.n_simu
172 |
173 | def get(self, idx):
174 | # assert self.time_interval == 1
175 | sim_id, time_id = divmod(idx, self.time_stamps_effective)
176 | _, data_traj, x_pos, param = self.dataset_cache[sim_id]
177 | if self.verbose:
178 | print(f"sim_id: {sim_id} time_id: {time_id} input: ({time_id * self.time_interval + self.t_cushion_input -self.input_steps * self.time_interval}, {time_id * self.time_interval + self.t_cushion_input}) output: ({time_id * self.time_interval + self.t_cushion_input}, {time_id * self.time_interval + self.t_cushion_input + self.output_steps * self.time_interval})")
179 | x_dens = torch.FloatTensor(np.stack([data_traj[time_id * self.time_interval + self.t_cushion_input + j] for j in range(-self.input_steps * self.time_interval, 0, self.time_interval)], -1))
180 | y_dens = torch.FloatTensor(np.stack([data_traj[time_id * self.time_interval + self.t_cushion_input + j] for j in range(0, self.output_steps * self.time_interval, self.time_interval)], -1)) # [1, rows, cols, output_steps, 1]
181 | edge_index, mask_valid = self.get_edge_index()
182 | param = torch.cat([torch.FloatTensor([ele]) for key, ele in param.items()])
183 | x_bdd = torch.ones(x_dens.shape[0])
184 | x_bdd[0] = 0
185 | x_bdd[-1] = 0
186 | x_pos = torch.FloatTensor(x_pos)[...,None]
187 | for dim in range(len(self.original_shape)):
188 | x_pos[..., dim] /= self.original_shape[dim]
189 |
190 | data = Data(
191 | x=x_dens.reshape(-1, *x_dens.shape[-1:], 1).clone(), # [number_nodes: 64 * 64, input_steps, 1]
192 | x_pos=x_pos, # [number_nodes: 128 * 128, 2]
193 | x_bdd=x_bdd[...,None],
194 | xfaces=torch.tensor([]),
195 | y=y_dens.reshape(-1, *y_dens.shape[-1:], 1).clone(), # [number_nodes: 64 * 64, input_steps, 1]
196 | edge_index=edge_index,
197 | mask=mask_valid,
198 | param=param,
199 | original_shape=self.original_shape,
200 | dyn_dims=self.dyn_dims,
201 | compute_func=(0, None),
202 | dataset=self.dataset,
203 | )
204 | # data = Attr_Dict(
205 | # node_feature={"n0": x_dens.reshape(-1, *x_dens.shape[-1:], 1).clone()},
206 | # node_label={"n0": y_dens.reshape(-1, *y_dens.shape[-1:], 1).clone()},
207 | # node_pos={"n0": x_pos},
208 | # x_bdd={"n0": x_bdd[...,None]},
209 | # xfaces={"n0": torch.tensor([])},
210 | # edge_index={("n0","0","n0"): edge_index},
211 | # mask={"n0": mask_valid},
212 | # param=My_Freeze_Tuple((("n0", param),)),
213 | # original_shape=My_Freeze_Tuple((("n0", self.original_shape),)),
214 | # dyn_dims=My_Freeze_Tuple((("n0", self.dyn_dims),)),
215 | # compute_func=My_Freeze_Tuple((("n0", (0, None)),)),
216 | # grid_keys=("n0",),
217 | # part_keys=(),
218 | # time_step=time_id,
219 | # sim_id=sim_id,
220 | # dataset=self.dataset,
221 | # )
222 | update_edge_attr_1d(data)
223 | return data
224 |
225 |
226 | def update_edge_attr_1d(data):
227 | dataset_str = to_tuple_shape(data.dataset)
228 | if dataset_str.split("-")[0] != "mppde1d":
229 | if hasattr(data, "node_feature"):
230 | edge_attr = data.node_pos["n0"][data.edge_index[("n0","0","n0")][0]] - data.node_pos["n0"][data.edge_index[("n0","0","n0")][1]]
231 | if dataset_str.split("-")[0] in ["mppde1de"]:
232 | data.edge_attr = {("n0","0","n0"): edge_attr}
233 | elif dataset_str.split("-")[0] in ["mppde1df", "mppde1dg", "mppde1dh"]:
234 | data.edge_attr = {("n0","0","n0"): torch.cat([edge_attr, edge_attr.abs()], -1)}
235 | else:
236 | raise
237 | if dataset_str.split("-")[0] in ["mppde1dg"]:
238 | data.x_bdd = {"n0": 1 - data.x_bdd["n0"]}
239 | else:
240 | edge_attr = data.x_pos[data.edge_index[0]] - data.x_pos[data.edge_index[1]]
241 | if dataset_str.split("-")[0] in ["mppde1de"]:
242 | data.edge_attr = edge_attr
243 | elif dataset_str.split("-")[0] in ["mppde1df", "mppde1dg", "mppde1dh"]:
244 | data.edge_attr = torch.cat([edge_attr, edge_attr.abs()], -1)
245 | else:
246 | raise
247 | if dataset_str.split("-")[0] in ["mppde1dg"]:
248 | data.x_bdd = 1 - data.x_bdd
249 | return data
250 |
251 |
252 | def get_data_pred(state_preds, step, data, **kwargs):
253 | """Get a new mppde1d Data from the state_preds at step "step" (e.g. 0,1,2....).
254 | Here we assume that the mesh does not change.
255 |
256 | Args:
257 | state_preds: has shape of [n_nodes, n_steps, feature_size:temporal_bundle_steps]
258 | data: a Deepsnap Data object
259 | **kwargs: keys for which need to use the value instead of data's.
260 | """
261 | is_deepsnap = hasattr(data, "node_feature")
262 | is_list = isinstance(state_preds, list)
263 | if is_list:
264 | assert len(state_preds[0].shape) == 3
265 | state_pred = state_preds[step].reshape(state_preds[step].shape[0], -1, data.node_feature["n0"].shape[-1])
266 | else:
267 | assert isinstance(state_preds, torch.Tensor)
268 | state_pred = state_preds[...,step,:].reshape(state_preds.shape[0], state_preds.shape[-1], 1)
269 | data_pred = Attr_Dict(
270 | node_feature={"n0": state_pred},
271 | node_label={"n0": None},
272 | node_pos={"n0": kwargs["x_pos"] if "x_pos" in kwargs else data.node_pos["n0"] if is_deepsnap else data.x_pos},
273 | x_bdd={"n0": kwargs["x_bdd"] if "x_bdd" in kwargs else data.x_bdd["n0"] if is_deepsnap else data.x_bdd},
274 | xfaces={"n0": torch.tensor([])},
275 | edge_index={("n0","0","n0"): kwargs["edge_index"] if "edge_index" in kwargs else data.edge_index[("n0","0","n0")] if is_deepsnap else data.edge_index},
276 | mask={"n0": data.mask["n0"] if is_deepsnap else data.mask},
277 | param={"n0": data.param["n0"][:1] if is_deepsnap else data.param[:1]},
278 | original_shape=to_tuple_shape(data.original_shape),
279 | dyn_dims=to_tuple_shape(data.dyn_dims),
280 | compute_func=to_tuple_shape(data.compute_func),
281 | grid_keys=("n0",),
282 | part_keys=(),
283 | dataset=to_tuple_shape(data.dataset),
284 | batch=kwargs["batch"] if "batch" in kwargs else data.batch,
285 | )
286 | update_edge_attr_1d(data_pred)
287 | return data_pred
288 |
289 |
290 | # In[ ]:
291 |
292 |
293 | if __name__ == "__main__":
294 | dataset = MPPDE1D(
295 | dataset="mppde1dh-E2-100",
296 | input_steps=1,
297 | output_steps=1,
298 | time_interval=1,
299 | is_y_diff=False,
300 | split="valid",
301 | transform=None,
302 | pre_transform=None,
303 | verbose=False,
304 | )
305 | # dataset2 = MPPDE1D(
306 | # dataset="mppde1de-E2-400-nt-1000-nx-400",
307 | # input_steps=1,
308 | # output_steps=1,
309 | # time_interval=5,
310 | # is_y_diff=False,
311 | # split="valid",
312 | # transform=None,
313 | # pre_transform=None,
314 | # verbose=True,
315 | # )
316 | import matplotlib.pylab as plt
317 | plt.plot(dataset[198].x.squeeze())
318 |
319 |
--------------------------------------------------------------------------------
/license:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 snap-stanford
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | wandb
2 | deepsnap==0.2.1
3 | matplotlib==3.4.3
4 | numpy==1.23.5
5 | plotly
6 | tqdm
7 | scipy
8 | h5py
9 | pandas
10 | pyyaml
11 | numba
12 | pstar
13 | xarray
14 |
--------------------------------------------------------------------------------
/results/README.md:
--------------------------------------------------------------------------------
1 | # Results
2 |
3 | This results/ folder stores the files containing the experiment results and model checkpoints. The files (with suffix of ".p") are stored under "results/{--exp_id}_{--date_time}/", where --exp_id and --date_time are specified in the command, indicating the experiment id and date for this batch of experiments. The ".p" file is a dictionary containing all necessary information. It can be loaded via: `data_record = pickle.load(open({FILENAME.p}, "rb"))`. The `data_record` dictionary contains the keys of (e.g.):
4 |
5 | * "model_dict": a list of model checkpoints, saved every --inspect_interval epochs. The model can be loaded via the command: `model = load_model(data_record["model_dict"][id])`, where id indicates which checkpoint you want to load from.
6 | * "epoch": a list of integers indicating the corresponding epoch number when the model_dict is saved.
7 | * "train_loss" (or "val_loss", "test_loss"): a list containing the loss for training (validation/test) at the corresponding epoch.
8 | * "last_model_dict": the last model_dict.
9 | * "last_optimizer_dict": optimizer state which can be used for resuming the training.
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | from deepsnap.batch import Batch as deepsnap_Batch
4 | import gc
5 | import numpy as np
6 | import pdb
7 | import pickle
8 | import pprint as pp
9 | import scipy
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils.data import DataLoader
13 | from torch import optim
14 | import torch.multiprocessing
15 | torch.multiprocessing.set_sharing_strategy('file_system')
16 | import time
17 |
18 | import sys, os
19 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..'))
20 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..'))
21 | from lamp.argparser import arg_parse
22 | from lamp.models import get_model, load_model, unittest_model, build_optimizer, test
23 | from lamp.gnns import GNNRemesher
24 | from lamp.datasets.load_dataset import load_data
25 | from lamp.pytorch_net.util import Attr_Dict, Batch, filter_filename, pload, pdump, Printer, get_time, init_args, update_args, clip_grad, set_seed, update_dict, filter_kwargs, plot_vectors, plot_matrices, make_dir, get_pdict, to_np_array, record_data, make_dir, Early_Stopping, str2bool, get_filename_short, print_banner, get_num_params, ddeepcopy as deepcopy, write_to_config
26 | from lamp.utils import EXP_PATH, MeshBatch
27 | from lamp.utils import p, update_legacy_default_hyperparam, get_grad_norm, loss_op_core, get_model_dict, get_elements, is_diagnose, get_keys_values, loss_op, to_tuple_shape, parse_multi_step, get_device, seed_everything
28 |
29 |
30 | # In[ ]:
31 |
32 | def find_hash_and_load(all_hash,mode=-1,exclude_idx=(None,),dirname=None,suffix=""):
33 | isplot = True
34 | df_dict_list = []
35 | dirname_start = "tailin-rl_2022-9-22/" if dirname is None else dirname
36 | for hash_str in all_hash:
37 | df_dict = {}
38 | df_dict["hash"] = hash_str
39 | # Load model:
40 | is_found = False
41 | for dirname_core in [
42 | dirname_start,
43 | "tailin-rl_2022-9-22/",
44 | "qq-rl_2022-11-5/",
45 | "qq-rl_2022-9-26/",
46 | "multiscale_cloth_2022-9-26/",
47 | ]:
48 | filename = filter_filename(EXP_PATH + dirname_core, include=hash_str)
49 | if len(filename) == 1:
50 | is_found = True
51 | break
52 | if not is_found:
53 | print(f"hash {hash_str} does not exist in {dirname}! Please pass in the correct dirname.")
54 | continue
55 | dirname = EXP_PATH + dirname_core
56 | if not dirname.endswith("/"):
57 | dirname += "/"
58 |
59 | try:
60 | data_record = pload(dirname + filename[0])
61 | except Exception as e:
62 | # p.print(f"Hash {hash_str}, best model at epoch {data_record['best_epoch']}:", banner_size=100)
63 | print(f"error {e} in hash_str {hash_str}")
64 | continue
65 | return data_record
66 |
67 | args = arg_parse()
68 | try:
69 | get_ipython().run_line_magic('matplotlib', 'inline')
70 | get_ipython().run_line_magic('load_ext', 'autoreload')
71 | get_ipython().run_line_magic('autoreload', '2')
72 | is_jupyter = True
73 | args.exp_id = "tailin-test"
74 | args.date_time = "8-27"
75 | # args.date_time = "{}-{}".format(datetime.datetime.now().month, datetime.datetime.now().day)
76 |
77 | # Train:
78 | args.epochs = 200
79 | args.contrastive_rel_coef = 0
80 | args.n_conv_blocks = 6
81 | args.latent_noise_amp = 1e-5
82 | args.multi_step = "1"
83 | args.latent_multi_step = "1^2^3^4"
84 | args.latent_loss_normalize_mode = "targetindi"
85 | args.channel_mode = "exp-16"
86 | args.batch_size = 20
87 | args.val_batch_size = 20
88 | args.reg_type = "None"
89 | args.reg_coef = 1e-4
90 | args.is_reg_anneal = True
91 | args.lr_scheduler_type = "cos"
92 | args.id = "test2"
93 | args.n_workers = 0
94 | args.plot_interval = 50
95 | args.temporal_bundle_steps = 1
96 |
97 | ##################################
98 | # RL algorithm:
99 | ##################################
100 | # args.algo chooses from "gnnremesher-evolution(+reward:32)", "rlgnnremesher^sizing", "rlgnnremesher^agent"
101 | args.algo = "rlgnnremesher^agent"
102 | if args.algo.startswith("rlgnnremesher"):
103 | args.rl_coefs = "None"
104 | args.rl_horizon = 4
105 | args.reward_mode = "lossdiff+statediff"
106 | args.reward_beta = "1"
107 | args.reward_src = "env"
108 | args.rl_lambda = 0.95
109 | args.rl_gamma = 0.99
110 | args.rl_rho = 1.
111 | args.rl_eta = 1e-4
112 | args.rl_critic_update_iterations = 10
113 | args.rl_data_dropout = "node:0-0.3:0.5"
114 |
115 | args.value_latent_size = 32
116 | args.value_num_pool = 1
117 | args.value_act_name = "elu"
118 | args.value_act_name_final = "linear"
119 | args.value_layer_norm = False
120 | args.value_batch_norm = False
121 | args.value_num_steps = 3
122 | args.value_pooling_type = "global_mean_pool"
123 | args.value_target_mode = "value-lambda"
124 |
125 | args.load_dirname = "tailin-multi_2022-8-27"
126 | args.load_filename = "IHvBKQ8K_ampere3"
127 |
128 | ##################################
129 | # Dataset and model:
130 | ##################################
131 | # args.dataset = "mppde1df-E2-100"
132 | # args.dataset = "arcsimmesh_square"
133 | args.dataset = "arcsimmesh_square_annotated"
134 | if args.dataset.startswith("mppde1d"):
135 | args.latent_size = 64
136 | args.act_name = "elu"
137 | args.use_grads = False
138 | args.n_train = "-1"
139 | args.epochs = 2000
140 | args.use_pos = False
141 | args.latent_size = 64
142 | args.contrastive_rel_coef = 0
143 | args.is_prioritized_dropout = False
144 | args.input_steps = 1
145 | args.multi_step = "1^2:0.1^3:0.1^4:0.1"
146 | args.temporal_bundle_steps = 25
147 | args.n_train = ":100"
148 | args.epochs = 2000
149 | args.test_interval = 100
150 | args.save_interval = 100
151 |
152 | args.data_dropout = "node:0-0.1:0.1"
153 | args.use_pos = False
154 | args.rl_coefs = "reward:0.1"
155 |
156 | # Data:
157 | args.time_interval = 1
158 | args.dataset_split_type = "random"
159 | args.train_fraction = 1
160 |
161 | # Model:
162 | args.evolution_type = "mlp-3-elu-2"
163 | args.forward_type = "Euler"
164 | args.act_name = "elu"
165 |
166 | args.gpuid = "7"
167 | args.is_unittest = True
168 |
169 | elif args.dataset.startswith("arcsimmesh"):
170 | args.exp_id = "takashi-2dtest"
171 | args.date_time = "{}-{}".format(datetime.datetime.now().month, datetime.datetime.now().day)
172 | args.algo = "gnnremesher-evolution"
173 | args.encoder_type = "cnn-s"
174 | args.evo_conv_type = "cnn"
175 | args.decoder_type = "cnn-tr"
176 | args.padding_mode = "zeros"
177 | args.n_conv_layers_latent = 3
178 | args.n_conv_blocks = 4
179 | args.n_latent_levs = 1
180 | args.is_latent_flatten = True
181 | args.latent_size = 16
182 | args.act_name = "elu"
183 | args.decoder_act_name = "rational"
184 | args.use_grads = False
185 | args.n_train = "-1"
186 | args.use_pos = False
187 | args.contrastive_rel_coef = 0
188 | args.is_prioritized_dropout = False
189 | args.input_steps = 2
190 | args.multi_step = "1"
191 | args.latent_multi_step = "1"
192 | args.temporal_bundle_steps = 1
193 | # args.static_encoder_type = "param-2-elu"
194 | args.static_latent_size = 16
195 | args.n_train = ":100"
196 | args.epochs = 20
197 | args.test_interval = 10
198 | args.save_interval = 10
199 |
200 | #args.data_dropout = "node:0-0.4"
201 | args.use_pos = False
202 | args.load_filename = "IHvBKQ8K_ampere3"
203 | args.rl_coefs = "reward:0.1"
204 |
205 | # Data:
206 | args.time_interval = 1
207 | args.dataset_split_type = "random"
208 | args.train_fraction = 1
209 |
210 | # Model:
211 | args.evolution_type = "mlp-3-elu-2"
212 | args.forward_type = "Euler"
213 | args.act_name = "elu"
214 | args.is_mesh = True
215 | args.edge_attr=True
216 | # args.edge_threshold=0.000001
217 | args.edge_threshold=0.
218 |
219 | args.gpuid = "3"
220 | args.is_unittest = True
221 |
222 | except:
223 | is_jupyter = False
224 |
225 | if args.dataset.startswith("mppde1d"):
226 | if args.dataset.endswith("-40"):
227 | args.output_padding_str = "0-0-0-0"
228 | elif args.dataset.endswith("-50"):
229 | args.output_padding_str = "1-0-1-0"
230 | elif args.dataset.endswith("-100"):
231 | args.output_padding_str = "1-1-0-0"
232 |
233 |
234 | # # 2. Load data and model:
235 |
236 | # In[ ]:
237 |
238 |
239 | set_seed(args.seed)
240 | (dataset_train_val, dataset_test), (train_loader, val_loader, test_loader) = load_data(args)
241 | p.print(f"Minibatches for train: {len(train_loader)}")
242 | p.print(f"Minibatches for val: {len(val_loader)}")
243 | p.print(f"Minibatches for test: {len(test_loader)}")
244 | args_test8 = deepcopy(args)
245 | if args.dataset.startswith("m"):
246 | args_test8.multi_step = "1^8"
247 | args_test8.pred_steps = 8
248 | args_test8.is_train = False
249 | else:
250 | args_test8.multi_step = "1^75"
251 | args_test8.pred_steps = 75
252 | args_test8.is_train = False
253 | args_test8.batch_size = 1
254 | args_test8.val_batch_size = 1
255 | # seed_everything(42)
256 | (_, dataset_test8), (_, val_loader8, test_loader8) = load_data(args_test8)
257 |
258 | p.print(f"Minibatches for val8: {len(val_loader8)}")
259 | p.print(f"Minibatches for test8: {len(test_loader8)}")
260 | val_loader8 = None
261 | device = get_device(args)
262 | args.device = device
263 | args_test8.device = device
264 |
265 | data = deepcopy(dataset_test[0])
266 | epoch = 0
267 |
268 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"):
269 | args_testfine = deepcopy(args)
270 | args_testfine.use_fineres_data = True
271 | (dataset_train_val_fine,_), (train_loader_fine, _, _) = load_data(args_testfine)
272 | p.print(f"Minibatches for trainfine: {len(train_loader_fine)}")
273 | else:
274 | args_testfine=None
275 |
276 | model = get_model(args, data, device)
277 | if args.algo.startswith("rlgnnremesher") or args.algo.startswith("srlgnnremesher"):
278 | device = get_device(args)
279 | loaded_dirname = EXP_PATH + args.load_dirname
280 | filenames = filter_filename(loaded_dirname, include=[args.load_filename])
281 | assert len(filenames) == 1, f"There are {len(filenames)} files under ./results/{args.load_dirname} that contain the str {args.load_filename}. Re-check the argument of --load_dirname and --load_filename."
282 | loaded_filename = os.path.join(loaded_dirname, filenames[0])
283 | data_record_load = pload(loaded_filename)
284 |
285 | args_load = init_args(update_legacy_default_hyperparam(data_record_load["args"]))
286 | args_load.multi_step = args.multi_step
287 | evolution_model = load_model(data_record_load["model_dict"][-1], device)
288 | if args.fix_alt_evolution_model:
289 | evolution_model_alt = load_model(data_record_load["model_dict"][-1], device)
290 | evolution_model_alt.to(device)
291 | evolution_model_alt.eval()
292 | else:
293 | evolution_model_alt = None
294 | if not args.rl_is_finetune_evolution:
295 | evolution_model.eval()
296 |
297 |
298 | if args.load_hash!="None":
299 | data_record_load = find_hash_and_load([args.load_hash])
300 | model.actor.load_state_dict(data_record_load["model_dict"][-1]["actor_model_dict"]["state_dict"])
301 | model.critic.load_state_dict(data_record_load["model_dict"][-1]["critic_model_dict"]["state_dict"])
302 | model.critic_target.load_state_dict(data_record_load["model_dict"][-1]["critic_model_dict"]["state_dict"])
303 | # model = load_model(data_record_load["model_dict"][-1],device=device)
304 | evolution_model = load_model(data_record_load["evolution_model_dict"][-1],device=device)
305 | if not args.rl_is_finetune_evolution:
306 | evolution_model.eval()
307 |
308 | # # 3. Training:
309 |
310 | # In[ ]:
311 |
312 |
313 | if args.algo.startswith("rlgnnremesher"):
314 | separate_params = [
315 | {'params': model.actor.parameters(), 'lr': args.actor_lr},
316 | {'params': model.critic.parameters(), 'lr': args.value_lr},
317 | ]
318 | opt, scheduler = build_optimizer(
319 | args, params=None,
320 | separate_params=separate_params,
321 | )
322 | elif args.algo.startswith("srlgnnremesher"):
323 | separate_params = [ {'params': model.parameters(), 'lr': args.actor_lr},]
324 | opt, scheduler = build_optimizer(args, params=None,separate_params=separate_params,)
325 | else:
326 | opt, scheduler = build_optimizer(args, model.parameters())
327 |
328 | if args.rl_is_finetune_evolution:
329 | opt_params = [{'params': evolution_model.parameters(), 'lr': args.lr}]
330 | opt_evolution, opt_scheduler = build_optimizer(
331 | args, params=None,
332 | separate_params=opt_params,
333 | )
334 |
335 | n_params_model = get_num_params(model)
336 | p.print("n_params_model: {}".format(n_params_model), end="")
337 | machine_name = os.uname()[1].split('.')[0]
338 | data_record = {"n_params_model": n_params_model, "args": update_dict(args.__dict__, "machine_name", machine_name),
339 | "best_train_model_dict": [], "best_train_loss": [], "best_train_loss_history":[]}
340 | early_stopping = Early_Stopping(patience=args.early_stopping_patience)
341 |
342 | short_str_dict = {
343 | "dataset": "",
344 | "n_train": "train",
345 | "algo": "algo",
346 | "act_name": "act",
347 | "latent_size": "hid",
348 | "multi_step": "mt",
349 | "temporal_bundle_steps": "tb",
350 | "loss_type": "lo",
351 | "gpuid": "gpu",
352 | "id": "id",
353 | }
354 |
355 | filename_short = get_filename_short(
356 | short_str_dict.keys(),
357 | short_str_dict,
358 | args_dict=args.__dict__,
359 | )
360 | filename = EXP_PATH + "{}_{}/".format(args.exp_id, args.date_time) + filename_short[:-2] + "_{}.p".format(machine_name)
361 | write_to_config(args, filename)
362 | args.filename = filename
363 | kwargs = {}
364 | if args.algo.startswith("rlgnnremesher") or args.algo.startswith("srlgnnremesher"):
365 | kwargs["evolution_model"] = evolution_model
366 | kwargs["evolution_model_alt"] = evolution_model_alt
367 | p.print(filename, banner_size=100)
368 | # print(model)
369 | make_dir(filename)
370 | best_val_loss = np.Inf
371 | if args.load_filename != "None":
372 | val_loss = np.Inf
373 | collate_fn = deepsnap_Batch.collate() if data.__class__.__name__ == "HeteroGraph" else MeshBatch(
374 | is_absorb_batch=True, is_collate_tuple=True).collate() if args.dataset.startswith("arcsimmesh") else Batch(
375 | is_absorb_batch=True, is_collate_tuple=True).collate()
376 | if args.is_unittest:
377 | unittest_model(model,
378 | collate_fn([data, data]), args, device, use_grads=args.use_grads, use_pos=args.use_pos, is_mesh=args.is_mesh,
379 | test_cases="all" if not (args.dataset.startswith("PIL") or args.dataset.startswith("PHIL")) else "model_dict", algo=args.algo,
380 | **kwargs
381 | )
382 | if args.is_tensorboard:
383 | from tensorboardX import SummaryWriter
384 | writer = SummaryWriter(EXP_PATH + '/log/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
385 | pp.pprint(args.__dict__)
386 |
387 | if args.wandb:
388 | import wandb
389 | args.wandb_project_name = name="_".join([args.wandb_project_name,args.algo])
390 | wandb.init(project=args.wandb_project_name, entity="multi-scale", name="_".join(["beta_",args.reward_beta]+filename.split("_")[-2:])[:-2] + f'_ntrain_{args.n_train}{"_" + args.id if args.id != "0" else ""}',
391 | config={"name": args.exp_id})
392 | wandb.watch(model, log_freq=1000, log="all")
393 | wandb.watch(evolution_model, log_freq=1000, log="all")
394 | wandb.config=vars(args)
395 | else:
396 | wandb = None
397 | step_num = 0
398 | opt_actor = None
399 | opt_evl = None
400 |
401 | # if (args.algo.startswith("rlgnnremesher") or args.algo.startswith("srlgnnremesher")) and args.wandb:
402 | # model.get_tested(
403 | # test_loader8,
404 | # args_test8,
405 | # current_epoch=0,
406 | # current_minibatch=0,
407 | # wandb=wandb,
408 | # step_num=step_num,
409 | # **kwargs
410 | # )
411 |
412 | if args.load_hash!="None":
413 | if "last_optimizer_dict" in data_record_load.keys():
414 | opt.load_state_dict(data_record_load["last_optimizer_dict"])
415 | if "last_evolution_optimizer_dict" in data_record_load.keys():
416 | opt_evolution.load_state_dict(data_record_load["last_evolution_optimizer_dict"])
417 | if "last_scheduler_dict" in data_record_load.keys():
418 | opt_scheduler.load_state_dict(data_record_load["last_scheduler_dict"])
419 |
420 |
421 |
422 | while epoch < args.epochs:
423 | total_loss = 0
424 | count = 0
425 |
426 | model.train()
427 | train_info = {}
428 | best_train_loss = np.Inf
429 | last_few_losses = []
430 | num_losses = 20
431 | t_start = time.time()
432 |
433 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"):
434 | train_loader_fine_iterator = iter(train_loader_fine)
435 | for j, data in enumerate(train_loader):
436 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"):
437 | try:
438 | data_fine = next(train_loader_fine_iterator)
439 | except StopIteration:
440 | train_loader_fine_iterator = iter(train_loader_fine)
441 | data_fine = next(dataloader_iterator)
442 | else:
443 | data_fine=None
444 |
445 | t_end = time.time()
446 | if args.verbose >= 2 and j % 100 == 0:
447 | p.print(f"Data loading time: {t_end - t_start:.6f}")
448 | if data.__class__.__name__ == "Attr_Dict":
449 | data = data.to(device)
450 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"): data_fine = data_fine.to(device)
451 | else:
452 | data.to(device)
453 | if args.rl_is_finetune_evolution and args.dataset.startswith("a"): data_fine.to(device)
454 | opt.zero_grad()
455 | if args.rl_is_finetune_evolution:
456 | opt_evolution.zero_grad()
457 | if args.actor_critic_step==None:
458 | opt_actor=True
459 | opt_evl = True
460 | else:
461 | if step_num%(args.actor_critic_step+args.evolution_steps)= 0:
625 | p.print(filename)
626 | record_data(data_record, [epoch, get_model_dict(model)], ["save_epoch", "model_dict"])
627 | if "evolution_model" in locals():
628 | record_data(data_record, [evolution_model.model_dict], ["evolution_model_dict"])
629 | with open(filename, "wb") as f:
630 | pickle.dump(data_record, f)
631 | if val_loss < best_val_loss:
632 | best_val_loss = val_loss
633 | data_record["best_model_dict"] = get_model_dict(model)
634 | data_record["best_optimizer_dict"] = opt.state_dict()
635 | data_record["best_scheduler_dict"] = scheduler.state_dict() if scheduler is not None else None
636 | if "evolution_model" in locals():
637 | data_record["best_evolution_model_dict"] = evolution_model.model_dict
638 | if args.rl_is_finetune_evolution:
639 | data_record["best_evolution_optimizer_dict"] = opt_evolution.state_dict()
640 | data_record["best_epoch"] = epoch
641 | data_record["last_model_dict"] = get_model_dict(model)
642 | data_record["last_optimizer_dict"] = opt.state_dict()
643 | data_record["last_scheduler_dict"] = scheduler.state_dict() if scheduler is not None else None
644 | data_record["last_epoch"] = epoch
645 | if "evolution_model" in locals():
646 | data_record["last_evolution_model_dict"] = evolution_model.model_dict
647 | if args.rl_is_finetune_evolution:
648 | data_record["last_evolution_optimizer_dict"] = opt_evolution.state_dict()
649 | p.print("12", precision="millisecond", is_silent=args.is_timing<1, avg_window=1)
650 |
651 | pdump(data_record, filename)
652 | if "to_stop" in locals() and to_stop:
653 | p.print("Early-stop at epoch {}.".format(epoch))
654 | break
655 | epoch += 1
656 | record_data(data_record, [epoch, get_model_dict(model)], ["save_epoch", "model_dict"])
657 | if "evolution_model" in locals():
658 | record_data(data_record, [evolution_model.model_dict], ["evolution_model_dict"])
659 | pdump(data_record, filename)
660 |
661 |
--------------------------------------------------------------------------------
/utils_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # In[ ]:
5 |
6 |
7 | import torch
8 | import math
9 | import numpy as np
10 | from torch import nn
11 | # from torchmeta.modules import MetaModule
12 | from collections import OrderedDict
13 | import copy
14 | import torch.nn.functional as F
15 |
16 |
17 | def get_conv_func(pos_dim, *args, **kwargs):
18 | if "reg_type_list" in kwargs:
19 | reg_type_list = kwargs.pop("reg_type_list")
20 | else:
21 | reg_type_list = None
22 | if pos_dim == 1:
23 | conv = nn.Conv1d(*args, **kwargs)
24 | elif pos_dim == 2:
25 | conv = nn.Conv2d(*args, **kwargs)
26 | elif pos_dim == 3:
27 | conv = nn.Conv3d(*args, **kwargs)
28 | else:
29 | raise Exception("The pos_dim can only be 1, 2 or 3!")
30 | if reg_type_list is not None:
31 | if "snn" in reg_type_list:
32 | conv = SpectralNorm(conv)
33 | elif "snr" in reg_type_list:
34 | conv = SpectralNormReg(conv)
35 | return conv
36 |
37 |
38 | def get_conv_trans_func(pos_dim, *args, **kwargs):
39 | if "reg_type_list" in kwargs:
40 | reg_type_list = kwargs.pop("reg_type_list")
41 | else:
42 | reg_type_list = None
43 | if pos_dim == 1:
44 | conv_trans = nn.ConvTranspose1d(*args, **kwargs)
45 | elif pos_dim == 2:
46 | conv_trans = nn.ConvTranspose2d(*args, **kwargs)
47 | elif pos_dim == 3:
48 | conv_trans = nn.ConvTranspose3d(*args, **kwargs)
49 | else:
50 | raise Exception("The pos_dim can only be 1, 2 or 3!")
51 | # The weight's output dim=1 for ConvTranspose
52 | if reg_type_list is not None:
53 | if "snn" in reg_type_list:
54 | conv_trans = SpectralNorm(conv_trans, dim=1)
55 | elif "snr" in reg_type_list:
56 | conv_trans = SpectralNormReg(conv_trans, dim=1)
57 | return conv_trans
58 |
59 |
60 | # ### Spectral Norm:
61 |
62 | # In[ ]:
63 |
64 |
65 | def l2normalize(v, eps=1e-12):
66 | return v / (v.norm() + eps)
67 |
68 |
69 | class SpectralNorm(nn.Module):
70 | def __init__(self, module, name='weight', power_iterations=1, dim=0):
71 | super(SpectralNorm, self).__init__()
72 | self.module = module
73 | self.name = name
74 | self.power_iterations = power_iterations
75 | self.dim = dim
76 | if not self._made_params():
77 | self._make_params()
78 |
79 | def reshape_weight_to_matrix(self, weight):
80 | weight_mat = weight
81 | if self.dim != 0:
82 | # permute dim to front
83 | weight_mat = weight_mat.permute(self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim])
84 | height = weight_mat.size(0)
85 | return weight_mat.reshape(height, -1)
86 |
87 | def _update_u_v(self):
88 | u = getattr(self.module, self.name + "_u")
89 | v = getattr(self.module, self.name + "_v")
90 | w = getattr(self.module, self.name + "_bar")
91 | w_mat = self.reshape_weight_to_matrix(w)
92 |
93 | height = w_mat.shape[0]
94 | for _ in range(self.power_iterations):
95 | v.data = l2normalize(torch.mv(torch.t(w_mat.data), u.data))
96 | u.data = l2normalize(torch.mv(w_mat.data, v.data))
97 |
98 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
99 | sigma = u.dot(w_mat.mv(v))
100 | setattr(self.module, self.name, w / sigma.expand_as(w))
101 |
102 | def _made_params(self):
103 | try:
104 | u = getattr(self.module, self.name + "_u")
105 | v = getattr(self.module, self.name + "_v")
106 | w = getattr(self.module, self.name + "_bar")
107 | return True
108 | except AttributeError:
109 | return False
110 |
111 | def _make_params(self):
112 | w = getattr(self.module, self.name)
113 | w_mat = self.reshape_weight_to_matrix(w)
114 |
115 | height = w_mat.shape[0]
116 | width = w_mat.shape[1]
117 |
118 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
119 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
120 | u.data = l2normalize(u.data)
121 | v.data = l2normalize(v.data)
122 | w_bar = nn.Parameter(w.data)
123 |
124 | del self.module._parameters[self.name]
125 |
126 | self.module.register_parameter(self.name + "_u", u)
127 | self.module.register_parameter(self.name + "_v", v)
128 | self.module.register_parameter(self.name + "_bar", w_bar)
129 |
130 | def forward(self, *args):
131 | if self.training:
132 | self._update_u_v()
133 | else:
134 | setattr(self.module, self.name, getattr(self.module, self.name + "_bar") / 1)
135 | return self.module.forward(*args)
136 |
137 |
138 | # ### SpectralNormReg:
139 |
140 | # In[ ]:
141 |
142 |
143 | class SpectralNormReg(nn.Module):
144 | def __init__(self, module, name='weight', power_iterations=1, dim=0):
145 | super(SpectralNormReg, self).__init__()
146 | self.module = module
147 | self.name = name
148 | self.power_iterations = power_iterations
149 | self.dim = dim
150 | if not self._made_params():
151 | self._make_params()
152 |
153 | def reshape_weight_to_matrix(self, weight):
154 | weight_mat = weight
155 | if self.dim != 0:
156 | # permute dim to front
157 | weight_mat = weight_mat.permute(self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim])
158 | height = weight_mat.size(0)
159 | return weight_mat.reshape(height, -1)
160 |
161 | def compute_snreg(self):
162 | u = getattr(self.module, self.name + "_u")
163 | v = getattr(self.module, self.name + "_v")
164 | w = getattr(self.module, self.name + "_bar")
165 | w_mat = self.reshape_weight_to_matrix(w)
166 |
167 | height = w_mat.shape[0]
168 | for _ in range(self.power_iterations):
169 | v.data = l2normalize(torch.mv(torch.t(w_mat.data), u.data))
170 | u.data = l2normalize(torch.mv(w_mat.data, v.data))
171 |
172 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
173 | sigma = u.dot(w_mat.mv(v))
174 | self.snreg = sigma.square() / 2
175 | setattr(self.module, self.name, w / 1) # Here the " / 1" is to prevent state_dict() to record self.module.weight
176 |
177 | def _made_params(self):
178 | try:
179 | u = getattr(self.module, self.name + "_u")
180 | v = getattr(self.module, self.name + "_v")
181 | w = getattr(self.module, self.name + "_bar")
182 | return True
183 | except AttributeError:
184 | return False
185 |
186 | def _make_params(self):
187 | w = getattr(self.module, self.name)
188 | w_mat = self.reshape_weight_to_matrix(w)
189 |
190 | height = w_mat.shape[0]
191 | width = w_mat.shape[1]
192 |
193 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
194 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
195 | u.data = l2normalize(u.data)
196 | v.data = l2normalize(v.data)
197 | w_bar = nn.Parameter(w.data)
198 |
199 | del self.module._parameters[self.name]
200 |
201 | self.module.register_parameter(self.name + "_u", u)
202 | self.module.register_parameter(self.name + "_v", v)
203 | self.module.register_parameter(self.name + "_bar", w_bar)
204 |
205 | def forward(self, *args):
206 | self.compute_snreg()
207 | return self.module.forward(*args)
208 |
209 |
210 | # ### Hessian regularization:
211 |
212 | # In[ ]:
213 |
214 |
215 | def get_Hessian_penalty(
216 | G,
217 | z,
218 | mode,
219 | k=2,
220 | epsilon=0.1,
221 | reduction=torch.max,
222 | return_separately=False,
223 | G_z=None,
224 | is_nondimensionalize=False,
225 | **G_kwargs
226 | ):
227 | """
228 | Adapted from https://github.com/wpeebles/hessian_penalty/ (Peebles et al. 2020).
229 | Note: If you want to regularize multiple network activations simultaneously, you need to
230 | make sure the function G you pass to hessian_penalty returns a list of those activations when it's called with
231 | G(z, **G_kwargs). Otherwise, if G returns a tensor the Hessian Penalty will only be computed for the final
232 | output of G.
233 |
234 | Args:
235 | G: Function that maps input z to either a tensor or a list of tensors (activations)
236 | z: Input to G that the Hessian Penalty will be computed with respect to
237 | mode: choose from "Hdiag", "Hoff" or "Hall", specifying the scope of Hessian values to perform sum square on.
238 | "Hall" will be the sum of "Hdiag" (for diagonal elements) and "Hoff" (for off-diagonal elements).
239 | k: Number of Hessian directions to sample (must be >= 2)
240 | epsilon: Amount to blur G before estimating Hessian (must be > 0)
241 | reduction: Many-to-one function to reduce each pixel/neuron's individual hessian penalty into a final loss
242 | return_separately: If False, hessian penalties for each activation output by G are automatically summed into
243 | a final loss. If True, the hessian penalties for each layer will be returned in a list
244 | instead. If G outputs a single tensor, setting this to True will produce a length-1
245 | list.
246 | :param G_z: [Optional small speed-up] If you have already computed G(z, **G_kwargs) for the current training
247 | iteration, then you can provide it here to reduce the number of forward passes of this method by 1
248 | :param G_kwargs: Additional inputs to G besides the z vector. For example, in BigGAN you
249 | would pass the class label into this function via y=
250 | :return: A differentiable scalar (the hessian penalty), or a list of hessian penalties if return_separately is True
251 | """
252 | if G_z is None:
253 | G_z = G(z, **G_kwargs)
254 | rademacher_size = torch.Size((k, *z.size())) # (k, N, z.size())
255 | if mode == "Hall":
256 | loss_diag = get_Hessian_penalty(G=G, z=z, mode="Hdiag", k=k, epsilon=epsilon, reduction=reduction, return_separately=return_separately, G_z=G_z, **G_kwargs)
257 | loss_offdiag = get_Hessian_penalty(G=G, z=z, mode="Hoff", k=k, epsilon=epsilon, reduction=reduction, return_separately=return_separately, G_z=G_z, **G_kwargs)
258 | if return_separately:
259 | loss = []
260 | for loss_i_diag, loss_i_offdiag in zip(loss_diag, loss_offdiag):
261 | loss.append(loss_i_diag + loss_i_offdiag)
262 | else:
263 | loss = loss_diag + loss_offdiag
264 | return loss
265 | elif mode == "Hdiag":
266 | xs = epsilon * complex_rademacher(rademacher_size, device=z.device)
267 | elif mode == "Hoff":
268 | xs = epsilon * rademacher(rademacher_size, device=z.device)
269 | else:
270 | raise
271 | second_orders = []
272 |
273 | if mode == "Hdiag" and isinstance(G, nn.Module):
274 | # Use the complex64 dtype:
275 | dtype_ori = next(iter(G.parameters())).dtype
276 | G.type(torch.complex64)
277 | if isinstance(G, nn.Module):
278 | G_wrapper = get_listified_fun(G)
279 | G_z = listity_tensor(G_z)
280 | else:
281 | G_wrapper = G
282 |
283 | for x in xs: # Iterate over each (N, z.size()) tensor in xs
284 | central_second_order = multi_layer_second_directional_derivative(G_wrapper, z, x, G_z, epsilon, **G_kwargs)
285 | second_orders.append(central_second_order) # Appends a tensor with shape equal to G(z).size()
286 | loss = multi_stack_metric_and_reduce(second_orders, mode, reduction, return_separately) # (k, G(z).size()) --> scalar
287 |
288 | if mode == "Hdiag" and isinstance(G, nn.Module):
289 | # Revert back to original dtype:
290 | G.type(dtype_ori)
291 |
292 | if is_nondimensionalize:
293 | # Multiply a factor ||z||_2^2 so that the result is dimensionless:
294 | factor = z.square().mean()
295 | if return_separately:
296 | loss = [ele * factor for ele in loss]
297 | else:
298 | loss = loss * factor
299 | return loss
300 |
301 |
302 | def listity_tensor(tensor):
303 | """Turn the output features (except for the first batch dimension) of a function into a list
304 |
305 | Args:
306 | tensor: has shape [B, d1, d2, ...]
307 | """
308 | batch_size = tensor.shape[0]
309 | shape = tensor.shape[1:]
310 | tensor_reshape = tensor.reshape(batch_size, -1)
311 | tensor_listify = [tensor_reshape[:, i] for i in range(tensor_reshape.shape[1])]
312 | return tensor_listify
313 |
314 |
315 | def get_listified_fun(G):
316 | def fun(z, **Gkwargs):
317 | G_out = G(z, **Gkwargs)
318 | return listity_tensor(G_out)
319 | return fun
320 |
321 |
322 | def rademacher(shape, device='cpu'):
323 | """Creates a random tensor of size [shape] under the Rademacher distribution (P(x=1) == P(x=-1) == 0.5)"""
324 | return torch.randint(2, size=shape, device=device).float() * 2 - 1
325 |
326 |
327 | def complex_rademacher(shape, device='cpu'):
328 | """Creates a random tensor of size [shape] with (P(x=1) == P(x=-1) == P(x=1j) == P(x=-1j) == 0.25)"""
329 | collection = torch.from_numpy(np.array([1., -1, 1j, -1j])).type(torch.complex64).to(device)
330 | x = x.randint(4, size=shape, device=device) # Creates random tensor of 0, 1, 2, 3
331 | return collection[x] # Map tensor of 0, 1, 2, 3 to 1, -1, 1j, -1j
332 |
333 |
334 | def multi_layer_second_directional_derivative(G, z, x, G_z, epsilon, **G_kwargs):
335 | """Estimates the second directional derivative of G w.r.t. its input at z in the direction x"""
336 | G_to_x = G(z + x, **G_kwargs)
337 | G_from_x = G(z - x, **G_kwargs)
338 |
339 | G_to_x = listify(G_to_x)
340 | G_from_x = listify(G_from_x)
341 | G_z = listify(G_z)
342 |
343 | eps_sqr = epsilon ** 2
344 | sdd = [(G2x - 2 * G_z_base + Gfx) / eps_sqr for G2x, G_z_base, Gfx in zip(G_to_x, G_z, G_from_x)]
345 | return sdd
346 |
347 |
348 | def stack_metric_and_reduce(list_of_activations, mode, reduction=torch.max):
349 | """Equation (5) from the paper."""
350 | second_orders = torch.stack(list_of_activations) # (k, N, C, H, W)
351 | if mode == "Hoff":
352 | tensor = torch.var(second_orders, dim=0, unbiased=True) / 2 # (N, C, H, W)
353 | elif mode == "Hdiag":
354 | tensor = torch.mean((second_orders ** 2).real, dim=0)
355 | else:
356 | raise
357 | penalty = reduction(tensor) # (1,) (scalar)
358 | return penalty
359 |
360 |
361 | def multi_stack_metric_and_reduce(sdds, mode, reduction=torch.max, return_separately=False):
362 | """Iterate over all activations to be regularized, then apply Equation (5) to each."""
363 | sum_of_penalties = 0 if not return_separately else []
364 | for activ_n in zip(*sdds):
365 | penalty = stack_metric_and_reduce(activ_n, mode, reduction)
366 | sum_of_penalties += penalty if not return_separately else [penalty]
367 | return sum_of_penalties
368 |
369 |
370 | def listify(x):
371 | """If x is already a list, do nothing. Otherwise, wrap x in a list."""
372 | if isinstance(x, list):
373 | return x
374 | else:
375 | return [x]
376 |
377 |
378 | def _test_hessian_penalty(mode, k=100):
379 | """
380 | A simple multi-layer test to verify the implementation.
381 | Function: G(z) = [z_0 * z_1 + z0 ** 2, z_0**2 * z_1 + 2 * z1 ** 2]
382 | The Hessian for the first value is [
383 | [2, 1],
384 | [1, 0],
385 | ]
386 | so the offdiagonal sum square is 2
387 | the diagonal sum square is 4.
388 |
389 | The Hessian for the second function value is: [
390 | [2 * z_1, 2 * z_0],
391 | [2 * z_0, 4],
392 | ]
393 | so the offdiagonal sum square is 8 * z_0**2
394 | the diagonal sum square is 16 + 4 * z_1**2.
395 | Ground Truth Hessian Penalty: [2, 8 * z_0**2]
396 | """
397 | batch_size = 10
398 | nz = 2
399 | z = torch.randn(batch_size, nz)
400 | def reduction(x): return x.abs().mean()
401 | def G(z): return [z[:, 0] * z[:, 1] + z[:, 0] ** 2, (z[:, 0] ** 2) * z[:, 1] + 2 * z[:, 1] ** 2]
402 | if mode == "Hdiag":
403 | ground_truth = [4, 16 + 4 * (z[:, 1] ** 2).mean().item()]
404 | elif mode == "Hoff":
405 | ground_truth = [2, reduction(8 * z[:, 0] ** 2).item()]
406 | elif mode == "Hall":
407 | ground_truth = [4+2, 16 + 4 * (z[:, 1] ** 2).mean().item() + reduction(8 * z[:, 0] ** 2).item()]
408 | else:
409 | raise
410 | # In this simple example, we use k=100 to reduce variance, but when applied to neural networks
411 | # you will probably want to use a small k (e.g., k=2) due to memory considerations.
412 | predicted = get_Hessian_penalty(G, z, mode=mode, G_z=None, k=k, reduction=reduction, return_separately=True)
413 | predicted = [p.item() for p in predicted]
414 | print('Ground Truth: %s' % ground_truth)
415 | print('Approximation: %s' % predicted) # This should be close to ground_truth, but not exactly correct
416 | print('Difference: %s' % [str(100 * abs(p - gt) / gt) + '%' for p, gt in zip(predicted, ground_truth)])
417 |
418 |
419 | ## >>> functions for the MeshGraphNets:
420 |
421 | def init_weights_requ(m):
422 | if type(m) == BatchLinear or type(m) == nn.Linear:
423 | if hasattr(m, 'weight'):
424 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
425 |
426 |
427 | def init_weights_normal(m):
428 | if type(m) == BatchLinear or type(m) == nn.Linear:
429 | if hasattr(m, 'weight'):
430 | nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_out')
431 |
432 |
433 | def init_weights_selu(m):
434 | if type(m) == BatchLinear or type(m) == nn.Linear:
435 | if hasattr(m, 'weight'):
436 | num_input = m.weight.size(-1)
437 | nn.init.normal_(m.weight, std=1/math.sqrt(num_input))
438 |
439 |
440 | def init_weights_elu(m):
441 | if type(m) == BatchLinear or type(m) == nn.Linear:
442 | if hasattr(m, 'weight'):
443 | num_input = m.weight.size(-1)
444 | nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277)/math.sqrt(num_input))
445 |
446 |
447 | def init_weights_xavier(m):
448 | if type(m) == BatchLinear or type(m) == nn.Linear:
449 | if hasattr(m, 'weight'):
450 | nn.init.xavier_normal_(m.weight)
451 |
452 |
453 | def init_weights_uniform(m):
454 | if type(m) == BatchLinear or type(m) == nn.Linear:
455 | if hasattr(m, 'weight'):
456 | torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
457 |
458 |
459 | def sine_init(m, w0=60):
460 | with torch.no_grad():
461 | if hasattr(m, 'weight'):
462 | num_input = m.weight.size(-1)
463 | m.weight.uniform_(-np.sqrt(6/num_input)/w0, np.sqrt(6/num_input)/w0)
464 |
465 |
466 | def first_layer_sine_init(m):
467 | with torch.no_grad():
468 | if hasattr(m, 'weight'):
469 | num_input = m.weight.size(-1)
470 | m.weight.uniform_(-1/num_input, 1/num_input)
471 |
472 |
473 | class MetaModule(nn.Module):
474 | """
475 | Base class for PyTorch meta-learning modules. These modules accept an
476 | additional argument `params` in their `forward` method.
477 | Notes
478 | -----
479 | Objects inherited from `MetaModule` are fully compatible with PyTorch
480 | modules from `torch.nn.Module`. The argument `params` is a dictionary of
481 | tensors, with full support of the computation graph (for differentiation).
482 | """
483 | def __init__(self):
484 | super(MetaModule, self).__init__()
485 | self._children_modules_parameters_cache = dict()
486 |
487 | def meta_named_parameters(self, prefix='', recurse=True):
488 | gen = self._named_members(
489 | lambda module: module._parameters.items()
490 | if isinstance(module, MetaModule) else [],
491 | prefix=prefix, recurse=recurse)
492 | for elem in gen:
493 | yield elem
494 |
495 | def meta_parameters(self, recurse=True):
496 | for name, param in self.meta_named_parameters(recurse=recurse):
497 | yield param
498 |
499 | def get_subdict(self, params, key=None):
500 | if params is None:
501 | return None
502 |
503 | all_names = tuple(params.keys())
504 | if (key, all_names) not in self._children_modules_parameters_cache:
505 | if key is None:
506 | self._children_modules_parameters_cache[(key, all_names)] = all_names
507 |
508 | else:
509 | key_escape = re.escape(key)
510 | key_re = re.compile(r'^{0}\.(.+)'.format(key_escape))
511 |
512 | self._children_modules_parameters_cache[(key, all_names)] = [
513 | key_re.sub(r'\1', k) for k in all_names if key_re.match(k) is not None]
514 |
515 | names = self._children_modules_parameters_cache[(key, all_names)]
516 | if not names:
517 | warnings.warn('Module `{0}` has no parameter corresponding to the '
518 | 'submodule named `{1}` in the dictionary `params` '
519 | 'provided as an argument to `forward()`. Using the '
520 | 'default parameters for this submodule. The list of '
521 | 'the parameters in `params`: [{2}].'.format(
522 | self.__class__.__name__, key, ', '.join(all_names)),
523 | stacklevel=2)
524 | return None
525 |
526 | return OrderedDict([(name, params[f'{key}.{name}']) for name in names])
527 |
528 |
529 | class BatchLinear(nn.Linear, MetaModule):
530 | '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
531 | hypernetwork.'''
532 | __doc__ = nn.Linear.__doc__
533 |
534 | def forward(self, input, params=None):
535 | if params is None:
536 | params = OrderedDict(self.named_parameters())
537 |
538 | bias = params.get('bias', None)
539 | weight = params['weight']
540 |
541 | output = input.matmul(weight.permute(*[i for i in range(len(weight.shape)-2)], -1, -2))
542 | output += bias.unsqueeze(-2)
543 | return output
544 |
545 |
546 | class FirstSine(nn.Module):
547 | def __init__(self, w0=60):
548 | super().__init__()
549 | self.w0 = torch.tensor(w0)
550 |
551 | def forward(self, input):
552 | return torch.sin(self.w0*input)
553 |
554 |
555 | class Sine(nn.Module):
556 | def __init__(self, w0=60):
557 | super().__init__()
558 | self.w0 = torch.tensor(w0)
559 |
560 | def forward(self, input):
561 | return torch.sin(self.w0*input)
562 |
563 |
564 | class ReQU(nn.Module):
565 | def __init__(self, inplace=True):
566 | super().__init__()
567 | self.relu = nn.ReLU(inplace)
568 |
569 | def forward(self, input):
570 | # return torch.sin(np.sqrt(256)*input)
571 | return .5*self.relu(input)**2
572 |
573 |
574 | class MSoftplus(nn.Module):
575 | def __init__(self):
576 | super().__init__()
577 | self.softplus = nn.Softplus()
578 | self.cst = torch.log(torch.tensor(2.))
579 |
580 | def forward(self, input):
581 | return self.softplus(input)-self.cst
582 |
583 |
584 | class Swish(nn.Module):
585 | def __init__(self):
586 | super().__init__()
587 |
588 | def forward(self, input):
589 | return input*torch.sigmoid(input)
590 |
591 |
592 | def layer_factory(layer_type):
593 | layer_dict = {
594 | 'relu': (nn.ReLU(inplace=True), init_weights_normal),
595 | 'leakyrelu': (nn.LeakyReLU(inplace=True), init_weights_normal),
596 | 'requ': (ReQU(inplace=False), init_weights_requ),
597 | 'sigmoid': (nn.Sigmoid(), None),
598 | 'fsine': (Sine(), first_layer_sine_init),
599 | 'sine': (Sine(), sine_init),
600 | 'tanh': (nn.Tanh(), init_weights_xavier),
601 | 'selu': (nn.SELU(inplace=True), init_weights_selu),
602 | 'gelu': (nn.GELU(), init_weights_selu),
603 | 'swish': (Swish(), init_weights_selu),
604 | 'softplus': (nn.Softplus(), init_weights_normal),
605 | 'msoftplus': (MSoftplus(), init_weights_normal),
606 | 'elu': (nn.ELU(), init_weights_elu),
607 | 'silu': (nn.SiLU(), init_weights_selu),
608 | }
609 | return layer_dict[layer_type]
610 |
611 |
612 | class PositionalEncoding(nn.Module):
613 | def __init__(self, num_encoding_functions=6, include_input=True, log_sampling=True, normalize=False,
614 | input_dim=2, gaussian_pe=False, gaussian_variance=0.1):
615 | super().__init__()
616 | self.num_encoding_functions = num_encoding_functions
617 | self.include_input = include_input
618 | self.log_sampling = log_sampling
619 | self.normalize = normalize
620 | self.gaussian_pe = gaussian_pe
621 | self.normalization = None
622 |
623 | if self.gaussian_pe:
624 | # this needs to be registered as a parameter so that it is saved in the model state dict
625 | # and so that it is converted using .cuda(). Doesn't need to be trained though
626 | self.gaussian_weights = nn.Parameter(2*np.pi*gaussian_variance * torch.randn((num_encoding_functions*2), input_dim),
627 | requires_grad=False)
628 |
629 | else:
630 | self.frequency_bands = None
631 | if self.log_sampling:
632 | self.frequency_bands = 2.0 ** torch.linspace(
633 | 0.0,
634 | self.num_encoding_functions - 1,
635 | self.num_encoding_functions)
636 | else:
637 | self.frequency_bands = torch.linspace(
638 | 2.0 ** 0.0,
639 | 2.0 ** (self.num_encoding_functions - 1),
640 | self.num_encoding_functions)
641 |
642 | if normalize:
643 | self.normalization = torch.tensor(1/self.frequency_bands)
644 |
645 | def forward(self, tensor) -> torch.Tensor:
646 | r"""Apply positional encoding to the input.
647 | Args:
648 | tensor (torch.Tensor): Input tensor to be positionally encoded.
649 | encoding_size (optional, int): Number of encoding functions used to compute
650 | a positional encoding (default: 6).
651 | include_input (optional, bool): Whether or not to include the input in the
652 | positional encoding (default: True).
653 | Returns:
654 | (torch.Tensor): Positional encoding of the input tensor.
655 | """
656 |
657 | encoding = [tensor] if self.include_input else []
658 | if self.gaussian_pe:
659 | for func in [torch.sin, torch.cos]:
660 | encoding.append(func(torch.matmul(tensor, self.gaussian_weights.T)))
661 | else:
662 | for idx, freq in enumerate(self.frequency_bands):
663 | for func in [torch.sin, torch.cos]:
664 | if self.normalization is not None:
665 | encoding.append(self.normalization[idx]*func(tensor * freq))
666 | else:
667 | encoding.append(func(tensor * freq))
668 |
669 | # Special case, for no positional encoding
670 | if len(encoding) == 1:
671 | return encoding[0]
672 | else:
673 | return torch.cat(encoding, dim=-1)
674 |
675 |
676 |
677 | class FCBlock(nn.Module):
678 | '''A fully connected neural network that also allows swapping out the weights when used with a hypernetwork.
679 | Can be used just as a normal neural network though, as well.
680 | '''
681 | def __init__(self, in_features, out_features,
682 | num_hidden_layers, hidden_features,
683 | outermost_linear=False, outmost_nonlinearity=None, nonlinearity='relu',
684 | weight_init=None, w0=60, set_bias=None,
685 | dropout=0.0, layer_norm=False,latent_dim=64,skip_connect=None):
686 | super().__init__()
687 |
688 | self.skip_connect = skip_connect
689 | self.latent_dim = latent_dim
690 | self.first_layer_init = None
691 | self.dropout = dropout
692 |
693 | if outmost_nonlinearity==None:
694 | outmost_nonlinearity = nonlinearity
695 |
696 | # Create hidden features list
697 | if not isinstance(hidden_features, list):
698 | num_hidden_features = hidden_features
699 | hidden_features = []
700 | for i in range(num_hidden_layers+1):
701 | hidden_features.append(num_hidden_features)
702 | else:
703 | num_hidden_layers = len(hidden_features)-1
704 | #print(f"net_size={hidden_features}")
705 |
706 | # Create the net
707 | #print(f"num_layers={len(hidden_features)}")
708 | if isinstance(nonlinearity, list):
709 | print(f"num_non_lin={len(nonlinearity)}")
710 | assert len(hidden_features) == len(nonlinearity), "Num hidden layers needs to " "match the length of the list of non-linearities"
711 |
712 | self.net = []
713 | self.net.append(nn.Sequential(
714 | nn.Linear(in_features, hidden_features[0]),
715 | layer_factory(nonlinearity[0])[0]
716 | ))
717 | for i in range(num_hidden_layers):
718 | if self.skip_connect==None:
719 | self.net.append(nn.Sequential(
720 | nn.Linear(hidden_features[i], hidden_features[i+1]),
721 | layer_factory(nonlinearity[i+1])[0]
722 | ))
723 | else:
724 | if i+1 in self.skip_connect:
725 | self.net.append(nn.Sequential(
726 | nn.Linear(hidden_features[i]+self.latent_dim, hidden_features[i+1]),
727 | layer_factory(nonlinearity[i+1])[0]
728 | ))
729 | else:
730 | self.net.append(nn.Sequential(
731 | nn.Linear(hidden_features[i], hidden_features[i+1]),
732 | layer_factory(nonlinearity[i+1])[0]
733 | ))
734 |
735 | if outermost_linear:
736 | self.net.append(nn.Sequential(
737 | nn.Linear(hidden_features[-1], out_features),
738 | ))
739 | else:
740 | self.net.append(nn.Sequential(
741 | nn.Linear(hidden_features[-1], out_features),
742 | layer_factory(nonlinearity[-1])[0]
743 | ))
744 | elif isinstance(nonlinearity, str):
745 | nl, weight_init = layer_factory(nonlinearity)
746 | outmost_nl, _ = layer_factory(outmost_nonlinearity)
747 | if(nonlinearity == 'sine'):
748 | first_nl = FirstSine()
749 | self.first_layer_init = first_layer_sine_init
750 | else:
751 | first_nl = nl
752 |
753 | if weight_init is not None:
754 | self.weight_init = weight_init
755 |
756 | self.net = []
757 | self.net.append(nn.Sequential(
758 | nn.Linear(in_features, hidden_features[0]),
759 | first_nl
760 | ))
761 |
762 | for i in range(num_hidden_layers):
763 | if(self.dropout > 0):
764 | self.net.append(nn.Dropout(self.dropout))
765 | if self.skip_connect == None:
766 | self.net.append(nn.Sequential(
767 | nn.Linear(hidden_features[i], hidden_features[i+1]),
768 | copy.deepcopy(nl)
769 | ))
770 | else:
771 | if i+1 in self.skip_connect:
772 | self.net.append(nn.Sequential(
773 | nn.Linear(hidden_features[i]+self.latent_dim, hidden_features[i+1]),
774 | copy.deepcopy(nl)
775 | ))
776 | else:
777 | self.net.append(nn.Sequential(
778 | nn.Linear(hidden_features[i], hidden_features[i+1]),
779 | copy.deepcopy(nl)
780 | ))
781 |
782 | if (self.dropout > 0):
783 | self.net.append(nn.Dropout(self.dropout))
784 | if outermost_linear:
785 | self.net.append(nn.Sequential(
786 | nn.Linear(hidden_features[-1], out_features),
787 | ))
788 | else:
789 | self.net.append(nn.Sequential(
790 | nn.Linear(hidden_features[-1], out_features),
791 | copy.deepcopy(outmost_nl)
792 | ))
793 | if layer_norm:
794 | self.net.append(nn.LayerNorm([out_features]))
795 |
796 | self.net = nn.Sequential(*self.net)
797 |
798 | if isinstance(nonlinearity, list):
799 | for layer_num, layer_name in enumerate(nonlinearity):
800 | self.net[layer_num].apply(layer_factory(layer_name)[1])
801 | elif isinstance(nonlinearity, str):
802 | if self.weight_init is not None:
803 | self.net.apply(self.weight_init)
804 |
805 | if self.first_layer_init is not None:
806 | self.net[0].apply(self.first_layer_init)
807 |
808 | if set_bias is not None:
809 | self.net[-1][0].bias.data = set_bias * torch.ones_like(self.net[-1][0].bias.data)
810 |
811 | def forward(self, coords, batch_vecs=None):
812 | if self.skip_connect == None:
813 | output = self.net(coords)
814 | else:
815 | input = coords
816 | for i in range(len(self.net)):
817 | output = self.net[i](input)
818 | if i+1 in self.skip_connect:
819 | input = torch.cat([batch_vecs, output], dim=-1)
820 | else:
821 | input = output
822 | return output
823 |
824 |
825 | class CoordinateNet_autodecoder(nn.Module):
826 | '''A autodecoder network'''
827 | def __init__(self, latent_size=64, out_features=1, nl='sine', in_features=64+2,
828 | hidden_features=256, num_hidden_layers=3, num_pe_fns=6,
829 | w0=60,use_pe=False,skip_connect=None,dataset_size=100,
830 | outmost_nonlinearity=None,outermost_linear=True):
831 | super().__init__()
832 |
833 | self.nl = nl
834 | self.use_pe = use_pe
835 | self.latent_size = latent_size
836 | self.lat_vecs = torch.nn.Embedding(dataset_size, self.latent_size)
837 | torch.nn.init.normal_(self.lat_vecs.weight.data, 0.0, 1/ math.sqrt(self.latent_size))
838 |
839 | if self.nl != 'sine' and use_pe:
840 | in_features = 2 * (2*num_pe_fns + 1)+latent_size
841 |
842 | if self.use_pe:
843 | self.pe = PositionalEncoding(num_encoding_functions=num_pe_fns)
844 | self.decoder = FCBlock(in_features=in_features,
845 | out_features=out_features,
846 | num_hidden_layers=num_hidden_layers,
847 | hidden_features=hidden_features,
848 | outermost_linear=outermost_linear,
849 | nonlinearity=nl,
850 | w0=w0,skip_connect=skip_connect,latent_dim=latent_size,outmost_nonlinearity=outmost_nonlinearity)
851 | self.mean = torch.mean(torch.mean(self.lat_vecs.weight.data.detach(), dim=1)).cuda()
852 | self.varience = torch.mean(torch.var(self.lat_vecs.weight.data.detach(), dim=1)).cuda()
853 |
854 |
855 | def forward(self, model_input,latent=None):
856 | coords = model_input['coords'].clone().detach().requires_grad_(True)
857 | if latent==None:
858 | batch_vecs = self.lat_vecs(model_input['idx']).unsqueeze(1).repeat(1,coords.shape[1],1)
859 | else:
860 | batch_vecs = latent.unsqueeze(1).repeat(1,coords.shape[1],1)
861 |
862 | if self.nl != 'sine' and self.use_pe:
863 | coords_pe = self.pe(coords)
864 | input = torch.cat([batch_vecs, coords_pe], dim=-1)
865 | output = self.decoder(input,batch_vecs)
866 | else:
867 | input = torch.cat([batch_vecs, coords], dim=-1)
868 | output = self.decoder(input,batch_vecs)
869 |
870 | return {'model_in': coords, 'model_out': output,'batch_vecs': batch_vecs, 'meta': model_input}
--------------------------------------------------------------------------------