├── .gitignore ├── LICENSE ├── README.md ├── exp_bc_h.py ├── exp_bc_h_vortex.py ├── exp_sea_h.py ├── exp_sea_h_vortex.py ├── exp_smoke_h.py ├── model_dict.py ├── models ├── DeepLag_2D.py ├── DeepLag_3D.py ├── FNO_2D.py ├── FNO_3D.py ├── Factformer_2D.py ├── Factformer_3D.py ├── GNOT_2D.py ├── GNOT_3D.py ├── GkTrm_2D.py ├── GkTrm_3D.py ├── LSM_2D.py ├── LSM_3D.py ├── UNet_2D.py ├── UNet_3D.py ├── Vortex_2D.py └── libs │ ├── __init__.py │ ├── attention.py │ ├── basics.py │ ├── fact │ ├── attention.py │ ├── basics.py │ ├── factorization_module.py │ └── positional_encoding_module.py │ ├── factorization_module.py │ ├── gktrm │ ├── __init__.py │ ├── ft.py │ ├── layers.py │ ├── ns_lite.py │ ├── utils.py │ └── utils_ft.py │ ├── positional_encoding_module.py │ └── vortex │ ├── io_utils.py │ ├── learning_utils.py │ ├── pretrained.tar │ └── simulation_utils.py ├── pic ├── bounded-navier-stokes.gif ├── eulag_block_v4.3.png ├── framework_v4.3.png ├── ocean-current.gif └── traj.png ├── requirements.txt ├── scripts ├── bc_deeplag.sh ├── bc_factformer.sh ├── bc_fno.sh ├── bc_gktrm.sh ├── bc_gnot.sh ├── bc_lsm.sh ├── bc_unet.sh ├── bc_vortex.sh ├── sea_deeplag.sh ├── sea_factformer.sh ├── sea_fno.sh ├── sea_gktrm.sh ├── sea_gnot.sh ├── sea_lsm.sh ├── sea_unet.sh ├── sea_vortex.sh ├── smoke_deeplag3d.sh ├── smoke_factformer3d.sh ├── smoke_fno3d.sh ├── smoke_gktrm3d.sh ├── smoke_gnot3d.sh ├── smoke_lsm3d.sh ├── smoke_unet3d.sh ├── test_all.sh └── test_all_longrollout.sh ├── test_bc_h.py ├── test_bc_h_vortex.py ├── test_sea_h.py ├── test_sea_h_vortex.py ├── test_smoke_h.py └── utils ├── adam.py ├── data_factory.py ├── params.py ├── split_merge_npy_file.py └── utilities3.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 THUML @ Tsinghua University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /exp_bc_h_vortex.py: -------------------------------------------------------------------------------- 1 | import os 2 | from timeit import default_timer 3 | from datetime import datetime, timedelta 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | from utils.data_factory import get_bc_dataset, BoundedNSDataset 8 | from utils.utilities3 import * 9 | from utils.params import get_args 10 | from utils.adam import Adam 11 | from model_dict import get_model 12 | 13 | from tqdm import tqdm 14 | 15 | time_str = (datetime.now()).strftime("%Y%m%d_%H%M%S") 16 | 17 | torch.manual_seed(0) 18 | np.random.seed(0) 19 | torch.cuda.manual_seed(0) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | ################################################################ 24 | # configs 25 | ################################################################ 26 | args = get_args(time=time_str) 27 | 28 | TRAIN_PATH = os.path.join(args.data_path, 're0.25_4c_gray_10000.npy') 29 | TEST_PATH = os.path.join(args.data_path, 're0.25_4c_gray_10000.npy') 30 | BOUNDARY_PATH = os.path.join(args.data_path, 'boundary_4c_rot.npy') 31 | 32 | padding = [int(p) for p in args.padding.split(',')] 33 | ntrain = args.ntrain 34 | ntest = args.ntest 35 | N = args.ntotal 36 | args.in_channels = args.in_dim * args.in_var 37 | args.out_channels = args.out_dim * args.out_var 38 | r1 = args.h_down 39 | r2 = args.w_down 40 | s1 = int(((args.h - 1) / r1) + 1) 41 | s2 = int(((args.w - 1) / r2) + 1) 42 | T_in = args.T_in 43 | T_out = args.T_out 44 | patch_size = tuple(int(x) for x in args.patch_size.split(',')) 45 | 46 | batch_size = args.batch_size 47 | learning_rate = args.learning_rate 48 | epochs = args.epochs 49 | step_size = args.step_size 50 | gamma = args.gamma 51 | delta_t = args.delta_t 52 | 53 | model_save_path = args.model_save_path 54 | model_save_name = args.model_save_name 55 | 56 | 57 | ################################################################ 58 | # models 59 | ################################################################ 60 | model = get_model(args) 61 | 62 | 63 | ################################################################ 64 | # load data and data normalization 65 | ################################################################ 66 | train_dataset = BoundedNSDataset(args, dataset_file=TRAIN_PATH, split='train',delta_t=delta_t, return_idx=True) 67 | test_dataset = BoundedNSDataset(args, dataset_file=TEST_PATH, split='test',delta_t=delta_t, return_idx=True) 68 | train_loader = train_dataset.loader() 69 | test_loader = test_dataset.loader() 70 | 71 | boundary, domain = process_boundary_condition(BOUNDARY_PATH, ds_rate=(r1,r2)) 72 | 73 | 74 | ################################################################ 75 | # training and evaluation 76 | ################################################################ 77 | optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 78 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 79 | 80 | myloss = LpLoss(size_average=False) 81 | 82 | step = 1 83 | min_test_l2_full = 114514 84 | for ep in range(epochs): 85 | model.train() 86 | t1 = default_timer() 87 | train_l2_step = 0 88 | train_l2_full = 0 89 | for batch_idx, (index, xx, yy) in enumerate(tqdm(train_loader)): 90 | index = index.to(device) * 0.01 91 | loss = 0 92 | xx = xx.to(device) 93 | yy = yy.to(device) 94 | xx /= 256.0 95 | yy /= 256.0 96 | 97 | for t in range(0, T_out, step): 98 | y = yy[..., t*args.out_var : (t + step)*args.out_var] 99 | im, vel_loss = model(xx, index) 100 | 101 | # print(xx.shape, y.shape) 102 | loss += nn.MSELoss().cuda()(im, y) 103 | loss += vel_loss 104 | if t == 0: 105 | pred = im 106 | else: 107 | pred = torch.cat((pred, im), -1) 108 | 109 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 110 | 111 | train_l2_step += loss.item() 112 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 113 | train_l2_full += l2_full.item() 114 | 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | t2 = default_timer() 119 | 120 | test_l2_step = 0 121 | test_l2_full = 0 122 | with torch.no_grad(): 123 | for batch_idx, (index, xx, yy) in enumerate(test_loader): 124 | index = index.to(device) * 0.01 125 | loss = 0 126 | xx = xx.to(device) 127 | yy = yy.to(device) 128 | xx /= 256.0 129 | 130 | for t in range(0, T_out, step): 131 | y = yy[..., t*args.out_var : (t + step)*args.out_var] 132 | im, _ = model(xx, index) 133 | im *= 256.0 134 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 135 | 136 | if t == 0: 137 | pred = im 138 | else: 139 | pred = torch.cat((pred, im), -1) 140 | 141 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 142 | 143 | test_l2_step += loss.item() 144 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 145 | 146 | # t2 = default_timer() 147 | scheduler.step() 148 | if test_l2_full / ntest < min_test_l2_full: 149 | min_test_l2_full = test_l2_full / ntest 150 | print(ep, t2 - t1, train_l2_step / ntrain / (T_out / step), train_l2_full / ntrain, 151 | test_l2_step / ntest / (T_out / step), 152 | test_l2_full / ntest, 'new_best!') 153 | print('save best model') 154 | torch.save(model.state_dict(), os.path.join(args.run_save_path, model_save_name[:-3]+f'_best.pt')) 155 | pd = pred[-1, :, :, -1].detach().cpu().numpy() 156 | gt = yy[-1, :, :, -1].detach().cpu().numpy() 157 | visual(pd, os.path.join(args.run_save_path, f'best_pred.png')) 158 | visual(gt, os.path.join(args.run_save_path, f'best_gt.png')) 159 | visual(np.abs(gt-pd), os.path.join(args.run_save_path, f'best_err.png')) 160 | else: 161 | print(ep, t2 - t1, train_l2_step / ntrain / (T_out / step), train_l2_full / ntrain, 162 | test_l2_step / ntest / (T_out / step), 163 | test_l2_full / ntest) 164 | if ep % 10 == 0: 165 | # if not os.path.exists(model_save_path): 166 | # os.makedirs(model_save_path) 167 | print('save latest model') 168 | torch.save(model.state_dict(), os.path.join(args.run_save_path, model_save_name[:-3]+f'_latest.pt')) 169 | if ep % 100 == 0: 170 | pd = pred[-1, :, :, -1].detach().cpu().numpy() 171 | gt = yy[-1, :, :, -1].detach().cpu().numpy() 172 | visual(pd, os.path.join(args.run_save_path, f'ep_{ep}_pred.png')) 173 | visual(gt, os.path.join(args.run_save_path, f'ep_{ep}_gt.png')) 174 | visual(np.abs(gt-pd), os.path.join(args.run_save_path, f'ep_{ep}_err.png')) 175 | -------------------------------------------------------------------------------- /exp_sea_h_vortex.py: -------------------------------------------------------------------------------- 1 | import os 2 | from timeit import default_timer 3 | from datetime import datetime, timedelta 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | from utils.data_factory import SeaDataset, SeaDatasetMemory 8 | from utils.utilities3 import * 9 | from utils.params import get_args 10 | from utils.adam import Adam 11 | from model_dict import get_model 12 | 13 | from tqdm import tqdm 14 | 15 | time_str = (datetime.now()).strftime("%Y%m%d_%H%M%S") 16 | 17 | torch.manual_seed(0) 18 | np.random.seed(0) 19 | torch.cuda.manual_seed(0) 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | ################################################################ 24 | # configs 25 | ################################################################ 26 | args = get_args(time=time_str) 27 | 28 | padding = [int(p) for p in args.padding.split(',')] 29 | ntrain = args.ntrain 30 | ntest = args.ntest 31 | N = args.ntotal 32 | args.in_channels = args.in_dim * args.in_var 33 | args.out_channels = args.out_dim * args.out_var 34 | r1 = args.h_down 35 | r2 = args.w_down 36 | s1 = int(((args.h - 1) / r1) + 1) 37 | s2 = int(((args.w - 1) / r2) + 1) 38 | T_in = args.T_in 39 | T_out = args.T_out 40 | patch_size = tuple(int(x) for x in args.patch_size.split(',')) 41 | 42 | batch_size = args.batch_size 43 | learning_rate = args.learning_rate 44 | epochs = args.epochs 45 | step_size = args.step_size 46 | gamma = args.gamma 47 | 48 | model_save_path = args.model_save_path 49 | model_save_name = args.model_save_name 50 | 51 | 52 | ################################################################ 53 | # models 54 | ################################################################ 55 | model = get_model(args) 56 | 57 | 58 | ################################################################ 59 | # load data and data normalization 60 | ################################################################ 61 | train_dataset = SeaDatasetMemory(args, region=args.region, split='train', return_idx=True) 62 | test_dataset = SeaDatasetMemory(args, region=args.region, split='test', return_idx=True) 63 | train_loader = train_dataset.loader() 64 | test_loader = test_dataset.loader() 65 | 66 | land, sea = get_land_sea_mask(args.data_path, args.fill_value) 67 | 68 | 69 | ################################################################ 70 | # training and evaluation 71 | ################################################################ 72 | optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) 73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) 74 | myloss = LpLoss(size_average=False, channel_wise=False) 75 | 76 | step = 1 77 | min_test_l2_full = 114514 78 | for ep in range(epochs): 79 | model.train() 80 | t1 = default_timer() 81 | train_l2_step = 0 82 | train_l2_full = 0 83 | for batch_idx, (index, xx, yy) in enumerate(tqdm(train_loader)): 84 | index = index.to(device) * 0.01 85 | loss = 0 86 | xx = xx.to(device) 87 | yy = yy.to(device) 88 | for t in range(0, T_out, step): 89 | y = yy[..., t*args.out_var : (t + step)*args.out_var] 90 | im, vel_loss = model(xx, index) 91 | 92 | # print(xx.shape, y.shape) 93 | loss += nn.MSELoss().cuda()(im, y) 94 | loss += vel_loss 95 | if t == 0: 96 | pred = im 97 | else: 98 | pred = torch.cat((pred, im), -1) 99 | 100 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 101 | 102 | train_l2_step += loss.item() 103 | l2_full = myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)) 104 | train_l2_full += l2_full.item() 105 | 106 | optimizer.zero_grad() 107 | loss.backward() 108 | optimizer.step() 109 | t2 = default_timer() 110 | 111 | test_l2_step = 0 112 | test_l2_full = 0 113 | with torch.no_grad(): 114 | for batch_idx, (index, xx, yy) in enumerate(test_loader): 115 | index = index.to(device) * 0.01 116 | loss = 0 117 | xx = xx.to(device) 118 | yy = yy.to(device) 119 | 120 | for t in range(0, T_out, step): 121 | y = yy[..., t*args.out_var : (t + step)*args.out_var] 122 | im, _ = model(xx, index) 123 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 124 | 125 | if t == 0: 126 | pred = im 127 | else: 128 | pred = torch.cat((pred, im), -1) 129 | 130 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 131 | 132 | test_l2_step += loss.item() 133 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 134 | 135 | scheduler.step() 136 | if test_l2_full / ntest < min_test_l2_full: 137 | min_test_l2_full = test_l2_full / ntest 138 | print(ep, t2 - t1, train_l2_step / ntrain / (T_out / step), train_l2_full / ntrain, 139 | test_l2_step / ntest / (T_out / step), 140 | test_l2_full / ntest, 'new_best!') 141 | print('save best model') 142 | torch.save(model.state_dict(), os.path.join(args.run_save_path, model_save_name[:-3]+f'_best.pt')) 143 | pd = pred[-1, :, :, -5:].detach().cpu().numpy() 144 | gt = yy[-1, :, :, -5:].detach().cpu().numpy() 145 | vars = ['thetao', 'so', 'uo', 'vo', 'zos'] 146 | for i in range(5): 147 | visual(pd[...,i], os.path.join(args.run_save_path, f'best_{vars[i]}_pred.png')) 148 | visual(gt[...,i], os.path.join(args.run_save_path, f'best_{vars[i]}_gt.png')) 149 | visual(np.abs(gt-pd)[...,i], os.path.join(args.run_save_path, f'best_{vars[i]}_err.png')) 150 | else: 151 | print(ep, t2 - t1, train_l2_step / ntrain / (T_out / step), train_l2_full / ntrain, 152 | test_l2_step / ntest / (T_out / step), 153 | test_l2_full / ntest) 154 | if ep % 10 == 0: 155 | # if not os.path.exists(model_save_path): 156 | # os.makedirs(model_save_path) 157 | print('save latest model') 158 | torch.save(model.state_dict(), os.path.join(args.run_save_path, model_save_name[:-3]+f'_latest.pt')) 159 | if ep % 100 == 0: 160 | pd = pred[-1, :, :, -5:].detach().cpu().numpy() 161 | gt = yy[-1, :, :, -5:].detach().cpu().numpy() 162 | vars = ['thetao', 'so', 'uo', 'vo', 'zos'] 163 | for i in range(5): 164 | visual(pd[...,i], os.path.join(args.run_save_path, f'ep_{ep}_{vars[i]}_pred.png')) 165 | visual(gt[...,i], os.path.join(args.run_save_path, f'ep_{ep}_{vars[i]}_gt.png')) 166 | visual(np.abs(gt-pd)[...,i], os.path.join(args.run_save_path, f'ep_{ep}_{vars[i]}_err.png')) 167 | -------------------------------------------------------------------------------- /model_dict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from models import FNO_2D, FNO_3D, DeepLag_2D, DeepLag_3D, LSM_2D, LSM_3D, Factformer_2D, Factformer_3D, GNOT_2D, GNOT_3D, UNet_2D, UNet_3D, GkTrm_2D, GkTrm_3D, Vortex_2D 4 | 5 | 6 | def get_model(args, ckpt_dir=None): 7 | model_dict = { 8 | 'FNO_2D': FNO_2D, 9 | 'FNO_3D': FNO_3D, 10 | 'LSM_2D': LSM_2D, 11 | 'LSM_3D': LSM_3D, 12 | 'DeepLag_2D': DeepLag_2D, 13 | 'DeepLag_3D': DeepLag_3D, 14 | 'Factformer_2D': Factformer_2D, 15 | 'Factformer_3D': Factformer_3D, 16 | 'GNOT_2D': GNOT_2D, 17 | 'GNOT_3D': GNOT_3D, 18 | 'UNet_2D': UNet_2D, 19 | 'UNet_3D': UNet_3D, 20 | 'GkTrm_2D': GkTrm_2D, 21 | 'GkTrm_3D': GkTrm_3D, 22 | 'Vortex_2D': Vortex_2D, 23 | } 24 | if ckpt_dir is None: 25 | return model_dict[args.model].Model(args=args).cuda() 26 | else: 27 | os.system(f'cp {str(ckpt_dir)}/{args.model}.py ./models/tmp_test_model.py') 28 | from models import tmp_test_model 29 | model = tmp_test_model.Model(args=args).cuda() 30 | os.system(f'rm -f ./models/tmp_test_model.py') 31 | return model 32 | -------------------------------------------------------------------------------- /models/FNO_2D.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | modified by Haixu Wu to adapt to this code base 4 | """ 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import torch 8 | import numpy as np 9 | import math 10 | 11 | 12 | ################################################################ 13 | # fourier layer 14 | ################################################################ 15 | class SpectralConv2d(nn.Module): 16 | def __init__(self, in_channels, out_channels, modes1, modes2): 17 | super(SpectralConv2d, self).__init__() 18 | 19 | """ 20 | 2D Fourier layer. It does FFT, linear transform, and Inverse FFT. 21 | """ 22 | 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 26 | self.modes2 = modes2 27 | 28 | self.scale = (1 / (in_channels * out_channels)) 29 | self.weights1 = nn.Parameter( 30 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 31 | self.weights2 = nn.Parameter( 32 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 33 | 34 | # Complex multiplication 35 | def compl_mul2d(self, input, weights): 36 | # (batch, in_channel, x,y ), (in_channel, out_channel, x,y) -> (batch, out_channel, x,y) 37 | return torch.einsum("bixy,ioxy->boxy", input, weights) 38 | 39 | def forward(self, x): 40 | batchsize = x.shape[0] 41 | # Compute Fourier coeffcients up to factor of e^(- something constant) 42 | x_ft = torch.fft.rfft2(x) 43 | 44 | # Multiply relevant Fourier modes 45 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, 46 | device=x.device) 47 | out_ft[:, :, :self.modes1, :self.modes2] = \ 48 | self.compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 49 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 50 | self.compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 51 | 52 | # Return to physical space 53 | x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) 54 | return x 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, args): 59 | super(Model, self).__init__() 60 | in_channels = args.in_dim * args.in_var 61 | out_channels = args.out_dim * args.out_var 62 | self.modes1 = args.num_basis 63 | self.modes2 = args.num_basis 64 | self.width = args.d_model 65 | self.padding = [int(x) for x in args.padding.split(',')] 66 | self.num_layers = args.num_layers 67 | # self.conv0 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 68 | # self.conv1 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 69 | # self.conv2 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 70 | # self.conv3 = SpectralConv2d(self.width, self.width, self.modes1, self.modes2) 71 | self.convs = nn.ModuleList([SpectralConv2d(self.width, self.width, self.modes1, self.modes2) for _ in range(args.num_layers)]) 72 | # self.w0 = nn.Conv2d(self.width, self.width, 1) 73 | # self.w1 = nn.Conv2d(self.width, self.width, 1) 74 | # self.w2 = nn.Conv2d(self.width, self.width, 1) 75 | # self.w3 = nn.Conv2d(self.width, self.width, 1) 76 | self.ws = nn.ModuleList([nn.Conv2d(self.width, self.width, 1) for _ in range(args.num_layers)]) 77 | 78 | self.fc0 = nn.Linear(in_channels + 2, self.width) # input channel is 3: (a(x, y), x, y) 79 | self.fc1 = nn.Linear(self.width, 128) 80 | self.fc2 = nn.Linear(128, out_channels) 81 | 82 | def forward(self, x): 83 | grid = self.get_grid(x.shape, x.device) 84 | x = torch.cat((x, grid), dim=-1) 85 | x = self.fc0(x) 86 | x = x.permute(0, 3, 1, 2) 87 | if not all(item == 0 for item in self.padding): 88 | x = F.pad(x, [0, self.padding[0], 0, self.padding[1]]) 89 | 90 | # x1 = self.conv0(x) 91 | # x2 = self.w0(x) 92 | # x = x1 + x2 93 | # x = F.gelu(x) 94 | 95 | # x1 = self.conv1(x) 96 | # x2 = self.w1(x) 97 | # x = x1 + x2 98 | # x = F.gelu(x) 99 | 100 | # x1 = self.conv2(x) 101 | # x2 = self.w2(x) 102 | # x = x1 + x2 103 | # x = F.gelu(x) 104 | 105 | # x1 = self.conv3(x) 106 | # x2 = self.w3(x) 107 | # x = x1 + x2 108 | for i in range(self.num_layers): 109 | x1 = self.convs[i](x) 110 | x2 = self.ws[i](x) 111 | x = x1 + x2 112 | if i < self.num_layers - 1: 113 | x = F.gelu(x) 114 | 115 | if not all(item == 0 for item in self.padding): 116 | x = x[..., :-self.padding[1], :-self.padding[0]] 117 | x = x.permute(0, 2, 3, 1) 118 | x = self.fc1(x) 119 | x = F.gelu(x) 120 | x = self.fc2(x) 121 | return x 122 | 123 | def get_grid(self, shape, device): 124 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 125 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 126 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 127 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 128 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 129 | return torch.cat((gridx, gridy), dim=-1).to(device) 130 | 131 | 132 | if __name__ == "__main__": 133 | conv = SpectralConv2d(in_channels=64, out_channels=64, modes1=32, modes2=32) 134 | 135 | x = torch.randn([1, 64, 128, 128]) 136 | print(conv(x).shape) -------------------------------------------------------------------------------- /models/FNO_3D.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Zongyi Li 3 | modified by Haixu Wu to adapt to this code base 4 | """ 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import torch 8 | import numpy as np 9 | import math 10 | 11 | 12 | ################################################################ 13 | # 3d fourier layers 14 | ################################################################ 15 | 16 | class SpectralConv3d(nn.Module): 17 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 18 | super(SpectralConv3d, self).__init__() 19 | 20 | """ 21 | 3D Fourier layer. It does FFT, linear transform, and Inverse FFT. 22 | """ 23 | 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.modes1 = modes1 # Number of Fourier modes to multiply, at most floor(N/2) + 1 27 | self.modes2 = modes2 28 | self.modes3 = modes3 29 | 30 | self.scale = (1 / (in_channels * out_channels)) 31 | self.weights1 = nn.Parameter( 32 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 33 | dtype=torch.cfloat)) 34 | self.weights2 = nn.Parameter( 35 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 36 | dtype=torch.cfloat)) 37 | self.weights3 = nn.Parameter( 38 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 39 | dtype=torch.cfloat)) 40 | self.weights4 = nn.Parameter( 41 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, 42 | dtype=torch.cfloat)) 43 | 44 | # Complex multiplication 45 | def compl_mul3d(self, input, weights): 46 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 47 | return torch.einsum("bixyz,ioxyz->boxyz", input, weights) 48 | 49 | def forward(self, x): 50 | batchsize = x.shape[0] 51 | # Compute Fourier coeffcients up to factor of e^(- something constant) 52 | x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1]) 53 | 54 | # Multiply relevant Fourier modes 55 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-3), x.size(-2), x.size(-1) // 2 + 1, 56 | dtype=torch.cfloat, device=x.device) 57 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 58 | self.compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 59 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 60 | self.compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 61 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 62 | self.compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 63 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 64 | self.compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 65 | 66 | # Return to physical space 67 | x = torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) 68 | return x 69 | 70 | 71 | class Model(nn.Module): 72 | def __init__(self, args): 73 | super(Model, self).__init__() 74 | in_channels = args.in_dim * args.in_var 75 | out_channels = args.out_dim * args.out_var 76 | self.modes1 = args.num_basis 77 | self.modes2 = args.num_basis 78 | self.modes3 = args.num_basis // 2 79 | self.width = args.d_model 80 | self.padding = [int(x) for x in args.padding.split(',')] 81 | 82 | # self.conv0 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 83 | # self.conv1 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 84 | # self.conv2 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 85 | # self.conv3 = SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) 86 | self.convs = nn.ModuleList([SpectralConv3d(self.width, self.width, self.modes1, self.modes2, self.modes3) for _ in range(args.num_layers)]) 87 | # self.w0 = nn.Conv3d(self.width, self.width, 1) 88 | # self.w1 = nn.Conv3d(self.width, self.width, 1) 89 | # self.w2 = nn.Conv3d(self.width, self.width, 1) 90 | # self.w3 = nn.Conv3d(self.width, self.width, 1) 91 | self.ws = nn.ModuleList([nn.Conv3d(self.width, self.width, 1) for _ in range(args.num_layers)]) 92 | # self.bn0 = torch.nn.BatchNorm3d(self.width) 93 | # self.bn1 = torch.nn.BatchNorm3d(self.width) 94 | # self.bn2 = torch.nn.BatchNorm3d(self.width) 95 | # self.bn3 = torch.nn.BatchNorm3d(self.width) 96 | self.bns = nn.ModuleList([torch.nn.BatchNorm3d(self.width) for _ in range(args.num_layers)]) 97 | 98 | self.fc0 = nn.Linear(in_channels + 3, self.width) 99 | self.fc1 = nn.Linear(self.width, 128) 100 | self.fc2 = nn.Linear(128, out_channels) 101 | 102 | def forward(self, x): 103 | grid = self.get_grid(x.shape, x.device) 104 | x = torch.cat((x, grid), dim=-1) 105 | x = self.fc0(x) 106 | x = x.permute(0, 4, 1, 2, 3) 107 | if not all(item == 0 for item in self.padding): 108 | x = F.pad(x, [0, self.padding[0], 0, self.padding[1], 0, self.padding[2]]) 109 | 110 | # x1 = self.conv0(x) 111 | # x2 = self.w0(x) 112 | # x = x1 + x2 113 | # x = F.gelu(x) 114 | 115 | # x1 = self.conv1(x) 116 | # x2 = self.w1(x) 117 | # x = x1 + x2 118 | # x = F.gelu(x) 119 | 120 | # x1 = self.conv2(x) 121 | # x2 = self.w2(x) 122 | # x = x1 + x2 123 | # x = F.gelu(x) 124 | 125 | # x1 = self.conv3(x) 126 | # x2 = self.w3(x) 127 | # x = x1 + x2 128 | for i in range(self.num_layers): 129 | x1 = self.convs[i](x) 130 | x2 = self.ws[i](x) 131 | x = x1 + x2 132 | if i < self.num_layers - 1: 133 | x = F.gelu(x) 134 | 135 | if not all(item == 0 for item in self.padding): 136 | x = x[..., :-self.padding[2], :-self.padding[1], :-self.padding[0]] 137 | x = x.permute(0, 2, 3, 4, 1) # pad the domain if input is non-periodic 138 | x = self.fc1(x) 139 | x = F.gelu(x) 140 | x = self.fc2(x) 141 | return x 142 | 143 | def get_grid(self, shape, device): 144 | batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] 145 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 146 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 147 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 148 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 149 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float) 150 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 151 | return torch.cat((gridx, gridy, gridz), dim=-1).to(device) 152 | -------------------------------------------------------------------------------- /models/Factformer_2D.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | import numpy as np 7 | import argparse 8 | from tqdm import tqdm 9 | import time 10 | import os 11 | import gc 12 | from einops import rearrange, repeat, reduce 13 | from einops.layers.torch import Rearrange 14 | 15 | from .libs.fact.factorization_module import FABlock2D 16 | from .libs.fact.positional_encoding_module import GaussianFourierFeatureTransform 17 | 18 | class FactorizedTransformer(nn.Module): 19 | def __init__(self, 20 | dim, 21 | dim_head, 22 | heads, 23 | dim_out, 24 | depth, 25 | **kwargs): 26 | super().__init__() 27 | self.layers = nn.ModuleList([]) 28 | for _ in range(depth): 29 | 30 | layer = nn.ModuleList([]) 31 | layer.append(nn.Sequential( 32 | GaussianFourierFeatureTransform(2, dim // 2, 1), 33 | nn.Linear(dim, dim) 34 | )) 35 | layer.append(FABlock2D(dim, dim_head, dim, heads, dim_out, use_rope=True, 36 | **kwargs)) 37 | self.layers.append(layer) 38 | 39 | def forward(self, u, pos_lst=None): 40 | b, nx, ny, c = u.shape 41 | # nx, ny = pos_lst[0].shape[0], pos_lst[1].shape[0] 42 | # print(f'nx, ny: {nx}, {ny}') 43 | if pos_lst is None: 44 | pos, pos_lst = self.get_grid(u.shape, u.device) 45 | pos = pos.view(-1, 2) 46 | # print(pos.shape) 47 | for pos_enc, attn_layer in self.layers: 48 | tmp = pos_enc(pos).view(1, nx, ny, -1) 49 | u += tmp 50 | u = attn_layer(u, pos_lst) + u 51 | return u 52 | 53 | def get_grid(self, shape, device): 54 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 55 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 56 | gridx = gridx.reshape(1, size_x, 1, 1) 57 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 58 | gridy = gridy.reshape(1, 1, size_y, 1) 59 | return torch.cat((gridx.repeat([1, 1, size_y, 1]), gridy.repeat([1, size_x, 1, 1])), dim=-1).to(device), \ 60 | [gridx[0,...,0], gridy[0,0]] 61 | 62 | 63 | class Model(nn.Module): 64 | def __init__(self, args): 65 | super().__init__() 66 | self.args = args 67 | self.H = int(((args.h - 1) / args.h_down) + 1) 68 | self.W = int(((args.w - 1) / args.w_down) + 1) 69 | # self.resolutions = args.resolutions # hierachical resolutions, [16, 8, 4] 70 | # self.out_resolution = args.out_resolution 71 | 72 | self.in_dim = args.in_dim * args.in_var 73 | self.out_dim = args.out_dim * args.out_var 74 | 75 | self.depth = args.depth # depth of the encoder transformer 76 | self.dim = args.d_model # dimension of the transformer 77 | self.heads = args.heads 78 | self.dim_head = args.dim_head 79 | # self.reducer = args.reducer 80 | # self.resolution = args.resolution 81 | 82 | # self.pos_in_dim = args.pos_in_dim 83 | # self.pos_out_dim = args.pos_out_dim 84 | self.positional_embedding = 'rotary' 85 | self.kernel_multiplier = 3 86 | 87 | self.to_in = nn.Linear(self.in_dim, self.dim, bias=True) 88 | 89 | self.encoder = FactorizedTransformer(self.dim, self.dim_head, self.heads, self.dim, self.depth, 90 | kernel_multiplier=self.kernel_multiplier) 91 | 92 | self.down_block = nn.Sequential( 93 | nn.InstanceNorm2d(self.dim), 94 | nn.Conv2d(self.dim, self.dim//2, kernel_size=3, stride=2, padding=1, bias=True), 95 | nn.GELU(), 96 | nn.Conv2d(self.dim//2, self.dim//2, kernel_size=3, stride=1, padding=1, bias=True)) 97 | 98 | self.up_block = nn.Sequential( 99 | nn.Upsample(size=(self.H, self.W), mode='nearest'), 100 | nn.Conv2d(self.dim//2, self.dim//2, kernel_size=3, stride=1, padding=1, bias=True), 101 | nn.GELU(), 102 | nn.Conv2d(self.dim//2, self.dim, kernel_size=3, stride=1, padding=1, bias=True)) 103 | 104 | self.simple_to_out = nn.Sequential( 105 | Rearrange('b nx ny c -> b c (nx ny)'), 106 | nn.GroupNorm(num_groups=8, num_channels=self.dim*2), 107 | nn.Conv1d(self.dim*2, self.dim, kernel_size=1, stride=1, padding=0, bias=False), 108 | nn.GELU(), 109 | nn.Conv1d(self.dim, self.out_dim, kernel_size=1, stride=1, padding=0, bias=True) 110 | ) 111 | 112 | def forward(self, 113 | u, 114 | fx=None, T=None 115 | ): 116 | pos_lst = None 117 | # b, _, c = u.shape 118 | # u = u.view(b, self.H, self.W, c) 119 | _, nx, ny, _ = u.shape 120 | u = self.to_in(u) 121 | u_last = self.encoder(u, pos_lst) 122 | u = rearrange(u_last, 'b nx ny c -> b c nx ny') 123 | assert u.shape[1] == self.dim 124 | u = self.down_block(u) 125 | u = self.up_block(u) 126 | u = rearrange(u, 'b c nx ny -> b nx ny c') 127 | # print(u.shape, u_last.shape) 128 | u = torch.cat([u, u_last], dim=-1) 129 | u = self.simple_to_out(u) 130 | u = rearrange(u, 'b c (nx ny) -> b nx ny c', nx=nx, ny=ny) 131 | return u -------------------------------------------------------------------------------- /models/Factformer_3D.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | import numpy as np 7 | import argparse 8 | from tqdm import tqdm 9 | import time 10 | import os 11 | import gc 12 | from einops import rearrange, repeat, reduce 13 | from einops.layers.torch import Rearrange 14 | 15 | from .libs.fact.factorization_module import FABlock3D 16 | from .libs.fact.positional_encoding_module import GaussianFourierFeatureTransform 17 | 18 | class FactorizedTransformer(nn.Module): 19 | def __init__(self, 20 | dim, 21 | dim_head, 22 | heads, 23 | dim_out, 24 | depth, 25 | **kwargs): 26 | super().__init__() 27 | self.layers = nn.ModuleList([]) 28 | for _ in range(depth): 29 | 30 | layer = nn.ModuleList([]) 31 | layer.append(nn.Sequential( 32 | GaussianFourierFeatureTransform(3, dim // 2, 1), 33 | nn.Linear(dim, dim) 34 | )) 35 | layer.append(FABlock3D(dim, dim_head, dim, heads, dim_out, use_rope=True, 36 | **kwargs)) 37 | self.layers.append(layer) 38 | 39 | def forward(self, u, pos_lst=None): 40 | b, nz, nx, ny, c = u.shape 41 | # nz, nx, ny = pos_lst[0].shape[0], pos_lst[1].shape[0], pos_lst[2].shape[0] 42 | # print(f'nz, nx, ny: {nz}, {nx}, {ny}') 43 | if pos_lst is None: 44 | pos, pos_lst = self.get_grid(u.shape, u.device) 45 | pos = pos.view(-1, 3) 46 | # print(pos.shape) 47 | for pos_enc, attn_layer in self.layers: 48 | tmp = pos_enc(pos).view(1, nz, nx, ny, -1) 49 | u += tmp 50 | u = attn_layer(u, pos_lst) + u 51 | return u 52 | 53 | def get_grid(self, shape, device): 54 | batchsize, size_z, size_x, size_y = shape[0], shape[1], shape[2], shape[3] 55 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float) 56 | gridz = gridz.reshape(1, size_z, 1, 1, 1) 57 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 58 | gridx = gridx.reshape(1, 1, size_x, 1, 1) 59 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 60 | gridy = gridy.reshape(1, 1, 1, size_y, 1) 61 | return torch.cat((gridz.repeat([1, 1, size_x, size_y, 1]), gridx.repeat([1, size_z, 1, size_y, 1]), gridy.repeat([1, size_z, size_x, 1, 1])), dim=-1).to(device), \ 62 | [gridz[0,:,:,0,0], gridx[0,0,:,:,0], gridy[0,0,0,:,:]] 63 | 64 | 65 | class Model(nn.Module): 66 | def __init__(self, args): 67 | super().__init__() 68 | self.args = args 69 | self.Z = int(((args.z - 1) / args.z_down) + 1) 70 | self.H = int(((args.h - 1) / args.h_down) + 1) 71 | self.W = int(((args.w - 1) / args.w_down) + 1) 72 | # self.resolutions = args.resolutions # hierachical resolutions, [16, 8, 4] 73 | # self.out_resolution = args.out_resolution 74 | 75 | self.in_dim = args.in_dim * args.in_var 76 | self.out_dim = args.out_dim * args.out_var 77 | 78 | self.depth = args.depth # depth of the encoder transformer 79 | self.dim = args.d_model # dimension of the transformer 80 | self.heads = args.heads 81 | self.dim_head = args.dim_head 82 | # self.reducer = args.reducer 83 | # self.resolution = args.resolution 84 | 85 | # self.pos_in_dim = args.pos_in_dim 86 | # self.pos_out_dim = args.pos_out_dim 87 | self.positional_embedding = 'rotary' 88 | self.kernel_multiplier = 3 89 | 90 | self.to_in = nn.Linear(self.in_dim, self.dim, bias=True) 91 | 92 | self.encoder = FactorizedTransformer(self.dim, self.dim_head, self.heads, self.dim, self.depth, 93 | kernel_multiplier=self.kernel_multiplier) 94 | 95 | self.down_block = nn.Sequential( 96 | nn.InstanceNorm3d(self.dim), 97 | nn.Conv3d(self.dim, self.dim//2, kernel_size=3, stride=2, padding=1, bias=True), 98 | nn.GELU(), 99 | nn.Conv3d(self.dim//2, self.dim//2, kernel_size=3, stride=1, padding=1, bias=True)) 100 | 101 | self.up_block = nn.Sequential( 102 | nn.Upsample(size=(self.Z, self.H, self.W), mode='nearest'), 103 | nn.Conv3d(self.dim//2, self.dim//2, kernel_size=3, stride=1, padding=1, bias=True), 104 | nn.GELU(), 105 | nn.Conv3d(self.dim//2, self.dim, kernel_size=3, stride=1, padding=1, bias=True)) 106 | 107 | self.simple_to_out = nn.Sequential( 108 | Rearrange('b nz nx ny c -> b c (nz nx ny)'), 109 | nn.GroupNorm(num_groups=8, num_channels=self.dim*2), 110 | nn.Conv1d(self.dim*2, self.dim, kernel_size=1, stride=1, padding=0, bias=False), 111 | nn.GELU(), 112 | nn.Conv1d(self.dim, self.out_dim, kernel_size=1, stride=1, padding=0, bias=True) 113 | ) 114 | 115 | def forward(self, 116 | u, 117 | fx=None, T=None 118 | ): 119 | pos_lst = None 120 | # b, _, c = u.shape 121 | # u = u.view(b, self.Z, self.H, self.W, c) 122 | _, nz, nx, ny, _ = u.shape 123 | u = self.to_in(u) 124 | u_last = self.encoder(u, pos_lst) 125 | u = rearrange(u_last, 'b nz nx ny c -> b c nz nx ny') 126 | assert u.shape[1] == self.dim 127 | u = self.down_block(u) 128 | u = self.up_block(u) 129 | u = rearrange(u, 'b c nz nx ny -> b nz nx ny c') 130 | # print(u.shape, u_last.shape) 131 | u = torch.cat([u, u_last], dim=-1) 132 | u = self.simple_to_out(u) 133 | u = rearrange(u, 'b c (nz nx ny) -> b nz nx ny c', nx=nx, ny=ny) 134 | return u -------------------------------------------------------------------------------- /models/UNet_2D.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | import math 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | ################################################################ 10 | # Multiscale modules 2D 11 | ################################################################ 12 | class DoubleConv(nn.Module): 13 | """(convolution => [BN] => ReLU) * 2""" 14 | 15 | def __init__(self, in_channels, out_channels, mid_channels=None): 16 | super().__init__() 17 | if not mid_channels: 18 | mid_channels = out_channels 19 | self.double_conv = nn.Sequential( 20 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 21 | nn.BatchNorm2d(mid_channels), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 24 | nn.BatchNorm2d(out_channels), 25 | nn.ReLU(inplace=True) 26 | ) 27 | 28 | def forward(self, x): 29 | return self.double_conv(x) 30 | 31 | 32 | class Down(nn.Module): 33 | """Downscaling with maxpool then double conv""" 34 | 35 | def __init__(self, in_channels, out_channels): 36 | super().__init__() 37 | self.maxpool_conv = nn.Sequential( 38 | nn.MaxPool2d(2), 39 | DoubleConv(in_channels, out_channels) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.maxpool_conv(x) 44 | 45 | 46 | class Up(nn.Module): 47 | """Upscaling then double conv""" 48 | 49 | def __init__(self, in_channels, out_channels, bilinear=True): 50 | super().__init__() 51 | 52 | # if bilinear, use the normal convolutions to reduce the number of channels 53 | if bilinear: 54 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 55 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 56 | else: 57 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 58 | self.conv = DoubleConv(in_channels, out_channels) 59 | 60 | def forward(self, x1, x2): 61 | x1 = self.up(x1) 62 | # input is CHW 63 | diffY = x2.size()[2] - x1.size()[2] 64 | diffX = x2.size()[3] - x1.size()[3] 65 | 66 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 67 | diffY // 2, diffY - diffY // 2]) 68 | # if you have padding issues, see 69 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 70 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 71 | x = torch.cat([x2, x1], dim=1) 72 | return self.conv(x) 73 | 74 | 75 | class OutConv(nn.Module): 76 | def __init__(self, in_channels, out_channels): 77 | super(OutConv, self).__init__() 78 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 79 | 80 | def forward(self, x): 81 | return self.conv(x) 82 | 83 | 84 | class Model(nn.Module): 85 | def __init__(self, args, bilinear=True): 86 | super(Model, self).__init__() 87 | in_channels = args.in_dim * args.in_var 88 | out_channels = args.out_dim * args.out_var 89 | width = args.d_model 90 | num_token = args.num_token 91 | num_basis = args.num_basis 92 | patch_size = [int(x) for x in args.patch_size.split(',')] 93 | padding = [int(x) for x in args.padding.split(',')] 94 | # multiscale modules 95 | self.inc = DoubleConv(width, width) 96 | self.down1 = Down(width, width * 2) 97 | self.down2 = Down(width * 2, width * 4) 98 | self.down3 = Down(width * 4, width * 8) 99 | factor = 2 if bilinear else 1 100 | self.down4 = Down(width * 8, width * 16 // factor) 101 | self.up1 = Up(width * 16, width * 8 // factor, bilinear) 102 | self.up2 = Up(width * 8, width * 4 // factor, bilinear) 103 | self.up3 = Up(width * 4, width * 2 // factor, bilinear) 104 | self.up4 = Up(width * 2, width, bilinear) 105 | self.outc = OutConv(width, width) 106 | # projectors 107 | self.padding = padding 108 | self.fc0 = nn.Linear(in_channels + 2, width) 109 | self.fc1 = nn.Linear(width, 128) 110 | self.fc2 = nn.Linear(128, out_channels) 111 | 112 | def forward(self, x): 113 | grid = self.get_grid(x.shape, x.device) 114 | x = torch.cat((x, grid), dim=-1) 115 | x = self.fc0(x) 116 | x = x.permute(0, 3, 1, 2) 117 | 118 | if not all(item == 0 for item in self.padding): 119 | x = F.pad(x, [self.padding[1]//2, self.padding[1]-self.padding[1]//2, 120 | self.padding[0]//2, self.padding[0]-self.padding[0]//2]) 121 | 122 | x1 = self.inc(x) 123 | x2 = self.down1(x1) 124 | x3 = self.down2(x2) 125 | x4 = self.down3(x3) 126 | x5 = self.down4(x4) 127 | x = self.up1(x5, x4) 128 | x = self.up2(x, x3) 129 | x = self.up3(x, x2) 130 | x = self.up4(x, x1) 131 | x = self.outc(x) 132 | 133 | if not all(item == 0 for item in self.padding): 134 | x = x[..., self.padding[0]//2:-(self.padding[0]-self.padding[0]//2), 135 | self.padding[1]//2:-(self.padding[1]-self.padding[1]//2)] 136 | x = x.permute(0, 2, 3, 1) 137 | x = self.fc1(x) 138 | x = F.gelu(x) 139 | x = self.fc2(x) 140 | return x 141 | 142 | def get_grid(self, shape, device): 143 | batchsize, size_x, size_y = shape[0], shape[1], shape[2] 144 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 145 | gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) 146 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 147 | gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) 148 | return torch.cat((gridx, gridy), dim=-1).to(device) 149 | -------------------------------------------------------------------------------- /models/UNet_3D.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | import numpy as np 5 | import math 6 | 7 | 8 | ################################################################ 9 | # Multiscale modules 3D 10 | ################################################################ 11 | 12 | class DoubleConv(nn.Module): 13 | """(convolution => [BN] => ReLU) * 2""" 14 | 15 | def __init__(self, in_channels, out_channels, mid_channels=None): 16 | super().__init__() 17 | if not mid_channels: 18 | mid_channels = out_channels 19 | self.double_conv = nn.Sequential( 20 | nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 21 | nn.BatchNorm3d(mid_channels), 22 | nn.ReLU(inplace=True), 23 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 24 | nn.BatchNorm3d(mid_channels), 25 | nn.ReLU(inplace=True), 26 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm3d(mid_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 30 | nn.BatchNorm3d(mid_channels), 31 | nn.ReLU(inplace=True), 32 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 33 | nn.BatchNorm3d(mid_channels), 34 | nn.ReLU(inplace=True), 35 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 36 | nn.BatchNorm3d(mid_channels), 37 | nn.ReLU(inplace=True), 38 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 39 | nn.BatchNorm3d(mid_channels), 40 | nn.ReLU(inplace=True), 41 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 42 | nn.BatchNorm3d(mid_channels), 43 | nn.ReLU(inplace=True), 44 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 45 | nn.BatchNorm3d(mid_channels), 46 | nn.ReLU(inplace=True), 47 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 48 | nn.BatchNorm3d(mid_channels), 49 | nn.ReLU(inplace=True), 50 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 51 | nn.BatchNorm3d(mid_channels), 52 | nn.ReLU(inplace=True), 53 | nn.Conv3d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False), 54 | nn.BatchNorm3d(mid_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 57 | nn.BatchNorm3d(out_channels), 58 | nn.ReLU(inplace=True), 59 | ) 60 | 61 | def forward(self, x): 62 | return self.double_conv(x) 63 | 64 | 65 | class Down(nn.Module): 66 | """Downscaling with maxpool then double conv""" 67 | 68 | def __init__(self, in_channels, out_channels): 69 | super().__init__() 70 | self.maxpool_conv = nn.Sequential( 71 | nn.MaxPool3d(2), 72 | DoubleConv(in_channels, out_channels) 73 | ) 74 | 75 | def forward(self, x): 76 | return self.maxpool_conv(x) 77 | 78 | 79 | class Up(nn.Module): 80 | """Upscaling then double conv""" 81 | 82 | def __init__(self, in_channels, out_channels, bilinear=True): 83 | super().__init__() 84 | 85 | # if bilinear, use the normal convolutions to reduce the number of channels 86 | if bilinear: 87 | self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) 88 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 89 | else: 90 | self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2) 91 | self.conv = DoubleConv(in_channels, out_channels) 92 | 93 | def forward(self, x1, x2): 94 | x1 = self.up(x1) 95 | x = torch.cat([x2, x1], dim=1) 96 | return self.conv(x) 97 | 98 | 99 | class OutConv(nn.Module): 100 | def __init__(self, in_channels, out_channels): 101 | super(OutConv, self).__init__() 102 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1) 103 | 104 | def forward(self, x): 105 | return self.conv(x) 106 | 107 | 108 | class Model(nn.Module): 109 | def __init__(self, args, bilinear=True): 110 | super(Model, self).__init__() 111 | in_channels = args.in_dim * args.in_var 112 | out_channels = args.out_dim * args.out_var 113 | width = args.d_model 114 | num_token = args.num_token 115 | num_basis = args.num_basis 116 | patch_size = [int(x) for x in args.patch_size.split(',')] 117 | padding = [int(x) for x in args.padding.split(',')] 118 | # multiscale modules 119 | self.inc = DoubleConv(width, width) 120 | self.down1 = Down(width, width * 2) 121 | self.down2 = Down(width * 2, width * 4) 122 | self.down3 = Down(width * 4, width * 8) 123 | factor = 2 if bilinear else 1 124 | self.down4 = Down(width * 8, width * 16 // factor) 125 | self.up1 = Up(width * 16, width * 8 // factor, bilinear) 126 | self.up2 = Up(width * 8, width * 4 // factor, bilinear) 127 | self.up3 = Up(width * 4, width * 2 // factor, bilinear) 128 | self.up4 = Up(width * 2, width, bilinear) 129 | self.outc = OutConv(width, width) 130 | # projectors 131 | self.padding = padding 132 | self.fc0 = nn.Linear(in_channels + 3, width) 133 | self.fc1 = nn.Linear(width, 128) 134 | self.fc2 = nn.Linear(128, out_channels) 135 | 136 | def forward(self, x): 137 | grid = self.get_grid(x.shape, x.device) 138 | x = torch.cat((x, grid), dim=-1) 139 | x = self.fc0(x) 140 | x = x.permute(0, 4, 1, 2, 3) 141 | if not all(item == 0 for item in self.padding): 142 | x = F.pad(x, [0, self.padding[0], 0, self.padding[1], 0, self.padding[2]]) 143 | 144 | x1 = self.inc(x) 145 | x2 = self.down1(x1) 146 | x3 = self.down2(x2) 147 | x4 = self.down3(x3) 148 | x5 = self.down4(x4) 149 | x = self.up1(x5, x4) 150 | x = self.up2(x, x3) 151 | x = self.up3(x, x2) 152 | x = self.up4(x, x1) 153 | x = self.outc(x) 154 | 155 | if not all(item == 0 for item in self.padding): 156 | x = x[..., :-self.padding[2], :-self.padding[1], :-self.padding[0]] 157 | x = x.permute(0, 2, 3, 4, 1) 158 | x = self.fc1(x) 159 | x = F.gelu(x) 160 | x = self.fc2(x) 161 | return x 162 | 163 | def get_grid(self, shape, device): 164 | batchsize, size_x, size_y, size_z = shape[0], shape[1], shape[2], shape[3] 165 | gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) 166 | gridx = gridx.reshape(1, size_x, 1, 1, 1).repeat([batchsize, 1, size_y, size_z, 1]) 167 | gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) 168 | gridy = gridy.reshape(1, 1, size_y, 1, 1).repeat([batchsize, size_x, 1, size_z, 1]) 169 | gridz = torch.tensor(np.linspace(0, 1, size_z), dtype=torch.float) 170 | gridz = gridz.reshape(1, 1, 1, size_z, 1).repeat([batchsize, size_x, size_y, 1, 1]) 171 | return torch.cat((gridx, gridy, gridz), dim=-1).to(device) 172 | -------------------------------------------------------------------------------- /models/Vortex_2D.py: -------------------------------------------------------------------------------- 1 | from .libs.vortex.io_utils import * 2 | from .libs.vortex.simulation_utils import * 3 | from .libs.vortex.learning_utils import L2_Loss, vort_to_vel 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | import math 8 | torch.manual_seed(123) 9 | import sys 10 | import os 11 | 12 | from functorch import jacrev, vmap 13 | import torch 14 | class SineResidualBlock(nn.Module): 15 | 16 | def __init__(self, in_features, out_features, bias=True, 17 | is_first=False, omega_0=30): 18 | super(SineResidualBlock, self).__init__() 19 | self.omega_0 = omega_0 20 | self.is_first = is_first 21 | 22 | self.in_features = in_features 23 | self.linear = nn.Linear(in_features, out_features, bias=bias) 24 | # add shortcut 25 | self.shortcut = nn.Sequential() 26 | if in_features != out_features: 27 | self.shortcut = nn.Sequential( 28 | nn.Linear(in_features, out_features), 29 | ) 30 | 31 | self.init_weights() 32 | 33 | def init_weights(self): 34 | with torch.no_grad(): 35 | if self.is_first: 36 | self.linear.weight.uniform_(-1 / self.in_features, 37 | 1 / self.in_features) 38 | else: 39 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 40 | np.sqrt(6 / self.in_features) / self.omega_0) 41 | 42 | def forward(self, input): 43 | out = torch.sin(self.omega_0 * self.linear(input)) 44 | out += self.shortcut(input) 45 | out = nn.functional.relu(out) 46 | return out 47 | 48 | class Dynamics_Net(nn.Module): 49 | def __init__(self): 50 | super(Dynamics_Net, self).__init__() 51 | in_dim = 1 52 | out_dim = 1 53 | width = 40 54 | self.layers = nn.Sequential(SineResidualBlock(in_dim, width, omega_0=1., is_first=True), 55 | SineResidualBlock(width, width, omega_0=1.), 56 | SineResidualBlock(width, width, omega_0=1.), 57 | SineResidualBlock(width, width, omega_0=1.), 58 | nn.Linear(width, out_dim), 59 | ) 60 | 61 | def forward(self, x): 62 | '''Forward pass''' 63 | return self.layers(x) 64 | 65 | class Position_Net(nn.Module): 66 | def __init__(self, num_vorts): 67 | super(Position_Net, self).__init__() 68 | in_dim = 1 69 | out_dim = num_vorts * 2 70 | self.layers = nn.Sequential(SineResidualBlock(in_dim, 64, omega_0=1., is_first=True), 71 | SineResidualBlock(64, 128, omega_0=1.), 72 | SineResidualBlock(128, 256, omega_0=1.), 73 | nn.Linear(256, out_dim) 74 | ) 75 | 76 | def forward(self, x): 77 | '''Forward pass''' 78 | return self.layers(x) 79 | 80 | class Model(nn.Module): 81 | def __init__(self, args): 82 | super(Model, self).__init__() 83 | self.ckptdir = args.ckptdir if hasattr(args, 'ckptdir') else 'checkpoints' 84 | self.num_vorts = args.num_vorts if hasattr(args, 'num_vorts') else 16 85 | self.decay_gamma = args.decay_gamma if hasattr(args, 'decay_gamma') else 0.99 86 | self.decimate_point = args.decimate_point if hasattr(args,'decimate_point') else 20000 # LR decimates at this point 87 | self.decay_step = max(1, int(self.decimate_point/math.log(0.1, self.decay_gamma))) # decay once every (# >= 1) learning steps 88 | self.pre_ckptdir = args.pre_ckptdir if hasattr(args, 'pre_ckptdir') else './models/libs/vortex/pretrained.tar' 89 | self.width = int(((args.w - 1) / args.w_down) + 1) 90 | self.height = int(((args.h - 1) / args.h_down) + 1) 91 | self.C_out = args.out_dim * args.out_var 92 | self.device = torch.device('cuda') 93 | #self.net_dict, self.start, self.grad_vars, self.optimizer, self.lr_scheduler = create_bundle(self.ckptdir, self.num_vorts, self.decay_step, self.decay_gamma, pretrain_dir = self.pre_ckptdir) 94 | self.img_x = gen_grid(self.width, self.height, device) # grid coordinates' 95 | self.batch_size = args.batch_size 96 | self.vort_scale = args.vort_scale if hasattr(args, 'vort_scale') else 0.5 97 | self.num_sims = 1 98 | self.net_dict_len = Dynamics_Net() 99 | self.net_dict_pos = Position_Net(self.num_vorts) 100 | pre_ckpt = torch.load(self.pre_ckptdir) 101 | self.net_dict_pos.load_state_dict(pre_ckpt['model_pos_state_dict']) 102 | self.register_parameter('w_pred_param', nn.Parameter(torch.zeros(self.num_vorts, 1, dtype=torch.float32))) 103 | self.register_parameter('size_pred_param', nn.Parameter(torch.zeros(self.num_vorts, 1,dtype=torch.float32))) 104 | def eval_vel(self, vorts_size, vorts_w, vorts_pos, query_pos): 105 | return vort_to_vel(self.net_dict_len, vorts_size, vorts_w, vorts_pos, query_pos, length_scale = self.vort_scale) 106 | 107 | def dist_2_len_(self, dist): 108 | return self.net_dict_len(dist) 109 | 110 | def size_pred(self): 111 | pred = self.size_pred_param 112 | size = 0.03 + torch.sigmoid(pred) 113 | return size 114 | 115 | def w_pred(self): 116 | pred = self.w_pred_param 117 | w = torch.sin(pred) 118 | return w 119 | 120 | def comp_velocity(self, timestamps): 121 | jac = vmap(jacrev((self.net_dict_pos)))(timestamps) 122 | post = jac[:, :, 0:1].view((timestamps.shape[0],-1,2,1)) 123 | xt = post[:, :, 0, :] 124 | yt = post[:, :, 1, :] 125 | uv = torch.cat((xt, yt), dim = 2) 126 | return uv 127 | 128 | def forward(self, x, index): 129 | #x: b h w c 130 | index = index.unsqueeze(1).float() 131 | with torch.no_grad(): 132 | pos_pred_gradless = self.net_dict_pos(index).view((-1,self.num_vorts,2)) 133 | D_vel = self.eval_vel(self.size_pred(), self.w_pred(), pos_pred_gradless, pos_pred_gradless) 134 | 135 | # if boundary is not None: 136 | # D_vel = boundary_treatment(pos_pred_gradless, D_vel, boundary, mode = 1) 137 | 138 | # velocity loss 139 | T_vel = self.comp_velocity(index) # velocity prescribed by trajectory module 140 | vel_loss = 0.001 * L2_Loss(T_vel, D_vel) 141 | pos_pred = self.net_dict_pos(index).view((self.batch_size,self.num_vorts,2)) 142 | sim_imgs, sim_vorts_poss, sim_img_vels, sim_vorts_vels = simulate(x.clone(), self.img_x, pos_pred, self.w_pred(), \ 143 | self.size_pred(), self.num_sims, vel_func = self.eval_vel, boundary = None) 144 | return sim_imgs[-1][:, :, :, :self.C_out], vel_loss -------------------------------------------------------------------------------- /models/libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/DeepLag/694b3cd98d8fc1c65803d4bdbbbc174f8a4d1322/models/libs/__init__.py -------------------------------------------------------------------------------- /models/libs/basics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | from einops import rearrange, repeat 7 | 8 | 9 | class PreNorm(nn.Module): 10 | def __init__(self, dim, fn): 11 | super().__init__() 12 | self.norm = nn.LayerNorm(dim) 13 | self.fn = fn 14 | 15 | def forward(self, x, **kwargs): 16 | return self.fn(self.norm(x), **kwargs) 17 | 18 | 19 | class PostNorm(nn.Module): 20 | def __init__(self, dim, fn): 21 | super().__init__() 22 | self.norm = nn.LayerNorm(dim) 23 | self.fn = fn 24 | 25 | def forward(self, x, **kwargs): 26 | return self.norm(self.fn(x, **kwargs)) 27 | 28 | 29 | class GeAct(nn.Module): 30 | """Gated activation function""" 31 | def __init__(self, act_fn): 32 | super().__init__() 33 | self.fn = act_fn 34 | 35 | def forward(self, x): 36 | c = x.shape[-1] # channel last arrangement 37 | return self.fn(x[..., :int(c//2)]) * x[..., int(c//2):] 38 | 39 | 40 | class MLP(nn.Module): 41 | def __init__(self, dims, act_fn, dropout=0.): 42 | super().__init__() 43 | layers = [] 44 | 45 | for i in range(len(dims) - 1): 46 | if isinstance(act_fn, GeAct) and i < len(dims) - 2: 47 | layers.append(nn.Linear(dims[i], dims[i+1] * 2)) 48 | else: 49 | layers.append(nn.Linear(dims[i], dims[i+1])) 50 | if i < len(dims) - 2: 51 | layers.append(act_fn) 52 | layers.append(nn.Dropout(dropout)) 53 | self.net = nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | return self.net(x) 57 | 58 | 59 | def masked_instance_norm(x, mask, eps = 1e-6): 60 | """ 61 | x of shape: [batch_size (N), num_objects (L), features(C)] 62 | mask of shape: [batch_size (N), num_objects (L), 1] 63 | """ 64 | mask = mask.float() # (N,L,1) 65 | mean = (torch.sum(x * mask, 1) / torch.sum(mask, 1)) # (N,C) 66 | mean = mean.detach() 67 | var_term = ((x - mean.unsqueeze(1).expand_as(x)) * mask)**2 # (N,L,C) 68 | var = (torch.sum(var_term, 1) / torch.sum(mask, 1)) #(N,C) 69 | var = var.detach() 70 | mean_reshaped = mean.unsqueeze(1).expand_as(x) # (N, L, C) 71 | var_reshaped = var.unsqueeze(1).expand_as(x) # (N, L, C) 72 | ins_norm = (x - mean_reshaped) / torch.sqrt(var_reshaped + eps) # (N, L, C) 73 | return ins_norm 74 | 75 | 76 | def get_time_embedding(t, dim): 77 | """ 78 | This matches the implementation in Denoising Diffusion Probabilistic Models: 79 | From Fairseq. 80 | Build sinusoidal embeddings. 81 | This matches the implementation in tensor2tensor, but differs slightly 82 | from the description in Section 3.5 of "Attention Is All You Need". 83 | """ 84 | assert len(t.shape) == 1 85 | 86 | half_dim = dim // 2 87 | emb = math.log(10000) / (half_dim - 1) 88 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 89 | emb = emb.to(device=t.device) 90 | emb = t.float()[:, None] * emb[None, :] 91 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 92 | if dim % 2 == 1: # zero pad 93 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 94 | return emb 95 | 96 | 97 | # below code are taken from the amazing Hyena repo: 98 | # https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/src/models/sequence/hyena.py#L64 99 | 100 | class Sin(nn.Module): 101 | def __init__(self, dim, w=10, train_freq=True): 102 | super().__init__() 103 | self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) 104 | 105 | def forward(self, x): 106 | return torch.sin(self.freq * x) 107 | 108 | 109 | class PositionalEmbedding(nn.Module): 110 | def __init__(self, 111 | emb_dim: int, 112 | seq_len: int, 113 | lr_pos_emb: float = 1e-5, 114 | **kwargs): 115 | """Complex exponential positional embeddings for Hyena filters.""" 116 | super().__init__() 117 | 118 | self.seq_len = seq_len 119 | # The time embedding fed to the filteres is normalized so that t_f = 1 120 | t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 121 | 122 | assert emb_dim > 1 123 | bands = (emb_dim - 1) // 2 124 | # To compute the right embeddings we use the "proper" linspace 125 | t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] 126 | w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 127 | 128 | f = torch.linspace(1e-4, bands - 1, bands)[None, None] # 1, 1, emb_dim 129 | z = torch.exp(-1j * f * w) 130 | z = torch.cat([t, z.real, z.imag], dim=-1) 131 | self.register_parameter("z", nn.Parameter(z)) 132 | optim = {"lr": lr_pos_emb} 133 | setattr(getattr(self, "z"), "_optim", optim) 134 | self.register_buffer("t", t) 135 | 136 | def forward(self, L): 137 | return self.z[:, :L], self.t[:, :L] 138 | 139 | 140 | # reference convolution with residual connection 141 | def fftconv_ref(u, k, D, dropout_mask, gelu=False, k_rev=None): 142 | seqlen = u.shape[-1] 143 | fft_size = 2 * seqlen 144 | k_f = torch.fft.rfft(k, n=fft_size) / fft_size 145 | if k_rev is not None: 146 | k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size 147 | k_f = k_f + k_rev_f.conj() 148 | u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) 149 | 150 | if len(u.shape) > 3: 151 | k_f = k_f.unsqueeze(1) 152 | 153 | y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] 154 | 155 | out = y + u * D.unsqueeze(-1) # bias term 156 | if gelu: 157 | out = F.gelu(out) 158 | if dropout_mask is not None: 159 | return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) 160 | else: 161 | return out.to(dtype=u.dtype) 162 | 163 | 164 | class ExponentialModulation(nn.Module): 165 | def __init__( 166 | self, 167 | d_model, 168 | fast_decay_pct=0.3, 169 | slow_decay_pct=1.5, 170 | target=1e-2, 171 | modulate: bool = True, 172 | shift: float = 0.0, 173 | **kwargs 174 | ): 175 | super().__init__() 176 | self.modulate = modulate 177 | self.shift = shift 178 | max_decay = math.log(target) / fast_decay_pct 179 | min_decay = math.log(target) / slow_decay_pct 180 | deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] 181 | self.register_buffer("deltas", deltas) 182 | 183 | def forward(self, t, x): 184 | if self.modulate: 185 | decay = torch.exp(-t * self.deltas.abs()) 186 | x = x * (decay + self.shift) 187 | return x 188 | 189 | 190 | -------------------------------------------------------------------------------- /models/libs/fact/basics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | from einops import rearrange, repeat 7 | 8 | 9 | class PreNorm(nn.Module): 10 | def __init__(self, dim, fn): 11 | super().__init__() 12 | self.norm = nn.LayerNorm(dim) 13 | self.fn = fn 14 | 15 | def forward(self, x, **kwargs): 16 | return self.fn(self.norm(x), **kwargs) 17 | 18 | 19 | class PostNorm(nn.Module): 20 | def __init__(self, dim, fn): 21 | super().__init__() 22 | self.norm = nn.LayerNorm(dim) 23 | self.fn = fn 24 | 25 | def forward(self, x, **kwargs): 26 | return self.norm(self.fn(x, **kwargs)) 27 | 28 | 29 | class GeAct(nn.Module): 30 | """Gated activation function""" 31 | def __init__(self, act_fn): 32 | super().__init__() 33 | self.fn = act_fn 34 | 35 | def forward(self, x): 36 | c = x.shape[-1] # channel last arrangement 37 | return self.fn(x[..., :int(c//2)]) * x[..., int(c//2):] 38 | 39 | 40 | class MLP(nn.Module): 41 | def __init__(self, dims, act_fn, dropout=0.): 42 | super().__init__() 43 | layers = [] 44 | 45 | for i in range(len(dims) - 1): 46 | if isinstance(act_fn, GeAct) and i < len(dims) - 2: 47 | layers.append(nn.Linear(dims[i], dims[i+1] * 2)) 48 | else: 49 | layers.append(nn.Linear(dims[i], dims[i+1])) 50 | if i < len(dims) - 2: 51 | layers.append(act_fn) 52 | layers.append(nn.Dropout(dropout)) 53 | self.net = nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | return self.net(x) 57 | 58 | 59 | def masked_instance_norm(x, mask, eps = 1e-6): 60 | """ 61 | x of shape: [batch_size (N), num_objects (L), features(C)] 62 | mask of shape: [batch_size (N), num_objects (L), 1] 63 | """ 64 | mask = mask.float() # (N,L,1) 65 | mean = (torch.sum(x * mask, 1) / torch.sum(mask, 1)) # (N,C) 66 | mean = mean.detach() 67 | var_term = ((x - mean.unsqueeze(1).expand_as(x)) * mask)**2 # (N,L,C) 68 | var = (torch.sum(var_term, 1) / torch.sum(mask, 1)) #(N,C) 69 | var = var.detach() 70 | mean_reshaped = mean.unsqueeze(1).expand_as(x) # (N, L, C) 71 | var_reshaped = var.unsqueeze(1).expand_as(x) # (N, L, C) 72 | ins_norm = (x - mean_reshaped) / torch.sqrt(var_reshaped + eps) # (N, L, C) 73 | return ins_norm 74 | 75 | 76 | def get_time_embedding(t, dim): 77 | """ 78 | This matches the implementation in Denoising Diffusion Probabilistic Models: 79 | From Fairseq. 80 | Build sinusoidal embeddings. 81 | This matches the implementation in tensor2tensor, but differs slightly 82 | from the description in Section 3.5 of "Attention Is All You Need". 83 | """ 84 | assert len(t.shape) == 1 85 | 86 | half_dim = dim // 2 87 | emb = math.log(10000) / (half_dim - 1) 88 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 89 | emb = emb.to(device=t.device) 90 | emb = t.float()[:, None] * emb[None, :] 91 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 92 | if dim % 2 == 1: # zero pad 93 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 94 | return emb 95 | 96 | 97 | # below code are taken from the amazing Hyena repo: 98 | # https://github.com/HazyResearch/safari/blob/9ecfaf0e49630b5913fce19adec231b41c2e0e39/src/models/sequence/hyena.py#L64 99 | 100 | class Sin(nn.Module): 101 | def __init__(self, dim, w=10, train_freq=True): 102 | super().__init__() 103 | self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) 104 | 105 | def forward(self, x): 106 | return torch.sin(self.freq * x) 107 | 108 | 109 | class PositionalEmbedding(nn.Module): 110 | def __init__(self, 111 | emb_dim: int, 112 | seq_len: int, 113 | lr_pos_emb: float = 1e-5, 114 | **kwargs): 115 | """Complex exponential positional embeddings for Hyena filters.""" 116 | super().__init__() 117 | 118 | self.seq_len = seq_len 119 | # The time embedding fed to the filteres is normalized so that t_f = 1 120 | t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 121 | 122 | assert emb_dim > 1 123 | bands = (emb_dim - 1) // 2 124 | # To compute the right embeddings we use the "proper" linspace 125 | t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] 126 | w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 127 | 128 | f = torch.linspace(1e-4, bands - 1, bands)[None, None] # 1, 1, emb_dim 129 | z = torch.exp(-1j * f * w) 130 | z = torch.cat([t, z.real, z.imag], dim=-1) 131 | self.register_parameter("z", nn.Parameter(z)) 132 | optim = {"lr": lr_pos_emb} 133 | setattr(getattr(self, "z"), "_optim", optim) 134 | self.register_buffer("t", t) 135 | 136 | def forward(self, L): 137 | return self.z[:, :L], self.t[:, :L] 138 | 139 | 140 | # reference convolution with residual connection 141 | def fftconv_ref(u, k, D, dropout_mask, gelu=False, k_rev=None): 142 | seqlen = u.shape[-1] 143 | fft_size = 2 * seqlen 144 | k_f = torch.fft.rfft(k, n=fft_size) / fft_size 145 | if k_rev is not None: 146 | k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size 147 | k_f = k_f + k_rev_f.conj() 148 | u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) 149 | 150 | if len(u.shape) > 3: 151 | k_f = k_f.unsqueeze(1) 152 | 153 | y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] 154 | 155 | out = y + u * D.unsqueeze(-1) # bias term 156 | if gelu: 157 | out = F.gelu(out) 158 | if dropout_mask is not None: 159 | return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) 160 | else: 161 | return out.to(dtype=u.dtype) 162 | 163 | 164 | class ExponentialModulation(nn.Module): 165 | def __init__( 166 | self, 167 | d_model, 168 | fast_decay_pct=0.3, 169 | slow_decay_pct=1.5, 170 | target=1e-2, 171 | modulate: bool = True, 172 | shift: float = 0.0, 173 | **kwargs 174 | ): 175 | super().__init__() 176 | self.modulate = modulate 177 | self.shift = shift 178 | max_decay = math.log(target) / fast_decay_pct 179 | min_decay = math.log(target) / slow_decay_pct 180 | deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] 181 | self.register_buffer("deltas", deltas) 182 | 183 | def forward(self, t, x): 184 | if self.modulate: 185 | decay = torch.exp(-t * self.deltas.abs()) 186 | x = x * (decay + self.shift) 187 | return x 188 | 189 | 190 | -------------------------------------------------------------------------------- /models/libs/fact/factorization_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | from typing import Union, Tuple, List, Optional 8 | from .positional_encoding_module import RotaryEmbedding, apply_rotary_pos_emb, SirenNet 9 | from .basics import PreNorm, PostNorm, GeAct, MLP, masked_instance_norm 10 | from .attention import LowRankKernel 11 | 12 | 13 | class PoolingReducer(nn.Module): 14 | def __init__(self, 15 | in_dim, 16 | hidden_dim, 17 | out_dim): 18 | super().__init__() 19 | self.to_in = nn.Linear(in_dim, hidden_dim, bias=False) 20 | self.out_ffn = PreNorm(in_dim, MLP([hidden_dim, hidden_dim, out_dim], GeAct(nn.GELU()))) 21 | 22 | def forward(self, x): 23 | # note that the dimension to be pooled will be the last dimension 24 | # x: b nx ... c 25 | x = self.to_in(x) 26 | # pool all spatial dimension but the first one 27 | ndim = len(x.shape) 28 | x = x.mean(dim=tuple(range(2, ndim-1))) 29 | x = self.out_ffn(x) 30 | return x # b nx c 31 | 32 | 33 | class FABlock2D(nn.Module): 34 | # contains factorization and attention on each axis 35 | def __init__(self, 36 | dim, 37 | dim_head, 38 | latent_dim, 39 | heads, 40 | dim_out, 41 | use_rope=True, 42 | kernel_multiplier=3, 43 | scaling_factor=1.0): 44 | super().__init__() 45 | 46 | self.dim = dim 47 | self.latent_dim = latent_dim 48 | self.heads = heads 49 | self.dim_head = dim_head 50 | self.in_norm = nn.LayerNorm(dim) 51 | self.to_v = nn.Linear(self.dim, heads * dim_head, bias=False) 52 | self.to_in = nn.Linear(self.dim, self.dim, bias=False) 53 | 54 | self.to_x = nn.Sequential( 55 | PoolingReducer(self.dim, self.dim, self.latent_dim), 56 | ) 57 | self.to_y = nn.Sequential( 58 | Rearrange('b nx ny c -> b ny nx c'), 59 | PoolingReducer(self.dim, self.dim, self.latent_dim), 60 | ) 61 | 62 | positional_encoding = 'rotary' if use_rope else 'none' 63 | use_softmax = False 64 | self.low_rank_kernel_x = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 65 | positional_embedding=positional_encoding, 66 | residual=False, # add a diagonal bias 67 | softmax=use_softmax, 68 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 69 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 70 | self.low_rank_kernel_y = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 71 | positional_embedding=positional_encoding, 72 | residual=False, 73 | softmax=use_softmax, 74 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 75 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 76 | 77 | self.instance = nn.InstanceNorm2d(dim_head * heads) 78 | self.to_out = nn.Sequential( 79 | nn.Linear(dim_head * heads, dim_out, bias=False), 80 | nn.GELU(), 81 | nn.Linear(dim_out, dim_out, bias=False)) 82 | 83 | def forward(self, u, pos_lst): 84 | # x: b c h w 85 | u = self.in_norm(u) 86 | v = self.to_v(u) 87 | u = self.to_in(u) 88 | 89 | u_x = self.to_x(u) 90 | u_y = self.to_y(u) 91 | 92 | pos_x, pos_y = pos_lst 93 | k_x = self.low_rank_kernel_x(u_x, pos_x=pos_x) 94 | k_y = self.low_rank_kernel_y(u_y, pos_x=pos_y) 95 | 96 | u_phi = rearrange(v, 'b i l (h c) -> b h i l c', h=self.heads) 97 | u_phi = torch.einsum('bhij,bhjmc->bhimc', k_x, u_phi) 98 | u_phi = torch.einsum('bhlm,bhimc->bhilc', k_y, u_phi) 99 | u_phi = rearrange(u_phi, 'b h i l c -> b i l (h c)', h=self.heads) 100 | u_phi = u_phi.permute(0, 3, 1, 2).contiguous() 101 | u_phi = self.instance(u_phi) 102 | u_phi = u_phi.permute(0, 2, 3, 1).contiguous() 103 | ret = self.to_out(u_phi) 104 | return ret 105 | 106 | 107 | class FABlock3D(nn.Module): 108 | # contains factorization and attention on each axis 109 | def __init__(self, 110 | dim, 111 | dim_head, 112 | latent_dim, 113 | heads, 114 | dim_out, 115 | use_rope=True, 116 | kernel_multiplier=3, 117 | scaling_factor=1.0): 118 | super().__init__() 119 | 120 | self.dim = dim 121 | self.latent_dim = latent_dim 122 | self.heads = heads 123 | self.dim_head = dim_head 124 | self.in_norm = nn.LayerNorm(dim) 125 | self.to_v = nn.Linear(self.dim, heads * dim_head, bias=False) 126 | self.to_in = nn.Linear(self.dim, self.dim, bias=False) 127 | 128 | self.to_x = nn.Sequential( 129 | PoolingReducer(self.dim, self.dim, self.latent_dim), 130 | ) 131 | self.to_y = nn.Sequential( 132 | Rearrange('b nx ny nz c -> b ny nx nz c'), 133 | PoolingReducer(self.dim, self.dim, self.latent_dim), 134 | ) 135 | self.to_z = nn.Sequential( 136 | Rearrange('b nx ny nz c -> b nz nx ny c'), 137 | PoolingReducer(self.dim, self.dim, self.latent_dim), 138 | ) 139 | 140 | positional_encoding = 'rotary' if use_rope else 'none' 141 | use_softmax = False 142 | self.low_rank_kernel_x = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 143 | positional_embedding=positional_encoding, 144 | residual=False, # add a diagonal bias 145 | softmax=use_softmax, 146 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 147 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 148 | self.low_rank_kernel_y = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 149 | positional_embedding=positional_encoding, 150 | residual=False, 151 | softmax=use_softmax, 152 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 153 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 154 | self.low_rank_kernel_z = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 155 | positional_embedding=positional_encoding, 156 | residual=False, 157 | softmax=use_softmax, 158 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 159 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 160 | 161 | self.to_out = nn.Sequential( 162 | nn.InstanceNorm3d(dim_head * heads), 163 | nn.Linear(dim_head * heads, dim_out, bias=False), 164 | nn.GELU(), 165 | nn.Linear(dim_out, dim_out, bias=False)) 166 | 167 | def forward(self, u, pos_lst): 168 | # x: b h w d c 169 | u = self.in_norm(u) 170 | v = self.to_v(u) 171 | u = self.to_in(u) 172 | 173 | u_x = self.to_x(u) 174 | u_y = self.to_y(u) 175 | u_z = self.to_z(u) 176 | pos_x, pos_y, pos_z = pos_lst 177 | 178 | k_x = self.low_rank_kernel_x(u_x, pos_x=pos_x) 179 | k_y = self.low_rank_kernel_y(u_y, pos_x=pos_y) 180 | k_z = self.low_rank_kernel_z(u_z, pos_x=pos_z) 181 | 182 | u_phi = rearrange(v, 'b i l r (h c) -> b h i l r c', h=self.heads) 183 | u_phi = torch.einsum('bhij,bhjmsc->bhimsc', k_x, u_phi) 184 | u_phi = torch.einsum('bhlm,bhimsc->bhilsc', k_y, u_phi) 185 | u_phi = torch.einsum('bhrs,bhilsc->bhilrc', k_z, u_phi) 186 | u_phi = rearrange(u_phi, 'b h i l r c -> b i l r (h c)', h=self.heads) 187 | 188 | return self.to_out(u_phi) 189 | -------------------------------------------------------------------------------- /models/libs/fact/positional_encoding_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat, reduce 5 | import numpy as np 6 | 7 | 8 | # modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py 9 | class RotaryEmbedding(nn.Module): 10 | def __init__(self, dim, min_freq=1/64, scale=1.): 11 | super().__init__() 12 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 13 | self.min_freq = min_freq 14 | self.scale = scale 15 | self.register_buffer('inv_freq', inv_freq) 16 | 17 | def forward(self, coordinates, device): 18 | # coordinates [b, n] 19 | t = coordinates.to(device).type_as(self.inv_freq) 20 | t = t * (self.scale / self.min_freq) 21 | freqs = torch.einsum('... i , j -> ... i j', t, self.inv_freq) # [b, n, d//2] 22 | return torch.cat((freqs, freqs), dim=-1) # [b, n, d] 23 | 24 | 25 | def rotate_half(x): 26 | x = rearrange(x, '... (j d) -> ... j d', j = 2) 27 | x1, x2 = x.unbind(dim = -2) 28 | return torch.cat((-x2, x1), dim = -1) 29 | 30 | 31 | def apply_rotary_pos_emb(t, freqs): 32 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) 33 | 34 | 35 | def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): 36 | # split t into first half and second half 37 | # t: [b, h, n, d] 38 | # freq_x/y: [b, n, d] 39 | d = t.shape[-1] 40 | t_x, t_y = t[..., :d//2], t[..., d//2:] 41 | 42 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x), 43 | apply_rotary_pos_emb(t_y, freqs_y)), dim=-1) 44 | 45 | def apply_3d_rotary_pos_emb(t, freqs_x, freqs_y, freqs_z): 46 | # split t into three parts 47 | # t: [b, h, n, d] 48 | # freq_x/y: [b, n, d] 49 | d = t.shape[-1] 50 | t_x, t_y, t_z = t[..., :d//3], t[..., d//3:2*d//3], t[..., 2*d//3:] 51 | 52 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x), 53 | apply_rotary_pos_emb(t_y, freqs_y), 54 | apply_rotary_pos_emb(t_z, freqs_z)), dim=-1) 55 | 56 | 57 | # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py 58 | def get_emb(sin_inp): 59 | """ 60 | Gets a base embedding for one dimension with sin and cos intertwined 61 | """ 62 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 63 | return torch.flatten(emb, -2, -1) 64 | 65 | 66 | # Gaussian Fourier features 67 | # code modified from: https://github.com/ndahlquist/pytorch-fourier-feature-networks 68 | # author: Nic Dahlquist 69 | class GaussianFourierFeatureTransform(nn.Module): 70 | """ 71 | An implementation of Gaussian Fourier feature mapping. 72 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": 73 | https://arxiv.org/abs/2006.10739 74 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 75 | Given an input of size [batches, n, num_input_channels], 76 | returns a tensor of size [batches, n, mapping_size*2]. 77 | """ 78 | 79 | def __init__(self, num_input_channels, 80 | mapping_size=256, scale=10, learnable=False, 81 | num_heads=1): 82 | super().__init__() 83 | 84 | self._num_input_channels = num_input_channels 85 | self._mapping_size = mapping_size 86 | 87 | self._B = nn.Parameter(torch.randn((num_input_channels, mapping_size * num_heads)) * scale, 88 | requires_grad=learnable) 89 | self.num_heads = num_heads 90 | 91 | def forward(self, x, unfold_head=False): 92 | if len(x.shape) == 2: 93 | x = x.unsqueeze(0) 94 | batches, num_of_points, channels = x.shape 95 | 96 | # Make shape compatible for matmul with _B. 97 | # From [B, N, C] to [(B*N), C]. 98 | x = rearrange(x, 'b n c -> (b n) c') 99 | 100 | x = x @ self._B.to(x.device) 101 | 102 | # From [(B*W*H), C] to [B, W, H, C] 103 | x = rearrange(x, '(b n) c -> b n c', b=batches) 104 | 105 | x = 2 * np.pi * x 106 | if unfold_head: 107 | x = rearrange(x, 'b n (h d) -> b h n d', h=self.num_heads) 108 | return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 109 | 110 | 111 | 112 | # helpers 113 | 114 | def exists(val): 115 | return val is not None 116 | 117 | def cast_tuple(val, repeat = 1): 118 | return val if isinstance(val, tuple) else ((val,) * repeat) 119 | 120 | # sin activation 121 | 122 | class Sine(nn.Module): 123 | def __init__(self, w0 = 1.): 124 | super().__init__() 125 | self.w0 = w0 126 | 127 | def forward(self, x): 128 | return torch.sin(self.w0 * x) 129 | 130 | # siren layer 131 | 132 | class Siren(nn.Module): 133 | def __init__(self, 134 | dim_in, 135 | dim_out, 136 | w0=1., 137 | c=6., 138 | is_first=False, 139 | use_bias=True, 140 | activation=None): 141 | super().__init__() 142 | self.dim_in = dim_in 143 | self.is_first = is_first 144 | 145 | weight = torch.zeros(dim_out, dim_in) 146 | bias = torch.zeros(dim_out) if use_bias else None 147 | self.init_(weight, bias, c=c, w0=w0) 148 | 149 | self.weight = nn.Parameter(weight) 150 | self.bias = nn.Parameter(bias) if use_bias else None 151 | self.activation = Sine(w0) if activation is None else activation 152 | 153 | def init_(self, weight, bias, c, w0): 154 | dim = self.dim_in 155 | 156 | w_std = (1 / dim) if self.is_first else (np.sqrt(c / dim) / w0) 157 | weight.uniform_(-w_std, w_std) 158 | 159 | if exists(bias): 160 | bias.uniform_(-w_std, w_std) 161 | 162 | def forward(self, x): 163 | out = F.linear(x, self.weight, self.bias) 164 | out = self.activation(out) 165 | return out 166 | 167 | # siren network 168 | class SirenNet(nn.Module): 169 | def __init__(self, 170 | dim_in, 171 | dim_hidden, dim_out, num_layers, 172 | w0=1., 173 | w0_initial=30., 174 | use_bias=True, final_activation=None, 175 | normalize_input=True): 176 | super().__init__() 177 | self.num_layers = num_layers 178 | self.dim_hidden = dim_hidden 179 | self.normalize_input = normalize_input 180 | 181 | self.layers = nn.ModuleList([]) 182 | for ind in range(num_layers): 183 | is_first = ind == 0 184 | layer_w0 = w0_initial if is_first else w0 185 | layer_dim_in = dim_in if is_first else dim_hidden 186 | 187 | self.layers.append(Siren( 188 | dim_in=layer_dim_in, 189 | dim_out=dim_hidden, 190 | w0=layer_w0, 191 | use_bias=use_bias, 192 | is_first=is_first, 193 | )) 194 | 195 | final_activation = nn.Identity() if not exists(final_activation) else final_activation 196 | self.last_layer = Siren(dim_in=dim_hidden, 197 | dim_out=dim_out, 198 | w0=w0, 199 | use_bias=use_bias, 200 | activation=final_activation) 201 | 202 | # self.last_layer = nn.Linear(dim_hidden, dim_out) 203 | # init last layer orthogonally 204 | # nn.init.orthogonal_(self.last_layer.weight, gain=1/dim_out) 205 | 206 | def in_norm(self, x): 207 | return (2 * x - torch.min(x, dim=1, keepdim=True)[0] - torch.max(x, dim=1, keepdim=True)[0]) /\ 208 | (torch.max(x, dim=1, keepdim=True)[0] - torch.min(x, dim=1, keepdim=True)[0]) 209 | 210 | def forward(self, x, mods=None): 211 | if self.normalize_input: 212 | x = self.in_norm(x) 213 | # x = (x - 0.5) * 2 214 | 215 | for layer in self.layers: 216 | x = layer(x) 217 | if mods is not None: 218 | x *= mods 219 | x = self.last_layer(x) 220 | # x = self.final_activation(x) 221 | return x 222 | -------------------------------------------------------------------------------- /models/libs/factorization_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | from typing import Union, Tuple, List, Optional 8 | from positional_encoding_module import RotaryEmbedding, apply_rotary_pos_emb, SirenNet 9 | from basics import PreNorm, PostNorm, GeAct, MLP, masked_instance_norm 10 | from attention import LowRankKernel 11 | 12 | 13 | class PoolingReducer(nn.Module): 14 | def __init__(self, 15 | in_dim, 16 | hidden_dim, 17 | out_dim): 18 | super().__init__() 19 | self.to_in = nn.Linear(in_dim, hidden_dim, bias=False) 20 | self.out_ffn = PreNorm(in_dim, MLP([hidden_dim, hidden_dim, out_dim], GeAct(nn.GELU()))) 21 | 22 | def forward(self, x): 23 | # note that the dimension to be pooled will be the last dimension 24 | # x: b nx ... c 25 | x = self.to_in(x) 26 | # pool all spatial dimension but the first one 27 | ndim = len(x.shape) 28 | x = x.mean(dim=tuple(range(2, ndim-1))) 29 | x = self.out_ffn(x) 30 | return x # b nx c 31 | 32 | 33 | class FABlock2D(nn.Module): 34 | # contains factorization and attention on each axis 35 | def __init__(self, 36 | dim, 37 | dim_head, 38 | latent_dim, 39 | heads, 40 | dim_out, 41 | use_rope=True, 42 | kernel_multiplier=3, 43 | scaling_factor=1.0): 44 | super().__init__() 45 | 46 | self.dim = dim 47 | self.latent_dim = latent_dim 48 | self.heads = heads 49 | self.dim_head = dim_head 50 | self.in_norm = nn.LayerNorm(dim) 51 | self.to_v = nn.Linear(self.dim, heads * dim_head, bias=False) 52 | self.to_in = nn.Linear(self.dim, self.dim, bias=False) 53 | 54 | self.to_x = nn.Sequential( 55 | PoolingReducer(self.dim, self.dim, self.latent_dim), 56 | ) 57 | self.to_y = nn.Sequential( 58 | Rearrange('b nx ny c -> b ny nx c'), 59 | PoolingReducer(self.dim, self.dim, self.latent_dim), 60 | ) 61 | 62 | positional_encoding = 'rotary' if use_rope else 'none' 63 | use_softmax = False 64 | self.low_rank_kernel_x = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 65 | positional_embedding=positional_encoding, 66 | residual=False, # add a diagonal bias 67 | softmax=use_softmax, 68 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 69 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 70 | self.low_rank_kernel_y = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 71 | positional_embedding=positional_encoding, 72 | residual=False, 73 | softmax=use_softmax, 74 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 75 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 76 | 77 | self.instance = nn.InstanceNorm2d(dim_head * heads) 78 | self.to_out = nn.Sequential( 79 | nn.Linear(dim_head * heads, dim_out, bias=False), 80 | nn.GELU(), 81 | nn.Linear(dim_out, dim_out, bias=False)) 82 | 83 | def forward(self, u, pos_lst): 84 | # x: b c h w 85 | u = self.in_norm(u) 86 | v = self.to_v(u) 87 | u = self.to_in(u) 88 | 89 | u_x = self.to_x(u) 90 | u_y = self.to_y(u) 91 | 92 | pos_x, pos_y = pos_lst 93 | k_x = self.low_rank_kernel_x(u_x, pos_x=pos_x) 94 | k_y = self.low_rank_kernel_y(u_y, pos_x=pos_y) 95 | 96 | u_phi = rearrange(v, 'b i l (h c) -> b h i l c', h=self.heads) 97 | u_phi = torch.einsum('bhij,bhjmc->bhimc', k_x, u_phi) 98 | u_phi = torch.einsum('bhlm,bhimc->bhilc', k_y, u_phi) 99 | u_phi = rearrange(u_phi, 'b h i l c -> b i l (h c)', h=self.heads) 100 | u_phi = u_phi.permute(0, 3, 1, 2).contiguous() 101 | u_phi = self.instance(u_phi) 102 | u_phi = u_phi.permute(0, 2, 3, 1).contiguous() 103 | ret = self.to_out(u_phi) 104 | return ret 105 | 106 | 107 | class FABlock3D(nn.Module): 108 | # contains factorization and attention on each axis 109 | def __init__(self, 110 | dim, 111 | dim_head, 112 | latent_dim, 113 | heads, 114 | dim_out, 115 | use_rope=True, 116 | kernel_multiplier=3, 117 | scaling_factor=1.0): 118 | super().__init__() 119 | 120 | self.dim = dim 121 | self.latent_dim = latent_dim 122 | self.heads = heads 123 | self.dim_head = dim_head 124 | self.in_norm = nn.LayerNorm(dim) 125 | self.to_v = nn.Linear(self.dim, heads * dim_head, bias=False) 126 | self.to_in = nn.Linear(self.dim, self.dim, bias=False) 127 | 128 | self.to_x = nn.Sequential( 129 | PoolingReducer(self.dim, self.dim, self.latent_dim), 130 | ) 131 | self.to_y = nn.Sequential( 132 | Rearrange('b nx ny nz c -> b ny nx nz c'), 133 | PoolingReducer(self.dim, self.dim, self.latent_dim), 134 | ) 135 | self.to_z = nn.Sequential( 136 | Rearrange('b nx ny nz c -> b nz nx ny c'), 137 | PoolingReducer(self.dim, self.dim, self.latent_dim), 138 | ) 139 | 140 | positional_encoding = 'rotary' if use_rope else 'none' 141 | use_softmax = False 142 | self.low_rank_kernel_x = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 143 | positional_embedding=positional_encoding, 144 | residual=False, # add a diagonal bias 145 | softmax=use_softmax, 146 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 147 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 148 | self.low_rank_kernel_y = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 149 | positional_embedding=positional_encoding, 150 | residual=False, 151 | softmax=use_softmax, 152 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 153 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 154 | self.low_rank_kernel_z = LowRankKernel(self.latent_dim, dim_head * kernel_multiplier, heads, 155 | positional_embedding=positional_encoding, 156 | residual=False, 157 | softmax=use_softmax, 158 | scaling=1 / np.sqrt(dim_head * kernel_multiplier) 159 | if kernel_multiplier > 4 or use_softmax else scaling_factor) 160 | 161 | self.to_out = nn.Sequential( 162 | nn.InstanceNorm3d(dim_head * heads), 163 | nn.Linear(dim_head * heads, dim_out, bias=False), 164 | nn.GELU(), 165 | nn.Linear(dim_out, dim_out, bias=False)) 166 | 167 | def forward(self, u, pos_lst): 168 | # x: b h w d c 169 | u = self.in_norm(u) 170 | v = self.to_v(u) 171 | u = self.to_in(u) 172 | 173 | u_x = self.to_x(u) 174 | u_y = self.to_y(u) 175 | u_z = self.to_z(u) 176 | pos_x, pos_y, pos_z = pos_lst 177 | 178 | k_x = self.low_rank_kernel_x(u_x, pos_x=pos_x) 179 | k_y = self.low_rank_kernel_y(u_y, pos_x=pos_y) 180 | k_z = self.low_rank_kernel_z(u_z, pos_x=pos_z) 181 | 182 | u_phi = rearrange(v, 'b i l r (h c) -> b h i l r c', h=self.heads) 183 | u_phi = torch.einsum('bhij,bhjmsc->bhimsc', k_x, u_phi) 184 | u_phi = torch.einsum('bhlm,bhimsc->bhilsc', k_y, u_phi) 185 | u_phi = torch.einsum('bhrs,bhilsc->bhilrc', k_z, u_phi) 186 | u_phi = rearrange(u_phi, 'b h i l r c -> b i l r (h c)', h=self.heads) 187 | 188 | return self.to_out(u_phi) 189 | -------------------------------------------------------------------------------- /models/libs/gktrm/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | # from .utils import * 3 | # from .utils_ft import * 4 | # from .ft import * 5 | # from .model import * 6 | -------------------------------------------------------------------------------- /models/libs/positional_encoding_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import rearrange, repeat, reduce 5 | import numpy as np 6 | 7 | 8 | # modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py 9 | class RotaryEmbedding(nn.Module): 10 | def __init__(self, dim, min_freq=1/64, scale=1.): 11 | super().__init__() 12 | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 13 | self.min_freq = min_freq 14 | self.scale = scale 15 | self.register_buffer('inv_freq', inv_freq) 16 | 17 | def forward(self, coordinates, device): 18 | # coordinates [b, n] 19 | t = coordinates.to(device).type_as(self.inv_freq) 20 | t = t * (self.scale / self.min_freq) 21 | freqs = torch.einsum('... i , j -> ... i j', t, self.inv_freq) # [b, n, d//2] 22 | return torch.cat((freqs, freqs), dim=-1) # [b, n, d] 23 | 24 | 25 | def rotate_half(x): 26 | x = rearrange(x, '... (j d) -> ... j d', j = 2) 27 | x1, x2 = x.unbind(dim = -2) 28 | return torch.cat((-x2, x1), dim = -1) 29 | 30 | 31 | def apply_rotary_pos_emb(t, freqs): 32 | return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) 33 | 34 | 35 | def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): 36 | # split t into first half and second half 37 | # t: [b, h, n, d] 38 | # freq_x/y: [b, n, d] 39 | d = t.shape[-1] 40 | t_x, t_y = t[..., :d//2], t[..., d//2:] 41 | 42 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x), 43 | apply_rotary_pos_emb(t_y, freqs_y)), dim=-1) 44 | 45 | def apply_3d_rotary_pos_emb(t, freqs_x, freqs_y, freqs_z): 46 | # split t into three parts 47 | # t: [b, h, n, d] 48 | # freq_x/y: [b, n, d] 49 | d = t.shape[-1] 50 | t_x, t_y, t_z = t[..., :d//3], t[..., d//3:2*d//3], t[..., 2*d//3:] 51 | 52 | return torch.cat((apply_rotary_pos_emb(t_x, freqs_x), 53 | apply_rotary_pos_emb(t_y, freqs_y), 54 | apply_rotary_pos_emb(t_z, freqs_z)), dim=-1) 55 | 56 | 57 | # https://github.com/tatp22/multidim-positional-encoding/blob/master/positional_encodings/torch_encodings.py 58 | def get_emb(sin_inp): 59 | """ 60 | Gets a base embedding for one dimension with sin and cos intertwined 61 | """ 62 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 63 | return torch.flatten(emb, -2, -1) 64 | 65 | 66 | # Gaussian Fourier features 67 | # code modified from: https://github.com/ndahlquist/pytorch-fourier-feature-networks 68 | # author: Nic Dahlquist 69 | class GaussianFourierFeatureTransform(nn.Module): 70 | """ 71 | An implementation of Gaussian Fourier feature mapping. 72 | "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains": 73 | https://arxiv.org/abs/2006.10739 74 | https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html 75 | Given an input of size [batches, n, num_input_channels], 76 | returns a tensor of size [batches, n, mapping_size*2]. 77 | """ 78 | 79 | def __init__(self, num_input_channels, 80 | mapping_size=256, scale=10, learnable=False, 81 | num_heads=1): 82 | super().__init__() 83 | 84 | self._num_input_channels = num_input_channels 85 | self._mapping_size = mapping_size 86 | 87 | self._B = nn.Parameter(torch.randn((num_input_channels, mapping_size * num_heads)) * scale, 88 | requires_grad=learnable) 89 | self.num_heads = num_heads 90 | 91 | def forward(self, x, unfold_head=False): 92 | if len(x.shape) == 2: 93 | x = x.unsqueeze(0) 94 | batches, num_of_points, channels = x.shape 95 | 96 | # Make shape compatible for matmul with _B. 97 | # From [B, N, C] to [(B*N), C]. 98 | x = rearrange(x, 'b n c -> (b n) c') 99 | 100 | x = x @ self._B.to(x.device) 101 | 102 | # From [(B*W*H), C] to [B, W, H, C] 103 | x = rearrange(x, '(b n) c -> b n c', b=batches) 104 | 105 | x = 2 * np.pi * x 106 | if unfold_head: 107 | x = rearrange(x, 'b n (h d) -> b h n d', h=self.num_heads) 108 | return torch.cat([torch.sin(x), torch.cos(x)], dim=-1) 109 | 110 | 111 | 112 | # helpers 113 | 114 | def exists(val): 115 | return val is not None 116 | 117 | def cast_tuple(val, repeat = 1): 118 | return val if isinstance(val, tuple) else ((val,) * repeat) 119 | 120 | # sin activation 121 | 122 | class Sine(nn.Module): 123 | def __init__(self, w0 = 1.): 124 | super().__init__() 125 | self.w0 = w0 126 | 127 | def forward(self, x): 128 | return torch.sin(self.w0 * x) 129 | 130 | # siren layer 131 | 132 | class Siren(nn.Module): 133 | def __init__(self, 134 | dim_in, 135 | dim_out, 136 | w0=1., 137 | c=6., 138 | is_first=False, 139 | use_bias=True, 140 | activation=None): 141 | super().__init__() 142 | self.dim_in = dim_in 143 | self.is_first = is_first 144 | 145 | weight = torch.zeros(dim_out, dim_in) 146 | bias = torch.zeros(dim_out) if use_bias else None 147 | self.init_(weight, bias, c=c, w0=w0) 148 | 149 | self.weight = nn.Parameter(weight) 150 | self.bias = nn.Parameter(bias) if use_bias else None 151 | self.activation = Sine(w0) if activation is None else activation 152 | 153 | def init_(self, weight, bias, c, w0): 154 | dim = self.dim_in 155 | 156 | w_std = (1 / dim) if self.is_first else (np.sqrt(c / dim) / w0) 157 | weight.uniform_(-w_std, w_std) 158 | 159 | if exists(bias): 160 | bias.uniform_(-w_std, w_std) 161 | 162 | def forward(self, x): 163 | out = F.linear(x, self.weight, self.bias) 164 | out = self.activation(out) 165 | return out 166 | 167 | # siren network 168 | class SirenNet(nn.Module): 169 | def __init__(self, 170 | dim_in, 171 | dim_hidden, dim_out, num_layers, 172 | w0=1., 173 | w0_initial=30., 174 | use_bias=True, final_activation=None, 175 | normalize_input=True): 176 | super().__init__() 177 | self.num_layers = num_layers 178 | self.dim_hidden = dim_hidden 179 | self.normalize_input = normalize_input 180 | 181 | self.layers = nn.ModuleList([]) 182 | for ind in range(num_layers): 183 | is_first = ind == 0 184 | layer_w0 = w0_initial if is_first else w0 185 | layer_dim_in = dim_in if is_first else dim_hidden 186 | 187 | self.layers.append(Siren( 188 | dim_in=layer_dim_in, 189 | dim_out=dim_hidden, 190 | w0=layer_w0, 191 | use_bias=use_bias, 192 | is_first=is_first, 193 | )) 194 | 195 | final_activation = nn.Identity() if not exists(final_activation) else final_activation 196 | self.last_layer = Siren(dim_in=dim_hidden, 197 | dim_out=dim_out, 198 | w0=w0, 199 | use_bias=use_bias, 200 | activation=final_activation) 201 | 202 | # self.last_layer = nn.Linear(dim_hidden, dim_out) 203 | # init last layer orthogonally 204 | # nn.init.orthogonal_(self.last_layer.weight, gain=1/dim_out) 205 | 206 | def in_norm(self, x): 207 | return (2 * x - torch.min(x, dim=1, keepdim=True)[0] - torch.max(x, dim=1, keepdim=True)[0]) /\ 208 | (torch.max(x, dim=1, keepdim=True)[0] - torch.min(x, dim=1, keepdim=True)[0]) 209 | 210 | def forward(self, x, mods=None): 211 | if self.normalize_input: 212 | x = self.in_norm(x) 213 | # x = (x - 0.5) * 2 214 | 215 | for layer in self.layers: 216 | x = layer(x) 217 | if mods is not None: 218 | x *= mods 219 | x = self.last_layer(x) 220 | # x = self.final_activation(x) 221 | return x 222 | -------------------------------------------------------------------------------- /models/libs/vortex/io_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import imageio.v2 as imageio 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import configargparse 9 | 10 | def to_numpy(x): 11 | return x.detach().cpu().numpy() 12 | 13 | def to8b(x): 14 | return (255*np.clip(x,0,1)).astype(np.uint8) 15 | 16 | # create gif from images using FFMPEG 17 | def merge_imgs(framerate, save_dir): 18 | os.system('ffmpeg -hide_banner -loglevel error -y -i {0}/%03d.jpg -vf palettegen {0}/palette.png'.format(save_dir)) 19 | os.system('ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/%03d.jpg -i {1}/palette.png -lavfi paletteuse {1}/output.gif'.format(framerate, save_dir)) 20 | 21 | # remove everything in dir 22 | def remove_everything_in(folder): 23 | for filename in os.listdir(folder): 24 | file_path = os.path.join(folder, filename) 25 | try: 26 | if os.path.isfile(file_path) or os.path.islink(file_path): 27 | os.unlink(file_path) 28 | elif os.path.isdir(file_path): 29 | shutil.rmtree(file_path) 30 | except Exception as e: 31 | print('Failed to delete %s. Reason: %s' % (file_path, e)) 32 | 33 | # read image 34 | def imwrite(f, img): 35 | img = to8b(img) 36 | imageio.imwrite(f, img) # save frame as jpeg file 37 | 38 | # generate grid coordinates 39 | def gen_grid(width, height, device): 40 | img_n_grid_x = width 41 | img_n_grid_y = height 42 | img_dx = 1./img_n_grid_y 43 | c_x, c_y = torch.meshgrid(torch.arange(img_n_grid_x), torch.arange(img_n_grid_y), indexing = "ij") 44 | img_x = img_dx * (torch.cat((c_x[..., None], c_y[..., None]), axis = 2) + 0.5).to(device) # grid center locations 45 | return img_x 46 | 47 | # write image 48 | def write_image(img_xy, outdir, i): 49 | img_xy = copy.deepcopy(img_xy) 50 | c_pred = np.flip(img_xy.transpose([1,0,2]), 0) 51 | img8b = to8b(c_pred) 52 | save_filepath = os.path.join(outdir, '{:03d}.jpg'.format(i)) 53 | imageio.imwrite(save_filepath, img8b) 54 | 55 | # write vortex (particles) with their velocities 56 | def write_vorts(vorts_pos, vorts_uv, outdir, i): 57 | vorts_pos = copy.deepcopy(vorts_pos) 58 | pos = vorts_pos 59 | fig = plt.figure(num=1, figsize=(7, 7), clear=True) 60 | ax = fig.add_subplot() 61 | fig.subplots_adjust(0.1,0.1,0.9,0.9) 62 | ax.set_xlim([0, 1]) 63 | ax.set_ylim([0, 1]) 64 | s = ax.scatter(pos[..., 0], pos[..., 1], s = 100) 65 | ax.quiver(pos[..., 0], pos[..., 1], vorts_uv[..., 0], vorts_uv[..., 1], color = "red", scale = 10.) 66 | fig.savefig(os.path.join(outdir, '{:03d}.jpg'.format(i)), dpi = 512//8) 67 | 68 | # write vorticity field 69 | def write_vorticity(vort_img, outdir, i): 70 | vort_img = copy.deepcopy(vort_img) 71 | array = vort_img 72 | scale = array.shape[1] 73 | array = np.transpose(array, (1, 0, 2)) # from X, Y to Y, X 74 | fig = plt.figure(num=1, figsize=(8, 7), clear=True) 75 | ax = fig.add_subplot() 76 | fig.subplots_adjust(0.05,0.,0.9,1) 77 | ax.set_xlim([0, array.shape[0]]) 78 | ax.set_ylim([0, array.shape[1]]) 79 | p = ax.imshow(array, alpha = 0.75, vmin = -10, vmax = 10) 80 | fig.colorbar(p, fraction=0.04) 81 | fig.savefig(os.path.join(outdir, '{:03d}.jpg'.format(i)), dpi = 512//8) 82 | 83 | # write vortices over image 84 | def write_visualization(img, vorts_pos, vorts_w, outdir, i, boundary = None): 85 | img = copy.deepcopy(img) 86 | # mask the out-of-bound area as green 87 | if boundary is not None: 88 | OUT = (boundary[0] >= -boundary[2]).cpu().numpy() 89 | img[OUT] = (0.5 * img[OUT]) 90 | img[OUT, 1] += 0.5 91 | vorts_pos = copy.deepcopy(vorts_pos) 92 | vorts_w = copy.deepcopy(vorts_w) 93 | array = img 94 | scale = array.shape[1] 95 | pos = vorts_pos * scale 96 | array = np.transpose(array, (1, 0, 2)) # from X, Y to Y, X 97 | fig = plt.figure(num=1, figsize=(8, 7), clear=True) 98 | ax = fig.add_subplot() 99 | fig.subplots_adjust(0.05,0.,0.9,1) 100 | ax.set_xlim([0, array.shape[0]]) 101 | ax.set_ylim([0, array.shape[1]]) 102 | s = ax.scatter(pos[..., 0], pos[..., 1], s = 100, c = vorts_w.flatten(), vmin = None, vmax = None) 103 | p = ax.imshow(array, alpha = 0.75) 104 | fig.colorbar(s, fraction=0.04) 105 | fig.savefig(os.path.join(outdir, '{:03d}.jpg'.format(i)), dpi = 512//8) 106 | 107 | # convert rgb image to yuv parametrization 108 | def rgb_to_yuv(_rgb_image): 109 | rgb_image = _rgb_image[..., None] 110 | matrix = np.array([[0.299, 0.587, 0.114], 111 | [-0.14713, -0.28886, 0.436], 112 | [0.615, -0.51499, -0.10001]]).astype(np.float32) 113 | matrix = torch.from_numpy(matrix)[None, None, None, ...].to(rgb_image.get_device()) 114 | yuv_image = torch.einsum("abcde, abcef -> abcdf", matrix, rgb_image) 115 | return yuv_image.squeeze() 116 | 117 | # command line tools 118 | def config_parser(): 119 | parser = configargparse.ArgumentParser() 120 | parser.add_argument('--config', is_config_file=True, 121 | help='config file path') 122 | parser.add_argument('--seen_ratio', type=float, default=0.3333, 123 | help='fraction of input video available during training') 124 | parser.add_argument("--data_name", type=str, default='synthetic_1', 125 | help='name of video data') 126 | parser.add_argument("--run_pretrain", type=bool, default = False, 127 | help='whether to run pretrain only') 128 | parser.add_argument("--test_only", type=bool, default = False, 129 | help='whether to run test only') 130 | parser.add_argument("--start_over", type=bool, default = False, 131 | help='whether to clear previous record on this experiment') 132 | parser.add_argument("--exp_name", type=str, default = "exp_0", 133 | help='the name of current experiment') 134 | parser.add_argument('--vort_scale', type=float, default=0.33, 135 | help='characteristic scale for vortices') 136 | parser.add_argument('--num_train', type=int, default=40000, 137 | help='number of training iterations') 138 | parser.add_argument("--init_vort_dist", type=float, default=1.0, 139 | help='how spread out are init vortices') 140 | return parser 141 | -------------------------------------------------------------------------------- /models/libs/vortex/learning_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from torch.optim.lr_scheduler import LambdaLR, StepLR 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | device = torch.device("cuda") 10 | real = torch.float32 11 | 12 | L2_Loss = nn.MSELoss().cuda() 13 | 14 | class SineResidualBlock(nn.Module): 15 | 16 | def __init__(self, in_features, out_features, bias=True, 17 | is_first=False, omega_0=30): 18 | super().__init__() 19 | self.omega_0 = omega_0 20 | self.is_first = is_first 21 | 22 | self.in_features = in_features 23 | self.linear = nn.Linear(in_features, out_features, bias=bias) 24 | # add shortcut 25 | self.shortcut = nn.Sequential() 26 | if in_features != out_features: 27 | self.shortcut = nn.Sequential( 28 | nn.Linear(in_features, out_features), 29 | ) 30 | 31 | self.init_weights() 32 | 33 | def init_weights(self): 34 | with torch.no_grad(): 35 | if self.is_first: 36 | self.linear.weight.uniform_(-1 / self.in_features, 37 | 1 / self.in_features) 38 | else: 39 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 40 | np.sqrt(6 / self.in_features) / self.omega_0) 41 | 42 | def forward(self, input): 43 | out = torch.sin(self.omega_0 * self.linear(input)) 44 | out += self.shortcut(input) 45 | out = nn.functional.relu(out) 46 | return out 47 | 48 | class Dynamics_Net(nn.Module): 49 | def __init__(self): 50 | super().__init__() 51 | in_dim = 1 52 | out_dim = 1 53 | width = 40 54 | self.layers = nn.Sequential(SineResidualBlock(in_dim, width, omega_0=1., is_first=True), 55 | SineResidualBlock(width, width, omega_0=1.), 56 | SineResidualBlock(width, width, omega_0=1.), 57 | SineResidualBlock(width, width, omega_0=1.), 58 | nn.Linear(width, out_dim), 59 | ) 60 | 61 | def forward(self, x): 62 | '''Forward pass''' 63 | return self.layers(x) 64 | 65 | class Position_Net(nn.Module): 66 | def __init__(self, num_vorts): 67 | super().__init__() 68 | in_dim = 1 69 | out_dim = num_vorts * 2 70 | self.layers = nn.Sequential(SineResidualBlock(in_dim, 64, omega_0=1., is_first=True), 71 | SineResidualBlock(64, 128, omega_0=1.), 72 | SineResidualBlock(128, 256, omega_0=1.), 73 | nn.Linear(256, out_dim) 74 | ) 75 | 76 | def forward(self, x): 77 | '''Forward pass''' 78 | return self.layers(x) 79 | 80 | 81 | def create_bundle(logdir, num_vorts, decay_step, decay_gamma, pretrain_dir = None): 82 | model_len = Dynamics_Net().to(device) 83 | model_pos = Position_Net(num_vorts).to(device) 84 | grad_vars = list(model_len.parameters()) 85 | grad_vars2 = list(model_pos.parameters()) 86 | ########################## 87 | # Load checkpoints 88 | ckpts = [os.path.join(logdir, f) for f in sorted(os.listdir(logdir)) if 'tar' in f] 89 | pretrain_ckpts = [] 90 | if pretrain_dir: 91 | pretrain_ckpts = [os.path.join(pretrain_dir, f) for f in sorted(os.listdir(pretrain_dir)) if 'tar' in f] 92 | print("[Initialize] Pretrain ckpts: ", pretrain_ckpts) 93 | 94 | if len(ckpts) <= 0: # no checkpoints to load 95 | w_pred = torch.zeros(num_vorts, 1, device = device, dtype = real) 96 | w_pred.requires_grad = True 97 | size_pred = torch.zeros(num_vorts, 1, device = device, dtype = real) 98 | size_pred.requires_grad = True 99 | start = 0 100 | optimizer = torch.optim.Adam([{'params': grad_vars}, \ 101 | {'params': grad_vars2, 'lr':3.e-4},\ 102 | {'params': w_pred, 'lr':5.e-3},\ 103 | {'params': size_pred, 'lr':5.e-3}], lr=1.e-3, betas=(0.9, 0.999)) 104 | # Load pretrained if there is one and no checkpoint exists 105 | if len(pretrain_ckpts) > 0: 106 | pre_ckpt_path = pretrain_ckpts[-1] 107 | print ("[Initialize] Has pretrained available, reloading from: ", pre_ckpt_path) 108 | pre_ckpt = torch.load(pre_ckpt_path) 109 | model_pos.load_state_dict(pre_ckpt['model_pos_state_dict']) 110 | 111 | else: # has checkpoints to load: 112 | ckpt_path = ckpts[-1] 113 | print ("[Initialize] Has checkpoint available, reloading from: ", ckpt_path) 114 | ckpt = torch.load(ckpt_path) 115 | start = ckpt['global_step'] 116 | w_pred = ckpt['w_pred'] 117 | w_pred.requires_grad = True 118 | size_pred = ckpt['size_pred'] 119 | size_pred.requires_grad = True 120 | model_len.load_state_dict(ckpt["model_len_state_dict"]) 121 | model_pos.load_state_dict(ckpt["model_pos_state_dict"]) 122 | optimizer = torch.optim.Adam([{'params': grad_vars}, \ 123 | {'params': grad_vars2, 'lr':3.e-4},\ 124 | {'params': w_pred, 'lr':5.e-3},\ 125 | {'params': size_pred, 'lr':5.e-3}], lr=1.e-3, betas=(0.9, 0.999)) 126 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 127 | 128 | lr_scheduler = StepLR(optimizer, step_size = decay_step, gamma = decay_gamma) 129 | 130 | 131 | ########################## 132 | net_dict = { 133 | 'model_len' : model_len, 134 | 'model_pos' : model_pos, 135 | 'w_pred' : w_pred, 136 | 'size_pred' : size_pred, 137 | } 138 | 139 | return net_dict, start, grad_vars, optimizer, lr_scheduler 140 | 141 | 142 | # vels: [batch, width, height, 2] 143 | def calc_div(vels): 144 | batch_size, width, height, D = vels.shape 145 | dx = 1./height 146 | du_dx = 1./(2*dx) * (vels[:, 2:, 1:-1, 0] - vels[:, :-2, 1:-1, 0]) 147 | dv_dy = 1./(2*dx) * (vels[:, 1:-1, 2:, 1] - vels[:, 1:-1, :-2, 1]) 148 | return du_dx + dv_dy 149 | 150 | # field: [batch, width, height, 1] 151 | def calc_grad(field): 152 | batch_size, width, height, _ = field.shape 153 | dx = 1./height 154 | df_dx = 1./(2*dx) * (field[:, 2:, 1:-1] - field[:, :-2, 1:-1]) 155 | df_dy = 1./(2*dx) * (field[:, 1:-1, 2:] - field[:, 1:-1, :-2]) 156 | return torch.cat((df_dx, df_dy), dim = -1) 157 | 158 | def calc_vort(vel_img, boundary = None): # compute the curl of velocity 159 | W, H, _ = vel_img.shape 160 | dx = 1./H 161 | vort_img = torch.zeros(W, H, 1, device = device, dtype = real) 162 | u = vel_img[...,[0]] 163 | v = vel_img[...,[1]] 164 | dvdx = 1/(2*dx) * (v[2:, 1:-1] - v[:-2, 1:-1]) 165 | dudy = 1/(2*dx) * (u[1:-1, 2:] - u[1:-1, :-2]) 166 | vort_img[1:-1, 1:-1] = dvdx - dudy 167 | if boundary is not None: 168 | # set out-of-bound pixels to 0 because velocity undefined there 169 | OUT = (boundary[0] >= -boundary[2] - 4) 170 | vort_img[OUT] *= 0 171 | return vort_img 172 | 173 | # sdf: [W, H] 174 | # sdf normal: [W, H, 2] 175 | def calc_sdf_normal(sdf): 176 | W, H = sdf.shape 177 | sdf_normal = torch.zeros((W, H, 2)).cuda() #[W, H, 2] 178 | sdf_normal[1:-1, 1:-1] = calc_grad(sdf[None,...,None])[0] # outward pointing [W, H, 2] 179 | sdf_normal = F.normalize(sdf_normal, dim = -1, p = 2) 180 | return sdf_normal 181 | 182 | # vorts_pos: [batch, num_vorts, 2] 183 | # query_pos: [num_query, 2] or [batch, num_query, 2] 184 | # return: [batch, num_queries, num_vorts, 2] 185 | def calc_diff_batched(_vorts_pos, _query_pos): 186 | vorts_pos = _vorts_pos[:, None, :, :] # [batch, 1, num_vorts, 2] 187 | if len(_query_pos.shape) > 2: 188 | query_pos = _query_pos[:, :, None, :] # [batch, num_query, 1, 2] 189 | else: 190 | query_pos = _query_pos[None, :, None, :] # [1, num_query, 1, 2] 191 | diff = query_pos - vorts_pos # [batch, num_queries, num_vorts, 2] 192 | return diff 193 | 194 | 195 | # vorts_pos shape: [batch, num_vorts, 2] 196 | # vorts_w shape: [num_vorts, 1] or [batch, num_vorts, 1] 197 | # vorts_size shape: [num_vorts, 1] or [batch, num_vorts, 1] 198 | def vort_to_vel(network_length, vorts_size, vorts_w, vorts_pos, query_pos, length_scale): 199 | diff = calc_diff_batched(vorts_pos, query_pos) # [batch_size, num_query, num_query, 2] 200 | # some broadcasting 201 | if len(vorts_size.shape) > 2: 202 | blob_size = vorts_size[:, None, ...] # [batch, 1, num_vorts, 1] 203 | else: 204 | blob_size = vorts_size[None, None, ...] # [1, 1, num_vorts, 1] 205 | if len(vorts_w.shape) > 2: 206 | vorts_w = vorts_w[:, None, ...] # [batch, num_query, num_vort, 1] 207 | else: 208 | vorts_w = vorts_w[None, None, ...] # [1, 1, num_vort, 1] 209 | 210 | diff = calc_diff_batched(vorts_pos, query_pos) 211 | dist = torch.norm(diff, dim = -1, p = 2, keepdim = True) 212 | dist_not_zero = dist > 0.0 213 | 214 | # cross product in 2D 215 | R = diff.flip([-1]) # (x, y) becomes (y, x) 216 | R[..., 0] *= -1 # (y, x) becomes (-y, x) 217 | R = F.normalize(R, dim = -1) 218 | 219 | dist = dist / (blob_size/length_scale) 220 | dist[dist_not_zero] = torch.pow(dist[dist_not_zero], 0.3) 221 | magnitude = network_length(dist) 222 | magnitude = magnitude / (blob_size/length_scale) 223 | 224 | result = magnitude * R * vorts_w 225 | result = torch.sum(result, dim = -2) # [batch_size, num_queries, 2] 226 | 227 | return result -------------------------------------------------------------------------------- /models/libs/vortex/pretrained.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/DeepLag/694b3cd98d8fc1c65803d4bdbbbc174f8a4d1322/models/libs/vortex/pretrained.tar -------------------------------------------------------------------------------- /models/libs/vortex/simulation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(123) 3 | import numpy as np 4 | import os 5 | from .learning_utils import calc_grad 6 | import math 7 | import time 8 | from functorch import vmap 9 | import torch.nn.functional as F 10 | 11 | device = torch.device("cuda") 12 | real = torch.float32 13 | 14 | def RK1(pos, u, dt): 15 | return pos + dt * u 16 | 17 | def RK2(pos, u, dt): 18 | p_mid = pos + 0.5 * dt * u 19 | return pos + dt * sample_grid_batched(u, p_mid) 20 | 21 | def RK3(pos, u, dt): 22 | u1 = u 23 | p1 = pos + 0.5 * dt * u1 24 | u2 = sample_grid_batched(u, p1) 25 | p2 = pos + 0.75 * dt * u2 26 | u3 = sample_grid_batched(u, p2) 27 | return pos + dt * (2/9 * u1 + 1/3 * u2 + 4/9 * u3) 28 | 29 | def advect_quantity_batched(quantity, u, x, dt, boundary): 30 | return advect_quantity_batched_BFECC(quantity, u, x, dt, boundary) 31 | 32 | # pos: [num_queries, 2] 33 | # if a backtraced position is out-of-bound, project it to the interior 34 | def project_to_inside(pos, boundary): 35 | if boundary is None: # if no boundary then do nothing 36 | return pos 37 | sdf, sdf_normal, _ = boundary 38 | W, H = sdf.shape 39 | dx = 1./H 40 | pos_grid = (pos / dx).floor().long() 41 | pos_grid_x = pos_grid[...,0] 42 | pos_grid_y = pos_grid[...,1] 43 | pos_grid_x = torch.clamp(pos_grid_x, 0, W-1) 44 | pos_grid_y = torch.clamp(pos_grid_y, 0, H-1) 45 | sd_at_pos = sdf[pos_grid_x, pos_grid_y][...,None] # [num_queries, 1] 46 | sd_normal_at_pos = sdf_normal[pos_grid_x, pos_grid_y] # [num_queries, 2] 47 | OUT = (sd_at_pos >= -boundary[2]).squeeze(-1) # [num_queries] 48 | OUT_pos = pos[OUT] #[num_out_queries, 2] 49 | OUT_pos_fixed = OUT_pos - (sd_at_pos[OUT]+boundary[2]) * dx * sd_normal_at_pos[OUT] # remember to multiply by dx 50 | pos[OUT] = OUT_pos_fixed 51 | return pos 52 | 53 | 54 | def index_take_2D(source, index_x, index_y): 55 | W, H, Channel = source.shape 56 | W_, H_ = index_x.shape 57 | index_flattened_x = index_x.flatten() 58 | index_flattened_y = index_y.flatten() 59 | sampled = source[index_flattened_x, index_flattened_y].view((W_, H_, Channel)) 60 | return sampled 61 | 62 | index_take_batched = vmap(index_take_2D) 63 | 64 | # clipping used for MacCormack and BFECC 65 | def MacCormack_clip(advected_quantity, quantity, u, x, dt, boundary): 66 | batch, W, H, _ = u.shape 67 | prev_pos = RK3(x, u, -1. * dt) # [batch, W, H, 2] 68 | prev_pos = project_to_inside(prev_pos.view((-1, 2)), boundary).view(prev_pos.shape) 69 | dx = 1./H 70 | pos_grid = (prev_pos / dx - 0.5).floor().long() 71 | pos_grid_x = torch.clamp(pos_grid[..., 0], 0, W-2) 72 | pos_grid_y = torch.clamp(pos_grid[..., 1], 0, H-2) 73 | pos_grid_x_plus = pos_grid_x + 1 74 | pos_grid_y_plus = pos_grid_y + 1 75 | BL = index_take_batched(quantity, pos_grid_x, pos_grid_y) 76 | BR = index_take_batched(quantity, pos_grid_x_plus, pos_grid_y) 77 | TR = index_take_batched(quantity, pos_grid_x_plus, pos_grid_y_plus) 78 | TL = index_take_batched(quantity, pos_grid_x, pos_grid_y_plus) 79 | stacked = torch.stack((BL, BR, TR, TL), dim = 0) 80 | maxed = torch.max(stacked, dim = 0).values # [batch, W, H, 3] 81 | mined = torch.min(stacked, dim = 0).values # [batch, W, H, 3] 82 | _advected_quantity = torch.clamp(advected_quantity, mined, maxed) 83 | return _advected_quantity 84 | 85 | # SL 86 | def advect_quantity_batched_SL(quantity, u, x, dt, boundary): 87 | prev_pos = RK3(x, u, -1. * dt) # [batch, W, H, 2] 88 | prev_pos = project_to_inside(prev_pos.view((-1, 2)), boundary).view(prev_pos.shape) 89 | new_quantity = sample_grid_batched(quantity, prev_pos) 90 | return new_quantity 91 | 92 | # BFECC 93 | def advect_quantity_batched_BFECC(quantity, u, x, dt, boundary): 94 | quantity1 = advect_quantity_batched_SL(quantity, u, x, dt, boundary) 95 | quantity2 = advect_quantity_batched_SL(quantity1, u, x, -1.*dt, boundary) 96 | new_quantity = advect_quantity_batched_SL(quantity + 0.5 * (quantity-quantity2), u, x, dt, boundary) 97 | new_quantity = MacCormack_clip(new_quantity, quantity, u, x, dt, boundary) 98 | return new_quantity 99 | 100 | # MacCormack 101 | def advect_quantity_batched_MacCormack(quantity, u, x, dt, boundary): 102 | quantity1 = advect_quantity_batched_SL(quantity, u, x, dt, boundary) 103 | quantity2 = advect_quantity_batched_SL(quantity1, u, x, -1.*dt, boundary) 104 | new_quantity = quantity1 + 0.5 * (quantity - quantity2) 105 | new_quantity = MacCormack_clip(new_quantity, quantity, u, x, dt, boundary) 106 | return new_quantity 107 | 108 | # data = [batch, X, Y, n_channel] 109 | # pos = [batch, X, Y, 2] 110 | def sample_grid_batched(data, pos): 111 | data_ = data.permute([0, 3, 2, 1]) 112 | pos_ = pos.clone().permute([0, 2, 1, 3]) 113 | pos_ = (pos_ - 0.5) * 2 114 | F_sample_grid = F.grid_sample(data_, pos_, padding_mode = 'border', align_corners = False, mode = "bilinear") 115 | F_sample_grid = F_sample_grid.permute([0, 3, 2, 1]) 116 | return F_sample_grid 117 | 118 | # pos: [num_query, 2] or [batch, num_query, 2] 119 | # vel: [batch, num_query, 2] 120 | # mode: 0 for image, 1 for vort 121 | def boundary_treatment(pos, vel, boundary, mode = 0): 122 | vel_after = vel.clone() 123 | batch, num_query, _ = vel.shape 124 | sdf = boundary[0] # [W, H] 125 | sdf_normal = boundary[1] 126 | if mode == 0: 127 | score = torch.clamp((sdf / -15.), min = 0.).flatten() 128 | inside_band = (score < 1.).squeeze(-1).flatten() 129 | score = score[None, ..., None] 130 | vel_after[:, inside_band, :] = score[:, inside_band, :] * vel[:, inside_band, :] 131 | else: 132 | W, H = sdf.shape 133 | dx = 1./H 134 | pos_grid = (pos / dx).floor().long() 135 | pos_grid_x = pos_grid[...,0] 136 | pos_grid_y = pos_grid[...,1] 137 | pos_grid_x = torch.clamp(pos_grid_x, 0, W-1) 138 | pos_grid_y = torch.clamp(pos_grid_y, 0, H-1) 139 | sd = sdf[pos_grid_x, pos_grid_y][...,None] 140 | sd_normal = sdf_normal[pos_grid_x, pos_grid_y] 141 | score = torch.clamp((sd / -75.), min = 0.) 142 | inside_band = (score < 1.).squeeze(-1) 143 | vel_normal = torch.einsum('bij,bij->bi', vel, sd_normal)[...,None] * sd_normal 144 | vel_tang = vel - vel_normal 145 | tang_at_boundary = 0.33 146 | vel_after[inside_band] = ((1.-tang_at_boundary) * score[inside_band] + tang_at_boundary) * vel_tang[inside_band] + score[inside_band] * vel_normal[inside_band] 147 | 148 | return vel_after 149 | 150 | # simulate a single step 151 | def simulate_step(img, img_x, vorts_pos, vorts_w, vorts_size, vel_func, dt, boundary): 152 | batch_size = vorts_pos.shape[0] 153 | img_x_flattened = img_x.view(-1, 2) 154 | if boundary is None: 155 | img_vel_flattened = vel_func(vorts_size, vorts_w, vorts_pos, img_x_flattened) 156 | img_vel = img_vel_flattened.view((batch_size, img_x.shape[0], img_x.shape[1], -1)) 157 | new_img = torch.clip(advect_quantity_batched(img, img_vel, img_x, dt, boundary), 0., 1.) 158 | vorts_vel = vel_func(vorts_size, vorts_w, vorts_pos, vorts_pos) 159 | new_vorts_pos = RK1(vorts_pos, vorts_vel, dt) 160 | else: 161 | OUT = (boundary[0]>=-boundary[2]) 162 | IN = ~OUT 163 | img_x_flattened = img_x.view(-1, 2) 164 | IN_flattened = IN.expand(img_x.shape[:-1]).flatten() 165 | img_vel_flattened = torch.zeros(batch_size, *img_x_flattened.shape).to(device) 166 | # only the velocity of the IN part will be computed, the rest will be left as 0 167 | img_vel_flattened[:, IN_flattened] = vel_func(vorts_size, vorts_w, vorts_pos, img_x_flattened[IN_flattened]) 168 | img_vel_flattened = boundary_treatment(img_x_flattened, img_vel_flattened, boundary, mode = 0) 169 | img_vel = img_vel_flattened.view((batch_size, img_x.shape[0], img_x.shape[1], -1)) 170 | new_img = torch.clip(advect_quantity_batched(img, img_vel, img_x, dt, boundary), 0., 1.) 171 | new_img[:, OUT] = img[:, OUT] # the image of the OUT part will be left unchanged 172 | vorts_vel = vel_func(vorts_size, vorts_w, vorts_pos, vorts_pos) 173 | vorts_vel = boundary_treatment(vorts_pos, vorts_vel, boundary, mode = 1) 174 | new_vorts_pos = RK1(vorts_pos, vorts_vel, dt) 175 | 176 | return new_img, new_vorts_pos, img_vel, vorts_vel 177 | 178 | # simulate in batches 179 | # img: the initial image 180 | # img_x: the grid coordinates (meshgrid) 181 | # vorts_pos: init vortex positions 182 | # vorts_w: vorticity 183 | # vorts_size: size 184 | # num_steps: how many steps to simulate 185 | # vel_func: how to compute velocity from vorticity 186 | def simulate(img, img_x, vorts_pos, vorts_w, vorts_size, num_steps, vel_func, boundary = None, dt = 0.01): 187 | imgs = [] 188 | vorts_poss = [] 189 | img_vels = [] 190 | vorts_vels = [] 191 | for i in range(num_steps): 192 | img, vorts_pos, img_vel, vorts_vel = simulate_step(img, img_x, vorts_pos, vorts_w, vorts_size, vel_func, dt, boundary = boundary) 193 | imgs.append(img.clone()) 194 | vorts_poss.append(vorts_pos.clone()) 195 | img_vels.append(img_vel) 196 | vorts_vels.append(vorts_vel) 197 | 198 | return imgs, vorts_poss, img_vels, vorts_vels 199 | -------------------------------------------------------------------------------- /pic/bounded-navier-stokes.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/DeepLag/694b3cd98d8fc1c65803d4bdbbbc174f8a4d1322/pic/bounded-navier-stokes.gif -------------------------------------------------------------------------------- /pic/eulag_block_v4.3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/DeepLag/694b3cd98d8fc1c65803d4bdbbbc174f8a4d1322/pic/eulag_block_v4.3.png -------------------------------------------------------------------------------- /pic/framework_v4.3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/DeepLag/694b3cd98d8fc1c65803d4bdbbbc174f8a4d1322/pic/framework_v4.3.png -------------------------------------------------------------------------------- /pic/ocean-current.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/DeepLag/694b3cd98d8fc1c65803d4bdbbbc174f8a4d1322/pic/ocean-current.gif -------------------------------------------------------------------------------- /pic/traj.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/DeepLag/694b3cd98d8fc1c65803d4bdbbbc174f8a4d1322/pic/traj.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ConfigArgParse 2 | einops 3 | galerkin_transformer 4 | h5py 5 | imageio 6 | ipython 7 | jupyterthemes 8 | matplotlib 9 | model 10 | numpy 11 | pandas 12 | plotly 13 | psutil 14 | PyYAML 15 | scipy 16 | seaborn 17 | torch 18 | torchinfo 19 | tqdm 20 | -------------------------------------------------------------------------------- /scripts/bc_deeplag.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=6 2 | 3 | python exp_bc_h.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 4 \ 19 | --w-down 4 \ 20 | --T-in 10 \ 21 | --T-out 10 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0002 \ 24 | --epochs 101 \ 25 | --step-size 100 \ 26 | --model DeepLag_2D \ 27 | --model-nickname deeplag \ 28 | --d-model 64 \ 29 | --num-samples 512 \ 30 | --num-layers 4 \ 31 | --num-basis 12 \ 32 | --num-token 4 \ 33 | --patch-size 4,4 \ 34 | --padding 0,0 \ 35 | --kernel-size 3 \ 36 | --offset-ratio-range 16,8 \ 37 | --resample-strategy learned \ 38 | --model-save-path ./checkpoints/bc \ 39 | --model-save-name deeplag.pt \ 40 | --delta-t 1 -------------------------------------------------------------------------------- /scripts/bc_factformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=6 2 | 3 | python exp_bc_h.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 2 \ 19 | --w-down 2 \ 20 | --T-in 10 \ 21 | --T-out 10 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0005 \ 24 | --epochs 101 \ 25 | --step-size 100 \ 26 | --model Factformer_2D \ 27 | --model-nickname factformer \ 28 | --depth 12 \ 29 | --d-model 128 \ 30 | --heads 6 \ 31 | --dim-head 64 \ 32 | --num-samples 512 \ 33 | --num-layers 4 \ 34 | --num-basis 12 \ 35 | --num-token 4 \ 36 | --patch-size 1,1 \ 37 | --padding 0,0 \ 38 | --kernel-size 3 \ 39 | --offset-ratio-range 16,8 \ 40 | --resample-strategy learned \ 41 | --model-save-path ./checkpoints/bc \ 42 | --model-save-name factformer.pt \ 43 | --delta-t 2 -------------------------------------------------------------------------------- /scripts/bc_fno.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=6 2 | 3 | python exp_bc.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 2 \ 19 | --w-down 2 \ 20 | --T-in 10 \ 21 | --T-out 10 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0005 \ 24 | --epochs 101 \ 25 | --model FNO_2D \ 26 | --model-nickname fno \ 27 | --d-model 64 \ 28 | --num-basis 12 \ 29 | --num-token 4 \ 30 | --patch-size 1,1 \ 31 | --padding 0,0 \ 32 | --model-save-path ./checkpoints/bc \ 33 | --model-save-name fno.pt \ 34 | --num-layers 10 \ -------------------------------------------------------------------------------- /scripts/bc_gktrm.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | python -u exp_bc_h.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 2 \ 19 | --w-down 2 \ 20 | --T-in 10 \ 21 | --T-out 10 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0005 \ 24 | --epochs 101 \ 25 | --step-size 100 \ 26 | --model GkTrm_2D \ 27 | --model-nickname gktrm \ 28 | --d-model 64 \ 29 | --num-samples 512 \ 30 | --num-layers 4 \ 31 | --num-basis 12 \ 32 | --num-token 4 \ 33 | --patch-size 1,1 \ 34 | --padding 0,0 \ 35 | --kernel-size 3 \ 36 | --offset-ratio-range 16,8 \ 37 | --resample-strategy learned \ 38 | --model-save-path ./checkpoints/bc \ 39 | --model-save-name gktrm.pt -------------------------------------------------------------------------------- /scripts/bc_gnot.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python exp_bc_h.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 4 \ 19 | --w-down 4 \ 20 | --T-in 10 \ 21 | --T-out 10 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0005 \ 24 | --epochs 101 \ 25 | --step-size 100 \ 26 | --model GNOT_2D \ 27 | --model-nickname gnot \ 28 | --d-model 220 \ 29 | --num-samples 512 \ 30 | --num-layers 4 \ 31 | --num-basis 12 \ 32 | --num-token 4 \ 33 | --patch-size 1,1 \ 34 | --padding 0,0 \ 35 | --kernel-size 3 \ 36 | --offset-ratio-range 16,8 \ 37 | --resample-strategy learned \ 38 | --model-save-path ./checkpoints/bc \ 39 | --model-save-name gnot.pt -------------------------------------------------------------------------------- /scripts/bc_lsm.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | 3 | python exp_bc.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 2 \ 19 | --w-down 2 \ 20 | --T-in 10 \ 21 | --T-out 10 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0002 \ 24 | --epochs 101 \ 25 | --step-size 100 \ 26 | --model LSM_2D \ 27 | --model-nickname lsm \ 28 | --d-model 64 \ 29 | --num-basis 12 \ 30 | --num-token 4 \ 31 | --patch-size 4,4 \ 32 | --padding 0,0 \ 33 | --model-save-path ./checkpoints/bc \ 34 | --model-save-name lsm.pt -------------------------------------------------------------------------------- /scripts/bc_unet.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=6 2 | 3 | python exp_bc_h.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 5 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 2 \ 19 | --w-down 2 \ 20 | --T-in 5 \ 21 | --T-out 5 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0002 \ 24 | --epochs 101 \ 25 | --step-size 100 \ 26 | --model UNet_2D \ 27 | --model-nickname unet \ 28 | --d-model 64 \ 29 | --num-samples 512 \ 30 | --num-layers 4 \ 31 | --num-basis 12 \ 32 | --num-token 4 \ 33 | --patch-size 4,4 \ 34 | --padding 0,0 \ 35 | --kernel-size 3 \ 36 | --offset-ratio-range 16,8 \ 37 | --resample-strategy learned \ 38 | --model-save-path ./checkpoints/bc \ 39 | --model-save-name unet.pt \ 40 | --delta-t 2 -------------------------------------------------------------------------------- /scripts/bc_vortex.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | 3 | python exp_bc_h_vortex.py \ 4 | --dataset-nickname bc \ 5 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 6 | --ntrain 2000 \ 7 | --ntest 500 \ 8 | --ntotal 3000 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 1 \ 12 | --out-var 1 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 512 \ 17 | --w 512 \ 18 | --h-down 4 \ 19 | --w-down 4 \ 20 | --T-in 10 \ 21 | --T-out 10 \ 22 | --batch-size 1 \ 23 | --learning-rate 0.0002 \ 24 | --epochs 101 \ 25 | --step-size 100 \ 26 | --model Vortex_2D \ 27 | --model-nickname vortex \ 28 | --d-model 64 \ 29 | --num-samples 512 \ 30 | --num-layers 4 \ 31 | --num-basis 12 \ 32 | --num-token 4 \ 33 | --patch-size 1,1 \ 34 | --padding 0,0 \ 35 | --kernel-size 3 \ 36 | --offset-ratio-range 16,8 \ 37 | --resample-strategy learned \ 38 | --model-save-path ./checkpoints/bc \ 39 | --model-save-name vortex.pt -------------------------------------------------------------------------------- /scripts/sea_deeplag.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | python exp_sea_h.py \ 4 | --dataset-nickname sea \ 5 | --data-path /home/miaoshangchen/NAS/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 2 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 20 \ 28 | --model DeepLag_2D \ 29 | --model-nickname deeplag \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 2,2 \ 36 | --padding 12,20 \ 37 | --kernel-size 5 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/sea \ 41 | --model-save-name deeplag.pt -------------------------------------------------------------------------------- /scripts/sea_factformer.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python exp_sea_h.py \ 4 | --dataset-nickname sea \ 5 | --data-path /home/miaoshangchen/NAS/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model Factformer_2D \ 29 | --model-nickname factformer \ 30 | --depth 6 \ 31 | --d-model 128 \ 32 | --heads 6 \ 33 | --dim-head 64 \ 34 | --num-samples 512 \ 35 | --num-layers 4 \ 36 | --num-basis 12 \ 37 | --num-token 4 \ 38 | --patch-size 4,4 \ 39 | --padding 0,0 \ 40 | --kernel-size 3 \ 41 | --offset-ratio-range 16,8 \ 42 | --resample-strategy learned \ 43 | --model-save-path ./checkpoints/sea \ 44 | --model-save-name factformer.pt -------------------------------------------------------------------------------- /scripts/sea_fno.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python exp_sea_h.py \ 4 | --dataset-nickname sea \ 5 | --data-path /home/miaoshangchen/NAS/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 10 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model FNO_2D \ 29 | --model-nickname fno \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 12 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 4,4 \ 36 | --padding 0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/sea \ 41 | --model-save-name fno.pt -------------------------------------------------------------------------------- /scripts/sea_gktrm.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python exp_sea_h.py \ 4 | --dataset-nickname sea \ 5 | --data-path /home/miaoshangchen/NAS/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 10 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model GkTrm_2D \ 29 | --model-nickname gktrm \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 4,4 \ 36 | --padding 0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/sea \ 41 | --model-save-name gktrm.pt -------------------------------------------------------------------------------- /scripts/sea_gnot.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python exp_sea_h.py \ 4 | --dataset-nickname sea \ 5 | --data-path /data/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model GNOT_2D \ 29 | --model-nickname gnot \ 30 | --d-model 96 \ 31 | --num-samples 512 \ 32 | --num-layers 3 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 4,4 \ 36 | --padding 0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/sea \ 41 | --model-save-name gnot.pt -------------------------------------------------------------------------------- /scripts/sea_lsm.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | python exp_sea_h.py \ 4 | --dataset-nickname sea \ 5 | --data-path /data/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 3 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model LSM_2D \ 29 | --model-nickname lsm \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 4,4 \ 36 | --padding 12,20 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/sea \ 41 | --model-save-name lsm.pt -------------------------------------------------------------------------------- /scripts/sea_unet.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=3 2 | 3 | python exp_sea_h.py \ 4 | --dataset-nickname sea \ 5 | --data-path /home/miaoshangchen/NAS/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 3 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model UNet_2D \ 29 | --model-nickname unet \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 4,4 \ 36 | --padding 0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/sea \ 41 | --model-save-name unet.pt -------------------------------------------------------------------------------- /scripts/sea_vortex.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=7 2 | # if NAS path to /home/miaoshangchen/NAS/sea_data_small/data_sea 3 | python exp_sea_h_vortex.py \ 4 | --dataset-nickname sea \ 5 | --data-path /home/miaoshangchen/NAS/sea_data_small/data_sea \ 6 | --region kuroshio \ 7 | --ntrain 3000 \ 8 | --ntest 600 \ 9 | --ntotal 3600 \ 10 | --in-dim 10 \ 11 | --out-dim 1 \ 12 | --in-var 5 \ 13 | --out-var 5 \ 14 | --has-t \ 15 | --tmin 0 \ 16 | --tmax 9 \ 17 | --h 180 \ 18 | --w 300 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --T-in 10 \ 22 | --T-out 10 \ 23 | --fill-value \-32760 \ 24 | --batch-size 3 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 20 \ 28 | --model Vortex_2D \ 29 | --model-nickname vortex \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 4,4 \ 36 | --padding 12,20 \ 37 | --kernel-size 5 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/sea \ 41 | --model-save-name vortex.pt -------------------------------------------------------------------------------- /scripts/smoke_deeplag3d.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python exp_smoke_h.py \ 4 | --dataset-nickname smoke \ 5 | --data-path /home/miaoshangchen/NAS/smoke_data \ 6 | --ntrain 1000 \ 7 | --ntest 200 \ 8 | --ntotal 1200 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 4 \ 12 | --out-var 4 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 32 \ 17 | --w 32 \ 18 | --z 32 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --z-down 1 \ 22 | --T-in 10 \ 23 | --T-out 10 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model DeepLag_3D \ 29 | --model-nickname deeplag3d \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 1,1,1 \ 36 | --padding 0,0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/smoke \ 41 | --model-save-name deeplag3d.pt -------------------------------------------------------------------------------- /scripts/smoke_factformer3d.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=4 2 | 3 | python exp_smoke_h.py \ 4 | --dataset-nickname smoke \ 5 | --data-path /home/miaoshangchen/NAS/smoke_data \ 6 | --ntrain 1000 \ 7 | --ntest 200 \ 8 | --ntotal 1200 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 4 \ 12 | --out-var 4 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 32 \ 17 | --w 32 \ 18 | --z 32 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --z-down 1 \ 22 | --T-in 10 \ 23 | --T-out 10 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model Factformer_3D \ 29 | --model-nickname factformer3d \ 30 | --depth 13 \ 31 | --d-model 64 \ 32 | --heads 6 \ 33 | --dim-head 64 \ 34 | --num-samples 512 \ 35 | --num-layers 4 \ 36 | --num-basis 12 \ 37 | --num-token 4 \ 38 | --patch-size 1,1,1 \ 39 | --padding 0,0,0 \ 40 | --kernel-size 3 \ 41 | --offset-ratio-range 16,8 \ 42 | --resample-strategy learned \ 43 | --model-save-path ./checkpoints/smoke \ 44 | --model-save-name factformer3d.pt -------------------------------------------------------------------------------- /scripts/smoke_fno3d.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | python exp_smoke_h.py \ 4 | --dataset-nickname smoke \ 5 | --data-path /home/miaoshangchen/NAS/smoke_data \ 6 | --ntrain 1000 \ 7 | --ntest 200 \ 8 | --ntotal 1200 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 4 \ 12 | --out-var 4 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 32 \ 17 | --w 32 \ 18 | --z 32 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --z-down 1 \ 22 | --T-in 10 \ 23 | --T-out 10 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model FNO_3D \ 29 | --model-nickname fno3d \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 1,1,1 \ 36 | --padding 0,0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/smoke \ 41 | --model-save-name fno3d.pt -------------------------------------------------------------------------------- /scripts/smoke_gktrm3d.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | 3 | python exp_smoke_h.py \ 4 | --dataset-nickname smoke \ 5 | --data-path /home/miaoshangchen/NAS/smoke_data \ 6 | --ntrain 1000 \ 7 | --ntest 200 \ 8 | --ntotal 1200 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 4 \ 12 | --out-var 4 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 32 \ 17 | --w 32 \ 18 | --z 32 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --z-down 1 \ 22 | --T-in 10 \ 23 | --T-out 10 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model GkTrm_3D \ 29 | --model-nickname gktrm3d \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 4,4,4 \ 36 | --padding 0,0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/smoke \ 41 | --model-save-name gktrm3d.pt -------------------------------------------------------------------------------- /scripts/smoke_gnot3d.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=6 2 | 3 | python exp_smoke_h.py \ 4 | --dataset-nickname smoke \ 5 | --data-path /home/miaoshangchen/NAS/smoke_data \ 6 | --ntrain 1000 \ 7 | --ntest 200 \ 8 | --ntotal 1200 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 4 \ 12 | --out-var 4 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 32 \ 17 | --w 32 \ 18 | --z 32 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --z-down 1 \ 22 | --T-in 10 \ 23 | --T-out 10 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model GNOT_3D \ 29 | --model-nickname gnot3d \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 6 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 1,1,1 \ 36 | --padding 0,0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/smoke \ 41 | --model-save-name gnot3d.pt -------------------------------------------------------------------------------- /scripts/smoke_lsm3d.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=2 2 | 3 | python exp_smoke_h.py \ 4 | --dataset-nickname smoke \ 5 | --data-path /home/miaoshangchen/NAS/smoke_data \ 6 | --ntrain 1000 \ 7 | --ntest 200 \ 8 | --ntotal 1200 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 4 \ 12 | --out-var 4 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 32 \ 17 | --w 32 \ 18 | --z 32 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --z-down 1 \ 22 | --T-in 10 \ 23 | --T-out 10 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model LSM_3D \ 29 | --model-nickname lsm3d \ 30 | --d-model 48 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 1,1,1 \ 36 | --padding 0,0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/smoke \ 41 | --model-save-name lsm3d.pt -------------------------------------------------------------------------------- /scripts/smoke_unet3d.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | 3 | python exp_smoke_h.py \ 4 | --dataset-nickname smoke \ 5 | --data-path /home/miaoshangchen/NAS/smoke_data \ 6 | --ntrain 1000 \ 7 | --ntest 200 \ 8 | --ntotal 1200 \ 9 | --in-dim 10 \ 10 | --out-dim 1 \ 11 | --in-var 4 \ 12 | --out-var 4 \ 13 | --has-t \ 14 | --tmin 0 \ 15 | --tmax 9 \ 16 | --h 32 \ 17 | --w 32 \ 18 | --z 32 \ 19 | --h-down 1 \ 20 | --w-down 1 \ 21 | --z-down 1 \ 22 | --T-in 10 \ 23 | --T-out 10 \ 24 | --batch-size 1 \ 25 | --learning-rate 0.0005 \ 26 | --epochs 101 \ 27 | --step-size 100 \ 28 | --model UNet_3D \ 29 | --model-nickname unet3d \ 30 | --d-model 64 \ 31 | --num-samples 512 \ 32 | --num-layers 4 \ 33 | --num-basis 12 \ 34 | --num-token 4 \ 35 | --patch-size 1,1,1 \ 36 | --padding 0,0,0 \ 37 | --kernel-size 3 \ 38 | --offset-ratio-range 16,8 \ 39 | --resample-strategy learned \ 40 | --model-save-path ./checkpoints/smoke \ 41 | --model-save-name unet3d.pt -------------------------------------------------------------------------------- /scripts/test_all.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=7 2 | 3 | 4 | # bounded NS 5 | python test_bc_h.py \ 6 | --ckpt-dir ./checkpoints \ 7 | --dataset-nickname bc \ 8 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 9 | --model-name UNet_2D \ 10 | --time-str 20240519_172142 \ 11 | --milestone best 12 | 13 | python test_bc_h.py \ 14 | --ckpt-dir ./checkpoints \ 15 | --dataset-nickname bc \ 16 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 17 | --model-name FNO_2D \ 18 | --time-str 20240520_093530 \ 19 | --milestone best 20 | 21 | python test_bc_h.py \ 22 | --ckpt-dir ./checkpoints \ 23 | --dataset-nickname bc \ 24 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 25 | --model-name GkTrm_2D \ 26 | --time-str 20240520_013510 \ 27 | --milestone best 28 | 29 | python test_bc_h.py \ 30 | --ckpt-dir ./checkpoints \ 31 | --dataset-nickname bc \ 32 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 33 | --model-name GNOT_2D \ 34 | --time-str 20240429_065015 \ 35 | --milestone best 36 | 37 | python test_bc_h.py \ 38 | --ckpt-dir ./checkpoints \ 39 | --dataset-nickname bc \ 40 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 41 | --model-name LSM_2D \ 42 | --time-str 20240520_094033 \ 43 | --milestone best 44 | 45 | python test_bc_h.py \ 46 | --ckpt-dir ./checkpoints \ 47 | --dataset-nickname bc \ 48 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 49 | --model-name Factformer_2D \ 50 | --time-str 20240429_064817 \ 51 | --milestone best 52 | 53 | python test_bc_h_vortex.py \ 54 | --ckpt-dir ./checkpoints \ 55 | --dataset-nickname bc \ 56 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 57 | --model-name Vortex_2D \ 58 | --time-str 20240520_055530 \ 59 | --milestone best 60 | 61 | python -u test_bc_h.py \ 62 | --ckpt-dir ./checkpoints \ 63 | --dataset-nickname bc \ 64 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 65 | --model-name DeepLag_2D \ 66 | --time-str 20240520_140626 \ 67 | --milestone best 68 | 69 | 70 | # Ocean Current 71 | python test_sea_h.py \ 72 | --ckpt-dir ./checkpoints \ 73 | --dataset-nickname sea \ 74 | --model-name UNet_2D \ 75 | --time-str 20240520_132532 \ 76 | --milestone best 77 | 78 | python test_sea_h.py \ 79 | --ckpt-dir ./checkpoints \ 80 | --dataset-nickname sea \ 81 | --model-name FNO_2D \ 82 | --time-str 20240520_020333 \ 83 | --milestone best 84 | 85 | python test_sea_h.py \ 86 | --ckpt-dir ./checkpoints \ 87 | --dataset-nickname sea \ 88 | --model-name GkTrm_2D \ 89 | --time-str 20240520_020707 \ 90 | --milestone best 91 | 92 | python test_sea_h.py \ 93 | --ckpt-dir ./checkpoints \ 94 | --dataset-nickname sea \ 95 | --model-name GNOT_2D \ 96 | --time-str 20240501_154700 \ 97 | --milestone best 98 | 99 | python test_sea_h.py \ 100 | --ckpt-dir ./checkpoints \ 101 | --dataset-nickname sea \ 102 | --model-name LSM_2D \ 103 | --time-str 20240520_132354 \ 104 | --milestone best 105 | 106 | python test_sea_h.py \ 107 | --ckpt-dir ./checkpoints \ 108 | --dataset-nickname sea \ 109 | --model-name Factformer_2D \ 110 | --time-str 20240501_152629 \ 111 | --milestone best 112 | 113 | python test_sea_h_vortex.py \ 114 | --ckpt-dir ./checkpoints \ 115 | --dataset-nickname sea \ 116 | --model-name Vortex_2D \ 117 | --time-str 20240522_023056 \ 118 | --milestone best 119 | 120 | python test_sea_h.py \ 121 | --ckpt-dir ./checkpoints \ 122 | --dataset-nickname sea \ 123 | --model-name DeepLag_2D \ 124 | --time-str 20240507_170237 \ 125 | --milestone best 126 | 127 | 128 | # Smoke 129 | python test_smoke_h.py \ 130 | --ckpt-dir ./checkpoints \ 131 | --dataset-nickname smoke \ 132 | --model-name UNet_3D \ 133 | --time-str 20240520_064624 \ 134 | --milestone best 135 | 136 | python test_smoke_h.py \ 137 | --ckpt-dir ./checkpoints \ 138 | --dataset-nickname smoke \ 139 | --model-name FNO_3D \ 140 | --time-str 20240520_064910 \ 141 | --milestone best 142 | 143 | python test_smoke_h.py \ 144 | --ckpt-dir ./checkpoints \ 145 | --dataset-nickname smoke \ 146 | --model-name GkTrm_3D \ 147 | --time-str 20240501_160526 \ 148 | --milestone best 149 | 150 | python test_smoke_h.py \ 151 | --ckpt-dir ./checkpoints \ 152 | --dataset-nickname smoke \ 153 | --model-name GNOT_3D \ 154 | --time-str 20240503_060448 \ 155 | --milestone best 156 | 157 | python test_smoke_h.py \ 158 | --ckpt-dir ./checkpoints \ 159 | --dataset-nickname smoke \ 160 | --model-name LSM_3D \ 161 | --time-str 20240520_063957 \ 162 | --milestone best 163 | 164 | python test_smoke_h.py \ 165 | --ckpt-dir ./checkpoints \ 166 | --dataset-nickname smoke \ 167 | --model-name Factformer_3D \ 168 | --time-str 20240520_173349 \ 169 | --milestone best 170 | 171 | python test_smoke_h.py \ 172 | --ckpt-dir ./checkpoints \ 173 | --dataset-nickname smoke \ 174 | --model-name DeepLag_3D \ 175 | --time-str 20240520_172807 \ 176 | --milestone best 177 | -------------------------------------------------------------------------------- /scripts/test_all_longrollout.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=5 2 | 3 | 4 | # bounded NS 5 | python test_bc_h.py \ 6 | --ckpt-dir ./checkpoints \ 7 | --dataset-nickname bc \ 8 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 9 | --model-name UNet_2D \ 10 | --time-str 20240519_172142 \ 11 | --milestone best \ 12 | --T-out 30 13 | 14 | python test_bc_h.py \ 15 | --ckpt-dir ./checkpoints \ 16 | --dataset-nickname bc \ 17 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 18 | --model-name FNO_2D \ 19 | --time-str 20240520_093530 \ 20 | --milestone best \ 21 | --T-out 30 22 | 23 | python test_bc_h.py \ 24 | --ckpt-dir ./checkpoints \ 25 | --dataset-nickname bc \ 26 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 27 | --model-name GkTrm_2D \ 28 | --time-str 20240520_013510 \ 29 | --milestone best \ 30 | --T-out 30 31 | 32 | python test_bc_h.py \ 33 | --ckpt-dir ./checkpoints \ 34 | --dataset-nickname bc \ 35 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 36 | --model-name GNOT_2D \ 37 | --time-str 20240429_065015 \ 38 | --milestone best \ 39 | --T-out 30 40 | 41 | python test_bc_h.py \ 42 | --ckpt-dir ./checkpoints \ 43 | --dataset-nickname bc \ 44 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 45 | --model-name LSM_2D \ 46 | --time-str 20240520_094033 \ 47 | --milestone best \ 48 | --T-out 30 49 | 50 | python test_bc_h.py \ 51 | --ckpt-dir ./checkpoints \ 52 | --dataset-nickname bc \ 53 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 54 | --model-name Factformer_2D \ 55 | --time-str 20240429_064817 \ 56 | --milestone best \ 57 | --T-out 30 58 | 59 | python test_bc_h_vortex.py \ 60 | --ckpt-dir ./checkpoints \ 61 | --dataset-nickname bc \ 62 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 63 | --model-name Vortex_2D \ 64 | --time-str 20240520_055530 \ 65 | --milestone best \ 66 | --T-out 30 67 | 68 | python -u test_bc_h.py \ 69 | --ckpt-dir ./checkpoints \ 70 | --dataset-nickname bc \ 71 | --data-path /home/miaoshangchen/NAS/Bounded_NS \ 72 | --model-name DeepLag_2D \ 73 | --time-str 20240520_140626 \ 74 | --milestone best \ 75 | --T-out 30 76 | 77 | 78 | # Ocean Current 79 | python test_sea_h.py \ 80 | --ckpt-dir ./checkpoints \ 81 | --dataset-nickname sea \ 82 | --model-name UNet_2D \ 83 | --time-str 20240520_132532 \ 84 | --milestone best \ 85 | --T-out 30 86 | 87 | python test_sea_h.py \ 88 | --ckpt-dir ./checkpoints \ 89 | --dataset-nickname sea \ 90 | --model-name FNO_2D \ 91 | --time-str 20240520_020333 \ 92 | --milestone best \ 93 | --T-out 30 94 | 95 | python test_sea_h.py \ 96 | --ckpt-dir ./checkpoints \ 97 | --dataset-nickname sea \ 98 | --model-name GkTrm_2D \ 99 | --time-str 20240520_020707 \ 100 | --milestone best \ 101 | --T-out 30 102 | 103 | python test_sea_h.py \ 104 | --ckpt-dir ./checkpoints \ 105 | --dataset-nickname sea \ 106 | --model-name GNOT_2D \ 107 | --time-str 20240501_154700 \ 108 | --milestone best \ 109 | --T-out 30 110 | 111 | python test_sea_h.py \ 112 | --ckpt-dir ./checkpoints \ 113 | --dataset-nickname sea \ 114 | --model-name LSM_2D \ 115 | --time-str 20240520_132354 \ 116 | --milestone best \ 117 | --T-out 30 118 | 119 | python test_sea_h.py \ 120 | --ckpt-dir ./checkpoints \ 121 | --dataset-nickname sea \ 122 | --model-name Factformer_2D \ 123 | --time-str 20240501_152629 \ 124 | --milestone best \ 125 | --T-out 30 126 | 127 | python test_sea_h_vortex.py \ 128 | --ckpt-dir ./checkpoints \ 129 | --dataset-nickname sea \ 130 | --model-name Vortex_2D \ 131 | --time-str 20240522_023056 \ 132 | --milestone best \ 133 | --T-out 30 134 | 135 | python test_sea_h.py \ 136 | --ckpt-dir ./checkpoints \ 137 | --dataset-nickname sea \ 138 | --model-name DeepLag_2D \ 139 | --time-str 20240507_170237 \ 140 | --milestone best \ 141 | --T-out 30 142 | 143 | 144 | # Smoke 145 | python test_smoke_h.py \ 146 | --ckpt-dir ./checkpoints \ 147 | --dataset-nickname smoke \ 148 | --model-name UNet_3D \ 149 | --time-str 20240520_064624 \ 150 | --milestone best \ 151 | --T-out 30 152 | 153 | python test_smoke_h.py \ 154 | --ckpt-dir ./checkpoints \ 155 | --dataset-nickname smoke \ 156 | --model-name FNO_3D \ 157 | --time-str 20240520_064910 \ 158 | --milestone best \ 159 | --T-out 30 160 | 161 | python test_smoke_h.py \ 162 | --ckpt-dir ./checkpoints \ 163 | --dataset-nickname smoke \ 164 | --model-name GkTrm_3D \ 165 | --time-str 20240501_160526 \ 166 | --milestone best \ 167 | --T-out 30 168 | 169 | python test_smoke_h.py \ 170 | --ckpt-dir ./checkpoints \ 171 | --dataset-nickname smoke \ 172 | --model-name GNOT_3D \ 173 | --time-str 20240503_060448 \ 174 | --milestone best \ 175 | --T-out 30 176 | 177 | python test_smoke_h.py \ 178 | --ckpt-dir ./checkpoints \ 179 | --dataset-nickname smoke \ 180 | --model-name LSM_3D \ 181 | --time-str 20240520_063957 \ 182 | --milestone best \ 183 | --T-out 30 184 | 185 | python test_smoke_h.py \ 186 | --ckpt-dir ./checkpoints \ 187 | --dataset-nickname smoke \ 188 | --model-name Factformer_3D \ 189 | --time-str 20240520_173349 \ 190 | --milestone best \ 191 | --T-out 30 192 | 193 | python test_smoke_h.py \ 194 | --ckpt-dir ./checkpoints \ 195 | --dataset-nickname smoke \ 196 | --model-name DeepLag_3D \ 197 | --time-str 20240520_172807 \ 198 | --milestone best \ 199 | --T-out 30 200 | -------------------------------------------------------------------------------- /test_bc_h.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import logging 4 | import pickle 5 | from timeit import default_timer 6 | from datetime import datetime, timedelta 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from utils.data_factory import get_bc_dataset, BoundedNSDataset 11 | from utils.utilities3 import * 12 | from utils.params import get_args, get_test_args 13 | from utils.adam import Adam 14 | from model_dict import get_model 15 | 16 | from tqdm import tqdm 17 | 18 | torch.manual_seed(0) 19 | np.random.seed(0) 20 | torch.cuda.manual_seed(0) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | ################################################################ 25 | # configs 26 | ################################################################ 27 | test_args = get_test_args() 28 | ckpt_dir = test_args.ckpt_dir 29 | dataset_nickname = test_args.dataset_nickname 30 | model_name = test_args.model_name 31 | time_str = test_args.time_str 32 | milestone = test_args.milestone 33 | T_out = test_args.T_out 34 | 35 | args = get_args(cfg_file=Path(ckpt_dir)/dataset_nickname/model_name/time_str/'configs.txt') 36 | test_save_path = os.path.join(args.run_save_path, f'test_{milestone}_{T_out}') 37 | if not os.path.isdir(test_save_path): 38 | os.makedirs(test_save_path) 39 | 40 | LOG_FORMAT = "%(message)s" 41 | logger = logging.getLogger('Loss logger') 42 | logger.setLevel(logging.INFO) 43 | f_handler = logging.FileHandler(os.path.join(test_save_path, args.log_save_name)) 44 | f_handler.setLevel(logging.INFO) 45 | f_handler.setFormatter(logging.Formatter(LOG_FORMAT)) 46 | logger.addHandler(f_handler) 47 | 48 | TRAIN_PATH = os.path.join(test_args.data_path, 're0.25_4c_gray_10000.npy') 49 | TEST_PATH = os.path.join(test_args.data_path, 're0.25_4c_gray_10000.npy') 50 | BOUNDARY_PATH = os.path.join(test_args.data_path, 'boundary_4c_rot.npy') 51 | 52 | padding = [int(p) for p in args.padding.split(',')] 53 | ntrain = args.ntrain 54 | ntest = args.ntest 55 | N = args.ntotal 56 | args.in_channels = args.in_dim * args.in_var 57 | args.out_channels = args.out_dim * args.out_var 58 | r1 = args.h_down 59 | r2 = args.w_down 60 | s1 = int(((args.h - 1) / r1) + 1) 61 | s2 = int(((args.w - 1) / r2) + 1) 62 | T_in = args.T_in 63 | # T_out = args.T_out 64 | patch_size = tuple(int(x) for x in args.patch_size.split(',')) 65 | 66 | batch_size = args.batch_size 67 | learning_rate = args.learning_rate 68 | epochs = args.epochs 69 | step_size = args.step_size 70 | gamma = args.gamma 71 | 72 | model_save_path = args.model_save_path 73 | model_save_name = args.model_save_name 74 | 75 | 76 | ################################################################ 77 | # models 78 | ################################################################ 79 | model = get_model(args, ckpt_dir=Path(ckpt_dir)/dataset_nickname/model_name/time_str) 80 | state_dict = torch.load(Path(ckpt_dir)/dataset_nickname/model_name/time_str/ (model_save_name[:-3]+f'_{milestone}.pt')) 81 | model.load_state_dict(state_dict) 82 | 83 | 84 | ################################################################ 85 | # load data and data normalization 86 | ################################################################ 87 | train_dataset = BoundedNSDataset(args, dataset_file=TRAIN_PATH, split='train') 88 | test_dataset = BoundedNSDataset(args, dataset_file=TEST_PATH, split='test') 89 | train_loader = train_dataset.loader() 90 | test_loader = test_dataset.loader() 91 | 92 | boundary, domain = process_boundary_condition(BOUNDARY_PATH, ds_rate=(r1,r2)) 93 | 94 | if 'DeepLag' in args.model: 95 | model.set_bdydom(boundary, domain) 96 | if args.resample_strategy == 'uniform' or args.resample_strategy == 'learned': 97 | model.num_samples = min(model.num_samples, s1*s2) 98 | elif args.resample_strategy == 'boundary': 99 | model.num_samples = min(model.num_samples, model.coo_boundary_ms[0].shape[0]) 100 | elif args.resample_strategy == 'domain': 101 | model.num_samples = min(model.num_samples, model.coo_domain_ms[0].shape[0]) 102 | 103 | 104 | ################################################################ 105 | # evaluation 106 | ################################################################ 107 | myloss = LpLoss(size_average=False) 108 | 109 | step = 1 110 | min_test_l2_full = 114514 111 | 112 | t1 = default_timer() 113 | test_l2_step = 0 114 | test_l2_full = 0 115 | timewise_l2_step = torch.zeros(T_out//step).to(device) 116 | timewise_l2_full = torch.zeros(T_out//step).to(device) 117 | print('ready') 118 | with torch.no_grad(): 119 | for batch_idx, (xx, yy) in enumerate(tqdm(test_loader)): 120 | loss = 0 121 | xx = xx.to(device) 122 | yy = yy.to(device) 123 | if 'DeepLag' in args.model: 124 | h_x_q, h_coo_q, h_coo_offset_q = [], [], [] 125 | for i in range(model.num_layers): 126 | if args.resample_strategy == 'uniform': 127 | num_samples = model.num_samples // (4**i) 128 | coo_q = torch.cat([ 129 | torch.randint(0,model.img_h_layers[i]-1,(batch_size,num_samples,1)), 130 | torch.randint(0,model.img_w_layers[i]-1,(batch_size,num_samples,1)) 131 | ], dim=-1).to(torch.float32) # b k 2 132 | elif args.resample_strategy == 'boundary': 133 | num_samples = min(model.num_samples//(4**i), model.coo_boundary_ms[i].shape[0]) 134 | idx_coo_sample = torch.multinomial(1./torch.ones(model.coo_boundary_ms[i].shape[0]), num_samples, replacement=False) # k 135 | coo_q = model.coo_boundary_ms[i][idx_coo_sample][None, ...].repeat(batch_size,1,1).to(torch.float32) # b k 2 136 | elif args.resample_strategy == 'domain': 137 | num_samples = min(model.num_samples//(4**i), model.coo_domain_ms[i].shape[0]) 138 | idx_coo_sample = torch.multinomial(1./torch.ones(model.coo_domain_ms[i].shape[0]), num_samples, replacement=False) # k 139 | coo_q = model.coo_domain_ms[i][idx_coo_sample][None, ...].repeat(batch_size,1,1).to(torch.float32) # b k 2 140 | elif args.resample_strategy == 'learned': 141 | num_samples = model.num_samples // (4**i) 142 | coo_q = None # new_prob 143 | num_chan = args.d_model*(2**i) if i < model.num_layers-1 else args.d_model*(2**(i-1)) 144 | h_x_q.append(torch.zeros(batch_size, num_samples, num_chan).to(device)) 145 | h_coo_q.append(coo_q.to(device) if args.resample_strategy != 'learned' else None) # new_prob 146 | h_coo_offset_q.append(torch.zeros(batch_size, num_samples, 2).to(device)) 147 | 148 | for i, t in enumerate(range(0, T_out, step)): 149 | y = yy[..., t*args.out_var : (t + step)*args.out_var] # B H W C_out=V_out 150 | if 'DeepLag' in args.model: 151 | im, h_x_q, h_coo_q, h_coo_offset_q, coo_offset_xys = model(xx, h_x_q, h_coo_q, h_coo_offset_q) # B H W C_out=V_out # with coo_offset 152 | else: 153 | im = model(xx) 154 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 155 | timewise_l2_step[i] += myloss(im, y) 156 | 157 | if t == 0: 158 | pred = im 159 | else: 160 | pred = torch.cat((pred, im), -1) 161 | 162 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 163 | 164 | test_l2_step += loss.item() 165 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 166 | for i, t in enumerate(range(0, T_out, step)): 167 | timewise_l2_full[i] += myloss(pred[..., i*args.in_var:(i+1)*args.in_var], yy[..., i*args.in_var:(i+1)*args.in_var]) 168 | 169 | t2 = default_timer() 170 | if test_l2_full / ntest < min_test_l2_full: 171 | print(t2 - t1, 172 | 'test_rel_l2:', 173 | test_l2_step / ntest / (T_out / step), 174 | test_l2_full / ntest, 175 | 'timewise_l2:', 176 | timewise_l2_step / ntest, 177 | timewise_l2_full / ntest) 178 | logger.info(f'{t2 - t1} ' + \ 179 | f'test_rel_l2: {test_l2_step / ntest / (T_out / step)} {test_l2_full / ntest} ' + \ 180 | f'timewise_l2: {timewise_l2_step / ntest} {timewise_l2_full / ntest}') 181 | pd = pred[-1, :, :, -1].detach().cpu().numpy() 182 | gt = yy[-1, :, :, -1].detach().cpu().numpy() 183 | visual(pd, os.path.join(test_save_path, f'{milestone}_pred.png')) 184 | visual(gt, os.path.join(test_save_path, f'{milestone}_gt.png')) 185 | visual(np.abs(gt-pd), os.path.join(test_save_path, f'{milestone}_err.png')) 186 | else: 187 | raise Exception('Abnormal loss!') 188 | -------------------------------------------------------------------------------- /test_bc_h_vortex.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import logging 4 | import pickle 5 | from timeit import default_timer 6 | from datetime import datetime, timedelta 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from utils.data_factory import get_bc_dataset, BoundedNSDataset 11 | from utils.utilities3 import * 12 | from utils.params import get_args, get_test_args 13 | from utils.adam import Adam 14 | from model_dict import get_model 15 | 16 | from tqdm import tqdm 17 | 18 | torch.manual_seed(0) 19 | np.random.seed(0) 20 | torch.cuda.manual_seed(0) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | ################################################################ 25 | # configs 26 | ################################################################ 27 | test_args = get_test_args() 28 | ckpt_dir = test_args.ckpt_dir 29 | dataset_nickname = test_args.dataset_nickname 30 | model_name = test_args.model_name 31 | time_str = test_args.time_str 32 | milestone = test_args.milestone 33 | T_out = test_args.T_out 34 | 35 | args = get_args(cfg_file=Path(ckpt_dir)/dataset_nickname/model_name/time_str/'configs.txt') 36 | test_save_path = os.path.join(args.run_save_path, f'test_{milestone}_{T_out}') 37 | if not os.path.isdir(test_save_path): 38 | os.makedirs(test_save_path) 39 | 40 | LOG_FORMAT = "%(message)s" 41 | logger = logging.getLogger('Loss logger') 42 | logger.setLevel(logging.INFO) 43 | f_handler = logging.FileHandler(os.path.join(test_save_path, args.log_save_name)) 44 | f_handler.setLevel(logging.INFO) 45 | f_handler.setFormatter(logging.Formatter(LOG_FORMAT)) 46 | logger.addHandler(f_handler) 47 | 48 | TRAIN_PATH = os.path.join(test_args.data_path, 're0.25_4c_gray_10000.npy') 49 | TEST_PATH = os.path.join(test_args.data_path, 're0.25_4c_gray_10000.npy') 50 | BOUNDARY_PATH = os.path.join(test_args.data_path, 'boundary_4c_rot.npy') 51 | 52 | padding = [int(p) for p in args.padding.split(',')] 53 | ntrain = args.ntrain 54 | ntest = args.ntest 55 | N = args.ntotal 56 | args.in_channels = args.in_dim * args.in_var 57 | args.out_channels = args.out_dim * args.out_var 58 | r1 = args.h_down 59 | r2 = args.w_down 60 | s1 = int(((args.h - 1) / r1) + 1) 61 | s2 = int(((args.w - 1) / r2) + 1) 62 | T_in = args.T_in 63 | # T_out = args.T_out 64 | patch_size = tuple(int(x) for x in args.patch_size.split(',')) 65 | 66 | batch_size = args.batch_size 67 | learning_rate = args.learning_rate 68 | epochs = args.epochs 69 | step_size = args.step_size 70 | gamma = args.gamma 71 | 72 | model_save_path = args.model_save_path 73 | model_save_name = args.model_save_name 74 | 75 | 76 | ################################################################ 77 | # models 78 | ################################################################ 79 | model = get_model(args, ckpt_dir=Path(ckpt_dir)/dataset_nickname/model_name/time_str) 80 | state_dict = torch.load(Path(ckpt_dir)/dataset_nickname/model_name/time_str/ (model_save_name[:-3]+f'_{milestone}.pt')) 81 | model.load_state_dict(state_dict) 82 | 83 | 84 | ################################################################ 85 | # load data and data normalization 86 | ################################################################ 87 | train_dataset = BoundedNSDataset(args, dataset_file=TRAIN_PATH, split='train', return_idx=True) 88 | test_dataset = BoundedNSDataset(args, dataset_file=TEST_PATH, split='test', return_idx=True) 89 | train_loader = train_dataset.loader() 90 | test_loader = test_dataset.loader() 91 | 92 | boundary, domain = process_boundary_condition(BOUNDARY_PATH, ds_rate=(r1,r2)) 93 | 94 | 95 | ################################################################ 96 | # evaluation 97 | ################################################################ 98 | myloss = LpLoss(size_average=False) 99 | 100 | step = 1 101 | min_test_l2_full = 114514 102 | 103 | t1 = default_timer() 104 | test_l2_step = 0 105 | test_l2_full = 0 106 | timewise_l2_step = torch.zeros(T_out//step).to(device) 107 | timewise_l2_full = torch.zeros(T_out//step).to(device) 108 | print('ready') 109 | with torch.no_grad(): 110 | for batch_idx, (index, xx, yy) in enumerate(tqdm(test_loader)): 111 | index = index.to(device) * 0.01 112 | loss = 0 113 | xx = xx.to(device) 114 | yy = yy.to(device) 115 | xx /= 256.0 116 | 117 | for i, t in enumerate(range(0, T_out, step)): 118 | y = yy[..., t*args.out_var : (t + step)*args.out_var] # B H W C_out=V_out 119 | im, _ = model(xx, index) 120 | im *= 256.0 121 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 122 | timewise_l2_step[i] += myloss(im, y) 123 | 124 | if t == 0: 125 | pred = im 126 | else: 127 | pred = torch.cat((pred, im), -1) 128 | 129 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 130 | 131 | test_l2_step += loss.item() 132 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 133 | for i, t in enumerate(range(0, T_out, step)): 134 | timewise_l2_full[i] += myloss(pred[..., i*args.in_var:(i+1)*args.in_var], yy[..., i*args.in_var:(i+1)*args.in_var]) 135 | 136 | t2 = default_timer() 137 | if test_l2_full / ntest < min_test_l2_full: 138 | print(t2 - t1, 139 | 'test_rel_l2:', 140 | test_l2_step / ntest / (T_out / step), 141 | test_l2_full / ntest, 142 | 'timewise_l2:', 143 | timewise_l2_step / ntest, 144 | timewise_l2_full / ntest) 145 | logger.info(f'{t2 - t1} ' + \ 146 | f'test_rel_l2: {test_l2_step / ntest / (T_out / step)} {test_l2_full / ntest} ' + \ 147 | f'timewise_l2: {timewise_l2_step / ntest} {timewise_l2_full / ntest}') 148 | pd = pred[-1, :, :, -1].detach().cpu().numpy() 149 | gt = yy[-1, :, :, -1].detach().cpu().numpy() 150 | visual(pd, os.path.join(test_save_path, f'{milestone}_pred.png')) 151 | visual(gt, os.path.join(test_save_path, f'{milestone}_gt.png')) 152 | visual(np.abs(gt-pd), os.path.join(test_save_path, f'{milestone}_err.png')) 153 | else: 154 | raise Exception('Abnormal loss!') 155 | -------------------------------------------------------------------------------- /test_sea_h.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import logging 4 | import pickle 5 | from timeit import default_timer 6 | from datetime import datetime, timedelta 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from utils.data_factory import SeaDataset, SeaDatasetMemory 11 | from utils.utilities3 import * 12 | from utils.params import get_args, get_test_args 13 | from utils.adam import Adam 14 | from model_dict import get_model 15 | 16 | from tqdm import tqdm 17 | 18 | torch.manual_seed(0) 19 | np.random.seed(0) 20 | torch.cuda.manual_seed(0) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | ################################################################ 25 | # configs 26 | ################################################################ 27 | test_args = get_test_args() 28 | ckpt_dir = test_args.ckpt_dir 29 | dataset_nickname = test_args.dataset_nickname 30 | model_name = test_args.model_name 31 | time_str = test_args.time_str 32 | milestone = test_args.milestone 33 | T_out = test_args.T_out 34 | 35 | args = get_args(cfg_file=Path(ckpt_dir)/dataset_nickname/model_name/time_str/'configs.txt') 36 | test_save_path = os.path.join(args.run_save_path, f'test_{milestone}_{T_out}') 37 | if not os.path.isdir(test_save_path): 38 | os.makedirs(test_save_path) 39 | 40 | LOG_FORMAT = "%(message)s" 41 | logger = logging.getLogger('Loss logger') 42 | logger.setLevel(logging.INFO) 43 | f_handler = logging.FileHandler(os.path.join(test_save_path, args.log_save_name)) 44 | f_handler.setLevel(logging.INFO) 45 | f_handler.setFormatter(logging.Formatter(LOG_FORMAT)) 46 | logger.addHandler(f_handler) 47 | 48 | padding = [int(p) for p in args.padding.split(',')] 49 | ntrain = args.ntrain 50 | ntest = args.ntest 51 | N = args.ntotal 52 | args.in_channels = args.in_dim * args.in_var 53 | args.out_channels = args.out_dim * args.out_var 54 | r1 = args.h_down 55 | r2 = args.w_down 56 | s1 = int(((args.h - 1) / r1) + 1) 57 | s2 = int(((args.w - 1) / r2) + 1) 58 | T_in = args.T_in 59 | # T_out = args.T_out 60 | patch_size = tuple(int(x) for x in args.patch_size.split(',')) 61 | 62 | batch_size = args.batch_size 63 | learning_rate = args.learning_rate 64 | epochs = args.epochs 65 | step_size = args.step_size 66 | gamma = args.gamma 67 | 68 | model_save_path = args.model_save_path 69 | model_save_name = args.model_save_name 70 | 71 | 72 | ################################################################ 73 | # models 74 | ################################################################ 75 | model = get_model(args, ckpt_dir=Path(ckpt_dir)/dataset_nickname/model_name/time_str) 76 | state_dict = torch.load(Path(ckpt_dir)/dataset_nickname/model_name/time_str/ (model_save_name[:-3]+f'_{milestone}.pt')) 77 | model.load_state_dict(state_dict) 78 | 79 | 80 | ################################################################ 81 | # load data and data normalization 82 | ################################################################ 83 | train_dataset = SeaDatasetMemory(args, region=args.region, split='train') 84 | test_dataset = SeaDatasetMemory(args, region=args.region, split='test') 85 | train_loader = train_dataset.loader() 86 | test_loader = test_dataset.loader() 87 | 88 | land, sea = get_land_sea_mask(args.data_path, args.fill_value) 89 | if 'DeepLag' in args.model: 90 | model.set_bdydom(land, sea) 91 | if args.resample_strategy == 'uniform' or args.resample_strategy == 'learned': 92 | model.num_samples = min(model.num_samples, s1*s2) 93 | elif args.resample_strategy == 'boundary': 94 | model.num_samples = min(model.num_samples, model.coo_boundary_ms[0].shape[0]) 95 | elif args.resample_strategy == 'domain': 96 | model.num_samples = min(model.num_samples, model.coo_domain_ms[0].shape[0]) 97 | 98 | data_mean = np.load(Path(args.data_path)/'..'/f'sea_{args.region}_mean.npy') 99 | 100 | 101 | ################################################################ 102 | # evaluation 103 | ################################################################ 104 | myloss = LpLoss(size_average=False, channel_wise=False) 105 | mseloss = nn.MSELoss() 106 | 107 | step = 1 108 | min_test_l2_full = 114514 109 | 110 | t1 = default_timer() 111 | test_l2_step = 0 112 | test_l2_full = 0 113 | test_vor_step = 0 114 | test_vor_full = 0 115 | test_acc_step = torch.zeros(T_out//step).to(device) 116 | test_acc_full = torch.zeros(T_out//step).to(device) 117 | with torch.no_grad(): 118 | for batch_idx, (xx, yy) in enumerate(tqdm(test_loader)): 119 | loss = 0 120 | vor_loss = 0 121 | xx = xx.to(device) 122 | yy = yy.to(device) 123 | if 'DeepLag' in args.model: 124 | h_x_q, h_coo_q, h_coo_offset_q = [], [], [] 125 | for i in range(model.num_layers): 126 | if args.resample_strategy == 'uniform': 127 | num_samples = model.num_samples // (4**i) 128 | coo_q = torch.cat([ 129 | torch.randint(0,model.img_h_layers[i]-1,(batch_size,num_samples,1)), 130 | torch.randint(0,model.img_w_layers[i]-1,(batch_size,num_samples,1)) 131 | ], dim=-1).to(torch.float32) # b k 2 132 | elif args.resample_strategy == 'boundary': 133 | num_samples = min(model.num_samples//(4**i), model.coo_boundary_ms[i].shape[0]) 134 | idx_coo_sample = torch.multinomial(1./torch.ones(model.coo_boundary_ms[i].shape[0]), num_samples, replacement=False) # k 135 | coo_q = model.coo_boundary_ms[i][idx_coo_sample][None, ...].repeat(batch_size,1,1).to(torch.float32) # b k 2 136 | elif args.resample_strategy == 'domain': 137 | num_samples = min(model.num_samples//(4**i), model.coo_domain_ms[i].shape[0]) 138 | idx_coo_sample = torch.multinomial(1./torch.ones(model.coo_domain_ms[i].shape[0]), num_samples, replacement=False) # k 139 | coo_q = model.coo_domain_ms[i][idx_coo_sample][None, ...].repeat(batch_size,1,1).to(torch.float32) # b k 2 140 | elif args.resample_strategy == 'learned': 141 | num_samples = model.num_samples // (4**i) 142 | coo_q = None # new_prob 143 | num_chan = args.d_model*(2**i) if i < model.num_layers-1 else args.d_model*(2**(i-1)) 144 | h_x_q.append(torch.zeros(batch_size, num_samples, num_chan).to(device)) 145 | h_coo_q.append(coo_q.to(device) if args.resample_strategy != 'learned' else None) # new_prob 146 | h_coo_offset_q.append(torch.zeros(batch_size, num_samples, 2).to(device)) 147 | 148 | for i, t in enumerate(range(0, T_out, step)): 149 | y = yy[..., t*args.out_var : (t + step)*args.out_var] # B H W C_out=V_out 150 | if 'DeepLag' in args.model: 151 | im, h_x_q, h_coo_q, h_coo_offset_q, coo_offset_xys = model(xx, h_x_q, h_coo_q, h_coo_offset_q) # B H W C_out=V_out # with coo_offset 152 | else: 153 | im = model(xx) 154 | loss += myloss(im, y) 155 | vor_loss += mseloss(vorticity(-im[..., -2], im[..., -3]), vorticity(-y[..., -2], y[..., -3])) 156 | test_acc_step[i] += correct_acc_loss(im, y, data_mean) 157 | 158 | if t == 0: 159 | pred = im 160 | else: 161 | pred = torch.cat((pred, im), -1) 162 | 163 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 164 | 165 | test_l2_step += loss.item() 166 | test_l2_full += myloss(pred, yy).item() 167 | test_vor_step += vor_loss.item() 168 | test_vor_full += mseloss(vorticity(-pred[..., 3::args.in_var], pred[..., 2::args.in_var]), vorticity(-yy[..., 3::args.in_var], yy[..., 2::args.in_var])).item() 169 | for i, t in enumerate(range(0, T_out, step)): 170 | test_acc_full[i] += correct_acc_loss(pred[..., i*args.in_var:(i+1)*args.in_var], yy[..., i*args.in_var:(i+1)*args.in_var], data_mean) 171 | 172 | t2 = default_timer() 173 | if test_l2_full / ntest < min_test_l2_full: 174 | print(t2 - t1, 175 | 'test_rel_l2:', 176 | test_l2_step / ntest / (T_out / step), 177 | test_l2_full / ntest, 178 | 'test_vor:', 179 | test_vor_step / ntest / (T_out / step), 180 | test_vor_full / ntest, 181 | 'test_acc:', 182 | test_acc_step / ntest, 183 | test_acc_full / ntest) 184 | logger.info(f'{t2 - t1} ' + \ 185 | f'test_rel_l2: {test_l2_step / ntest / (T_out / step)} {test_l2_full / ntest} ' + \ 186 | f'test_vor: {test_vor_step / ntest / (T_out / step)} {test_vor_full / ntest} ' + \ 187 | f'test_acc: {test_acc_step / ntest} {test_acc_full / ntest}') 188 | pd = pred[-1, :, :, -5:].detach().cpu().numpy() 189 | gt = yy[-1, :, :, -5:].detach().cpu().numpy() 190 | vars = ['thetao', 'so', 'uo', 'vo', 'zos'] 191 | for i in range(5): 192 | visual(pd[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_pred.png')) 193 | visual(gt[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_gt.png')) 194 | visual(np.abs(gt-pd)[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_err.png')) 195 | else: 196 | raise Exception('Abnormal loss!') 197 | -------------------------------------------------------------------------------- /test_sea_h_vortex.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import logging 4 | import pickle 5 | from timeit import default_timer 6 | from datetime import datetime, timedelta 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from utils.data_factory import SeaDataset, SeaDatasetMemory 11 | from utils.utilities3 import * 12 | from utils.params import get_args, get_test_args 13 | from utils.adam import Adam 14 | from model_dict import get_model 15 | 16 | from tqdm import tqdm 17 | 18 | torch.manual_seed(0) 19 | np.random.seed(0) 20 | torch.cuda.manual_seed(0) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | ################################################################ 25 | # configs 26 | ################################################################ 27 | test_args = get_test_args() 28 | ckpt_dir = test_args.ckpt_dir 29 | dataset_nickname = test_args.dataset_nickname 30 | model_name = test_args.model_name 31 | time_str = test_args.time_str 32 | milestone = test_args.milestone 33 | T_out = test_args.T_out 34 | 35 | args = get_args(cfg_file=Path(ckpt_dir)/dataset_nickname/model_name/time_str/'configs.txt') 36 | test_save_path = os.path.join(args.run_save_path, f'test_{milestone}_{T_out}') 37 | if not os.path.isdir(test_save_path): 38 | os.makedirs(test_save_path) 39 | 40 | LOG_FORMAT = "%(message)s" 41 | logger = logging.getLogger('Loss logger') 42 | logger.setLevel(logging.INFO) 43 | f_handler = logging.FileHandler(os.path.join(test_save_path, args.log_save_name)) 44 | f_handler.setLevel(logging.INFO) 45 | f_handler.setFormatter(logging.Formatter(LOG_FORMAT)) 46 | logger.addHandler(f_handler) 47 | 48 | padding = [int(p) for p in args.padding.split(',')] 49 | ntrain = args.ntrain 50 | ntest = args.ntest 51 | N = args.ntotal 52 | args.in_channels = args.in_dim * args.in_var 53 | args.out_channels = args.out_dim * args.out_var 54 | r1 = args.h_down 55 | r2 = args.w_down 56 | s1 = int(((args.h - 1) / r1) + 1) 57 | s2 = int(((args.w - 1) / r2) + 1) 58 | T_in = args.T_in 59 | # T_out = args.T_out 60 | patch_size = tuple(int(x) for x in args.patch_size.split(',')) 61 | 62 | batch_size = args.batch_size 63 | learning_rate = args.learning_rate 64 | epochs = args.epochs 65 | step_size = args.step_size 66 | gamma = args.gamma 67 | 68 | model_save_path = args.model_save_path 69 | model_save_name = args.model_save_name 70 | 71 | 72 | ################################################################ 73 | # models 74 | ################################################################ 75 | model = get_model(args, ckpt_dir=Path(ckpt_dir)/dataset_nickname/model_name/time_str) 76 | state_dict = torch.load(Path(ckpt_dir)/dataset_nickname/model_name/time_str/ (model_save_name[:-3]+f'_{milestone}.pt')) 77 | model.load_state_dict(state_dict) 78 | 79 | 80 | ################################################################ 81 | # load data and data normalization 82 | ################################################################ 83 | train_dataset = SeaDatasetMemory(args, region=args.region, split='train', return_idx=True) 84 | test_dataset = SeaDatasetMemory(args, region=args.region, split='test', return_idx=True) 85 | train_loader = train_dataset.loader() 86 | test_loader = test_dataset.loader() 87 | 88 | land, sea = get_land_sea_mask(args.data_path, args.fill_value) 89 | 90 | data_mean = np.load(Path(args.data_path)/'..'/f'sea_{args.region}_mean.npy') 91 | 92 | 93 | ################################################################ 94 | # evaluation 95 | ################################################################ 96 | myloss = LpLoss(size_average=False, channel_wise=False) 97 | mseloss = nn.MSELoss() 98 | 99 | step = 1 100 | min_test_l2_full = 114514 101 | 102 | t1 = default_timer() 103 | test_l2_step = 0 104 | test_l2_full = 0 105 | test_vor_step = 0 106 | test_vor_full = 0 107 | test_acc_step = torch.zeros(T_out//step).to(device) 108 | test_acc_full = torch.zeros(T_out//step).to(device) 109 | with torch.no_grad(): 110 | for batch_idx, (index, xx, yy) in enumerate(tqdm(test_loader)): 111 | index = index.to(device) * 0.01 112 | loss = 0 113 | vor_loss = 0 114 | xx = xx.to(device) 115 | yy = yy.to(device) 116 | xx /= 256.0 117 | yy /= 256.0 118 | 119 | for i, t in enumerate(range(0, T_out, step)): 120 | y = yy[..., t*args.out_var : (t + step)*args.out_var] # B H W C_out=V_out 121 | im, _ = model(xx, index) 122 | im *= 256.0 123 | loss += myloss(im, y) 124 | vor_loss += mseloss(vorticity(-im[..., -2], im[..., -3]), vorticity(-y[..., -2], y[..., -3])) 125 | test_acc_step[i] += correct_acc_loss(im, y, data_mean) 126 | 127 | if t == 0: 128 | pred = im 129 | else: 130 | pred = torch.cat((pred, im), -1) 131 | 132 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 133 | 134 | test_l2_step += loss.item() 135 | test_l2_full += myloss(pred, yy).item() 136 | test_vor_step += vor_loss.item() 137 | test_vor_full += mseloss(vorticity(-pred[..., 3::args.in_var], pred[..., 2::args.in_var]), vorticity(-yy[..., 3::args.in_var], yy[..., 2::args.in_var])).item() 138 | for i, t in enumerate(range(0, T_out, step)): 139 | test_acc_full[i] += correct_acc_loss(pred[..., i*args.in_var:(i+1)*args.in_var], yy[..., i*args.in_var:(i+1)*args.in_var], data_mean) 140 | 141 | t2 = default_timer() 142 | if test_l2_full / ntest < min_test_l2_full: 143 | print(t2 - t1, 144 | 'test_rel_l2:', 145 | test_l2_step / ntest / (T_out / step), 146 | test_l2_full / ntest, 147 | 'test_vor:', 148 | test_vor_step / ntest / (T_out / step), 149 | test_vor_full / ntest, 150 | 'test_acc:', 151 | test_acc_step / ntest, 152 | test_acc_full / ntest) 153 | logger.info(f'{t2 - t1} ' + \ 154 | f'test_rel_l2: {test_l2_step / ntest / (T_out / step)} {test_l2_full / ntest} ' + \ 155 | f'test_vor: {test_vor_step / ntest / (T_out / step)} {test_vor_full / ntest} ' + \ 156 | f'test_acc: {test_acc_step / ntest} {test_acc_full / ntest}') 157 | pd = pred[-1, :, :, -5:].detach().cpu().numpy() 158 | gt = yy[-1, :, :, -5:].detach().cpu().numpy() 159 | vars = ['thetao', 'so', 'uo', 'vo', 'zos'] 160 | for i in range(5): 161 | visual(pd[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_pred.png')) 162 | visual(gt[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_gt.png')) 163 | visual(np.abs(gt-pd)[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_err.png')) 164 | else: 165 | raise Exception('Abnormal loss!') 166 | -------------------------------------------------------------------------------- /test_smoke_h.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import logging 4 | import pickle 5 | from timeit import default_timer 6 | from datetime import datetime, timedelta 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | from utils.data_factory import SmokeDataset, SmokeDatasetMemory 11 | from utils.utilities3 import * 12 | from utils.params import get_args, get_test_args 13 | from utils.adam import Adam 14 | from model_dict import get_model 15 | 16 | from tqdm import tqdm 17 | 18 | torch.manual_seed(0) 19 | np.random.seed(0) 20 | torch.cuda.manual_seed(0) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | ################################################################ 25 | # configs 26 | ################################################################ 27 | test_args = get_test_args() 28 | ckpt_dir = test_args.ckpt_dir 29 | dataset_nickname = test_args.dataset_nickname 30 | model_name = test_args.model_name 31 | time_str = test_args.time_str 32 | milestone = test_args.milestone 33 | T_out = test_args.T_out 34 | 35 | args = get_args(cfg_file=Path(ckpt_dir)/dataset_nickname/model_name/time_str/'configs.txt') 36 | test_save_path = os.path.join(args.run_save_path, f'test_{milestone}_{T_out}') 37 | if not os.path.isdir(test_save_path): 38 | os.makedirs(test_save_path) 39 | 40 | LOG_FORMAT = "%(message)s" 41 | logger = logging.getLogger('Loss logger') 42 | logger.setLevel(logging.INFO) 43 | f_handler = logging.FileHandler(os.path.join(test_save_path, args.log_save_name)) 44 | f_handler.setLevel(logging.INFO) 45 | f_handler.setFormatter(logging.Formatter(LOG_FORMAT)) 46 | logger.addHandler(f_handler) 47 | 48 | padding = [int(p) for p in args.padding.split(',')] 49 | ntrain = args.ntrain 50 | ntest = args.ntest 51 | N = args.ntotal 52 | args.in_channels = args.in_dim * args.in_var 53 | args.out_channels = args.out_dim * args.out_var 54 | r1 = args.z_down 55 | r2 = args.h_down 56 | r3 = args.w_down 57 | s1 = int(((args.z - 1) / r1) + 1) 58 | s2 = int(((args.h - 1) / r2) + 1) 59 | s3 = int(((args.w - 1) / r3) + 1) 60 | T_in = args.T_in 61 | # T_out = args.T_out 62 | patch_size = tuple(int(x) for x in args.patch_size.split(',')) 63 | 64 | batch_size = args.batch_size 65 | learning_rate = args.learning_rate 66 | epochs = args.epochs 67 | step_size = args.step_size 68 | gamma = args.gamma 69 | 70 | model_save_path = args.model_save_path 71 | model_save_name = args.model_save_name 72 | 73 | 74 | ################################################################ 75 | # models 76 | ################################################################ 77 | model = get_model(args, ckpt_dir=Path(ckpt_dir)/dataset_nickname/model_name/time_str) 78 | state_dict = torch.load(Path(ckpt_dir)/dataset_nickname/model_name/time_str/ (model_save_name[:-3]+f'_{milestone}.pt')) 79 | model.load_state_dict(state_dict) 80 | 81 | 82 | ################################################################ 83 | # load data and data normalization 84 | ################################################################ 85 | train_dataset = SmokeDatasetMemory(args, split='train') 86 | test_dataset = SmokeDatasetMemory(args, split='test') 87 | train_loader = train_dataset.loader() 88 | test_loader = test_dataset.loader() 89 | 90 | boundary = torch.ones(s1, s2, s3) 91 | boundary[1:-1, 1:-1, 1:-1] = 0 92 | domain = 1 - boundary.clone().detach() 93 | if 'DeepLag' in args.model: 94 | model.set_bdydom(boundary, domain) 95 | if args.resample_strategy == 'uniform' or args.resample_strategy == 'learned': 96 | model.num_samples = min(model.num_samples, s1*s2*s3//model.pixel_per_patch) 97 | elif args.resample_strategy == 'boundary': 98 | model.num_samples = min(model.num_samples, model.coo_boundary_ms[0].shape[0]) 99 | elif args.resample_strategy == 'domain': 100 | model.num_samples = min(model.num_samples, model.coo_domain_ms[0].shape[0]) 101 | 102 | 103 | ################################################################ 104 | # evaluation 105 | ################################################################ 106 | myloss = LpLoss(size_average=False, channel_wise=False) 107 | mseloss = nn.MSELoss() 108 | 109 | step = 1 110 | min_test_l2_full = 114514 111 | 112 | t1 = default_timer() 113 | test_l2_step = 0 114 | test_l2_full = 0 115 | test_vor_step = 0 116 | test_vor_full = 0 117 | with torch.no_grad(): 118 | for batch_idx, (xx, yy) in enumerate(tqdm(test_loader)): 119 | loss = 0 120 | vor_loss = 0 121 | xx = xx.to(device) # B Z H W T*C 122 | yy = yy.to(device) # B Z H W T*C 123 | if 'DeepLag' in args.model: 124 | h_x_q, h_coo_q, h_coo_offset_q = [], [], [] 125 | for i in range(model.num_layers): 126 | if args.resample_strategy == 'uniform': 127 | num_samples = model.num_samples // (8**i) 128 | coo_q = torch.cat([ 129 | torch.randint(0,model.img_z_layers[i]-1,(batch_size,num_samples,1)), 130 | torch.randint(0,model.img_h_layers[i]-1,(batch_size,num_samples,1)), 131 | torch.randint(0,model.img_w_layers[i]-1,(batch_size,num_samples,1)) 132 | ], dim=-1).to(torch.float32) # b k 3 133 | elif args.resample_strategy == 'boundary': 134 | num_samples = min(model.num_samples//(8**i), model.coo_boundary_ms[i].shape[0]) 135 | idx_coo_sample = torch.multinomial(1./torch.ones(model.coo_boundary_ms[i].shape[0]), num_samples, replacement=False) # k 136 | coo_q = model.coo_boundary_ms[i][idx_coo_sample][None, ...].repeat(batch_size,1,1).to(torch.float32) # b k 3 137 | elif args.resample_strategy == 'domain': 138 | num_samples = min(model.num_samples//(8**i), model.coo_domain_ms[i].shape[0]) 139 | idx_coo_sample = torch.multinomial(1./torch.ones(model.coo_domain_ms[i].shape[0]), num_samples, replacement=False) # k 140 | coo_q = model.coo_domain_ms[i][idx_coo_sample][None, ...].repeat(batch_size,1,1).to(torch.float32) # b k 3 141 | elif args.resample_strategy == 'learned': 142 | num_samples = model.num_samples // (8**i) 143 | coo_q = None 144 | num_chan = args.d_model*(2**i) if i < model.num_layers-1 else args.d_model*(2**(i-1)) 145 | h_x_q.append(torch.zeros(batch_size, num_samples, num_chan).to(device)) 146 | h_coo_q.append(coo_q.to(device) if args.resample_strategy != 'learned' else None) 147 | h_coo_offset_q.append(torch.zeros(batch_size, num_samples, 3).to(device)) 148 | 149 | for t in range(0, T_out, step): 150 | y = yy[..., t*args.out_var : (t + step)*args.out_var] # B Z H W C_out=V_out 151 | if 'DeepLag' in args.model: 152 | im, h_x_q, h_coo_q, h_coo_offset_q, coo_offset_zxys = model(xx, h_x_q, h_coo_q, h_coo_offset_q) # B Z H W C_out=V_out 153 | else: 154 | im = model(xx) 155 | loss += myloss(im.reshape(batch_size, -1), y.reshape(batch_size, -1)) 156 | vor_loss += mseloss(vorticity_3d(im[..., -3], im[..., -2], im[..., -1]), vorticity_3d(-y[..., -3], y[..., -2], y[..., -1])) 157 | 158 | if t == 0: 159 | pred = im 160 | else: 161 | pred = torch.cat((pred, im), -1) 162 | 163 | xx = torch.cat((xx[..., step*args.in_var:], im), dim=-1) 164 | 165 | test_l2_step += loss.item() 166 | test_l2_full += myloss(pred.reshape(batch_size, -1), yy.reshape(batch_size, -1)).item() 167 | test_vor_step += vor_loss.item() 168 | test_vor_full += mseloss(vorticity_3d(pred[..., 1::args.in_var], pred[..., 2::args.in_var], pred[..., 3::args.in_var]), vorticity_3d(yy[..., 1::args.in_var], yy[..., 2::args.in_var], yy[..., 3::args.in_var])).item() 169 | 170 | t2 = default_timer() 171 | if test_l2_full / ntest < min_test_l2_full: 172 | print(t2 - t1, 173 | 'test_rel_l2:', 174 | test_l2_step / ntest / (T_out / step), 175 | test_l2_full / ntest, 176 | 'test_vor:', 177 | test_vor_step / ntest / (T_out / step), 178 | test_vor_full / ntest) 179 | logger.info(f'{t2 - t1} ' + \ 180 | f'test_rel_l2: {test_l2_step / ntest / (T_out / step)} {test_l2_full / ntest} ' + \ 181 | f'test_vor: {test_vor_step / ntest / (T_out / step)} {test_vor_full / ntest}') 182 | pd = pred[-1, :, :, :, -4:].detach().cpu().numpy() 183 | gt = yy[-1, :, :, :, -4:].detach().cpu().numpy() 184 | vars = ['field', 'ux', 'uy', 'uz'] 185 | for i in range(4): 186 | visual_zoy(pd[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_pred.png')) 187 | visual_zoy(gt[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_gt.png')) 188 | visual_zoy(np.abs(gt-pd)[...,i], os.path.join(test_save_path, f'{milestone}_{vars[i]}_err.png')) 189 | else: 190 | raise Exception('Abnormal loss!') 191 | -------------------------------------------------------------------------------- /utils/adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Optional 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | def adam(params: List[Tensor], 9 | grads: List[Tensor], 10 | exp_avgs: List[Tensor], 11 | exp_avg_sqs: List[Tensor], 12 | max_exp_avg_sqs: List[Tensor], 13 | state_steps: List[int], 14 | *, 15 | amsgrad: bool, 16 | beta1: float, 17 | beta2: float, 18 | lr: float, 19 | weight_decay: float, 20 | eps: float): 21 | r"""Functional API that performs Adam algorithm computation. 22 | See :class:`~torch.optim.Adam` for details. 23 | """ 24 | 25 | for i, param in enumerate(params): 26 | 27 | grad = grads[i] 28 | exp_avg = exp_avgs[i] 29 | exp_avg_sq = exp_avg_sqs[i] 30 | step = state_steps[i] 31 | 32 | bias_correction1 = 1 - beta1 ** step 33 | bias_correction2 = 1 - beta2 ** step 34 | 35 | if weight_decay != 0: 36 | grad = grad.add(param, alpha=weight_decay) 37 | 38 | # Decay the first and second moment running average coefficient 39 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 40 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 41 | if amsgrad: 42 | # Maintains the maximum of all 2nd moment running avg. till now 43 | torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 44 | # Use the max. for normalizing running avg. of gradient 45 | denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) 46 | else: 47 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 48 | 49 | step_size = lr / bias_correction1 50 | 51 | param.addcdiv_(exp_avg, denom, value=-step_size) 52 | 53 | 54 | class Adam(Optimizer): 55 | r"""Implements Adam algorithm. 56 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 57 | The implementation of the L2 penalty follows changes proposed in 58 | `Decoupled Weight Decay Regularization`_. 59 | Args: 60 | params (iterable): iterable of parameters to optimize or dicts defining 61 | parameter groups 62 | lr (float, optional): learning rate (default: 1e-3) 63 | betas (Tuple[float, float], optional): coefficients used for computing 64 | running averages of gradient and its square (default: (0.9, 0.999)) 65 | eps (float, optional): term added to the denominator to improve 66 | numerical stability (default: 1e-8) 67 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 68 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 69 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 70 | (default: False) 71 | .. _Adam\: A Method for Stochastic Optimization: 72 | https://arxiv.org/abs/1412.6980 73 | .. _Decoupled Weight Decay Regularization: 74 | https://arxiv.org/abs/1711.05101 75 | .. _On the Convergence of Adam and Beyond: 76 | https://openreview.net/forum?id=ryQu7f-RZ 77 | """ 78 | 79 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 80 | weight_decay=0, amsgrad=False): 81 | if not 0.0 <= lr: 82 | raise ValueError("Invalid learning rate: {}".format(lr)) 83 | if not 0.0 <= eps: 84 | raise ValueError("Invalid epsilon value: {}".format(eps)) 85 | if not 0.0 <= betas[0] < 1.0: 86 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 87 | if not 0.0 <= betas[1] < 1.0: 88 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 89 | if not 0.0 <= weight_decay: 90 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 91 | defaults = dict(lr=lr, betas=betas, eps=eps, 92 | weight_decay=weight_decay, amsgrad=amsgrad) 93 | super(Adam, self).__init__(params, defaults) 94 | 95 | def __setstate__(self, state): 96 | super(Adam, self).__setstate__(state) 97 | for group in self.param_groups: 98 | group.setdefault('amsgrad', False) 99 | 100 | @torch.no_grad() 101 | def step(self, closure=None): 102 | """Performs a single optimization step. 103 | Args: 104 | closure (callable, optional): A closure that reevaluates the model 105 | and returns the loss. 106 | """ 107 | loss = None 108 | if closure is not None: 109 | with torch.enable_grad(): 110 | loss = closure() 111 | 112 | for group in self.param_groups: 113 | params_with_grad = [] 114 | grads = [] 115 | exp_avgs = [] 116 | exp_avg_sqs = [] 117 | max_exp_avg_sqs = [] 118 | state_steps = [] 119 | beta1, beta2 = group['betas'] 120 | 121 | for p in group['params']: 122 | if p.grad is not None: 123 | params_with_grad.append(p) 124 | if p.grad.is_sparse: 125 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 126 | grads.append(p.grad) 127 | 128 | state = self.state[p] 129 | # Lazy state initialization 130 | if len(state) == 0: 131 | state['step'] = 0 132 | # Exponential moving average of gradient values 133 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 134 | # Exponential moving average of squared gradient values 135 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 136 | if group['amsgrad']: 137 | # Maintains max of all exp. moving avg. of sq. grad. values 138 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 139 | 140 | exp_avgs.append(state['exp_avg']) 141 | exp_avg_sqs.append(state['exp_avg_sq']) 142 | 143 | if group['amsgrad']: 144 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 145 | 146 | # update the steps for each param group update 147 | state['step'] += 1 148 | # record the step after step update 149 | state_steps.append(state['step']) 150 | 151 | adam(params_with_grad, 152 | grads, 153 | exp_avgs, 154 | exp_avg_sqs, 155 | max_exp_avg_sqs, 156 | state_steps, 157 | amsgrad=group['amsgrad'], 158 | beta1=beta1, 159 | beta2=beta2, 160 | lr=group['lr'], 161 | weight_decay=group['weight_decay'], 162 | eps=group['eps']) 163 | return loss 164 | -------------------------------------------------------------------------------- /utils/split_merge_npy_file.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | 5 | def split_npy_file(input_file, output_dir, chunk_size=1000): 6 | # create output dir 7 | if not os.path.exists(output_dir): 8 | os.makedirs(output_dir) 9 | 10 | # read large .npy 11 | filename = os.path.basename(input_file) 12 | large_array = np.load(input_file) 13 | n, h, w = large_array.shape 14 | 15 | # split by chunk_size and save 16 | for i in range(0, n, chunk_size): 17 | split_array = large_array[i:i+chunk_size] 18 | output_file = os.path.join(output_dir, f'{filename[:-4]}_split_{i // chunk_size}.npy') 19 | np.save(output_file, split_array) 20 | print(f'Saved {output_file}') 21 | 22 | 23 | def merge_npy_files(input_dir, output_file): 24 | # get all .npy in dir, sort by file name 25 | split_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.npy') and 'split' in f]) 26 | 27 | # read files and concat arrays 28 | arrays = [np.load(f) for f in split_files] 29 | merged_array = np.concatenate(arrays, axis=0) 30 | 31 | # save marged file 32 | np.save(output_file, merged_array) 33 | print(f'Saved merged file as {output_file}') 34 | 35 | 36 | # usage 37 | large_npy_file = '/home/miaoshangchen/NAS/Bounded_NS/re0.25_4c_gray_10000.npy' 38 | split_save_dir = os.path.dirname(large_npy_file) 39 | # split_npy_file(large_npy_file, split_save_dir, chunk_size=2500) 40 | merge_npy_files(split_save_dir, large_npy_file) 41 | --------------------------------------------------------------------------------