├── utils ├── ads.npy ├── lambda.npy ├── sweeps.py ├── domains.py ├── loss_utils.py ├── optimizer_utils.py ├── logging_utils.py ├── YParams.py ├── get_scale.ipynb ├── data_utils.py ├── misc_utils.py ├── gen_data_helmholtz.py ├── gen_data_poisson.py ├── inferencer.py ├── gen_data_advdiff.py └── trainer.py ├── assets └── overview.png ├── requirements.txt ├── export_DDP_vars.sh ├── config ├── sweep_config.yaml ├── operators_ad.yaml ├── operators_poisson.yaml └── operators_helmholtz.yaml ├── run_gen_data.sh ├── LICENSE ├── run.sh ├── train.py ├── models ├── ffn.py ├── fno.py └── basics.py ├── eval.py ├── .gitignore └── README.md /utils/ads.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShashankSubramanian/neuraloperators-TL-scaling/HEAD/utils/ads.npy -------------------------------------------------------------------------------- /utils/lambda.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShashankSubramanian/neuraloperators-TL-scaling/HEAD/utils/lambda.npy -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShashankSubramanian/neuraloperators-TL-scaling/HEAD/assets/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==3.7.0 2 | matplotlib==3.5.2 3 | mpi4py==3.1.3 4 | numpy==1.22.3 5 | ruamel.yaml==0.17.21 6 | scipy==1.6.3 7 | torch==1.12.0 8 | torchvision==0.13.0 9 | wandb==0.15.3 10 | -------------------------------------------------------------------------------- /export_DDP_vars.sh: -------------------------------------------------------------------------------- 1 | export RANK=$SLURM_PROCID 2 | export WORLD_RANK=$SLURM_PROCID 3 | export LOCAL_RANK=$SLURM_LOCALID 4 | export WORLD_SIZE=$SLURM_NTASKS 5 | export MASTER_PORT=29500 # default from torch launcher 6 | export WANDB_START_METHOD="thread" 7 | -------------------------------------------------------------------------------- /config/sweep_config.yaml: -------------------------------------------------------------------------------- 1 | name: name_of_sweep 2 | entity: your_wb_entitiy 3 | project: your_wb_project 4 | program: dummy 5 | method: grid 6 | metric: 7 | name: val_err 8 | goal: minimize 9 | parameters: # total 65 jobs 10 | lr: # sample this lrs 11 | values: [1E-5, 5E-5, 1E-4, 5E-4, 1E-3] 12 | subsample: # subsample downstream dataset by this amount 13 | values: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] 14 | plot_figs: # dont plot figs in W&b 15 | value: !!bool False 16 | -------------------------------------------------------------------------------- /utils/sweeps.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from decimal import Decimal 3 | 4 | def format_lr(lr): 5 | sci = '%.0E'%lr 6 | return sci.replace('E-0', 'em') # assumes 1e-10 < lr < 1e0 7 | 8 | def sweep_name_suffix(params, sweep_id): 9 | ''' 10 | Return a unique name for each sweep trial in a given sweep 11 | Allows custom naming of sweep trials, e.g. according to trial hyperparams 12 | Naming scheme is chosen based on the given sweep id and trial parameters 13 | ''' 14 | if sweep_id in ['jponn9sj']: 15 | return 'lr%s_s%d'%(format_lr(params.lr), params.subsample) 16 | else: 17 | dt = datetime.now() 18 | return dt.strftime("%Y%m%d-%H-%M-%S") 19 | -------------------------------------------------------------------------------- /utils/domains.py: -------------------------------------------------------------------------------- 1 | ''' 2 | domain classes 3 | ''' 4 | import torch 5 | import random 6 | import numpy as np 7 | from utils.misc_utils import normalize, softmax, show 8 | import matplotlib.pyplot as plt 9 | 10 | class DomainXY(): 11 | """ 12 | Creates a uniform grid of 2D spatial points 13 | """ 14 | def __init__(self, params): 15 | self.params = params 16 | dx = params.Lx / params.nx 17 | dy = params.Ly / params.ny 18 | self.dx = dx 19 | self.dy = dy 20 | self.x = np.arange(0, params.Lx, dx) 21 | self.y = np.arange(0, params.Ly, dy) 22 | x_g, y_g = np.meshgrid(self.x, self.y) 23 | self.x_g, self.y_g = x_g, y_g 24 | self.grid = np.column_stack((x_g.flatten(), y_g.flatten())) 25 | 26 | -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | loss functions 3 | """ 4 | import torch 5 | import logging 6 | import numpy as np 7 | import time 8 | import torchvision 9 | 10 | 11 | class LossMSE(): 12 | """ mse loss """ 13 | def __init__(self, params, model): 14 | self.params = params 15 | self.model = model 16 | 17 | def data(self, inputs, pred, target): 18 | if self.params.loss_style == 'mean': 19 | loss = torch.mean((target - pred)**2) 20 | elif self.params.loss_style == 'sum': 21 | loss = torch.sum((target - pred)**2)/pred.shape[0] 22 | return loss 23 | 24 | def bc(self, inputs, pred, targets): 25 | # currently no BC 26 | return torch.tensor(0.).to(self.params.device, dtype=torch.float32) 27 | 28 | def pde(self, inputs, pred, targets): 29 | # currently no PDE loss 30 | return torch.tensor(0.).to(self.params.device, dtype=torch.float32) 31 | 32 | -------------------------------------------------------------------------------- /utils/optimizer_utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import logging 3 | import torch 4 | import torch.optim as optim 5 | from torch.optim import lr_scheduler 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | 9 | def set_scheduler(args, optimizer): 10 | """ set the lr scheduler """ 11 | if args.scheduler == 'reducelr': 12 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=args.patience, verbose=True, min_lr=1e-3*1e-5, factor=0.2) 13 | elif args.scheduler == 'cosine': 14 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_cosine_lr_epochs) 15 | else: 16 | scheduler = None 17 | return scheduler 18 | 19 | def set_optimizer(args, net): 20 | """ set the optimizer """ 21 | if args.optimizer == "adam": 22 | optimizer = optim.Adam(net.parameters(), lr=args.lr) 23 | elif args.optimizer == "sgd": 24 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9) 25 | return optimizer 26 | 27 | -------------------------------------------------------------------------------- /run_gen_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ntrain=32768 4 | nval=4096 5 | ntest=4096 6 | ng=144 7 | datapath='/path/to/data/poisson' 8 | 9 | e1=1 # poissons diffusion eigenvalue range 10 | e2=5 11 | 12 | adr1=0.2 # advection to diffusion ratio range 13 | adr2=1 14 | # for AD ratio we saved a set of velocity scales that correspond to AD ration in utils/*.npy. See python script for details 15 | 16 | o1=1 # helmholtz wave number range 17 | o2=10 18 | 19 | # create poissons examples 20 | python utils/gen_data_poisson.py --ntrain=$ntrain --nval=$nval --ntest=$ntest \ 21 | --ng=$ng --sparse --n 128 --datapath $datapath --e1 $e1 --e2 $e2 22 | # create AD examples 23 | #python utils/gen_data_advdiff.py --ntrain=$ntrain --nval=$nval --ntest=$ntest \ 24 | # --ng=$ng --sparse --n 128 --datapath $datapath --adr1 $adr1 --adr2 $adr2 25 | # create Helm examples 26 | #python utils/gen_data_helmholtz.py --ntrain=$ntrain --nval=$nval --ntest=$ntest \ 27 | # --ng=$ng --sparse --n 128 --datapath $datapath --o1 $o1 --o2 $o2 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shashank Subramanian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | _format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 5 | 6 | def config_logger(log_level=logging.INFO): 7 | logging.basicConfig(format=_format, level=log_level) 8 | 9 | def log_to_file(logger_name=None, log_level=logging.INFO, log_filename='tensorflow.log'): 10 | 11 | if not os.path.exists(os.path.dirname(log_filename)): 12 | os.makedirs(os.path.dirname(log_filename)) 13 | 14 | if logger_name is not None: 15 | log = logging.getLogger(logger_name) 16 | else: 17 | log = logging.getLogger() 18 | 19 | fh = logging.FileHandler(log_filename) 20 | fh.setLevel(log_level) 21 | fh.setFormatter(logging.Formatter(_format)) 22 | log.addHandler(fh) 23 | 24 | def log_versions(): 25 | import torch 26 | import subprocess 27 | 28 | logging.info('--------------- Versions ---------------') 29 | logging.info('git branch: ' + str(subprocess.check_output(['git', 'branch']).strip())) 30 | logging.info('git hash: ' + str(subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip())) 31 | logging.info('Torch: ' + str(torch.__version__)) 32 | logging.info('----------------------------------------') 33 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # this run script assumes slurm job scheduler, but similar run cmd can be used elsewhere 4 | # for DDP (slurm vars setting) 5 | export MASTER_ADDR=$(hostname) 6 | 7 | # number of gpus 8 | ngpu=4 9 | 10 | # yaml file 11 | config_file=./config/operators_poisson.yaml 12 | # config name to run 13 | config="poisson-scale-k1_5" 14 | # sub run number 15 | run_num="test" 16 | 17 | # where to store results 18 | scratch="/path/to/results/" 19 | 20 | # run command 21 | cmd="python train.py --yaml_config=$config_file --config=$config --run_num=$run_num --root_dir=$scratch" 22 | 23 | # source DDP vars first for data-parallel training (if not srun, just source and then run cmd; see pytorch docs for DDP) 24 | srun -l -n $ngpu --cpus-per-task=10 --gpus-per-node $ngpu bash -c "source export_DDP_vars.sh && $cmd" 25 | 26 | # for inference run the following commands to use eval.py (single gpu is sufficient, no logging by default) 27 | # pass the model weights to use 28 | #weights_for_inference=$scratch/expts/$config/$run_num/checkpoints/ckpt_best.tar 29 | #cmd_inf="python eval.py --yaml_config=$config_file --config=$config --run_num=$run_num --root_dir=$scratch --weights=$weights_for_inference" 30 | #bash -c "$cmd_inf" 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import argparse 3 | import torch 4 | import wandb 5 | import matplotlib.pyplot as plt 6 | import logging 7 | import torch.distributed as dist 8 | from utils import logging_utils 9 | logging_utils.config_logger() 10 | from utils.YParams import YParams 11 | from utils.trainer import Trainer 12 | 13 | if __name__ == '__main__': 14 | # parsers 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--yaml_config", default='./config/operators.yaml', type=str) 17 | parser.add_argument("--config", default='default', type=str) 18 | parser.add_argument("--root_dir", default='./', type=str, help='root dir to store results') 19 | parser.add_argument("--run_num", default='0', type=str, help='sub run config') 20 | parser.add_argument("--sweep_id", default=None, type=str, help='sweep config from ./configs/sweeps.yaml') 21 | args = parser.parse_args() 22 | params = YParams(os.path.abspath(args.yaml_config), args.config) 23 | trainer = Trainer(params, args) 24 | 25 | if args.sweep_id and trainer.world_rank==0: 26 | logging.disable(logging.CRITICAL) 27 | wandb.agent(args.sweep_id, function=trainer.launch, count=1, entity=trainer.params.entity, project=trainer.params.project) 28 | else: 29 | trainer.launch() 30 | 31 | if dist.is_initialized(): 32 | dist.barrier() 33 | 34 | logging.info('DONE') 35 | -------------------------------------------------------------------------------- /models/ffn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from utils.misc_utils import set_activation 5 | 6 | class FeedForward(nn.Module): 7 | ''' An n-layer-feed-forward-layer module ''' 8 | def __init__(self, in_dim=2, out_dim=1, depth=5, hidden_dim=50, activation='tanh'): 9 | super().__init__() 10 | self.depth = depth 11 | self.activation = set_activation(activation) 12 | self.ff_in = nn.Linear(in_dim, hidden_dim) 13 | self.linears = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for i in range(self.depth-2)]) 14 | self.ff_out = nn.Linear(hidden_dim, out_dim) 15 | self.apply(self._init_weights) 16 | 17 | def _init_weights(self, m): 18 | ''' Xavier Normal Initialization ''' 19 | if isinstance(m, nn.Linear): 20 | nn.init.xavier_normal_(m.weight.data, gain=1.0) 21 | if isinstance(m, nn.Linear) and m.bias is not None: 22 | nn.init.zeros_(m.bias.data) 23 | 24 | def forward(self, x): 25 | x = self.ff_in(x) 26 | x = self.activation(x) 27 | for i in range(self.depth-2): 28 | x = self.linears[i](x) 29 | x = self.activation(x) 30 | x = self.ff_out(x) 31 | return x 32 | 33 | def ffn_pinns(params): 34 | return FeedForward(in_dim=params.in_dim, out_dim=params.out_dim, 35 | depth=params.depth, 36 | hidden_dim=params.hidden_dim, 37 | activation='tanh') 38 | -------------------------------------------------------------------------------- /utils/YParams.py: -------------------------------------------------------------------------------- 1 | from ruamel.yaml import YAML 2 | import logging 3 | 4 | class YParams(): 5 | """ Yaml file parser """ 6 | def __init__(self, yaml_filename, config_name, print_params=False): 7 | self._yaml_filename = yaml_filename 8 | self._config_name = config_name 9 | self.params = {} 10 | 11 | if print_params: 12 | print("------------------ Configuration ------------------") 13 | 14 | with open(yaml_filename) as _file: 15 | 16 | for key, val in YAML().load(_file)[config_name].items(): 17 | if print_params: print(key, val) 18 | if val =='None': val = None 19 | 20 | self.params[key] = val 21 | self.__setattr__(key, val) 22 | 23 | if print_params: 24 | print("---------------------------------------------------") 25 | 26 | def __getitem__(self, key): 27 | return self.params[key] 28 | 29 | def __setitem__(self, key, val): 30 | self.params[key] = val 31 | self.__setattr__(key, val) 32 | 33 | def __contains__(self, key): 34 | return (key in self.params) 35 | 36 | def update_params(self, config): 37 | for key, val in config.items(): 38 | self.params[key] = val 39 | self.__setattr__(key, val) 40 | 41 | def log(self): 42 | logging.info("------------------ Configuration ------------------") 43 | logging.info("Configuration file: "+str(self._yaml_filename)) 44 | logging.info("Configuration name: "+str(self._config_name)) 45 | for key, val in self.params.items(): 46 | logging.info(str(key) + ' ' + str(val)) 47 | logging.info("---------------------------------------------------") 48 | -------------------------------------------------------------------------------- /config/operators_ad.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | num_data_workers: 1 3 | # model 4 | model: 'fno' 5 | depth: 5 6 | in_dim: 2 7 | out_dim: 1 8 | dropout: 0 9 | # data/domain 10 | Lx: !!float 1.0 11 | Ly: !!float 1.0 12 | nx: 256 13 | ny: 256 14 | # optimization 15 | loss_style: 'mean' 16 | loss_func: 'mse' 17 | optimizer: 'adam' 18 | scheduler: 'none' 19 | lr: !!float 1.0 20 | max_epochs: 500 21 | max_cosine_lr_epochs: 500 22 | batch_size: 25 23 | # misc 24 | log_to_screen: !!bool True 25 | save_checkpoint: !!bool False 26 | seed: 0 27 | plot_figs: !!bool False 28 | pack_data: !!bool False 29 | # Weights & Biases 30 | # Weights & Biases 31 | entity: 'your_wandb_entity' 32 | project: 'your_wandb_project' 33 | distill: !!bool False 34 | subsample: 1 35 | 36 | advdiff: &advdiff 37 | <<: *DEFAULT 38 | batch_size: 128 39 | valid_batch_size: 128 40 | nx: 128 41 | ny: 128 42 | log_to_wandb: !!bool True 43 | save_checkpoint: !!bool True 44 | max_epochs: 500 45 | scheduler: 'cosine' 46 | plot_figs: !!bool True 47 | loss_style: 'sum' 48 | 49 | model: 'fno' 50 | layers: [64, 64, 64, 64, 64] 51 | modes1: [65, 65, 65, 65] 52 | modes2: [65, 65, 65, 65] 53 | fc_dim: 128 54 | 55 | in_dim: 6 56 | out_dim: 1 57 | mode_cut: 16 58 | embed_cut: 64 59 | fc_cut: 2 60 | 61 | optimizer: 'adam' 62 | 63 | lr: 1E-3 64 | pack_data: !!bool False 65 | 66 | ad-scale-adr0p2_1: &ad_scale_0p2_1 67 | <<: *advdiff 68 | train_path: '/path/to/data/train_adr0p2_1_32k.h5' 69 | val_path: '/path/to/data/val_adr0p2_1_4k.h5' 70 | test_path: '/path/to/data/test_adr0p2_1_4k.h5' 71 | scales_path: '/path/to/data/train_adr0p2_1_scales.npy' 72 | batch_size: 128 73 | valid_batch_size: 128 74 | log_to_wandb: !!bool False 75 | mode_cut: 32 76 | embed_cut: 64 77 | fc_cut: 2 78 | -------------------------------------------------------------------------------- /config/operators_poisson.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | num_data_workers: 1 3 | # model 4 | model: 'fno' 5 | depth: 5 6 | in_dim: 2 7 | out_dim: 1 8 | dropout: 0 9 | # data/domain 10 | Lx: !!float 1.0 11 | Ly: !!float 1.0 12 | nx: 256 13 | ny: 256 14 | # optimization 15 | loss_style: 'mean' 16 | loss_func: 'mse' 17 | optimizer: 'adam' 18 | scheduler: 'none' 19 | lr: !!float 1.0 20 | max_epochs: 500 21 | max_cosine_lr_epochs: 500 22 | batch_size: 25 23 | # misc 24 | log_to_screen: !!bool True 25 | save_checkpoint: !!bool False 26 | seed: 0 27 | plot_figs: !!bool False 28 | pack_data: !!bool False 29 | # Weights & Biases 30 | entity: 'your_wandb_entity' 31 | project: 'your_wandb_project' 32 | log_to_wandb: !!bool False 33 | distill: !!bool False 34 | subsample: 1 35 | 36 | poisson: &poisson 37 | <<: *DEFAULT 38 | batch_size: 512 39 | valid_batch_size: 512 40 | nx: 128 41 | ny: 128 42 | log_to_wandb: !!bool False 43 | save_checkpoint: !!bool True 44 | max_epochs: 500 45 | scheduler: 'cosine' 46 | plot_figs: !!bool True 47 | loss_style: 'sum' 48 | system: 'poisson' 49 | 50 | model: 'fno' 51 | layers: [64, 64, 64, 64, 64] 52 | modes1: [65, 65, 65, 65] 53 | modes2: [65, 65, 65, 65] 54 | fc_dim: 128 55 | 56 | in_dim: 4 57 | out_dim: 1 58 | mode_cut: 16 59 | embed_cut: 64 60 | fc_cut: 2 61 | 62 | optimizer: 'adam' 63 | 64 | lr: 1E-3 65 | pack_data: !!bool False 66 | 67 | poisson-scale-k1_5: &poisson_scale_k1_5 68 | <<: *poisson 69 | train_path: '/path/to/data/_train_k1_5_32k.h5' 70 | val_path: '/path/to/data/_val_k1_5_4k.h5' 71 | test_path: '/path/to/data/_test_k1_5_4k.h5' 72 | scales_path: '/path/to/data/_train_k1_5_scales.npy' 73 | batch_size: 128 74 | valid_batch_size: 128 75 | log_to_wandb: !!bool False 76 | mode_cut: 32 77 | embed_cut: 64 78 | fc_cut: 2 79 | -------------------------------------------------------------------------------- /config/operators_helmholtz.yaml: -------------------------------------------------------------------------------- 1 | default: &DEFAULT 2 | num_data_workers: 1 3 | # model 4 | model: 'fno' 5 | depth: 5 6 | in_dim: 2 7 | out_dim: 1 8 | dropout: 0 9 | # data/domain 10 | Lx: !!float 1.0 11 | Ly: !!float 1.0 12 | nx: 256 13 | ny: 256 14 | # optimization 15 | loss_style: 'mean' 16 | loss_func: 'mse' 17 | optimizer: 'adam' 18 | scheduler: 'none' 19 | lr: !!float 1.0 20 | max_epochs: 500 21 | max_cosine_lr_epochs: 500 22 | batch_size: 25 23 | # misc 24 | log_to_screen: !!bool True 25 | save_checkpoint: !!bool False 26 | seed: 0 27 | plot_figs: !!bool False 28 | pack_data: !!bool False 29 | # Weights & Biases 30 | entity: 'pinns' 31 | project: 'neuraloperators' 32 | log_to_wandb: !!bool False 33 | distill: !!bool False 34 | subsample: 1 35 | 36 | helmholtz: &helmholtz 37 | <<: *DEFAULT 38 | batch_size: 128 39 | valid_batch_size: 128 40 | nx: 128 41 | ny: 128 42 | log_to_wandb: !!bool True 43 | save_checkpoint: !!bool True 44 | max_epochs: 500 45 | scheduler: 'cosine' 46 | plot_figs: !!bool True 47 | loss_style: 'sum' 48 | 49 | model: 'fno' 50 | layers: [64, 64, 64, 64, 64] 51 | modes1: [65, 65, 65, 65] 52 | modes2: [65, 65, 65, 65] 53 | fc_dim: 128 54 | 55 | in_dim: 2 56 | out_dim: 1 57 | mode_cut: 32 58 | embed_cut: 64 59 | fc_cut: 2 60 | 61 | optimizer: 'adam' 62 | 63 | lr: 1E-3 64 | pack_data: !!bool False 65 | 66 | helm-scale-o1_10: &helm_o1_10 67 | <<: *helmholtz 68 | train_path: '/path/to/data/train_o1_10_32k.h5' 69 | val_path: '/path/to/data/val_o1_10_4k.h5' 70 | test_path: '/path/to/data/test_o1_10_4k.h5' 71 | scales_path: '/path/to/data/train_o1_10_scales.npy' 72 | batch_size: 128 73 | valid_batch_size: 128 74 | log_to_wandb: !!bool False 75 | in_dim: 3 76 | out_dim: 1 77 | mode_cut: 32 78 | embed_cut: 64 79 | fc_cut: 2 80 | lr: 1E-3 81 | subsample: 1 82 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import argparse 3 | import torch 4 | import wandb 5 | import matplotlib.pyplot as plt 6 | import logging 7 | import torch.distributed as dist 8 | from utils import logging_utils 9 | logging_utils.config_logger() 10 | from utils.YParams import YParams 11 | from utils.inferencer import Inferencer 12 | from ruamel.yaml import YAML 13 | from ruamel.yaml.comments import CommentedMap as ruamelDict 14 | import numpy as np 15 | 16 | 17 | if __name__ == '__main__': 18 | # parsers 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--yaml_config", default='./config/operators.yaml', type=str) 21 | parser.add_argument("--config", default='default', type=str) 22 | parser.add_argument("--root_dir", default='./', type=str, help='root dir to store results') 23 | parser.add_argument("--run_num", default='0', type=str, help='sub run config') 24 | parser.add_argument("--sweep", default='none', type=str) 25 | parser.add_argument("--weights", default='./ckpt.tar', type=str) 26 | args = parser.parse_args() 27 | 28 | 29 | params = YParams(os.path.abspath(args.yaml_config), args.config) 30 | logging.info("Starting config {}".format(args.config)) 31 | 32 | params['weights'] = args.weights 33 | 34 | if hasattr(params, 'weights'): 35 | logging.info("with weights {}".format(params.weights)) 36 | else: 37 | assert(False), "no model weights provided" 38 | 39 | inferencer = Inferencer(params, args) 40 | if inferencer.world_rank == 0: 41 | hparams = ruamelDict() 42 | yaml = YAML() 43 | for key, value in params.params.items(): 44 | hparams[str(key)] = str(value) 45 | with open(os.path.join(params['experiment_dir'], 'hyperparams.yaml'), 'w') as hpfile: 46 | yaml.dump(hparams, hpfile ) 47 | inferencer.launch() 48 | if dist.is_initialized(): 49 | dist.barrier() 50 | logging.info("Finished config {}".format(args.config)) 51 | 52 | 53 | logging.info('DONE') 54 | -------------------------------------------------------------------------------- /utils/get_scale.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1b1572df-976c-4991-974d-a6400217eaf8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import sys\n", 12 | "import numpy as np\n", 13 | "import matplotlib\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "from mpl_toolkits.axes_grid1 import make_axes_locatable\n", 16 | "from matplotlib.colors import Normalize\n", 17 | "import h5py" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "a9a57d31-6dce-491c-8bbc-2f6ab754cf40", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "datapath = \"/path/to/data/poisson\"" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "b5c821cf-83e2-4306-80e2-104d4f19dacc", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "with h5py.File(os.path.join(datapath, \"_train_k1_5_32k.h5\"), \"r\") as f:\n", 38 | " x_train = f['fields'][:]\n", 39 | " x_tensor = f['tensor'][:]\n", 40 | " print(list(f.keys()))" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "6ef7a2fa-f7d8-47e1-8f8d-42abffb23358", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "source_norm = []\n", 51 | "sol_max = []\n", 52 | "tensor_max = []\n", 53 | "nx = ny = 128\n", 54 | "lx = ly = 1\n", 55 | "num_ten = x_tensor.shape[1]\n", 56 | "\n", 57 | "for i in range(x_train.shape[0]):\n", 58 | " sn = np.linalg.norm(x_train[i,0]) * lx/nx * ly/ny\n", 59 | " source_norm.append(sn)\n", 60 | " sol_max.append(np.max(np.abs(x_train[i,1])))\n", 61 | " tensor_max.append([np.abs(x_tensor[i,t_idx]) for t_idx in range(num_ten)])\n", 62 | "\n", 63 | "tensor_max = np.array(tensor_max)\n", 64 | "source_scale = np.median(source_norm)\n", 65 | "sol_scale = np.median(sol_max)\n", 66 | "tensor_scale = [np.median(tensor_max[:,j]) for j in range(num_ten)]\n", 67 | "\n", 68 | "scale = [source_scale] + tensor_scale + [sol_scale] + [lx, ly]\n", 69 | "print(scale)\n" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "41fee5b9-327d-4d8b-8a7b-ddc363e42baa", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "np.save(os.path.join(datapath, \"train_k1_5_scales.npy\"), scale)" 80 | ] 81 | } 82 | ], 83 | "metadata": { 84 | "kernelspec": { 85 | "display_name": "pytorch-1.9.0", 86 | "language": "python", 87 | "name": "pytorch-1.9.0" 88 | }, 89 | "language_info": { 90 | "codemirror_mode": { 91 | "name": "ipython", 92 | "version": 3 93 | }, 94 | "file_extension": ".py", 95 | "mimetype": "text/x-python", 96 | "name": "python", 97 | "nbconvert_exporter": "python", 98 | "pygments_lexer": "ipython3", 99 | "version": "3.8.11" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 5 104 | } 105 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /models/fno.py: -------------------------------------------------------------------------------- 1 | ''' from original FNO repo ''' 2 | import torch 3 | import torch.nn as nn 4 | from .basics import SpectralConv2dV2, _get_act 5 | 6 | 7 | class FNN2d(nn.Module): 8 | def __init__(self, modes1, modes2, 9 | width=64, fc_dim=128, 10 | layers=None, 11 | in_dim=3, out_dim=1, 12 | dropout=0, 13 | activation='tanh', 14 | mean_constraint=False): 15 | super(FNN2d, self).__init__() 16 | 17 | """ 18 | The overall network. It contains 4 layers of the Fourier layer. 19 | 1. Lift the input to the desire channel dimension by self.fc0 . 20 | 2. 4 layers of the integral operators u' = (W + K)(u). 21 | W defined by self.w; K defined by self.conv . 22 | 3. Project from the channel space to the output space by self.fc1 and self.fc2 . 23 | 24 | input: the solution of the coefficient function and locations (a(x, y), x, y) 25 | input shape: (batchsize, x=s, y=s, c=3) 26 | output: the solution 27 | output shape: (batchsize, x=s, y=s, c=1) 28 | """ 29 | 30 | self.modes1 = modes1 31 | self.modes2 = modes2 32 | self.width = width 33 | # input channel is 3: (a(x, y), x, y) 34 | if layers is None: 35 | self.layers = [width] * 4 36 | else: 37 | self.layers = layers 38 | self.fc0 = nn.Linear(in_dim, self.layers[0]) 39 | 40 | self.sp_convs = nn.ModuleList([SpectralConv2dV2( 41 | in_size, out_size, mode1_num, mode2_num) 42 | for in_size, out_size, mode1_num, mode2_num 43 | in zip(self.layers, self.layers[1:], self.modes1, self.modes2)]) 44 | 45 | self.dropout = nn.Dropout(p=dropout) 46 | 47 | self.ws = nn.ModuleList([nn.Conv1d(in_size, out_size, 1) 48 | for in_size, out_size in zip(self.layers, self.layers[1:])]) 49 | 50 | self.fc1 = nn.Linear(layers[-1], fc_dim) 51 | self.fc2 = nn.Linear(fc_dim, out_dim) 52 | self.activation = _get_act(activation) 53 | self.mean_constraint = mean_constraint 54 | 55 | def forward(self, x): 56 | ''' 57 | (b,c,h,w) -> (b,1,h,w) 58 | ''' 59 | length = len(self.ws) 60 | batchsize = x.shape[0] 61 | size_x, size_y = x.shape[2], x.shape[3] 62 | 63 | x = x.permute(0, 2, 3, 1) 64 | x = self.fc0(x) # project 65 | x = x.permute(0, 3, 1, 2) 66 | 67 | for i, (speconv, w) in enumerate(zip(self.sp_convs, self.ws)): 68 | x1 = speconv(x) 69 | x2 = w(x.view(batchsize, self.layers[i], -1)).view(batchsize, self.layers[i+1], size_x, size_y) 70 | x = x1 + x2 71 | if i != length - 1: 72 | x = self.activation(x) 73 | x = self.dropout(x) 74 | x = x.permute(0, 2, 3, 1) 75 | x = self.fc1(x) 76 | x = self.activation(x) 77 | x = self.dropout(x) 78 | x = self.fc2(x) 79 | x = self.dropout(x) 80 | x = x.permute(0, 3, 1, 2) 81 | 82 | if self.mean_constraint: 83 | x = x - torch.mean(x, dim=(-2,-1), keepdim=True) 84 | 85 | return x 86 | 87 | def fno(params): 88 | if params.mode_cut > 0: 89 | params.modes1 = [params.mode_cut]*len(params.modes1) 90 | params.modes2 = [params.mode_cut]*len(params.modes2) 91 | 92 | if params.embed_cut > 0: 93 | params.layers = [params.embed_cut]*len(params.layers) 94 | 95 | if params.fc_cut > 0 and params.embed_cut > 0: 96 | params.fc_dim = params.embed_cut * params.fc_cut 97 | 98 | input_dim = params.in_dim 99 | 100 | return FNN2d(params.modes1, params.modes2, layers=params.layers, fc_dim=params.fc_dim, 101 | in_dim=input_dim, out_dim=params.out_dim, dropout=params.dropout, 102 | activation='gelu', mean_constraint=(params.loss_func == 'pde')) 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Characterizing Scaling and Transfer Learning of Neural Networks in SciML 2 | This repository contains PyTorch code for training the Fourier Neural Operator on different PDE systems to characterize its scaling and transfer learning behavior on different downstream tasks. 3 | ![overview](assets/overview.png) 4 | 5 | ## Environment 6 | The necessary packages for this code can be installed via 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## Data Generation 12 | We consider three PDE systems: Poisson's, Advection-Diffusion, Helmholtz. For the data generation, we sample source functions and PDE coefficients to create train-val-test data splits. The layout for data generation is: 13 | 14 | - `utils/gen_data_[pde].py` generates the data for the PDE system [pde] for the three PDEs. 15 | - `run_gen_data.sh` is an example run script to generate data. See the comments in the script for different hyperparameters that control the data generation process. The ``--help`` option can also be used on the python scripts for more information. The paths to data and PDE coefficient ranges need to be set here. 16 | - The data is stored in ``hdf5`` format with ``fields`` value storing the source function (``0 index``) and the target (PDE) solution (``1 index``) and ``tensor`` storing the PDE coefficient values. For example, in the Poisson's PDE: 17 | ``` 18 | with h5py.File(os.path.join(path_to_data, ".h5"), "r") as f: 19 | x_train = f['fields'][:] 20 | x_tensor = f['tensor'][:] 21 | ``` 22 | stores ``x_train`` of shape ``(n_train, 2, 128, 128)`` for source and solution functions on grid size ``128x128`` and ``x_tensor`` of shape ``(n_train, 3)`` for storing three diffusion coefficient values corresponding to the diffusion tensor diagonal and off-diagonal values. 23 | - Once the data is generated, the scales for input normalization can be generated using ``utils/get_scale.ipynb``. The data paths and scales paths are passed to the configuration files for training. 24 | 25 | ## Training and Inference 26 | - Configuration files (in YAML format) are in `configs/` for different PDE systems. For example, config for Poisson's is in `configs/operator_poisson.yaml`. The main configs for the three systems are ``poisson-scale-k1_5``, ``ad-scale_adr0p2_1`` and ``helm-scale-o1_10``. The data paths and scales paths need to be set here. For example, the config at [configs/operator_poisson.yaml](config/operator_poisson.yaml) has the data setup and minimal hyperparameters as follows: 27 | ``` 28 | poisson-scale-k1_5: &poisson_scale_k1_5 # sampled eigenvalues are in (1,5) for diffusion 29 | <<: *poisson 30 | ... # can change other default configs from poisson if needed # 31 | ... 32 | train_path: # path to train data 33 | val_path: # path to validation data 34 | test_path: # path to test data 35 | scales_path: # path to train data scales for input normalization 36 | batch_size: # batch size for training 37 | valid_batch_size: # batch size for validation 38 | log_to_wandb: # switch on for logging to weights&biases 39 | mode_cut: # number of fourier modes to use 40 | embed_cut: # embedding dimension of FNO 41 | fc_cut: # multiplier for last fc layer 42 | ``` 43 | - Data, trainer, and other miscellaneous utilities are in `utils/`. We use standard PyTorch dataloaders and models wrapped with DDP for distributed data-parallel training with checkpointing. 44 | - The FNO model is the standard model and is in `models/`. The hyperparameters used are in the config files for the respective PDE systems. 45 | - Environment variables for DDP (local rank, master port etc) are set in `export_DDP_vars.sh` to be sourced before running any distributed training. See the [PyTorch DDP tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) for more details on using DDP. There are other ways to implement this, but our run script is specifically for slurm systems. 46 | - Example running scripts are in `run.sh` (4 GPU DDP train script). ``train.py`` is the training script (``utils/trainer.py``) is the trainer class and ``eval.py`` can be used for inference (``utils/inferencer.py`` is the inference class). See the run scripts for details. 47 | 48 | ### HPO tuning and scaling with W&B 49 | We use [Weights & Biases](https://wandb.ai/site) (see their docs for details) for logging and tuning of all experiments. The ``log_to_wandb`` flag in the configuration files can be set to ``True`` for this once you have a W&B login and project setup. For model scaling and data scaling, we use their [HPO sweep feature](https://docs.wandb.ai/guides/sweeps). An example sweep config that sweeps over different downstream dataset sizes and learning rates is in ``config/sweep_config.yaml``. We use the ``subsample`` hyperparameter to sub-sample our downstream dataset for the dataset scaling. For example, ``subsample=2`` implies only half of the data is used for training, etc. 50 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | data loaders 3 | """ 4 | import re 5 | import time 6 | import os, sys 7 | import logging 8 | import h5py 9 | import glob 10 | import torch 11 | import random 12 | import numpy as np 13 | from torch.utils.data import DataLoader, Dataset, TensorDataset 14 | from torch.utils.data.distributed import DistributedSampler 15 | 16 | 17 | def get_data_loader(params, location, distributed, train=True, pack=False): 18 | transform = torch.from_numpy 19 | dataset = PDESolns(params, location, transform, train) 20 | sampler = DistributedSampler(dataset, shuffle=train) if distributed else None 21 | if train: 22 | batch_size = params.local_batch_size 23 | else: 24 | batch_size = params.local_valid_batch_size 25 | if not pack: 26 | dataloader = DataLoader(dataset, 27 | batch_size=int(batch_size), 28 | num_workers=params.num_data_workers, 29 | shuffle=False,#(sampler is None), 30 | sampler=sampler, 31 | drop_last=True, 32 | pin_memory=torch.cuda.is_available()) 33 | else: 34 | # data is small, pack it all onto the gpu 35 | X = dataset.data[:,0:dataset.in_channels] 36 | y = dataset.data[:,dataset.in_channels:] 37 | X = torch.tensor(X, requires_grad=True).float().to(params.device) 38 | y = torch.tensor(y, requires_grad=True).float().to(params.device) 39 | tensor_dataset = TensorDataset(X, y) 40 | dataloader = torch.utils.data.DataLoader(tensor_dataset, batch_size=int(batch_size), shuffle=True) 41 | return dataloader, dataset, sampler 42 | 43 | 44 | class PDESolns(Dataset): 45 | def __init__(self, params, location, transform, train): 46 | self.transform = transform 47 | self.params = params 48 | self.location = location 49 | self.train = train 50 | if hasattr(self.params, "subsample") and (self.train): 51 | self.subsample = self.params.subsample 52 | else: 53 | self.subsample = 1 # subsample only if training 54 | self.scales = None 55 | self._get_files_stats() 56 | file = self._open_file(self.location) 57 | self.data = file['fields'] 58 | if 'tensor' in list(file.keys()): 59 | self.tensor = file['tensor'] 60 | else: 61 | self.tensor = None 62 | 63 | def _get_files_stats(self): 64 | self.file = self.location 65 | with h5py.File(self.file, 'r') as _f: 66 | logging.info("Getting file stats from {}".format(self.file)) 67 | self.n_samples = _f['fields'].shape[0] 68 | self.img_shape_x = _f['fields'].shape[2] 69 | self.img_shape_y = _f['fields'].shape[3] 70 | self.in_channels = _f['fields'].shape[1]-1 71 | if 'tensor' in list(_f.keys()): 72 | self.tensor_shape = _f['tensor'].shape[1] 73 | else: 74 | self.tensor_shape = 0 75 | self.n_samples /= self.subsample 76 | self.n_samples = int(self.n_samples) 77 | logging.info("Found data at path {}. Number of examples: {}. Image Shape: {} x {}".format(self.location, self.n_samples, self.img_shape_x, self.img_shape_y)) 78 | if hasattr(self.params, "scales_path"): 79 | self.scales = np.load(self.params.scales_path) 80 | self.scales = np.array([s if s != 0 else 1 for s in self.scales]) 81 | self.scales = self.scales.astype('float32') 82 | measure_x = self.scales[-2] / self.img_shape_x 83 | measure_y = self.scales[-1] / self.img_shape_y 84 | self.measure = measure_x * measure_y 85 | logging.info("Scales for PDE are (source, tensor, sol, domain): {}".format(self.scales)) 86 | logging.info("Measure of the set is lx/nx * ly/ny = {}/{} * {}/{}".format(self.scales[-2], self.img_shape_x, self.scales[-1], self.img_shape_y)) 87 | 88 | def __len__(self): 89 | return self.n_samples 90 | 91 | def _open_file(self, path): 92 | return h5py.File(path, 'r') 93 | 94 | def __getitem__(self, idx): 95 | local_idx = int(idx*self.subsample) 96 | X = (self.data[local_idx,0:self.in_channels]) 97 | if self.tensor: # append coefficient tensor to channels 98 | tensor = [] 99 | for tidx in range(self.tensor_shape): 100 | coef = np.full((1, self.img_shape_x, self.img_shape_y), self.tensor[local_idx,tidx]) 101 | tensor.append(coef) 102 | X = np.concatenate([X] + tensor, axis=0).astype('float32') 103 | 104 | if self.scales is not None: 105 | f_norm = np.linalg.norm(X[0]) * self.measure 106 | f_scaling = f_norm / self.scales[0] 107 | X = X / f_scaling # ensures that 10f and 10k for example, have the same input 108 | # scale the tensors 109 | X[self.in_channels:] = X[self.in_channels:] / self.scales[self.in_channels:(self.in_channels + self.tensor_shape), None, None] 110 | 111 | X = self.transform(X) 112 | y = self.transform(self.data[local_idx,self.in_channels:]) 113 | return X, y 114 | 115 | 116 | -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | misc utils 3 | """ 4 | 5 | import numpy as np 6 | import scipy as sc 7 | import scipy.ndimage as nd 8 | import torch 9 | import torch.nn as nn 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | from mpl_toolkits.axes_grid1 import make_axes_locatable 13 | from matplotlib.colors import Normalize 14 | from matplotlib import cm 15 | from torch.nn import functional as F 16 | 17 | def fft_coef(n, is_torch=False): 18 | if not is_torch: 19 | ikx_pos = 1j * np.arange(0, n//2+1, 1) 20 | ikx_neg = 1j * np.arange(-n//2+1, 0, 1) 21 | ikx = np.concatenate((ikx_pos, ikx_neg)) 22 | else: 23 | ikx_pos = 1j * torch.arange(0, n/2+1, 1) 24 | ikx_neg = 1j * torch.arange(-n/2+1, 0, 1) 25 | ikx = torch.cat((ikx_pos, ikx_neg)) 26 | return ikx 27 | 28 | def get_fft_coef(u, is_torch=False): 29 | nx, ny = u.shape 30 | ikx = fft_coef(nx, is_torch).reshape(1,nx) 31 | iky = fft_coef(ny, is_torch).reshape(ny,1) 32 | if not is_torch: 33 | ikx = np.repeat(ikx, ny, axis=0) 34 | iky = np.repeat(iky, nx, axis=1) 35 | else: 36 | # need to check if this is the same.. 37 | ikx = torch.repeat_interleave(ikx, ny, dim=0) 38 | iky = torch.repeat_interleave(iky, nx, dim=1) 39 | return ikx, iky 40 | 41 | def laplacian(u, params): 42 | ikx, iky = get_fft_coef(u) 43 | ikx2 = ikx**2 44 | iky2 = iky**2 45 | u_hat = np.fft.fft2(u) 46 | u_hat *= (ikx2+iky2) * (4.0 * np.pi**2)/(params.Lx*params.Ly) 47 | return np.real(np.fft.ifft2(u_hat)) 48 | 49 | def grad(u, params): 50 | ikx, iky = get_fft_coef(u) 51 | u_hat = np.fft.fft2(u) 52 | ux = np.real(np.fft.ifft2(u_hat * ikx)) * (2.0 * np.pi)/params.Lx 53 | uy = np.real(np.fft.ifft2(u_hat * iky)) * (2.0 * np.pi)/params.Ly 54 | return ux, uy 55 | 56 | def gradabs(u, params): 57 | ux, uy = grad(u, params) 58 | return np.sqrt(ux**2 + uy**2) 59 | 60 | def div(ux, uy, params): 61 | ikx, iky = get_fft_coef(ux) 62 | ux_hat = np.fft.fft2(ux) 63 | uy_hat = np.fft.fft2(uy) 64 | u1 = np.real(np.fft.ifft2(ux_hat * ikx)) * (2.0 * np.pi)/params.Lx 65 | u2 = np.real(np.fft.ifft2(uy_hat * iky)) * (2.0 * np.pi)/params.Ly 66 | return (u1 + u2) 67 | 68 | def diffusion_coef(x, y, freq, scale, torch_tensor=False): 69 | if not torch_tensor: 70 | return scale * (1 + np.sin(2*np.pi*freq*x) * np.sin(2*np.pi*freq*y)) 71 | else: 72 | return scale * (1 + torch.sin(2*np.pi*freq*x) * torch.sin(2*np.pi*freq*y)) 73 | 74 | def diffusion_op(u, k, K, nx, ny, params): 75 | u = u.reshape((nx, ny)) 76 | gradux, graduy = grad(u, params) 77 | # diff tensor 78 | Kux = K['k11']*gradux + K['k12']*graduy 79 | Kuy = K['k22']*graduy + K['k12']*gradux 80 | # heterogeneous 81 | Kux *= k 82 | Kuy *= k 83 | return div(Kux, Kuy, params) 84 | 85 | def rbf(x, y, p, sigma=1/64, spacing=2/64, ng=256, center=(0.5,0.5), torch_tensor=False): 86 | """ create an rbf grid basis functions """ 87 | num = np.sqrt(ng) # num of centers in each direction 88 | l = (num - 1) * spacing # length of grid in each direction 89 | 90 | centers_x = np.arange(center[0]-l/2, center[0]+l/2+spacing, spacing) 91 | centers_y = np.arange(center[1]-l/2, center[1]+l/2+spacing, spacing) 92 | centers_x, centers_y = np.meshgrid(centers_x, centers_y) 93 | ratio = [] 94 | c = [] 95 | 96 | for cx, cy in zip(centers_x.flatten(), centers_y.flatten()): 97 | r = (x - cx)**2 + (y - cy)**2 98 | R = 2 * sigma**2 99 | ratio.append(r/R) 100 | c.append(1./(2*np.pi*sigma**2)) # normalizing factor 101 | 102 | source = 0*ratio[0] 103 | # compute for data; need dc component here 104 | idx = 0 105 | for r, ci in zip(ratio, c): 106 | source += p[idx]*ci*np.exp(-r) 107 | idx += 1 108 | dc = np.mean(source.flatten()) 109 | source = source - dc 110 | 111 | return source 112 | 113 | 114 | def softmax(x): 115 | return np.exp(x)/sum(np.exp(x)) 116 | 117 | def normalize(x): 118 | return (x - np.min(x)) / (np.max(x) - np.min(x)) 119 | 120 | def compute_grad_norm(p_list): 121 | grad_norm = 0 122 | for p in p_list: 123 | param_g_norm = p.grad.detach().data.norm(2) 124 | grad_norm += param_g_norm.item()**2 125 | grad_norm = grad_norm**0.5 126 | return grad_norm 127 | 128 | def l2_err(pred, target): 129 | x = torch.sum((pred-target)**2, dim=(-1,-2))/torch.sum(target**2, dim=(-1,-2)) 130 | x = torch.sqrt(x) 131 | return torch.mean(x, dim=0) 132 | 133 | def compute_err(output, target): 134 | err = output - target 135 | return np.linalg.norm(err[:])/np.linalg.norm(target[:]) 136 | 137 | def show(u, ax, fig, rescale=None): 138 | if u is not None: 139 | if rescale is None: 140 | h = ax.imshow(u.T, interpolation='nearest', cmap='rainbow', 141 | origin='lower', aspect='auto') 142 | else: 143 | h = ax.imshow(u.T, interpolation='nearest', cmap='rainbow', 144 | origin='lower', aspect='auto', vmin=rescale[0], vmax=rescale[1]) 145 | divider = make_axes_locatable(ax) 146 | cax = divider.append_axes("right", size="5%", pad=0.10) 147 | cbar = fig.colorbar(h, cax=cax) 148 | cbar.ax.tick_params(labelsize=15) 149 | ax.tick_params(labelsize=15) 150 | 151 | 152 | def vis_fields(fields, params, domain): 153 | source, target, pred, pde_res, temp = fields 154 | err = np.abs(pred - target) 155 | x_g = domain.x_g 156 | y_g = domain.y_g 157 | scale = [params.ny/params.Ly, params.nx/params.Lx] 158 | 159 | fx = 17 160 | fy = 8 161 | fig = plt.figure(figsize=(fx,fy)) 162 | ax1 = fig.add_subplot(2,3,1) 163 | show(source, ax1, fig) 164 | ax1.contour(y_g*scale[0], x_g*scale[1], source, 15) 165 | ax1.set_title("source") 166 | ax2 = fig.add_subplot(2,3,2) 167 | show(target, ax2, fig) 168 | ax2.set_title("target") 169 | ax2.contour(y_g*scale[0], x_g*scale[1], target, 15) 170 | ax3 = fig.add_subplot(2,3,3) 171 | show(pred, ax3, fig) 172 | ax3.set_title("pred") 173 | ax3.contour(y_g*scale[0], x_g*scale[1], pred, 15) 174 | ax4 = fig.add_subplot(2,3,4) 175 | show(pde_res, ax4, fig) 176 | ax4.set_title("pde-res") 177 | # ax4.contour(y_g*scale[0], x_g*scale[1], temp, 15) 178 | ax5 = fig.add_subplot(2,3,5) 179 | show(temp, ax5, fig) 180 | ax5.set_title("temp") 181 | # ax5.contour(y_g*scale[0], x_g*scale[1], temp, 15) 182 | ax6 = fig.add_subplot(2,3,6) 183 | show(err, ax6, fig) 184 | ax6.set_title("err") 185 | ax6.contour(y_g*scale[0], x_g*scale[1], err, 15) 186 | 187 | fig.tight_layout() 188 | 189 | return fig 190 | 191 | def vis_fields_many(fields, params, domain): 192 | x_g = domain.x_g 193 | y_g = domain.y_g 194 | scale = [params.ny/params.Ly, params.nx/params.Lx] 195 | 196 | fx = 30 197 | fy = 17 198 | fig = plt.figure(figsize=(fx,fy)) 199 | num_examples = len(fields)//3 200 | n_cols = 4 201 | for i in range(num_examples): 202 | ax1 = fig.add_subplot(num_examples,n_cols,n_cols*i+1) 203 | source = fields[3*i] 204 | show(source, ax1, fig) 205 | ax1.contour(y_g*scale[0], x_g*scale[1], source, 15) 206 | ax1.set_title("source") 207 | ax2 = fig.add_subplot(num_examples,n_cols,n_cols*i+2) 208 | target = fields[3*i+1] 209 | show(target, ax2, fig) 210 | ax2.set_title("target") 211 | ax2.contour(y_g*scale[0], x_g*scale[1], target, 15) 212 | ax3 = fig.add_subplot(num_examples,n_cols,n_cols*i+3) 213 | pred = fields[3*i+2] 214 | show(pred, ax3, fig) 215 | ax3.contour(y_g*scale[0], x_g*scale[1], pred, 15) 216 | ax3.set_title("pred") 217 | ax4 = fig.add_subplot(num_examples,n_cols,n_cols*i+4) 218 | err = np.abs(pred-target) 219 | show(err, ax4, fig) 220 | ax4.contour(y_g*scale[0], x_g*scale[1], err, 15) 221 | ax4.set_title("err") 222 | 223 | fig.tight_layout() 224 | 225 | return fig 226 | 227 | def set_activation(activation): 228 | if activation == 'identity': 229 | return nn.Identity() 230 | elif activation == 'tanh': 231 | return nn.Tanh() 232 | elif activation == 'relu': 233 | return nn.ReLU() 234 | elif activation == 'gelu': 235 | return nn.GELU() 236 | else: 237 | print("WARNING: invalid activation function!") 238 | return -1 239 | -------------------------------------------------------------------------------- /utils/gen_data_helmholtz.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | from matplotlib.colors import Normalize 10 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') 11 | from utils.misc_utils import show, fft_coef, grad, div, laplacian 12 | from types import SimpleNamespace 13 | import random 14 | import h5py 15 | import time 16 | 17 | def rbf(x, y, p, sigma=1/64, spacing=2/64, ng=256, center=(0.5,0.5)): 18 | """ create an rbf grid basis functions """ 19 | num = np.sqrt(ng) # num of centers in each direction 20 | l = (num - 1) * spacing # length of grid in each direction 21 | 22 | centers_x = np.arange(center[0]-l/2, center[0]+l/2+spacing, spacing) 23 | centers_y = np.arange(center[1]-l/2, center[1]+l/2+spacing, spacing) 24 | centers_x, centers_y = np.meshgrid(centers_x, centers_y) 25 | ratio = [] 26 | c = [] 27 | 28 | 29 | for cx, cy in zip(centers_x.flatten(), centers_y.flatten()): 30 | r = (x - cx)**2 + (y - cy)**2 31 | R = 2 * sigma**2 32 | ratio.append(r/R) 33 | c.append(1./(2*np.pi*sigma**2)) # normalizing factor 34 | 35 | source = 0*ratio[0] 36 | 37 | idx = 0 38 | for r, ci in zip(ratio, c): 39 | source += p[idx]*1*np.exp(-r) 40 | idx += 1 41 | 42 | source = source/np.max(source) 43 | 44 | return source 45 | 46 | def helm_op(u, omega, nx, ny, params): 47 | u = u.reshape((nx, ny)) 48 | lap = params.diff_coef_scale * laplacian(u, params) 49 | lap += omega * u 50 | return lap 51 | 52 | 53 | def helmholtz(x, y, std, space, vf, params): 54 | x_g, y_g = np.meshgrid(x, y) 55 | 56 | sol_max = 100 57 | while (sol_max > 2): # ignore solutions with large values 58 | ng = params.ng 59 | p = np.zeros(ng) 60 | min_act = 1E-3 # avoid zeros 61 | max_act = 1 62 | 63 | omega = np.random.randint(params.o1, params.o2+1) 64 | 65 | if not params.sparse: 66 | p = np.random.rand(ng,1) 67 | for i in range(ng): 68 | alpha = np.random.rand() 69 | if alpha > vf: 70 | p[i] = (max_act - min_act) * np.random.rand() + min_act 71 | 72 | all_zeros = not np.any(p) 73 | if all_zeros: 74 | randidx = np.random.randint(0, len(p)) 75 | p[randidx] = (max_act - min_act) * np.random.rand() + min_act 76 | 77 | source = rbf(x_g, y_g, p, ng=ng, sigma=std, spacing=space) 78 | 79 | nx = x.shape[0] 80 | ny = y.shape[0] 81 | ikx = fft_coef(nx).reshape(1,nx) 82 | ikx = np.repeat(ikx, ny, axis=0) 83 | iky = fft_coef(ny).reshape(ny,1) 84 | iky = np.repeat(iky, nx, axis=1) 85 | ikx2 = ikx**2 86 | iky2 = iky**2 87 | 88 | f_hat = np.fft.fft2(source) 89 | 90 | ik_factor = ikx2 + iky2 91 | ik_factor *= (4.0 * np.pi**2) / (params.Lx * params.Ly) * params.diff_coef_scale 92 | factor = (omega + ik_factor) 93 | condn = (factor == 0) 94 | 95 | f_hat = np.where(condn, 0, f_hat) 96 | source = source if np.all(~condn[:]) else np.real(np.fft.ifft2(f_hat)) 97 | 98 | factor = np.where(condn, 0, -1/factor) # set to zero in undefined places in freq space 99 | u_hat = factor * f_hat 100 | 101 | u = np.real(np.fft.ifft2(u_hat)) 102 | 103 | # check the result by seeing norm(LHS) == norm(RHS) 104 | lhs = helm_op(u, omega, nx, ny, params) 105 | lhs_norm = np.linalg.norm(lhs) 106 | rhs_norm = np.linalg.norm(source) 107 | if (np.abs(lhs_norm - rhs_norm) > 1E-5): 108 | print("INACCURATE SOLUTION!") 109 | 110 | sol_max = np.max(np.abs(u[:])) 111 | 112 | 113 | return u, source, omega 114 | 115 | 116 | def create_hdf5(path, dat, ten): 117 | with h5py.File(path, "a") as f: 118 | try: 119 | f.create_dataset('fields', dat.shape, dtype=' vf: 87 | p[i] = (max_act - min_act) * np.random.rand() + min_act 88 | 89 | all_zeros = not np.any(p) 90 | if all_zeros: 91 | randidx = np.random.randint(0, len(p)) 92 | p[randidx] = (max_act - min_act) * np.random.rand() + min_act 93 | 94 | source = rbf(x_g, y_g, p, ng=ng, sigma=std, spacing=space) 95 | 96 | nx = x.shape[0] 97 | ny = y.shape[0] 98 | ikx = fft_coef(nx).reshape(1,nx) 99 | ikx = np.repeat(ikx, ny, axis=0) 100 | iky = fft_coef(ny).reshape(ny,1) 101 | iky = np.repeat(iky, nx, axis=1) 102 | ikx2 = ikx**2 103 | iky2 = iky**2 104 | 105 | f_hat = np.fft.fft2(source) 106 | 107 | diff_factor = ikx2*K['k11'] + iky2*K['k22'] + 2*ikx*iky*K['k12'] 108 | diff_factor *= (4.0 * np.pi**2) / (params.Lx * params.Ly) # not a 2pi domain, but unit cube 109 | factor = params.diff_coef_scale * diff_factor 110 | factor = np.where(factor == 0, 0, -1/factor) # zeroth mode; set to zero 111 | u_hat = factor * f_hat 112 | 113 | u = np.real(np.fft.ifft2(u_hat)) 114 | u = u - np.mean(u.flatten()) # remove the dc component 115 | 116 | k = np.array([K['k11'], K['k22'], K['k12']]) 117 | 118 | return u, source, k 119 | 120 | 121 | def create_hdf5(path, dat, ten): 122 | with h5py.File(path, "a") as f: 123 | try: 124 | f.create_dataset('fields', dat.shape, dtype=' 0: 42 | torch.cuda.manual_seed_all(seed) 43 | 44 | def count_parameters(model): 45 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 46 | return params/1000000 47 | 48 | class Inferencer(): 49 | """ trainer class """ 50 | def __init__(self, params, args): 51 | self.sweep = args.sweep 52 | self.root_dir = args.root_dir 53 | self.config = args.config 54 | self.run_num = args.run_num 55 | self.world_size = 1 56 | if 'WORLD_SIZE' in os.environ: 57 | self.world_size = int(os.environ['WORLD_SIZE']) 58 | 59 | self.local_rank = 0 60 | self.world_rank = 0 61 | if self.world_size > 1: 62 | dist.init_process_group(backend='nccl', 63 | init_method='env://') 64 | self.world_rank = dist.get_rank() 65 | self.local_rank = int(os.environ["LOCAL_RANK"]) 66 | 67 | if torch.cuda.is_available(): 68 | torch.cuda.set_device(self.local_rank) 69 | torch.backends.cudnn.benchmark = True 70 | 71 | self.log_to_screen = params.log_to_screen and self.world_rank==0 72 | self.log_to_wandb = False # turn off for inference; params.log_to_wandb and self.world_rank==0 73 | params['name'] = args.config + '_' + args.run_num 74 | params['group'] = 'op_' + args.config 75 | if torch.cuda.is_available(): 76 | self.device = torch.cuda.current_device() 77 | else: 78 | self.device = torch.device('cpu') 79 | self.params = params 80 | self.params.device = self.device 81 | self._build() 82 | 83 | def _build(self): 84 | # init wandb 85 | if self.sweep != 'none': 86 | exp_dir = os.path.join(*[self.root_dir, 'sweeps', self.sweep, self.config, 'inference']) 87 | else: 88 | exp_dir = os.path.join(*[self.root_dir, 'expts', self.config, self.run_num]) 89 | 90 | if self.world_rank==0: 91 | if not os.path.isdir(exp_dir): 92 | os.makedirs(exp_dir) 93 | os.makedirs(os.path.join(exp_dir, 'wandb/')) 94 | 95 | self.params['experiment_dir'] = os.path.abspath(exp_dir) 96 | if self.log_to_wandb: 97 | wandb.init(dir=os.path.join(exp_dir, "wandb"), 98 | config=self.params.params, name=self.params.name, group=self.params.group, project=self.params.project, 99 | entity=self.params.entity) 100 | 101 | set_seed(self.params, self.world_size) 102 | 103 | self.params['global_batch_size'] = self.params.batch_size 104 | self.params['local_batch_size'] = int(self.params.batch_size//self.world_size) 105 | self.params['global_valid_batch_size'] = self.params.valid_batch_size 106 | self.params['local_valid_batch_size'] = int(self.params.valid_batch_size//self.world_size) 107 | 108 | if self.world_rank==0: 109 | self.params.log() 110 | 111 | def launch(self): 112 | self.test_data_loader, self.test_dataset, _ = get_data_loader(self.params, self.params.test_path, dist.is_initialized(), train=False, pack=self.params.pack_data) 113 | 114 | # domain grid 115 | self.domain = DomainXY(self.params) 116 | 117 | if self.params.model == 'fno': 118 | self.model = models.fno.fno(self.params).to(self.device) 119 | else: 120 | assert(False), "Error, model arch invalid." 121 | 122 | if dist.is_initialized(): 123 | self.model = DistributedDataParallel(self.model, 124 | device_ids=[self.local_rank], 125 | output_device=[self.local_rank]) 126 | 127 | if self.params.loss_func == "mse": 128 | self.loss_func = LossMSE(self.params, self.model) 129 | else: 130 | assert(False), "Error, loss func invalid." 131 | 132 | if hasattr(self.params, 'weights'): 133 | self.params['checkpoint_path'] = self.params['weights'] 134 | logging.info("Loading checkpoint %s"%self.params.checkpoint_path) 135 | self.restore_checkpoint(self.params.checkpoint_path) 136 | 137 | if self.log_to_screen: 138 | print("model wt norm = {}".format(self.get_model_wt_norm(self.model))) 139 | 140 | self.logs = {} 141 | test_time, fields = self.test() 142 | logging.info("testing time = {}".format(test_time)) 143 | if self.log_to_wandb: 144 | # log visualizations every epoch 145 | fig = vis_fields_many(fields, self.params, self.domain) 146 | self.logs['vis'] = wandb.Image(fig) 147 | plt.close(fig) 148 | wandb.log(self.logs, step=1) 149 | wandb.finish() 150 | 151 | def get_model_wt_norm(self, model): 152 | n = 0 153 | for p in model.parameters(): 154 | p_norm = p.data.detach().norm(2) 155 | n += p_norm.item()**2 156 | n = n**0.5 157 | return n 158 | 159 | def test(self): 160 | self.model.eval() 161 | #self.model.train() # need gradients 162 | test_start = time.time() 163 | 164 | logs_buff = torch.zeros((2), dtype=torch.float32, device=self.device) 165 | self.logs['test_err'] = logs_buff[0].view(-1) 166 | self.logs['test_loss'] = logs_buff[1].view(-1) 167 | 168 | num_examples = 3 169 | idx = [np.random.randint(0, len(self.test_data_loader)) for _ in range(num_examples)] # index of batch 170 | img_idx = [np.random.randint(0, self.params.local_valid_batch_size) for _ in range(num_examples)] # index within batch 171 | fields = [] 172 | ii = 0 173 | 174 | bs = self.params.local_valid_batch_size 175 | 176 | with torch.no_grad(): 177 | for i, (inputs, targets) in enumerate(self.test_data_loader): 178 | if not self.params.pack_data: 179 | inputs, targets = inputs.to(self.device), targets.to(self.device) 180 | u = self.model(inputs) 181 | loss_data = self.loss_func.data(inputs, u, targets) 182 | loss_pde = self.loss_func.pde(inputs, u, targets) 183 | loss_bc = self.loss_func.bc(inputs, u, targets) 184 | loss = loss_data + loss_bc + loss_pde 185 | 186 | self.logs['test_err'] += l2_err(u.detach(), targets.detach()) # computes rel l2 err of each image and averages across batches 187 | self.logs['test_loss'] += loss.detach() 188 | if i in idx: 189 | source = inputs[img_idx[ii],0].detach().cpu().numpy() 190 | soln = targets[img_idx[ii],0].detach().cpu().numpy() 191 | pred = u[img_idx[ii],0].detach().cpu().numpy() 192 | fields.extend([source, soln, pred]) 193 | ii += 1 194 | 195 | self.logs['test_err'] /= len(self.test_data_loader) 196 | self.logs['test_loss'] /= len(self.test_data_loader) 197 | 198 | if dist.is_initialized(): 199 | for key in ['test_loss', 'test_err']: 200 | dist.all_reduce(self.logs[key].detach()) 201 | self.logs[key] = float(self.logs[key]/dist.get_world_size()) 202 | else: 203 | for key in ['test_loss', 'test_err']: 204 | self.logs[key] = float(self.logs[key]) 205 | 206 | if self.log_to_screen: 207 | print(self.logs) 208 | 209 | #self.save_logs(tag="_ckpt") 210 | self.save_logs(tag="_best") 211 | 212 | test_time = time.time() - test_start 213 | 214 | return test_time, fields 215 | 216 | def restore_checkpoint(self, checkpoint_path): 217 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank)) 218 | try: 219 | self.model.load_state_dict(checkpoint['model_state']) 220 | except: 221 | new_state_dict = OrderedDict() 222 | for key, val in checkpoint['model_state'].items(): 223 | name = key[7:] 224 | new_state_dict[name] = val 225 | self.model.load_state_dict(new_state_dict) 226 | 227 | 228 | def save_logs(self, tag=""): 229 | with open(os.path.join(self.params.experiment_dir, "logs"+tag+".txt"), "w") as f: 230 | for k, v in self.logs.items(): 231 | f.write("{},{}\n".format(k,v)) 232 | -------------------------------------------------------------------------------- /utils/gen_data_advdiff.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | from matplotlib.colors import Normalize 10 | sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') 11 | from utils.misc_utils import show, fft_coef, grad, div 12 | from types import SimpleNamespace 13 | import random 14 | import h5py 15 | import time 16 | 17 | def rbf(x, y, p, sigma=1/64, spacing=2/64, ng=256, center=(0.5,0.5)): 18 | """ create an rbf grid basis functions """ 19 | num = np.sqrt(ng) # num of centers in each direction 20 | l = (num - 1) * spacing # length of grid in each direction 21 | 22 | centers_x = np.arange(center[0]-l/2, center[0]+l/2+spacing, spacing) 23 | centers_y = np.arange(center[1]-l/2, center[1]+l/2+spacing, spacing) 24 | centers_x, centers_y = np.meshgrid(centers_x, centers_y) 25 | ratio = [] 26 | c = [] 27 | 28 | 29 | for cx, cy in zip(centers_x.flatten(), centers_y.flatten()): 30 | r = (x - cx)**2 + (y - cy)**2 31 | R = 2 * sigma**2 32 | ratio.append(r/R) 33 | c.append(1./(2*np.pi*sigma**2)) # normalizing factor 34 | 35 | source = 0*ratio[0] 36 | 37 | idx = 0 38 | for r, ci in zip(ratio, c): 39 | source += p[idx]*1*np.exp(-r) 40 | idx += 1 41 | 42 | source = source/np.max(source) 43 | dc = np.mean(source.flatten()) 44 | source = source - dc # remove integral of source 45 | 46 | return source 47 | 48 | def get_random_diffusion_tensor(): 49 | ''' create a random diff tensor by controlling eigenvalues ''' 50 | e1 = 1 51 | e2 = 5 52 | a1 = 1 53 | a4 = e1 + np.random.rand() * (e2 - e1) # random btw eigvals e1 and e2 54 | 55 | theta = np.random.rand() * 2 * np.pi # random rotation 56 | rot = np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]]) 57 | theta_neg = -1 * theta 58 | rot_neg = np.array([[np.cos(theta_neg),-np.sin(theta_neg)],[np.sin(theta_neg),np.cos(theta_neg)]]) 59 | A = np.array([[a1, 0], [0, a4]]) 60 | A = rot_neg @ (A @ rot) # similarity transf to preserve eigs 61 | 62 | return A 63 | 64 | def get_random_velocity_vector(): 65 | ''' create a random velocity vector by sampling on a circle ''' 66 | theta = np.random.rand() * 2 * np.pi 67 | r = 1 68 | return [r*np.cos(theta), r*np.sin(theta)] 69 | 70 | def advection_op(u, v, nx, ny, params): 71 | ''' does v.grad(u) ''' 72 | u = u.reshape((nx, ny)) 73 | gradux, graduy = grad(u, params) 74 | return (v['v1']*gradux + v['v2']*graduy) 75 | 76 | def diffusion_op(u, k, K, nx, ny, params): 77 | ''' does div(Kgradu) ''' 78 | u = u.reshape((nx, ny)) 79 | gradux, graduy = grad(u, params) 80 | # diff tensor 81 | Kux = K['k11']*gradux + K['k12']*graduy 82 | Kuy = K['k22']*graduy + K['k12']*gradux 83 | # heterogeneous 84 | Kux *= k 85 | Kuy *= k 86 | return div(Kux, Kuy, params) 87 | 88 | def advdiff(x, y, std, space, vf, params): 89 | """ -v.\gradu + \lapu = -f """ 90 | k_mat = get_random_diffusion_tensor() 91 | K = {'k11': k_mat[0,0], 'k22': k_mat[1,1], 'k12': k_mat[0,1]} 92 | vel = get_random_velocity_vector() 93 | v = {'v1': vel[0], 'v2': vel[1]} # just for readability make a dict 94 | 95 | ad1 = params.ad1 96 | ad2 = params.ad2 97 | adr = ad1 + np.random.rand() * (ad2 - ad1) 98 | # sample lamda based on adr 99 | lams = np.load("utils/lambda.npy") 100 | ads = np.load("utils/ads.npy") 101 | means = np.mean(ads, axis=1) 102 | lam = lams[closest(means, adr)] 103 | 104 | params.lam = lam 105 | 106 | x_g, y_g = np.meshgrid(x, y) 107 | 108 | # create a source function 109 | ng = params.ng 110 | p = np.zeros(ng) 111 | 112 | min_act = 1E-3 # avoid zeros 113 | max_act = 1 114 | 115 | for i in range(ng): 116 | alpha = np.random.rand() 117 | if alpha > vf: 118 | p[i] = (max_act - min_act) * np.random.rand() + min_act 119 | 120 | all_zeros = not np.any(p) 121 | if all_zeros: 122 | randidx = np.random.randint(0, len(p)) 123 | p[randidx] = (max_act - min_act) * np.random.rand() + min_act 124 | 125 | 126 | source = rbf(x_g, y_g, p, ng=ng, sigma=std, spacing=space) # linear combination of rbfs 127 | 128 | nx = x.shape[0] 129 | ny = y.shape[0] 130 | ikx = fft_coef(nx).reshape(1,nx) 131 | ikx = np.repeat(ikx, ny, axis=0) 132 | iky = fft_coef(ny).reshape(ny,1) 133 | iky = np.repeat(iky, nx, axis=1) 134 | ikx2 = ikx**2 135 | iky2 = iky**2 136 | 137 | 138 | f_hat = np.fft.fft2(source) # RHS 139 | 140 | diff_factor = ikx2*K['k11'] + iky2*K['k22'] + 2*ikx*iky*K['k12'] 141 | diff_factor *= (4.0 * np.pi**2) / (params.Lx * params.Ly) # not a 2pi domain, but unit cube 142 | adv_factor = v['v1']*ikx + v['v2']*iky 143 | adv_factor *= (2.0 * np.pi) / (params.Lx) # implicity assumed Lx = Ly; TODO: be careful 144 | factor = (1 - params.lam) * params.diff_coef_scale * diff_factor - params.lam * params.adv_coef_scale * adv_factor 145 | 146 | factor = np.where(factor == 0, 0, -1/factor) # zeroth mode; set to zero 147 | u_hat = factor * f_hat 148 | 149 | u = np.real(np.fft.ifft2(u_hat)) 150 | u = u - np.mean(u.flatten()) # remove the dc component 151 | 152 | au = (params.lam) * params.adv_coef_scale * advection_op(u, v, nx, ny, params) 153 | du = (1 - params.lam) * params.diff_coef_scale * diffusion_op(u, 1, K, nx, ny, params) 154 | au = np.linalg.norm(au) 155 | du = np.linalg.norm(du) 156 | ratio = au/du 157 | 158 | 159 | k = np.array([K['k11'], K['k22'], K['k12']]) 160 | v = np.array([v['v1'], v['v2']]) # for dataset creating and interfacing with training code 161 | 162 | return u, source, k, v, ratio 163 | 164 | 165 | def create_hdf5(path, dat, ten, other): 166 | with h5py.File(path, "a") as f: 167 | try: 168 | f.create_dataset('fields', dat.shape, dtype=' (batch, out_channel, x) 27 | return torch.einsum("bix,iox->box", a, b) 28 | 29 | 30 | def compl_mul2d(a, b): 31 | # (batch, in_channel, x,y,t ), (in_channel, out_channel, x,y,t) -> (batch, out_channel, x,y,t) 32 | return torch.einsum("bixy,ioxy->boxy", a, b) 33 | 34 | 35 | def compl_mul3d(a, b): 36 | return torch.einsum("bixyz,ioxyz->boxyz", a, b) 37 | 38 | @torch.jit.script 39 | def compl_mul2d_v2(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 40 | tmp = torch.einsum("bixys,ioxyt->stboxy", a, b) 41 | return torch.stack([tmp[0,0,:,:,:,:] - tmp[1,1,:,:,:,:], tmp[1,0,:,:,:,:] + tmp[0,1,:,:,:,:]], dim=-1) 42 | 43 | ################################################################ 44 | # 1d fourier layer 45 | ################################################################ 46 | 47 | 48 | class SpectralConv1d(nn.Module): 49 | def __init__(self, in_channels, out_channels, modes1): 50 | super(SpectralConv1d, self).__init__() 51 | 52 | """ 53 | 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. 54 | """ 55 | 56 | self.in_channels = in_channels 57 | self.out_channels = out_channels 58 | # Number of Fourier modes to multiply, at most floor(N/2) + 1 59 | self.modes1 = modes1 60 | 61 | self.scale = (1 / (in_channels*out_channels)) 62 | self.weights1 = nn.Parameter( 63 | self.scale * torch.rand(in_channels, out_channels, self.modes1, 2)) 64 | 65 | def forward(self, x): 66 | batchsize = x.shape[0] 67 | # Compute Fourier coeffcients up to factor of e^(- something constant) 68 | x_ft = torch.fft.rfftn(x, dim=[2]) 69 | 70 | # Multiply relevant Fourier modes 71 | out_ft = torch.zeros(batchsize, self.in_channels, x.size(-1)//2 + 1, device=x.device, dtype=torch.cfloat) 72 | out_ft[:, :, :self.modes1] = compl_mul1d(x_ft[:, :, :self.modes1], self.weights1) 73 | 74 | # Return to physical space 75 | x = torch.fft.irfft(out_ft, s=[x.size(-1)], dim=[2]) 76 | return x 77 | 78 | ################################################################ 79 | # 2d fourier layer 80 | ################################################################ 81 | 82 | 83 | class SpectralConv2d(nn.Module): 84 | def __init__(self, in_channels, out_channels, modes1, modes2): 85 | super(SpectralConv2d, self).__init__() 86 | self.in_channels = in_channels 87 | self.out_channels = out_channels 88 | # Number of Fourier modes to multiply, at most floor(N/2) + 1 89 | self.modes1 = modes1 90 | self.modes2 = modes2 91 | 92 | self.scale = (1 / (in_channels * out_channels)) 93 | self.weights1 = nn.Parameter( 94 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 95 | self.weights2 = nn.Parameter( 96 | self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) 97 | 98 | def forward(self, x, gridy=None): 99 | batchsize = x.shape[0] 100 | size1 = x.shape[-2] 101 | size2 = x.shape[-1] 102 | # Compute Fourier coeffcients up to factor of e^(- something constant) 103 | x_ft = torch.fft.rfftn(x, dim=[2, 3]) 104 | 105 | if gridy is None: 106 | # Multiply relevant Fourier modes 107 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, device=x.device, 108 | dtype=torch.cfloat) 109 | out_ft[:, :, :self.modes1, :self.modes2] = \ 110 | compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 111 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 112 | compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 113 | 114 | # Return to physical space 115 | x = torch.fft.irfftn(out_ft, s=(x.size(-2), x.size(-1)), dim=[2, 3]) 116 | else: 117 | factor1 = compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 118 | factor2 = compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 119 | x = self.ifft2d(gridy, factor1, factor2, self.modes1, self.modes2) / (size1 * size2) 120 | return x 121 | 122 | def ifft2d(self, gridy, coeff1, coeff2, k1, k2): 123 | 124 | # y (batch, N, 2) locations in [0,1]*[0,1] 125 | # coeff (batch, channels, kmax, kmax) 126 | 127 | batchsize = gridy.shape[0] 128 | N = gridy.shape[1] 129 | device = gridy.device 130 | m1 = 2 * k1 131 | m2 = 2 * k2 - 1 132 | 133 | # wavenumber (m1, m2) 134 | k_x1 = torch.cat((torch.arange(start=0, end=k1, step=1), \ 135 | torch.arange(start=-(k1), end=0, step=1)), 0).reshape(m1,1).repeat(1,m2).to(device) 136 | k_x2 = torch.cat((torch.arange(start=0, end=k2, step=1), \ 137 | torch.arange(start=-(k2-1), end=0, step=1)), 0).reshape(1,m2).repeat(m1,1).to(device) 138 | 139 | # K = , (batch, N, m1, m2) 140 | K1 = torch.outer(gridy[:,:,0].view(-1), k_x1.view(-1)).reshape(batchsize, N, m1, m2) 141 | K2 = torch.outer(gridy[:,:,1].view(-1), k_x2.view(-1)).reshape(batchsize, N, m1, m2) 142 | K = K1 + K2 143 | 144 | # basis (N, m1, m2) 145 | basis = torch.exp( 1j * 2* np.pi * K).to(device) 146 | 147 | # coeff (batch, channels, m1, m2) 148 | coeff3 = coeff1[:,:,1:,1:].flip(-1, -2).conj() 149 | coeff4 = torch.cat([coeff1[:,:,0:1,1:].flip(-1).conj(), coeff2[:,:,:,1:].flip(-1, -2).conj()], dim=-2) 150 | coeff12 = torch.cat([coeff1, coeff2], dim=-2) 151 | coeff43 = torch.cat([coeff4, coeff3], dim=-2) 152 | coeff = torch.cat([coeff12, coeff43], dim=-1) 153 | 154 | # Y (batch, channels, N) 155 | Y = torch.einsum("bcxy,bnxy->bcn", coeff, basis) 156 | Y = Y.real 157 | return Y 158 | 159 | 160 | class SpectralConv2dV2(nn.Module): 161 | def __init__(self, in_channels, out_channels, modes1, modes2): 162 | super(SpectralConv2dV2, self).__init__() 163 | self.in_channels = in_channels 164 | self.out_channels = out_channels 165 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 166 | self.modes2 = modes2 167 | self.scale = (1 / (in_channels * out_channels)) 168 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 169 | #self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1+1, self.modes2, 2)) 170 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, 2)) 171 | 172 | def forward(self, x: torch.Tensor): 173 | size_0 = x.size(-2) 174 | size_1 = x.size(-1) 175 | batchsize = x.shape[0] 176 | dtype=x.dtype 177 | 178 | #Compute Fourier coeffcients up to factor of e^(- something constant) 179 | x_ft = torch.fft.rfft2(x.float(), dim=(-2,-1), norm='ortho') 180 | x_ft = torch.view_as_real(x_ft) 181 | 182 | out_ft = torch.zeros(batchsize, self.out_channels, size_0, size_1//2 + 1, 2, device=x.device) 183 | out_ft[:, :, :self.modes1, :self.modes2] = \ 184 | compl_mul2d_v2(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) 185 | out_ft[:, :, -self.modes1:, :self.modes2] = \ 186 | compl_mul2d_v2(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) 187 | out_ft = torch.view_as_complex(out_ft) 188 | 189 | #Return to physical space 190 | x = torch.fft.irfft2(out_ft, dim=(-2,-1), norm='ortho', s=(size_0, size_1)).to(dtype) 191 | 192 | return x 193 | 194 | 195 | class SpectralConv3d(nn.Module): 196 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3): 197 | super(SpectralConv3d, self).__init__() 198 | self.in_channels = in_channels 199 | self.out_channels = out_channels 200 | self.modes1 = modes1 #Number of Fourier modes to multiply, at most floor(N/2) + 1 201 | self.modes2 = modes2 202 | self.modes3 = modes3 203 | 204 | self.scale = (1 / (in_channels * out_channels)) 205 | self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 206 | self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 207 | self.weights3 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 208 | self.weights4 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, self.modes3, dtype=torch.cfloat)) 209 | 210 | def forward(self, x): 211 | batchsize = x.shape[0] 212 | # Compute Fourier coeffcients up to factor of e^(- something constant) 213 | x_ft = torch.fft.rfftn(x, dim=[2,3,4]) 214 | # Multiply relevant Fourier modes 215 | out_ft = torch.zeros(batchsize, self.out_channels, x.size(2), x.size(3), x.size(4)//2 + 1, device=x.device, dtype=torch.cfloat) 216 | out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \ 217 | compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1) 218 | out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \ 219 | compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2) 220 | out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \ 221 | compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3) 222 | out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \ 223 | compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4) 224 | 225 | #Return to physical space 226 | x = torch.fft.irfftn(out_ft, s=(x.size(2), x.size(3), x.size(4)), dim=[2,3,4]) 227 | return x 228 | 229 | 230 | class FourierBlock(nn.Module): 231 | def __init__(self, in_channels, out_channels, modes1, modes2, modes3, activation='tanh'): 232 | super(FourierBlock, self).__init__() 233 | self.in_channel = in_channels 234 | self.out_channel = out_channels 235 | self.speconv = SpectralConv3d(in_channels, out_channels, modes1, modes2, modes3) 236 | self.linear = nn.Conv1d(in_channels, out_channels, 1) 237 | if activation == 'tanh': 238 | self.activation = torch.tanh_ 239 | elif activation == 'gelu': 240 | self.activation = nn.GELU 241 | elif activation == 'none': 242 | self.activation = None 243 | else: 244 | raise ValueError(f'{activation} is not supported') 245 | 246 | def forward(self, x): 247 | ''' 248 | input x: (batchsize, channel width, x_grid, y_grid, t_grid) 249 | ''' 250 | x1 = self.speconv(x) 251 | x2 = self.linear(x.view(x.shape[0], self.in_channel, -1)) 252 | out = x1 + x2.view(x.shape[0], self.out_channel, x.shape[2], x.shape[3], x.shape[4]) 253 | if self.activation is not None: 254 | out = self.activation(out) 255 | return out 256 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import numpy as np 3 | import argparse 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.distributed as dist 9 | from torch.nn.parallel import DistributedDataParallel 10 | import wandb 11 | import matplotlib.pyplot as plt 12 | from datetime import datetime 13 | import logging 14 | from utils import logging_utils 15 | logging_utils.config_logger() 16 | from utils.YParams import YParams 17 | from utils.data_utils import get_data_loader 18 | from utils.optimizer_utils import set_scheduler, set_optimizer 19 | from utils.loss_utils import LossMSE 20 | from utils.misc_utils import compute_grad_norm, vis_fields, l2_err 21 | from utils.domains import DomainXY 22 | from utils.sweeps import sweep_name_suffix 23 | from ruamel.yaml import YAML 24 | from ruamel.yaml.comments import CommentedMap as ruamelDict 25 | from collections import OrderedDict 26 | 27 | # models 28 | import models.ffn 29 | import models.fno 30 | 31 | def print_mem(): 32 | print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024)) 33 | print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024)) 34 | print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024)) 35 | 36 | def set_seed(params, world_size): 37 | seed = params.seed 38 | if seed is None: 39 | seed = np.random.randint(10000) 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | if world_size > 0: 44 | torch.cuda.manual_seed_all(seed) 45 | 46 | def count_parameters(model): 47 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 48 | return params/1000000 49 | 50 | class Trainer(): 51 | """ trainer class """ 52 | def __init__(self, params, args): 53 | self.sweep_id = args.sweep_id 54 | self.root_dir = args.root_dir 55 | self.config = args.config 56 | self.run_num = args.run_num 57 | self.world_size = 1 58 | if 'WORLD_SIZE' in os.environ: 59 | self.world_size = int(os.environ['WORLD_SIZE']) 60 | 61 | self.local_rank = 0 62 | self.world_rank = 0 63 | if self.world_size > 1: 64 | dist.init_process_group(backend='nccl', 65 | init_method='env://') 66 | self.world_rank = dist.get_rank() 67 | self.local_rank = int(os.environ["LOCAL_RANK"]) 68 | 69 | if torch.cuda.is_available(): 70 | torch.cuda.set_device(self.local_rank) 71 | torch.backends.cudnn.benchmark = True 72 | 73 | self.log_to_screen = params.log_to_screen and self.world_rank==0 74 | self.log_to_wandb = params.log_to_wandb and self.world_rank==0 75 | params['name'] = args.config + '_' + args.run_num 76 | params['group'] = 'op_' + args.config 77 | if torch.cuda.is_available(): 78 | self.device = torch.cuda.current_device() 79 | else: 80 | self.device = torch.device('cpu') 81 | self.params = params 82 | self.params.device = self.device 83 | 84 | def init_exp_dir(self, exp_dir): 85 | if self.world_rank==0: 86 | if not os.path.isdir(exp_dir): 87 | os.makedirs(exp_dir) 88 | os.makedirs(os.path.join(exp_dir, 'checkpoints/')) 89 | os.makedirs(os.path.join(exp_dir, 'wandb/')) 90 | self.params['experiment_dir'] = os.path.abspath(exp_dir) 91 | self.params['checkpoint_path'] = os.path.join(exp_dir, 'checkpoints/ckpt.tar') 92 | self.params['resuming'] = True if os.path.isfile(self.params.checkpoint_path) else False 93 | 94 | def launch(self): 95 | 96 | if self.sweep_id: 97 | if self.world_rank==0: 98 | with wandb.init() as run: 99 | hpo_config = wandb.config 100 | self.params.update_params(hpo_config) 101 | self.modify_bs_for_subsampling() 102 | logging.info(self.params.name+'_'+sweep_name_suffix(self.params, self.sweep_id)) 103 | run.name = self.params.name+'_'+sweep_name_suffix(self.params, self.sweep_id) 104 | self.name = run.name 105 | self.params.name = self.name 106 | exp_dir = os.path.join(*[self.root_dir, 'sweeps', self.sweep_id, self.name]) 107 | self.init_exp_dir(exp_dir) 108 | logging.info('HPO sweep %s, trial cfg %s'%(self.sweep_id, self.name)) 109 | self.build_and_run() 110 | else: 111 | self.build_and_run() 112 | 113 | else: 114 | self.modify_bs_for_subsampling() 115 | exp_dir = os.path.join(*[self.root_dir, 'expts', self.config, self.run_num]) 116 | self.init_exp_dir(exp_dir) 117 | if self.log_to_wandb: 118 | wandb.init(dir=os.path.join(exp_dir, "wandb"), 119 | config=self.params.params, name=self.params.name, group=self.params.group, project=self.params.project, 120 | entity=self.params.entity, resume=self.params.resuming) 121 | self.build_and_run() 122 | 123 | 124 | 125 | def build_and_run(self): 126 | 127 | if self.sweep_id and dist.is_initialized(): 128 | # Broadcast sweep config to other ranks 129 | from mpi4py import MPI 130 | comm = MPI.COMM_WORLD 131 | rank = comm.Get_rank() 132 | assert self.world_rank == rank 133 | if rank != 0: 134 | self.params = None 135 | self.params = comm.bcast(self.params, root=0) 136 | self.params.device = self.device # dont broadcast 0s device 137 | 138 | if self.world_rank == 0: 139 | logging.info(self.params.log()) 140 | 141 | set_seed(self.params, self.world_size) 142 | 143 | self.params['global_batch_size'] = self.params.batch_size 144 | self.params['local_batch_size'] = int(self.params.batch_size//self.world_size) 145 | self.params['global_valid_batch_size'] = self.params.valid_batch_size 146 | self.params['local_valid_batch_size'] = int(self.params.valid_batch_size//self.world_size) 147 | 148 | # dump the yaml used 149 | if self.world_rank == 0: 150 | hparams = ruamelDict() 151 | yaml = YAML() 152 | for key, value in self.params.params.items(): 153 | hparams[str(key)] = str(value) 154 | with open(os.path.join(self.params['experiment_dir'], 'hyperparams.yaml'), 'w') as hpfile: 155 | yaml.dump(hparams, hpfile ) 156 | 157 | self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(self.params, self.params.train_path, dist.is_initialized(), train=True, pack=self.params.pack_data) 158 | self.val_data_loader, self.val_dataset, self.valid_sampler = get_data_loader(self.params, self.params.val_path, dist.is_initialized(), train=False, pack=self.params.pack_data) 159 | 160 | # domain grid 161 | self.domain = DomainXY(self.params) 162 | 163 | 164 | if self.params.model == 'fno': 165 | self.model = models.fno.fno(self.params).to(self.device) 166 | else: 167 | assert(False), "Error, model arch invalid." 168 | 169 | if dist.is_initialized(): 170 | self.model = DistributedDataParallel(self.model, 171 | device_ids=[self.local_rank], 172 | output_device=[self.local_rank]) 173 | 174 | 175 | 176 | self.optimizer = set_optimizer(self.params, self.model) 177 | 178 | self.scheduler = set_scheduler(self.params, self.optimizer) 179 | 180 | if self.params.loss_func == "mse": 181 | self.loss_func = LossMSE(self.params, self.model) 182 | else: 183 | assert(False), "Error, loss func invalid." 184 | 185 | self.iters = 0 186 | self.startEpoch = 0 187 | 188 | if hasattr(self.params, 'weights'): 189 | self.params.resuming = False 190 | logging.info("Loading IC weights %s"%self.params.weights) 191 | self.load_model(self.params.weights) 192 | 193 | if self.params.resuming: 194 | logging.info("Loading checkpoint %s"%self.params.checkpoint_path) 195 | self.restore_checkpoint(self.params.checkpoint_path) 196 | 197 | self.epoch = self.startEpoch 198 | self.logs = {} 199 | self.train_loss = self.data_loss = self.bc_loss = self.pde_loss = self.grad = 0.0 200 | n_params = count_parameters(self.model) 201 | if self.log_to_screen: 202 | logging.info(self.model) 203 | logging.info('number of model parameters: {}'.format(n_params)) 204 | 205 | # launch training 206 | self.train() 207 | 208 | def train(self): 209 | if self.log_to_screen: 210 | logging.info("Starting training loop...") 211 | best_loss = np.inf 212 | 213 | best_epoch = 0 214 | best_err = 1 215 | self.logs['best_epoch'] = best_epoch 216 | plot_figs = self.params.plot_figs 217 | 218 | for epoch in range(self.startEpoch, self.params.max_epochs): 219 | self.epoch = epoch 220 | if dist.is_initialized(): 221 | # shuffles data before every epoch 222 | self.train_sampler.set_epoch(epoch) 223 | start = time.time() 224 | 225 | # train 226 | tr_time = self.train_one_epoch() 227 | val_time, fields = self.val_one_epoch() 228 | self.logs['wt_norm'] = self.get_model_wt_norm(self.model) 229 | 230 | if self.params.scheduler == 'reducelr': 231 | self.scheduler.step(self.logs['train_loss']) 232 | elif self.params.scheduler == 'cosine': 233 | self.scheduler.step() 234 | 235 | if self.logs['val_loss'] <= best_loss: 236 | is_best_loss = True 237 | best_loss = self.logs['val_loss'] 238 | best_err = self.logs['val_err'] 239 | else: 240 | is_best_loss = False 241 | self.logs['best_val_loss'] = best_loss 242 | self.logs['best_val_err'] = best_err 243 | 244 | best_epoch = self.epoch if is_best_loss else best_epoch 245 | self.logs['best_epoch'] = best_epoch 246 | 247 | if self.params.save_checkpoint: 248 | if self.world_rank == 0: 249 | #checkpoint at the end of every epoch 250 | if is_best_loss: 251 | self.save_logs(tag="_best") 252 | self.save_checkpoint(self.params.checkpoint_path, is_best=is_best_loss) 253 | 254 | if self.log_to_wandb: 255 | # log visualizations every epoch 256 | if plot_figs: 257 | fig = vis_fields(fields, self.params, self.domain) 258 | self.logs['vis'] = wandb.Image(fig) 259 | plt.close(fig) 260 | self.logs['learning_rate'] = self.optimizer.param_groups[0]['lr'] 261 | self.logs['time_per_epoch'] = tr_time 262 | wandb.log(self.logs, step=self.epoch+1) 263 | 264 | if self.log_to_screen: 265 | logging.info('Time taken for epoch {} is {} sec; with {}/{} in tr/val'.format(self.epoch+1, time.time()-start, tr_time, val_time)) 266 | logging.info('Loss (total = data + bc + pde) {} = {} + {} + {}'.format(self.logs['train_loss'], self.logs['data_loss'], 267 | self.logs['bc_loss'], self.logs['pde_loss'])) 268 | 269 | 270 | if self.log_to_wandb: 271 | wandb.finish() 272 | 273 | 274 | def get_model_wt_norm(self, model): 275 | n = 0 276 | for p in model.parameters(): 277 | p_norm = p.data.detach().norm(2) 278 | n += p_norm.item()**2 279 | n = n**0.5 280 | return n 281 | 282 | def save_logs(self, tag=""): 283 | with open(os.path.join(self.params.experiment_dir, "logs"+tag+".txt"), "w") as f: 284 | f.write("epoch,{}\n".format(self.epoch)) 285 | for k, v in self.logs.items(): 286 | f.write("{},{}\n".format(k,v)) 287 | 288 | 289 | def train_one_epoch(self): 290 | tr_time = 0 291 | self.model.train() 292 | 293 | # buffers for logs 294 | logs_buff = torch.zeros((6), dtype=torch.float32, device=self.device) 295 | self.logs['train_loss'] = logs_buff[0].view(-1) 296 | self.logs['data_loss'] = logs_buff[1].view(-1) 297 | self.logs['bc_loss'] = logs_buff[2].view(-1) 298 | self.logs['pde_loss'] = logs_buff[3].view(-1) 299 | self.logs['grad'] = logs_buff[4].view(-1) 300 | self.logs['tr_err'] = logs_buff[5].view(-1) 301 | 302 | 303 | for i, (inputs, targets) in enumerate(self.train_data_loader): 304 | self.iters += 1 305 | data_start = time.time() 306 | if not self.params.pack_data: # send to gpu if not already packed in the dataloader 307 | inputs, targets = inputs.to(self.device), targets.to(self.device) 308 | tr_start = time.time() 309 | 310 | self.model.zero_grad() 311 | u = self.model(inputs) 312 | 313 | loss_data = self.loss_func.data(inputs, u, targets) 314 | loss_pde = self.loss_func.pde(inputs, u, targets) 315 | loss_bc = self.loss_func.bc(inputs, u, targets) 316 | loss = loss_data + loss_bc + loss_pde 317 | 318 | loss.backward() 319 | self.optimizer.step() 320 | 321 | grad_norm = compute_grad_norm(self.model.parameters()) 322 | tr_err = l2_err(u.detach(), targets.detach()) 323 | 324 | # add all the minibatch losses 325 | self.logs['train_loss'] += loss.detach() 326 | self.logs['data_loss'] += loss_data.detach() 327 | self.logs['bc_loss'] += loss_bc.detach() 328 | self.logs['pde_loss'] += loss_pde.detach() 329 | self.logs['grad'] += grad_norm 330 | self.logs['tr_err'] += tr_err 331 | 332 | tr_time += time.time() - tr_start 333 | 334 | self.logs['train_loss'] /= len(self.train_data_loader) 335 | self.logs['data_loss'] /= len(self.train_data_loader) 336 | self.logs['bc_loss'] /= len(self.train_data_loader) 337 | self.logs['pde_loss'] /= len(self.train_data_loader) 338 | self.logs['grad'] /= len(self.train_data_loader) 339 | self.logs['tr_err'] /= len(self.train_data_loader) 340 | 341 | logs_to_reduce = ['train_loss', 'data_loss', 'bc_loss', 'pde_loss', 'grad', 'tr_err'] 342 | 343 | if dist.is_initialized(): 344 | for key in logs_to_reduce: 345 | dist.all_reduce(self.logs[key].detach()) 346 | # todo change loss to unscaled 347 | self.logs[key] = float(self.logs[key]/dist.get_world_size()) 348 | 349 | return tr_time 350 | 351 | def val_one_epoch(self): 352 | self.model.eval() # need gradients 353 | #self.model.train() # need gradients 354 | val_start = time.time() 355 | 356 | logs_buff = torch.zeros((2), dtype=torch.float32, device=self.device) 357 | self.logs['val_err'] = logs_buff[0].view(-1) 358 | self.logs['val_loss'] = logs_buff[1].view(-1) 359 | idx = np.random.randint(0, len(self.val_data_loader)) 360 | img_idx = np.random.randint(0, self.params.local_valid_batch_size) 361 | with torch.no_grad(): 362 | for i, (inputs, targets) in enumerate(self.val_data_loader): 363 | if not self.params.pack_data: 364 | inputs, targets = inputs.to(self.device), targets.to(self.device) 365 | u = self.model(inputs) 366 | loss_data = self.loss_func.data(inputs, u, targets) 367 | loss_pde = self.loss_func.pde(inputs, u, targets) 368 | loss_bc = self.loss_func.bc(inputs, u, targets) 369 | loss = loss_data + loss_bc + loss_pde 370 | self.logs['val_err'] += l2_err(u.detach(), targets.detach()) 371 | self.logs['val_loss'] += loss.detach() 372 | if i == idx: 373 | source = inputs[img_idx,0].detach().cpu().numpy() 374 | soln = targets[img_idx,0].detach().cpu().numpy() 375 | pred = u[img_idx,0].detach().cpu().numpy() 376 | pde_res = 0*pred 377 | temp = 0*pred 378 | 379 | fields = [source, soln, pred, pde_res, temp] 380 | 381 | self.logs['val_loss'] /= len(self.val_data_loader) 382 | self.logs['val_err'] /= len(self.val_data_loader) 383 | if dist.is_initialized(): 384 | for key in ['val_loss', 'val_err']: 385 | dist.all_reduce(self.logs[key].detach()) 386 | self.logs[key] = float(self.logs[key]/dist.get_world_size()) 387 | 388 | val_time = time.time() - val_start 389 | 390 | return val_time, fields 391 | 392 | def save_checkpoint(self, checkpoint_path, is_best=False, model=None): 393 | if not model: 394 | model = self.model 395 | torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if self.scheduler is not None else None)}, checkpoint_path) 396 | if is_best: 397 | torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': (self.scheduler.state_dict() if self.scheduler is not None else None)}, checkpoint_path.replace('.tar', '_best.tar')) 398 | 399 | def restore_checkpoint(self, checkpoint_path): 400 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank)) 401 | try: 402 | self.model.load_state_dict(checkpoint['model_state']) 403 | except: 404 | new_state_dict = OrderedDict() 405 | for key, val in checkpoint['model_state'].items(): 406 | name = key[7:] 407 | new_state_dict[name] = val 408 | self.model.load_state_dict(new_state_dict) 409 | 410 | self.iters = checkpoint['iters'] 411 | self.startEpoch = checkpoint['epoch'] + 1 412 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 413 | if self.scheduler is not None: 414 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 415 | 416 | def load_model(self, checkpoint_path): 417 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank)) 418 | try: 419 | self.model.load_state_dict(checkpoint['model_state']) 420 | except: 421 | new_state_dict = OrderedDict() 422 | for key, val in checkpoint['model_state'].items(): 423 | name = key[7:] 424 | new_state_dict[name] = val 425 | self.model.load_state_dict(new_state_dict) 426 | 427 | def switch_off_grad(self, model): 428 | for param in model.parameters(): 429 | param.requires_grad = False 430 | 431 | 432 | def modify_bs_for_subsampling(self): 433 | '''Reduce batchsize for very small datasets''' 434 | sz = self.params.subsample 435 | if sz >= 512: 436 | fac = np.log2(sz) - 8 437 | self.params.batch_size = int(128/2**fac) 438 | --------------------------------------------------------------------------------