├── .gitignore ├── LICENSE ├── README.md ├── icon-lm ├── README.md ├── analysis │ ├── analysis.py │ ├── analysis_partial.sh │ ├── analysis_seeds.sh │ ├── analysis_weno.sh │ └── analysis_weno_aug.py ├── attention_module.py ├── config_data │ ├── test_icon_config.json │ ├── test_lm_config.json │ ├── test_lm_pde_full_config.json │ ├── test_lm_plot_mfc_config.json │ ├── test_lm_precise_config.json │ ├── test_lm_vague_config.json │ ├── test_lm_weno_config.json │ ├── train_icon_config.json │ ├── train_lm_config.json │ ├── train_lm_pde_full_config.json │ ├── train_lm_precise_config.json │ ├── train_lm_vague_config.json │ └── train_lm_weno_config.json ├── config_model │ ├── model_deepo_pde_config.json │ ├── model_deepo_weno_config.json │ ├── model_fno_pde_config.json │ ├── model_fno_weno_config.json │ ├── model_gpt2_config.json │ ├── model_icon_config.json │ └── model_lm_config.json ├── data_preparation │ ├── captions_1009 │ │ ├── mfc_gparam.md │ │ ├── mfc_rhoparam.md │ │ ├── ode1.md │ │ ├── ode2.md │ │ ├── ode3.md │ │ ├── pde1.md │ │ ├── pde2.md │ │ ├── pde3.md │ │ ├── resolve.py │ │ ├── series.md │ │ └── suffix.json │ ├── data_dynamics.py │ ├── data_io.py │ ├── data_mfc_hj.py │ ├── data_mfc_pdhg.py │ ├── data_pdes.py │ ├── data_series.py │ ├── data_utils.py │ ├── data_writetfrecord.py │ ├── datagen.py │ ├── datagen.sh │ ├── datagen_weno.py │ ├── datagen_weno.sh │ ├── datagen_weno_fix.sh │ └── weno │ │ ├── utils.py │ │ ├── weno_3_coeff.py │ │ ├── weno_roll.py │ │ ├── weno_scheme.py │ │ ├── weno_solver.py │ │ ├── weno_test_acc.py │ │ └── weno_test_disc.py ├── data_sequence.py ├── dataloader.py ├── dataloader_realtime.py ├── environment.yml ├── models_deepo.py ├── models_fno.py ├── models_gpt2_check.py ├── models_gpt2_icon.py ├── models_gpt2_source.py ├── models_icon.py ├── models_lm.py ├── models_utils.py ├── models_utils_pytorch.py ├── operator │ ├── finetune.py │ ├── tune_deepo.sh │ └── tune_fno.sh ├── operator_weno │ ├── analysis.py │ ├── analysis_plot.py │ ├── analysis_time.py │ ├── datagen.py │ └── datagen_time.py ├── plot.py ├── plot_icon_lm │ ├── plot_benchmark.py │ ├── plot_cap.py │ ├── plot_finetune.py │ ├── plot_profile.py │ ├── plot_scratch.py │ ├── plot_separate.py │ └── plot_utils.py ├── plot_icon_weno │ ├── cubic_approx.py │ ├── cubic_approx_nochange.py │ ├── plot_weno_cubic.py │ ├── plot_weno_cubic_decay.py │ ├── plot_weno_cubic_profile.py │ ├── plot_weno_new.py │ ├── plot_weno_new_cubicfit.py │ ├── plot_weno_new_profile.py │ ├── plot_weno_new_profile_split.py │ ├── plot_weno_new_time.py │ └── plot_weno_new_variable.py ├── plot_mask_lm.py ├── run.py ├── run.sh ├── run_weno.sh ├── runner_deepo_torch.py ├── runner_jax.py ├── runner_torch.py ├── transformer_flax.py ├── transformer_hk.py └── utils.py └── icon ├── Dockerfile ├── README.md ├── analysis ├── analysis_accelerate.py ├── analysis_ind.sh ├── analysis_len.sh ├── analysis_ood.sh ├── plot_analysis_ind.ipynb ├── plot_analysis_len.ipynb ├── plot_analysis_nt.py ├── plot_analysis_ood.py ├── plot_analysis_plot_1d.py ├── plot_analysis_plot_2d.py ├── plot_analysis_utils.py ├── test_config.json ├── test_config_len.json └── test_config_ood.json ├── data_generation ├── data_dynamics.py ├── data_mfc_hj.py ├── data_pdes.py ├── data_series.py ├── datagen.py ├── datagen.sh ├── datagen_ood.py ├── datagen_ood.sh └── datawrite_tfrecord.py ├── data_sequence.py ├── dataloader.py ├── models.py ├── plot.py ├── run.py ├── run_group.sh ├── train_config.json ├── transformer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.DS_Store 2 | **/*.pyc 3 | useless/* 4 | **/results*/* 5 | **/*log 6 | **/check_points*/* 7 | **/*pickle 8 | **/*tfrecord 9 | **/*pdf 10 | **/*png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Liu Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /icon-lm/README.md: -------------------------------------------------------------------------------- 1 | # ICON-LM for Multi-Modal In-Context Operator Learning, and Conservation Laws 2 | 3 | This folder contains the code associated with the following two papers: 4 | - [Fine-Tune Language Models as Multi-Modal Differential Equation Solvers](https://arxiv.org/pdf/2308.05061.pdf). This paper focuses on improving the model architecture and training scheme for in-context operator learning. 5 | 6 | - [PDE Generalization of In-Context Operator Networks: A Study on 1D Scalar Nonlinear Conservation Laws](https://arxiv.org/pdf/2401.07364.pdf). This paper focuses on solving PDE problems with in-context operator learning, espetially its generalization capability to new PDEs. 7 | 8 | ## Environment Setup 9 | 10 | The YAML file `environment.yml` contains the environment setup for the code. To create the environment, run the following command: 11 | ``` 12 | conda env create -f environment.yml 13 | ``` 14 | which will create an environment named `icon`. You may need to update the `prefix` field in the YAML file to specify the location of the environment. 15 | 16 | ## Multi-Modal In-Context Operator Learning 17 | 18 | Fine-Tune Language Models as Multi-Modal Differential Equation Solvers 19 | 20 | ### Data Preparation 21 | 22 | The captions are already stored in the `data_preparation/captions_1009` folder. The code for function data generation is located in the `data_preparation/` folder. Navigate to the `data_generation/` folder and run `bash datagen.sh`, which will generate the function data for the experiments. 23 | 24 | The generated data will be stored in the `data_preparation/data` folder. We moved data to `/home/shared/icon/data/data0910c` for our experiments. 25 | 26 | ### Training 27 | 28 | All the in-context operator learning models shown in the paper are supported here, including (1) encoder-decoder ICON (baseline), (2) ICON-LM (ours), and (3) fine-tuning GPT-2. The run commands for training each model are listed in `run.sh`. 29 | 30 | ### Analysis and Visualization 31 | 32 | The analysis code and run commands are located in the `analysis/` folder. The visualization code are located in the `plot_icon_lm/` folder. 33 | 34 | ### Classic Operator Learning Models 35 | 36 | In the paper we compared ICON-LM with classic operator learning models, including FNO and DeepONet. The run commands for pretraining FNO and DeepONet are listed in `run.sh`. The code and commands for fine-tuning FNO and DeepONet are located in the `operator` folder. 37 | 38 | 39 | ## Conservation Laws 40 | PDE Generalization of In-Context Operator Networks: A Study on 1D Scalar Nonlinear Conservation Laws 41 | 42 | ### Data Preparation 43 | 44 | Navigate to the `data_generation/` folder and run `bash datagen_weno.sh` to generate the data for training ICON-LM. Run `bash datagen_weno_fix.sh` to generate the data for pretraining classic operator learning methods. 45 | 46 | ### Training 47 | 48 | The run commands for training the ICON-LM models and pretraining classic operator learning models are listed in `run_weno.sh`. 49 | 50 | ### Analysis and Visualization 51 | 52 | The analysis code are located in the `analysis/` folder, with and run commands in `analysis/analysis_weno.sh`. The visualization code are located in the `plot_icon_weno/` folder. 53 | 54 | ### Classic Operator Learning Models 55 | 56 | The run commands for pretraining FNO and DeepONet are listed in `run_weno.sh`. The code and commands for fine-tuning FNO and DeepONet are located in the `operator_weno` folder. You need to first generate fine-tuning data using `operator_weno/datagen.py` before fine-tuning. 57 | -------------------------------------------------------------------------------- /icon-lm/analysis/analysis_partial.sh: -------------------------------------------------------------------------------- 1 | # This sh file focuses on Problem #14 the inverse nonlinear reaction-diffusion PDE problem 2 | # It will restore the trained model from restore_dir, run analysis on the test set of Problem #14, and store results in the analysis directory 3 | 4 | 5 | stamp='20231005-094726' 6 | analysis_dir='/home/shared/icon/analysis/icon_lm_learn_'$stamp'-pde3-inverse' 7 | restore_dir='/home/shared/icon/save/user/ckpts/icon_lm_learn/'$stamp 8 | 9 | 10 | CUDA_VISIBLE_DEVICES=0 python3 analysis.py --backend jax --model icon_lm --task ind --write quest,demo,equation \ 11 | --test_demo_num_list 5 --test_caption_id_list -1 --loss_mode nocap \ 12 | --test_config_filename test_lm_pde_full_config.json \ 13 | --test_data_globs 'test_pde_cubic_spatial_inverse*' \ 14 | --model_config_filename model_lm_config.json \ 15 | --analysis_dir $analysis_dir \ 16 | --restore_dir $restore_dir \ 17 | --batch_size 10 >out_analysis_icon_lm_learn-$stamp-pde3-inverse.log 2>&1 && 18 | 19 | 20 | echo "Done." 21 | -------------------------------------------------------------------------------- /icon-lm/analysis/analysis_weno.sh: -------------------------------------------------------------------------------- 1 | stamp='20231209-222440' 2 | analysis_dir='/home/shared/icon/analysis/icon_weno_'$stamp'_light' 3 | restore_dir='/home/shared/icon/save/user/ckpts/icon_weno/'$stamp 4 | test_data_dirs='/home/shared/icon/data/data0904_weno_cubic_test_light' # use the light version for quick analysis 5 | 6 | 7 | CUDA_VISIBLE_DEVICES=0 python3 analysis.py --model 'icon_lm' --backend jax --task weno_cubic --write quest,demo \ 8 | --test_caption_id_list -1 --test_data_dirs $test_data_dirs --loss_mode nocap \ 9 | --test_config_filename 'test_lm_weno_config.json' --model_config_filename 'model_lm_config.json' \ 10 | --restore_dir $restore_dir --analysis_dir $analysis_dir --batch_size 64 \ 11 | >out_analysis_icon_weno_$stamp.light.log 2>&1 && 12 | 13 | 14 | echo "Done." 15 | -------------------------------------------------------------------------------- /icon-lm/analysis/analysis_weno_aug.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | import jax 7 | import jax.numpy as jnp 8 | from jax.config import config 9 | import tensorflow as tf 10 | import os 11 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 12 | tf.config.set_visible_devices([], device_type='GPU') 13 | from collections import OrderedDict 14 | 15 | import sys 16 | sys.path.append('../data_preparation/') 17 | sys.path.append('../') 18 | sys.path.append('../data_preparation/weno/') 19 | import utils 20 | from datagen_weno import generate_weno_scalar_sol 21 | 22 | 23 | def write_quadratic_consistency_error(folder, eqn_name): 24 | with open("{}/result_dict.pkl".format(folder), "rb") as file: 25 | result_dict = pickle.load(file) 26 | consistency_dict = {} 27 | for key in result_dict.keys(): 28 | if key[0] == eqn_name and key[3] == 'pred': 29 | print(key, flush = True) 30 | _, coeff_a, coeff_b, _, demo_num, caption_id = key 31 | fn = jax.jit(lambda u: coeff_a * u * u + coeff_b * u) 32 | grad_fn = jax.jit(lambda u: 2 * coeff_a * u + coeff_b) 33 | forward = generate_weno_scalar_sol(dx = 0.01, dt = 0.001, init = result_dict[key], fn = fn, steps = 100, grad_fn = grad_fn)[:,-1,...] 34 | consistency_dict[(eqn_name, coeff_a, coeff_b, 'forward', demo_num, caption_id)] = forward 35 | # save consistency_dict to the same folder 36 | with open("{}/consistency_dict.pkl".format(folder), "wb") as file: 37 | pickle.dump(consistency_dict, file) 38 | 39 | 40 | def write_cubic_consistency_error(folder, eqn_name): 41 | with open("{}/result_dict.pkl".format(folder), "rb") as file: 42 | result_dict = pickle.load(file) 43 | consistency_dict = {} 44 | for key in result_dict.keys(): 45 | if key[0] == eqn_name and key[4] == 'pred': 46 | print(key, flush = True) 47 | _, coeff_a, coeff_b, coeff_c, _, demo_num, caption_id = key 48 | fn = jax.jit(lambda u: coeff_a * u * u * u + coeff_b * u * u + coeff_c * u) 49 | grad_fn = jax.jit(lambda u: 3 * coeff_a * u * u + 2 * coeff_b * u + coeff_c) 50 | forward = generate_weno_scalar_sol(dx = 0.01, dt = 0.0005, init = result_dict[key], fn = fn, steps = 200, grad_fn = grad_fn)[:,-1,...] 51 | consistency_dict[(eqn_name, coeff_a, coeff_b, coeff_c, 'forward', demo_num, caption_id)] = forward 52 | # save consistency_dict to the same folder 53 | with open("{}/consistency_dict.pkl".format(folder), "wb") as file: 54 | pickle.dump(consistency_dict, file) 55 | 56 | def write_sin_consistency_error(folder, eqn_name): 57 | with open("{}/result_dict.pkl".format(folder), "rb") as file: 58 | result_dict = pickle.load(file) 59 | consistency_dict = {} 60 | for key in result_dict.keys(): 61 | if key[0] == eqn_name and key[5] == 'pred': 62 | print(key, flush = True) 63 | _, coeff_a, coeff_b, coeff_c, stride, _, demo_num, caption_id = key 64 | fn = jax.jit(lambda u: coeff_a * jnp.sin(coeff_c * u) + coeff_b * jnp.cos(coeff_c * u)) 65 | grad_fn = jax.jit(lambda u: coeff_a * coeff_c + jnp.cos(coeff_c * u) - coeff_b * coeff_c + jnp.sin(coeff_c * u)) 66 | forward = generate_weno_scalar_sol(dx = 0.01, dt = 0.0005, init = result_dict[key], fn = fn, steps = int(stride), grad_fn = grad_fn)[:,-1,...] 67 | consistency_dict[(eqn_name, coeff_a, coeff_b, coeff_c, stride, 'forward', demo_num, caption_id)] = forward 68 | # save consistency_dict to the same folder 69 | with open("{}/consistency_dict.pkl".format(folder), "wb") as file: 70 | pickle.dump(consistency_dict, file) 71 | 72 | if __name__ == "__main__": 73 | 74 | # folder = "/home/shared/icon/analysis/icon_weno_20230829-170831_light" 75 | # eqn_name = 'conservation_weno_quadratic_backward' 76 | # write_consistency_error(folder, eqn_name) 77 | 78 | folder = "/home/shared/icon/analysis/icon_weno_20230904-184910_sin" 79 | eqn_name = 'conservation_weno_sin_backward' 80 | write_sin_consistency_error(folder, eqn_name) 81 | -------------------------------------------------------------------------------- /icon-lm/config_data/test_icon_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "random", 13 | "return_raw": false, 14 | "load_list": [], 15 | "pde_mask":{ 16 | "demo_num_begin": 5, 17 | "demo_num_end": 6, 18 | "start_ind_begin": 5, 19 | "start_ind_end": 45, 20 | "demo_cond_len_in_use_begin": 41, 21 | "demo_cond_len_in_use_end": 51, 22 | "demo_qoi_len_in_use_begin": 41, 23 | "demo_qoi_len_in_use_end": 51, 24 | "quest_cond_len_in_use_begin": 41, 25 | "quest_cond_len_in_use_end": 51, 26 | "quest_qoi_len_in_use_begin": 41, 27 | "quest_qoi_len_in_use_end": 51}, 28 | "pde_spatial_forward":{ 29 | "demo_num_begin": 5, 30 | "demo_num_end": 6, 31 | "demo_cond_select": "random", 32 | "demo_qoi_select": "random", 33 | "quest_cond_select": "random", 34 | "quest_qoi_select": "random", 35 | "demo_cond_len_in_use_begin": 41, 36 | "demo_cond_len_in_use_end": 51, 37 | "demo_qoi_len_in_use_begin": 41, 38 | "demo_qoi_len_in_use_end": 51, 39 | "quest_cond_len_in_use_begin": 41, 40 | "quest_cond_len_in_use_end": 51, 41 | "quest_qoi_len_in_use_begin": 41, 42 | "quest_qoi_len_in_use_end": 51}, 43 | "pde_spatial_inverse":{ 44 | "demo_num_begin": 5, 45 | "demo_num_end": 6, 46 | "demo_cond_select": "random", 47 | "demo_qoi_select": "random", 48 | "quest_cond_select": "random", 49 | "quest_qoi_select": "random", 50 | "demo_cond_len_in_use_begin": 41, 51 | "demo_cond_len_in_use_end": 51, 52 | "demo_qoi_len_in_use_begin": 41, 53 | "demo_qoi_len_in_use_end": 51, 54 | "quest_cond_len_in_use_begin": 41, 55 | "quest_cond_len_in_use_end": 51, 56 | "quest_qoi_len_in_use_begin": 41, 57 | "quest_qoi_len_in_use_end": 51}, 58 | "ode_forward":{ 59 | "demo_num_begin": 5, 60 | "demo_num_end": 6, 61 | "demo_cond_select": "even", 62 | "demo_qoi_select": "even", 63 | "quest_cond_select": "even", 64 | "quest_qoi_select": "even", 65 | "demo_qoi_len_in_use_begin": 41, 66 | "demo_qoi_len_in_use_end": 51, 67 | "quest_qoi_len_in_use_begin": 41, 68 | "quest_qoi_len_in_use_end": 51}, 69 | "ode_inverse":{ 70 | "demo_num_begin": 5, 71 | "demo_num_end": 6, 72 | "demo_cond_select": "even", 73 | "demo_qoi_select": "even", 74 | "quest_cond_select": "even", 75 | "quest_qoi_select": "even", 76 | "demo_qoi_len_in_use_begin": 40, 77 | "demo_qoi_len_in_use_end": 50, 78 | "quest_qoi_len_in_use_begin": 40, 79 | "quest_qoi_len_in_use_end": 50}, 80 | "time_series":{ 81 | "demo_num_begin": 5, 82 | "demo_num_end": 6, 83 | "demo_cond_select": "random", 84 | "demo_qoi_select": "random", 85 | "quest_cond_select": "random", 86 | "quest_qoi_select": "random", 87 | "demo_cond_len_in_use_begin": 41, 88 | "demo_cond_len_in_use_end": 51, 89 | "demo_qoi_len_in_use_begin": 41, 90 | "demo_qoi_len_in_use_end": 51, 91 | "quest_cond_len_in_use_begin": 41, 92 | "quest_cond_len_in_use_end": 51, 93 | "quest_qoi_len_in_use_begin": 41, 94 | "quest_qoi_len_in_use_end": 51}, 95 | "mfc_gparam_forward":{ 96 | "demo_num_begin": 5, 97 | "demo_num_end": 6, 98 | "demo_cond_select": "random", 99 | "demo_qoi_select": "random", 100 | "quest_cond_select": "random", 101 | "quest_qoi_select": "random", 102 | "demo_cond_len_in_use_begin": 41, 103 | "demo_cond_len_in_use_end": 51, 104 | "demo_qoi_len_in_use_begin": 41, 105 | "demo_qoi_len_in_use_end": 51, 106 | "quest_cond_len_in_use_begin": 41, 107 | "quest_cond_len_in_use_end": 51, 108 | "quest_qoi_len_in_use_begin": 41, 109 | "quest_qoi_len_in_use_end": 51}, 110 | "mfc_rhoparam_forward":{ 111 | "demo_num_begin": 5, 112 | "demo_num_end": 6, 113 | "demo_cond_select": "random", 114 | "demo_qoi_select": "random", 115 | "quest_cond_select": "random", 116 | "quest_qoi_select": "random", 117 | "demo_cond_len_in_use_begin": 41, 118 | "demo_cond_len_in_use_end": 51, 119 | "demo_qoi_len_in_use_begin": 41, 120 | "demo_qoi_len_in_use_end": 51, 121 | "quest_cond_len_in_use_begin": 41, 122 | "quest_cond_len_in_use_end": 51, 123 | "quest_qoi_len_in_use_begin": 41, 124 | "quest_qoi_len_in_use_end": 51}, 125 | "others":{ 126 | "demo_num_begin": 5, 127 | "demo_num_end": 6, 128 | "demo_cond_select": "random", 129 | "demo_qoi_select": "random", 130 | "quest_cond_select": "random", 131 | "quest_qoi_select": "random", 132 | "demo_cond_len_in_use_begin": 41, 133 | "demo_cond_len_in_use_end": 51, 134 | "demo_qoi_len_in_use_begin": 41, 135 | "demo_qoi_len_in_use_end": 51, 136 | "quest_cond_len_in_use_begin": 41, 137 | "quest_cond_len_in_use_end": 51, 138 | "quest_qoi_len_in_use_begin": 41, 139 | "quest_qoi_len_in_use_end": 51} 140 | } -------------------------------------------------------------------------------- /icon-lm/config_data/test_lm_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "random", 13 | "load_caption": {"0":"list(range(80))", "1":"list(range(80))"}, 14 | "caption_dir": "captions_1009", 15 | "return_raw": false, 16 | "load_list": ["input_id"], 17 | "pde_mask":{ 18 | "demo_num_begin": 5, 19 | "demo_num_end": 6, 20 | "start_ind_begin": 5, 21 | "start_ind_end": 45, 22 | "demo_cond_len_in_use_begin": 41, 23 | "demo_cond_len_in_use_end": 51, 24 | "demo_qoi_len_in_use_begin": 41, 25 | "demo_qoi_len_in_use_end": 51, 26 | "quest_cond_len_in_use_begin": 41, 27 | "quest_cond_len_in_use_end": 51, 28 | "quest_qoi_len_in_use_begin": 41, 29 | "quest_qoi_len_in_use_end": 51}, 30 | "pde_spatial_forward":{ 31 | "demo_num_begin": 5, 32 | "demo_num_end": 6, 33 | "demo_cond_select": "random", 34 | "demo_qoi_select": "random", 35 | "quest_cond_select": "random", 36 | "quest_qoi_select": "random", 37 | "demo_cond_len_in_use_begin": 41, 38 | "demo_cond_len_in_use_end": 51, 39 | "demo_qoi_len_in_use_begin": 41, 40 | "demo_qoi_len_in_use_end": 51, 41 | "quest_cond_len_in_use_begin": 41, 42 | "quest_cond_len_in_use_end": 51, 43 | "quest_qoi_len_in_use_begin": 41, 44 | "quest_qoi_len_in_use_end": 51}, 45 | "pde_spatial_inverse":{ 46 | "demo_num_begin": 5, 47 | "demo_num_end": 6, 48 | "demo_cond_select": "random", 49 | "demo_qoi_select": "random", 50 | "quest_cond_select": "random", 51 | "quest_qoi_select": "random", 52 | "demo_cond_len_in_use_begin": 41, 53 | "demo_cond_len_in_use_end": 51, 54 | "demo_qoi_len_in_use_begin": 41, 55 | "demo_qoi_len_in_use_end": 51, 56 | "quest_cond_len_in_use_begin": 41, 57 | "quest_cond_len_in_use_end": 51, 58 | "quest_qoi_len_in_use_begin": 41, 59 | "quest_qoi_len_in_use_end": 51}, 60 | "ode_forward":{ 61 | "demo_num_begin": 5, 62 | "demo_num_end": 6, 63 | "demo_cond_select": "even", 64 | "demo_qoi_select": "even", 65 | "quest_cond_select": "even", 66 | "quest_qoi_select": "even", 67 | "demo_qoi_len_in_use_begin": 41, 68 | "demo_qoi_len_in_use_end": 51, 69 | "quest_qoi_len_in_use_begin": 41, 70 | "quest_qoi_len_in_use_end": 51}, 71 | "ode_inverse":{ 72 | "demo_num_begin": 5, 73 | "demo_num_end": 6, 74 | "demo_cond_select": "even", 75 | "demo_qoi_select": "even", 76 | "quest_cond_select": "even", 77 | "quest_qoi_select": "even", 78 | "demo_qoi_len_in_use_begin": 40, 79 | "demo_qoi_len_in_use_end": 50, 80 | "quest_qoi_len_in_use_begin": 40, 81 | "quest_qoi_len_in_use_end": 50}, 82 | "time_series":{ 83 | "demo_num_begin": 5, 84 | "demo_num_end": 6, 85 | "demo_cond_select": "random", 86 | "demo_qoi_select": "random", 87 | "quest_cond_select": "random", 88 | "quest_qoi_select": "random", 89 | "demo_cond_len_in_use_begin": 41, 90 | "demo_cond_len_in_use_end": 51, 91 | "demo_qoi_len_in_use_begin": 41, 92 | "demo_qoi_len_in_use_end": 51, 93 | "quest_cond_len_in_use_begin": 41, 94 | "quest_cond_len_in_use_end": 51, 95 | "quest_qoi_len_in_use_begin": 41, 96 | "quest_qoi_len_in_use_end": 51}, 97 | "mfc_gparam_forward":{ 98 | "demo_num_begin": 5, 99 | "demo_num_end": 6, 100 | "demo_cond_select": "random", 101 | "demo_qoi_select": "random", 102 | "quest_cond_select": "random", 103 | "quest_qoi_select": "random", 104 | "demo_cond_len_in_use_begin": 41, 105 | "demo_cond_len_in_use_end": 51, 106 | "demo_qoi_len_in_use_begin": 41, 107 | "demo_qoi_len_in_use_end": 51, 108 | "quest_cond_len_in_use_begin": 41, 109 | "quest_cond_len_in_use_end": 51, 110 | "quest_qoi_len_in_use_begin": 41, 111 | "quest_qoi_len_in_use_end": 51}, 112 | "mfc_rhoparam_forward":{ 113 | "demo_num_begin": 5, 114 | "demo_num_end": 6, 115 | "demo_cond_select": "random", 116 | "demo_qoi_select": "random", 117 | "quest_cond_select": "random", 118 | "quest_qoi_select": "random", 119 | "demo_cond_len_in_use_begin": 41, 120 | "demo_cond_len_in_use_end": 51, 121 | "demo_qoi_len_in_use_begin": 41, 122 | "demo_qoi_len_in_use_end": 51, 123 | "quest_cond_len_in_use_begin": 41, 124 | "quest_cond_len_in_use_end": 51, 125 | "quest_qoi_len_in_use_begin": 41, 126 | "quest_qoi_len_in_use_end": 51}, 127 | "others":{ 128 | "demo_num_begin": 5, 129 | "demo_num_end": 6, 130 | "demo_cond_select": "random", 131 | "demo_qoi_select": "random", 132 | "quest_cond_select": "random", 133 | "quest_qoi_select": "random", 134 | "demo_cond_len_in_use_begin": 41, 135 | "demo_cond_len_in_use_end": 51, 136 | "demo_qoi_len_in_use_begin": 41, 137 | "demo_qoi_len_in_use_end": 51, 138 | "quest_cond_len_in_use_begin": 41, 139 | "quest_cond_len_in_use_end": 51, 140 | "quest_qoi_len_in_use_begin": 41, 141 | "quest_qoi_len_in_use_end": 51} 142 | } -------------------------------------------------------------------------------- /icon-lm/config_data/test_lm_plot_mfc_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 2600, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "random", 13 | "load_caption": {"0":"list(range(80))", "1":"list(range(80))"}, 14 | "caption_dir": "captions_1009", 15 | "return_raw": true, 16 | "load_list": ["input_id"], 17 | "pde_mask":{ 18 | "demo_num_begin": 5, 19 | "demo_num_end": 6, 20 | "start_ind_begin": 5, 21 | "start_ind_end": 45, 22 | "demo_cond_len_in_use_begin": 41, 23 | "demo_cond_len_in_use_end": 51, 24 | "demo_qoi_len_in_use_begin": 41, 25 | "demo_qoi_len_in_use_end": 51, 26 | "quest_cond_len_in_use_begin": 41, 27 | "quest_cond_len_in_use_end": 51, 28 | "quest_qoi_len_in_use_begin": 41, 29 | "quest_qoi_len_in_use_end": 51}, 30 | "pde_spatial_forward":{ 31 | "demo_num_begin": 5, 32 | "demo_num_end": 6, 33 | "demo_cond_select": "random", 34 | "demo_qoi_select": "random", 35 | "quest_cond_select": "random", 36 | "quest_qoi_select": "random", 37 | "demo_cond_len_in_use_begin": 41, 38 | "demo_cond_len_in_use_end": 51, 39 | "demo_qoi_len_in_use_begin": 41, 40 | "demo_qoi_len_in_use_end": 51, 41 | "quest_cond_len_in_use_begin": 41, 42 | "quest_cond_len_in_use_end": 51, 43 | "quest_qoi_len_in_use_begin": 41, 44 | "quest_qoi_len_in_use_end": 51}, 45 | "pde_spatial_inverse":{ 46 | "demo_num_begin": 5, 47 | "demo_num_end": 6, 48 | "demo_cond_select": "random", 49 | "demo_qoi_select": "random", 50 | "quest_cond_select": "random", 51 | "quest_qoi_select": "random", 52 | "demo_cond_len_in_use_begin": 41, 53 | "demo_cond_len_in_use_end": 51, 54 | "demo_qoi_len_in_use_begin": 41, 55 | "demo_qoi_len_in_use_end": 51, 56 | "quest_cond_len_in_use_begin": 41, 57 | "quest_cond_len_in_use_end": 51, 58 | "quest_qoi_len_in_use_begin": 41, 59 | "quest_qoi_len_in_use_end": 51}, 60 | "ode_forward":{ 61 | "demo_num_begin": 5, 62 | "demo_num_end": 6, 63 | "demo_cond_select": "even", 64 | "demo_qoi_select": "even", 65 | "quest_cond_select": "even", 66 | "quest_qoi_select": "even", 67 | "demo_qoi_len_in_use_begin": 41, 68 | "demo_qoi_len_in_use_end": 51, 69 | "quest_qoi_len_in_use_begin": 41, 70 | "quest_qoi_len_in_use_end": 51}, 71 | "ode_inverse":{ 72 | "demo_num_begin": 5, 73 | "demo_num_end": 6, 74 | "demo_cond_select": "even", 75 | "demo_qoi_select": "even", 76 | "quest_cond_select": "even", 77 | "quest_qoi_select": "even", 78 | "demo_qoi_len_in_use_begin": 40, 79 | "demo_qoi_len_in_use_end": 50, 80 | "quest_qoi_len_in_use_begin": 40, 81 | "quest_qoi_len_in_use_end": 50}, 82 | "time_series":{ 83 | "demo_num_begin": 5, 84 | "demo_num_end": 6, 85 | "demo_cond_select": "random", 86 | "demo_qoi_select": "random", 87 | "quest_cond_select": "random", 88 | "quest_qoi_select": "random", 89 | "demo_cond_len_in_use_begin": 41, 90 | "demo_cond_len_in_use_end": 51, 91 | "demo_qoi_len_in_use_begin": 41, 92 | "demo_qoi_len_in_use_end": 51, 93 | "quest_cond_len_in_use_begin": 41, 94 | "quest_cond_len_in_use_end": 51, 95 | "quest_qoi_len_in_use_begin": 41, 96 | "quest_qoi_len_in_use_end": 51}, 97 | "mfc_gparam_forward":{ 98 | "demo_num_begin": 5, 99 | "demo_num_end": 6, 100 | "demo_cond_select": "random", 101 | "demo_qoi_select": "random", 102 | "quest_cond_select": "random", 103 | "quest_qoi_select": "first", 104 | "demo_cond_len_in_use_begin": 41, 105 | "demo_cond_len_in_use_end": 51, 106 | "demo_qoi_len_in_use_begin": 41, 107 | "demo_qoi_len_in_use_end": 51, 108 | "quest_cond_len_in_use_begin": 41, 109 | "quest_cond_len_in_use_end": 51, 110 | "quest_qoi_len_in_use_begin": 2600, 111 | "quest_qoi_len_in_use_end": 2601}, 112 | "mfc_rhoparam_forward":{ 113 | "demo_num_begin": 5, 114 | "demo_num_end": 6, 115 | "demo_cond_select": "random", 116 | "demo_qoi_select": "random", 117 | "quest_cond_select": "random", 118 | "quest_qoi_select": "random", 119 | "demo_cond_len_in_use_begin": 41, 120 | "demo_cond_len_in_use_end": 51, 121 | "demo_qoi_len_in_use_begin": 41, 122 | "demo_qoi_len_in_use_end": 51, 123 | "quest_cond_len_in_use_begin": 41, 124 | "quest_cond_len_in_use_end": 51, 125 | "quest_qoi_len_in_use_begin": 41, 126 | "quest_qoi_len_in_use_end": 51}, 127 | "others":{ 128 | "demo_num_begin": 5, 129 | "demo_num_end": 6, 130 | "demo_cond_select": "random", 131 | "demo_qoi_select": "random", 132 | "quest_cond_select": "random", 133 | "quest_qoi_select": "random", 134 | "demo_cond_len_in_use_begin": 41, 135 | "demo_cond_len_in_use_end": 51, 136 | "demo_qoi_len_in_use_begin": 41, 137 | "demo_qoi_len_in_use_end": 51, 138 | "quest_cond_len_in_use_begin": 41, 139 | "quest_cond_len_in_use_end": 51, 140 | "quest_qoi_len_in_use_begin": 41, 141 | "quest_qoi_len_in_use_end": 51} 142 | } -------------------------------------------------------------------------------- /icon-lm/config_data/test_lm_precise_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "random", 13 | "load_caption": {"0":"[]", "1":"list(range(80,100))"}, 14 | "caption_dir": "captions_1009", 15 | "return_raw": false, 16 | "load_list": ["input_id"], 17 | "pde_mask":{ 18 | "demo_num_begin": 5, 19 | "demo_num_end": 6, 20 | "start_ind_begin": 5, 21 | "start_ind_end": 45, 22 | "demo_cond_len_in_use_begin": 41, 23 | "demo_cond_len_in_use_end": 51, 24 | "demo_qoi_len_in_use_begin": 41, 25 | "demo_qoi_len_in_use_end": 51, 26 | "quest_cond_len_in_use_begin": 41, 27 | "quest_cond_len_in_use_end": 51, 28 | "quest_qoi_len_in_use_begin": 41, 29 | "quest_qoi_len_in_use_end": 51}, 30 | "pde_spatial_forward":{ 31 | "demo_num_begin": 5, 32 | "demo_num_end": 6, 33 | "demo_cond_select": "random", 34 | "demo_qoi_select": "random", 35 | "quest_cond_select": "random", 36 | "quest_qoi_select": "random", 37 | "demo_cond_len_in_use_begin": 41, 38 | "demo_cond_len_in_use_end": 51, 39 | "demo_qoi_len_in_use_begin": 41, 40 | "demo_qoi_len_in_use_end": 51, 41 | "quest_cond_len_in_use_begin": 41, 42 | "quest_cond_len_in_use_end": 51, 43 | "quest_qoi_len_in_use_begin": 41, 44 | "quest_qoi_len_in_use_end": 51}, 45 | "pde_spatial_inverse":{ 46 | "demo_num_begin": 5, 47 | "demo_num_end": 6, 48 | "demo_cond_select": "random", 49 | "demo_qoi_select": "random", 50 | "quest_cond_select": "random", 51 | "quest_qoi_select": "random", 52 | "demo_cond_len_in_use_begin": 41, 53 | "demo_cond_len_in_use_end": 51, 54 | "demo_qoi_len_in_use_begin": 41, 55 | "demo_qoi_len_in_use_end": 51, 56 | "quest_cond_len_in_use_begin": 41, 57 | "quest_cond_len_in_use_end": 51, 58 | "quest_qoi_len_in_use_begin": 41, 59 | "quest_qoi_len_in_use_end": 51}, 60 | "ode_forward":{ 61 | "demo_num_begin": 5, 62 | "demo_num_end": 6, 63 | "demo_cond_select": "even", 64 | "demo_qoi_select": "even", 65 | "quest_cond_select": "even", 66 | "quest_qoi_select": "even", 67 | "demo_qoi_len_in_use_begin": 41, 68 | "demo_qoi_len_in_use_end": 51, 69 | "quest_qoi_len_in_use_begin": 41, 70 | "quest_qoi_len_in_use_end": 51}, 71 | "ode_inverse":{ 72 | "demo_num_begin": 5, 73 | "demo_num_end": 6, 74 | "demo_cond_select": "even", 75 | "demo_qoi_select": "even", 76 | "quest_cond_select": "even", 77 | "quest_qoi_select": "even", 78 | "demo_qoi_len_in_use_begin": 40, 79 | "demo_qoi_len_in_use_end": 50, 80 | "quest_qoi_len_in_use_begin": 40, 81 | "quest_qoi_len_in_use_end": 50}, 82 | "time_series":{ 83 | "demo_num_begin": 5, 84 | "demo_num_end": 6, 85 | "demo_cond_select": "random", 86 | "demo_qoi_select": "random", 87 | "quest_cond_select": "random", 88 | "quest_qoi_select": "random", 89 | "demo_cond_len_in_use_begin": 41, 90 | "demo_cond_len_in_use_end": 51, 91 | "demo_qoi_len_in_use_begin": 41, 92 | "demo_qoi_len_in_use_end": 51, 93 | "quest_cond_len_in_use_begin": 41, 94 | "quest_cond_len_in_use_end": 51, 95 | "quest_qoi_len_in_use_begin": 41, 96 | "quest_qoi_len_in_use_end": 51}, 97 | "mfc_gparam_forward":{ 98 | "demo_num_begin": 5, 99 | "demo_num_end": 6, 100 | "demo_cond_select": "random", 101 | "demo_qoi_select": "random", 102 | "quest_cond_select": "random", 103 | "quest_qoi_select": "random", 104 | "demo_cond_len_in_use_begin": 41, 105 | "demo_cond_len_in_use_end": 51, 106 | "demo_qoi_len_in_use_begin": 41, 107 | "demo_qoi_len_in_use_end": 51, 108 | "quest_cond_len_in_use_begin": 41, 109 | "quest_cond_len_in_use_end": 51, 110 | "quest_qoi_len_in_use_begin": 41, 111 | "quest_qoi_len_in_use_end": 51}, 112 | "mfc_rhoparam_forward":{ 113 | "demo_num_begin": 5, 114 | "demo_num_end": 6, 115 | "demo_cond_select": "random", 116 | "demo_qoi_select": "random", 117 | "quest_cond_select": "random", 118 | "quest_qoi_select": "random", 119 | "demo_cond_len_in_use_begin": 41, 120 | "demo_cond_len_in_use_end": 51, 121 | "demo_qoi_len_in_use_begin": 41, 122 | "demo_qoi_len_in_use_end": 51, 123 | "quest_cond_len_in_use_begin": 41, 124 | "quest_cond_len_in_use_end": 51, 125 | "quest_qoi_len_in_use_begin": 41, 126 | "quest_qoi_len_in_use_end": 51}, 127 | "others":{ 128 | "demo_num_begin": 5, 129 | "demo_num_end": 6, 130 | "demo_cond_select": "random", 131 | "demo_qoi_select": "random", 132 | "quest_cond_select": "random", 133 | "quest_qoi_select": "random", 134 | "demo_cond_len_in_use_begin": 41, 135 | "demo_cond_len_in_use_end": 51, 136 | "demo_qoi_len_in_use_begin": 41, 137 | "demo_qoi_len_in_use_end": 51, 138 | "quest_cond_len_in_use_begin": 41, 139 | "quest_cond_len_in_use_end": 51, 140 | "quest_qoi_len_in_use_begin": 41, 141 | "quest_qoi_len_in_use_end": 51} 142 | } -------------------------------------------------------------------------------- /icon-lm/config_data/test_lm_vague_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "random", 13 | "load_caption": {"0":"list(range(80,100))", "1":"[]"}, 14 | "caption_dir": "captions_1009", 15 | "return_raw": false, 16 | "load_list": ["input_id"], 17 | "pde_mask":{ 18 | "demo_num_begin": 5, 19 | "demo_num_end": 6, 20 | "start_ind_begin": 5, 21 | "start_ind_end": 45, 22 | "demo_cond_len_in_use_begin": 41, 23 | "demo_cond_len_in_use_end": 51, 24 | "demo_qoi_len_in_use_begin": 41, 25 | "demo_qoi_len_in_use_end": 51, 26 | "quest_cond_len_in_use_begin": 41, 27 | "quest_cond_len_in_use_end": 51, 28 | "quest_qoi_len_in_use_begin": 41, 29 | "quest_qoi_len_in_use_end": 51}, 30 | "pde_spatial_forward":{ 31 | "demo_num_begin": 5, 32 | "demo_num_end": 6, 33 | "demo_cond_select": "random", 34 | "demo_qoi_select": "random", 35 | "quest_cond_select": "random", 36 | "quest_qoi_select": "random", 37 | "demo_cond_len_in_use_begin": 41, 38 | "demo_cond_len_in_use_end": 51, 39 | "demo_qoi_len_in_use_begin": 41, 40 | "demo_qoi_len_in_use_end": 51, 41 | "quest_cond_len_in_use_begin": 41, 42 | "quest_cond_len_in_use_end": 51, 43 | "quest_qoi_len_in_use_begin": 41, 44 | "quest_qoi_len_in_use_end": 51}, 45 | "pde_spatial_inverse":{ 46 | "demo_num_begin": 5, 47 | "demo_num_end": 6, 48 | "demo_cond_select": "random", 49 | "demo_qoi_select": "random", 50 | "quest_cond_select": "random", 51 | "quest_qoi_select": "random", 52 | "demo_cond_len_in_use_begin": 41, 53 | "demo_cond_len_in_use_end": 51, 54 | "demo_qoi_len_in_use_begin": 41, 55 | "demo_qoi_len_in_use_end": 51, 56 | "quest_cond_len_in_use_begin": 41, 57 | "quest_cond_len_in_use_end": 51, 58 | "quest_qoi_len_in_use_begin": 41, 59 | "quest_qoi_len_in_use_end": 51}, 60 | "ode_forward":{ 61 | "demo_num_begin": 5, 62 | "demo_num_end": 6, 63 | "demo_cond_select": "even", 64 | "demo_qoi_select": "even", 65 | "quest_cond_select": "even", 66 | "quest_qoi_select": "even", 67 | "demo_qoi_len_in_use_begin": 41, 68 | "demo_qoi_len_in_use_end": 51, 69 | "quest_qoi_len_in_use_begin": 41, 70 | "quest_qoi_len_in_use_end": 51}, 71 | "ode_inverse":{ 72 | "demo_num_begin": 5, 73 | "demo_num_end": 6, 74 | "demo_cond_select": "even", 75 | "demo_qoi_select": "even", 76 | "quest_cond_select": "even", 77 | "quest_qoi_select": "even", 78 | "demo_qoi_len_in_use_begin": 40, 79 | "demo_qoi_len_in_use_end": 50, 80 | "quest_qoi_len_in_use_begin": 40, 81 | "quest_qoi_len_in_use_end": 50}, 82 | "time_series":{ 83 | "demo_num_begin": 5, 84 | "demo_num_end": 6, 85 | "demo_cond_select": "random", 86 | "demo_qoi_select": "random", 87 | "quest_cond_select": "random", 88 | "quest_qoi_select": "random", 89 | "demo_cond_len_in_use_begin": 41, 90 | "demo_cond_len_in_use_end": 51, 91 | "demo_qoi_len_in_use_begin": 41, 92 | "demo_qoi_len_in_use_end": 51, 93 | "quest_cond_len_in_use_begin": 41, 94 | "quest_cond_len_in_use_end": 51, 95 | "quest_qoi_len_in_use_begin": 41, 96 | "quest_qoi_len_in_use_end": 51}, 97 | "mfc_gparam_forward":{ 98 | "demo_num_begin": 5, 99 | "demo_num_end": 6, 100 | "demo_cond_select": "random", 101 | "demo_qoi_select": "random", 102 | "quest_cond_select": "random", 103 | "quest_qoi_select": "random", 104 | "demo_cond_len_in_use_begin": 41, 105 | "demo_cond_len_in_use_end": 51, 106 | "demo_qoi_len_in_use_begin": 41, 107 | "demo_qoi_len_in_use_end": 51, 108 | "quest_cond_len_in_use_begin": 41, 109 | "quest_cond_len_in_use_end": 51, 110 | "quest_qoi_len_in_use_begin": 41, 111 | "quest_qoi_len_in_use_end": 51}, 112 | "mfc_rhoparam_forward":{ 113 | "demo_num_begin": 5, 114 | "demo_num_end": 6, 115 | "demo_cond_select": "random", 116 | "demo_qoi_select": "random", 117 | "quest_cond_select": "random", 118 | "quest_qoi_select": "random", 119 | "demo_cond_len_in_use_begin": 41, 120 | "demo_cond_len_in_use_end": 51, 121 | "demo_qoi_len_in_use_begin": 41, 122 | "demo_qoi_len_in_use_end": 51, 123 | "quest_cond_len_in_use_begin": 41, 124 | "quest_cond_len_in_use_end": 51, 125 | "quest_qoi_len_in_use_begin": 41, 126 | "quest_qoi_len_in_use_end": 51}, 127 | "others":{ 128 | "demo_num_begin": 5, 129 | "demo_num_end": 6, 130 | "demo_cond_select": "random", 131 | "demo_qoi_select": "random", 132 | "quest_cond_select": "random", 133 | "quest_qoi_select": "random", 134 | "demo_cond_len_in_use_begin": 41, 135 | "demo_cond_len_in_use_end": 51, 136 | "demo_qoi_len_in_use_begin": 41, 137 | "demo_qoi_len_in_use_end": 51, 138 | "quest_cond_len_in_use_begin": 41, 139 | "quest_cond_len_in_use_end": 51, 140 | "quest_qoi_len_in_use_begin": 41, 141 | "quest_qoi_len_in_use_end": 51} 142 | } -------------------------------------------------------------------------------- /icon-lm/config_data/test_lm_weno_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 100, 5 | "demo_qoi_len": 100, 6 | "quest_cond_len": 100, 7 | "quest_qoi_len": 100, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "all", 13 | "return_raw": false, 14 | "load_list": [], 15 | "pde_mask":{ 16 | "demo_num_begin": 5, 17 | "demo_num_end": 6, 18 | "start_ind_begin": 5, 19 | "start_ind_end": 45, 20 | "demo_cond_len_in_use_begin": 41, 21 | "demo_cond_len_in_use_end": 51, 22 | "demo_qoi_len_in_use_begin": 41, 23 | "demo_qoi_len_in_use_end": 51, 24 | "quest_cond_len_in_use_begin": 41, 25 | "quest_cond_len_in_use_end": 51, 26 | "quest_qoi_len_in_use_begin": 41, 27 | "quest_qoi_len_in_use_end": 51}, 28 | "pde_spatial_forward":{ 29 | "demo_num_begin": 5, 30 | "demo_num_end": 6, 31 | "demo_cond_select": "random", 32 | "demo_qoi_select": "random", 33 | "quest_cond_select": "random", 34 | "quest_qoi_select": "random", 35 | "demo_cond_len_in_use_begin": 41, 36 | "demo_cond_len_in_use_end": 51, 37 | "demo_qoi_len_in_use_begin": 41, 38 | "demo_qoi_len_in_use_end": 51, 39 | "quest_cond_len_in_use_begin": 41, 40 | "quest_cond_len_in_use_end": 51, 41 | "quest_qoi_len_in_use_begin": 41, 42 | "quest_qoi_len_in_use_end": 51}, 43 | "pde_spatial_inverse":{ 44 | "demo_num_begin": 5, 45 | "demo_num_end": 6, 46 | "demo_cond_select": "random", 47 | "demo_qoi_select": "random", 48 | "quest_cond_select": "random", 49 | "quest_qoi_select": "random", 50 | "demo_cond_len_in_use_begin": 41, 51 | "demo_cond_len_in_use_end": 51, 52 | "demo_qoi_len_in_use_begin": 41, 53 | "demo_qoi_len_in_use_end": 51, 54 | "quest_cond_len_in_use_begin": 41, 55 | "quest_cond_len_in_use_end": 51, 56 | "quest_qoi_len_in_use_begin": 41, 57 | "quest_qoi_len_in_use_end": 51}, 58 | "ode_forward":{ 59 | "demo_num_begin": 5, 60 | "demo_num_end": 6, 61 | "demo_cond_select": "even", 62 | "demo_qoi_select": "even", 63 | "quest_cond_select": "even", 64 | "quest_qoi_select": "even", 65 | "demo_qoi_len_in_use_begin": 41, 66 | "demo_qoi_len_in_use_end": 51, 67 | "quest_qoi_len_in_use_begin": 41, 68 | "quest_qoi_len_in_use_end": 51}, 69 | "ode_inverse":{ 70 | "demo_num_begin": 5, 71 | "demo_num_end": 6, 72 | "demo_cond_select": "even", 73 | "demo_qoi_select": "even", 74 | "quest_cond_select": "even", 75 | "quest_qoi_select": "even", 76 | "demo_qoi_len_in_use_begin": 40, 77 | "demo_qoi_len_in_use_end": 50, 78 | "quest_qoi_len_in_use_begin": 40, 79 | "quest_qoi_len_in_use_end": 50}, 80 | "time_series":{ 81 | "demo_num_begin": 5, 82 | "demo_num_end": 6, 83 | "demo_cond_select": "random", 84 | "demo_qoi_select": "random", 85 | "quest_cond_select": "random", 86 | "quest_qoi_select": "random", 87 | "demo_cond_len_in_use_begin": 41, 88 | "demo_cond_len_in_use_end": 51, 89 | "demo_qoi_len_in_use_begin": 41, 90 | "demo_qoi_len_in_use_end": 51, 91 | "quest_cond_len_in_use_begin": 41, 92 | "quest_cond_len_in_use_end": 51, 93 | "quest_qoi_len_in_use_begin": 41, 94 | "quest_qoi_len_in_use_end": 51}, 95 | "mfc_gparam_forward":{ 96 | "demo_num_begin": 5, 97 | "demo_num_end": 6, 98 | "demo_cond_select": "random", 99 | "demo_qoi_select": "random", 100 | "quest_cond_select": "random", 101 | "quest_qoi_select": "random", 102 | "demo_cond_len_in_use_begin": 41, 103 | "demo_cond_len_in_use_end": 51, 104 | "demo_qoi_len_in_use_begin": 41, 105 | "demo_qoi_len_in_use_end": 51, 106 | "quest_cond_len_in_use_begin": 41, 107 | "quest_cond_len_in_use_end": 51, 108 | "quest_qoi_len_in_use_begin": 41, 109 | "quest_qoi_len_in_use_end": 51}, 110 | "mfc_rhoparam_forward":{ 111 | "demo_num_begin": 5, 112 | "demo_num_end": 6, 113 | "demo_cond_select": "random", 114 | "demo_qoi_select": "random", 115 | "quest_cond_select": "random", 116 | "quest_qoi_select": "random", 117 | "demo_cond_len_in_use_begin": 41, 118 | "demo_cond_len_in_use_end": 51, 119 | "demo_qoi_len_in_use_begin": 41, 120 | "demo_qoi_len_in_use_end": 51, 121 | "quest_cond_len_in_use_begin": 41, 122 | "quest_cond_len_in_use_end": 51, 123 | "quest_qoi_len_in_use_begin": 41, 124 | "quest_qoi_len_in_use_end": 51}, 125 | "others":{ 126 | "demo_num_begin": 5, 127 | "demo_num_end": 6, 128 | "demo_cond_select": "first", 129 | "demo_qoi_select": "first", 130 | "quest_cond_select": "first", 131 | "quest_qoi_select": "first", 132 | "demo_cond_len_in_use_begin": 100, 133 | "demo_cond_len_in_use_end": 101, 134 | "demo_qoi_len_in_use_begin": 100, 135 | "demo_qoi_len_in_use_end": 101, 136 | "quest_cond_len_in_use_begin": 100, 137 | "quest_cond_len_in_use_end": 101, 138 | "quest_qoi_len_in_use_begin": 100, 139 | "quest_qoi_len_in_use_end": 101} 140 | } -------------------------------------------------------------------------------- /icon-lm/config_data/train_icon_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "random", 12 | "select_caption": "random", 13 | "return_raw": false, 14 | "load_list": [], 15 | "pde_mask":{ 16 | "demo_num_begin": 1, 17 | "demo_num_end": 6, 18 | "start_ind_begin": 5, 19 | "start_ind_end": 45, 20 | "demo_cond_len_in_use_begin": 41, 21 | "demo_cond_len_in_use_end": 51, 22 | "demo_qoi_len_in_use_begin": 41, 23 | "demo_qoi_len_in_use_end": 51, 24 | "quest_cond_len_in_use_begin": 41, 25 | "quest_cond_len_in_use_end": 51, 26 | "quest_qoi_len_in_use_begin": 41, 27 | "quest_qoi_len_in_use_end": 51}, 28 | "pde_spatial_forward":{ 29 | "demo_num_begin": 1, 30 | "demo_num_end": 6, 31 | "demo_cond_select": "random", 32 | "demo_qoi_select": "random", 33 | "quest_cond_select": "random", 34 | "quest_qoi_select": "random", 35 | "demo_cond_len_in_use_begin": 41, 36 | "demo_cond_len_in_use_end": 51, 37 | "demo_qoi_len_in_use_begin": 41, 38 | "demo_qoi_len_in_use_end": 51, 39 | "quest_cond_len_in_use_begin": 41, 40 | "quest_cond_len_in_use_end": 51, 41 | "quest_qoi_len_in_use_begin": 41, 42 | "quest_qoi_len_in_use_end": 51}, 43 | "pde_spatial_inverse":{ 44 | "demo_num_begin": 1, 45 | "demo_num_end": 6, 46 | "demo_cond_select": "random", 47 | "demo_qoi_select": "random", 48 | "quest_cond_select": "random", 49 | "quest_qoi_select": "random", 50 | "demo_cond_len_in_use_begin": 41, 51 | "demo_cond_len_in_use_end": 51, 52 | "demo_qoi_len_in_use_begin": 41, 53 | "demo_qoi_len_in_use_end": 51, 54 | "quest_cond_len_in_use_begin": 41, 55 | "quest_cond_len_in_use_end": 51, 56 | "quest_qoi_len_in_use_begin": 41, 57 | "quest_qoi_len_in_use_end": 51}, 58 | "ode_forward":{ 59 | "demo_num_begin": 1, 60 | "demo_num_end": 6, 61 | "demo_cond_select": "even", 62 | "demo_qoi_select": "even", 63 | "quest_cond_select": "even", 64 | "quest_qoi_select": "even", 65 | "demo_qoi_len_in_use_begin": 41, 66 | "demo_qoi_len_in_use_end": 51, 67 | "quest_qoi_len_in_use_begin": 41, 68 | "quest_qoi_len_in_use_end": 51}, 69 | "ode_inverse":{ 70 | "demo_num_begin": 1, 71 | "demo_num_end": 6, 72 | "demo_cond_select": "even", 73 | "demo_qoi_select": "even", 74 | "quest_cond_select": "even", 75 | "quest_qoi_select": "even", 76 | "demo_qoi_len_in_use_begin": 40, 77 | "demo_qoi_len_in_use_end": 50, 78 | "quest_qoi_len_in_use_begin": 40, 79 | "quest_qoi_len_in_use_end": 50}, 80 | "time_series":{ 81 | "demo_num_begin": 1, 82 | "demo_num_end": 6, 83 | "demo_cond_select": "random", 84 | "demo_qoi_select": "random", 85 | "quest_cond_select": "random", 86 | "quest_qoi_select": "random", 87 | "demo_cond_len_in_use_begin": 41, 88 | "demo_cond_len_in_use_end": 51, 89 | "demo_qoi_len_in_use_begin": 41, 90 | "demo_qoi_len_in_use_end": 51, 91 | "quest_cond_len_in_use_begin": 41, 92 | "quest_cond_len_in_use_end": 51, 93 | "quest_qoi_len_in_use_begin": 41, 94 | "quest_qoi_len_in_use_end": 51}, 95 | "mfc_gparam_forward":{ 96 | "demo_num_begin": 1, 97 | "demo_num_end": 6, 98 | "demo_cond_select": "random", 99 | "demo_qoi_select": "random", 100 | "quest_cond_select": "random", 101 | "quest_qoi_select": "random", 102 | "demo_cond_len_in_use_begin": 41, 103 | "demo_cond_len_in_use_end": 51, 104 | "demo_qoi_len_in_use_begin": 41, 105 | "demo_qoi_len_in_use_end": 51, 106 | "quest_cond_len_in_use_begin": 41, 107 | "quest_cond_len_in_use_end": 51, 108 | "quest_qoi_len_in_use_begin": 41, 109 | "quest_qoi_len_in_use_end": 51}, 110 | "mfc_rhoparam_forward":{ 111 | "demo_num_begin": 1, 112 | "demo_num_end": 6, 113 | "demo_cond_select": "random", 114 | "demo_qoi_select": "random", 115 | "quest_cond_select": "random", 116 | "quest_qoi_select": "random", 117 | "demo_cond_len_in_use_begin": 41, 118 | "demo_cond_len_in_use_end": 51, 119 | "demo_qoi_len_in_use_begin": 41, 120 | "demo_qoi_len_in_use_end": 51, 121 | "quest_cond_len_in_use_begin": 41, 122 | "quest_cond_len_in_use_end": 51, 123 | "quest_qoi_len_in_use_begin": 41, 124 | "quest_qoi_len_in_use_end": 51}, 125 | "others":{ 126 | "demo_num_begin": 1, 127 | "demo_num_end": 6, 128 | "demo_cond_select": "random", 129 | "demo_qoi_select": "random", 130 | "quest_cond_select": "random", 131 | "quest_qoi_select": "random", 132 | "demo_cond_len_in_use_begin": 41, 133 | "demo_cond_len_in_use_end": 51, 134 | "demo_qoi_len_in_use_begin": 41, 135 | "demo_qoi_len_in_use_end": 51, 136 | "quest_cond_len_in_use_begin": 41, 137 | "quest_cond_len_in_use_end": 51, 138 | "quest_qoi_len_in_use_begin": 41, 139 | "quest_qoi_len_in_use_end": 51} 140 | } -------------------------------------------------------------------------------- /icon-lm/config_data/train_lm_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "random", 12 | "select_caption": "random", 13 | "load_caption": {"0":"list(range(80))", "1":"list(range(80))"}, 14 | "caption_dir": "captions_1009", 15 | "return_raw": false, 16 | "load_list": ["input_id"], 17 | "pde_mask":{ 18 | "demo_num_begin": 5, 19 | "demo_num_end": 6, 20 | "start_ind_begin": 5, 21 | "start_ind_end": 45, 22 | "demo_cond_len_in_use_begin": 41, 23 | "demo_cond_len_in_use_end": 51, 24 | "demo_qoi_len_in_use_begin": 41, 25 | "demo_qoi_len_in_use_end": 51, 26 | "quest_cond_len_in_use_begin": 41, 27 | "quest_cond_len_in_use_end": 51, 28 | "quest_qoi_len_in_use_begin": 41, 29 | "quest_qoi_len_in_use_end": 51}, 30 | "pde_spatial_forward":{ 31 | "demo_num_begin": 5, 32 | "demo_num_end": 6, 33 | "demo_cond_select": "random", 34 | "demo_qoi_select": "random", 35 | "quest_cond_select": "random", 36 | "quest_qoi_select": "random", 37 | "demo_cond_len_in_use_begin": 41, 38 | "demo_cond_len_in_use_end": 51, 39 | "demo_qoi_len_in_use_begin": 41, 40 | "demo_qoi_len_in_use_end": 51, 41 | "quest_cond_len_in_use_begin": 41, 42 | "quest_cond_len_in_use_end": 51, 43 | "quest_qoi_len_in_use_begin": 41, 44 | "quest_qoi_len_in_use_end": 51}, 45 | "pde_spatial_inverse":{ 46 | "demo_num_begin": 5, 47 | "demo_num_end": 6, 48 | "demo_cond_select": "random", 49 | "demo_qoi_select": "random", 50 | "quest_cond_select": "random", 51 | "quest_qoi_select": "random", 52 | "demo_cond_len_in_use_begin": 41, 53 | "demo_cond_len_in_use_end": 51, 54 | "demo_qoi_len_in_use_begin": 41, 55 | "demo_qoi_len_in_use_end": 51, 56 | "quest_cond_len_in_use_begin": 41, 57 | "quest_cond_len_in_use_end": 51, 58 | "quest_qoi_len_in_use_begin": 41, 59 | "quest_qoi_len_in_use_end": 51}, 60 | "ode_forward":{ 61 | "demo_num_begin": 5, 62 | "demo_num_end": 6, 63 | "demo_cond_select": "even", 64 | "demo_qoi_select": "even", 65 | "quest_cond_select": "even", 66 | "quest_qoi_select": "even", 67 | "demo_qoi_len_in_use_begin": 41, 68 | "demo_qoi_len_in_use_end": 51, 69 | "quest_qoi_len_in_use_begin": 41, 70 | "quest_qoi_len_in_use_end": 51}, 71 | "ode_inverse":{ 72 | "demo_num_begin": 5, 73 | "demo_num_end": 6, 74 | "demo_cond_select": "even", 75 | "demo_qoi_select": "even", 76 | "quest_cond_select": "even", 77 | "quest_qoi_select": "even", 78 | "demo_qoi_len_in_use_begin": 40, 79 | "demo_qoi_len_in_use_end": 50, 80 | "quest_qoi_len_in_use_begin": 40, 81 | "quest_qoi_len_in_use_end": 50}, 82 | "time_series":{ 83 | "demo_num_begin": 5, 84 | "demo_num_end": 6, 85 | "demo_cond_select": "random", 86 | "demo_qoi_select": "random", 87 | "quest_cond_select": "random", 88 | "quest_qoi_select": "random", 89 | "demo_cond_len_in_use_begin": 41, 90 | "demo_cond_len_in_use_end": 51, 91 | "demo_qoi_len_in_use_begin": 41, 92 | "demo_qoi_len_in_use_end": 51, 93 | "quest_cond_len_in_use_begin": 41, 94 | "quest_cond_len_in_use_end": 51, 95 | "quest_qoi_len_in_use_begin": 41, 96 | "quest_qoi_len_in_use_end": 51}, 97 | "mfc_gparam_forward":{ 98 | "demo_num_begin": 5, 99 | "demo_num_end": 6, 100 | "demo_cond_select": "random", 101 | "demo_qoi_select": "random", 102 | "quest_cond_select": "random", 103 | "quest_qoi_select": "random", 104 | "demo_cond_len_in_use_begin": 41, 105 | "demo_cond_len_in_use_end": 51, 106 | "demo_qoi_len_in_use_begin": 41, 107 | "demo_qoi_len_in_use_end": 51, 108 | "quest_cond_len_in_use_begin": 41, 109 | "quest_cond_len_in_use_end": 51, 110 | "quest_qoi_len_in_use_begin": 41, 111 | "quest_qoi_len_in_use_end": 51}, 112 | "mfc_rhoparam_forward":{ 113 | "demo_num_begin": 5, 114 | "demo_num_end": 6, 115 | "demo_cond_select": "random", 116 | "demo_qoi_select": "random", 117 | "quest_cond_select": "random", 118 | "quest_qoi_select": "random", 119 | "demo_cond_len_in_use_begin": 41, 120 | "demo_cond_len_in_use_end": 51, 121 | "demo_qoi_len_in_use_begin": 41, 122 | "demo_qoi_len_in_use_end": 51, 123 | "quest_cond_len_in_use_begin": 41, 124 | "quest_cond_len_in_use_end": 51, 125 | "quest_qoi_len_in_use_begin": 41, 126 | "quest_qoi_len_in_use_end": 51}, 127 | "others":{ 128 | "demo_num_begin": 5, 129 | "demo_num_end": 6, 130 | "demo_cond_select": "random", 131 | "demo_qoi_select": "random", 132 | "quest_cond_select": "random", 133 | "quest_qoi_select": "random", 134 | "demo_cond_len_in_use_begin": 41, 135 | "demo_cond_len_in_use_end": 51, 136 | "demo_qoi_len_in_use_begin": 41, 137 | "demo_qoi_len_in_use_end": 51, 138 | "quest_cond_len_in_use_begin": 41, 139 | "quest_cond_len_in_use_end": 51, 140 | "quest_qoi_len_in_use_begin": 41, 141 | "quest_qoi_len_in_use_end": 51} 142 | } -------------------------------------------------------------------------------- /icon-lm/config_data/train_lm_precise_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "random", 13 | "load_caption": {"0":"[]", "1":"list(range(0,80))"}, 14 | "caption_dir": "captions_1009", 15 | "return_raw": false, 16 | "load_list": ["input_id"], 17 | "pde_mask":{ 18 | "demo_num_begin": 5, 19 | "demo_num_end": 6, 20 | "start_ind_begin": 5, 21 | "start_ind_end": 45, 22 | "demo_cond_len_in_use_begin": 41, 23 | "demo_cond_len_in_use_end": 51, 24 | "demo_qoi_len_in_use_begin": 41, 25 | "demo_qoi_len_in_use_end": 51, 26 | "quest_cond_len_in_use_begin": 41, 27 | "quest_cond_len_in_use_end": 51, 28 | "quest_qoi_len_in_use_begin": 41, 29 | "quest_qoi_len_in_use_end": 51}, 30 | "pde_spatial_forward":{ 31 | "demo_num_begin": 5, 32 | "demo_num_end": 6, 33 | "demo_cond_select": "random", 34 | "demo_qoi_select": "random", 35 | "quest_cond_select": "random", 36 | "quest_qoi_select": "random", 37 | "demo_cond_len_in_use_begin": 41, 38 | "demo_cond_len_in_use_end": 51, 39 | "demo_qoi_len_in_use_begin": 41, 40 | "demo_qoi_len_in_use_end": 51, 41 | "quest_cond_len_in_use_begin": 41, 42 | "quest_cond_len_in_use_end": 51, 43 | "quest_qoi_len_in_use_begin": 41, 44 | "quest_qoi_len_in_use_end": 51}, 45 | "pde_spatial_inverse":{ 46 | "demo_num_begin": 5, 47 | "demo_num_end": 6, 48 | "demo_cond_select": "random", 49 | "demo_qoi_select": "random", 50 | "quest_cond_select": "random", 51 | "quest_qoi_select": "random", 52 | "demo_cond_len_in_use_begin": 41, 53 | "demo_cond_len_in_use_end": 51, 54 | "demo_qoi_len_in_use_begin": 41, 55 | "demo_qoi_len_in_use_end": 51, 56 | "quest_cond_len_in_use_begin": 41, 57 | "quest_cond_len_in_use_end": 51, 58 | "quest_qoi_len_in_use_begin": 41, 59 | "quest_qoi_len_in_use_end": 51}, 60 | "ode_forward":{ 61 | "demo_num_begin": 5, 62 | "demo_num_end": 6, 63 | "demo_cond_select": "even", 64 | "demo_qoi_select": "even", 65 | "quest_cond_select": "even", 66 | "quest_qoi_select": "even", 67 | "demo_qoi_len_in_use_begin": 41, 68 | "demo_qoi_len_in_use_end": 51, 69 | "quest_qoi_len_in_use_begin": 41, 70 | "quest_qoi_len_in_use_end": 51}, 71 | "ode_inverse":{ 72 | "demo_num_begin": 5, 73 | "demo_num_end": 6, 74 | "demo_cond_select": "even", 75 | "demo_qoi_select": "even", 76 | "quest_cond_select": "even", 77 | "quest_qoi_select": "even", 78 | "demo_qoi_len_in_use_begin": 40, 79 | "demo_qoi_len_in_use_end": 50, 80 | "quest_qoi_len_in_use_begin": 40, 81 | "quest_qoi_len_in_use_end": 50}, 82 | "time_series":{ 83 | "demo_num_begin": 5, 84 | "demo_num_end": 6, 85 | "demo_cond_select": "random", 86 | "demo_qoi_select": "random", 87 | "quest_cond_select": "random", 88 | "quest_qoi_select": "random", 89 | "demo_cond_len_in_use_begin": 41, 90 | "demo_cond_len_in_use_end": 51, 91 | "demo_qoi_len_in_use_begin": 41, 92 | "demo_qoi_len_in_use_end": 51, 93 | "quest_cond_len_in_use_begin": 41, 94 | "quest_cond_len_in_use_end": 51, 95 | "quest_qoi_len_in_use_begin": 41, 96 | "quest_qoi_len_in_use_end": 51}, 97 | "mfc_gparam_forward":{ 98 | "demo_num_begin": 5, 99 | "demo_num_end": 6, 100 | "demo_cond_select": "random", 101 | "demo_qoi_select": "random", 102 | "quest_cond_select": "random", 103 | "quest_qoi_select": "random", 104 | "demo_cond_len_in_use_begin": 41, 105 | "demo_cond_len_in_use_end": 51, 106 | "demo_qoi_len_in_use_begin": 41, 107 | "demo_qoi_len_in_use_end": 51, 108 | "quest_cond_len_in_use_begin": 41, 109 | "quest_cond_len_in_use_end": 51, 110 | "quest_qoi_len_in_use_begin": 41, 111 | "quest_qoi_len_in_use_end": 51}, 112 | "mfc_rhoparam_forward":{ 113 | "demo_num_begin": 5, 114 | "demo_num_end": 6, 115 | "demo_cond_select": "random", 116 | "demo_qoi_select": "random", 117 | "quest_cond_select": "random", 118 | "quest_qoi_select": "random", 119 | "demo_cond_len_in_use_begin": 41, 120 | "demo_cond_len_in_use_end": 51, 121 | "demo_qoi_len_in_use_begin": 41, 122 | "demo_qoi_len_in_use_end": 51, 123 | "quest_cond_len_in_use_begin": 41, 124 | "quest_cond_len_in_use_end": 51, 125 | "quest_qoi_len_in_use_begin": 41, 126 | "quest_qoi_len_in_use_end": 51}, 127 | "others":{ 128 | "demo_num_begin": 5, 129 | "demo_num_end": 6, 130 | "demo_cond_select": "random", 131 | "demo_qoi_select": "random", 132 | "quest_cond_select": "random", 133 | "quest_qoi_select": "random", 134 | "demo_cond_len_in_use_begin": 41, 135 | "demo_cond_len_in_use_end": 51, 136 | "demo_qoi_len_in_use_begin": 41, 137 | "demo_qoi_len_in_use_end": 51, 138 | "quest_cond_len_in_use_begin": 41, 139 | "quest_cond_len_in_use_end": 51, 140 | "quest_qoi_len_in_use_begin": 41, 141 | "quest_qoi_len_in_use_end": 51} 142 | } -------------------------------------------------------------------------------- /icon-lm/config_data/train_lm_vague_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 50, 5 | "demo_qoi_len": 50, 6 | "quest_cond_len": 50, 7 | "quest_qoi_len": 50, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "ordered", 12 | "select_caption": "random", 13 | "load_caption": {"0":"list(range(0,80))", "1":"[]"}, 14 | "caption_dir": "captions_1009", 15 | "return_raw": false, 16 | "load_list": ["input_id"], 17 | "pde_mask":{ 18 | "demo_num_begin": 5, 19 | "demo_num_end": 6, 20 | "start_ind_begin": 5, 21 | "start_ind_end": 45, 22 | "demo_cond_len_in_use_begin": 41, 23 | "demo_cond_len_in_use_end": 51, 24 | "demo_qoi_len_in_use_begin": 41, 25 | "demo_qoi_len_in_use_end": 51, 26 | "quest_cond_len_in_use_begin": 41, 27 | "quest_cond_len_in_use_end": 51, 28 | "quest_qoi_len_in_use_begin": 41, 29 | "quest_qoi_len_in_use_end": 51}, 30 | "pde_spatial_forward":{ 31 | "demo_num_begin": 5, 32 | "demo_num_end": 6, 33 | "demo_cond_select": "random", 34 | "demo_qoi_select": "random", 35 | "quest_cond_select": "random", 36 | "quest_qoi_select": "random", 37 | "demo_cond_len_in_use_begin": 41, 38 | "demo_cond_len_in_use_end": 51, 39 | "demo_qoi_len_in_use_begin": 41, 40 | "demo_qoi_len_in_use_end": 51, 41 | "quest_cond_len_in_use_begin": 41, 42 | "quest_cond_len_in_use_end": 51, 43 | "quest_qoi_len_in_use_begin": 41, 44 | "quest_qoi_len_in_use_end": 51}, 45 | "pde_spatial_inverse":{ 46 | "demo_num_begin": 5, 47 | "demo_num_end": 6, 48 | "demo_cond_select": "random", 49 | "demo_qoi_select": "random", 50 | "quest_cond_select": "random", 51 | "quest_qoi_select": "random", 52 | "demo_cond_len_in_use_begin": 41, 53 | "demo_cond_len_in_use_end": 51, 54 | "demo_qoi_len_in_use_begin": 41, 55 | "demo_qoi_len_in_use_end": 51, 56 | "quest_cond_len_in_use_begin": 41, 57 | "quest_cond_len_in_use_end": 51, 58 | "quest_qoi_len_in_use_begin": 41, 59 | "quest_qoi_len_in_use_end": 51}, 60 | "ode_forward":{ 61 | "demo_num_begin": 5, 62 | "demo_num_end": 6, 63 | "demo_cond_select": "even", 64 | "demo_qoi_select": "even", 65 | "quest_cond_select": "even", 66 | "quest_qoi_select": "even", 67 | "demo_qoi_len_in_use_begin": 41, 68 | "demo_qoi_len_in_use_end": 51, 69 | "quest_qoi_len_in_use_begin": 41, 70 | "quest_qoi_len_in_use_end": 51}, 71 | "ode_inverse":{ 72 | "demo_num_begin": 5, 73 | "demo_num_end": 6, 74 | "demo_cond_select": "even", 75 | "demo_qoi_select": "even", 76 | "quest_cond_select": "even", 77 | "quest_qoi_select": "even", 78 | "demo_qoi_len_in_use_begin": 40, 79 | "demo_qoi_len_in_use_end": 50, 80 | "quest_qoi_len_in_use_begin": 40, 81 | "quest_qoi_len_in_use_end": 50}, 82 | "time_series":{ 83 | "demo_num_begin": 5, 84 | "demo_num_end": 6, 85 | "demo_cond_select": "random", 86 | "demo_qoi_select": "random", 87 | "quest_cond_select": "random", 88 | "quest_qoi_select": "random", 89 | "demo_cond_len_in_use_begin": 41, 90 | "demo_cond_len_in_use_end": 51, 91 | "demo_qoi_len_in_use_begin": 41, 92 | "demo_qoi_len_in_use_end": 51, 93 | "quest_cond_len_in_use_begin": 41, 94 | "quest_cond_len_in_use_end": 51, 95 | "quest_qoi_len_in_use_begin": 41, 96 | "quest_qoi_len_in_use_end": 51}, 97 | "mfc_gparam_forward":{ 98 | "demo_num_begin": 5, 99 | "demo_num_end": 6, 100 | "demo_cond_select": "random", 101 | "demo_qoi_select": "random", 102 | "quest_cond_select": "random", 103 | "quest_qoi_select": "random", 104 | "demo_cond_len_in_use_begin": 41, 105 | "demo_cond_len_in_use_end": 51, 106 | "demo_qoi_len_in_use_begin": 41, 107 | "demo_qoi_len_in_use_end": 51, 108 | "quest_cond_len_in_use_begin": 41, 109 | "quest_cond_len_in_use_end": 51, 110 | "quest_qoi_len_in_use_begin": 41, 111 | "quest_qoi_len_in_use_end": 51}, 112 | "mfc_rhoparam_forward":{ 113 | "demo_num_begin": 5, 114 | "demo_num_end": 6, 115 | "demo_cond_select": "random", 116 | "demo_qoi_select": "random", 117 | "quest_cond_select": "random", 118 | "quest_qoi_select": "random", 119 | "demo_cond_len_in_use_begin": 41, 120 | "demo_cond_len_in_use_end": 51, 121 | "demo_qoi_len_in_use_begin": 41, 122 | "demo_qoi_len_in_use_end": 51, 123 | "quest_cond_len_in_use_begin": 41, 124 | "quest_cond_len_in_use_end": 51, 125 | "quest_qoi_len_in_use_begin": 41, 126 | "quest_qoi_len_in_use_end": 51}, 127 | "others":{ 128 | "demo_num_begin": 5, 129 | "demo_num_end": 6, 130 | "demo_cond_select": "random", 131 | "demo_qoi_select": "random", 132 | "quest_cond_select": "random", 133 | "quest_qoi_select": "random", 134 | "demo_cond_len_in_use_begin": 41, 135 | "demo_cond_len_in_use_end": 51, 136 | "demo_qoi_len_in_use_begin": 41, 137 | "demo_qoi_len_in_use_end": 51, 138 | "quest_cond_len_in_use_begin": 41, 139 | "quest_cond_len_in_use_end": 51, 140 | "quest_qoi_len_in_use_begin": 41, 141 | "quest_qoi_len_in_use_end": 51} 142 | } -------------------------------------------------------------------------------- /icon-lm/config_data/train_lm_weno_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_num": 5, 3 | "quest_num": 1, 4 | "demo_cond_len": 100, 5 | "demo_qoi_len": 100, 6 | "quest_cond_len": 100, 7 | "quest_qoi_len": 100, 8 | "k_dim": 3, 9 | "v_dim": 1, 10 | "k_mode": "itx", 11 | "select_demo_quest": "random", 12 | "select_caption": "random", 13 | "return_raw": false, 14 | "load_list": [], 15 | "pde_mask":{ 16 | "demo_num_begin": 5, 17 | "demo_num_end": 6, 18 | "start_ind_begin": 5, 19 | "start_ind_end": 45, 20 | "demo_cond_len_in_use_begin": 41, 21 | "demo_cond_len_in_use_end": 51, 22 | "demo_qoi_len_in_use_begin": 41, 23 | "demo_qoi_len_in_use_end": 51, 24 | "quest_cond_len_in_use_begin": 41, 25 | "quest_cond_len_in_use_end": 51, 26 | "quest_qoi_len_in_use_begin": 41, 27 | "quest_qoi_len_in_use_end": 51}, 28 | "pde_spatial_forward":{ 29 | "demo_num_begin": 5, 30 | "demo_num_end": 6, 31 | "demo_cond_select": "random", 32 | "demo_qoi_select": "random", 33 | "quest_cond_select": "random", 34 | "quest_qoi_select": "random", 35 | "demo_cond_len_in_use_begin": 41, 36 | "demo_cond_len_in_use_end": 51, 37 | "demo_qoi_len_in_use_begin": 41, 38 | "demo_qoi_len_in_use_end": 51, 39 | "quest_cond_len_in_use_begin": 41, 40 | "quest_cond_len_in_use_end": 51, 41 | "quest_qoi_len_in_use_begin": 41, 42 | "quest_qoi_len_in_use_end": 51}, 43 | "pde_spatial_inverse":{ 44 | "demo_num_begin": 5, 45 | "demo_num_end": 6, 46 | "demo_cond_select": "random", 47 | "demo_qoi_select": "random", 48 | "quest_cond_select": "random", 49 | "quest_qoi_select": "random", 50 | "demo_cond_len_in_use_begin": 41, 51 | "demo_cond_len_in_use_end": 51, 52 | "demo_qoi_len_in_use_begin": 41, 53 | "demo_qoi_len_in_use_end": 51, 54 | "quest_cond_len_in_use_begin": 41, 55 | "quest_cond_len_in_use_end": 51, 56 | "quest_qoi_len_in_use_begin": 41, 57 | "quest_qoi_len_in_use_end": 51}, 58 | "ode_forward":{ 59 | "demo_num_begin": 5, 60 | "demo_num_end": 6, 61 | "demo_cond_select": "even", 62 | "demo_qoi_select": "even", 63 | "quest_cond_select": "even", 64 | "quest_qoi_select": "even", 65 | "demo_qoi_len_in_use_begin": 41, 66 | "demo_qoi_len_in_use_end": 51, 67 | "quest_qoi_len_in_use_begin": 41, 68 | "quest_qoi_len_in_use_end": 51}, 69 | "ode_inverse":{ 70 | "demo_num_begin": 5, 71 | "demo_num_end": 6, 72 | "demo_cond_select": "even", 73 | "demo_qoi_select": "even", 74 | "quest_cond_select": "even", 75 | "quest_qoi_select": "even", 76 | "demo_qoi_len_in_use_begin": 40, 77 | "demo_qoi_len_in_use_end": 50, 78 | "quest_qoi_len_in_use_begin": 40, 79 | "quest_qoi_len_in_use_end": 50}, 80 | "time_series":{ 81 | "demo_num_begin": 5, 82 | "demo_num_end": 6, 83 | "demo_cond_select": "random", 84 | "demo_qoi_select": "random", 85 | "quest_cond_select": "random", 86 | "quest_qoi_select": "random", 87 | "demo_cond_len_in_use_begin": 41, 88 | "demo_cond_len_in_use_end": 51, 89 | "demo_qoi_len_in_use_begin": 41, 90 | "demo_qoi_len_in_use_end": 51, 91 | "quest_cond_len_in_use_begin": 41, 92 | "quest_cond_len_in_use_end": 51, 93 | "quest_qoi_len_in_use_begin": 41, 94 | "quest_qoi_len_in_use_end": 51}, 95 | "mfc_gparam_forward":{ 96 | "demo_num_begin": 5, 97 | "demo_num_end": 6, 98 | "demo_cond_select": "random", 99 | "demo_qoi_select": "random", 100 | "quest_cond_select": "random", 101 | "quest_qoi_select": "random", 102 | "demo_cond_len_in_use_begin": 41, 103 | "demo_cond_len_in_use_end": 51, 104 | "demo_qoi_len_in_use_begin": 41, 105 | "demo_qoi_len_in_use_end": 51, 106 | "quest_cond_len_in_use_begin": 41, 107 | "quest_cond_len_in_use_end": 51, 108 | "quest_qoi_len_in_use_begin": 41, 109 | "quest_qoi_len_in_use_end": 51}, 110 | "mfc_rhoparam_forward":{ 111 | "demo_num_begin": 5, 112 | "demo_num_end": 6, 113 | "demo_cond_select": "random", 114 | "demo_qoi_select": "random", 115 | "quest_cond_select": "random", 116 | "quest_qoi_select": "random", 117 | "demo_cond_len_in_use_begin": 41, 118 | "demo_cond_len_in_use_end": 51, 119 | "demo_qoi_len_in_use_begin": 41, 120 | "demo_qoi_len_in_use_end": 51, 121 | "quest_cond_len_in_use_begin": 41, 122 | "quest_cond_len_in_use_end": 51, 123 | "quest_qoi_len_in_use_begin": 41, 124 | "quest_qoi_len_in_use_end": 51}, 125 | "others":{ 126 | "demo_num_begin": 5, 127 | "demo_num_end": 6, 128 | "demo_cond_select": "first", 129 | "demo_qoi_select": "first", 130 | "quest_cond_select": "first", 131 | "quest_qoi_select": "first", 132 | "demo_cond_len_in_use_begin": 100, 133 | "demo_cond_len_in_use_end": 101, 134 | "demo_qoi_len_in_use_begin": 100, 135 | "demo_qoi_len_in_use_end": 101, 136 | "quest_cond_len_in_use_begin": 100, 137 | "quest_cond_len_in_use_end": 101, 138 | "quest_qoi_len_in_use_begin": 100, 139 | "quest_qoi_len_in_use_end": 101} 140 | } -------------------------------------------------------------------------------- /icon-lm/config_model/model_deepo_pde_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "in_dim": 101, 3 | "hidden_dim": 1024, 4 | "emb_dim": 1024, 5 | "hidden_layers": 6, 6 | "query_idx_start": 2, 7 | "query_idx_end": 3 8 | } -------------------------------------------------------------------------------- /icon-lm/config_model/model_deepo_weno_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "in_dim": 100, 3 | "hidden_dim": 1024, 4 | "emb_dim": 1024, 5 | "hidden_layers": 6, 6 | "query_idx_start": 2, 7 | "query_idx_end": 3 8 | } -------------------------------------------------------------------------------- /icon-lm/config_model/model_fno_pde_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_modes": 16, 3 | "hidden_channels": 512, 4 | "in_channels": 2, 5 | "out_channels": 1, 6 | "n_layers": 4, 7 | "cond_grid_idx_start": 2, 8 | "cond_grid_idx_end": 3 9 | } -------------------------------------------------------------------------------- /icon-lm/config_model/model_fno_weno_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_modes": 16, 3 | "hidden_channels": 512, 4 | "in_channels": 2, 5 | "out_channels": 1, 6 | "n_layers": 4, 7 | "cond_grid_idx_start": 2, 8 | "cond_grid_idx_end": 3 9 | } -------------------------------------------------------------------------------- /icon-lm/config_model/model_gpt2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_id_dim": 4, 3 | "demo_max_num": 6, 4 | "index_mode": "learn", 5 | "k_dim": 3, 6 | "v_dim": 1, 7 | "causal": "caption", 8 | "caption_len": 300, 9 | "input_net": {"hidden_dim": 1024}, 10 | "output_net": {"hidden_dim": 1024} 11 | } -------------------------------------------------------------------------------- /icon-lm/config_model/model_icon_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_max_num": 5, 3 | "encoder": {"num_heads":8, "num_layers":6, "model_size":256, "QK_size":256, "V_size":256, "widening_factor": 4, "initializer": "glorot_uniform"}, 4 | "decoder": {"num_heads":8, "num_layers":6, "model_size":256, "QK_size":256, "V_size":256, "widening_factor": 4, "initializer": "glorot_uniform"}, 5 | "out_size": 1 6 | } -------------------------------------------------------------------------------- /icon-lm/config_model/model_lm_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "demo_id_dim": 4, 3 | "demo_max_num": 6, 4 | "index_mode": "learn", 5 | "caption_len": 300, 6 | "caption_feature": "embedding", 7 | "caption_vocab_size": 30516, 8 | "transformer": {"n_layers":6, 9 | "n_heads":8, 10 | "head_dim":256, 11 | "model_dim":256, 12 | "dropout_rate":0, 13 | "widening_factor": 4, 14 | "kernel_init": "glorot_uniform", 15 | "attention_fn": "vanilla" 16 | }, 17 | "out_dim": 1 18 | } -------------------------------------------------------------------------------- /icon-lm/data_preparation/captions_1009/resolve.py: -------------------------------------------------------------------------------- 1 | 2 | def fill_caption_number(caption_template, identifier, params, format_str = '{:.3g}'): 3 | if "ode_auto_const" in identifier: 4 | caption = caption_template.replace('[0.001]', format_str.format(params[0])) \ 5 | .replace('[0.002]', format_str.format(params[1])) 6 | elif "ode_auto_linear1" in identifier: 7 | caption = caption_template.replace('[0.001]', format_str.format(params[0])) \ 8 | .replace('[0.002]', format_str.format(params[1])) 9 | elif "ode_auto_linear2" in identifier: 10 | caption = caption_template.replace('[0.001]', format_str.format(params[0])) \ 11 | .replace('[0.002]', format_str.format(params[1])) \ 12 | .replace('[0.003]', format_str.format(params[2])) 13 | elif "series_damped_oscillator" in identifier: 14 | caption = caption_template.replace('[0.001]', format_str.format(params[0])) 15 | elif "pde_poisson_spatial" in identifier: 16 | caption = caption_template.replace('[0.001]', format_str.format(params[0])) \ 17 | .replace('[0.002]', format_str.format(params[1])) 18 | elif "pde_porous_spatial" in identifier: 19 | # all_params.append("{:.8f}_{:.8f}_{:.8f}_{:.8f}".format(coeff_ul, coeff_ur, coeff_c, coeff_a)) 20 | # coeff_l = lamda_prod_coeff_a 21 | # coeff_a = coeff_c 22 | caption = caption_template.replace('[0.003]', format_str.format(params[0])) \ 23 | .replace('[0.004]', format_str.format(params[1])) \ 24 | .replace('[0.002]', format_str.format(params[2])) \ 25 | .replace('[0.001]', format_str.format(params[3] * 0.05)) 26 | elif "pde_cubic_spatial" in identifier: 27 | # all_params.append("{:.8f}_{:.8f}_{:.8f}_{:.8f}".format(coeff_ul, coeff_ur, coeff_a, coeff_k)) 28 | # coeff_l = lamda_prod_coeff_a 29 | # coeff_a = coeff_k 30 | caption = caption_template.replace('[0.003]', format_str.format(params[0])) \ 31 | .replace('[0.004]', format_str.format(params[1])) \ 32 | .replace('[0.001]', format_str.format(params[2] * 0.1)) \ 33 | .replace('[0.002]', format_str.format(params[3])) 34 | elif "mfc_gparam_hj" in identifier: 35 | str_gvals = ', '.join([format_str.format(g) for g in params]) 36 | caption = caption_template.replace('[g]', str_gvals) 37 | elif "mfc_rhoparam_hj" in identifier: 38 | str_gvals = ', '.join([format_str.format(g) for g in params]) 39 | caption = caption_template.replace('[\\rho_0]', str_gvals) 40 | else: 41 | raise NotImplementedError 42 | return caption -------------------------------------------------------------------------------- /icon-lm/data_preparation/captions_1009/suffix.json: -------------------------------------------------------------------------------- 1 | { 2 | "ode_auto_const_forward" : " Condition: $u(0)$ and $c(t), t\\in[0,1]$, QoI: $u(t), t\\in[0,1]$.", 3 | "ode_auto_linear1_forward" : " Condition: $u(0)$ and $c(t), t\\in[0,1]$, QoI: $u(t), t\\in[0,1]$.", 4 | "ode_auto_linear2_forward" : " Condition: $u(0)$ and $c(t), t\\in[0,1]$, QoI: $u(t), t\\in[0,1]$.", 5 | 6 | "ode_auto_const_inverse" : " Condition: $u(t), t\\in[0,1]$, QoI: $c(t), t\\in[0,1]$.", 7 | "ode_auto_linear1_inverse" : " Condition: $u(t), t\\in[0,1]$, QoI: $c(t), t\\in[0,1]$.", 8 | "ode_auto_linear2_inverse" : " Condition: $u(t), t\\in[0,1]$, QoI: $c(t), t\\in[0,1]$.", 9 | 10 | "series_damped_oscillator_forward" : " Condition: $u(t), t\\in[0,0.5)$, QoI: $u(t), t\\in[0.5,1]$.", 11 | "series_damped_oscillator_inverse" : " Condition: $u(t), t\\in[0.5,1]$, QoI: $u(t), t\\in[0,0.5)$.", 12 | 13 | "pde_poisson_spatial_forward" : " Condition: $c(x), x\\in[0,1]$, QoI: $u(x), x\\in[0,1]$.", 14 | "pde_porous_spatial_forward" : " Condition: $c(x), x\\in[0,1]$, QoI: $u(x), x\\in[0,1]$.", 15 | "pde_cubic_spatial_forward" : " Condition: $c(x), x\\in[0,1]$, QoI: $u(x), x\\in[0,1]$.", 16 | 17 | "pde_poisson_spatial_inverse" : " Condition: $u(x), x\\in[0,1]$, QoI: $c(x), x\\in[0,1]$.", 18 | "pde_porous_spatial_inverse" : " Condition: $u(x), x\\in[0,1]$, QoI: $c(x), x\\in[0,1]$.", 19 | "pde_cubic_spatial_inverse" : " Condition: $u(x), x\\in[0,1]$, QoI: $c(x), x\\in[0,1]$.", 20 | 21 | "mfc_gparam_hj_forward11" : " Condition: $\\rho(0,x), x\\in[0,1]$, QoI: $\\rho(1,x), x\\in[0,1]$.", 22 | "mfc_gparam_hj_forward12" : " Condition: $\\rho(0,x), x\\in[0,1]$, QoI: $\\rho(t,x), t\\in[0.5,1], x\\in[0,1]$.", 23 | "mfc_gparam_hj_forward22" : " Condition: $\\rho(t,x), t\\in[0,0.5), x\\in[0,1]$, QoI: $\\rho(t,x), t\\in[0.5,1], x\\in[0,1]$.", 24 | "mfc_rhoparam_hj_forward11" : " Condition: $g(x), x\\in[0,1]$, QoI: $rho(1,x), x\\in[0,1]$.", 25 | "mfc_rhoparam_hj_forward12" : " Condition: $g(x), x\\in[0,1]$, QoI: $rho(t,x), t\\in[0.5,1], x\\in[0,1]$." 26 | } -------------------------------------------------------------------------------- /icon-lm/data_preparation/data_dynamics.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from collections import namedtuple 4 | from functools import partial 5 | import data_utils 6 | from einshape import jax_einshape as einshape 7 | 8 | 9 | ''' 10 | the semantics of jax.lax.scan() are given roughly by: 11 | 12 | def scan(f, init, xs, length=None): 13 | if xs is None: 14 | xs = [None] * length 15 | carry = init 16 | ys = [] 17 | for x in xs: 18 | carry, y = f(carry, x) 19 | ys.append(y) 20 | return carry, np.stack(ys) 21 | ''' 22 | 23 | # traj[0] = init, final is affected by control[-1] 24 | 25 | def rk4_step(u, c, dt, rhs): 26 | k1 = dt * rhs(c, u) 27 | k2 = dt * rhs(c, u + 0.5 * k1) 28 | k3 = dt * rhs(c, u + 0.5 * k2) 29 | k4 = dt * rhs(c, u + k3) 30 | u_next = u + (1/6) * (k1 + 2*k2 + 2*k3 + k4) 31 | return u_next, u 32 | 33 | def euler_step(u, c, dt, rhs): 34 | u_next = u + dt * rhs(c, u) 35 | return u_next, u 36 | 37 | 38 | @partial(jax.jit, static_argnums=(-1,)) 39 | def ode_auto_const_fn(init, control, dt, coeff_a, coeff_b, step_fn): 40 | rhs = lambda c, u: coeff_a * c + coeff_b 41 | f = partial(step_fn, rhs = rhs, dt = dt) 42 | final, traj = jax.lax.scan(f, init, control) 43 | return final, traj 44 | 45 | @partial(jax.jit, static_argnums=(-1,)) 46 | def ode_auto_linear1_fn(init, control, dt, coeff_a, coeff_b, step_fn): 47 | rhs = lambda c, u: (coeff_a * c * u + coeff_b) 48 | f = partial(step_fn, rhs = rhs, dt = dt) 49 | final, traj = jax.lax.scan(f, init, control) 50 | return final, traj 51 | 52 | @partial(jax.jit, static_argnums=(-1,)) 53 | def ode_auto_linear2_fn(init, control, dt, coeff_a1, coeff_a2, coeff_a3, step_fn): 54 | rhs = lambda c, u: coeff_a1 * u + coeff_a2 * c + coeff_a3 55 | f = partial(step_fn, rhs = rhs, dt = dt) 56 | final, traj = jax.lax.scan(f, init, control) 57 | return final, traj 58 | 59 | @partial(jax.jit, static_argnums=(-1,)) 60 | def ode_auto_linear3_fn(init, control, dt, coeff_a1, coeff_a2, coeff_a3, step_fn): 61 | rhs = lambda c, u: coeff_a1 * c * u + coeff_a2 * u + coeff_a3 62 | f = partial(step_fn, rhs = rhs, dt = dt) 63 | final, traj = jax.lax.scan(f, init, control) 64 | return final, traj 65 | 66 | ode_auto_const_batch_fn = jax.jit(jax.vmap(ode_auto_const_fn, [0,0, None, None, None, None], (0,0)), static_argnums=(-1,)) 67 | ode_auto_linear1_batch_fn = jax.jit(jax.vmap(ode_auto_linear1_fn, [0,0, None, None, None, None],(0,0)), static_argnums=(-1,)) 68 | ode_auto_linear2_batch_fn = jax.jit(jax.vmap(ode_auto_linear2_fn, [0,0, None, None, None, None, None],(0,0)), static_argnums=(-1,)) 69 | ode_auto_linear3_batch_fn = jax.jit(jax.vmap(ode_auto_linear3_fn, [0,0, None, None, None, None, None],(0,0)), static_argnums=(-1,)) 70 | 71 | 72 | @partial(jax.jit, static_argnames=('ode_batch_fn','length','num',)) 73 | def generate_one_dyn(key, ode_batch_fn, dt, length, num, k_sigma, k_l, init_range, coeffs, 74 | control = None): 75 | ''' 76 | generate data for dynamics 77 | @param 78 | key: jax.random.PRNGKey 79 | ode_batch_fn: e.g. ode_auto_const_batch_fn, jitted function 80 | dt: float, time step 81 | length: int, length of time series 82 | num: int, number of samples 83 | k_sigma, k_l: float, kernel parameters 84 | init_range: tuple, range of initial values 85 | coeffs: tuple, coefficients of the dynamics, will be unpacked and passed to ode_batch_fn 86 | control: 2D array (num, length), control signal, if None, generate with Gaussian process 87 | @return 88 | ts: 2D array (num, length, 1), time series 89 | control: 2D array (num, length, 1), control signal 90 | traj: 2D array (num, length, 1), trajectory 91 | ''' 92 | ts = jnp.arange(length) * dt 93 | key, subkey1, subkey2 = jax.random.split(key, num = 3) 94 | if control is None: 95 | control = data_utils.generate_gaussian_process(subkey1, ts, num, kernel = data_utils.rbf_kernel_jax, k_sigma = k_sigma, k_l = k_l) 96 | init = jax.random.uniform(subkey2, (num,), minval = init_range[0], maxval = init_range[1]) 97 | # traj[0] = init, final is affected by control[-1] 98 | _, traj = ode_batch_fn(init, control, dt, *coeffs, euler_step) 99 | ts_expand = einshape("i->ji", ts, j = num) 100 | return ts_expand[...,None], control[...,None], traj[...,None] 101 | 102 | 103 | if __name__ == "__main__": 104 | from jax.config import config 105 | config.update('jax_enable_x64', True) 106 | 107 | # test du/dt = u, with ground truth u = exp(t) 108 | init = 1 109 | dt = 0.02 110 | ts = jnp.arange(50) * dt 111 | control = ts 112 | final, traj = ode_auto_linear2_fn(init, control, dt, 1.0, 0, 0, rk4_step) 113 | assert jnp.allclose(final, jnp.exp(ts[-1]+dt)) 114 | assert jnp.allclose(traj, jnp.exp(ts)) 115 | -------------------------------------------------------------------------------- /icon-lm/data_preparation/data_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pprint import pprint 3 | 4 | def read_lists_from_file_separate(file_path): 5 | with open(file_path, 'r') as file: 6 | content = file.read() 7 | 8 | lines = content.splitlines() 9 | 10 | lists = [] 11 | current_list = [] 12 | 13 | for line in lines: 14 | if line.strip() == '': 15 | if current_list: # append only non-empty lists 16 | lists.append(current_list) 17 | current_list = [] 18 | else: 19 | current_list.append(line) 20 | 21 | # append the last list if it's non-empty 22 | if current_list: 23 | lists.append(current_list) 24 | 25 | return lists 26 | 27 | 28 | def read_lists_from_file(file_path, mode='separate'): 29 | # list of list of strings, sublists are separated by empty lines in the file 30 | list_of_list = read_lists_from_file_separate(file_path) 31 | if mode == 'separate': # return original list of list 32 | return list_of_list 33 | elif mode == 'one': # return the concatenation of all sublists 34 | return [item for sublist in list_of_list for item in sublist] 35 | elif type(mode) == dict: # more detailed control 36 | return_list = [] 37 | for i in range(len(list_of_list)): 38 | sublist = list_of_list[i] # list of strings 39 | indices = eval(mode[str(i)]) # list of indices 40 | print(i, indices) 41 | add_list = [] 42 | for j in indices: 43 | try: 44 | add_list.append(sublist[j]) 45 | except IndexError: 46 | print('IndexError: index {} out of range for list {} of length {}'.format(j, i, len(sublist))) 47 | return_list.extend(add_list) 48 | return return_list 49 | else: 50 | raise ValueError('mode must be "separate", "one" or a dict') 51 | 52 | def read_whole_file(file_path): 53 | with open(file_path, 'r') as file: 54 | content = file.read() 55 | return content 56 | 57 | def write_whole_file(file_path, content): 58 | with open(file_path, 'w') as file: 59 | file.write(content) 60 | 61 | 62 | 63 | if __name__ == "__main__": 64 | group1, group2, group3 = read_lists_from_file('captions/ode1.md') 65 | for g in [group1, group2, group3]: 66 | print('---') 67 | for line in g: 68 | print(line) -------------------------------------------------------------------------------- /icon-lm/data_preparation/data_series.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from collections import namedtuple 4 | 5 | 6 | def generate_sin(xs, amp, period, phase): 7 | return amp * jnp.sin(xs * 2 * jnp.pi / period + phase) 8 | generate_sin_batch = jax.jit(jax.vmap(generate_sin, [None, 0, 0, 0], 0)) 9 | 10 | def generate_sin_base(xs, amp, period, phase, base): 11 | return base + generate_sin(xs, amp, period, phase) 12 | generate_sin_base_batch = jax.jit(jax.vmap(generate_sin, [None, 0, 0, 0, None], 0)) # base is shared in batch 13 | 14 | def generate_damped_oscillator(xs, amp, period, phase, decay): 15 | return generate_sin(xs, amp, period, phase) * jnp.exp(-decay * xs) 16 | 17 | generate_damped_oscillator_batch = jax.jit(jax.vmap(generate_damped_oscillator, [None, 0, 0, 0, None], 0)) # decay is shared in batch -------------------------------------------------------------------------------- /icon-lm/data_preparation/data_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial.distance import cdist 2 | import jax 3 | import jax.numpy as jnp 4 | from einshape import jax_einshape as einshape 5 | from functools import partial 6 | 7 | # Define the covariance function 8 | def rbf_kernel(x1, x2, sigma, l): 9 | """ 10 | Radial basis function kernel 11 | """ 12 | sq_norm = cdist(x1 / l, x2 / l, metric='sqeuclidean') 13 | return sigma**2 * jnp.exp(-0.5 * sq_norm) 14 | 15 | # Define the covariance function 16 | def rbf_kernel_jax(x1, x2, sigma, l): 17 | """ 18 | Radial basis function kernel, only support 1D x1 and x2 19 | """ 20 | xx1, xx2 = jnp.meshgrid(x1, x2, indexing='ij') 21 | sq_norm = (xx1-xx2)**2/(l**2) 22 | return sigma**2 * jnp.exp(-0.5 * sq_norm) 23 | 24 | # Define the covariance function 25 | def rbf_sin_kernel_jax(x1, x2, sigma, l): 26 | """ 27 | suppose x1, x2 in [0,1], 28 | """ 29 | xx1, xx2 = jnp.meshgrid(x1, x2, indexing='ij') 30 | sq_norm = (jnp.sin(jnp.pi*(xx1-xx2)))**2/(l**2) 31 | return sigma**2 * jnp.exp(-0.5 * sq_norm) 32 | 33 | def rbf_circle_kernel_jax(x1, x2, sigma, l): 34 | """ 35 | suppose x1, x2 in [0,1], 36 | """ 37 | xx1, xx2 = jnp.meshgrid(x1, x2, indexing='ij') 38 | xx1_1 = jnp.sin(xx1 * 2 * jnp.pi) 39 | xx1_2 = jnp.cos(xx1 * 2 * jnp.pi) 40 | xx2_1 = jnp.sin(xx2 * 2 * jnp.pi) 41 | xx2_2 = jnp.cos(xx2 * 2 * jnp.pi) 42 | sq_norm = (xx1_1-xx2_1)**2/(l**2) + (xx1_2-xx2_2)**2/(l**2) 43 | return sigma**2 * jnp.exp(-0.5 * sq_norm) 44 | 45 | @partial(jax.jit, static_argnames=('num','kernel')) 46 | def generate_gaussian_process(key, ts, num, kernel, k_sigma, k_l): 47 | ''' 48 | ts: 1D array (length,) 49 | out: Gaussian process samples, 2D array (num, length) 50 | ''' 51 | length = len(ts) 52 | mean = jnp.zeros((num,length)) 53 | # cov = rbf_kernel(ts[:, None], ts[:, None], sigma=k_sigma, l=k_l) 54 | cov = kernel(ts, ts, sigma=k_sigma, l=k_l) 55 | cov = einshape('ii->nii', cov, n = num) 56 | out = jax.random.multivariate_normal(key, mean=mean, cov=cov, shape=(num,), method='svd') 57 | return out 58 | 59 | 60 | -------------------------------------------------------------------------------- /icon-lm/data_preparation/datagen.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | 3 | dir=data 4 | testeqns=100 5 | testquests=5 6 | traineqns=1000 7 | 8 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --length 100 --dt 0.01 --eqn_types series_damped_oscillator --seed 101 && 9 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --eqn_types ode_auto_const --seed 102 && 10 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --eqn_types ode_auto_linear1 --seed 103 && 11 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --eqn_types ode_auto_linear2 --seed 104 && 12 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --eqn_types pde_poisson_spatial --seed 105 && 13 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --eqn_types pde_porous_spatial --seed 106 && 14 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --eqn_types pde_cubic_spatial --seed 107 && 15 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_gparam_hj --seed 108 && 16 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode test --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_rhoparam_hj --seed 109 && 17 | 18 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --length 100 --dt 0.01 --eqn_types series_damped_oscillator --seed 1 && 19 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --eqn_types ode_auto_const --seed 2 && 20 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --eqn_types ode_auto_linear1 --seed 3 && 21 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --eqn_types ode_auto_linear2 --seed 4 && 22 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --length 100 --dx 0.01 --eqn_types pde_poisson_spatial --seed 5 && 23 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --length 100 --dx 0.01 --eqn_types pde_porous_spatial --seed 6 && 24 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --length 100 --dx 0.01 --eqn_types pde_cubic_spatial --seed 7 && 25 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_gparam_hj --seed 8 && 26 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --caption_mode train --name train --eqns $traineqns --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_rhoparam_hj --seed 9 && 27 | 28 | echo "Done" 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /icon-lm/data_preparation/datagen_weno.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | 3 | dir=data0904_weno_cubic 4 | traineqns=1000 5 | trainnum=100 6 | testeqns=10 # only for visualization during training 7 | testnum=100 8 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_weno.py --eqn_types weno_cubic --dir $dir --name test --eqns $testeqns --num $testnum --dt 0.0005 --file_split 1 --truncate 100 --seed 101 && 9 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_weno.py --eqn_types weno_cubic --dir $dir --name train --eqns $traineqns --num $trainnum --dt 0.0005 --file_split 10 --truncate 100 --seed 1 && 10 | 11 | # for analysis 12 | dir=data0904_weno_cubic_test 13 | testeqns=11 14 | testnum=100 15 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_weno.py --eqn_types weno_cubic --dir $dir --name test --eqns $testeqns --num $testnum --dt 0.0005 --file_split $testeqns --eqn_mode grid_-1_1 --truncate 10 --seed 101 && 16 | 17 | # for quick analysis 18 | dir=data0904_weno_cubic_test_light 19 | testeqns=5 20 | testnum=100 21 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_weno.py --eqn_types weno_cubic --dir $dir --name test --eqns $testeqns --num $testnum --dt 0.0005 --file_split $testeqns --eqn_mode grid_-1_1 --truncate 10 --seed 101 && 22 | 23 | echo "Done" 24 | 25 | -------------------------------------------------------------------------------- /icon-lm/data_preparation/datagen_weno_fix.sh: -------------------------------------------------------------------------------- 1 | # generate data for fixed operator, used for pretraining classic operator learning 2 | 3 | gpu=0 4 | 5 | dir=data0604_weno_cubic_fix_0.2_0.2_0.2 6 | traineqns=100 7 | trainnum=100 8 | testeqns=10 # only for visualization during training 9 | testnum=100 10 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_weno.py --eqn_types weno_cubic --dir $dir --name test --eqns $testeqns --num $testnum --dt 0.0005 --file_split 1 --eqn_mode fix_0.2_0.2_0.2 --truncate 100 --seed 101 && 11 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_weno.py --eqn_types weno_cubic --dir $dir --name train --eqns $traineqns --num $trainnum --dt 0.0005 --file_split 10 --eqn_mode fix_0.2_0.2_0.2 --truncate 100 --seed 1 && 12 | 13 | echo "Done" -------------------------------------------------------------------------------- /icon-lm/data_preparation/weno/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | 4 | def timeit(func): 5 | @wraps(func) 6 | def timeit_wrapper(*args, **kwargs): 7 | start_time = time.perf_counter() 8 | result = func(*args, **kwargs) 9 | end_time = time.perf_counter() 10 | total_time = end_time - start_time 11 | print(f'Function {func.__name__} Took {total_time:.4f} seconds', flush = True) 12 | return result 13 | return timeit_wrapper 14 | 15 | class TicToc: 16 | def __init__(self): 17 | self.start_time = {} 18 | self.end_time = {} 19 | def tic(self, name): 20 | self.start_time[name] = time.perf_counter() 21 | def toc(self, name): 22 | self.end_time[name] = time.perf_counter() 23 | total_time = self.end_time[name] - self.start_time[name] 24 | print(f'{name} Took {total_time:.4f} seconds', flush = True) 25 | 26 | 27 | timer = TicToc() -------------------------------------------------------------------------------- /icon-lm/data_preparation/weno/weno_3_coeff.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | k = 3 4 | c_rj = jnp.array([ 5 | [11/6, -7/6, 1/3], 6 | [1/3, 5/6, -1/6], 7 | [-1/6, 5/6, 1/3], 8 | [1/3, -7/6, 11/6], 9 | ]) 10 | 11 | d_r = jnp.array([3/10, 3/5, 1/10]) 12 | d_r_t = jnp.array([1/10, 3/5, 3/10]) 13 | 14 | roll_list = (2,1,0,-1,-2) 15 | 16 | def get_beta(u_roll): 17 | ''' 18 | u_roll: [..., 5], [2,1,0,-1,-2] 19 | ''' 20 | us = [u_roll[..., i] for i in [2, 3, 4, 0, 1]] #[0,-1,-2,2,1] 21 | beta_0 = 13/12 * (us[0] - 2 * us[1] + us[2])**2 + 1/4 * (3 * us[0] - 4 * us[1] + us[2])**2 22 | beta_1 = 13/12 * (us[-1] - 2 * us[0] + us[1])**2 + 1/4 * (us[-1] - us[1])**2 23 | beta_2 = 13/12 * (us[-2] - 2 * us[-1] + us[0])**2 + 1/4 * (us[-2] - 4 * us[-1] + 3 * us[0])**2 24 | return jnp.stack([beta_0, beta_1, beta_2], axis = -1) # [..., 3] 25 | 26 | 27 | epsilon = 1E-6 -------------------------------------------------------------------------------- /icon-lm/data_preparation/weno/weno_roll.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from functools import partial 4 | 5 | ''' 6 | u: [usize, udim] 7 | lb: [udim] 8 | rb: [udim] 9 | out: [usize+2, udim] 10 | ''' 11 | roll_db_funs = {2: jax.jit(lambda u, lb, rb: jnp.concatenate([jnp.tile(lb, [3,1]), u[:-1,:]], axis = 0)), 12 | 1: jax.jit(lambda u, lb, rb: jnp.concatenate([jnp.tile(lb, [2,1]), u], axis = 0)), 13 | 0: jax.jit(lambda u, lb, rb: jnp.concatenate([lb[None,:], u, rb[None,:]], axis = 0)), 14 | -1: jax.jit(lambda u, lb, rb: jnp.concatenate([u, jnp.tile(rb, [2,1])], axis = 0)), 15 | -2: jax.jit(lambda u, lb, rb: jnp.concatenate([u[1:,:], jnp.tile(rb, [3,1])], axis = 0)), 16 | } 17 | 18 | roll_pb_funs = {2: jax.jit(lambda u: jnp.concatenate([u[-3:,:], u[:-1,:]], axis = 0)), 19 | 1: jax.jit(lambda u: jnp.concatenate([u[-2:,:], u], axis = 0)), 20 | 0: jax.jit(lambda u: jnp.concatenate([u[-1:,:], u, u[:1,:]], axis = 0)), 21 | -1: jax.jit(lambda u: jnp.concatenate([u, u[:2,:]], axis = 0)), 22 | -2: jax.jit(lambda u: jnp.concatenate([u[1:,:], u[:3,:]], axis = 0)), 23 | } 24 | 25 | @partial(jax.jit, static_argnames =("roll_list",)) 26 | def get_u_roll_dirichlet(u, left_bound, right_bound, roll_list): 27 | ''' 28 | u: [usize, udim] 29 | u_roll: [usize + 2, udim, 5] 30 | ''' 31 | u_roll = jnp.stack([roll_db_funs[j](u, left_bound, right_bound) for j in roll_list], axis = -1) # stencil for u_{i} 32 | return u_roll 33 | 34 | 35 | 36 | @partial(jax.jit, static_argnames =("roll_list",)) 37 | def get_u_roll_periodic(u, left_bound, right_bound, roll_list): 38 | ''' 39 | u: [usize, udim] 40 | u_roll: [usize + 2, udim, 5] 41 | boundaries are dummy 42 | ''' 43 | u_roll = jnp.stack([roll_pb_funs[j](u) for j in roll_list], axis = -1) # stencil for u_{i} 44 | return u_roll 45 | 46 | if __name__ == "__main__": 47 | u = jnp.array([jnp.arange(10),jnp.arange(10)*2]).T 48 | lb = jnp.array([-99,-100]) 49 | rb = jnp.array([99,100]) 50 | r1 = get_u_roll_dirichlet(u, lb, rb, (2,1,0,-1,-2)) 51 | print(r1[:,0,:].T) 52 | print(r1[:,1,:].T) 53 | print('----------') 54 | r2 = get_u_roll_periodic(u, lb, rb, (2,1,0,-1,-2)) 55 | print(r2[:,0,:].T) 56 | print(r2[:,1,:].T) 57 | -------------------------------------------------------------------------------- /icon-lm/data_preparation/weno/weno_solver.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | import weno_scheme 5 | import weno_roll 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import utils 9 | 10 | 11 | @jax.jit 12 | def euler_eqn(u, gamma): 13 | rho = u[...,0] 14 | rhv = u[...,1] 15 | ene = u[...,2] 16 | 17 | v = rhv/rho 18 | p = (gamma-1)*(ene - 0.5 * rhv * v) 19 | 20 | f1 = rhv 21 | f2 = rhv * v + p 22 | f3 = v * (ene + p) 23 | f = jnp.stack([f1,f2,f3], axis = -1) # [..., 3] 24 | return f 25 | 26 | 27 | def generate_weno_scalar_sol(dx, dt, init, fn, steps, grad_fn = None, stable_tol = None): 28 | ''' 29 | init: (batch, N, 1) 30 | ''' 31 | alpha = weno_scheme.get_scalar_alpha_batch(init, grad_fn, 100, 0.1) # (batch,) 32 | left_bound = jnp.zeros_like(init) # dummy 33 | right_bound = jnp.zeros_like(init) # dummy 34 | us = [init] 35 | for i in range(steps): 36 | us.append(weno_scheme.weno_step_batch(dt, dx, us[-1], weno_scheme.get_w_classic, weno_roll.get_u_roll_periodic, fn, 37 | alpha, 'rk4', left_bound, right_bound)) 38 | out = jnp.stack(us, axis = 1) # (batch, steps + 1, N, 1) 39 | # check if the solution is stable 40 | if stable_tol and (jnp.any(jnp.isnan(us[-1])) or jnp.max(jnp.abs(us[-1])) > stable_tol): 41 | print("sol instable", flush=True) 42 | for init_i, this_init in enumerate(init): 43 | sol = generate_weno_scalar_sol(dx = dx, dt = dt, init = this_init[None,...], fn = fn, steps = steps, grad_fn = grad_fn, stable_tol=None) 44 | if jnp.any(jnp.isnan(sol)) or jnp.max(jnp.abs(sol)) > 10.0: 45 | print(init_i) 46 | print(this_init) 47 | raise ValueError("sol contains nan") 48 | return out 49 | 50 | 51 | def generate_weno_euler_sol(dx, dt, gamma, init, steps, stable_tol = None): 52 | ''' 53 | init: (batch, N, 3) 54 | gamma: scalar 55 | ''' 56 | fn = jax.jit(lambda u: euler_eqn(u, gamma)) 57 | left_bound = init[:,0,:] # (batch, 3) 58 | right_bound = init[:,-1,:] # (batch, 3) 59 | us = [init] 60 | for i in range(steps): 61 | alpha = jnp.zeros((init.shape[0],)) # (batch,) dummy 62 | next_u = weno_scheme.weno_step_euler_batch(dt, dx, us[-1], 63 | weno_scheme.get_w_classic, weno_roll.get_u_roll_dirichlet, 64 | gamma, fn, 65 | alpha, 'rk4', left_bound, right_bound) 66 | us.append(next_u) 67 | 68 | out = jnp.stack(us, axis = 1) # (batch, steps + 1, N, 1) 69 | # check if the solution is stable 70 | if stable_tol and (jnp.any(jnp.isnan(us[-1])) or jnp.max(jnp.abs(us[-1])) > stable_tol): 71 | print("sol instable", flush=True) 72 | raise ValueError("sol contains nan") 73 | return out 74 | -------------------------------------------------------------------------------- /icon-lm/data_preparation/weno/weno_test_acc.py: -------------------------------------------------------------------------------- 1 | from jax.config import config 2 | config.update("jax_enable_x64", True) 3 | import jax.numpy as jnp 4 | import jax 5 | import weno_scheme 6 | import numpy as np 7 | import weno_roll 8 | import matplotlib.pyplot as plt 9 | import tabulate 10 | 11 | def test_reconstruct(N, u_fn, u_integral_fn, get_w, get_u_roll): 12 | dx = 1/N 13 | x = jnp.linspace(0, 1, N+1) 14 | x_left = x[:-1] 15 | x_right = x[1:] 16 | u_value = u_fn(x)[:,None] # [N+1, 1] 17 | u_average = 1/dx * (u_integral_fn(x_right) - u_integral_fn(x_left))[:,None] 18 | v_minus_rhalf_all, v_plus_lhalf_all = weno_scheme.weno_reconstruct(u_average, get_w, get_u_roll, u_value[0], u_value[-1]) 19 | reconstruct_minus = v_minus_rhalf_all[:-1,:] # [N+1, 1] 20 | reconstruct_plus = v_plus_lhalf_all[1:,:] # [N+1, 1] 21 | return u_value, reconstruct_minus, reconstruct_plus 22 | 23 | 24 | 25 | def test_rhs(N, f_fn, gradf_fn, u_fn, u_integral_fn, get_w, get_u_roll): 26 | dx = 1/N 27 | x = jnp.linspace(0, 1, N+1) 28 | x_left = x[:-1] 29 | x_right = x[1:] 30 | u_value = u_fn(x)[:,None] # [N+1, 1] 31 | u_average = 1/dx * (u_integral_fn(x_right) - u_integral_fn(x_left))[:,None] 32 | u_roll = get_u_roll(u_average, u_value[0], u_value[-1], weno_scheme.coeff.roll_list) 33 | w_r, w_r_t = get_w(u_roll) 34 | alpha = weno_scheme.get_scalar_alpha(u_average, gradf_fn) 35 | rhs = weno_scheme.get_rhs(f_fn, alpha, u_roll, dx, w_r, w_r_t) 36 | 37 | gt_flux = lambda x: -f_fn(u_fn(x)) 38 | gt_rhs = ((gt_flux(x_right) - gt_flux(x_left)) / dx)[:,None] 39 | 40 | return gt_rhs, rhs 41 | 42 | if __name__ == "__main__": 43 | 44 | u_fn = lambda x: jnp.sin(2 * jnp.pi * x) 45 | u_integral_fn = lambda x: -1/(2 * jnp.pi) * jnp.cos(2 * jnp.pi * x) 46 | f_fn = lambda x: 0.5 * x * x 47 | gradf_fn = lambda x: x 48 | errors_minus = [] 49 | errors_plus = [] 50 | errors_rhs = [] 51 | Ns = [] 52 | for N in [2**i for i in range(4, 13)]: 53 | gt, reconstruct_minus, reconstruct_plus = test_reconstruct(N, u_fn, u_integral_fn, weno_scheme.get_w_classic, weno_roll.get_u_roll_periodic) 54 | gt_rhs, rhs = test_rhs(N, f_fn, gradf_fn, u_fn, u_integral_fn, weno_scheme.get_w_classic, weno_roll.get_u_roll_periodic) 55 | error_minus = jnp.abs(gt - reconstruct_minus) 56 | error_plus = jnp.abs(gt - reconstruct_plus) 57 | error_rhs = jnp.abs(gt_rhs - rhs) 58 | errors_minus.append(jnp.max(error_minus)) 59 | errors_plus.append(jnp.max(error_plus)) 60 | errors_rhs.append(jnp.max(error_rhs)) 61 | Ns.append(N) 62 | 63 | print(tabulate.tabulate(zip(Ns, errors_minus, errors_plus, errors_rhs), headers = ["N", "error_minus", "error_plus", "error_rhs"])) 64 | 65 | plt.loglog(Ns, errors_minus, 'ro--', label = "v-") 66 | plt.loglog(Ns, errors_plus, 'b^--', label = "v+") 67 | plt.loglog(Ns, errors_rhs, 'g^--', label = "rhs") 68 | plt.loglog(Ns, 1/np.array(Ns)**5, 'k--', label = "1/N^5") 69 | plt.legend() 70 | plt.xlabel("N (dx = 1/N)") 71 | plt.ylabel("L infinity error") 72 | plt.savefig("weno_reconstruct.png") 73 | 74 | 75 | -------------------------------------------------------------------------------- /icon-lm/dataloader_realtime.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import tensorflow as tf 4 | tf.config.set_visible_devices([], device_type='GPU') 5 | import os 6 | from utils import load_json 7 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 8 | 9 | 10 | def split_equation(equation, lead_n=4): 11 | '''equation is a string, e.g. 'ode_auto_linear1_forward_0.806_0.047''' 12 | # Split the string by the separator '_' 13 | split_list = equation.split('_') 14 | # Extract the first four elements for the identifier 15 | identifier = '_'.join(split_list[:lead_n]) 16 | # Convert the rest to float 17 | params = [float(param) for param in split_list[lead_n:]] 18 | return identifier, params 19 | 20 | # define a parser function to parse the serialized example 21 | def parse_function(example_proto, config): 22 | ''' 23 | @return 24 | equation: string describing the equation 25 | caption: caption strings (n,) 26 | embedding_raw: embedding of the caption strings, (n, len, embedding_dim) 27 | embedding_pool: pooled embedding of the caption strings, (n, embedding_dim) 28 | embedding_mask: mask of the caption strings, (n, len) 29 | cond_k: condition key, 3D, (num, cond_length, cond_k_dim) 30 | cond_v: condition value, 3D, (num, cond_length, cond_v_dim) 31 | qoi_k: qoi key, 3D, (num, qoi_length, qoi_k_dim) 32 | qoi_v: qoi value, 3D, (num, qoi_length, qoi_v_dim) 33 | ''' 34 | feature_description = { 35 | 'equation': tf.io.FixedLenFeature([], tf.string), 36 | 'cond_k': tf.io.FixedLenFeature([], tf.string), 37 | 'cond_v': tf.io.FixedLenFeature([], tf.string), 38 | 'qoi_k': tf.io.FixedLenFeature([], tf.string), 39 | 'qoi_v': tf.io.FixedLenFeature([], tf.string), 40 | } 41 | 42 | parsed_example = tf.io.parse_single_example(example_proto, feature_description) 43 | equation = parsed_example['equation'] 44 | cond_k = tf.io.parse_tensor(parsed_example['cond_k'], out_type=tf.float32) 45 | cond_v = tf.io.parse_tensor(parsed_example['cond_v'], out_type=tf.float32) 46 | qoi_k = tf.io.parse_tensor(parsed_example['qoi_k'], out_type=tf.float32) 47 | qoi_v = tf.io.parse_tensor(parsed_example['qoi_v'], out_type=tf.float32) 48 | 49 | return equation, cond_k, cond_v, qoi_k, qoi_v 50 | 51 | 52 | def select_caption(equation, cond_k, cond_v, qoi_k, qoi_v, config): 53 | # make dummy captions 54 | caption = tf.zeros((), dtype = tf.string) 55 | embedding_raw = tf.zeros((), dtype = tf.float32) 56 | embedding_pool = tf.zeros((), dtype = tf.float32) 57 | embedding_mask = tf.zeros((), dtype = tf.bool) 58 | input_id = tf.zeros((), dtype = tf.int32) 59 | 60 | return equation, caption, input_id, embedding_raw, embedding_pool, embedding_mask, cond_k, cond_v, qoi_k, qoi_v 61 | -------------------------------------------------------------------------------- /icon-lm/models_deepo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class FFN(nn.Module): 8 | def __init__(self, in_dim, dim_hidden, out_dim, hidden_layers): 9 | super().__init__() 10 | # self.dropout = dropout 11 | self.hidden_layers = hidden_layers 12 | self.midlin = nn.ModuleList() 13 | self.lin1 = nn.Linear(in_dim, dim_hidden) 14 | for i in range(self.hidden_layers): 15 | self.midlin.append(nn.Linear(dim_hidden, dim_hidden)) 16 | self.lin2 = nn.Linear(dim_hidden, out_dim) 17 | 18 | def forward(self, input): 19 | x = self.lin1(input) 20 | x = F.gelu(x) 21 | for mlin in self.midlin: 22 | x = mlin(x) 23 | x = F.gelu(x) 24 | x = self.lin2(x) 25 | return x 26 | 27 | 28 | class DeepONet(nn.Module): 29 | def __init__(self, config): 30 | super().__init__() 31 | ''' 32 | in_dim: the input dimension of features, should be # data points * data dimension 33 | hidden_dim: the hidden dimension of the FFN 34 | emb_dim: the output dimension of the FFN, which is the embedding dimension 35 | hidden_layers: the number of hidden layers of the FFN 36 | *** only support scalar output *** 37 | ''' 38 | 39 | self.config = config 40 | 41 | self.branch = FFN( 42 | config['in_dim'], 43 | config['hidden_dim'], 44 | config['emb_dim'], 45 | config['hidden_layers'], 46 | ) 47 | self.trunk = FFN( 48 | 1, 49 | config['hidden_dim'], 50 | config['emb_dim'], 51 | config['hidden_layers'], 52 | ) 53 | 54 | # trainable bias of size (1,) 55 | self.b = torch.nn.parameter.Parameter(torch.zeros(1,)) 56 | 57 | def standard_forward(self, cond, query): 58 | """ 59 | cond: (bs, cond_len, input_dim) 60 | query: (bs, query_len, query_dim) 61 | """ 62 | bs = cond.size(0) 63 | flat_cond = cond.reshape(bs, -1) # (bs, cond_len, input_dim) -> (bs, cond_len*input_dim) 64 | 65 | cond_emb = self.branch(flat_cond) # (bs, cond_len*input_dim) -> (bs, emb_dim) 66 | query_emd = self.trunk(query) # (bs, query_len, query_dim) -> (bs, query_len, emb_dim) 67 | out = torch.einsum("be,bqe->bq", cond_emb, query_emd) # (bs, emb_dim) * (bs, query_len, emb_dim) -> (bs, query_len) 68 | out += self.b # (bs, query_len) + (1,) -> (bs, query_len) 69 | out = out[...,None] # (bs, query_len, 1) 70 | return out 71 | 72 | def forward(self, cond_k, cond_v, qoi_k): 73 | ''' 74 | adapt to ICON's input format 75 | cond_k: (bs, cond_len, dim) 76 | cond_v: (bs, cond_len, dim) 77 | qoi_k: (bs, query_len, dim) 78 | ''' 79 | # slice the correct query 80 | query = qoi_k[...,self.config['query_idx_start']:self.config['query_idx_end']] # (bs, query_len, dim) 81 | out = self.standard_forward(cond_v, query) 82 | return out 83 | 84 | 85 | if __name__ == "__main__": 86 | config = { 87 | 'in_dim': 100, 88 | 'hidden_dim': 200, 89 | 'emb_dim': 300, 90 | 'hidden_layers': 4, 91 | } 92 | deeponet = DeepONet(config) 93 | cond = torch.randn(10, 100, 1) 94 | query = torch.randn(10, 11, 1) 95 | out = deeponet.standard_forward(cond, query) 96 | print(out.shape) 97 | -------------------------------------------------------------------------------- /icon-lm/models_fno.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import neuralop.models as models 6 | 7 | 8 | class FNO(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | ''' 12 | in_dim: the input dimension of features, should be # data points * data dimension 13 | hidden_dim: the hidden dimension of the FFN 14 | emb_dim: the output dimension of the FFN, which is the embedding dimension 15 | hidden_layers: the number of hidden layers of the FFN 16 | *** only support scalar output *** 17 | ''' 18 | self.config = config 19 | self.fno = models.FNO(n_modes=(config['n_modes'],), 20 | hidden_channels=config['hidden_channels'], 21 | n_layers=config['n_layers'], 22 | in_channels=config['in_channels'], 23 | out_channels=config['out_channels']) 24 | 25 | def forward(self, cond_k, cond_v, qoi_k): 26 | """ 27 | cond_k: (bs, cond_len, input_dim) 28 | cond_v: (bs, cond_len, input_dim) 29 | qoi_k: (bs, query_len, query_dim) 30 | """ 31 | grid = cond_k[...,self.config['cond_grid_idx_start']:self.config['cond_grid_idx_end']] 32 | grid = grid.permute(0, 2, 1) 33 | x = cond_v.permute(0, 2, 1) 34 | x = torch.cat([x, grid], dim=1) # (bs, channels, n_points) 35 | out = self.fno(x) # (bs, out_channels, n_points) 36 | out = out.permute(0, 2, 1) # (bs, n_points, out_channels) 37 | return out 38 | 39 | 40 | if __name__ == "__main__": 41 | batchsize = 128 42 | x = torch.rand(batchsize, 1, 101) # (bs, channels, n_points) 43 | grid = torch.linspace(0, 1, 101) 44 | x = torch.cat([x, grid.repeat(batchsize, 1, 1)], dim=1) 45 | model = models.FNO(n_modes=(16,), hidden_channels=1024, 46 | in_channels=2, out_channels=1) 47 | out = model(x) 48 | print("x.shape", x.shape) 49 | print("out.shape", out.shape) # (bs, channels, n_points) 50 | 51 | 52 | -------------------------------------------------------------------------------- /icon-lm/models_gpt2_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import GPT2Model, GPT2Tokenizer, GPT2Config 4 | from models_gpt2_source import GPT2Model as CustomGPT2Model 5 | 6 | 7 | 8 | ''' 9 | The code of GPT-2 model in models_gpt2_source.py is slightly modified, for the purpose of multi-modal learning in ICON. 10 | 1, the positional embedding is removed from the forward method, so we need to manually add it before call gpt2 model. 11 | 2, the attention mask is 2D matrix (excluding the batch dimension) 12 | In this script, we will test the custom GPT-2 model by comparing the last hidden state of the original GPT-2 model and the custom GPT-2 model. 13 | ''' 14 | 15 | class OriginalGPT2(nn.Module): 16 | def __init__(self, model_name='gpt2'): 17 | super(OriginalGPT2, self).__init__() 18 | 19 | # Load the pre-trained GPT-2 model and tokenizer 20 | self.tokenizer = GPT2Tokenizer.from_pretrained(model_name) 21 | 22 | gpt_config = GPT2Config.from_pretrained(model_name) 23 | gpt_config.resid_pdrop = 0 24 | gpt_config.attn_pdrop = 0 25 | gpt_config.embd_pdrop = 0 26 | gpt_config.summary_first_dropout = 0 27 | self.gpt2 = GPT2Model.from_pretrained(model_name, config=gpt_config) # without LM Head 28 | # Define the language modeling head and tie its weights to the token embeddings 29 | self.lm_head = nn.Linear(self.gpt2.config.n_embd, self.tokenizer.vocab_size, bias=False) 30 | self.lm_head.weight = self.gpt2.wte.weight 31 | 32 | def forward(self, input_ids): 33 | last_hidden_state = self.gpt2(input_ids = input_ids).last_hidden_state 34 | output = self.lm_head(last_hidden_state) 35 | 36 | return last_hidden_state, output 37 | 38 | 39 | 40 | class CustomGPT2(nn.Module): 41 | def __init__(self, model_name='gpt2'): 42 | super(CustomGPT2, self).__init__() 43 | 44 | # Load the pre-trained GPT-2 model and tokenizer 45 | self.tokenizer = GPT2Tokenizer.from_pretrained(model_name) 46 | 47 | gpt_config = GPT2Config.from_pretrained(model_name) 48 | gpt_config.resid_pdrop = 0 49 | gpt_config.attn_pdrop = 0 50 | gpt_config.embd_pdrop = 0 51 | gpt_config.summary_first_dropout = 0 52 | self.gpt2 = CustomGPT2Model.from_pretrained(model_name, config=gpt_config) # without LM Head 53 | # Define the language modeling head and tie its weights to the token embeddings 54 | self.lm_head = nn.Linear(self.gpt2.config.n_embd, self.tokenizer.vocab_size, bias=False) 55 | self.lm_head.weight = self.gpt2.wte.weight 56 | 57 | def forward(self, input_ids, attention_mask=None): 58 | # mannually add the positional embedding 59 | gpt2_embeddings = self.gpt2.wte(input_ids) # (batch_size, seq_length, hidden_size) 60 | position_ids = torch.arange(0, input_ids.size(1), dtype=torch.long).unsqueeze(0).expand(input_ids.size(0), -1) # (batch_size, seq_length) 61 | position_embeddings = self.gpt2.wpe(position_ids) # (batch_size, seq_length, hidden_size) 62 | transformed_input = gpt2_embeddings + position_embeddings # (batch_size, seq_length, hidden_size) 63 | last_hidden_state = self.gpt2(inputs_embeds=transformed_input, attention_mask = attention_mask)[0] 64 | output = self.lm_head(last_hidden_state) 65 | 66 | return last_hidden_state, output 67 | 68 | 69 | 70 | 71 | def test_gpt2(input_text = None): 72 | # Initialize your model 73 | 74 | origin_model = OriginalGPT2() 75 | origin_model.eval() 76 | custom_model = CustomGPT2() 77 | custom_model.eval() 78 | 79 | if input_text is None: 80 | input_text = "My name is Alice and I" 81 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 82 | input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0) # [1, seq_length] 83 | 84 | causal_mask = torch.tril(torch.ones(input_ids.size(1), input_ids.size(1), dtype = bool)).unsqueeze(0) # (1, seq_length, seq_length) 85 | print(causal_mask[0,:,:]) 86 | with torch.no_grad(): 87 | last_hidden_state_origin, output_logits_origin = origin_model(input_ids) # [1, seq_length, vocab_size] 88 | last_hidden_state_custom, output_logits_custom = custom_model(input_ids) 89 | last_hidden_state_custom_mask, output_logits_custom_mask = custom_model(input_ids, attention_mask = causal_mask) 90 | assert torch.allclose(last_hidden_state_origin, last_hidden_state_custom, atol=1e-4) 91 | assert torch.allclose(last_hidden_state_origin, last_hidden_state_custom_mask, atol=1e-4) 92 | print("Last hidden state is the same for the original and custom GPT2 model") 93 | 94 | 95 | 96 | if __name__ == "__main__": 97 | test_gpt2() 98 | 99 | 100 | -------------------------------------------------------------------------------- /icon-lm/operator/tune_deepo.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | 3 | icon_stamp=icon_lm_learn_20231005-094726-pde3-inverse 4 | tune_stamp=icon_lm_deepo_20240121-203825-pde3-inverse 5 | restore_dir=/home/shared/icon/save/user/ckpts/deepo_pretrain/20240121-203825 6 | model_config=model_deepo_pde_config.json 7 | 8 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name deepo --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 0,100 >tune-$tune_stamp-0-100.log 2>&1 & 9 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name deepo --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 100,200 >tune-$tune_stamp-100-200.log 2>&1 & 10 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name deepo --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 200,300 >tune-$tune_stamp-200-300.log 2>&1 & 11 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name deepo --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 300,400 >tune-$tune_stamp-300-400.log 2>&1 & 12 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name deepo --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 400,500 >tune-$tune_stamp-400-500.log 2>&1 & 13 | 14 | 15 | echo "Done" 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /icon-lm/operator/tune_fno.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | 3 | icon_stamp=icon_lm_learn_20231005-094726-pde3-inverse 4 | tune_stamp=icon_lm_fno_20240121-203841-pde3-inverse 5 | restore_dir=/home/shared/icon/save/user/ckpts/fno_pretrain/20240121-203841 6 | model_config=model_fno_pde_config.json 7 | 8 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name fno --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 0,100 >tune-$tune_stamp-0-100.log 2>&1 & 9 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name fno --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 100,200 >tune-$tune_stamp-100-200.log 2>&1 & 10 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name fno --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 200,300 >tune-$tune_stamp-200-300.log 2>&1 & 11 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name fno --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 300,400 >tune-$tune_stamp-300-400.log 2>&1 & 12 | CUDA_VISIBLE_DEVICES=$gpu python3 finetune.py --model_name fno --model_config $model_config --icon_stamp $icon_stamp --tune_stamp $tune_stamp --restore_dir $restore_dir --tune_bid_range 400,500 >tune-$tune_stamp-400-500.log 2>&1 & 13 | 14 | 15 | echo "Done" 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /icon-lm/operator_weno/analysis_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pickle as pkl 4 | from absl import flags, app 5 | 6 | def get_error(file_name_list): 7 | error_mean_list = [] 8 | error_std_list = [] 9 | for file_name in file_name_list: 10 | file = np.load(file_name) 11 | if len(file["label"].shape) == len(file["out"].shape): 12 | label = file["label"] 13 | pred = file["out"] 14 | elif len(file["label"].shape) == 4 and len(file["out"].shape) == 3: 15 | assert file["label"].shape[1] == 1 16 | label = file["label"][:,0,:,:] # remove the dimension. 17 | pred = file["out"] 18 | else: 19 | raise ValueError("Unknown shape of label") 20 | # both [..., 100, 1] 21 | error = np.mean(np.abs(label - pred), axis = (-2,-1)) 22 | error_mean_list.append(np.mean(error)) 23 | error_std_list.append(np.std(error)) 24 | return error_mean_list, error_std_list 25 | 26 | def get_error_tune(file_name): 27 | 28 | file = np.load(file_name) 29 | error_mean_list = [] 30 | error_std_list = [] 31 | label = file["label"][:,0,:,:] # remove the dimension. [bs, 100,1] 32 | pred_tune = file["pred_tune"][:,:,0,:,:] # [bs, step, 100, 1] 33 | for i in range(pred_tune.shape[1]): 34 | pred = pred_tune[:,i,:,:] 35 | error = np.mean(np.abs(label - pred), axis = (-2,-1)) 36 | error_mean_list.append(np.mean(error)) 37 | error_std_list.append(np.std(error)) 38 | return error_mean_list, error_std_list 39 | 40 | def plot_error_vs_num(): 41 | model_name_dict = {'fno': "FNO", 'deepo': "DeepONet"} 42 | model_name = model_name_dict[FLAGS.model] 43 | 44 | plt.figure(figsize=(5,4)) 45 | 46 | for coeff, color in zip(["0.20"], ['black']): 47 | error_mean, _ = get_error(["notune_{}_data_fix_{}_{}_{}.pkl.npz".format(FLAGS.model, coeff, coeff, coeff)]) 48 | plt.semilogx([2], [error_mean[0]], "s", color=color, subs = [], label = "$f = {}u^3+{}u^2+{}u$".format(coeff, coeff, coeff)) 49 | 50 | errors_mean, errors_std = {}, {} 51 | num_list = [5, 10, 30, 100, 300, 1000] 52 | coeff_list = ['0.21', '0.25', '0.30'] 53 | for coeff in coeff_list: 54 | for model in [FLAGS.model]: 55 | errors_mean[(model,coeff)], errors_std[(model,coeff)] = \ 56 | get_error( 57 | [f'tune_{model}_data_fix_{coeff}_{coeff}_{coeff}.pkl_demonum_{num}.npz' for num in num_list] 58 | ) 59 | 60 | for key, color in zip (coeff_list, ['red', 'blue', 'green']): 61 | plt.semilogx(num_list, errors_mean[(FLAGS.model, key)], "s-", 62 | label="$f = {}u^3+{}u^2+{}u$".format(key, key, key), color=color, subs = []) 63 | 64 | for coeff, color in zip(["0.21", "0.25", "0.30"], ['red', 'blue', 'green']): 65 | error_mean, _ = get_error(["notune_{}_data_fix_{}_{}_{}.pkl.npz".format(FLAGS.model, coeff, coeff, coeff)]) 66 | plt.semilogx([2,5], [error_mean[0], errors_mean[(FLAGS.model, coeff)][0]], "s--", color=color, subs = []) 67 | 68 | for coeff in ["0.20", "0.21", "0.25", "0.30"]: 69 | error_mean, _ = get_error(["icon_data_fix_{}_{}_{}.pkl_demonum_5.npz".format(coeff, coeff, coeff)]) 70 | label = "ICON with 5 examples" if coeff == "0.20" else None 71 | plt.axhline(y=error_mean, color='gray', linestyle='--', label = label) 72 | 73 | 74 | plt.xticks([2]+num_list, ["pretrained\n model"]+num_list) 75 | # plt.xlim(1,1100) 76 | plt.xlabel('# examples') 77 | plt.ylabel('average error') 78 | plt.title(model_name) 79 | plt.legend() 80 | plt.tight_layout() 81 | plt.savefig(f'error_{model_name}.pdf') 82 | 83 | def plot_error_vs_steps(): 84 | 85 | plt.figure(figsize=(5,4)) 86 | step_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] + [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] 87 | for demonum in [5, 10, 30, 100, 300]: 88 | file_name = "tune_all_fno_data_fix_0.30_0.30_0.30.pkl_demonum_{}.npz".format(demonum) 89 | error_mean_list, error_std_list = get_error_tune(file_name) 90 | plt.semilogx(step_list[1:], error_mean_list[1:], 'o--', label = "{} examples".format(demonum)) 91 | plt.legend() 92 | plt.xlabel('steps of fine-tuning') 93 | plt.ylabel('average error') 94 | plt.title("FNO, $f = 0.30u^3+0.30u^2+0.30u$") 95 | plt.tight_layout() 96 | plt.savefig("error_vs_steps.pdf") 97 | 98 | def main(argv): 99 | plot_error_vs_num() 100 | plot_error_vs_steps() 101 | 102 | if __name__ == '__main__': 103 | FLAGS = flags.FLAGS 104 | flags.DEFINE_string('model', 'fno', 'Model to use') 105 | app.run(main) 106 | -------------------------------------------------------------------------------- /icon-lm/operator_weno/datagen_time.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import jax.numpy as jnp 3 | import jax 4 | from einshape import jax_einshape as einshape 5 | import pickle 6 | from functools import partial 7 | 8 | import os 9 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 10 | 11 | import sys 12 | sys.path.append('../') 13 | sys.path.append('../data_preparation/') 14 | sys.path.append('../data_preparation/weno/') 15 | 16 | import utils 17 | from absl import app, flags, logging 18 | import haiku as hk 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | import data_utils 23 | from weno import weno_scheme, weno_roll 24 | 25 | 26 | 27 | def generate_weno_scalar_sol(dx, dt, init, fn, steps, grad_fn = None, stable_tol = None): 28 | ''' 29 | init: (batch, N, 1) 30 | ''' 31 | alpha = weno_scheme.get_scalar_alpha_batch(init, grad_fn, 100, 0.1) # (batch,) 32 | left_bound = jnp.zeros_like(init) # dummy 33 | right_bound = jnp.zeros_like(init) # dummy 34 | us = [init] 35 | # warm up 36 | for _ in range(100): 37 | weno_scheme.weno_step_batch(dt, dx, us[-1], weno_scheme.get_w_classic, weno_roll.get_u_roll_periodic, fn, 38 | alpha, 'rk4', left_bound, right_bound) 39 | # timing and repeat 40 | utils.timer.tic("sim") 41 | for j in range(FLAGS.repeat): 42 | us = [init] 43 | for _ in range(steps): 44 | us.append(weno_scheme.weno_step_batch(dt, dx, us[-1], weno_scheme.get_w_classic, weno_roll.get_u_roll_periodic, fn, 45 | alpha, 'rk4', left_bound, right_bound)) 46 | utils.timer.toc("sim") 47 | print("repeat = {}, steps = {}, time = {:.3f}".format(FLAGS.repeat, steps, utils.timer.get_time("sim"))) 48 | print("average time for each simulation {:.3f}".format(utils.timer.get_time("sim") / FLAGS.repeat)) 49 | out = jnp.stack(us, axis = 1) # (batch, steps + 1, N, 1) 50 | return out 51 | 52 | def simulate_conservation_weno_cubic(seed, length, steps, dt, bs): 53 | ''' 54 | simulate the conservation law with cubic flux function using WENO scheme 55 | du/dt + d(a * u^2 + b * u)/dx = 0 56 | a, b, c, specified in eqn_mode 57 | ''' 58 | rng = hk.PRNGSequence(jax.random.PRNGKey(seed)) 59 | coeff_a = 1.0 60 | coeff_b = 1.0 61 | coeff_c = 1.0 62 | 63 | xs = jnp.linspace(0.0, 1.0, length, endpoint=False) 64 | init = data_utils.generate_gaussian_process(next(rng), xs, bs, kernel = data_utils.rbf_circle_kernel_jax, 65 | k_sigma = 1.0, k_l = 1.0)[...,None] # (num, N, 1) 66 | 67 | fn = jax.jit(lambda u: coeff_a * u * u * u + coeff_b * u * u + coeff_c * u) 68 | grad_fn = jax.jit(lambda u: 3 * coeff_a * u * u + 2 * coeff_b * u + coeff_c) 69 | 70 | print("initial condition generated") 71 | sol = generate_weno_scalar_sol(dx = 1.0 / length, dt = dt, init = init, fn = fn, steps = steps, grad_fn = grad_fn, stable_tol = 10.0) # (num, steps + 1, N, 1) 72 | return sol 73 | 74 | def main(argv): 75 | seed = 1234 76 | eqns = 1 77 | length = 100 78 | steps = 200 79 | dt = 0.0005 80 | sim_num = 1 # simulation trajectories, i.e. number of initial conditions 81 | 82 | eqn_mode = "fix_0.20_0.20_0.20", 83 | sol = simulate_conservation_weno_cubic(seed, length, steps, dt, sim_num) 84 | 85 | print(sol.shape, sol.dtype) 86 | # (out_bs,) (out_bs, out_examples, 100, 1) 0.20000000_0.20000000_0.20000000 87 | 88 | 89 | if __name__ == '__main__': 90 | FLAGS = flags.FLAGS 91 | flags.DEFINE_integer('repeat', 1000, 'repeat') 92 | app.run(main) 93 | 94 | 95 | ''' 96 | CUDA_VISIBLE_DEVICES="0" python3 datagen_time.py 97 | 98 | initial condition generated 99 | sim Took 16.3644 seconds 100 | repeat = 1000, steps = 200, time = 16.364 101 | average time for each simulation 0.016 102 | (1, 201, 100, 1) float32 103 | ''' -------------------------------------------------------------------------------- /icon-lm/plot_icon_lm/plot_benchmark.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | 7 | from jax.config import config 8 | import tensorflow as tf 9 | import os 10 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 11 | tf.config.set_visible_devices([], device_type='GPU') 12 | from collections import OrderedDict 13 | from pprint import pprint 14 | import jax.tree_util as tree 15 | from absl import app, flags, logging 16 | from plot_utils import get_error_from_dict 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | 21 | def draw_join(demo_num_list, style_dict, folder_dict, title = None, plot = plt.plot, figsize=(4,3), ylim = (None, None)): 22 | 23 | all_error = {} 24 | for label, folder_list in folder_dict.items(): 25 | all_error[label] = [] 26 | 27 | for folder in folder_list: 28 | with open("{}/result_dict.pkl".format(folder), "rb") as file: 29 | result_dict = pickle.load(file) 30 | 31 | print("{}: {}".format(label, folder)) 32 | plot_key = [i[0] for i in result_dict.keys()] 33 | plot_key = list(OrderedDict.fromkeys(sorted(plot_key))) # remove duplicates and keep order 34 | 35 | relative_error_list = [] 36 | 37 | if ('ode_auto_const_forward', 'error', 1, 0) in result_dict: 38 | caption_id = 0 39 | else: 40 | caption_id = -1 41 | 42 | for key in plot_key: 43 | relative_error_list.append([get_error_from_dict(result_dict, key, demo_num, caption_id)[1] for demo_num in demo_num_list]) 44 | 45 | relative_error_list = np.array(relative_error_list) 46 | relative_error = np.mean(relative_error_list, axis = 0) 47 | all_error[label].append(relative_error) 48 | 49 | pprint(plot_key) 50 | print(tree.tree_map(lambda x: x.shape, all_error)) 51 | plt.figure(figsize=figsize) 52 | for label, folder_list in folder_dict.items(): 53 | error_mean = np.mean(all_error[label], axis = 0) 54 | error_std = np.std(all_error[label], axis = 0) 55 | plot(demo_num_list, error_mean, label= label, linestyle= style_dict[label]['line'], 56 | marker= style_dict[label]['marker'], markersize=7, color= style_dict[label]['color']) 57 | if len(all_error[label]) > 1: 58 | plt.fill_between(demo_num_list, error_mean - error_std, error_mean + error_std, alpha=0.2, color= style_dict[label]['color']) 59 | 60 | plt.xticks(demo_num_list) 61 | plt.xlabel('number of examples') 62 | plt.ylabel('relative error') 63 | # ymin = 0.01 64 | plt.ylim(ylim) 65 | # ax.set_ylim(figure_config[fi]['ylim']) 66 | # grid on 67 | plt.grid(True, which='both', axis='both', linestyle=':') 68 | plt.legend() 69 | if title is not None: 70 | plt.title(title) 71 | plt.tight_layout() 72 | plt.savefig('{}/ind_err_join_{}.pdf'.format(folder, title.replace(" ","_"))) 73 | print('saved to {}/ind_err_join_{}.pdf'.format(folder, title.replace(" ","_"))) 74 | else: 75 | plt.tight_layout() 76 | plt.savefig('{}/ind_err_join_err.pdf'.format(folder)) 77 | print('saved to {}/ind_err_join_err.pdf'.format(folder)) 78 | 79 | plt.close('all') 80 | 81 | 82 | def plot_benchmark(): 83 | ''' 84 | compare ICON-LM with baseline, small models 85 | ''' 86 | demo_num_list=[1,2,3,4,5] 87 | 88 | style_dict = { 89 | 'ICON-LM (ours)': {'line': '-', 'color': 'red', 'marker': 'o'}, 90 | 'Encoder-Decoder ICON': {'line': '--', 'color': 'blue', 'marker': 's'}, 91 | } 92 | 93 | folder_dict = { 94 | 'ICON-LM (ours)': ["/home/shared/icon/analysis/icon_lm_learn_s1-20231005-094726", 95 | "/home/shared/icon/analysis/icon_lm_learn_s2-20231008-110255", 96 | "/home/shared/icon/analysis/icon_lm_learn_s3-20231009-173624"], 97 | 'Encoder-Decoder ICON': ["/home/shared/icon/analysis/v1baseline_s1-20231004-103259", 98 | "/home/shared/icon/analysis/v1baseline_s2-20231006-114142", 99 | "/home/shared/icon/analysis/v1baseline_s3-20231007-175037"] 100 | } 101 | 102 | draw_join(demo_num_list, style_dict, folder_dict, figsize=(3.5,2.5)) 103 | 104 | 105 | 106 | def main(argv): 107 | del argv 108 | plot_benchmark() 109 | 110 | if __name__ == '__main__': 111 | 112 | FLAGS = flags.FLAGS 113 | 114 | app.run(main) 115 | 116 | 117 | -------------------------------------------------------------------------------- /icon-lm/plot_icon_lm/plot_cap.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | 7 | from jax.config import config 8 | import tensorflow as tf 9 | import os 10 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 11 | tf.config.set_visible_devices([], device_type='GPU') 12 | from collections import OrderedDict 13 | from pprint import pprint 14 | import jax.tree_util as tree 15 | from absl import app, flags, logging 16 | from plot_utils import get_error_from_dict 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | 21 | 22 | def draw_join(demo_num_list, style_dict, folder_dict, title = None, plot = plt.plot, figsize=(4,3), ylim = (None, None)): 23 | 24 | all_error = {} 25 | for label, folder_list in folder_dict.items(): 26 | all_error[label] = [] 27 | 28 | for folder in folder_list: 29 | with open("{}/result_dict.pkl".format(folder), "rb") as file: 30 | result_dict = pickle.load(file) 31 | 32 | print("{}: {}".format(label, folder)) 33 | plot_key = [i[0] for i in result_dict.keys()] 34 | plot_key = list(OrderedDict.fromkeys(sorted(plot_key))) # remove duplicates and keep order 35 | 36 | relative_error_list = [] 37 | 38 | if ('ode_auto_const_forward', 'error', 1, 0) in result_dict: 39 | caption_id = 0 40 | else: 41 | caption_id = -1 42 | 43 | for key in plot_key: 44 | relative_error_list.append([get_error_from_dict(result_dict, key, demo_num, caption_id)[1] for demo_num in demo_num_list]) 45 | 46 | relative_error_list = np.array(relative_error_list) 47 | relative_error = np.mean(relative_error_list, axis = 0) 48 | all_error[label].append(relative_error) 49 | 50 | pprint(plot_key) 51 | print(tree.tree_map(lambda x: x.shape, all_error)) 52 | plt.figure(figsize=figsize) 53 | for label, folder_list in folder_dict.items(): 54 | error_mean = np.mean(all_error[label], axis = 0) 55 | error_std = np.std(all_error[label], axis = 0) 56 | plot(demo_num_list, error_mean, label= label, linestyle= style_dict[label]['line'], 57 | marker= style_dict[label]['marker'], markersize=7, color= style_dict[label]['color']) 58 | if len(all_error[label]) > 1: 59 | plt.fill_between(demo_num_list, error_mean - error_std, error_mean + error_std, alpha=0.2, color= style_dict[label]['color']) 60 | 61 | plt.xticks(demo_num_list) 62 | plt.xlabel('number of examples') 63 | plt.ylabel('relative error') 64 | # ymin = 0.01 65 | plt.ylim(ylim) 66 | # ax.set_ylim(figure_config[fi]['ylim']) 67 | # grid on 68 | plt.grid(True, which='both', axis='both', linestyle=':') 69 | plt.legend() 70 | if title is not None: 71 | plt.title(title) 72 | plt.tight_layout() 73 | plt.savefig('{}/ind_err_join_{}.pdf'.format(folder, title.replace(" ","_"))) 74 | print('saved to {}/ind_err_join_{}.pdf'.format(folder, title.replace(" ","_"))) 75 | else: 76 | plt.tight_layout() 77 | plt.savefig('{}/ind_err_join_err_cap.pdf'.format(folder)) 78 | print('saved to {}/ind_err_join_err_cap.pdf'.format(folder)) 79 | 80 | plt.close('all') 81 | 82 | 83 | 84 | def plot_hero_join_cap_vs_nocap(): 85 | 86 | stamp = FLAGS.ckpt 87 | 88 | demo_num_list = [0,1,2,3,4,5] 89 | style_dict = {'no caption': {'line': ':', 'color': 'red', 'marker': 'o'}, 90 | 'value caption': {'line': '--', 'color': 'blue', 'marker': '^'}, 91 | 'precise caption': {'line': '-', 'color': 'black', 'marker': 's'}, 92 | } 93 | folder_dict = { 94 | 'no caption': ["/home/shared/icon/analysis/{}-testdata-testcap-{}".format(stamp, 'nocap')], 95 | 'value caption': ["/home/shared/icon/analysis/{}-testdata-testcap-{}".format(stamp, 'vague')], 96 | 'precise caption': ["/home/shared/icon/analysis/{}-testdata-testcap-{}".format(stamp, 'precise')], 97 | } 98 | draw_join(demo_num_list, style_dict, folder_dict, 99 | title = None, plot = plt.semilogy, figsize=(3.5,2.5), ylim=(0.008,1)) 100 | 101 | 102 | def main(argv): 103 | del argv 104 | plot_hero_join_cap_vs_nocap() 105 | 106 | if __name__ == '__main__': 107 | 108 | FLAGS = flags.FLAGS 109 | flags.DEFINE_list 110 | flags.DEFINE_string('ckpt', "icon_gpt2_full_s1-20231014-194955", 'checkpoint for fine-tune GPT-2') 111 | 112 | app.run(main) 113 | 114 | 115 | -------------------------------------------------------------------------------- /icon-lm/plot_icon_lm/plot_scratch.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | 7 | from jax.config import config 8 | import tensorflow as tf 9 | import os 10 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 11 | tf.config.set_visible_devices([], device_type='GPU') 12 | from collections import OrderedDict 13 | from pprint import pprint 14 | import jax.tree_util as tree 15 | from absl import app, flags, logging 16 | from plot_utils import get_error_from_dict 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | ''' 21 | This code will compare fine-tune GPT-2 with training from scratch. 22 | For each model, we will compare training captions v.s. testing captions to see generalization. 23 | ''' 24 | 25 | def get_join(demo_num_list, folder_dict): 26 | 27 | all_error = {} 28 | for label, folder_list in folder_dict.items(): 29 | all_error[label] = [] 30 | 31 | for folder in folder_list: 32 | with open("{}/result_dict.pkl".format(folder), "rb") as file: 33 | result_dict = pickle.load(file) 34 | 35 | print("{}: {}".format(label, folder)) 36 | plot_key = [i[0] for i in result_dict.keys()] # 19 types, 19 keys 37 | plot_key = list(OrderedDict.fromkeys(sorted(plot_key))) # remove duplicates and keep order 38 | 39 | relative_error_list = [] 40 | 41 | caption_id = 0 42 | 43 | for key in plot_key: 44 | relative_error_list.append([get_error_from_dict(result_dict, key, demo_num, caption_id)[1] for demo_num in demo_num_list]) 45 | 46 | relative_error_list = np.array(relative_error_list) 47 | relative_error = np.mean(relative_error_list, axis = 0) 48 | all_error[label].append(relative_error) 49 | 50 | pprint(plot_key) 51 | for k, v in all_error.items(): 52 | print(k, v) 53 | 54 | return all_error 55 | 56 | 57 | def plot_train_vs_test(): 58 | 59 | pretrained_stamp = "20231014-194955" 60 | unpretrained_stamp = "20240104-214007" 61 | 62 | cap = 'precise' 63 | demo_num_list = [0] 64 | style_dict = {'pretrained GPT-2, training captions': {'line': '--', 'color': 'red', 'marker': 'o'}, 65 | 'pretrained GPT-2, testing captions': {'line': '-', 'color': 'red', 'marker': 'o'}, 66 | 'unpretrained GPT-2, training captions': {'line': '--', 'color': 'blue', 'marker': 's'}, 67 | 'unpretrained GPT-2, testing captions': {'line': '-', 'color': 'blue', 'marker': 's'}, 68 | } 69 | folder_dict = { 70 | 'pretrained GPT-2, training captions': ["/home/shared/icon/analysis/icon_gpt2_full_s1-{}-testdata-traincap-{}".format(pretrained_stamp, cap)], 71 | 'pretrained GPT-2, testing captions': ["/home/shared/icon/analysis/icon_gpt2_full_s1-{}-testdata-testcap-{}".format(pretrained_stamp, cap)], 72 | 'unpretrained GPT-2, training captions': ["/home/shared/icon/analysis/icon_gpt2_unpretrained-s1-{}-testdata-traincap-{}".format(unpretrained_stamp, cap)], 73 | 'unpretrained GPT-2, testing captions': ["/home/shared/icon/analysis/icon_gpt2_unpretrained-s1-{}-testdata-testcap-{}".format(unpretrained_stamp, cap)], 74 | } 75 | all_error = get_join(demo_num_list, folder_dict) 76 | 77 | ''' 78 | plt.figure(figsize=(3.5, 2.5)) 79 | plt.bar(0, all_error['pretrained GPT-2, training captions'][0], color = 'red', width = 0.2, label = 'pretrained GPT-2, training captions') 80 | plt.bar(0.2, all_error['pretrained GPT-2, testing captions'][0], color = 'red', width = 0.2, label = 'pretrained GPT-2, testing captions') 81 | plt.bar(0.4, all_error['unpretrained GPT-2, training captions'][0], color = 'blue', width = 0.2, label = 'unpretrained GPT-2, training captions') 82 | plt.bar(0.6, all_error['unpretrained GPT-2, testing captions'][0], color = 'blue', width = 0.2, label = 'unpretrained GPT-2, testing captions') 83 | plt.legend(loc = 'upper left') 84 | plt.xticks([0, 0.2, 0.4, 0.6], ['pretrained GPT-2', 'pretrained GPT-2', 'unpretrained GPT-2', 'unpretrained GPT-2']) 85 | plt.ylabel('Relative error') 86 | plt.ylim([0, 0.05]) 87 | plt.tight_layout() 88 | plt.savefig('plot_unpretrained_vs_pretrained.pdf') 89 | ''' 90 | 91 | 92 | def main(argv): 93 | del argv 94 | plot_train_vs_test() 95 | 96 | if __name__ == '__main__': 97 | 98 | FLAGS = flags.FLAGS 99 | 100 | app.run(main) 101 | 102 | 103 | -------------------------------------------------------------------------------- /icon-lm/plot_icon_lm/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | label_map = { 8 | "mfc_gparam_hj_forward11": {'legend': "MFC $g$-parameter 1D -> 1D", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$x$"}, 9 | "mfc_gparam_hj_forward12": {'legend': "MFC $g$-parameter 1D -> 2D", 'linestyle': '--', 'marker': 'o', 'xlabel': r"$x$"}, 10 | "mfc_gparam_hj_forward22": {'legend': "MFC $g$-parameter 2D -> 2D", 'linestyle': ':', 'marker': 'o', 'xlabel': r"$x$"}, 11 | "mfc_rhoparam_hj_forward11": {'legend': r"MFC $\rho_0$-parameter 1D -> 1D", 'linestyle': '-', 'marker': 's', 'xlabel': r"$x$"}, 12 | "mfc_rhoparam_hj_forward12": {'legend': r"MFC $\rho_0$-parameter 1D -> 2D", 'linestyle': ':', 'marker': 's', 'xlabel': r"$x$"}, 13 | "mfc_gparam_solve_forward": {'legend': "MFC 1", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$x$"}, 14 | "mfc_rhoparam_solve_forward": {'legend': "MFC 2", 'linestyle': '--', 'marker': 'o', 'xlabel': r"$x$"}, 15 | "ode_auto_const_forward": {'legend': "Forward problem of ODE 1", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$t$"}, 16 | "ode_auto_const_inverse": {'legend': "Inverse problem of ODE 1", 'linestyle': ':', 'marker': 'o', 'xlabel': r"$t$"}, 17 | "ode_auto_linear1_forward": {'legend': "Forward problem of ODE 2", 'linestyle': '-', 'marker': 's', 'xlabel': r"$t$"}, 18 | "ode_auto_linear1_inverse": {'legend': "Inverse problem of ODE 2", 'linestyle': ':', 'marker': 's', 'xlabel': r"$t$"}, 19 | "ode_auto_linear2_forward": {'legend': "Forward problem of ODE 3", 'linestyle': '-', 'marker': 'd', 'xlabel': r"$t$"}, 20 | "ode_auto_linear2_inverse": {'legend': "Inverse problem of ODE 3", 'linestyle': ':', 'marker': 'd', 'xlabel': r"$t$"}, 21 | "pde_poisson_spatial_forward": {'legend': "Forward Poisson equation", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$x$"}, 22 | "pde_poisson_spatial_inverse": {'legend': "Inverse Poisson equation", 'linestyle': ':', 'marker': 'o', 'xlabel': r"$x$"}, 23 | "pde_porous_spatial_forward" : {'legend': "Forward linear reaction-diffusion", 'linestyle': '-', 'marker': 's', 'xlabel': r"$x$"}, 24 | "pde_porous_spatial_inverse": {'legend': "Inverse linear reaction-diffusion", 'linestyle': ':', 'marker': 's', 'xlabel': r"$x$"}, 25 | "pde_cubic_spatial_forward": {'legend': "Forward nonlinear reaction-diffusion", 'linestyle': '-', 'marker': 'd', 'xlabel': r"$x$"}, 26 | "pde_cubic_spatial_inverse": {'legend': "Inverse nonlinear reaction-diffusion", 'linestyle': ':', 'marker': 'd', 'xlabel': r"$x$"}, 27 | "series_damped_oscillator": {'legend': "time series prediction", 'linestyle': '-', 'marker': '*', 'xlabel': r"$t$"}, 28 | "series_damped_oscillator_forward": {'legend': "Forward damped oscillator", 'linestyle': '-', 'marker': '*', 'xlabel': r"$t$"}, 29 | "series_damped_oscillator_inverse": {'legend': "Inverse damped oscillator", 'linestyle': ':', 'marker': '*', 'xlabel': r"$t$"}, 30 | } 31 | 32 | 33 | 34 | def calculate_error(pred, label, mask): 35 | ''' 36 | pred: [..., len, 1] 37 | label: [..., len, 1] 38 | mask: [..., len] 39 | ''' 40 | mask = mask.astype(bool) 41 | error = np.linalg.norm(pred - label, axis = -1) # [..., len] 42 | error = np.mean(error, where = mask) 43 | gt_norm_mean = np.mean(np.linalg.norm(label, axis = -1), where = mask) 44 | relative_error = error/gt_norm_mean 45 | return error, relative_error 46 | 47 | def pattern_match(patterns, name): 48 | for pattern in patterns: 49 | if pattern in name: 50 | return True 51 | return False 52 | 53 | def get_error_from_dict(result_dict, key, demo_num, caption_id): 54 | error, relative_error = calculate_error(result_dict[(key, 'pred', demo_num, caption_id)], 55 | result_dict[(key, 'ground_truth')], 56 | result_dict[(key, 'mask')]) 57 | return error, relative_error 58 | 59 | 60 | 61 | @jax.jit 62 | def laplace_u(u, dx): 63 | uxx = (u[:-2] + u[2:] - 2*u[1:-1])/dx**2 64 | uxx_left = (2 * u[0] - 5 * u[1] + 4 * u[2] - u[3])/dx**2 65 | uxx_right = (2 * u[-1] - 5 * u[-2] + 4 * u[-3] - u[-4])/dx**2 66 | uxx = jnp.pad(uxx, (1, 1), mode='constant', constant_values = (uxx_left, uxx_right)) 67 | return uxx 68 | laplace_u_batch = jax.jit(jax.vmap(laplace_u, in_axes=(0, None))) 69 | -------------------------------------------------------------------------------- /icon-lm/plot_icon_weno/cubic_approx.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import curve_fit 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | # Define the cubic function to fit 7 | def cubic_function(x, a, b, c, d): 8 | return a*x**3 + b*x**2 + c*x + d 9 | 10 | def fix_x_in_range(xmin, xmax, function): 11 | # Generate x values in the range [-1, 1] 12 | x_values = np.linspace(xmin, xmax, 100) 13 | # Compute y values for these x values using the function 14 | y_values = function(x_values) 15 | # Fit the cubic function to the original function 16 | parameters, _ = curve_fit(cubic_function, x_values, y_values) 17 | return parameters 18 | 19 | # a, b, c, d = -0.5, 1.5, 1.2, 0.8 20 | a, b, c, d = 1, -1, 1, 1 21 | original_function = lambda x: a * np.sin(c * x) + b * np.cos(d * x) 22 | original_range = [-1,2] 23 | 24 | ''' 25 | apply change of variable, new_x = (x - xmin) / (xmax - xmin) * (new_xmax - new_xmin) + new_xmin 26 | so that new_x is in the range [new_xmin, new_xmax] 27 | x = (new_x - new_xmin) / (new_xmax - new_xmin) * (xmax - xmin) + xmin 28 | then new_f(new_x) = f(x) = f((new_x - new_xmin) / (new_xmax - new_xmin) * (xmax - xmin) + xmin) 29 | ''' 30 | new_range_list = [[-1,1], [-1,3]] 31 | plt.figure(figsize=(4*2, 3)) 32 | i = 0 33 | for new_range in new_range_list: 34 | i+=1 35 | plt.subplot(1,2,i) 36 | xmin, xmax = original_range 37 | new_xmin, new_xmax = new_range 38 | new_function = lambda new_x: original_function((new_x - new_xmin) / (new_xmax - new_xmin) * (xmax - xmin) + xmin) 39 | 40 | x_plot = np.linspace(new_xmin, new_xmax, 100) 41 | plt.plot(x_plot, new_function(x_plot), color='black') 42 | params = fix_x_in_range(new_xmin, new_xmax, new_function) 43 | plt.plot(x_plot, cubic_function(x_plot, *params), label='{:.3f}x^3 + {:.3f}x^2 + {:.3f}x + {:.3f}'.format(*params), color='red', linestyle='--') 44 | plt.xlabel('x') 45 | plt.ylabel('y') 46 | plt.title('Cubic Approximation in the range [{}, {}]'.format(new_xmin, new_xmax)) 47 | plt.legend() 48 | plt.grid(True) 49 | 50 | plt.savefig(f'cubic_approx.pdf') -------------------------------------------------------------------------------- /icon-lm/plot_icon_weno/cubic_approx_nochange.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import curve_fit 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | name = '$f = sin(u) - cos(u)$' 6 | def original_function(x): 7 | return np.sin(x) - np.cos(x) 8 | 9 | # Define the cubic function to fit 10 | def cubic_function(x, a, b, c, d): 11 | return a*x**3 + b*x**2 + c*x + d 12 | 13 | def fix_x_in_range(xmin, xmax): 14 | # Generate x values in the range [-1, 1] 15 | x_values = np.linspace(xmin, xmax, 100) 16 | # Compute y values for these x values using the original function 17 | y_values = original_function(x_values) 18 | # Fit the cubic function to the original function 19 | parameters, _ = curve_fit(cubic_function, x_values, y_values) 20 | return parameters 21 | 22 | 23 | # Plotting 24 | plt.figure(figsize=(5, 4)) 25 | x_plot = np.linspace(-2, 2, 100) 26 | plt.plot(x_plot, original_function(x_plot), label=name, color='black') 27 | 28 | params = [-1/6, 0.5, 1, -1] 29 | plt.plot(x_plot, cubic_function(x_plot, *params), label='$f = -1/6u^3 + 1/2u^2 + u - 1$'.format(*params), 30 | color='green', linestyle='--') 31 | 32 | xmin, xmax = -1,1 33 | params = fix_x_in_range(xmin, xmax) 34 | plt.plot(x_plot, cubic_function(x_plot, *params), label='$f = {:.3f}u^3 + {:.3f}u^2 + {:.3f}u {:.3f}$'.format(*params), 35 | color='blue', linestyle='--') 36 | 37 | xmin, xmax = -2,2 38 | params = fix_x_in_range(xmin, xmax) 39 | plt.plot(x_plot, cubic_function(x_plot, *params), label='$f = {:.3f}u^3 + {:.3f}u^2 + {:.3f}u {:.3f}$'.format(*params), 40 | color='red', linestyle='--') 41 | 42 | plt.xlabel('$u$') 43 | plt.ylabel('$f$') 44 | plt.vlines(1, -2, 2, linestyle=':', color='black') 45 | plt.vlines(-1, -2, 2, linestyle=':', color='black') 46 | plt.xlim(-2,2) 47 | plt.ylim(-2,2) 48 | plt.legend() 49 | plt.savefig(f'cubic_approx_{name[1:-1].replace(" ","_")}.pdf') 50 | 51 | 52 | 53 | name = '$f = tanh(u)$' 54 | def original_function(x): 55 | return np.tanh(x) 56 | 57 | 58 | # Define the cubic function to fit 59 | def cubic_function(x, a, b, c, d): 60 | return a*x**3 + b*x**2 + c*x + d 61 | 62 | def fix_x_in_range(xmin, xmax): 63 | # Generate x values in the range [-1, 1] 64 | x_values = np.linspace(xmin, xmax, 100) 65 | # Compute y values for these x values using the original function 66 | y_values = original_function(x_values) 67 | # Fit the cubic function to the original function 68 | parameters, _ = curve_fit(cubic_function, x_values, y_values) 69 | return parameters 70 | 71 | 72 | # Plotting 73 | plt.figure(figsize=(5, 4)) 74 | x_plot = np.linspace(-2, 2, 100) 75 | plt.plot(x_plot, original_function(x_plot), label=name, color='black') 76 | 77 | params = [-1/3, 0, 1, 0] 78 | plt.plot(x_plot, cubic_function(x_plot, *params), label='$f = - 1/3u^3 + u$', 79 | color='green', linestyle='--') 80 | 81 | xmin, xmax = -1,1 82 | params = fix_x_in_range(xmin, xmax) 83 | plt.plot(x_plot, cubic_function(x_plot, *params), label='$f = {:.3f}u^3 + {:.3f}u$'.format(params[0], params[2]), 84 | color='blue', linestyle='--') 85 | 86 | xmin, xmax = -2,2 87 | params = fix_x_in_range(xmin, xmax) 88 | plt.plot(x_plot, cubic_function(x_plot, *params), label='$f = {:.3f}u^3 + {:.3f}u$'.format(params[0], params[2]), 89 | color='red', linestyle='--') 90 | 91 | plt.xlabel('$u$') 92 | plt.ylabel('$f$') 93 | plt.vlines(1, -2, 2, linestyle=':', color='black') 94 | plt.vlines(-1, -2, 2, linestyle=':', color='black') 95 | plt.xlim(-2,2) 96 | plt.ylim(-2,2) 97 | plt.legend() 98 | plt.savefig(f'cubic_approx_{name[1:-1].replace(" ","_")}.pdf') -------------------------------------------------------------------------------- /icon-lm/plot_icon_weno/plot_weno_cubic_profile.py: -------------------------------------------------------------------------------- 1 | 2 | import pickle 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.patches as patches 6 | import matplotlib.gridspec as gridspec 7 | import jax 8 | from jax.config import config 9 | import tensorflow as tf 10 | import os 11 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 12 | tf.config.set_visible_devices([], device_type='GPU') 13 | from collections import OrderedDict 14 | 15 | cmap_1 = 'Blues' 16 | cmap_2 = 'Reds' 17 | # cmap_1 = 'bwr' 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | import matplotlib.gridspec as gridspec 22 | import pickle 23 | import matplotlib.patches as patches 24 | 25 | def get_ylim(demo, quest): 26 | demomin, demomax = np.min(demo), np.max(demo) 27 | questmin, questmax = np.min(quest), np.max(quest) 28 | ymin = min(demomin, questmin) 29 | ymax = max(demomax, questmax) 30 | gap = (ymax - ymin) 31 | return ymin - gap * 0.1, ymax + gap * 0.45 32 | 33 | def draw_profile(eqn_name, a, b, c, demo_num, title, case_idx = 0): 34 | with open("{}/result_dict.pkl".format(folder), "rb") as file: 35 | result_dict = pickle.load(file) 36 | with open("{}/consistency_dict.pkl".format(folder), "rb") as file: 37 | consistency_dict = pickle.load(file) 38 | 39 | caption_id = -1 40 | 41 | x = result_dict[(eqn_name, a, b, c, 'cond_k')] # (bs, 100, 3) 42 | cond = result_dict[(eqn_name, a, b, c, 'cond_v')] # (bs, 100, 1) 43 | qoi = result_dict[(eqn_name, a, b, c, 'ground_truth')] # (bs, 100, 1) 44 | pred = result_dict[(eqn_name, a, b, c, 'pred', demo_num, caption_id)] # (bs, 100, 1) 45 | demo_cond_v = result_dict[(eqn_name, a, b, c, 'demo_cond_v')] # (bs, 5, 100, 1) 46 | demo_qoi_v = result_dict[(eqn_name, a, b, c, 'demo_qoi_v')] # (bs, 5, 100, 1) 47 | 48 | print("x.shape", x.shape, "cond.shape", cond.shape, "qoi.shape", qoi.shape, "pred.shape", pred.shape, 49 | "demo_cond_v.shape", demo_cond_v.shape, "demo_qoi_v.shape", demo_qoi_v.shape) 50 | if 'backward' in eqn_name: 51 | forward = consistency_dict[(eqn_name, a, b, c, 'forward', demo_num, caption_id)] # (bs, 100, 1) 52 | 53 | 54 | plt.figure(figsize=(1 * 3.5, 2 * 3)) # 2 rows, 1 columns 55 | plt.subplot(2, 1, 1) # condition 56 | for i in range(demo_num): 57 | plt.plot(x[case_idx,:,-1], demo_cond_v[case_idx, i,:,0], linestyle = ':', alpha = 0.8) 58 | plt.plot(x[case_idx,:,-1], cond[case_idx,:,0], 'k-', label = "question condition") 59 | if 'backward' in eqn_name: 60 | plt.plot(x[case_idx,:,-1], forward[case_idx,:,0], 'b--', label = "forward simulation") 61 | plt.xlabel('$x$') 62 | plt.ylabel('$u$') 63 | ymin, ymax = get_ylim(demo_cond_v[case_idx], cond[case_idx]) 64 | plt.ylim(ymin, ymax) 65 | plt.legend(loc = 'upper right') 66 | if 'backward' in eqn_name: 67 | plt.title('Reverse Condition') 68 | else: 69 | plt.title('Forward Condition') 70 | 71 | plt.subplot(2, 1, 2) # QoI 72 | for i in range(demo_num): 73 | plt.plot(x[case_idx,:,-1], demo_qoi_v[case_idx, i,:,0], linestyle = ':', alpha = 0.8) 74 | if 'forward' in eqn_name: 75 | plt.plot(x[case_idx,:,-1], qoi[case_idx,:,0], 'k-', label = "ground truth") 76 | plt.plot(x[case_idx,:,-1], pred[case_idx,:,0], 'r--', label = "prediction") 77 | else: 78 | plt.plot(x[case_idx,:,-1], pred[case_idx,:,0], 'r-', label = "prediction") 79 | plt.xlabel('$x$') 80 | plt.ylabel('$u$') 81 | ymin, ymax = get_ylim(demo_qoi_v[case_idx], qoi[case_idx]) 82 | plt.ylim(ymin, ymax) 83 | plt.legend(loc = 'upper right') 84 | if 'backward' in eqn_name: 85 | plt.title('Reverse QoI') 86 | else: 87 | plt.title('Forward QoI') 88 | 89 | plt.suptitle(title) 90 | plt.tight_layout() 91 | plt.savefig('{}/profile_{}_{}_{}_{}_demo{}_case{}.pdf'.format(folder, eqn_name, a, b, c, demo_num, case_idx)) 92 | plt.close('all') 93 | 94 | if __name__ == "__main__": 95 | demo_num_list = [1,2,3,4,5] 96 | 97 | folder = "/home/shared/icon/analysis/icon_weno_20231209-222440_light" 98 | a,b,c = 0.5,0.5,0.5 99 | title = "$\partial_t u + \partial_x(0.5u^3 + 0.5u^2 + 0.5u) = 0$" 100 | for case_idx in range(5): 101 | draw_profile(eqn_name = 'conservation_weno_cubic_forward', a = a, b = b, c = c, demo_num = 5, title = title, case_idx = case_idx) 102 | draw_profile(eqn_name = 'conservation_weno_cubic_backward', a = a, b = b, c = c, demo_num = 5, title = title, case_idx = case_idx) 103 | 104 | a,b,c = -0.5,-0.5,-0.5 105 | title = "$\partial_t u + \partial_x(-0.5u^3 - 0.5u^2 - 0.5u) = 0$" 106 | for case_idx in range(5): 107 | draw_profile(eqn_name = 'conservation_weno_cubic_forward', a = a, b = b, c = c, demo_num = 5, title = title, case_idx = case_idx) 108 | draw_profile(eqn_name = 'conservation_weno_cubic_backward', a = a, b = b, c = c, demo_num = 5, title = title, case_idx = case_idx) 109 | -------------------------------------------------------------------------------- /icon-lm/run.sh: -------------------------------------------------------------------------------- 1 | # Encoder-decoder ICON (Single-Modal): 2 | CUDA_VISIBLE_DEVICES=0 python3 run.py --problem 'icon' --epochs 100 \ 3 | --train_batch_size 32 --train_data_dirs '/home/shared/icon/data/data0910c' \ 4 | --model_config_filename 'model_icon_config.json' \ 5 | --train_config_filename 'train_icon_config.json' \ 6 | --test_config_filename 'test_icon_config.json' \ 7 | --train_data_globs 'train*' --test_data_globs 'test*' \ 8 | --test_demo_num_list 1,3,5 --model icon --loss_mode nocap \ 9 | --nodeterministic --seed 1 --vistest --tfboard 10 | 11 | 12 | # ICON-LM (Single-Modal): 13 | CUDA_VISIBLE_DEVICES=0 python3 run.py --problem 'icon_lm' --epochs 100 \ 14 | --train_batch_size 24 --train_data_dirs '/home/shared/icon/data/data0910c' \ 15 | --model_config_filename 'model_lm_config.json' \ 16 | --train_config_filename 'train_lm_config.json' \ 17 | --test_config_filename 'test_lm_config.json' \ 18 | --train_data_globs 'train*' --test_data_globs 'test*' \ 19 | --test_demo_num_list 1,3,5 --model icon_lm --loss_mode nocap \ 20 | --nodeterministic --seed 1 --vistest --tfboard 21 | 22 | 23 | #GPT-2 (Multi-Modal): 24 | #Add `--unpretrained` option to start from an unpretrained GPT-2 model. 25 | python3 run.py --problem 'icon_gpt2' --epochs 100 \ 26 | --train_batch_size 10 --train_data_dirs '/home/shared/icon/data/data0910c' \ 27 | --model_config_filename 'model_gpt2_config.json' \ 28 | --train_config_filename 'train_lm_config.json' \ 29 | --test_config_filename 'test_lm_config.json' \ 30 | --train_data_globs 'train*' --test_data_globs 'test*' \ 31 | --test_demo_num_list 0,1,3,5 --test_caption_id_list -1,0 \ 32 | --model gpt2 --backend torch --loss_mode nocap,cap \ 33 | --nodeterministic --trainable_mode all --seed 1 --vistest --tfboard 34 | 35 | # Pretrain DeepONet on Problem #14: 36 | CUDA_VISIBLE_DEVICES=0 python3 run.py --problem 'deepo_pretrain' --epochs 10 \ 37 | --train_batch_size 32 --train_data_dirs '/home/shared/icon/data/data0910c' \ 38 | --train_data_globs 'train_pde_cubic_spatial_inverse*' \ 39 | --test_data_globs 'test_pde_cubic_spatial_inverse*' \ 40 | --model_config_filename 'model_deepo_pde_config.json' \ 41 | --train_config_filename 'train_lm_pde_full_config.json' \ 42 | --test_config_filename 'test_lm_pde_full_config.json' \ 43 | --test_demo_num_list 1,3,5 --model deepo \ 44 | --backend torch --loss_mode demo_quest \ 45 | --nodeterministic --trainable_mode all --seed 1 --vistest --tfboard 46 | 47 | # Pretrain FNO on Problem #14: 48 | CUDA_VISIBLE_DEVICES=0 python3 run.py --problem 'fno_pretrain' --epochs 10 \ 49 | --train_batch_size 32 --train_data_dirs '/home/shared/icon/data/data0910c' \ 50 | --train_data_globs 'train_pde_cubic_spatial_inverse*' \ 51 | --test_data_globs 'test_pde_cubic_spatial_inverse*' \ 52 | --model_config_filename 'model_fno_pde_config.json' \ 53 | --train_config_filename 'train_lm_pde_full_config.json' \ 54 | --test_config_filename 'test_lm_pde_full_config.json' \ 55 | --test_demo_num_list 1,3,5 --model fno \ 56 | --backend torch --loss_mode demo_quest \ 57 | --nodeterministic --trainable_mode all --seed 1 --vistest --tfboard 58 | -------------------------------------------------------------------------------- /icon-lm/run_weno.sh: -------------------------------------------------------------------------------- 1 | # Conservation Law L2 loss with batch size 8 for forward and 8 for reverse 2 | # set "--train_batch_size 2n" for batch size n for forward and n for reverse 3 | python3 run.py --problem 'icon_weno' --epochs 100 --train_batch_size 16 \ 4 | --train_data_dirs '/home/shared/icon/data/data0904_weno_cubic' \ 5 | --test_data_dirs '/home/shared/icon/data/data0904_weno_cubic' \ 6 | --model_config_filename 'model_lm_config.json' \ 7 | --train_config_filename 'train_lm_weno_config.json' \ 8 | --test_config_filename 'test_weno_config.json' \ 9 | --test_demo_num_list 0,1,3,5 --model icon_lm --loss_mode nocap \ 10 | --vistest --nodeterministic --tfboard 11 | 12 | 13 | 14 | # Conservation Law with batch size 4 for forward and 4 for consistency loss 15 | python3 run.py --problem 'icon_weno' --epochs 100 --train_batch_size 4 \ 16 | --train_data_dirs '/home/shared/icon/data/data0904_weno_cubic' \ 17 | --test_data_dirs '/home/shared/icon/data/data0904_weno_cubic' \ 18 | --train_data_globs 'train*forward*' --test_data_globs 'test*forward*' \ 19 | --model_config_filename 'model_lm_config.json' \ 20 | --train_config_filename 'train_lm_weno_config.json' \ 21 | --test_config_filename 'test_lm_weno_config.json' \ 22 | --test_demo_num_list 0,1,3,5 --model icon_lm --loss_mode consist \ 23 | --vistest --nodeterministic --tfboard 24 | 25 | 26 | # Pretrain DeepONet for f = 0.2 u^3 + 0.2 u^2 + 0.2 u 27 | CUDA_VISIBLE_DEVICES=0 python3 run.py --problem 'weno_deepo_pretrain' --epochs 10 --train_batch_size 32 \ 28 | --train_data_dirs '/home/shared/icon/data/data0604_weno_cubic_fix_0.2_0.2_0.2' \ 29 | --test_data_dirs '/home/shared/icon/data/data0604_weno_cubic_fix_0.2_0.2_0.2' \ 30 | --train_data_globs 'train*forward*' \ 31 | --test_data_globs 'test*forward*' \ 32 | --model_config_filename 'model_deepo_weno_config.json' \ 33 | --train_config_filename 'train_lm_weno_config.json' \ 34 | --test_config_filename 'test_lm_weno_config.json' \ 35 | --test_demo_num_list 0,1,3,5 \ 36 | --test_caption_id_list -1 \ 37 | --model deepo --backend torch --loss_mode demo_quest \ 38 | --nodeterministic --trainable_mode all --seed 1 --vistest --tfboard 39 | 40 | 41 | # Pretrain FNO for f = 0.2 u^3 + 0.2 u^2 + 0.2 u 42 | CUDA_VISIBLE_DEVICES=1 python3 run.py --problem 'weno_fno_pretrain' --epochs 10 --train_batch_size 32 \ 43 | --train_data_dirs '/home/shared/icon/data/data0604_weno_cubic_fix_0.2_0.2_0.2' \ 44 | --test_data_dirs '/home/shared/icon/data/data0604_weno_cubic_fix_0.2_0.2_0.2' \ 45 | --train_data_globs 'train*forward*' \ 46 | --test_data_globs 'test*forward*' \ 47 | --model_config_filename 'model_fno_weno_config.json' \ 48 | --train_config_filename 'train_lm_weno_config.json' \ 49 | --test_config_filename 'test_lm_weno_config.json' \ 50 | --test_demo_num_list 0,1,3,5 \ 51 | --test_caption_id_list -1 \ 52 | --model fno --backend torch --loss_mode demo_quest \ 53 | --nodeterministic --trainable_mode all --seed 1 --vistest --tfboard 54 | -------------------------------------------------------------------------------- /icon/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM ubuntu:22.04 as base 3 | 4 | ARG USER_ID 5 | ARG GROUP_ID 6 | 7 | RUN addgroup --gid $GROUP_ID user 8 | RUN adduser --disabled-password --gecos '' --uid $USER_ID --gid $GROUP_ID user 9 | 10 | RUN apt-get update 11 | RUN apt-get install -y --fix-missing git python3 python3-pip 12 | RUN rm -rf /var/lib/apt/lists/* 13 | 14 | USER user 15 | RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> $HOME/.bashrc 16 | 17 | RUN export PIP_DEFAULT_TIMEOUT=100 18 | 19 | RUN pip install --upgrade pip 20 | 21 | RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 22 | RUN pip install numpy scipy absl-py dm-haiku flax einops einshape jupyter matplotlib opt-einsum optax tensorflow pytz tqdm gymnasium gymnax brax pgx jumanji evosax distrax 23 | RUN pip install torch torchvision torchaudio transformers --index-url https://download.pytorch.org/whl/cu118 24 | 25 | WORKDIR /workspace -------------------------------------------------------------------------------- /icon/README.md: -------------------------------------------------------------------------------- 1 | # In-Context Operator Networks (ICON) 2 | 3 | This folder contains the code associated with the paper titled ["In-Context Operator Learning with Data Prompts for Differential Equation Problems"](https://www.pnas.org/doi/10.1073/pnas.2310142120) published in *Proceedings of the National Academy of Sciences (PNAS)*. 4 | 5 | ## Environment Setup 6 | 7 | ### Docker 8 | To facilitate the setup process, a Dockerfile is provided in this repository. Each user is required to build their own Docker image by following the instructions provided in [this guide](https://vsupalov.com/docker-shared-permissions/). Please replace `repo`, `tag`, and `xxx` with your own values. 9 | 10 | To build the Docker image (you may want to replace `repo` and `tag`): 11 | 12 | ``` 13 | docker build - < Dockerfile -t repo:tag --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) 14 | ``` 15 | 16 | To run the Docker container (Please navigate to your preferred directory first. Replace `repo` and `tag` with the ones you specified. You may also want to replace `xxx`): 17 | 18 | ``` 19 | docker run --gpus all --rm -itd --name xxx -v $(pwd):/workspace/ repo:tag bash 20 | ``` 21 | 22 | To attach to the Docker container (replace `xxx` with the one you specified): 23 | 24 | ``` 25 | docker attach xxx 26 | ``` 27 | 28 | To prevent disruptions due to timeouts, the following line is included in the Dockerfile: 29 | 30 | ``` 31 | RUN export PIP_DEFAULT_TIMEOUT=100 32 | ``` 33 | 34 | Please note that this Docker image does not have CUDA and cuDNN installed outside of pip. Therefore, TensorFlow will not be able to utilize the GPU. However, JAX and PyTorch will still be able to utilize the GPU. Since only tensorboard and tf.data.Dataset is used with TensorFlow, GPU usage for TensorFlow is not necessary. 35 | 36 | ### Pip 37 | 38 | The pip install commands can be found in the Dockerfile. 39 | 40 | ## Data Preparation 41 | 42 | The code for data preparation is located in the `data_generation/` folder. To generate the training data, in-distribution testing data, and out-of-distribution testing data, navigate to the `data_generation/` folder and execute the following commands: 43 | 44 | ``` 45 | bash datagen.sh # training data and in-distribution testing data 46 | bash datagen_ood.sh # out-of-distribution testing data, including equations of new forms 47 | ``` 48 | 49 | ## Training 50 | 51 | The code for training is located in the current folder. 52 | 53 | To train with all training data, use the following command: 54 | 55 | ``` 56 | python3 run.py --problem hero --num_heads 8 --num_layers 6 --hidden_dim 256 --train_batch_size 32 --epochs 100 --train_data_dirs './data_generation/data0511a' --k_dim 3 --k_mode itx --tfboard --plot_num 16 57 | ``` 58 | 59 | In the paper, the effect of different training datasets was also studied. Execute `bash run_group.sh` to perform these experiments. 60 | 61 | ## Analysis 62 | 63 | The code for analysis is located in the `analysis/` folder. To generate the results presented in the paper, navigate to the `analysis/` folder and execute the following commands: 64 | 65 | ``` 66 | bash analysis_ind.sh # in-distribution testing 67 | bash analysis_len.sh # super-resolution and sub-resolution 68 | bash analysis_ood.sh # out-of-distribution testing, including equations of new forms 69 | ``` 70 | 71 | To generate the figures presented in the paper, run the Python scripts and Jupyter notebooks with filenames starting with `plot_`. 72 | 73 | To run the analysis successfully, you may need to make modifications to the directory paths and checkpoint time stamps. 74 | 75 | ## Reference 76 | [In-Context Operator Learning with Data Prompts for Differential Equation Problems](https://www.pnas.org/doi/10.1073/pnas.2310142120) 77 | ``` 78 | @article{yang2023context, 79 | title={In-context operator learning with data prompts for differential equation problems}, 80 | author={Yang, Liu and Liu, Siting and Meng, Tingwei and Osher, Stanley J}, 81 | journal={Proceedings of the National Academy of Sciences}, 82 | volume={120}, 83 | number={39}, 84 | pages={e2310142120}, 85 | year={2023}, 86 | publisher={National Acad Sciences} 87 | } 88 | ``` -------------------------------------------------------------------------------- /icon/analysis/analysis_ind.sh: -------------------------------------------------------------------------------- 1 | problem="hero" 2 | stamp="20230515-094404" 3 | step=1000000 4 | testdata="data0511a" 5 | savedir="analysis0511a-v4-ind" 6 | 7 | # you can comment out "--save_raw --save_prompt --batch_size 1 \" to speed up 8 | # but you will not be able to plot the raw data then 9 | # if you keep it, for each demo_num_begin, it takes about 90 seconds to finish 10 | 11 | for demo_num_begin in 1 2 3 4 5 12 | do 13 | demo_num_end=$(expr $demo_num_begin + 1) 14 | python3 analysis_accelerate.py --analysis_dir $savedir --task ind \ 15 | --problem $problem --stamp $stamp --step $step \ 16 | --test_data_dirs $testdata --test_data_globs "test*" \ 17 | --k_dim 3 --k_mode itx \ 18 | --demo_num_begin $demo_num_begin --demo_num_end $demo_num_end \ 19 | --save_raw --save_prompt --batch_size 1 \ 20 | >out-analysis-ind-$stamp-$step-$testdata-$demo_num_begin-$demo_num_end.log 2>&1 21 | done 22 | 23 | echo "Done" 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /icon/analysis/analysis_len.sh: -------------------------------------------------------------------------------- 1 | problem="hero" 2 | stamp="20230515-094404" 3 | step=1000000 4 | testdata="data0511a" 5 | savedir="analysis0511a-v4-len" 6 | 7 | 8 | len_demo_cond_lens=(10 20 30 40 50 60 80 100 200 500) 9 | len_demo_qoi_lens=(10 20 30 40 50 60 80 100 200 500) 10 | len_quest_cond_lens=(10 20 30 40 50 60 80 100 200 500) 11 | 12 | len_quest_qoi_len=2600 13 | 14 | for index in ${!len_demo_cond_lens[*]} 15 | do 16 | for demo_num_begin in 5 4 3 2 1 17 | do 18 | demo_num_end=$(expr $demo_num_begin + 1) 19 | len_demo_cond_len=${len_demo_cond_lens[$index]} 20 | len_demo_qoi_len=${len_demo_qoi_lens[$index]} 21 | len_quest_cond_len=${len_quest_cond_lens[$index]} 22 | echo "len_demo_cond_len: $len_demo_cond_len" "len_demo_qoi_len: $len_demo_qoi_len" "len_quest_cond_len: $len_quest_cond_len" "len_quest_qoi_len: $len_quest_qoi_len" \ 23 | "demo_num_begin" $demo_num_begin "demo_num_end" $demo_num_end 24 | 25 | python3 analysis_accelerate.py --analysis_dir $savedir-$len_demo_cond_len-$len_demo_qoi_len-$len_quest_cond_len-$len_quest_qoi_len \ 26 | --task len --problem $problem --stamp $stamp --step $step \ 27 | --test_data_dirs $testdata --test_data_globs "test_mfc_gparam_hj_forward22*" \ 28 | --demo_num_begin $demo_num_begin --demo_num_end $demo_num_end \ 29 | --test_config_filename "test_config_len.json" --cond_len 2400 --qoi_len 2600 \ 30 | --len_demo_cond_len $len_demo_cond_len --len_demo_qoi_len $len_demo_qoi_len \ 31 | --len_quest_cond_len $len_quest_cond_len --len_quest_qoi_len $len_quest_qoi_len \ 32 | --batch_size 1 --save_raw --save_prompt --figs 10 \ 33 | --k_dim 3 --k_mode itx \ 34 | >out-analysis-len-$len_demo_cond_len-$len_demo_qoi_len-$len_quest_cond_len-$len_quest_qoi_len-$stamp-$step-$testdata-$demo_num_begin-$demo_num_end.log 2>&1 35 | done 36 | done 37 | 38 | echo "Done" 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /icon/analysis/analysis_ood.sh: -------------------------------------------------------------------------------- 1 | savedir="../analysis/analysis0521-ood" 2 | suffix="_light" # for light dataset 3 | # suffix="" # for heavy dataset 4 | figs=10 5 | testdatas=( 6 | "data0520_ood_odeconst$suffix" 7 | "data0520_ood_odelinear1$suffix" 8 | "data0520_ood_odelinear2$suffix" 9 | "data0520_ood_pdeporous_randbdry$suffix" 10 | "data0520_ood_seriesdamped$suffix" 11 | "data0520_nt_odelinear3$suffix" 12 | "data0520_nt_ot_rho1param$suffix" 13 | ) 14 | 15 | ### =============================This is for k_mode = 'itx'====================== 16 | 17 | problems=( 18 | "hero" 19 | "group-3-itx" "group-3-itx" "group-3-itx" "group-3-itx" 20 | ) 21 | stamps=( 22 | "20230515-094404" # hero v4 23 | "20230522-234219" "20230524-165312" "20230525-010119" "20230525-090937" 24 | ) 25 | step=( 26 | 1000000 27 | 200000 200000 200000 200000 28 | ) 29 | 30 | for index in ${!problems[*]} 31 | do 32 | echo "index: $index" "problem: ${problems[$index]}" "stamp: ${stamps[$index]}" "step: ${step[$index]}" 33 | problem=${problems[$index]} 34 | stamp=${stamps[$index]} 35 | step=${step[$index]} 36 | for testdata in "${testdatas[@]}" 37 | do 38 | for demo_num_begin in 5 39 | do 40 | echo "testdata: $testdata" "demo_num_begin: $demo_num_begin" 41 | demo_num_end=$(expr $demo_num_begin + 1) 42 | python3 analysis_accelerate.py --analysis_dir $savedir --task ood --figs $figs\ 43 | --problem $problem --stamp $stamp --step $step \ 44 | --test_data_dirs $testdata --test_data_globs "test*" \ 45 | --demo_num_begin $demo_num_begin --demo_num_end $demo_num_end \ 46 | --test_config_filename "test_config_ood.json" \ 47 | --k_dim 3 --k_mode itx \ 48 | >out-analysis-ood-$problem-$stamp-$step-$testdata-$demo_num_begin-$demo_num_end.log 2>&1 49 | done 50 | done 51 | done 52 | 53 | 54 | ## =============================fake operator ====================== 55 | 56 | testdatas=("data0520_nt_odelinear3$suffix") 57 | 58 | for testdata in "${testdatas[@]}" 59 | do 60 | for mode in "real_op" "fake_op" 61 | do 62 | echo "testdata: $testdata" "mode: $mode" 63 | python3 analysis_accelerate.py --analysis_dir $savedir --task ood --mode $mode --figs $figs\ 64 | --test_data_dirs $testdata --test_data_globs "test*" \ 65 | --demo_num_begin 5 --demo_num_end 6 \ 66 | --test_config_filename "test_config_ood.json" \ 67 | >out-analysis-ood-$mode-$testdata.log 2>&1 68 | done 69 | done 70 | 71 | ## =============================fake demo====================== 72 | 73 | testdatas=("data0520_nt_odelinear3$suffix") 74 | 75 | problems=( 76 | "hero" 77 | "group-3-itx" 78 | ) 79 | stamps=( 80 | "20230515-094404" # hero v4 81 | "20230522-234219" 82 | ) 83 | step=( 84 | 1000000 85 | 200000 86 | ) 87 | 88 | for testdata in "${testdatas[@]}" 89 | do 90 | for index in ${!problems[*]} 91 | do 92 | echo "fake demo, index: $index" "problem: ${problems[$index]}" "stamp: ${stamps[$index]}" "step: ${step[$index]}" 93 | problem=${problems[$index]} 94 | stamp=${stamps[$index]} 95 | step=${step[$index]} 96 | for demo_num_begin in 5 97 | do 98 | echo "testdata: $testdata" "demo_num_begin: $demo_num_begin" 99 | demo_num_end=$(expr $demo_num_begin + 1) 100 | python3 analysis_accelerate.py --analysis_dir $savedir --task ood --mode fake_demo --batch_size 50 --figs $figs\ 101 | --problem $problem --stamp $stamp --step $step \ 102 | --test_data_dirs $testdata --test_data_globs "test*" \ 103 | --demo_num_begin $demo_num_begin --demo_num_end $demo_num_end \ 104 | --test_config_filename "test_config_ood.json" \ 105 | --k_dim 3 --k_mode itx \ 106 | >out-analysis-ood-$problem-$stamp-$step-fake_demo-$testdata-$demo_num_begin-$demo_num_end.log 2>&1 107 | done 108 | done 109 | done 110 | 111 | echo "Done" 112 | -------------------------------------------------------------------------------- /icon/analysis/plot_analysis_len.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pickle\n", 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import matplotlib.patches as patches\n", 13 | "import sys\n", 14 | "sys.path.append('../')\n", 15 | "from jax.config import config\n", 16 | "import tensorflow as tf\n", 17 | "import os\n", 18 | "os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'\n", 19 | "tf.config.set_visible_devices([], device_type='GPU')" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "analysis_folder = \"../analysis/analysis0511a-v4-len\"\n", 29 | "runs = {\n", 30 | " \"hero_v1\": {\"stamp\":\"20230412-223022_1000000\"},\n", 31 | " \"hero_v2\": {\"stamp\":\"20230504-225151_1000000\"},\n", 32 | " \"hero_v3\": {\"stamp\":\"20230511-215453_1000000\"},\n", 33 | " \"hero_v4\": {\"stamp\":\"20230515-094404_1000000\"},\n", 34 | " \"ode 2\": {\"stamp\":\"20230505-093303_200000\"},\n", 35 | " \"ode 1,2\": {\"stamp\": \"20230505-174046_200000\"},\n", 36 | " \"ode 2,3\": {\"stamp\": \"20230506-014920_200000\"}, \n", 37 | " \"ode 1,2,3\": {\"stamp\": \"20230506-095636_200000\"},\n", 38 | " \"odes, series\": {\"stamp\": \"20230509-223508_200000\"},\n", 39 | " \"odes, series, pdes\": {\"stamp\": \"20230511-224540_200000\"},\n", 40 | " \"all-short\": {\"stamp\": \"20230506-180338_200000\"}, \n", 41 | " \"all-long\": {\"stamp\": \"20230504-225151_1000000\"}\n", 42 | " }\n", 43 | "\n", 44 | "figure_config = {1: {'marker': 'o'},\n", 45 | " 2: {'marker': 's'},\n", 46 | " 3: {'marker': 'd'},\n", 47 | " 4: {'marker': '^'},\n", 48 | " 5: {'marker': '*'},\n", 49 | " }\n", 50 | "\n", 51 | "datadir = 'data0511a'\n", 52 | "run = 'hero_v4'" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "equation = \"mfc_gparam_hj_forward22\"\n", 62 | "plt.figure(figsize=(4,3))\n", 63 | "for demo_num_begin in [1,2,3,4,5]:\n", 64 | " demo_num_end = demo_num_begin + 1\n", 65 | " len_list = [10, 20, 30, 40, 50, 60, 80, 100, 200, 500]\n", 66 | " rerr_list = []\n", 67 | "\n", 68 | " for this_len in len_list:\n", 69 | " this_analysis_folder = f\"{analysis_folder}-{this_len}-{this_len}-{this_len}-2600\"\n", 70 | " with open(f\"{this_analysis_folder}/err_{runs[run]['stamp']}_{datadir}_{demo_num_begin}_{demo_num_end}.pickle\", 'rb') as file:\n", 71 | " results = pickle.load(file)\n", 72 | " rerr_list.append(results[equation][\"relative_error_mean\"])\n", 73 | " plt.loglog(len_list, rerr_list, label=f\"# demos = {demo_num_begin}\", linestyle = 'dotted',\n", 74 | " marker= figure_config[demo_num_begin]['marker'], markersize=4)\n", 75 | " print(rerr_list)\n", 76 | "\n", 77 | "fill_y = np.linspace(0.001,0.2, 100)\n", 78 | "plt.fill_betweenx(fill_y, 0 * fill_y + 40, 0 * fill_y + 50, alpha=0.3, color='red', label='training region')\n", 79 | "plt.xlabel('number of key-value pairs in each condition/QoI')\n", 80 | "plt.ylabel('relative error')\n", 81 | "# plt.xticks(len_list, [str(this_len) for this_len in len_list])\n", 82 | "plt.legend()\n", 83 | "plt.ylim(0.005,0.11)\n", 84 | "plt.tight_layout()\n", 85 | "plt.savefig(f'{this_analysis_folder}/length_err_{equation}.pdf')\n", 86 | "plt.close('all')" 87 | ] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "jaxcuda", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.10.9" 107 | }, 108 | "orig_nbformat": 4 109 | }, 110 | "nbformat": 4, 111 | "nbformat_minor": 2 112 | } 113 | -------------------------------------------------------------------------------- /icon/analysis/plot_analysis_nt.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | from absl import app, flags, logging 6 | import sys 7 | sys.path.append('../') 8 | from jax.config import config 9 | import tensorflow as tf 10 | import os 11 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 12 | tf.config.set_visible_devices([], device_type='GPU') 13 | import utils 14 | 15 | runs = {"all data": {"stamp": "20230515-094404_1000000", "line": "solid", "color": "grey", "alpha": 1.0}, 16 | "wrong operator": {"stamp": "fake_op", "line": "solid", "color": "black", "alpha": 1.0}, 17 | "wrong demos": {"stamp": "20230522-234219_200000_fake_demo", "line": "dashed", "color": "red", "alpha": 1.0}, 18 | "ODE 2": {"stamp":"20230522-234219_200000", "line": "dotted", "color": "r", "alpha": 1.0}, 19 | "ODE 1,2": {"stamp": "20230524-165312_200000", "line": "dotted", "color": "orange", "alpha": 1.0}, 20 | "ODE 2,3": {"stamp": "20230525-010119_200000", "line": "dashed", "color": "c", "alpha": 1.0}, 21 | "ODE 1,2,3": {"stamp": "20230525-090937_200000", "line": "dashed", "color": "blue", "alpha": 1.0}, 22 | } 23 | 24 | 25 | def myplot(analysis_folder, runs_plot, datadir, mode): 26 | 27 | plt.figure(figsize=(4,3)) 28 | 29 | for label in runs_plot: 30 | stamp = runs[label]["stamp"] 31 | line = utils.linestyles[runs[label]["line"]] 32 | color = runs[label]["color"] 33 | alpha = runs[label]["alpha"] 34 | filename = f"{analysis_folder}/err_{stamp}_{datadir}_5_6.pickle" 35 | print(filename) 36 | with open(filename, 'rb') as file: 37 | results = pickle.load(file) 38 | ncoeffs = [] 39 | err_mean = [] 40 | for (str, ncoeff,), result in results.items(): 41 | if mode in str: 42 | ncoeffs.append(ncoeff) 43 | err_mean.append(result["relative_error_mean"]) 44 | x = [[i,j] for i,j in zip(ncoeffs, err_mean)] 45 | x.sort(key=lambda x: x[0]) 46 | new_ncoeffs = np.array(x)[:,0] 47 | new_err_mean = np.array(x)[:,1] 48 | plt.plot(new_ncoeffs, new_err_mean, label=label, linestyle = line, color=color, alpha=alpha) 49 | 50 | plt.legend(ncols = 1) 51 | plt.xlabel('coefficient for new term ($b$)') 52 | plt.ylabel('average relative error') 53 | plt.xlim([FLAGS.xmin, FLAGS.xmax]) 54 | yxrator = 0.55 if mode == "forward" else 1.5 55 | plt.ylim([0, FLAGS.xmax * yxrator]) 56 | plt.title(f"{mode} problem", fontsize = 10) 57 | 58 | filename = f'{analysis_folder}/err_runs_on_{datadir}_{FLAGS.xmax}_{mode}.pdf' 59 | plt.savefig(filename, format = 'pdf', bbox_inches='tight') 60 | 61 | 62 | 63 | def main(argv): 64 | 65 | for key, value in FLAGS.__flags.items(): 66 | print(value.name, ": ", value._value, flush=True) 67 | 68 | if FLAGS.runs_plot is None: 69 | runs_plot = ['wrong operator', 'wrong demos', 'ODE 2', 'ODE 1,2', 'ODE 2,3', 'ODE 1,2,3', 'all data'] 70 | 71 | myplot(FLAGS.analysis_folder, runs_plot, FLAGS.datadir, 'forward') 72 | myplot(FLAGS.analysis_folder, runs_plot, FLAGS.datadir, 'inverse') 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | FLAGS = flags.FLAGS 78 | 79 | flags.DEFINE_string('analysis_folder', '../analysis/analysis0521-ood', 'the folder where analysis results are stored') 80 | flags.DEFINE_list('runs_plot', None, 'the runs to plot') 81 | flags.DEFINE_string('datadir', 'data0520_nt_odelinear3', 'the datadir to plot') 82 | flags.DEFINE_float('xmin', -0.3, 'the xmin for the plot') 83 | flags.DEFINE_float('xmax', 0.3, 'the xmax for the plot') 84 | 85 | 86 | app.run(main) 87 | -------------------------------------------------------------------------------- /icon/analysis/plot_analysis_plot_1d.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.patches as patches 5 | from absl import app, flags, logging 6 | import sys 7 | sys.path.append('../') 8 | from jax.config import config 9 | import tensorflow as tf 10 | import os 11 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 12 | tf.config.set_visible_devices([], device_type='GPU') 13 | import utils 14 | from plot import get_plot_k_index 15 | from plot_analysis_utils import label_map 16 | 17 | def pattern_match(patterns, name): 18 | for pattern in patterns: 19 | if pattern in name: 20 | return True 21 | return False 22 | 23 | 24 | def myplot(): 25 | 26 | analysis_folder = FLAGS.analysis_folder 27 | stamp = FLAGS.stamp 28 | dataset_name = FLAGS.dataset_name 29 | patterns = ['ode', 'series', 'pde', 'mfc'] 30 | 31 | 32 | k_dim = FLAGS.k_dim 33 | k_mode = FLAGS.k_mode 34 | plot_num = FLAGS.plot_num 35 | demo_num = FLAGS.demo_num 36 | demo_num_max = 5 37 | 38 | with open("{}/err_{}_{}_{}_{}.pickle".format(analysis_folder, stamp, dataset_name, demo_num, demo_num + 1), 'rb') as file: 39 | results = pickle.load(file) 40 | 41 | for equation in results: 42 | if not pattern_match(patterns, equation): # key match the patterns 43 | continue 44 | 45 | this_result = results[equation] 46 | prompt = this_result['prompt'] 47 | query = this_result['query'] 48 | query_mask = this_result['query_mask'] 49 | pred = this_result['pred'] 50 | cond_k_index, qoi_k_index = get_plot_k_index(k_mode, equation) 51 | raw_cond_k_index, raw_qoi_k_index = get_plot_k_index('naive', equation) 52 | 53 | print(equation) 54 | print(f'cond_k_index: {cond_k_index}, qoi_k_index: {qoi_k_index}') 55 | print(f'raw_cond_k_index: {raw_cond_k_index}, raw_qoi_k_index: {raw_qoi_k_index}') 56 | for term in ["raw_demo_cond_k","raw_demo_cond_v", "raw_demo_qoi_k", "raw_demo_qoi_v", 57 | "raw_quest_cond_k","raw_quest_cond_v","raw_quest_qoi_k","raw_quest_qoi_v"]: 58 | print(term, this_result[term].shape) 59 | 60 | 61 | for plot_index in range(0, plot_num * 10, 10): 62 | plt.close('all') 63 | plt.figure(figsize=(4,5)) 64 | plt.subplot(2, 1, 1) 65 | # plot demo conditions 66 | for demo_i in range(demo_num): 67 | plt.plot(this_result['raw_demo_cond_k'][plot_index, demo_i, :, raw_cond_k_index], 68 | this_result['raw_demo_cond_v'][plot_index, demo_i, :, 0], ':', alpha = 0.7) 69 | mask_cond_i = np.abs(prompt[plot_index, :, -demo_num_max-1+demo_i] - 1) < 0.01 # around 1 70 | plt.plot(prompt[plot_index, mask_cond_i, cond_k_index], 71 | prompt[plot_index, mask_cond_i, k_dim], 'o', markersize=3, color = 'grey', alpha = 0.7) 72 | # plot quest conditions 73 | plt.plot(this_result['raw_quest_cond_k'][plot_index, 0, :, raw_cond_k_index], 74 | this_result['raw_quest_cond_v'][plot_index, 0, :, 0], 'k-') 75 | mask_cond_i = np.abs(prompt[plot_index, :, -1] - 1) < 0.01 # around 1 76 | plt.plot(prompt[plot_index, mask_cond_i, cond_k_index], 77 | prompt[plot_index, mask_cond_i, k_dim], 's', markersize=3, color = 'blue', alpha = 0.7) 78 | plt.title(label_map[equation]['legend']+'\ncondition') 79 | plt.xlabel(label_map[equation]['xlabel']) 80 | 81 | # plot demo qois 82 | plt.subplot(2, 1, 2) 83 | for demo_i in range(demo_num): 84 | plt.plot(this_result['raw_demo_qoi_k'][plot_index, demo_i, :, raw_qoi_k_index], 85 | this_result['raw_demo_qoi_v'][plot_index, demo_i, :, 0], ':', alpha = 0.7) 86 | mask_qoi_i = np.abs(prompt[plot_index, :, -demo_num_max-1+demo_i] + 1) < 0.01 # around -1 87 | plt.plot(prompt[plot_index, mask_qoi_i, qoi_k_index], 88 | prompt[plot_index, mask_qoi_i, k_dim], 'o', markersize=3, color = 'grey', alpha = 0.7) 89 | # plot quest qois, i.e., ground truth and predictions 90 | plt.plot(this_result['raw_quest_qoi_k'][plot_index, 0, :, raw_qoi_k_index], 91 | this_result['raw_quest_qoi_v'][plot_index, 0, :, 0], 'k-') 92 | plt.plot(query[plot_index, query_mask[plot_index], qoi_k_index], 93 | pred[plot_index, query_mask[plot_index], 0], 's', markersize=3, color = 'red', alpha = 0.7) 94 | plt.title('QoI') 95 | plt.xlabel(label_map[equation]['xlabel']) 96 | 97 | 98 | plt.tight_layout() 99 | plt.savefig(f"{analysis_folder}/{equation}_demo{demo_num}_index{plot_index}.pdf", 100 | format = 'pdf', bbox_inches='tight') 101 | plt.title(this_result['equation'][plot_index]) # for debugging 102 | plt.savefig(f"{analysis_folder}/{equation}_demo{demo_num}_index{plot_index}_equation.pdf", 103 | format = 'pdf', bbox_inches='tight') 104 | 105 | 106 | def main(argv): 107 | 108 | for key, value in FLAGS.__flags.items(): 109 | print(value.name, ": ", value._value, flush=True) 110 | 111 | myplot() 112 | 113 | if __name__ == '__main__': 114 | 115 | FLAGS = flags.FLAGS 116 | 117 | flags.DEFINE_integer('k_dim', 3, 'k dimension') 118 | flags.DEFINE_string('k_mode', 'itx', 'k mode') 119 | flags.DEFINE_integer('plot_num', 5, 'plot num for each equation, should correspond to different operators') 120 | flags.DEFINE_integer('demo_num', 5, 'demo num used, max 5') 121 | 122 | flags.DEFINE_string('analysis_folder', '../analysis/analysis0511a-v4-ind', 'the folder where analysis results are stored') 123 | flags.DEFINE_string('stamp', '20230515-094404_1000000', 'the stamp of the analysis result') 124 | flags.DEFINE_string('dataset_name', 'data0511a', 'the name of the dataset') 125 | 126 | app.run(main) 127 | -------------------------------------------------------------------------------- /icon/analysis/plot_analysis_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | label_map = { 4 | "mfc_gparam_hj_forward11": {'legend': "MFC $g$-parameter 1D -> 1D", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$x$"}, 5 | "mfc_gparam_hj_forward12": {'legend': "MFC $g$-parameter 1D -> 2D", 'linestyle': '--', 'marker': 'o', 'xlabel': r"$x$"}, 6 | "mfc_gparam_hj_forward22": {'legend': "MFC $g$-parameter 2D -> 2D", 'linestyle': ':', 'marker': 'o', 'xlabel': r"$x$"}, 7 | "mfc_rhoparam_hj_forward11": {'legend': r"MFC $\rho_0$-parameter 1D -> 1D", 'linestyle': '-', 'marker': 's', 'xlabel': r"$x$"}, 8 | "mfc_rhoparam_hj_forward12": {'legend': r"MFC $\rho_0$-parameter 1D -> 2D", 'linestyle': ':', 'marker': 's', 'xlabel': r"$x$"}, 9 | "mfc_gparam_solve_forward": {'legend': "MFC 1", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$x$"}, 10 | "mfc_rhoparam_solve_forward": {'legend': "MFC 2", 'linestyle': '--', 'marker': 'o', 'xlabel': r"$x$"}, 11 | "ode_auto_const_forward": {'legend': "Forward problem of ODE 1", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$t$"}, 12 | "ode_auto_const_inverse": {'legend': "Inverse problem of ODE 1", 'linestyle': ':', 'marker': 'o', 'xlabel': r"$t$"}, 13 | "ode_auto_linear1_forward": {'legend': "Forward problem of ODE 2", 'linestyle': '-', 'marker': 's', 'xlabel': r"$t$"}, 14 | "ode_auto_linear1_inverse": {'legend': "Inverse problem of ODE 2", 'linestyle': ':', 'marker': 's', 'xlabel': r"$t$"}, 15 | "ode_auto_linear2_forward": {'legend': "Forward problem of ODE 3", 'linestyle': '-', 'marker': 'd', 'xlabel': r"$t$"}, 16 | "ode_auto_linear2_inverse": {'legend': "Inverse problem of ODE 3", 'linestyle': ':', 'marker': 'd', 'xlabel': r"$t$"}, 17 | "pde_poisson_spatial_forward": {'legend': "Forward Poisson equation", 'linestyle': '-', 'marker': 'o', 'xlabel': r"$x$"}, 18 | "pde_poisson_spatial_inverse": {'legend': "Inverse Poisson equation", 'linestyle': ':', 'marker': 'o', 'xlabel': r"$x$"}, 19 | "pde_porous_spatial_forward" : {'legend': "Forward linear reaction-diffusion", 'linestyle': '-', 'marker': 's', 'xlabel': r"$x$"}, 20 | "pde_porous_spatial_inverse": {'legend': "Inverse linear reaction-diffusion", 'linestyle': ':', 'marker': 's', 'xlabel': r"$x$"}, 21 | "pde_cubic_spatial_forward": {'legend': "Forward nonlinear reaction-diffusion", 'linestyle': '-', 'marker': 'd', 'xlabel': r"$x$"}, 22 | "pde_cubic_spatial_inverse": {'legend': "Inverse nonlinear reaction-diffusion", 'linestyle': ':', 'marker': 'd', 'xlabel': r"$x$"}, 23 | "series_damped_oscillator": {'legend': "time series prediction", 'linestyle': '-', 'marker': '*', 'xlabel': r"$t$"}, 24 | "series_damped_oscillator_forward": {'legend': "Forward damped oscillator", 'linestyle': '-', 'marker': '*', 'xlabel': r"$t$"}, 25 | "series_damped_oscillator_inverse": {'legend': "Inverse damped oscillator", 'linestyle': ':', 'marker': '*', 'xlabel': r"$t$"}, 26 | } -------------------------------------------------------------------------------- /icon/analysis/test_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "pde_spatial_forward":{ 3 | "demo_num_begin": 1, 4 | "demo_num_end": 6, 5 | "select_cond_ind": "random", 6 | "select_qoi_ind": "random", 7 | "cond_len_in_use_begin": 41, 8 | "cond_len_in_use_end": 51, 9 | "qoi_len_in_use_begin": 41, 10 | "qoi_len_in_use_end": 51}, 11 | "pde_spatial_inverse":{ 12 | "demo_num_begin": 1, 13 | "demo_num_end": 6, 14 | "select_cond_ind": "random", 15 | "select_qoi_ind": "random", 16 | "cond_len_in_use_begin": 41, 17 | "cond_len_in_use_end": 51, 18 | "qoi_len_in_use_begin": 41, 19 | "qoi_len_in_use_end": 51}, 20 | "ode_forward":{ 21 | "demo_num_begin": 1, 22 | "demo_num_end": 6, 23 | "select_cond_ind": "even", 24 | "select_qoi_ind": "even", 25 | "qoi_len_in_use_begin": 41, 26 | "qoi_len_in_use_end": 51}, 27 | "ode_inverse":{ 28 | "demo_num_begin": 1, 29 | "demo_num_end": 6, 30 | "select_cond_ind": "even", 31 | "select_qoi_ind": "even", 32 | "qoi_len_in_use_begin": 40, 33 | "qoi_len_in_use_end": 50}, 34 | "time_series":{ 35 | "demo_num_begin": 1, 36 | "demo_num_end": 6, 37 | "select_cond_ind": "random", 38 | "select_qoi_ind": "random", 39 | "cond_len_in_use_begin": 41, 40 | "cond_len_in_use_end": 51, 41 | "qoi_len_in_use_begin": 41, 42 | "qoi_len_in_use_end": 51}, 43 | "mfc_gparam_forward":{ 44 | "demo_num_begin": 1, 45 | "demo_num_end": 6, 46 | "select_cond_ind": "random", 47 | "select_qoi_ind": "random", 48 | "cond_len_in_use_begin": 41, 49 | "cond_len_in_use_end": 51, 50 | "qoi_len_in_use_begin": 41, 51 | "qoi_len_in_use_end": 51}, 52 | "mfc_rhoparam_forward":{ 53 | "demo_num_begin": 1, 54 | "demo_num_end": 6, 55 | "select_cond_ind": "random", 56 | "select_qoi_ind": "random", 57 | "cond_len_in_use_begin": 41, 58 | "cond_len_in_use_end": 51, 59 | "qoi_len_in_use_begin": 41, 60 | "qoi_len_in_use_end": 51}, 61 | "others":{ 62 | "demo_num_begin": 1, 63 | "demo_num_end": 6, 64 | "select_cond_ind": "random", 65 | "select_qoi_ind": "random", 66 | "cond_len_in_use_begin": 41, 67 | "cond_len_in_use_end": 51, 68 | "qoi_len_in_use_begin": 41, 69 | "qoi_len_in_use_end": 51} 70 | } -------------------------------------------------------------------------------- /icon/analysis/test_config_len.json: -------------------------------------------------------------------------------- 1 | { 2 | "pde_spatial_forward":{ 3 | "demo_num_begin":5, 4 | "demo_num_end": 6, 5 | "select_cond_ind": "random", 6 | "select_qoi_ind": "random", 7 | "cond_len_in_use_begin": 50, 8 | "cond_len_in_use_end": 51, 9 | "qoi_len_in_use_begin": 50, 10 | "qoi_len_in_use_end": 51}, 11 | "pde_spatial_inverse":{ 12 | "demo_num_begin":5, 13 | "demo_num_end": 6, 14 | "select_cond_ind": "random", 15 | "select_qoi_ind": "random", 16 | "cond_len_in_use_begin": 50, 17 | "cond_len_in_use_end": 51, 18 | "qoi_len_in_use_begin": 50, 19 | "qoi_len_in_use_end": 51}, 20 | "ode_forward":{ 21 | "demo_num_begin":5, 22 | "demo_num_end": 6, 23 | "select_cond_ind": "even", 24 | "select_qoi_ind": "even", 25 | "qoi_len_in_use_begin": 50, 26 | "qoi_len_in_use_end": 51}, 27 | "ode_inverse":{ 28 | "demo_num_begin":5, 29 | "demo_num_end": 6, 30 | "select_cond_ind": "even", 31 | "select_qoi_ind": "even", 32 | "qoi_len_in_use_begin": 49, 33 | "qoi_len_in_use_end": 50}, 34 | "time_series":{ 35 | "demo_num_begin":5, 36 | "demo_num_end": 6, 37 | "select_cond_ind": "random", 38 | "select_qoi_ind": "random", 39 | "cond_len_in_use_begin": 50, 40 | "cond_len_in_use_end": 51, 41 | "qoi_len_in_use_begin": 50, 42 | "qoi_len_in_use_end": 51}, 43 | "mfc_gparam_forward":{ 44 | "demo_num_begin":5, 45 | "demo_num_end": 6, 46 | "select_cond_ind": "random", 47 | "select_qoi_ind": "random", 48 | "cond_len_in_use_begin": 2400, 49 | "cond_len_in_use_end": 2401, 50 | "qoi_len_in_use_begin": 2600, 51 | "qoi_len_in_use_end": 2601}, 52 | "mfc_rhoparam_forward":{ 53 | "demo_num_begin":5, 54 | "demo_num_end": 6, 55 | "select_cond_ind": "random", 56 | "select_qoi_ind": "random", 57 | "cond_len_in_use_begin": 50, 58 | "cond_len_in_use_end": 51, 59 | "qoi_len_in_use_begin": 50, 60 | "qoi_len_in_use_end": 51}, 61 | "others":{ 62 | "demo_num_begin":5, 63 | "demo_num_end": 6, 64 | "select_cond_ind": "random", 65 | "select_qoi_ind": "random", 66 | "cond_len_in_use_begin": 50, 67 | "cond_len_in_use_end": 51, 68 | "qoi_len_in_use_begin": 50, 69 | "qoi_len_in_use_end": 51} 70 | } -------------------------------------------------------------------------------- /icon/analysis/test_config_ood.json: -------------------------------------------------------------------------------- 1 | { 2 | "pde_spatial_forward":{ 3 | "demo_num_begin":5, 4 | "demo_num_end": 6, 5 | "select_cond_ind": "random", 6 | "select_qoi_ind": "random", 7 | "cond_len_in_use_begin": 50, 8 | "cond_len_in_use_end": 51, 9 | "qoi_len_in_use_begin": 50, 10 | "qoi_len_in_use_end": 51}, 11 | "pde_spatial_inverse":{ 12 | "demo_num_begin":5, 13 | "demo_num_end": 6, 14 | "select_cond_ind": "random", 15 | "select_qoi_ind": "random", 16 | "cond_len_in_use_begin": 50, 17 | "cond_len_in_use_end": 51, 18 | "qoi_len_in_use_begin": 50, 19 | "qoi_len_in_use_end": 51}, 20 | "ode_forward":{ 21 | "demo_num_begin":5, 22 | "demo_num_end": 6, 23 | "select_cond_ind": "even", 24 | "select_qoi_ind": "even", 25 | "qoi_len_in_use_begin": 50, 26 | "qoi_len_in_use_end": 51}, 27 | "ode_inverse":{ 28 | "demo_num_begin":5, 29 | "demo_num_end": 6, 30 | "select_cond_ind": "even", 31 | "select_qoi_ind": "even", 32 | "qoi_len_in_use_begin": 49, 33 | "qoi_len_in_use_end": 50}, 34 | "time_series":{ 35 | "demo_num_begin":5, 36 | "demo_num_end": 6, 37 | "select_cond_ind": "random", 38 | "select_qoi_ind": "random", 39 | "cond_len_in_use_begin": 50, 40 | "cond_len_in_use_end": 51, 41 | "qoi_len_in_use_begin": 50, 42 | "qoi_len_in_use_end": 51}, 43 | "mfc_gparam_forward":{ 44 | "demo_num_begin":5, 45 | "demo_num_end": 6, 46 | "select_cond_ind": "random", 47 | "select_qoi_ind": "random", 48 | "cond_len_in_use_begin": 50, 49 | "cond_len_in_use_end": 51, 50 | "qoi_len_in_use_begin": 50, 51 | "qoi_len_in_use_end": 51}, 52 | "mfc_rhoparam_forward":{ 53 | "demo_num_begin":5, 54 | "demo_num_end": 6, 55 | "select_cond_ind": "random", 56 | "select_qoi_ind": "random", 57 | "cond_len_in_use_begin": 50, 58 | "cond_len_in_use_end": 51, 59 | "qoi_len_in_use_begin": 50, 60 | "qoi_len_in_use_end": 51}, 61 | "others":{ 62 | "demo_num_begin":5, 63 | "demo_num_end": 6, 64 | "select_cond_ind": "random", 65 | "select_qoi_ind": "random", 66 | "cond_len_in_use_begin": 50, 67 | "cond_len_in_use_end": 51, 68 | "qoi_len_in_use_begin": 50, 69 | "qoi_len_in_use_end": 51} 70 | } -------------------------------------------------------------------------------- /icon/data_generation/data_dynamics.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from collections import namedtuple 4 | 5 | ''' 6 | the semantics of jax.lax.scan() are given roughly by: 7 | 8 | def scan(f, init, xs, length=None): 9 | if xs is None: 10 | xs = [None] * length 11 | carry = init 12 | ys = [] 13 | for x in xs: 14 | carry, y = f(carry, x) 15 | ys.append(y) 16 | return carry, np.stack(ys) 17 | ''' 18 | 19 | DynamicsFn = namedtuple("DynamicsFn", ["name", "dyn", "dyn_batch"]) 20 | 21 | def build_ode_auto(func, name, dt = 0.01): 22 | ''' 23 | constant control: 24 | du/dt = func(c(t), u(t)) 25 | ''' 26 | f = lambda u, c: (u + dt * func(c, u), u) 27 | @jax.jit 28 | def dyn(init, control): 29 | final, traj = jax.lax.scan(f, init, control) 30 | return (final, traj) 31 | dyn_batch = jax.jit(jax.vmap(dyn, [0,0],0)) 32 | return DynamicsFn(name = "ode_auto_{}".format(name), dyn=dyn, dyn_batch = dyn_batch) 33 | 34 | 35 | def build_ode_constant(func, name, dt = 0.01): 36 | ''' 37 | constant control: 38 | du/dt = func(control) 39 | ''' 40 | f = lambda u, c: (u + dt * func(c), u) 41 | @jax.jit 42 | def dyn(init, control): 43 | final, traj = jax.lax.scan(f, init, control) 44 | return (final, traj) 45 | dyn_batch = jax.jit(jax.vmap(dyn, [0,0],0)) 46 | return DynamicsFn(name = "ode_const_{}".format(name), dyn=dyn, dyn_batch = dyn_batch) 47 | 48 | 49 | def build_ode_linear(func, name, dt = 0.01): 50 | ''' 51 | constant control: 52 | du/dt = func(control) * u 53 | ''' 54 | f = lambda u, c: (u + dt * func(c) * u, u) 55 | @jax.jit 56 | def dyn(init, control): 57 | final, traj = jax.lax.scan(f, init, control) 58 | return (final, traj) 59 | dyn_batch = jax.jit(jax.vmap(dyn, [0,0],0)) 60 | return DynamicsFn(name = "ode_linear_{}".format(name), dyn=dyn, dyn_batch = dyn_batch) 61 | 62 | 63 | @jax.jit 64 | def ode_auto_const_fn(init, control, dt, coeff_a, coeff_b): 65 | rhs = lambda c, u: coeff_a * c + coeff_b 66 | f = lambda u, c: (u + dt * rhs(c, u), u) 67 | # traj[0] = init, final is affected by control[-1] 68 | final, traj = jax.lax.scan(f, init, control) 69 | return final, traj 70 | 71 | @jax.jit 72 | def ode_auto_linear1_fn(init, control, dt, coeff_a, coeff_b): 73 | rhs = lambda c, u: (coeff_a * c * u + coeff_b) 74 | f = lambda u, c: (u + dt * rhs(c, u), u) 75 | # traj[0] = init, final is affected by control[-1] 76 | final, traj = jax.lax.scan(f, init, control) 77 | return final, traj 78 | 79 | @jax.jit 80 | def ode_auto_linear2_fn(init, control, dt, coeff_a1, coeff_a2, coeff_a3): 81 | rhs = lambda c, u: coeff_a1 * u + coeff_a2 * c + coeff_a3 82 | f = lambda u, c: (u + dt * rhs(c, u), u) 83 | # traj[0] = init, final is affected by control[-1] 84 | final, traj = jax.lax.scan(f, init, control) 85 | return final, traj 86 | 87 | @jax.jit 88 | def ode_auto_linear3_fn(init, control, dt, coeff_a1, coeff_a2, coeff_a3): 89 | rhs = lambda c, u: coeff_a1 * c * u + coeff_a2 * u + coeff_a3 90 | f = lambda u, c: (u + dt * rhs(c, u), u) 91 | # traj[0] = init, final is affected by control[-1] 92 | final, traj = jax.lax.scan(f, init, control) 93 | return final, traj 94 | 95 | ode_auto_const_batch_fn = jax.jit(jax.vmap(ode_auto_const_fn, [0,0, None, None, None], (0,0))) 96 | ode_auto_linear1_batch_fn = jax.jit(jax.vmap(ode_auto_linear1_fn, [0,0, None, None, None],(0,0))) 97 | ode_auto_linear2_batch_fn = jax.jit(jax.vmap(ode_auto_linear2_fn, [0,0, None, None, None, None],(0,0))) 98 | ode_auto_linear3_batch_fn = jax.jit(jax.vmap(ode_auto_linear3_fn, [0,0, None, None, None, None],(0,0))) 99 | 100 | 101 | if __name__ == "__main__": 102 | from jax.config import config 103 | config.update('jax_enable_x64', True) 104 | 105 | dyn_fn = build_ode_constant(lambda c: c, name = "x1", dt = 0.02) 106 | final, traj = dyn_fn.dyn(1.0, jnp.ones((10,))) 107 | print(dyn_fn.name, final, traj) 108 | 109 | dyn_fn = build_ode_auto(lambda c, u: c, name = "x1", dt = 0.02) 110 | final, traj = dyn_fn.dyn(1.0, jnp.ones((10,))) 111 | print(dyn_fn.name, final, traj) 112 | 113 | dyn_fn = build_ode_linear(lambda c: c, name = "x1", dt = 0.02) 114 | final, traj = dyn_fn.dyn(1.0, jnp.ones((10,))) 115 | print(dyn_fn.name, final, traj) 116 | 117 | dyn_fn = build_ode_auto(lambda c, u: c * u, name = "x1", dt = 0.02) 118 | final, traj = dyn_fn.dyn(1.0, jnp.ones((10,))) 119 | print(dyn_fn.name, final, traj) -------------------------------------------------------------------------------- /icon/data_generation/data_series.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from collections import namedtuple 4 | 5 | 6 | def generate_sin(xs, amp, period, phase): 7 | return amp * jnp.sin(xs * 2 * jnp.pi / period + phase) 8 | generate_sin_batch = jax.jit(jax.vmap(generate_sin, [None, 0, 0, 0], 0)) 9 | 10 | def generate_sin_base(xs, amp, period, phase, base): 11 | return base + generate_sin(xs, amp, period, phase) 12 | generate_sin_base_batch = jax.jit(jax.vmap(generate_sin, [None, 0, 0, 0, None], 0)) # base is shared in batch 13 | 14 | def generate_damped_oscillator(xs, amp, period, phase, decay): 15 | return generate_sin(xs, amp, period, phase) * jnp.exp(-decay * xs) 16 | 17 | generate_damped_oscillator_batch = jax.jit(jax.vmap(generate_damped_oscillator, [None, 0, 0, 0, None], 0)) # decay is shared in batch -------------------------------------------------------------------------------- /icon/data_generation/datagen.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | dir=data0511a 3 | testeqns=100 4 | testquests=5 5 | traineqns=1000 6 | 7 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --length 100 --dt 0.01 --eqn_types series_damped_oscillator --seed 101 && 8 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --eqn_types ode_auto_const --seed 102 && 9 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --eqn_types ode_auto_linear1 --seed 103 && 10 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --eqn_types ode_auto_linear2 --seed 104 && 11 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --eqn_types pde_poisson_spatial --seed 105 && 12 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --eqn_types pde_porous_spatial --seed 106 && 13 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --eqn_types pde_cubic_spatial --seed 107 && 14 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_gparam_hj --seed 108 && 15 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name test --eqns $testeqns --quests $testquests --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_rhoparam_hj --seed 109 && 16 | 17 | 18 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --length 100 --dt 0.01 --eqn_types series_damped_oscillator --seed 1 && 19 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --eqn_types ode_auto_const --seed 2 && 20 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --eqn_types ode_auto_linear1 --seed 3 && 21 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --eqn_types ode_auto_linear2 --seed 4 && 22 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --length 100 --dx 0.01 --eqn_types pde_poisson_spatial --seed 5 && 23 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --length 100 --dx 0.01 --eqn_types pde_porous_spatial --seed 6 && 24 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --length 100 --dx 0.01 --eqn_types pde_cubic_spatial --seed 7 && 25 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_gparam_hj --seed 8 && 26 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen.py --dir $dir --name train --eqns $traineqns --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types mfc_rhoparam_hj --seed 9 && 27 | 28 | echo "Done" 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /icon/data_generation/datagen_ood.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | prefix=data0520 3 | 4 | # light dataset 5 | quests=1 6 | ood_coeff1_grids=10 7 | ood_coeff2_grids=11 8 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --ood_coeff1_grids $ood_coeff1_grids --ood_coeff2_grids $ood_coeff2_grids --quests $quests --eqn_types ode_auto_const --dir ./${prefix}_ood_odeconst_light >out_data_ood_ode_auto_const_light.log 2>&1 && 9 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --ood_coeff1_grids $ood_coeff1_grids --ood_coeff2_grids $ood_coeff2_grids --quests $quests --eqn_types ode_auto_linear1 --dir ./${prefix}_ood_odelinear1_light >out_data_ood_ode_auto_linear1_light.log 2>&1 && 10 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --ood_coeff1_grids $ood_coeff1_grids --ood_coeff2_grids $ood_coeff2_grids --quests $quests --eqn_types ode_auto_linear2 --dir ./${prefix}_ood_odelinear2_light >out_data_ood_ode_auto_linear2_light.log 2>&1 && 11 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --ood_coeff1_grids $ood_coeff1_grids --ood_coeff2_grids $ood_coeff2_grids --quests $quests --length 100 --dx 0.01 --eqn_types pde_porous_spatial --dir ./${prefix}_ood_pdeporous_randbdry_light >out_data_ood_pde_porous_spatial_light.log 2>&1 && 12 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --ood_coeff1_grids $ood_coeff1_grids --ood_coeff2_grids $ood_coeff2_grids --quests $quests --length 100 --dt 0.01 --eqn_types series_damped_oscillator --dir ./${prefix}_ood_seriesdamped_light >out_data_ood_seriesdamped_light.log 2>&1 && 13 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --ood_coeff1_grids $ood_coeff1_grids --ood_coeff2_grids $ood_coeff2_grids --quests $quests --eqn_types ode_auto_linear3 --dir ./${prefix}_nt_odelinear3_light >out_data_nt_ode_auto_linear3_light.log 2>&1 && 14 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --ood_coeff1_grids $ood_coeff1_grids --ood_coeff2_grids $ood_coeff2_grids --quests $quests --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types ot_rho1param --dir ./${prefix}_nt_ot_rho1param_light >out_data_nt_ot_rho1param_light.log 2>&1 && 15 | 16 | # formal dataset 17 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --eqn_types ode_auto_const --dir ./${prefix}_ood_odeconst >out_data_ood_ode_auto_const.log 2>&1 && 18 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --eqn_types ode_auto_linear1 --dir ./${prefix}_ood_odelinear1 >out_data_ood_ode_auto_linear1.log 2>&1 && 19 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --eqn_types ode_auto_linear2 --dir ./${prefix}_ood_odelinear2 >out_data_ood_ode_auto_linear2.log 2>&1 && 20 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --length 100 --dx 0.01 --eqn_types pde_porous_spatial --dir ./${prefix}_ood_pdeporous_randbdry >out_data_ood_pde_porous_spatial.log 2>&1 && 21 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --length 100 --dt 0.01 --eqn_types series_damped_oscillator --dir ./${prefix}_ood_seriesdamped >out_data_ood_seriesdamped.log 2>&1 && 22 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --eqn_types ode_auto_linear3 --ood_coeff2_grids 201 --dir ./${prefix}_nt_odelinear3 >out_data_nt_ode_auto_linear3.log 2>&1 && 23 | CUDA_VISIBLE_DEVICES=$gpu python3 datagen_ood.py --seed 202 --length 100 --dx 0.01 --dt 0.02 --nu_nx_ratio 1 --eqn_types ot_rho1param --dir ./${prefix}_nt_ot_rho1param >out_data_nt_ot_rho1param.log 2>&1 && 24 | 25 | 26 | echo "Done" 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /icon/run_group.sh: -------------------------------------------------------------------------------- 1 | gpu=0 2 | 3 | k_dim=3 4 | k_mode=itx 5 | problem=group-$k_dim-$k_mode 6 | 7 | CUDA_VISIBLE_DEVICES=$gpu python3 run.py --problem $problem --num_heads 8 --num_layers 6 --hidden_dim 256 --train_batch_size 16 --epochs 20 --train_warmup_percent 40 --train_data_dirs './data_generation/data0511a' --k_dim $k_dim --k_mode $k_mode --train_data_globs 'train*ode*linear1*' --test_data_globs 'test*ode*linear1*' --tfboard >out_0511a_${problem}_odes_l1.log 2>&1 && 8 | CUDA_VISIBLE_DEVICES=$gpu python3 run.py --problem $problem --num_heads 8 --num_layers 6 --hidden_dim 256 --train_batch_size 16 --epochs 20 --train_warmup_percent 40 --train_data_dirs './data_generation/data0511a' --k_dim $k_dim --k_mode $k_mode --train_data_globs 'train*ode*linear1*','train*ode*const*' --test_data_globs 'test*ode*linear1*','test*ode*const*' --tfboard >out_0511a_${problem}_odes_l1c.log 2>&1 && 9 | CUDA_VISIBLE_DEVICES=$gpu python3 run.py --problem $problem --num_heads 8 --num_layers 6 --hidden_dim 256 --train_batch_size 16 --epochs 20 --train_warmup_percent 40 --train_data_dirs './data_generation/data0511a' --k_dim $k_dim --k_mode $k_mode --train_data_globs 'train*ode*linear1*','train*ode*linear2*' --test_data_globs 'test*ode*linear1*','test*ode*linear2*' --tfboard >out_0511a_${problem}_odes_l1l2.log 2>&1 && 10 | CUDA_VISIBLE_DEVICES=$gpu python3 run.py --problem $problem --num_heads 8 --num_layers 6 --hidden_dim 256 --train_batch_size 16 --epochs 20 --train_warmup_percent 40 --train_data_dirs './data_generation/data0511a' --k_dim $k_dim --k_mode $k_mode --train_data_globs 'train*ode*' --test_data_globs 'test*ode*' --tfboard >out_0511a_${problem}_odes_l1l2c.log 2>&1 && 11 | 12 | echo "Done" -------------------------------------------------------------------------------- /icon/train_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "pde_spatial_forward":{ 3 | "demo_num_begin": 1, 4 | "demo_num_end": 6, 5 | "select_cond_ind": "random", 6 | "select_qoi_ind": "random", 7 | "cond_len_in_use_begin": 41, 8 | "cond_len_in_use_end": 51, 9 | "qoi_len_in_use_begin": 41, 10 | "qoi_len_in_use_end": 51}, 11 | "pde_spatial_inverse":{ 12 | "demo_num_begin": 1, 13 | "demo_num_end": 6, 14 | "select_cond_ind": "random", 15 | "select_qoi_ind": "random", 16 | "cond_len_in_use_begin": 41, 17 | "cond_len_in_use_end": 51, 18 | "qoi_len_in_use_begin": 41, 19 | "qoi_len_in_use_end": 51}, 20 | "ode_forward":{ 21 | "demo_num_begin": 1, 22 | "demo_num_end": 6, 23 | "select_cond_ind": "even", 24 | "select_qoi_ind": "even", 25 | "qoi_len_in_use_begin": 41, 26 | "qoi_len_in_use_end": 51}, 27 | "ode_inverse":{ 28 | "demo_num_begin": 1, 29 | "demo_num_end": 6, 30 | "select_cond_ind": "even", 31 | "select_qoi_ind": "even", 32 | "qoi_len_in_use_begin": 40, 33 | "qoi_len_in_use_end": 50}, 34 | "time_series":{ 35 | "demo_num_begin": 1, 36 | "demo_num_end": 6, 37 | "select_cond_ind": "random", 38 | "select_qoi_ind": "random", 39 | "cond_len_in_use_begin": 41, 40 | "cond_len_in_use_end": 51, 41 | "qoi_len_in_use_begin": 41, 42 | "qoi_len_in_use_end": 51}, 43 | "mfc_gparam_forward":{ 44 | "demo_num_begin": 1, 45 | "demo_num_end": 6, 46 | "select_cond_ind": "random", 47 | "select_qoi_ind": "random", 48 | "cond_len_in_use_begin": 41, 49 | "cond_len_in_use_end": 51, 50 | "qoi_len_in_use_begin": 41, 51 | "qoi_len_in_use_end": 51}, 52 | "mfc_rhoparam_forward":{ 53 | "demo_num_begin": 1, 54 | "demo_num_end": 6, 55 | "select_cond_ind": "random", 56 | "select_qoi_ind": "random", 57 | "cond_len_in_use_begin": 41, 58 | "cond_len_in_use_end": 51, 59 | "qoi_len_in_use_begin": 41, 60 | "qoi_len_in_use_end": 51}, 61 | "others":{ 62 | "demo_num_begin": 1, 63 | "demo_num_end": 6, 64 | "select_cond_ind": "random", 65 | "select_qoi_ind": "random", 66 | "cond_len_in_use_begin": 41, 67 | "cond_len_in_use_end": 51, 68 | "qoi_len_in_use_begin": 41, 69 | "qoi_len_in_use_end": 51} 70 | } -------------------------------------------------------------------------------- /icon/transformer.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional 3 | 4 | import haiku as hk 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | import utils 9 | 10 | 11 | def layer_norm(x: jnp.ndarray) -> jnp.ndarray: 12 | '''add a LayerNorm layer and apply to x''' 13 | ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) 14 | return ln(x) 15 | 16 | 17 | @dataclasses.dataclass 18 | class SelfAttnTransformer(hk.Module): 19 | ''' 20 | self-attention transformer 21 | ''' 22 | 23 | num_heads: int 24 | num_layers: int 25 | model_size: int # the size of embedding or h 26 | QK_size: int 27 | V_size: int 28 | widening_factor: int = 4 29 | initializer: str = 'glorot_uniform' 30 | name: Optional[str] = None 31 | 32 | def __call__( 33 | self, 34 | embeddings: jnp.ndarray, # [...,T, model_size] 35 | mask = None, # [...,1, T, T] 36 | ) -> jnp.ndarray: # [..., T, model_size] 37 | 38 | if self.initializer == 'glorot_uniform': 39 | initializer = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") # glorot_uniform 40 | elif self.initializer == 'layer_scale': 41 | initializer = hk.initializers.VarianceScaling(2 / self.num_layers) 42 | else: 43 | raise NotImplementedError 44 | 45 | e_norm = layer_norm(embeddings) 46 | h = e_norm 47 | for _ in range(self.num_layers): 48 | # First the attention block. 49 | attn_block = hk.MultiHeadAttention( 50 | num_heads=self.num_heads, 51 | key_size=self.QK_size, 52 | value_size=self.V_size, 53 | model_size=self.model_size, 54 | w_init=initializer) 55 | h_attn = attn_block(h, h, h, mask=mask) 56 | h = h + h_attn 57 | h = layer_norm(h) 58 | 59 | # Then the dense block. 60 | dense_block = hk.Sequential([ 61 | hk.Linear(self.widening_factor * self.model_size, w_init=initializer), 62 | jax.nn.gelu, 63 | hk.Linear(self.model_size, w_init=initializer), 64 | ]) 65 | h_dense = dense_block(h) 66 | h = h + h_dense 67 | h = layer_norm(h) 68 | 69 | return h 70 | 71 | 72 | @dataclasses.dataclass 73 | class CrossAttnTransformer(hk.Module): 74 | ''' 75 | cross-attention transformer 76 | Note that in MultiHeadAttention, the query, key, value will be reshaped via linear projection 77 | query -> Q [t, key_size] 78 | key -> K [T, key_size] 79 | value -> V [T, value_size] 80 | ''' 81 | 82 | num_heads: int 83 | num_layers: int 84 | model_size: int # dim for query 85 | QK_size: int # dim for Q and K 86 | V_size: int # dim for V 87 | widening_factor: int = 4 88 | initializer: str = 'glorot_uniform' 89 | name: Optional[str] = None 90 | 91 | 92 | def __call__( 93 | self, 94 | query: jnp.ndarray, # [t, model_size] 95 | key: jnp.ndarray, # [T, key_size] 96 | value: jnp.ndarray, #[T, value_size] 97 | mask = None, #[1, t, T] 98 | final_norm = True 99 | ) -> jnp.ndarray: # [t, D_o] 100 | 101 | if self.initializer == 'glorot_uniform': 102 | initializer = hk.initializers.VarianceScaling(1.0, "fan_avg", "uniform") # glorot_uniform 103 | elif self.initializer == 'layer_scale': 104 | initializer = hk.initializers.VarianceScaling(2 / self.num_layers) 105 | else: 106 | raise NotImplementedError 107 | 108 | query_norm = layer_norm(query) 109 | key_norm = layer_norm(key) 110 | value_norm = layer_norm(value) 111 | 112 | for i in range(self.num_layers): 113 | # First the attention block. 114 | attn_block = hk.MultiHeadAttention( 115 | num_heads=self.num_heads, 116 | key_size=self.QK_size, 117 | w_init=initializer, 118 | value_size=self.V_size, 119 | model_size=self.model_size, 120 | name = "attn_{}".format(i)) 121 | 122 | this_query = attn_block(query = query_norm, key = key_norm, value = value_norm, mask = mask) 123 | query_norm = layer_norm(this_query + query_norm) 124 | 125 | # Then the dense block. 126 | dense_block = hk.Sequential([ 127 | hk.Linear(self.widening_factor * self.model_size, w_init=initializer), 128 | jax.nn.gelu, 129 | hk.Linear(self.model_size, w_init=initializer), 130 | ], name = "dense_{}".format(i)) 131 | 132 | this_query = dense_block(query_norm) 133 | 134 | if (i == self.num_layers) and not final_norm: 135 | query_norm = this_query + query_norm 136 | else: 137 | query_norm = layer_norm(this_query + query_norm) 138 | 139 | return query_norm 140 | 141 | 142 | if __name__ == "__main__": 143 | 144 | query_size = 3 145 | key_size = 4 146 | value_size = 5 147 | QK_size = 10 148 | V_size = 12 149 | 150 | t = 20 151 | T = 40 152 | query = jax.random.normal(jax.random.PRNGKey(1), [t,query_size]) 153 | key = jax.random.normal(jax.random.PRNGKey(2), [T,key_size]) 154 | value = jax.random.normal(jax.random.PRNGKey(3), [T,value_size]) 155 | 156 | def f(q, k ,v): 157 | net = CrossAttnTransformer(num_heads = 8, 158 | num_layers = 4, 159 | model_size = query_size, 160 | QK_size = QK_size, 161 | V_size = V_size, 162 | widening_factor = 4) 163 | return net(q,k,v) 164 | 165 | f = hk.transform(f) 166 | rng_key = jax.random.PRNGKey(1234) 167 | params = f.init(rng_key, query, key, value) 168 | out_query = f.apply(params, rng_key, query, key, value) 169 | utils.print_pytree(params) 170 | print(out_query.shape) # (20,3) 171 | --------------------------------------------------------------------------------