├── .gitignore ├── README.md ├── YParams.py ├── baseline_utlis ├── __init__.py ├── rigid_neighbor.py └── test_rigid_neightbour.py ├── config ├── RB_config.yaml └── ssl_ns_elastic.yaml ├── data_utils ├── __init__.py ├── data_loaders.py └── data_utils.py ├── exps_FSI.sh ├── exps_RB.sh ├── fsi_animation_dx.gif ├── fsi_animation_pressue.gif ├── layers ├── __init__.py ├── codano_block_2D.py ├── codano_block_nd.py ├── fino_2D.py ├── fino_nd.py ├── gnn_layer.py ├── gno_layer.py ├── regrider.py ├── regular_transformer.py ├── unet3d.py ├── unet_sublayer.py └── variable_encoding.py ├── main.py ├── models ├── __init__.py ├── codano.py ├── codano_gino.py ├── deeponet.py ├── fno_gino.py ├── get_models.py ├── gnn.py ├── model_helpers.py ├── unet.py └── vit.py ├── requirements.txt ├── test ├── __init__.py └── evaluations.py ├── train ├── __init__.py ├── new_adam.py └── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | wandb/* 3 | *__pycache__/* 4 | */__pycache__/* 5 | *.out 6 | .ipynb_cache/* 7 | .ipynb_checkpoints/* 8 | *key.txt 9 | wandb/* 10 | */.ipynb_cache/* 11 | */.ipynb_checkpoints/* 12 | */.ipynb_checkpoints/* 13 | .ipynb_checkpoints/* 14 | 15 | *__pycache__/* 16 | */__pycache__/* 17 | *.out 18 | 19 | # Virtual environment(s): 20 | .venv/* 21 | weights/* 22 | config/wandb_api_key.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pretraining Codomain Attention Neural Operators for Solving Multiphysics PDEs 2 | 3 | > [Paper Link](https://arxiv.org/pdf/2403.12553.pdf) 4 | 5 | > **🚀🚀 HOW TO USE CoDA-NO MODEL USING `neuraloperator`** [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1W6Qy5Mk_vEjZgrA0tWMespXqKEYDOdc6?usp=sharing) 6 | 7 | ## Model Architecture 8 |

9 | 10 |
11 | Architecture of the Codomain Attention Neural Operator 12 |

13 | Each physical variable (or co-domain) of the input function is concatenated with variable-specific positional encoding (VSPE). Each variable, along with the VSPE, is passed through a GNO layer, which maps from the given non-uniform geometry to a latent regular grid. Then, the output on a uniform grid 14 | is passed through a series of CoDA-NO layers. Lastly, the output of the stacked CoDA-NO layers is mapped onto the domain of the 15 | output geometry for each query point using another GNO layer. 16 | 17 | At each CoDA-NO layer, the input function is tokenized codomain-wise to generate token functions. Each token function is passed through the K, Q, and V operators to get key, query, and value functions. The output function is calculated by extending the self-attention mechanism to the function space. 18 | 19 | 20 | ## Navier Stokes+Elastic Wave and Navier Stokes Dataset 21 | 22 | The fluid-solid interaction dataset is available at [HuggingFace](https://huggingface.co/datasets/ashiq24/FSI-pde-dataset). To download, please use the code 23 | ```python 24 | from huggingface_hub import snapshot_download 25 | 26 | folder_path = snapshot_download( 27 | repo_id="ashiq24/FSI-pde-dataset", 28 | repo_type="dataset", 29 | allow_patterns=["fsi-data/*"] 30 | ) 31 | ``` 32 | ### Data Set Structure 33 | 34 | **Displacement Field** 35 | ![Animation](https://github.com/neuraloperator/CoDA-NO/blob/main/fsi_animation_dx.gif?raw=true) 36 | 37 | **Fluid Structure Interaction(NS +Elastic wave)** 38 | The `fsi-data` folder contains simulation data organized by various parameters (`mu`, `x1`, `x2`) where `mu` determines the viscosity and `x1` and `x2` are the parameters of the inlet condition. The dataset includes files for mesh, displacement, velocity, and pressure. 39 | 40 | This dataset structure is detailed below: 41 | 42 | ```plaintext 43 | fsi-data/ 44 | ├── mesh.h5 # Initial mesh 45 | ├── mu=1.0/ # Simulation results for mu = 1.0 46 | │ ├── x1=-4/ # Inlet parameter x1 = -4 47 | │ │ ├── x2=-4/ # Inlet parameter for x2 = -4 48 | │ │ │ └── visualization/ 49 | │ │ │ ├── displacement.h5 # Displacements for mu=1.0, x1=-4, x2=-4 50 | │ │ │ ├── velocity.h5 # Velocity field for mu=1.0, x1=-4, x2=-4 51 | │ │ │ └── pressure.h5 # Pressure field for mu=1.0, x1=-4, x2=-4 52 | │ │ ├── x2=-2/ 53 | │ │ │ └── visualization/ 54 | │ │ │ ├── displacement.h5 55 | │ │ │ ├── velocity.h5 56 | │ │ │ └── pressure.h5 57 | │ │ └── ... # Other x2 values for x1 = -4 58 | │ ├── x1=-2/ 59 | │ │ ├── x2=-4/ 60 | │ │ │ └── visualization/ 61 | │ │ │ ├── displacement.h5 62 | │ │ │ ├── velocity.h5 63 | │ │ │ └── pressure.h5 64 | │ │ └── ... # Other x2 values for x1 = -2 65 | │ └── ... # Other x1 values for mu = 1.0 66 | ├── mu=5.0/ # Simulation results for mu = 5.0 67 | │ └── ... # Similar structure as mu=1.0 68 | └── mu=10.0/ # Simulation results for mu = 10.0 69 | └── ... # Similar structure as mu=1.0 70 | ``` 71 | The dataset has a dataloader and visualization code. Also, the `NsElasticDataset` class in `data_utils/data_loaders.py` loads data automatically for all specified `mu`s and inlet conditions (`x1` and `x2`). 72 | 73 | **Fluid Motions with Non-deformable Solid(NS)** is stored in `cfd-data` 74 | 75 | ## Rayleigh–Bénard convection 76 | Huggingface dataset link: [Rayleigh_Benard_Convection](https://huggingface.co/datasets/ashiq24/Rayleigh_Benard_Convection) 77 | 78 | 79 | 80 | 81 | ## Experiments 82 | 83 | > ⚠️ **Note:** This repository uses an older version of the `neuralop` library. For a version compatible with the latest `neuralop` library, please refer to the following implementation: 84 | 85 | > **The codomain attention layer is now available through the `neuraloperator` library** ([implementation](https://github.com/neuraloperator/neuraloperator/blob/main/neuralop/layers/coda_layer.py)). 86 | 87 | > **Also, the model is available through the `neuraloperator` library, see** [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1W6Qy5Mk_vEjZgrA0tWMespXqKEYDOdc6?usp=sharing) 88 | 89 | ### Installations 90 | The configurations for all the experiments are at `config/ssl_ns_elastic.yaml` (for fluid-structure interaction) and `config/RB_config.yaml` (For the Releigh Bernard system). 91 | 92 | To set up the environments and install the dependencies, please run the following command: 93 | ```bash 94 | pip install -r requirements.txt 95 | ``` 96 | It requires `python=3.11.9`, and the `torch` installations need to be tailored to your machine's specific Cuda version. Also, the installation of torch_geometric and torch_scatter should match the local machine's Cuda version. More at the [installation guide](https://pytorch-geometric.readthedocs.io/en/latest/). 97 | 98 | **Shortcut:** If you already use the `neuraloprator` package, we have installed most of the packages. Then, you just need to execute the following line to roll back to a compatible version. 99 | 100 | ``` 101 | pip install -e git+https://github.com/ashiq24/neuraloperator.git@codano_rep#egg=neuraloperator 102 | ``` 103 | 104 | We are going to release the CoDA-NO layers and models soon as part of the `neural operator` library. 105 | 106 | ### Running Experiments 107 | To run the experiments, download the datasets, update the "input_mesh_location" and "data_location" in the config file, update the Wandb credentials, and execute the following command 108 | 109 | ``` 110 | python main.py --exp (FSI/RB) --config "config name" --ntrain N 111 | ``` 112 | 113 | `--exp` : Determines which experiment we want to run, 'FSI' (fluid-structure interaction) or 'RB' (Releigh Bernard) 114 | 115 | `--config`: Determines which configuration to use from the config file 'config/ssl_ns_elastic.yaml/RB_config.yaml`. 116 | 117 | `--ntrain`: Determines Number of training data points. 118 | 119 | ## Scripts 120 | For training CoDA-NO architecture on NS/NS+EW (FSI) and Releigh Bernard convection datasets (both pre-training and fine-tuning), please execute the following scrips: 121 | ``` 122 | exps_FSI.sh 123 | exps_RB.sh 124 | ``` 125 | 126 | 127 | ## Reference 128 | If you find this paper and code useful in your research, please consider citing: 129 | ```bibtex 130 | @article{rahman2024pretraining, 131 | title={Pretraining Codomain Attention Neural Operators for Solving Multiphysics PDEs}, 132 | author={Rahman, Md Ashiqur and George, Robert Joseph and Elleithy, Mogab and Leibovici, Daniel and Li, Zongyi and Bonev, Boris and White, Colin and Berner, Julius and Yeh, Raymond A and Kossaifi, Jean and Azizzadenesheli, Kamyar and Anandkumar, Anima}, 133 | journal={Advances in Neural Information Processing Systems}, 134 | volume={37} 135 | year={2024} 136 | } 137 | ``` 138 | -------------------------------------------------------------------------------- /YParams.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pprint 4 | 5 | from ruamel.yaml import YAML 6 | 7 | 8 | class ParamsBase: 9 | """Convenience wrapper around a dictionary 10 | 11 | Allows referring to dictionary items as attributes, and tracking which 12 | attributes are modified. 13 | """ 14 | 15 | def __init__(self): 16 | self._original_attrs = None 17 | self.params = {} 18 | self._original_attrs = list(self.__dict__) 19 | 20 | def __getitem__(self, key): 21 | return self.params[key] 22 | 23 | def __setitem__(self, key, val): 24 | self.params[key] = val 25 | self.__setattr__(key, val) 26 | 27 | def __contains__(self, key): 28 | return key in self.params 29 | 30 | def get(self, key, default=None): 31 | if hasattr(self, key): 32 | return getattr(self, key) 33 | else: 34 | return self.params.get(key, default) 35 | 36 | def to_dict(self): 37 | new_attrs = { 38 | key: val for key, val in vars(self).items() 39 | if key not in self._original_attrs 40 | } 41 | return {**self.params, **new_attrs} 42 | 43 | @staticmethod 44 | def from_json(path: str) -> "ParamsBase": 45 | with open(path) as f: 46 | c = json.load(f) 47 | params = ParamsBase() 48 | params.update_params(c) 49 | return params 50 | 51 | def update_params(self, config): 52 | for key, val in config.items(): 53 | if val == 'None': 54 | val = None 55 | 56 | if type(val) == dict: 57 | child = ParamsBase() 58 | child.update_params(val) 59 | val = child 60 | 61 | self.params[key] = val 62 | self.__setattr__(key, val) 63 | 64 | 65 | class YParams(ParamsBase): 66 | def __init__(self, yaml_filename, config_name, print_params=False): 67 | """Open parameters stored with ``config_name`` in the yaml file ``yaml_filename``""" 68 | super().__init__() 69 | self._yaml_filename = yaml_filename 70 | self._config_name = config_name 71 | 72 | with open(yaml_filename) as _file: 73 | d = YAML().load(_file)[config_name] 74 | 75 | self.update_params(d) 76 | 77 | if print_params: 78 | print("------------------ Configuration ------------------") 79 | for k, v in d.items(): 80 | print(k, end='=') 81 | pprint.pprint(v) 82 | print("---------------------------------------------------") 83 | 84 | def log(self): 85 | logging.info("------------------ Configuration ------------------") 86 | logging.info("Configuration file: " + str(self._yaml_filename)) 87 | logging.info("Configuration name: " + str(self._config_name)) 88 | for key, val in self.to_dict().items(): 89 | logging.info(str(key) + ' ' + str(val)) 90 | logging.info("---------------------------------------------------") -------------------------------------------------------------------------------- /baseline_utlis/__init__.py: -------------------------------------------------------------------------------- 1 | from .rigid_neighbor import * 2 | -------------------------------------------------------------------------------- /baseline_utlis/rigid_neighbor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def simple_neighbor_search( 6 | data: torch.Tensor, 7 | queries: torch.Tensor, 8 | n_neigbor: float): 9 | """ 10 | 11 | Parameters 12 | ---------- 13 | Density-Based Spatial Clustering of Applications with Noise 14 | data : torch.Tensor 15 | vector of data points from which to find neighbors 16 | queries : torch.Tensor 17 | centers of neighborhoods 18 | 19 | """ 20 | 21 | # shaped num query points x num data points 22 | dists = torch.cdist(queries, data).to(queries.device) 23 | sorted_dist, _ = torch.sort(dists, dim=1) 24 | k = sorted_dist[:, n_neigbor] 25 | dists = dists - k[:, None] 26 | in_nbr = torch.where(dists < 0, 1., 0.) # i,j is one if j is i's neighbor 27 | # only keep the column indices 28 | nbr_indices = in_nbr.nonzero()[:, 1:].reshape(-1,) 29 | # num points in each neighborhood, summed cumulatively 30 | nbrhd_sizes = torch.cumsum(torch.sum(in_nbr, dim=1), dim=0) 31 | splits = torch.cat((torch.tensor([0.]).to(queries.device), nbrhd_sizes)) 32 | nbr_dict = {} 33 | nbr_dict['neighbors_index'] = nbr_indices.long().to(queries.device) 34 | nbr_dict['neighbors_row_splits'] = splits.long() 35 | return nbr_dict 36 | 37 | 38 | class FixedNeighborSearch(nn.Module): 39 | """Neighbor search within a ball of a given radius 40 | 41 | Parameters 42 | ---------- 43 | use_open3d : bool 44 | Whether to use open3d or torch_cluster 45 | NOTE: open3d implementation requires 3d data 46 | """ 47 | 48 | def __init__(self, use_open3d=True, use_torch_cluster=False): 49 | super().__init__() 50 | self.search_fn = simple_neighbor_search 51 | self.use_open3d = False 52 | 53 | def forward(self, data, queries, n_neigbor): 54 | """Find the neighbors, in data, of each point in queries 55 | within a ball of radius. Returns in CRS format. 56 | 57 | Parameters 58 | ---------- 59 | data : torch.Tensor of shape [n, d] 60 | Search space of possible neighbors 61 | NOTE: open3d requires d=3 62 | queries : torch.Tensor of shape [m, d] 63 | Point for which to find neighbors 64 | NOTE: open3d requires d=3 65 | radius : float 66 | Radius of each ball: B(queries[j], radius) 67 | 68 | Output 69 | ---------- 70 | return_dict : dict 71 | Dictionary with keys: neighbors_index, neighbors_row_splits 72 | neighbors_index: torch.Tensor with dtype=torch.int64 73 | Index of each neighbor in data for every point 74 | in queries. Neighbors are ordered in the same orderings 75 | as the points in queries. Open3d and torch_cluster 76 | implementations can differ by a permutation of the 77 | neighbors for every point. 78 | neighbors_row_splits: torch.Tensor of shape [m+1] with dtype=torch.int64 79 | The value at index j is the sum of the number of 80 | neighbors up to query point j-1. First element is 0 81 | and last element is the total number of neighbors. 82 | """ 83 | return_dict = {} 84 | 85 | return_dict = self.search_fn(data, queries, n_neigbor) 86 | 87 | return return_dict 88 | -------------------------------------------------------------------------------- /baseline_utlis/test_rigid_neightbour.py: -------------------------------------------------------------------------------- 1 | from rigid_neighbor import * 2 | import torch 3 | 4 | in_ = torch.randn((10, 3)) 5 | out = torch.randn((10, 3)) 6 | 7 | NS = FixedNeighborSearch(use_open3d=False) 8 | 9 | neighbour = NS(in_, in_, 3) 10 | print(neighbour) 11 | -------------------------------------------------------------------------------- /config/RB_config.yaml: -------------------------------------------------------------------------------- 1 | base_config: &BASE_CONFIG 2 | ## SSL parameters 3 | random_seed: 42 4 | config: " " 5 | nettype: 'transformer' 6 | evaluation_channel_drop: 1 # number of varibales dropped for prediction task with partially observed data 7 | drop_type : 'zeros' # dropped values are replaced by zeros 8 | grid_type: 'uniform' 9 | max_block : 0.3 #Block size for dropping pixels during data augmentation 10 | drop_pix: 0.5 #percentage of pixels dropped during data augmentation 11 | channel_per: 0.5 # percentage of channels affected by data augmentation 12 | channel_drop_per: 0.5 # percentage of affected channels to be dropped completly 13 | validation_aug: !!bool False 14 | max_block_val : 0.3 15 | drop_pix_val: 0.5 16 | channel_per_val: 0.2 17 | channel_drop_per_val: 1.0 18 | 19 | 20 | 21 | scheduler_type: 'step' # lr scheduler type 22 | batch_size: 4 23 | use_variable_encoding: !!bool True 24 | n_variables: 5 25 | masking: !!bool True # if true, it will perform data augmentation for SSL 26 | in_token_codim_en: 1 27 | kqv_non_linear: !!bool False 28 | 29 | hidden_token_codim_en: 6 30 | lifting_token_codim_en: 12 31 | lifting_token_codim_pred: 12 32 | out_token_codim_pred: 1 33 | n_layers_en: 2 34 | n_heads_en: [2,2] 35 | n_layers_dec: 2 36 | n_heads_dec: [2,2] 37 | n_layers_pred: 3 38 | n_heads_pred: [2,2,2] 39 | 40 | scalings_pred: [[1,1], [1,1], [1,1]] 41 | scalings_en: [[1,1], [1,1],[1,1]] 42 | scalings_dec: [[1,1],[1,1]] 43 | 44 | n_modes_en: [[100,100], [100,100]] 45 | n_modes_dec: [[100,100], [100,100]] 46 | n_modes_pred: [[100,100], [100,100],[100,100]] 47 | 48 | per_channel_attention: !!bool True 49 | 50 | 51 | #fft_type: 'fft' # Duplicate should be removed 52 | transform_type: 'fft' # might be also spherical harmonics transform or 'sht' 53 | 54 | tno_integral_op: 'fno' 55 | 56 | var_encoding: !!bool True 57 | n_encoding_channels: 3 58 | reconstruction: !!bool True 59 | enable_cls_token: !!bool True 60 | 61 | # add_static_feature: !!bool False 62 | 63 | pretrain_ssl : !!bool True #if true we pretrain the model by SSL 64 | 65 | # if True, it will fine tune the encoder during SL 66 | # otherwise it will freeze the weight of the encoder 67 | # which is trained by SSL 68 | 69 | ## training Hyeper parameters 70 | training_stage: 'regular' # can be regular or fine_tune 71 | freeze_encoder : !!bool False # if true, it will freeze the encoder during SL 72 | lr: 0.03 73 | weight_decay: 0.0000 74 | scheduler_step: 50 75 | scheduler_gamma: 0.5 76 | epochs: 50 77 | clip_gradient: !!bool True 78 | gradient_clip_value: 0.1 79 | ssl_only: !!bool False # if true, it will only train the model by SSL and will not be followed by SL 80 | weight_path: "./weights/" 81 | weight_saving_interval: 3 82 | 83 | # Weights and biases 84 | wandb_log: True 85 | wandb_name: 'codano-RB' 86 | wandb_group: 'neuraloperator' 87 | wandb_project: 'CoDA-NO_neurips' 88 | wandb_entity: 'ashiq24' 89 | wandb_log_test_interval: 1 90 | 91 | dataset: "" 92 | 93 | ## incremental learning 94 | incremental: False 95 | buffer_modes: 5 96 | grad_explained_ratio_threshold: 0.9999 97 | max_iter: 1 98 | grad_max_iter: 1 99 | 100 | incremental_loss_gap: False 101 | eps: 0.1 102 | 103 | incremental_resolution: False 104 | epoch_gap: 150 105 | horizontal_skip: !!bool False 106 | 107 | codano_NS2: &CODANO_NS2 108 | <<: *BASE_CONFIG 109 | # dataset hyper 110 | 111 | n_train: 40 112 | n_dim: 2 113 | equation_dict: { "NS": 2} 114 | n_test: 40 115 | 116 | subsampling_rate: 2 117 | 118 | ### 119 | dt: 2 120 | skip_start: 250 121 | 122 | data_location: ["../../../../../raid/ashiq/ns_vel/NS_data_re5000.pt", "../../../../../raid/ashiq/ns_vel/NS_data_re500.pt"] 123 | dataset: "NS" 124 | 125 | # training. hypers 126 | epochs: 30 127 | lr: 0.01 128 | weight_decay: 0.0000 129 | scheduler_type: 'rdp' 130 | scheduler_step: 3 131 | scheduler_gamma: 0.5 132 | gradient_clip_value: 5.0 133 | 134 | enable_cls_token: !!bool True 135 | n_encoding_channels: 16 136 | 137 | n_layers_en: 4 138 | n_heads_en: [2,2,2,2] 139 | n_layers_dec: 4 140 | n_heads_dec: [2,2,2,2] 141 | n_layers_pred: 3 142 | n_heads_pred: [2,2,2] 143 | 144 | scalings_en: [[1,1], [1,1], [1,1], [1,1]] 145 | scalings_dec: [[1,1],[1,1],[1,1], [1,1]] 146 | scalings_pred: [[1,1], [1,1], [1,1]] 147 | 148 | max_n_modes_en: [[32,32], [32,32], [32,32], [32,32]] 149 | max_n_modes_dec: [[32,32], [32,32], [32,32], [32,32]] 150 | max_n_modes_pred: [[32,32], [32,32], [32,32]] 151 | 152 | n_modes_en: [[32,32], [32,32], [32,32], [32,32]] 153 | n_modes_dec: [[32,32], [32,32], [32,32], [32,32]] 154 | n_modes_pred: [[32,32], [32,32], [32,32]] 155 | 156 | hidden_token_codim_en: 64 157 | lifting_token_codim_en: 128 158 | lifting_token_codim_pred: 128 159 | 160 | ## varibale encoder 161 | encoding_modes_x : 32 162 | encoding_modes_y : 32 163 | encoding_modes_t : 20 164 | basis : 'fft' 165 | 166 | 167 | n_static_channels: 2 168 | in_token_codim_en: 1 169 | out_token_codim_pred: 1 170 | 171 | 172 | pretrain_ssl : !!bool True 173 | freeze_encoder : !!bool False 174 | ssl_only: !!bool True 175 | positional_encoding_dim: 4 176 | 177 | #data augmentation 178 | channel_per: 1.0 179 | channel_drop_per: 0.0 180 | max_block : 0.5 181 | drop_pix: 0.5 182 | masking: !!bool True 183 | 184 | codano_big: &CODANO_BIG 185 | <<: *CODANO_NS2 186 | n_train: 100 187 | n_test: 100 188 | max_block : 0.5 189 | drop_pix: 0.6 190 | batch_size: 8 191 | 192 | n_encoding_channels: 4 193 | 194 | n_layers_en: 3 195 | n_heads_en: [32, 32, 32] 196 | n_layers_dec: 3 197 | n_heads_dec: [32,32,32] 198 | n_layers_pred: 1 199 | n_heads_pred: [16] 200 | 201 | scalings_en: [[1,1], [1,1], [1,1]] 202 | scalings_dec: [[1,1],[1,1],[1,1]] 203 | scalings_pred: [[1,1]] 204 | 205 | max_n_modes_en: [[64,64], [64,64], [64,64]] 206 | max_n_modes_dec: [[64,64], [64,64], [64,64]] 207 | max_n_modes_pred: [[64,64]] 208 | 209 | n_modes_en: [[64,64], [64,64], [64,64]] 210 | n_modes_dec: [[64,64], [64,64], [64,64]] 211 | n_modes_pred: [[64,64]] 212 | 213 | hidden_token_codim_en: 64 214 | lifting_token_codim_en: 128 215 | lifting_token_codim_pred: 128 216 | 217 | ## varibale encoder 218 | encoding_modes_x : 64 219 | encoding_modes_y : 64 220 | encoding_modes_t : 20 221 | weight_saving_interval: 1 222 | 223 | codano_RB: &CODANO_RB 224 | <<: *CODANO_NS2 225 | 226 | n_train: 40 227 | n_dim: 2 228 | batch_size: 4 229 | equation_dict: { "NS": 2, "T": 1} 230 | n_test: 150 231 | subsampling_rate: 2 232 | data_location: "../../../RB_Data/RB_data/data_test_2500.npz" 233 | dataset: "RB" 234 | 235 | pretrain_ssl : !!bool False 236 | masking: !!bool False 237 | 238 | 239 | n_encoding_channels: 16 240 | 241 | n_layers_en: 3 242 | n_heads_en: [16,16,16] 243 | n_layers_dec: 4 244 | n_heads_dec: [16,16,16,16] 245 | n_layers_pred: 3 246 | n_heads_pred: [16,16,16] 247 | 248 | scalings_en: [[1,1], [1,1], [1,1]] 249 | scalings_dec: [[1,1],[1,1],[1,1], [1,1]] 250 | scalings_pred: [[1,1], [1,1], [1,1]] 251 | 252 | max_n_modes_en: [[64,64], [64,64], [64,64]] 253 | max_n_modes_dec: [[64,64], [64,64], [64,64], [64,64]] 254 | max_n_modes_pred: [[64,64], [64,64], [64,64]] 255 | 256 | n_modes_en: [[64,64], [64,64], [64,64]] 257 | n_modes_dec: [[64,64], [64,64], [64,64], [64,64]] 258 | n_modes_pred: [[64,64], [64,64], [64,64]] 259 | 260 | hidden_token_codim_en: 64 261 | lifting_token_codim_en: 128 262 | lifting_token_codim_pred: 128 263 | 264 | ## varibale encoder 265 | encoding_modes_x : 32 266 | encoding_modes_y : 32 267 | encoding_modes_t : 20 268 | 269 | scheduler_step: 10 270 | scheduler_gamma: 0.5 271 | lr: 0.01 272 | 273 | ft_codano_RB: &FT_CODANO_RB 274 | <<: *CODANO_BIG 275 | 276 | batch_size: 5 277 | equation_dict: { "NS": 2, "T": 1} 278 | n_test: 150 279 | data_location: "../../../RB_Data/RB_data/data_test_2500.npz" 280 | dataset: "RB" 281 | 282 | training_stage: 'fine_tune' 283 | pretrain_ssl : !!bool False 284 | masking: !!bool False 285 | 286 | n_layers_pred: 1 287 | n_heads_pred: [1] 288 | 289 | scalings_pred: [[1,1]] 290 | 291 | n_modes_pred: [[64,64]] 292 | max_n_modes_pred: [[64,64]] 293 | pretrain_weight: "../../../RB_weights/pre_trained_weights/codano_big_ssl_encoder_7.pt" 294 | NS_variable_encoder_path: "../../../RB_weights/pre_trained_weights/codano_big_variable_encoder_7_NS.pt" 295 | T_variable_encoder_path: None 296 | 297 | scheduler_step: 2 298 | freeze_encoder : !!bool False 299 | clip_gradient: !!bool True 300 | scheduler_gamma: 0.5 301 | scheduler_type: 'rdp' 302 | lr: 0.05 303 | horizontal_skip: !!bool False 304 | weight_decay: 0.000000 305 | 306 | ft_codano_RB_small: &FT_CODANO_RB_SMALL 307 | <<: *CODANO_BIG 308 | batch_size: 5 309 | equation_dict: { "NS": 2, "T": 1} 310 | n_test: 150 311 | data_location: "../../RB_Data/RB_data/data_test_2500.npz" 312 | dataset: "RB" 313 | 314 | training_stage: 'fine_tune' 315 | pretrain_ssl : !!bool False 316 | masking: !!bool False 317 | 318 | n_layers_pred: 3 319 | n_heads_pred: [16 ,16,32] #,2] 320 | 321 | scalings_pred: [[1,1], [1,1], [1,1]] #, [1,1]] 322 | 323 | n_modes_pred: [[32,32], [32,32], [32,32]] #, [32,32]] 324 | max_n_modes_pred: [[32,32], [32,32], [32,32]] #,[32,32]] 325 | pretrain_weight: None 326 | NS_variable_encoder_path: None 327 | T_variable_encoder_path: None 328 | 329 | scheduler_step: 5 330 | freeze_encoder : !!bool False 331 | scheduler_gamma: 0.5 332 | lr: 0.005 333 | 334 | ft_codano_RB_test: &FT_CODANO_RB_TEST 335 | <<: *CODANO_RB 336 | n_train: 5 337 | batc_size: 2 338 | n_test: 5 339 | epochs: 10 340 | 341 | 342 | unet: &UNET 343 | <<: *CODANO_RB 344 | batch_size: 5 345 | nettype: 'unet' 346 | n_test: 150 347 | in_dim: 3 348 | out_dim: 3 349 | init_features: 64 350 | 351 | lr: 0.01 352 | scheduler_step: 5 353 | scheduler_gamma: 0.5 354 | 355 | fno: &FNO 356 | <<: *UNET 357 | nettype: 'fno' 358 | hidden_features: 32 359 | lifting_features: 64 360 | 361 | n_modes: [32,32] 362 | max_n_modes: [32,32] 363 | hidden_dim: 32 364 | in_dim: 3 365 | out_dim: 3 366 | lifting_dim: 64 367 | projection_dim: 64 368 | n_layers: 4 -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import * 2 | from .data_loaders import * 3 | -------------------------------------------------------------------------------- /data_utils/data_loaders.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import h5py 4 | import os 5 | import torch 6 | import numpy as np 7 | from torchvision.transforms import Normalize 8 | from torch.utils.data import ConcatDataset, random_split, DataLoader 9 | import itertools 10 | from utils import * 11 | from neuralop.datasets.tensor_dataset import TensorDataset 12 | from data_utils import get_mesh_displacement 13 | 14 | 15 | # data loaders of Releigh-Berneard and Navier-Stokes 16 | 17 | 18 | class regularMeshTensorDataset(TensorDataset): 19 | def __init__( 20 | self, 21 | x, 22 | y, 23 | transform_x=None, 24 | transform_y=None, 25 | equation=None): 26 | ''' 27 | data in format 28 | x: samples x channels x x_grid_size x y_grid_size 29 | ''' 30 | super().__init__(x, y, transform_x, transform_y) 31 | self.equation = equation 32 | x_grid_size = x[0].shape[-2] 33 | y_grid_size = x[0].shape[-1] 34 | 35 | self.static_features = self.generate_static_grid( 36 | x_grid_size, y_grid_size) 37 | 38 | def generate_static_grid(self, x_grid_size, y_grid_size): 39 | ''' 40 | creat grid of resolution x_grid_size x y_grid_size 41 | ''' 42 | x = torch.linspace(-1, 1, x_grid_size) 43 | y = torch.linspace(-1, 1, y_grid_size) 44 | x, y = torch.meshgrid(x, y) 45 | return torch.permute(torch.stack([x, y], dim=-1), (2, 0, 1)) 46 | 47 | def __getitem__(self, index): 48 | x = self.x[index] 49 | y = self.y[index] 50 | 51 | if self.transform_x is not None: 52 | x = self.transform_x(x) 53 | 54 | if self.transform_y is not None: 55 | y = self.transform_y(y) 56 | 57 | return {'x': x, 'y': y, 'equation': self.equation, 58 | 'static_features': self.static_features} 59 | 60 | 61 | def load_NS_dataset(path, n_samples, subsamplingrate=2): 62 | data_1 = torch.load(path[0]) 63 | data_2 = torch.load(path[1]) 64 | 65 | x1_t0 = data_1[:, :-1, :, :] 66 | x1_t1 = data_1[:, 1:, :, :] 67 | x2_t0 = data_2[:, :-1, :, :] 68 | x2_t1 = data_2[:, 1:, :, :] 69 | 70 | # flatten the data 71 | x1_t0 = x1_t0.reshape(-1, x1_t0.shape[-3], 72 | x1_t0.shape[-2], x1_t0.shape[-1]) 73 | x1_t1 = x1_t1.reshape(-1, x1_t0.shape[-3], 74 | x1_t1.shape[-2], x1_t1.shape[-1]) 75 | x2_t0 = x2_t0.reshape(-1, x1_t0.shape[-3], 76 | x2_t0.shape[-2], x2_t0.shape[-1]) 77 | x2_t1 = x2_t1.reshape(-1, x1_t0.shape[-3], 78 | x2_t1.shape[-2], x2_t1.shape[-1]) 79 | 80 | # shuffel the data 81 | indices = torch.randperm(x1_t0.shape[0]) 82 | x1_t0 = x1_t0[indices] 83 | x1_t1 = x1_t1[indices] 84 | idx = torch.randperm(x2_t0.shape[0]) 85 | x2_t0 = x2_t0[indices] 86 | x2_t1 = x2_t1[indices] 87 | 88 | ntain1 = int(n_samples / 2) 89 | ntain2 = n_samples - ntain1 90 | 91 | x1_train = x1_t0[:ntain1] 92 | y1_train = x1_t1[:ntain1] 93 | x2_train = x2_t0[:ntain2] 94 | y2_train = x2_t1[:ntain2] 95 | 96 | # concatenate the data 97 | x = torch.cat([x1_train, x2_train], dim=0) 98 | y = torch.cat([y1_train, y2_train], dim=0) 99 | 100 | return x.permute(0, 3, 1, 2), y.permute(0, 3, 1, 2) 101 | 102 | 103 | def load_NS_dataset_hdf5(path, n_samples, subsamplingrate=2): 104 | ''' 105 | loading only velocity field data 106 | ''' 107 | samples = 0 108 | i = 0 109 | data_x = [] 110 | data_y = [] 111 | 112 | while samples < n_samples + 100: 113 | if i > 999: 114 | break 115 | data_path = path[0] + 'data_' + str(i) + '.hdf5' 116 | x = h5py.File(data_path, "r")['data'] 117 | data_x.append(x[:: subsamplingrate, :: subsamplingrate, :-1, :2]) 118 | data_y.append(x[:: subsamplingrate, :: subsamplingrate, 1:, :2]) 119 | 120 | data_path = path[1] + 'data_' + str(i) + '.hdf5' 121 | x = h5py.File(data_path, "r")['data'] 122 | data_x.append(x[:: subsamplingrate, :: subsamplingrate, :-1, :2]) 123 | data_y.append(x[:: subsamplingrate, :: subsamplingrate, 1:, :2]) 124 | 125 | i += 1 126 | samples += (x.shape[2] - 1) 127 | 128 | x = torch.tensor(np.concatenate(data_x, axis=2)) 129 | y = torch.tensor(np.concatenate(data_y, axis=2)) 130 | 131 | print("Data Loaded: ", x.shape, y.shape) 132 | 133 | return torch.permute(x, (2, 3, 0, 1)), torch.permute(y, (2, 3, 0, 1)) 134 | 135 | 136 | def get_NS_dataloader(params): 137 | n_samples = params.n_train + params.n_test 138 | x, y = load_NS_dataset(params.data_location, 139 | n_samples, params.subsampling_rate) 140 | 141 | # shuffel and split test train 142 | indices = torch.randperm(x.shape[0]) 143 | x = x[indices] 144 | x_train = x[:params.n_train] 145 | y_train = y[:params.n_train] 146 | 147 | x_test = x[params.n_train:] 148 | y_test = y[params.n_train:] 149 | 150 | # get the mean and std of the data 151 | mean = torch.mean(x_train, dim=(0, 1, 2, 3)) 152 | std = torch.std(x_train, dim=(0, 1, 2, 3)) 153 | normalizer = Normalize(mean, std) 154 | 155 | dataset_train = regularMeshTensorDataset( 156 | x_train, 157 | y_train, 158 | transform_x=normalizer, 159 | transform_y=normalizer, 160 | equation=['NS']) 161 | dataset_test = regularMeshTensorDataset( 162 | x_test, 163 | y_test, 164 | transform_x=normalizer, 165 | transform_y=normalizer, 166 | equation=['NS']) 167 | 168 | dat_train = DataLoader( 169 | dataset_train, batch_size=params.batch_size, shuffle=True) 170 | dat_test = DataLoader( 171 | dataset_test, batch_size=params.batch_size, shuffle=False) 172 | return dat_train, dat_test 173 | 174 | 175 | def get_RB_dataloader(params): 176 | data = np.load(params.data_location) 177 | # files ['vx', 'vy', 'temp', 'time'] 178 | vx = torch.tensor(data['vx']).type(torch.float) 179 | vy = torch.tensor(data['vy']).type(torch.float) 180 | temp = torch.tensor(data['temp']).type(torch.float) 181 | time = torch.tensor(data['time']).type(torch.float) 182 | 183 | # stack the data 184 | data = torch.stack([vx, vy, temp], dim=1) 185 | data = data[params.skip_start:] 186 | x = data[:int(-1 * params.dt)] 187 | y = data[int(params.dt):] 188 | 189 | # shuffel and split test train 190 | # fix the seed 191 | torch.manual_seed(params.random_seed) 192 | indices = torch.randperm(x.shape[0]) 193 | x = x[indices] 194 | y = y[indices] 195 | 196 | x_train = x[:params.n_train, :, 197 | ::params.subsampling_rate, ::params.subsampling_rate] 198 | y_train = y[:params.n_train, :, 199 | ::params.subsampling_rate, ::params.subsampling_rate] 200 | 201 | x_test = x[-params.n_test:] 202 | y_test = y[-params.n_test:] 203 | print("len test data", len(x_test), len(y_test), params.n_test, x.shape) 204 | 205 | # get the mean and std of the data 206 | mean = torch.mean(x_train, dim=(0, 2, 3)) 207 | std = torch.std(x_train, dim=(0, 2, 3)) 208 | normalizer = Normalize(mean, std) 209 | 210 | dataset_train = regularMeshTensorDataset( 211 | x_train, 212 | y_train, 213 | transform_x=normalizer, 214 | transform_y=normalizer, 215 | equation=['ES']) 216 | dataset_test = regularMeshTensorDataset( 217 | x_test, 218 | y_test, 219 | transform_x=normalizer, 220 | transform_y=normalizer, 221 | equation=['ES']) 222 | 223 | dat_train = DataLoader( 224 | dataset_train, batch_size=params.batch_size, shuffle=True) 225 | dat_test = DataLoader( 226 | dataset_test, batch_size=params.batch_size, shuffle=False) 227 | 228 | return dat_train, dat_test 229 | 230 | 231 | # dataloader for Fluid Sturctur Interaction (FSI) problems 232 | 233 | class IrregularMeshTensorDataset(TensorDataset): 234 | def __init__( 235 | self, 236 | x, 237 | y, 238 | transform_x=None, 239 | transform_y=None, 240 | equation=None, 241 | x1=0, 242 | x2=0, 243 | mu=0.1, 244 | mesh=None): 245 | super().__init__(x, y, transform_x, transform_y) 246 | self.x1 = x1 247 | self.x2 = x2 248 | 249 | self.mu = mu 250 | self.mesh = mesh 251 | self.equation = equation 252 | print("Inside Dataset :", self.mesh.dtype, x.dtype, x.dtype) 253 | self._creat_static_features() 254 | 255 | def _creat_static_features(self, d_grid=None): 256 | ''' 257 | creating static channels for inlet and reynolds number 258 | ''' 259 | n_grid_points = self.x.shape[1] 260 | if len(self.equation) == 1: 261 | # equation can be either ['NS'] or ['NS', 'ES'] 262 | # of 3 or 5 channels/varibales 263 | n_variables = 3 264 | else: 265 | n_variables = self.x.shape[-1] 266 | if d_grid is not None: 267 | positional_enco = self.mesh + d_grid 268 | else: 269 | positional_enco = self.mesh 270 | 271 | raynolds = torch.ones(n_grid_points, 1) * self.mu 272 | inlet = ((-self.x1 / 2 + positional_enco[:, 1]) * 273 | (-self.x2 / 2 + positional_enco[:, 1]))[:, None]**2 274 | 275 | self.static_features = torch.cat( 276 | [raynolds, inlet, positional_enco], dim=-1).repeat(1, n_variables) 277 | 278 | def __getitem__(self, index): 279 | x = self.x[index] 280 | y = self.y[index] 281 | 282 | d_grid_x = get_mesh_displacement(x) 283 | d_grid_y = get_mesh_displacement(y) 284 | 285 | self._creat_static_features(d_grid_x) 286 | 287 | if self.transform_x is not None: 288 | x = self.transform_x(x) 289 | 290 | if self.transform_y is not None: 291 | y = self.transform_y(y) 292 | 293 | if len(self.equation) == 1: 294 | x = x[:, :3] 295 | y = y[:, :3] 296 | 297 | return {'x': x, 'y': y, 'd_grid_x': d_grid_x, 298 | 'd_grid_y': d_grid_y, 'static_features': self.static_features, 299 | 'equation': self.equation} 300 | 301 | 302 | class Normalizer(): 303 | def __init__(self, mean, std, eps=1e-6, persample=False): 304 | self.persample = persample 305 | self.mean = mean 306 | self.std = std 307 | self.eps = eps 308 | 309 | def __call__(self, data): 310 | if self.persample: 311 | self.mean = torch.mean(data, dim=(0)) 312 | self.std = torch.var(data, dim=(0))**0.5 313 | return (data - self.mean) / (self.std + self.eps) 314 | 315 | def denormalize(self, data): 316 | return data * (self.std + self.eps) + self.mean 317 | 318 | def cuda(self,): 319 | if self.mean is not None and self.std is not None: 320 | self.mean = self.mean.cuda() 321 | self.std = self.std.cuda() 322 | 323 | 324 | class NsElasticDataset(): 325 | def __init__(self, location, equation, mesh_location, params): 326 | self.location = location 327 | 328 | # _x1 and _x2 are the paraemters for the inlets condtions 329 | # _mu is the visocity 330 | self._x1 = [-4.0, -2.0, 0.0, 2.0, 4.0, 6.0] 331 | self._x2 = [-4.0, -2.0, 0, 2.0, 4.0, 6.0] 332 | self._mu = [0.1, 0.01, 0.5, 5, 1, 10] 333 | if params.data_partition == 'supervised': 334 | # held out 2 inlets for finetuning 335 | # there not introduced in the self-supevised 336 | # pretraining 337 | self._x1 = params.supervised_inlets_x1 338 | self._x2 = params.supervised_inlets_x2 339 | elif params.data_partition == 'self-supervised': 340 | self._x1 = list(set(self._x1) - set(params.supervised_inlets_x1)) 341 | self._x2 = list(set(self._x2) - set(params.supervised_inlets_x2)) 342 | else: 343 | raise ValueError( 344 | f"Data partition {params.data_partition} not supported") 345 | 346 | self.equation = equation 347 | 348 | mesh = get_mesh(params) 349 | self.input_mesh = torch.from_numpy(mesh).type(torch.float) 350 | print("Mesh Shape: ", self.input_mesh.shape) 351 | self.params = params 352 | 353 | self.normalizer = Normalizer(None, None, persample=True) 354 | 355 | def _readh5(self, h5f, dtype=torch.float32): 356 | a_dset_keys = list(h5f['VisualisationVector'].keys()) 357 | size = len(a_dset_keys) 358 | readings = [None for i in range(size)] 359 | for dset in a_dset_keys: 360 | ds_data = (h5f['VisualisationVector'][dset]) 361 | readings[int(dset)] = torch.tensor(np.array(ds_data), dtype=dtype) 362 | 363 | readings_tensor = torch.stack(readings, dim=0) 364 | print(f"Loaded tensor Size: {readings_tensor.shape}") 365 | return readings_tensor 366 | 367 | def get_data(self, mu, x1, x2): 368 | if mu not in self._mu: 369 | raise ValueError(f"Value of mu must be one of {self._mu}") 370 | if x1 not in self._x1 or x2 not in self._x2: 371 | raise ValueError( 372 | f"Value of is must be one of {self._ivals3} and {self._ivals12} ") 373 | if mu == 0.5: 374 | path = os.path.join( 375 | self.location, 376 | 'mu=' + str(mu), 377 | 'x1=' + str(-2.0), 378 | 'x2=' + str(x2), 379 | '1', 380 | 'Visualization') 381 | print(path) 382 | else: 383 | path = os.path.join( 384 | self.location, 385 | 'mu=' + str(mu), 386 | 'x1=' + str(x1), 387 | 'x2=' + str(x2), 388 | 'Visualization') 389 | 390 | filename = os.path.join(path, 'displacement.h5') 391 | 392 | h5f = h5py.File(filename, 'r') 393 | displacements_tensor = self._readh5(h5f) 394 | 395 | filename = os.path.join(path, 'pressure.h5') 396 | h5f = h5py.File(filename, 'r') 397 | pressure_tensor = self._readh5(h5f) 398 | 399 | filename = os.path.join(path, 'velocity.h5') 400 | h5f = h5py.File(filename, 'r') 401 | velocity_tensor = self._readh5(h5f) 402 | 403 | return velocity_tensor, pressure_tensor, displacements_tensor 404 | 405 | def get_data_txt(self, mu, x1, x2): 406 | if mu not in self._mu: 407 | raise ValueError(f"Value of mu must be one of {self._mu}") 408 | if x1 not in self._x1 or x2 not in self._x2: 409 | raise ValueError( 410 | f"Value of is must be one of {self._ivals3} and {self._ivals12} ") 411 | path = os.path.join( 412 | self.location, 413 | 'mu=' + str(mu), 414 | 'x1=' + str(x1), 415 | 'x2=' + str(x2), 416 | '1') 417 | 418 | velocity_x = torch.tensor(np.loadtxt(os.path.join(path, 'vel_x.txt'))) 419 | velocity_y = torch.tensor(np.loadtxt(os.path.join(path, 'vel_y.txt'))) 420 | if len(self.params.equation_dict) != 1: 421 | dis_x = torch.tensor(np.loadtxt(os.path.join(path, 'dis_x.txt'))) 422 | dis_y = torch.tensor(np.loadtxt(os.path.join(path, 'dis_y.txt'))) 423 | pressure = torch.tensor(np.loadtxt(os.path.join(path, 'pres.txt'))) 424 | else: 425 | # just copying values as place holder when only NS equation is used 426 | dis_x = velocity_x 427 | dis_y = velocity_y 428 | pressure = velocity_x 429 | 430 | # reshape each tensor into 2d by keeping 876 entries in each row 431 | dis_x = dis_x.view(-1, 876, 1) 432 | dis_y = dis_y.view(-1, 876, 1) 433 | pressure = pressure.view(-1, 876, 1) 434 | velocity_x = velocity_x.view(-1, 876, 1) 435 | velocity_y = velocity_y.view(-1, 876, 1) 436 | 437 | velocity = torch.cat([velocity_x, velocity_y], dim=-1) 438 | displacement = torch.cat([dis_x, dis_y], dim=-1) 439 | 440 | return velocity.to( 441 | torch.float), pressure.to( 442 | torch.float), displacement.to( 443 | torch.float) 444 | 445 | def get_dataloader( 446 | self, 447 | mu_list, 448 | dt, 449 | normalize=True, 450 | batch_size=1, 451 | train_test_split=0.2, 452 | sample_per_inlet=200, 453 | ntrain=None, 454 | ntest=None, 455 | data_loader_kwargs={'num_workers': 2}): 456 | 457 | train_datasets = [] 458 | test_datasets = [] 459 | 460 | for mu in mu_list: 461 | train, test = self.get_tensor_dataset( 462 | mu, dt, normalize, train_test_split=train_test_split, sample_per_inlet=sample_per_inlet) 463 | train_datasets.append(train) 464 | test_datasets.append(test) 465 | train_dataset = ConcatDataset(train_datasets) 466 | test_dataset = ConcatDataset(test_datasets) 467 | print("****Train dataset size***: ", len(train_dataset)) 468 | print("***Test dataset size***: ", len(test_dataset)) 469 | if ntrain is not None: 470 | train_dataset = random_split( 471 | train_dataset, [ntrain, len(train_dataset) - ntrain])[0] 472 | if ntest is not None: 473 | test_dataset = random_split( 474 | test_dataset, [ntest, len(test_dataset) - ntest])[0] 475 | 476 | train_dataloader = DataLoader( 477 | train_dataset, batch_size=batch_size, **data_loader_kwargs) 478 | test_dataloader = DataLoader( 479 | test_dataset, batch_size=batch_size, **data_loader_kwargs) 480 | 481 | return train_dataloader, test_dataloader 482 | 483 | def get_tensor_dataset( 484 | self, 485 | mu, 486 | dt, 487 | normalize=True, 488 | min_max_normalize=False, 489 | train_test_split=0.2, 490 | sample_per_inlet=200, 491 | x1_list=None, 492 | x2_list=None): 493 | 494 | if x1_list is None: 495 | x1_list = self._x1 496 | if x2_list is None: 497 | x2_list = self._x2 498 | train_datasets = [] 499 | test_datasets = [] 500 | # for the given mu 501 | # loop over all given inlets 502 | for x1 in x1_list: 503 | for x2 in x2_list: 504 | try: 505 | if mu == 0.5: 506 | velocities, pressure, displacements = self.get_data_txt( 507 | mu, x1, x2) 508 | else: 509 | velocities, pressure, displacements = self.get_data( 510 | mu, x1, x2) 511 | except FileNotFoundError as e: 512 | print(e) 513 | print( 514 | f"Original file not found for mu={mu}, x1={x1}, x2={x2}") 515 | continue 516 | 517 | # keeping vx,xy, P, dx,dy 518 | varable_idices = [0, 1, 3, 4, 5] 519 | if mu == 0.5: 520 | combined = torch.cat( 521 | [velocities, pressure, displacements], dim=-1)[:sample_per_inlet, :, :] 522 | else: 523 | combined = torch.cat( 524 | [velocities, pressure, displacements], dim=-1)[:sample_per_inlet, :, varable_idices] 525 | 526 | if hasattr( 527 | self.params, 528 | 'sub_sample_size') and self.params.sub_sample_size is not None: 529 | mesh_size = combined.shape[1] 530 | indexs = [i for i in range(mesh_size)] 531 | np.random.seed(self.params.random_seed) 532 | sub_indexs = np.random.choice( 533 | indexs, self.params.sub_sample_size, replace=False) 534 | combined = combined[:, sub_indexs, :] 535 | 536 | if self.params.super_resolution: 537 | new_quieries = self.get_data_txt( 538 | mu, x1, x2).to(dtype=combined.dtype) 539 | new_quieries = new_quieries[:sample_per_inlet, :] 540 | 541 | print("shape of old data", combined.shape) 542 | print("shape of new data", new_quieries.shape) 543 | 544 | combined = torch.cat([combined, new_quieries], dim=-2) 545 | print("shape of combined data", combined.shape) 546 | 547 | step_t0 = combined[:-dt, ...] 548 | step_t1 = combined[dt:, ...] 549 | 550 | indexs = [i for i in range(step_t0.shape[0])] 551 | 552 | ntrain = int((1 - train_test_split) * len(indexs)) 553 | ntest = len(indexs) - ntrain 554 | 555 | random.shuffle(indexs) 556 | train_t0, test_t0 = step_t0[indexs[:ntrain] 557 | ], step_t0[indexs[ntrain:ntrain + ntest]] 558 | train_t1, test_t1 = step_t1[indexs[:ntrain] 559 | ], step_t1[indexs[ntrain:ntrain + ntest]] 560 | 561 | if not normalize: 562 | normalizer = None 563 | else: 564 | if not min_max_normalize: 565 | mean, var = torch.mean(train_t0, dim=( 566 | 0, 1)), torch.var(train_t0, dim=(0, 1))**0.5 567 | else: 568 | mean = torch.min( 569 | train_t0.view(-1, train_t0.shape[-1]), dim=0)[0] 570 | var = torch.max(train_t0.view(-1, 571 | train_t0.shape[-1]), 572 | dim=0)[0] - torch.min(train_t0.view(-1, 573 | train_t0.shape[-1]), 574 | dim=0)[0] 575 | 576 | normalizer = Normalizer(mean, var) 577 | 578 | train_datasets.append( 579 | IrregularMeshTensorDataset( 580 | train_t0, 581 | train_t1, 582 | normalizer, 583 | normalizer, 584 | x1=x1, 585 | x2=x2, 586 | mu=mu, 587 | equation=self.equation, 588 | mesh=self.input_mesh)) 589 | test_datasets.append( 590 | IrregularMeshTensorDataset( 591 | test_t0, 592 | test_t1, 593 | normalizer, 594 | normalizer, 595 | x1=x1, 596 | x2=x2, 597 | mu=mu, 598 | equation=self.equation, 599 | mesh=self.input_mesh)) 600 | 601 | return ConcatDataset(train_datasets), ConcatDataset(test_datasets) 602 | -------------------------------------------------------------------------------- /data_utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import sys 4 | from typing import Optional, Tuple 5 | from utils import * 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset 10 | 11 | sys.path.append('..') 12 | 13 | 14 | class ResizeDataset(Dataset): 15 | def __init__(self, parent_dataset, resolution): 16 | self.parent_dataset = parent_dataset 17 | self.resolution = resolution 18 | 19 | def __len__(self,): 20 | return self.parent_dataset.__len__() 21 | 22 | def __getitem__(self, global_idx): 23 | x, y = self.parent_dataset.__getitem__(global_idx) 24 | xp = F.interpolate( 25 | x[None, ...] if len(x.shape) < 4 else x, 26 | size=self.resolution, 27 | mode='bicubic', 28 | align_corners=True, 29 | ) 30 | yp = F.interpolate( 31 | y[None, ...] if len(y.shape) < 4 else y, 32 | size=self.resolution, 33 | mode='bicubic', 34 | align_corners=True, 35 | ) 36 | return torch.squeeze(xp), torch.squeeze(yp) 37 | 38 | 39 | class MaskerUniform: 40 | """Performs masking on datasets with regular meshes. 41 | 42 | For masking with data points from data sets with irregular meshes, 43 | use ``MaskerNonuniformMesh`` (below). 44 | """ 45 | 46 | def __init__( 47 | self, 48 | drop_type='zeros', 49 | max_block=0.7, 50 | drop_pix=0.3, 51 | channel_per=0.5, 52 | channel_drop_per=0.2, 53 | device='cpu', 54 | min_block=30, 55 | ): 56 | self.drop_type = drop_type 57 | self.max_block = max_block 58 | self.drop_pix = drop_pix 59 | self.channel_per = channel_per 60 | self.channel_drop_per = channel_drop_per 61 | self.device = device 62 | self.min_block = min_block 63 | 64 | def __call__(self, size): 65 | """Returns a mask to be multiplied into a data tensor. 66 | 67 | Generates a binary mask of 0s and 1s to be point-wise multiplied into a 68 | data tensor to create a masked sample. By training on masked data, we 69 | expect the model to be resilient to missing data. 70 | 71 | TODO arg ``max_block_sz`` is outdated 72 | Parameters 73 | --- 74 | max_block_sz: percentage of the maximum block to be dropped 75 | """ 76 | 77 | np.random.seed() 78 | C, H, W = size 79 | mask = torch.ones(size, device=self.device) 80 | drop_t = self.drop_type 81 | if self.channel_per == 1.0: 82 | augmented_channels = [i for i in range(C)] 83 | else: 84 | augmented_channels = np.random.choice( 85 | C, math.ceil(C * self.channel_per)) 86 | drop_len = int(self.channel_drop_per * math.ceil(C * self.channel_per)) 87 | mask[augmented_channels[:drop_len], :, :] = 0.0 88 | for i in augmented_channels[drop_len:]: 89 | n_drop_pix = self.drop_pix * H * W 90 | mx_blk_height = int(H * self.max_block) 91 | mx_blk_width = int(W * self.max_block) 92 | 93 | repetation = 3 94 | while n_drop_pix > 0 and repetation > 0: 95 | rnd_r = random.randint(0, H - 2) 96 | rnd_c = random.randint(0, W - 2) 97 | 98 | rnd_h = min( 99 | random.randint(self.min_block, mx_blk_height), 100 | H - rnd_r 101 | ) 102 | rnd_w = min( 103 | random.randint(self.min_block, mx_blk_width), 104 | W - rnd_c 105 | ) 106 | mask[i, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = 0 107 | n_drop_pix -= rnd_h * rnd_c 108 | repetation -= 1 109 | return None, mask 110 | 111 | 112 | def batched_masker( 113 | data: torch.Tensor, # with a batch dimension 114 | aug, # instance of a Masker like above 115 | batched_channels: Optional[Tuple[Tuple[int]]] = None, 116 | ): 117 | mask = [] 118 | aug.device = data.device 119 | for i in range(data.shape[0]): 120 | _, m = aug(data[i].shape) 121 | mask.append(m) 122 | 123 | masks = torch.stack(mask, dim=0) 124 | return data * masks, masks 125 | 126 | 127 | # Nearest neighbour type masking for non-uniform 128 | # mesh 129 | 130 | class MaskerNonuniformMesh(object): 131 | def __init__( 132 | self, 133 | grid_non_uni, 134 | gird_uni, 135 | radius, 136 | drop_type='zeros', 137 | drop_pix=0.3, 138 | channel_aug_rate=0.7, 139 | channel_drop_rate=0.2, 140 | device='cpu', 141 | max_block=10, 142 | verbose=None 143 | ): 144 | """ 145 | Parameters 146 | --- 147 | drop_type : dropped pixels are filled with zeros 148 | drop_pix: Percentage of pixels to be dropped 149 | channel_aug_rate: Percentage of channels to be augmented 150 | channel_drop_rate: Percentage of (channel_aug_rate) channels to be masked completely (drop all pxiels) 151 | max_block: Maximum number of regions to be dropped. 152 | """ 153 | self.grid_non_uni = grid_non_uni 154 | self.grid_uni = gird_uni 155 | dists = torch.cdist(gird_uni, grid_non_uni).to( 156 | gird_uni.device) # shaped num query points x num data points 157 | self.in_nbr = torch.where(dists <= radius, 1., 0.).long() 158 | 159 | self.drop_type = drop_type 160 | self.drop_pix = drop_pix 161 | self.channel_aug_rate = channel_aug_rate 162 | self.channel_drop_rate = channel_drop_rate 163 | self.device = device 164 | self.max_block = max_block 165 | self.verbose = verbose 166 | 167 | def __call__(self, size): 168 | 169 | L, C = size 170 | mask = torch.ones(size, device=self.device) 171 | drop_t = self.drop_type # no effect now 172 | 173 | augmented_channels = np.random.choice( 174 | C, math.ceil(C * self.channel_aug_rate)) 175 | 176 | p = random.random() 177 | # with 50% probability, drop all pixels in very few channels 178 | # or doing masking to many channels 179 | if p < 0.5: 180 | drop_len = int( 181 | self.channel_drop_rate * 182 | math.ceil( 183 | C * 184 | self.channel_aug_rate)) 185 | mask[:, augmented_channels[:drop_len]] = 0.0 186 | return None, mask 187 | else: 188 | drop_len = 0 189 | # print('masking', augmented_channels[drop_len:]) 190 | for i in augmented_channels[drop_len:]: 191 | n_drop_pix = self.drop_pix * L 192 | max_location = self.max_block 193 | while n_drop_pix > 0: 194 | # python random is inclusive of low and high 195 | j = random.randint(0, self.in_nbr.shape[0] - 1) 196 | mask[self.in_nbr[j] == 1, i] = 0 197 | n_drop_pix -= sum(self.in_nbr[j]).float() 198 | max_location -= 1 199 | if max_location == 0: 200 | break 201 | return None, mask 202 | 203 | 204 | def get_meshes(params, grid_size): 205 | mesh = get_mesh(params) 206 | input_mesh = torch.from_numpy(mesh).type(torch.float).cuda() 207 | 208 | minx, maxx = np.min(mesh[:, 0]), np.max(mesh[:, 0]) 209 | miny, maxy = np.min(mesh[:, 1]), np.max(mesh[:, 1]) 210 | 211 | size_x, size_y = grid_size 212 | idx_x = torch.arange(start=minx, 213 | end=maxx + (maxx - minx) / size_x - 1e-5, 214 | step=(maxx - minx) / (size_x - 1)) 215 | idx_y = torch.arange(start=miny, 216 | end=maxy + (maxy - miny) / size_y - 1e-5, 217 | step=(maxy - miny) / (size_y - 1)) 218 | x, y = torch.meshgrid(idx_x, idx_y, indexing='ij') 219 | output_mesh = torch.transpose(torch.stack( 220 | [x.flatten(), y.flatten()]), 0, 1).type(torch.float).cuda() 221 | 222 | return input_mesh, output_mesh 223 | 224 | def get_mesh_displacement(x): 225 | """ 226 | Only returns the displacement field. 227 | """ 228 | return x[:, -2:].clone().detach() 229 | -------------------------------------------------------------------------------- /exps_FSI.sh: -------------------------------------------------------------------------------- 1 | # SSL - 2 | python main.py --exp FSI --config codano_gno_NS_ES --ntrain 8000 3 | python main.py --exp FSI --config codano_gno_NS --ntrain 8000 4 | 5 | 6 | #Finetuning 7 | ## Re = 400 8 | python main.py --exp FSI --config ft_NSES_NSES_5 --ntrain 5 --epochs 50 --scheduler_step 10 9 | python main.py --exp FSI --config ft_NSES_NSES_5 --ntrain 25 10 | python main.py --exp FSI --config ft_NSES_NSES_5 --ntrain 100 11 | 12 | python main.py --exp FSI --config ft_NS_NSES_5 --ntrain 5 --epochs 50 --scheduler_step 10 13 | python main.py --exp FSI --config ft_NS_NSES_5 --ntrain 25 14 | python main.py --exp FSI --config ft_NS_NSES_5 --ntrain 100 15 | 16 | python main.py --exp FSI --config ft_NSES_NS_5 --ntrain 5 --epochs 50 --scheduler_step 10 17 | python main.py --exp FSI --config ft_NSES_NS_5 --ntrain 25 18 | python main.py --exp FSI --config ft_NSES_NS_5 --ntrain 100 19 | 20 | python main.py --exp FSI --config ft_NS_NS_5 --ntrain 5 --epochs 50 --scheduler_step 10 21 | python main.py --exp FSI --config ft_NS_NS_5 --ntrain 25 22 | python main.py --exp FSI --config ft_NS_NS_5 --ntrain 100 23 | 24 | # ## Re = 4000 25 | python main.py --exp FSI --config ft_NSES_NSES_0.5 --ntrain 5 26 | python main.py --exp FSI --config ft_NSES_NSES_0.5 --ntrain 25 27 | python main.py --exp FSI --config ft_NSES_NSES_0.5 --ntrain 100 28 | 29 | python main.py --exp FSI --config ft_NS_NSES_0.5 --ntrain 5 30 | python main.py --exp FSI --config ft_NS_NSES_0.5 --ntrain 25 31 | python main.py --exp FSI --config ft_NS_NSES_0.5 --ntrain 100 32 | 33 | -------------------------------------------------------------------------------- /exps_RB.sh: -------------------------------------------------------------------------------- 1 | #SSL 2 | python main.py --exp RB --config codano_big --ntrain 40000 3 | 4 | # Finetuning 5 | python main.py --exp RB --config ft_codano_RB --ntrain 5 --epochs 50 --batch_size 1 6 | python main.py --exp RB --config ft_codano_RB --ntrain 10 --epochs 50 --batch_size 2 7 | python main.py --exp RB --config ft_codano_RB --ntrain 25 --epochs 50 --batch_size 5 8 | 9 | python main.py --exp RB --config unet --ntrain 5 --epochs 50 --batch_size 1 10 | python main.py --exp RB --config unet --ntrain 10 --epochs 50 --batch_size 2 11 | python main.py --exp RB --config unet --ntrain 25 --epochs 50 --batch_size 5 12 | 13 | python main.py --exp RB --config fno --ntrain 5 --epochs 50 --batch_size 1 14 | python main.py --exp RB --config fno --ntrain 10 --epochs 50 --batch_size 2 15 | python main.py --exp RB --config fno --ntrain 25 --epochs 50 --batch_size 5 16 | 17 | python main.py --exp RB --config codano_RB --ntrain 5 --epochs 50 --batch_size 1 18 | python main.py --exp RB --config codano_RB --ntrain 10 --epochs 50 --batch_size 2 19 | python main.py --exp RB --config codano_RB --ntrain 25 --epochs 50 --batch_size 5 -------------------------------------------------------------------------------- /fsi_animation_dx.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/fsi_animation_dx.gif -------------------------------------------------------------------------------- /fsi_animation_pressue.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/fsi_animation_pressue.gif -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/layers/__init__.py -------------------------------------------------------------------------------- /layers/codano_block_2D.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | import logging 4 | import numpy as np 5 | from einops import rearrange 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from neuralop.layers.fno_block import FNOBlocks 10 | from .fino_2D import SpectralConvKernel2d 11 | 12 | # Implementation of 2d Codano Blocks 13 | 14 | 15 | def NO_OP(x, *_args, **_kwargs): 16 | return x 17 | 18 | 19 | AffineNormalizer2D = partial(nn.InstanceNorm2d, affine=True) 20 | 21 | 22 | class CodanoBlocks(nn.Module): 23 | def __init__( 24 | self, 25 | n_modes, 26 | n_head=1, 27 | token_codimension=1, 28 | output_scaling_factor=None, 29 | incremental_n_modes=None, 30 | head_codimension=None, 31 | use_mlp=False, 32 | mlp=None, 33 | non_linearity=F.gelu, 34 | preactivation=False, 35 | fno_skip='linear', 36 | mlp_skip='soft-gating', 37 | mlp_expansion=1.0, 38 | separable=False, 39 | factorization='tucker', 40 | rank=1.0, 41 | SpectralConvolution=None, 42 | Normalizer=None, 43 | joint_factorization=False, 44 | fixed_rank_modes=False, 45 | implementation='factorized', 46 | decomposition_kwargs=None, 47 | fft_norm='forward', 48 | codimension_size=None, 49 | per_channel_attention=True, 50 | permutation_eq=True, 51 | temperature=1.0, 52 | kqv_non_linear=False, 53 | **_kwargs, 54 | ): 55 | super().__init__() 56 | 57 | # Co-dimension of each variable/token. The token embedding space is 58 | # identical to the variable space, so their dimensionalities are equal. 59 | self.variable_codimension = token_codimension 60 | self.token_codimension = token_codimension 61 | 62 | # (maybe) codim of attention from each head 63 | self.head_codimension = (head_codimension 64 | if head_codimension is not None 65 | else token_codimension) 66 | self.n_head = n_head # number of heads 67 | self.output_scaling_factor = output_scaling_factor 68 | self.temperature = temperature 69 | 70 | # K,Q,V operator with or without non_lin 71 | if kqv_non_linear: 72 | kqv_activation = non_linearity 73 | else: 74 | kqv_activation = NO_OP 75 | 76 | self.permutation_eq = permutation_eq 77 | 78 | if self.n_head is not None: 79 | # recalculating the value of `head_codim` 80 | self.head_codimension = max(token_codimension // self.n_head, 1) 81 | 82 | self.codimension_size = codimension_size 83 | self.mixer_token_codimension = token_codimension 84 | 85 | if per_channel_attention: 86 | # for per channel attention, forcing the values of token dims 87 | self.token_codimension = 1 88 | self.head_codimension = 1 89 | 90 | # this scale used for downsampling Q,K functions 91 | scale = 2 if per_channel_attention else 1 92 | scale = min(self.n_head, scale) 93 | 94 | mixer_modes = [i // scale for i in n_modes] 95 | 96 | 97 | if decomposition_kwargs is None: 98 | decomposition_kwargs = {} 99 | common_args = dict( 100 | use_mlp=use_mlp, 101 | mlp=mlp, 102 | preactivation=preactivation, 103 | mlp_skip=mlp_skip, 104 | mlp_dropout=0, 105 | incremental_n_modes=incremental_n_modes, 106 | rank=rank, 107 | fft_norm=fft_norm, 108 | mlp_expansion=mlp_expansion, 109 | fixed_rank_modes=fixed_rank_modes, 110 | implementation=implementation, 111 | separable=separable, 112 | factorization=factorization, 113 | decomposition_kwargs=decomposition_kwargs, 114 | joint_factorization=joint_factorization, 115 | ) 116 | 117 | kqv_args = dict( 118 | in_channels=self.token_codimension, 119 | out_channels=self.n_head * self.head_codimension, 120 | n_modes=mixer_modes, 121 | # args below are shared with Projection block 122 | non_linearity=kqv_activation, 123 | fno_skip='linear', 124 | norm=None, 125 | apply_skip=True, 126 | n_layers=1, 127 | ) 128 | self.K = FNOBlocks( 129 | output_scaling_factor=1 / scale, 130 | SpectralConv=partial( 131 | SpectralConvolution, 132 | rank=0.5, 133 | factorization=None, 134 | ), 135 | **kqv_args, 136 | **common_args, 137 | ) 138 | self.Q = FNOBlocks( 139 | output_scaling_factor=1 / scale, 140 | SpectralConv=partial( 141 | SpectralConvolution, 142 | rank=0.5, 143 | factorization=None, 144 | ), 145 | **kqv_args, 146 | **common_args, 147 | ) 148 | self.V = FNOBlocks( 149 | output_scaling_factor=1, 150 | SpectralConv=partial( 151 | SpectralConvolution, 152 | rank=0.5, 153 | factorization=None, 154 | ), 155 | **kqv_args, 156 | **common_args, 157 | ) 158 | 159 | if self.n_head * self.head_codimension != self.token_codimension: 160 | self.proj = FNOBlocks( 161 | in_channels=self.n_head * self.head_codimension, 162 | out_channels=self.token_codimension, 163 | n_modes=n_modes, 164 | output_scaling_factor=1, 165 | # args below are shared with KQV blocks 166 | apply_skip=True, 167 | non_linearity=NO_OP, 168 | fno_skip='linear', 169 | norm=None, 170 | SpectralConv=partial( 171 | SpectralConvolution, 172 | rank=0.5, 173 | factorization=None, 174 | ), 175 | n_layers=1, 176 | **common_args, 177 | ) 178 | else: 179 | self.proj = None 180 | 181 | self.attention_normalizer = Normalizer(self.token_codimension) 182 | 183 | mixer_args = dict( 184 | n_modes=n_modes, 185 | output_scaling_factor=1, 186 | non_linearity=non_linearity, 187 | norm='instance_norm', 188 | fno_skip=fno_skip, 189 | SpectralConv=partial( 190 | SpectralConvolution, 191 | rank=0.5, 192 | factorization=None, 193 | bias=True, 194 | ), 195 | ) 196 | # We have an option to make the last operator (MLP in regular 197 | # Transformer block) permutation equivariant. i.e., applying the 198 | # operator per variable or applying the operator on the whole channel 199 | # (like regular FNO). 200 | if permutation_eq: 201 | self.mixer = FNOBlocks( 202 | in_channels=self.mixer_token_codimension, 203 | out_channels=self.mixer_token_codimension, 204 | apply_skip=True, 205 | n_layers=2, 206 | **mixer_args, 207 | **common_args, 208 | ) 209 | self.norm1 = Normalizer(self.token_codimension) 210 | self.norm2 = Normalizer(self.mixer_token_codimension) 211 | self.mixer_out_normalizer = Normalizer( 212 | self.mixer_token_codimension) 213 | 214 | else: 215 | self.mixer = FNOBlocks( 216 | in_channels=codimension_size, 217 | out_channels=codimension_size, 218 | n_layers=2, 219 | **mixer_args, 220 | **common_args, 221 | ) 222 | self.norm1 = Normalizer(codimension_size) 223 | self.norm2 = Normalizer(codimension_size) 224 | self.mixer_out_normalizer = Normalizer(codimension_size) 225 | 226 | def forward(self, *args): 227 | raise NotImplementedError( 228 | "Use a proper subclass of CodanoBlock (i.e. CodanoBlock2d or CodanoBlock3D).") 229 | 230 | 231 | class CodanoBlocks2d(CodanoBlocks): 232 | def __init__(self, *args, **kwargs): 233 | Normalizer = kwargs.get("Normalizer") 234 | if Normalizer is None: 235 | Normalizer = AffineNormalizer2D 236 | kwargs["Normalizer"] = Normalizer 237 | 238 | Convolution = kwargs.get("SpectralConvolution") 239 | if Convolution is None: 240 | Convolution = SpectralConvKernel2d 241 | kwargs["SpectralConvolution"] = Convolution 242 | 243 | super().__init__(*args, **kwargs) 244 | 245 | # XXX rewrite comments on TNO*3D 246 | def compute_attention(self, xa, batch_size): 247 | """Compute the key-query-value variant of the attention matrix. 248 | 249 | Assumes input ``xa`` has been normalized. 250 | """ 251 | k = self.K(xa) 252 | q = self.Q(xa) 253 | v = self.V(xa) 254 | 255 | value_x, value_y = v.shape[-2], v.shape[-1] 256 | 257 | rearrangement = dict( 258 | pattern='(b t) (a d) h w -> b a t (d h w)', 259 | b=batch_size, 260 | a=self.n_head, 261 | ) 262 | k = rearrange(k, **rearrangement) 263 | q = rearrange(q, **rearrangement) 264 | v = rearrange(v, **rearrangement) 265 | 266 | dprod = (torch.matmul(q, k.transpose(-1, -2)) / 267 | (np.sqrt(k.shape[-1]) * self.temperature)) 268 | dprod = F.softmax(dprod, dim=-1) 269 | 270 | attention = torch.matmul(dprod, v) 271 | attention = rearrange( 272 | attention, 273 | 'b a t (d h w) -> b t a d h w', 274 | d=self.head_codimension, 275 | h=value_x, 276 | w=value_y, 277 | ) 278 | attention = rearrange(attention, 'b t a d h w -> (b t) (a d) h w') 279 | return attention 280 | 281 | def forward(self, x, output_shape=None): 282 | if self.permutation_eq: 283 | return self._forward_equivariant(x) 284 | else: 285 | return self._forward_non_equivariant(x) 286 | 287 | def _forward_equivariant(self, x): 288 | batch_size = x.shape[0] 289 | output_shape = x.shape[-2:] 290 | 291 | assert x.shape[1] % self.token_codimension == 0 292 | 293 | xa = rearrange(x, 'b (t d) h w -> (b t) d h w', 294 | d=self.token_codimension) 295 | xa_norm = self.norm1(xa) 296 | 297 | attention = self.compute_attention(xa_norm, batch_size) 298 | if self.proj is not None: 299 | attention = self.proj(attention) 300 | 301 | attention = self.attention_normalizer(attention + xa) 302 | attention = rearrange( 303 | attention, '(b t) d h w -> b (t d) h w', b=batch_size) 304 | # print("{attention.shape=}") 305 | attention = rearrange( 306 | attention, 307 | 'b (t d) h w -> (b t) d h w', 308 | d=self.mixer_token_codimension) 309 | # print("{attention.shape=}") 310 | 311 | attention_normalized = self.norm2(attention) 312 | output = self.mixer(attention_normalized, output_shape=output_shape) 313 | 314 | output = self.mixer_out_normalizer(output) + attention 315 | # print(f"{output.shape=}") 316 | output = rearrange(output, '(b t) d h w -> b (t d) h w', b=batch_size) 317 | 318 | return output 319 | 320 | def _forward_non_equivariant(self, x): 321 | batch_size = x.shape[0] 322 | output_shape = x.shape[-2:] 323 | 324 | assert x.shape[1] % self.token_codimension == 0 325 | 326 | x_norm = self.norm1(x) 327 | xa = rearrange(x_norm, 'b (t d) h w -> (b t) d h w', 328 | d=self.token_codimension) 329 | 330 | attention = self.compute_attention(xa, batch_size) 331 | if self.proj is not None: 332 | attention = self.proj(attention) 333 | 334 | attention = rearrange( 335 | attention, '(b t) d h w -> b (t d) h w', b=batch_size) 336 | attention_normalized = self.norm2(attention) 337 | output = self.mixer(attention_normalized, output_shape=output_shape) 338 | 339 | return output 340 | -------------------------------------------------------------------------------- /layers/codano_block_nd.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | from typing import Optional, Callable, Union, Dict 4 | 5 | import numpy as np 6 | from einops import rearrange 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from neuralop.layers.fno_block import FNOBlocks 11 | from .fino_nd import SpectralConvKernel2d, SpectralConvKernel1d, SpectralConvKernel3d 12 | 13 | 14 | # Implementation of generic N dimentional Codano Block 15 | AffineNormalizer1D = partial(nn.InstanceNorm1d, affine=True) 16 | AffineNormalizer2D = partial(nn.InstanceNorm2d, affine=True) 17 | AffineNormalizer3D = partial(nn.InstanceNorm3d, affine=True) 18 | # For higher demnsion (>=4), need to implement custom 19 | 20 | def Identity(x, *args, **kwargs): 21 | return x 22 | 23 | 24 | class CodanoBlock(nn.Module): 25 | def __init__( 26 | self, 27 | n_modes: Union[int, tuple], 28 | n_head: int = 1, 29 | token_codimension: int = 1, 30 | output_scaling_factor: Optional[float] = None, 31 | max_n_modes: Optional[Union[int, tuple]] = None, 32 | head_codimension: Optional[int] = None, 33 | use_mlp: bool = False, 34 | mlp: Optional[nn.Module] = None, 35 | non_linearity: Callable = F.gelu, 36 | preactivation: bool = False, 37 | fno_skip: str = 'linear', 38 | mlp_skip: str = 'linear', 39 | mlp_expansion: float = 1.0, 40 | separable: bool = False, 41 | factorization: str = None, 42 | rank: float = 1.0, 43 | SpectralConvolution: Optional[Callable] = None, 44 | Normalizer: Optional[Callable] = None, 45 | joint_factorization: bool = False, 46 | fixed_rank_modes: bool = False, 47 | implementation: str = 'reconstructed', 48 | decomposition_kwargs: Optional[Dict] = None, 49 | fft_norm: str = 'forward', 50 | codimension_size: Optional[int] = None, 51 | per_channel_attention: bool = True, 52 | permutation_eq: bool = True, 53 | temperature: float = 1.0, 54 | kqv_non_linear: bool = False, 55 | num_dims: int = 2, 56 | **_kwargs, 57 | ): 58 | super().__init__() 59 | 60 | self.variable_codimension = token_codimension 61 | self.token_codimension = token_codimension 62 | self.head_codimension = head_codimension or token_codimension 63 | self.n_head = n_head 64 | self.output_scaling_factor = output_scaling_factor 65 | self.temperature = temperature 66 | self.num_dims = num_dims 67 | 68 | if kqv_non_linear: 69 | kqv_activation = non_linearity 70 | else: 71 | kqv_activation = Identity 72 | 73 | self.permutation_eq = permutation_eq 74 | 75 | if self.n_head is not None: 76 | self.head_codimension = max(token_codimension // self.n_head, 1) 77 | 78 | self.codimension_size = codimension_size 79 | self.mixer_token_codimension = token_codimension 80 | 81 | if per_channel_attention: 82 | self.token_codimension = 1 83 | self.head_codimension = 1 84 | 85 | scale = min(self.n_head, 1 if per_channel_attention else 1) 86 | 87 | mixer_modes = [i // scale for i in n_modes] 88 | mixer_n_modes = [i // scale for i in max_n_modes] 89 | 90 | decomposition_kwargs = decomposition_kwargs or {} 91 | common_args = dict( 92 | use_mlp=use_mlp, 93 | mlp=mlp, 94 | preactivation=preactivation, 95 | mlp_skip=mlp_skip, 96 | mlp_dropout=0, 97 | rank=rank, 98 | fft_norm=fft_norm, 99 | mlp_expansion=mlp_expansion, 100 | fixed_rank_modes=fixed_rank_modes, 101 | implementation=implementation, 102 | separable=separable, 103 | factorization=factorization, 104 | decomposition_kwargs=decomposition_kwargs, 105 | joint_factorization=joint_factorization, 106 | ) 107 | 108 | kqv_args = dict( 109 | in_channels=self.token_codimension, 110 | out_channels=self.n_head * self.head_codimension, 111 | n_modes=mixer_modes, 112 | max_n_modes=mixer_n_modes, 113 | non_linearity=kqv_activation, 114 | fno_skip='linear', 115 | norm=None, 116 | apply_skip=True, 117 | n_layers=1, 118 | ) 119 | 120 | rank = 1.0 121 | conv_kwargs = dict(rank=rank, factorization=None) 122 | self.K = FNOBlocks( 123 | output_scaling_factor=1 / scale, 124 | SpectralConv=partial( 125 | SpectralConvolution, 126 | **conv_kwargs 127 | ), 128 | **kqv_args, 129 | **common_args, 130 | ) 131 | self.Q = FNOBlocks( 132 | output_scaling_factor=1 / scale, 133 | SpectralConv=partial( 134 | SpectralConvolution, 135 | **conv_kwargs, 136 | ), 137 | **kqv_args, 138 | **common_args, 139 | ) 140 | self.V = FNOBlocks( 141 | output_scaling_factor=1, 142 | SpectralConv=partial( 143 | SpectralConvolution, 144 | **conv_kwargs, 145 | ), 146 | **kqv_args, 147 | **common_args, 148 | ) 149 | 150 | if self.n_head * self.head_codimension != self.token_codimension: 151 | self.proj = FNOBlocks( 152 | in_channels=self.n_head * self.head_codimension, 153 | out_channels=self.token_codimension, 154 | n_modes=n_modes, 155 | max_n_modes=max_n_modes, 156 | output_scaling_factor=1, 157 | apply_skip=True, 158 | non_linearity=Identity, 159 | fno_skip='linear', 160 | norm=None, 161 | SpectralConv=partial( 162 | SpectralConvolution, 163 | rank=1.0, 164 | factorization=None, 165 | ), 166 | n_layers=1, 167 | **common_args, 168 | ) 169 | else: 170 | self.proj = None 171 | 172 | self.attention_normalizer = Normalizer(self.token_codimension) 173 | 174 | mixer_args = dict( 175 | n_modes=n_modes, 176 | max_n_modes=max_n_modes, 177 | output_scaling_factor=1, 178 | non_linearity=non_linearity, 179 | norm='instance_norm', 180 | fno_skip=fno_skip, 181 | SpectralConv=partial( 182 | SpectralConvolution, 183 | rank=rank, 184 | factorization=None, 185 | bias=True, 186 | ), 187 | ) 188 | 189 | if self.permutation_eq: 190 | self.mixer = FNOBlocks( 191 | in_channels=self.mixer_token_codimension, 192 | out_channels=self.mixer_token_codimension, 193 | apply_skip=True, 194 | n_layers=2, 195 | **mixer_args, 196 | **common_args, 197 | ) 198 | self.norm1 = Normalizer(self.token_codimension) 199 | self.norm2 = Normalizer(self.mixer_token_codimension) 200 | self.mixer_out_normalizer = Normalizer( 201 | self.mixer_token_codimension) 202 | else: 203 | self.mixer = FNOBlocks( 204 | in_channels=codimension_size, 205 | out_channels=codimension_size, 206 | n_layers=2, 207 | **mixer_args, 208 | **common_args, 209 | ) 210 | self.norm1 = Normalizer(codimension_size) 211 | self.norm2 = Normalizer(codimension_size) 212 | self.mixer_out_normalizer = Normalizer(codimension_size) 213 | 214 | def forward(self, *args): 215 | raise NotImplementedError( 216 | "Use a proper subclass of CodanoBlock (i.e. CodanoBlockND or CodanoBlock3D).") 217 | 218 | 219 | class CodanoBlockND(CodanoBlock): 220 | def __init__(self, *args, **kwargs): 221 | if kwargs["num_dims"] == 1: 222 | Normalizer = kwargs.get("Normalizer", AffineNormalizer1D) 223 | kwargs["Normalizer"] = Normalizer 224 | 225 | Convolution = kwargs.get( 226 | "SpectralConvolution", SpectralConvKernel1d) 227 | kwargs["SpectralConvolution"] = Convolution 228 | elif kwargs["num_dims"] == 2: 229 | Normalizer = kwargs.get("Normalizer", AffineNormalizer2D) 230 | kwargs["Normalizer"] = Normalizer 231 | 232 | Convolution = kwargs.get( 233 | "SpectralConvolution", SpectralConvKernel2d) 234 | kwargs["SpectralConvolution"] = Convolution 235 | elif kwargs["num_dims"] == 3: 236 | Normalizer = kwargs.get("Normalizer", AffineNormalizer3D) 237 | kwargs["Normalizer"] = Normalizer 238 | 239 | Convolution = kwargs.get( 240 | "SpectralConvolution", SpectralConvKernel3d) 241 | kwargs["SpectralConvolution"] = Convolution 242 | 243 | super().__init__(*args, **kwargs) 244 | 245 | def compute_attention(self, xa, batch_size): 246 | k = self.K(xa) 247 | q = self.Q(xa) 248 | v = self.V(xa) 249 | 250 | v_shape = v.shape[-self.num_dims:] 251 | 252 | rearrangement = dict( 253 | pattern=f'(b k) (a d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> b a k (d {" ".join(f"d{i}" for i in range(self.num_dims))})', 254 | b=batch_size, 255 | a=self.n_head, 256 | ) 257 | k = rearrange(k, **rearrangement) 258 | q = rearrange(q, **rearrangement) 259 | v = rearrange(v, **rearrangement) 260 | 261 | dprod = torch.matmul(q, k.transpose(-1, -2)) 262 | dprod = dprod / (self.temperature * np.sqrt(k.shape[-1])) 263 | dprod = F.softmax(dprod, dim=-1) 264 | 265 | attention = torch.matmul(dprod, v) 266 | rearrange_args = dict( 267 | pattern=f'b a k (d {" ".join(f"d{i}" for i in range(self.num_dims))}) -> b k a d {" ".join(f"d{i}" for i in range(self.num_dims))}', 268 | d=self.head_codimension, 269 | ) 270 | rearrange_args.update( 271 | {f'd{i}': v_shape[i] for i in range(self.num_dims)}) 272 | attention = rearrange(attention, **rearrange_args) 273 | attention = rearrange( 274 | attention, 275 | f'b k a d {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) (a d) {" ".join(f"d{i}" for i in range(self.num_dims))}') 276 | return attention 277 | 278 | def forward(self, x, output_shape=None): 279 | if self.permutation_eq: 280 | return self._forward_equivariant(x) 281 | else: 282 | return self._forward_non_equivariant(x) 283 | 284 | def _forward_equivariant(self, x): 285 | batch_size = x.shape[0] 286 | output_shape = x.shape[-self.num_dims:] 287 | 288 | assert x.shape[1] % self.token_codimension == 0 289 | 290 | xa = rearrange( 291 | x, 292 | f'b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) d {" ".join(f"d{i}" for i in range(self.num_dims))}', 293 | d=self.token_codimension) 294 | xa_norm = self.norm1(xa) 295 | 296 | attention = self.compute_attention(xa_norm, batch_size) 297 | if self.proj is not None: 298 | attention = self.proj(attention) 299 | 300 | attention = self.attention_normalizer(attention + xa) 301 | attention = rearrange( 302 | attention, 303 | f'(b k) d {" ".join(f"d{i}" for i in range(self.num_dims))} -> b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))}', 304 | b=batch_size) 305 | attention = rearrange( 306 | attention, 307 | f'b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) d {" ".join(f"d{i}" for i in range(self.num_dims))}', 308 | d=self.mixer_token_codimension) 309 | 310 | attention_normalized = self.norm2(attention) 311 | output = self.mixer(attention_normalized, output_shape=output_shape) 312 | 313 | output = self.mixer_out_normalizer(output) + attention 314 | output = rearrange( 315 | output, 316 | f'(b k) d {" ".join(f"d{i}" for i in range(self.num_dims))} -> b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))}', 317 | b=batch_size) 318 | 319 | return output 320 | 321 | def _forward_non_equivariant(self, x): 322 | batch_size = x.shape[0] 323 | output_shape = x.shape[-self.num_dims:] 324 | 325 | assert x.shape[1] % self.token_codimension == 0 326 | 327 | x_norm = self.norm1(x) 328 | xa = rearrange( 329 | x_norm, 330 | f'b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) d {" ".join(f"d{i}" for i in range(self.num_dims))}', 331 | d=self.token_codimension) 332 | 333 | attention = self.compute_attention(xa, batch_size) 334 | if self.proj is not None: 335 | attention = self.proj(attention) 336 | 337 | attention = rearrange( 338 | attention, 339 | f'(b k) d {" ".join(f"d{i}" for i in range(self.num_dims))} -> b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))}', 340 | b=batch_size) 341 | attention_normalized = self.norm2(attention) 342 | output = self.mixer(attention_normalized, output_shape=output_shape) 343 | 344 | return output 345 | -------------------------------------------------------------------------------- /layers/fino_2D.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import torch_harmonics as th 8 | from neuralop.layers.spectral_convolution import SpectralConv 9 | 10 | 11 | class SpectralConvKernel2d(SpectralConv): 12 | """ 13 | Parameters 14 | --- 15 | transform_type : {'sht', 'fft'} 16 | * If "sht" it uses the Spherical Fourier Transform. 17 | * If "fft" it uses the Fast Fourier Transform. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | in_channels, 23 | out_channels, 24 | n_modes, 25 | incremental_n_modes=None, 26 | bias=True, 27 | n_layers=1, 28 | separable=False, 29 | output_scaling_factor=None, 30 | rank=0.5, 31 | factorization='dense', 32 | implementation='reconstructed', 33 | fno_block_precision='full', 34 | fixed_rank_modes=False, 35 | joint_factorization=False, 36 | decomposition_kwargs=None, 37 | init_std='auto', 38 | fft_norm='forward', 39 | transform_type="sht", 40 | sht_nlat=180, 41 | sht_nlon=360, 42 | sht_grid="legendre-gauss", 43 | isht_grid="legendre-gauss", 44 | sht_norm="backward", 45 | frequency_mixer=False, 46 | verbose=True, 47 | logger=None 48 | ): 49 | self.verbose = verbose 50 | 51 | if decomposition_kwargs is None: 52 | decomposition_kwargs = {} 53 | super().__init__( 54 | in_channels, 55 | out_channels, 56 | n_modes, 57 | incremental_n_modes, 58 | bias=bias, 59 | n_layers=n_layers, 60 | separable=separable, 61 | output_scaling_factor=output_scaling_factor, 62 | fno_block_precision=fno_block_precision, 63 | rank=rank, 64 | factorization=factorization, 65 | implementation=implementation, 66 | fixed_rank_modes=fixed_rank_modes, 67 | joint_factorization=joint_factorization, 68 | decomposition_kwargs=decomposition_kwargs, 69 | init_std=init_std, 70 | fft_norm=fft_norm, 71 | ) 72 | 73 | # self.shared = shared 74 | 75 | # readjusting initialization 76 | if init_std == "auto": 77 | init_std = 1 / np.sqrt(in_channels * n_modes[-1] * n_modes[-2]) 78 | else: 79 | init_std = init_std 80 | 81 | for w in self.weight: 82 | w.normal_(0, init_std) 83 | 84 | # weights for frequency mixers 85 | 86 | hm1, hm2 = self.half_n_modes[0], self.half_n_modes[1] 87 | if frequency_mixer: 88 | # if frequency mixer is true 89 | # then initializing weights for frequncy mixing 90 | # otherwise it is just a regular FNO or SFNO layer 91 | print("using Mixer") 92 | self.W1r = nn.Parameter( 93 | torch.empty( 94 | hm1, 95 | hm2, 96 | hm1, 97 | hm2, 98 | dtype=torch.float)) 99 | self.W2r = nn.Parameter( 100 | torch.empty( 101 | hm1, 102 | hm2, 103 | hm1, 104 | hm2, 105 | dtype=torch.float)) 106 | self.W1i = nn.Parameter( 107 | torch.empty( 108 | hm1, 109 | hm2, 110 | hm1, 111 | hm2, 112 | dtype=torch.float)) 113 | self.W2i = nn.Parameter( 114 | torch.empty( 115 | hm1, 116 | hm2, 117 | hm1, 118 | hm2, 119 | dtype=torch.float)) 120 | self.reset_parameter() 121 | 122 | self.sht_grid = sht_grid 123 | self.isht_grid = isht_grid 124 | self.sht_norm = sht_norm 125 | self.transform_type = transform_type 126 | self.frequency_mixer = frequency_mixer 127 | 128 | 129 | if self.output_scaling_factor is not None: 130 | out_nlat = round(sht_nlat * self.output_scaling_factor[0][0]) 131 | out_nlon = round(sht_nlon * self.output_scaling_factor[0][1]) 132 | else: 133 | out_nlat = sht_nlat 134 | out_nlon = sht_nlon 135 | 136 | if self.transform_type == "sht": 137 | self.forward_sht = th.RealSHT( 138 | sht_nlat, 139 | sht_nlon, 140 | grid=self.sht_grid, 141 | norm=self.sht_norm, 142 | ) 143 | self.inverse_sht = th.InverseRealSHT( 144 | out_nlat, 145 | out_nlon, 146 | grid=self.isht_grid, 147 | norm=self.sht_norm, 148 | ) 149 | 150 | def reset_parameter(self): 151 | # Initial model parameters. 152 | scaling_factor = ((1 / self.in_channels)**0.5) / \ 153 | (self.half_n_modes[0] * self.half_n_modes[1]) 154 | torch.nn.init.normal_(self.W1r, mean=0.0, std=scaling_factor) 155 | torch.nn.init.normal_(self.W2r, mean=0.0, std=scaling_factor) 156 | torch.nn.init.normal_(self.W1i, mean=0.0, std=scaling_factor) 157 | torch.nn.init.normal_(self.W2i, mean=0.0, std=scaling_factor) 158 | 159 | @staticmethod 160 | def mode_mixer(x, weights): 161 | return torch.einsum("bimn,mnop->biop", x, weights) 162 | 163 | def forward_transform(self, x): 164 | height, width = x.shape[-2:] 165 | if self.transform_type == "fft": 166 | return torch.fft.rfft2(x.float(), norm=self.fft_norm) 167 | 168 | if self.transform_type == "sht": 169 | # The SHT is expensive to initialize, and during training we expect 170 | # the data to all be of the same shape. If we have a correct SHT, 171 | # let's use it: 172 | if ( 173 | self.forward_sht.nlat == height and 174 | self.forward_sht.nlon == width 175 | ): 176 | return self.forward_sht(x.double()).to(dtype=torch.cfloat) 177 | 178 | # Otherwise, initialize a new SHT: 179 | self.forward_sht = th.RealSHT( 180 | height, 181 | width, 182 | grid=self.sht_grid, 183 | norm=self.sht_norm, 184 | ).to(x.device) 185 | return self.forward_sht(x.double()).to(dtype=torch.cfloat) 186 | 187 | raise ValueError( 188 | 'Expected `transform_type` to be one of "fft" or "sht"; ' 189 | f'Got {self.transform_type=}' 190 | ) 191 | 192 | # Although a previous implementation kept an initialized 193 | # ``th.InverseRealSHT`` in its state, it always checked if its lat/lon grid 194 | # size matched the input's 195 | # resolution. Thus, it never really mattered that an object was in state. 196 | def inverse_transform( 197 | self, 198 | x: torch.Tensor, 199 | target_height: int, 200 | target_width: int, 201 | device, 202 | ): 203 | source_height, source_width = x.shape[-2:] 204 | if self.transform_type == "fft": 205 | return torch.fft.irfft2( 206 | x, 207 | s=(target_height, target_width), 208 | dim=(-2, -1), 209 | norm=self.fft_norm, 210 | ) 211 | 212 | if self.transform_type == "sht": 213 | # The SHT is expensive to initialize, and during training we expect 214 | # the data to all be of the same shape. If we have a correct SHT, 215 | # let's use it: 216 | if ( 217 | self.inverse_sht.lmax == source_height and 218 | self.inverse_sht.mmax == source_width and 219 | self.inverse_sht.nlat == target_height and 220 | self.inverse_sht.nlon == target_width 221 | ): 222 | return self.inverse_sht(x.to(dtype=torch.cdouble)).float() 223 | 224 | # Otherwise, initialize a new SHT: 225 | self.inverse_sht = th.InverseRealSHT( 226 | target_height, 227 | target_width, 228 | lmax=source_height, 229 | mmax=source_width, 230 | grid=self.sht_grid, 231 | norm=self.sht_norm, 232 | ).to(device) 233 | return self.inverse_sht(x.to(dtype=torch.cdouble)).float() 234 | 235 | raise ValueError( 236 | 'Expected `transform_type` to be one of "fft" or "sht"; ' 237 | f'Got {self.transform_type=}' 238 | ) 239 | 240 | def forward(self, x, indices=0, output_shape=None): 241 | batch_size, channels, height, width = x.shape 242 | 243 | x = self.forward_transform(x) 244 | 245 | upper_modes = [ 246 | slice(None), 247 | slice(None), 248 | slice(None, self.half_n_modes[0]), 249 | slice(None, self.half_n_modes[1]), 250 | ] 251 | """Slice for upper frequency modes. 252 | 253 | Equivalent to: ``x[:, :, :self.half_n_modes[0], :self.half_n_modes[1]]`` 254 | """ 255 | 256 | lower_modes = [ 257 | slice(None), 258 | slice(None), 259 | slice(-self.half_n_modes[0], None), 260 | slice(None, self.half_n_modes[1]), 261 | ] 262 | """Slice for lower frequency modes. 263 | 264 | Equivalent to: ``x[:, :, -self.half_n_modes[0]:, :self.half_n_modes[1]]`` 265 | """ 266 | 267 | # mode mixer 268 | # uses separate MLP to mix mode along each co-dim/channels 269 | if self.frequency_mixer: 270 | W1 = self.W1r + 1.0j * self.W1i 271 | W2 = self.W2r + 1.0j * self.W2i 272 | 273 | x[upper_modes] = self.mode_mixer(x[upper_modes].clone(), W1) 274 | x[lower_modes] = self.mode_mixer(x[lower_modes].clone(), W2) 275 | 276 | # spectral conv / channel mixer 277 | 278 | # The output will be of size: 279 | # (batch_size, self.out_channels, x.size(-2), x.size(-1)//2 + 1) 280 | out_fft = torch.zeros( 281 | [batch_size, self.out_channels, height, width // 2 + 1], 282 | dtype=x.dtype, 283 | device=x.device, 284 | ) 285 | 286 | # Upper block (truncate high frequencies): 287 | out_fft[upper_modes] = self._contract( 288 | x[upper_modes], 289 | self._get_weight(2 * indices), 290 | separable=self.separable, 291 | ) 292 | # Lower block (truncate low frequencies): 293 | out_fft[lower_modes] = self._contract( 294 | x[lower_modes], 295 | self._get_weight(2 * indices + 1), 296 | separable=self.separable, 297 | ) 298 | 299 | if self.output_scaling_factor is not None and output_shape is None: 300 | height = round(height * self.output_scaling_factor[indices][0]) 301 | width = round(width * self.output_scaling_factor[indices][1]) 302 | 303 | if output_shape is not None: 304 | height = output_shape[0] 305 | width = output_shape[1] 306 | 307 | x = self.inverse_transform(out_fft, height, width, x.device) 308 | 309 | if self.bias is not None: 310 | x = x + self.bias[indices, ...] 311 | 312 | return x 313 | -------------------------------------------------------------------------------- /layers/gnn_layer.py: -------------------------------------------------------------------------------- 1 | from baseline_utlis import FixedNeighborSearch 2 | from neuralop.layers.integral_transform import IntegralTransform 3 | from neuralop.layers.mlp import MLPLinear 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class GnnLayer(nn.Module): 9 | def __init__(self, in_dim, out_dim, 10 | input_grid, output_grid, mlp_layers, projection_hidden_dim, 11 | n_neigbor): 12 | ''' 13 | var_num: number of variables 14 | in_dim: Input Condim/channel per variables 15 | out_dim: Output Condim/channel per variables 16 | input_grid: Input grid (points) 17 | output_grid: Output grid (points) 18 | mlp_layers: MLP layers (for integral operator) 19 | projection_hidden_dim: Before applying integral operator we have pointwise MLP. This parameter 20 | determines the width of the multi-layered MLP 21 | n_neigbor: number of neighbours to consider 22 | ''' 23 | super().__init__() 24 | 25 | n_dim = input_grid.shape[-1] 26 | self.in_dim = in_dim 27 | self.out_dim = out_dim 28 | self.input_grid = input_grid 29 | self.output_grid = output_grid 30 | self.mlp_layers = [2 * n_dim + out_dim] + mlp_layers + [out_dim] 31 | self.n_neigbor = n_neigbor 32 | # project to higher dim 33 | self.projection = MLPLinear([self.in_dim, 34 | projection_hidden_dim, out_dim]) 35 | print([self.in_dim, projection_hidden_dim, out_dim]) 36 | print(self.mlp_layers) 37 | # apply GNO to get uniform grid 38 | NS = FixedNeighborSearch(use_open3d=False) 39 | 40 | self.neighbour = NS( 41 | input_grid.clone().cpu(), 42 | output_grid.clone().cpu(), 43 | n_neigbor=n_neigbor) 44 | 45 | for key, value in self.neighbour.items(): 46 | self.neighbour[key] = self.neighbour[key].cuda() 47 | 48 | self.it = IntegralTransform( 49 | mlp_layers=self.mlp_layers, transform_type='nonlinear') 50 | 51 | self.normalize = nn.LayerNorm(out_dim) 52 | 53 | def update_grid( 54 | self, 55 | input_grid=None, 56 | output_grid=None 57 | ): 58 | 59 | if input_grid is None: 60 | input_grid = self.input_grid 61 | if output_grid is None: 62 | output_grid = self.output_grid 63 | 64 | input_grid = input_grid.clone() 65 | self.input_grid = self.input_grid[:input_grid.shape[0], :] 66 | self.output_grid = self.output_grid[:output_grid.shape[0], :] 67 | 68 | NS = FixedNeighborSearch(use_open3d=False) 69 | 70 | self.neighbour = NS( 71 | input_grid.clone(), 72 | output_grid.clone(), 73 | n_neigbor=self.n_neigbor) 74 | for key, value in self.neighbour.items(): 75 | self.neighbour[key] = self.neighbour[key].cuda() 76 | 77 | def forward(self, inp): 78 | ''' 79 | inp : (batch_size, n_points, in_dims/Channels) 80 | ''' 81 | x = inp 82 | x = self.projection(x) 83 | out = self.it(self.input_grid, self.neighbour, 84 | self.output_grid, x) 85 | 86 | 87 | if out.shape == x.shape: 88 | out = out + x 89 | out = self.normalize(out) 90 | return out 91 | -------------------------------------------------------------------------------- /layers/gno_layer.py: -------------------------------------------------------------------------------- 1 | from neuralop.layers.neighbor_search import NeighborSearch 2 | from neuralop.layers.integral_transform import IntegralTransform 3 | from neuralop.layers.mlp import MLPLinear 4 | from baseline_utlis import FixedNeighborSearch 5 | from einops import rearrange 6 | from neuralop.layers.embeddings import PositionalEmbedding 7 | import torch.nn as nn 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | class GnoPremEq(nn.Module): 13 | def __init__( 14 | self, 15 | var_num, 16 | in_dim, out_dim, 17 | input_grid, 18 | output_grid, 19 | mlp_layers, 20 | projection_hidden_dim, 21 | radius, 22 | var_encoding=False, 23 | n_neigbor=10, 24 | fixed_neighbour=False, 25 | var_encoding_channels=1, 26 | n_layers=2, 27 | postional_em_dim=4, # always even 28 | end_projection=False, 29 | end_projection_outdim=None, 30 | ): 31 | ''' 32 | var_num: number of variables 33 | in_dim: Input Condim/channel per variables 34 | out_dim: Output Condim/channel per variables 35 | input_grid: Input grid (points) 36 | output_grid: Output grid (points) 37 | mlp_layers: MLP layers (for integral operator) 38 | projection_hidden_dim: Before applying integral operator we have pointwise MLP. This parameter 39 | determines the width of the multi-layered MLP 40 | radius: radius of the neighbourhood 41 | var_encoding: whether to use variable encoding 42 | var_encoding_channels: number of channels for variable encoding 43 | ''' 44 | super().__init__() 45 | assert postional_em_dim % 2 == 0 46 | n_dim = input_grid.shape[-1] 47 | self.radius = radius 48 | self.fixed_neighbour = fixed_neighbour 49 | self.n_neigbor = n_neigbor 50 | self.var_num = var_num 51 | self.in_dim = in_dim 52 | self.out_dim = out_dim 53 | self.input_grid = input_grid 54 | self.output_grid = output_grid 55 | self.mlp_layers = [2 * n_dim + self.out_dim] + mlp_layers + [out_dim] 56 | self.var_encoding = var_encoding 57 | self.postional_em_dim = postional_em_dim 58 | self.var_encoding_channels = var_encoding_channels 59 | self.n_layers = n_layers 60 | self.end_projection = end_projection 61 | self.end_projection_outdim = end_projection_outdim 62 | 63 | # get varibale encoding 64 | if self.var_encoding: 65 | self.var_encoder = MLPLinear( 66 | [n_dim + 2 * postional_em_dim, self.var_encoding_channels * var_num]) 67 | self.PE = PositionalEmbedding(postional_em_dim) 68 | self.variable_channels = [ 69 | i * (var_encoding_channels + self.in_dim) for i in range(var_num)] 70 | self.encoding_channels = list(set([i for i in range( 71 | (var_encoding_channels + 1) * var_num)]) - set(self.variable_channels)) 72 | else: 73 | self.var_encoding_channels = 0 74 | 75 | # project to higher dim 76 | self.projection = MLPLinear([self.var_encoding_channels + self.in_dim, 77 | projection_hidden_dim, out_dim], 78 | non_linearity=F.gelu) 79 | 80 | # apply GNO to get uniform grid 81 | 82 | self.neighbour = None 83 | self.neighbour_last = None 84 | self.update_grid() 85 | 86 | self.it = torch.nn.ModuleList() 87 | for i in range(n_layers): 88 | self.it.append(IntegralTransform( 89 | mlp_layers=self.mlp_layers, 90 | transform_type='nonlinear', 91 | mlp_non_linearity=F.gelu)) 92 | 93 | if self.end_projection: 94 | self.end_projector = MLPLinear([self.out_dim, 95 | projection_hidden_dim, self.end_projection_outdim], 96 | non_linearity=F.gelu) 97 | 98 | def update_grid( 99 | self, 100 | input_grid=None, 101 | output_grid=None 102 | ): 103 | if input_grid is None: 104 | input_grid = self.input_grid 105 | if output_grid is None: 106 | output_grid = self.output_grid 107 | 108 | input_grid = input_grid.clone() 109 | self.input_grid = self.input_grid[:input_grid.shape[0], :] 110 | self.output_grid = self.output_grid[:output_grid.shape[0], :] 111 | if self.fixed_neighbour: 112 | NS = FixedNeighborSearch(use_open3d=False) 113 | self.neighbour = NS( 114 | input_grid.clone().cpu(), 115 | input_grid.clone().cpu(), 116 | n_neigbor=self.n_neigbor) 117 | else: 118 | NS = NeighborSearch(use_open3d=False) 119 | self.neighbour = NS( 120 | input_grid.clone().cpu(), 121 | input_grid.clone().cpu(), 122 | radius=self.radius) 123 | 124 | for key, value in self.neighbour.items(): 125 | self.neighbour[key] = self.neighbour[key].cuda() 126 | 127 | NS_last = FixedNeighborSearch(use_open3d=False) 128 | self.neighbour_last = NS_last( 129 | input_grid.clone().cpu(), 130 | output_grid.clone().cpu(), 131 | n_neigbor=self.n_neigbor) 132 | 133 | for key, value in self.neighbour_last.items(): 134 | self.neighbour_last[key] = self.neighbour_last[key].cuda() 135 | 136 | def _intergral_transform(self, x): 137 | for i in range(self.n_layers): 138 | if i == self.n_layers - 1: 139 | x = self.it[i](self.input_grid, self.neighbour_last, 140 | self.output_grid, x) 141 | if self.end_projection: 142 | x = self.end_projector(x) 143 | else: 144 | x = self.it[i](self.input_grid, self.neighbour, 145 | self.input_grid, x) + x 146 | return x 147 | 148 | def forward(self, inp): 149 | ''' 150 | inp : (batch_size, n_points, in_dims/Channels) 151 | ''' 152 | 153 | if self.var_encoding: 154 | x = torch.zeros((inp.shape[0], inp.shape[1], len( 155 | self.variable_channels) + len(self.encoding_channels)), device=inp.device, dtype=inp.dtype) 156 | 157 | pe = self.PE(self.input_grid.reshape(-1)) 158 | pe = pe.reshape(self.input_grid.shape[0], -1) 159 | grid_pe = torch.cat([self.input_grid, pe], axis=1) 160 | var_encoding = self.var_encoder(grid_pe).to(x.device) 161 | x[:, :, self.variable_channels] = inp 162 | x[:, :, self.encoding_channels] = var_encoding[None, 163 | :, :].repeat(x.shape[0], 1, 1) 164 | else: 165 | x = inp 166 | 167 | 168 | x = rearrange( 169 | x, 170 | 'b n (v c) -> (b n) v c', 171 | c=self.in_dim + 172 | self.var_encoding_channels) 173 | x = self.projection(x) 174 | 175 | out = None 176 | 177 | for i in range(x.shape[-2]): 178 | # print(i) 179 | 180 | temp = self._intergral_transform(x[:, i, :]) 181 | if out is None: 182 | out = temp[None, ...] 183 | else: 184 | out = torch.cat([out, temp[None, ...]], dim=2) 185 | 186 | 187 | return out 188 | 189 | 190 | class GNO(nn.Module): 191 | def __init__(self, in_dim, out_dim, 192 | input_grid, output_grid, mlp_layers, projection_hidden_dim, 193 | radius, fixed_neighbour=False, n_neigbor=10): 194 | ''' 195 | var_num: number of variables 196 | in_dim: Input Condim/channel per variables 197 | out_dim: Output Condim/channel per variables 198 | input_grid: Input grid (points) 199 | output_grid: Output grid (points) 200 | mlp_layers: MLP layers (for integral operator) 201 | projection_hidden_dim: Before applying integral operator we have pointwise MLP. This parameter 202 | determines the width of the multi-layered MLP 203 | radius: radius of the neighbourhood 204 | ''' 205 | super().__init__() 206 | 207 | n_dim = input_grid.shape[-1] 208 | self.in_dim = in_dim 209 | self.out_dim = out_dim 210 | self.input_grid = input_grid 211 | self.output_grid = output_grid 212 | self.mlp_layers = [2 * n_dim] + mlp_layers + [out_dim] 213 | self.fixed_neighbour = fixed_neighbour 214 | self.n_neigbor = n_neigbor 215 | self.radius = radius 216 | # project to higher dim 217 | self.projection = MLPLinear([self.in_dim, 218 | projection_hidden_dim, out_dim]) 219 | 220 | self.neighbour = None 221 | self.update_grid() 222 | 223 | for key, value in self.neighbour.items(): 224 | self.neighbour[key] = self.neighbour[key].cuda() 225 | 226 | self.it = IntegralTransform(mlp_layers=self.mlp_layers) 227 | 228 | def update_grid( 229 | self, 230 | input_grid=None, 231 | output_grid=None 232 | ): 233 | if input_grid is None: 234 | input_grid = self.input_grid 235 | if output_grid is None: 236 | output_grid = self.output_grid 237 | input_grid = input_grid.clone() 238 | self.input_grid = self.input_grid[:input_grid.shape[0], :] 239 | self.output_grid = self.output_grid[:output_grid.shape[0], :] 240 | 241 | if self.fixed_neighbour: 242 | NS = FixedNeighborSearch(use_open3d=False) 243 | self.neighbour = NS( 244 | input_grid.clone(), 245 | output_grid.clone(), 246 | n_neigbor=self.n_neigbor) 247 | else: 248 | NS = NeighborSearch(use_open3d=False) 249 | self.neighbour = NS( 250 | input_grid.clone(), 251 | output_grid.clone(), 252 | radius=self.radius) 253 | 254 | def forward(self, inp): 255 | ''' 256 | inp : (batch_size, n_points, in_dims/Channels) 257 | ''' 258 | 259 | x = inp 260 | x = self.projection(x) 261 | 262 | out = self.it(self.input_grid, self.neighbour, 263 | self.output_grid, x[0, ...]) 264 | 265 | return out[None, ...] 266 | -------------------------------------------------------------------------------- /layers/regrider.py: -------------------------------------------------------------------------------- 1 | import torch_harmonics as th 2 | import torch.nn as nn 3 | 4 | 5 | class Regird(nn.Module): 6 | def __init__( 7 | self, 8 | input_grid, 9 | output_grid, 10 | sht_nlat=128, 11 | sht_nlon=256, 12 | output_scaling_factor=None): 13 | super().__init__() 14 | self.input_transform = th.RealSHT( 15 | sht_nlat, sht_nlon, grid=input_grid, norm='backward').float() 16 | self.input_grid = input_grid 17 | self.output_grid = output_grid 18 | self.output_transform = th.InverseRealSHT( 19 | sht_nlat, sht_nlon, grid=output_grid, norm='backward').float() 20 | 21 | def forward(self, x): 22 | if self.input_transform.nlat != x.shape[-2] or self.input_transform.nlon != x.shape[-1]: 23 | self.input_transform = th.RealSHT(x.shape[-2], x.shape[-1], grid=self.input_grid, 24 | norm='backward').to(x.device, dtype=x.dtype) 25 | self.output_transform = th.InverseRealSHT(x.shape[-2], x.shape[-1], grid=self.output_grid, 26 | norm='backward').to(x.device, dtype=x.dtype) 27 | 28 | return self.output_transform(self.input_transform(x)) 29 | -------------------------------------------------------------------------------- /layers/regular_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange, repeat 4 | from vit_pytorch import ViT 5 | 6 | 7 | class vision_transformer(ViT): 8 | def forward(self, img): 9 | x = self.to_patch_embedding(img) 10 | b, n, _ = x.shape 11 | 12 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) 13 | x = torch.cat((cls_tokens, x), dim=1) 14 | x += self.pos_embedding[:, :(n + 1)] 15 | x = self.dropout(x) 16 | 17 | x = self.transformer(x) 18 | return x[:, 1:, :] -------------------------------------------------------------------------------- /layers/unet3d.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class UNet3d(nn.Module): 8 | def __init__(self, in_channels=3, out_channels=1, init_features=32): 9 | super(UNet3d, self).__init__() 10 | 11 | features = init_features 12 | self.encoder1 = UNet3d._block(in_channels, features, name="enc1") 13 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 14 | self.encoder2 = UNet3d._block(features, features * 2, name="enc2") 15 | self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)) 16 | self.bottleneck = UNet3d._block( 17 | features * 2, features * 4, name="bottleneck") 18 | self.upconv2 = nn.ConvTranspose3d( 19 | features * 4, features * 2, kernel_size=(1, 2, 2), stride=(1, 2, 2) 20 | ) 21 | self.decoder2 = UNet3d._block( 22 | (features * 2) * 2, features * 2, name="dec2") 23 | self.upconv1 = nn.ConvTranspose3d( 24 | features * 2, features, kernel_size=(1, 2, 2), stride=(1, 2, 2) 25 | ) 26 | self.decoder1 = UNet3d._block(features * 2, features, name="dec1") 27 | 28 | self.conv = nn.Conv3d( 29 | in_channels=features, out_channels=out_channels, kernel_size=1 30 | ) 31 | def forward(self, x): 32 | enc1 = self.encoder1(x) 33 | enc2 = self.encoder2(self.pool1(enc1)) 34 | bottleneck = self.bottleneck(self.pool2(enc2)) 35 | 36 | dec2 = self.upconv2(bottleneck) 37 | dec2 = torch.cat((dec2, enc2), dim=1) 38 | dec2 = self.decoder2(dec2) 39 | dec1 = self.upconv1(dec2) 40 | dec1 = torch.cat((dec1, enc1), dim=1) 41 | dec1 = self.decoder1(dec1) 42 | return self.conv(dec1) 43 | 44 | @staticmethod 45 | def _block(in_channels, features, name): 46 | return nn.Sequential( 47 | OrderedDict( 48 | [ 49 | ( 50 | name + "conv1", 51 | nn.Conv3d( 52 | in_channels=in_channels, 53 | out_channels=features, 54 | kernel_size=3, 55 | padding=1, 56 | bias=False, 57 | ), 58 | ), 59 | (name + "norm1", nn.BatchNorm3d(num_features=features)), 60 | (name + "relu1", nn.ReLU(inplace=True)), 61 | ( 62 | name + "conv2", 63 | nn.Conv3d( 64 | in_channels=features, 65 | out_channels=features, 66 | kernel_size=3, 67 | padding=1, 68 | bias=False, 69 | ), 70 | ), 71 | (name + "norm2", nn.BatchNorm3d(num_features=features)), 72 | (name + "relu2", nn.ReLU(inplace=True)), 73 | ] 74 | ) 75 | ) 76 | -------------------------------------------------------------------------------- /layers/unet_sublayer.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class UNet2d(nn.Module): 14 | 15 | def __init__(self, in_channels=3, out_channels=1, init_features=32): 16 | super(UNet2d, self).__init__() 17 | 18 | features = init_features 19 | self.encoder1 = UNet2d._block(in_channels, features, name="enc1") 20 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 21 | self.encoder2 = UNet2d._block(features, features * 2, name="enc2") 22 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 23 | self.bottleneck = UNet2d._block( 24 | features * 2, features * 4, name="bottleneck") 25 | self.upconv2 = nn.ConvTranspose2d( 26 | features * 4, features * 2, kernel_size=2, stride=2 27 | ) 28 | self.decoder2 = UNet2d._block( 29 | (features * 2) * 2, features * 2, name="dec2") 30 | self.upconv1 = nn.ConvTranspose2d( 31 | features * 2, features, kernel_size=2, stride=2 32 | ) 33 | self.decoder1 = UNet2d._block(features * 2, features, name="dec1") 34 | 35 | self.conv = nn.Conv2d( 36 | in_channels=features, out_channels=out_channels, kernel_size=1 37 | ) 38 | 39 | def forward(self, x, **kwargs): 40 | enc1 = self.encoder1(x) 41 | enc2 = self.encoder2(self.pool1(enc1)) 42 | 43 | bottleneck = self.bottleneck(self.pool2(enc2)) 44 | dec2 = self.upconv2(bottleneck) 45 | dec2 = torch.cat((dec2, enc2), dim=1) 46 | dec2 = self.decoder2(dec2) 47 | dec1 = self.upconv1(dec2) 48 | dec1 = torch.cat((dec1, enc1), dim=1) 49 | dec1 = self.decoder1(dec1) 50 | return self.conv(dec1) 51 | 52 | @staticmethod 53 | def _block(in_channels, features, name): 54 | return nn.Sequential( 55 | OrderedDict( 56 | [ 57 | ( 58 | name + "conv1", 59 | nn.Conv2d( 60 | in_channels=in_channels, 61 | out_channels=features, 62 | kernel_size=3, 63 | padding=1, 64 | bias=False, 65 | ), 66 | ), 67 | (name + "norm1", nn.BatchNorm2d(num_features=features)), 68 | (name + "tanh1", nn.ReLU()), 69 | ( 70 | name + "conv2", 71 | nn.Conv2d( 72 | in_channels=features, 73 | out_channels=features, 74 | kernel_size=3, 75 | padding=1, 76 | bias=False, 77 | ), 78 | ), 79 | (name + "norm2", nn.BatchNorm2d(num_features=features)), 80 | (name + "tanh2", nn.ReLU()), 81 | ] 82 | ) 83 | ) 84 | -------------------------------------------------------------------------------- /layers/variable_encoding.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import Tuple 3 | from neuralop.layers.mlp import MLPLinear 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import torch_harmonics as th 8 | from neuralop.layers.embeddings import PositionalEmbedding 9 | 10 | 11 | class VariableEncoding2d(nn.Module): 12 | def __init__(self, 13 | n_variables: int, 14 | variable_encoding_size: int, 15 | modes: Tuple[int, ...], 16 | basis='fft') -> None: 17 | super().__init__() 18 | self.modes = modes 19 | channel = n_variables * variable_encoding_size 20 | self.coefficients_r = nn.Parameter( 21 | torch.empty(channel, *modes)) 22 | self.coefficients_i = nn.Parameter( 23 | torch.empty(channel, *modes)) 24 | self.reset_parameters() 25 | self.basis = basis 26 | if basis == 'fft': 27 | self.transform = torch.fft.ifft2 28 | elif basis == 'sht': 29 | self.transform = th.InverseRealSHT( 30 | *modes, 31 | lmax=modes[-2], 32 | mmax=modes[-1], 33 | grid="legendre-gauss", 34 | norm="backward", 35 | ) 36 | 37 | def reset_parameters(self): 38 | std = (1 / (self.modes[-1] * self.modes[-2]))**0.5 39 | torch.nn.init.normal_(self.coefficients_r, mean=0.0, std=std) 40 | torch.nn.init.normal_(self.coefficients_i, mean=0.0, std=std) 41 | 42 | def forward(self, x): 43 | """Take a resolution and outputs the positional encodings""" 44 | size_x, size_y = x.shape[-2], x.shape[-1] 45 | if self.basis == 'sht': 46 | if self.transform.nlat == size_x and self.transform.nlon == size_y: 47 | return self.transform( 48 | self.coefficients_r + 1.0j * self.coefficients_i) 49 | 50 | self.transform = th.InverseRealSHT( 51 | size_x, 52 | size_y, 53 | lmax=self.modes[-2], 54 | mmax=self.modes[-1], 55 | grid="legendre-gauss", 56 | norm='backward' 57 | ).to( 58 | device=self.coefficients_i.device, 59 | dtype=self.coefficients_i.dtype 60 | ) 61 | return self.transform( 62 | self.coefficients_r + 1.0j * self.coefficients_i, 63 | s=(size_x, size_y) 64 | ).real 65 | 66 | else: 67 | return self.transform( 68 | self.coefficients_r + 1.0j * self.coefficients_i, 69 | s=(size_x, size_y) 70 | ).real 71 | 72 | 73 | # SHT doesn't make sense for 3 dimenstional data 74 | class FourierVariableEncoding3D(nn.Module): 75 | def __init__(self, n_features: int, modes: Tuple[int, ...]) -> None: 76 | super().__init__() 77 | if len(modes) != 3: 78 | raise ValueError( 79 | f"Expected 3 frequency modes, but got {len(modes)} modes:\n{modes=}") 80 | 81 | self.modes = modes 82 | self.weights_re = nn.Parameter(torch.empty(n_features, *modes)) 83 | self.weights_im = nn.Parameter(torch.empty(n_features, *modes)) 84 | self.reset_parameters() 85 | # self.transform = torch.fft.ifftn 86 | 87 | def reset_parameters(self): 88 | std = 1 / np.sqrt(reduce(lambda a, b: a * b, self.modes)) 89 | torch.nn.init.normal_(self.weights_re, mean=0.0, std=std) 90 | torch.nn.init.normal_(self.weights_im, mean=0.0, std=std) 91 | 92 | def forward(self, x_shape): 93 | """Take a resolution and outputs the positional encodings""" 94 | _, size_t, size_x, size_y = x_shape[0], x_shape[1], x_shape[2], x_shape[3] 95 | s = (size_t, size_x, size_y) 96 | # now check if s is a tuple of ints 97 | if not all(isinstance(i, int) for i in s): 98 | raise ValueError( 99 | f"Expected a tuple of integers, but got {s=}") 100 | else: 101 | s = tuple(s) 102 | return torch.fft.ifftn( 103 | self.weights_re + 1.0j * self.weights_im, 104 | s=s, 105 | norm="forward", # don't multiply by any normalization factor 106 | ).real 107 | 108 | 109 | class VariableEncodingIrregularMesh(nn.Module): 110 | def __init__( 111 | self, 112 | n_variables: int, 113 | variable_encoding_size: int, 114 | n_dim: int = 2, 115 | positional_encoding_dim: int = 8 116 | ) -> None: 117 | super().__init__() 118 | self.n_variables = n_variables 119 | self.variable_encoding_size = variable_encoding_size 120 | self.n_dim = n_dim 121 | self.positional_encoding_dim = positional_encoding_dim 122 | self.var_encoder = MLPLinear( 123 | [n_dim + self.n_dim * positional_encoding_dim, self.variable_encoding_size * n_variables]) 124 | self.PE = PositionalEmbedding(positional_encoding_dim) 125 | 126 | def forward(self, grid_poits): 127 | pe = self.PE(grid_poits.reshape(-1)) 128 | pe = pe.reshape(grid_poits.shape[0], -1) 129 | grid_pe = torch.cat([grid_poits, pe], axis=1) 130 | var_encoding = self.var_encoder(grid_pe) 131 | return var_encoding 132 | 133 | 134 | class VariableEncodingWrapper(nn.Module): 135 | def __init__( 136 | self, 137 | equation_dict: dict, 138 | variable_encoding_size: int, 139 | n_dim: int = 2, 140 | positional_encoding_dim: int = 8, 141 | varibale_encoding_modes: Tuple[int, ...] = (32, 32), 142 | basis='fft', 143 | uniform=False) -> None: 144 | ''' 145 | For each equation in the equation_dict, we create a VariableEncodingIrregularMesh 146 | dic is of form {"Equation": n_variables, ...} 147 | ''' 148 | super().__init__() 149 | self.n_dim = n_dim 150 | self.equation_dict = equation_dict 151 | self.variable_encoding_size = variable_encoding_size 152 | self.model_dict = nn.ModuleDict() 153 | self.uniform = uniform 154 | for i in equation_dict.keys(): 155 | if not uniform: 156 | self.model_dict[i] = VariableEncodingIrregularMesh( 157 | n_variables=equation_dict[i], 158 | variable_encoding_size=self.variable_encoding_size, 159 | n_dim=n_dim, 160 | positional_encoding_dim=positional_encoding_dim 161 | ) 162 | else: 163 | self.model_dict[i] = VariableEncoding2d( 164 | n_variables=equation_dict[i], 165 | variable_encoding_size=self.variable_encoding_size, 166 | modes=varibale_encoding_modes, 167 | basis=basis 168 | ) 169 | 170 | def load_encoder(self, equation: str, path: str): 171 | self.model_dict[equation].load_state_dict( 172 | torch.load(path, map_location=torch.device('cpu'))) 173 | 174 | def save_encoder(self, equation: str, path: str): 175 | torch.save(self.model_dict[equation].state_dict(), path) 176 | 177 | def save_all_encoder(self, path: str): 178 | for i in self.equation_dict.keys(): 179 | torch.save(self.model_dict[i].state_dict(), path + f"_{i}" + ".pt") 180 | 181 | def freeze(self, equation: str): 182 | for param in self.model_dict[equation].parameters(): 183 | param.requires_grad = False 184 | 185 | def forward(self, grid_poits, equation: str = None): 186 | ''' 187 | grid_poits: (n_points, n_dim) or for uniform mesh input tensor of shape (D, channels, H, W) 188 | ''' 189 | encoding_list = [] 190 | if equation is None: 191 | equation = list(self.equation_dict.keys()) 192 | for i in equation: 193 | encoding_list.append(self.model_dict[i](grid_poits)) 194 | 195 | if self.uniform: 196 | return torch.cat(encoding_list, axis=0).unsqueeze(0) 197 | else: 198 | return torch.cat(encoding_list, axis=1).unsqueeze(0) 199 | 200 | 201 | def get_variable_encoder(params): 202 | return VariableEncodingWrapper( 203 | params.equation_dict, 204 | variable_encoding_size=params.n_encoding_channels, 205 | n_dim=params.n_dim, 206 | positional_encoding_dim=params.positional_encoding_dim, 207 | uniform=params.grid_type == 'uniform', 208 | varibale_encoding_modes=( 209 | params.encoding_modes_x, 210 | params.encoding_modes_y) if hasattr( 211 | params, 212 | 'encoding_modes_y') else None) 213 | 214 | 215 | 216 | class TokenExpansion(nn.Module): 217 | def __init__( 218 | self, 219 | n_variables: int, 220 | n_encoding_channels, 221 | n_static_channels: int, 222 | uniform_grid=False) -> None: 223 | """ 224 | stack the variables and the corresponsing encodings together 225 | 226 | Args: 227 | n_variables (int): number of variables 228 | n_encoding_channels (int): number of encoding channels 229 | n_static_channels (int): number of static channels 230 | unifor_grid (bool): if the grid is uniform 231 | 232 | """ 233 | super().__init__() 234 | self.n_variables = n_variables 235 | self.n_encoding_channels = n_encoding_channels 236 | self.n_static_channels = n_static_channels 237 | self.uniform_grid = uniform_grid 238 | 239 | expansion_factor = 1 + self.n_static_channels + self.n_encoding_channels 240 | 241 | self.variable_channels = [ 242 | i * expansion_factor for i in range(n_variables)] 243 | self.static_channels = [] 244 | if self.n_static_channels != 0: 245 | for v in self.variable_channels: 246 | self.static_channels.extend( 247 | range(v + 1, v + self.n_static_channels + 1)) 248 | self.encoding_channels = [] 249 | if self.n_encoding_channels != 0: 250 | self.encoding_channels = sorted(list( 251 | set(range(n_variables * expansion_factor)) 252 | - set(self.variable_channels) 253 | - set(self.static_channels) 254 | )) 255 | 256 | print(self.variable_channels) 257 | print(self.static_channels) 258 | print(self.encoding_channels) 259 | 260 | def forward( 261 | self, 262 | inp: torch.Tensor, 263 | variable_encodings: torch.tensor, 264 | static_channels: torch.tensor) -> torch.Tensor: 265 | """ 266 | x: (batch_size, points, n_variables) 267 | """ 268 | if not self.uniform_grid: 269 | x = torch.zeros((inp.shape[0], inp.shape[1], len(self.variable_channels) + len( 270 | self.encoding_channels) + len(self.static_channels)), device=inp.device, dtype=inp.dtype) 271 | x[:, :, self.variable_channels] = inp 272 | if self.n_static_channels != 0: 273 | x[:, :, self.static_channels] = static_channels.repeat( 274 | x.shape[0], 1, 1) 275 | if self.n_encoding_channels != 0: 276 | x[:, :, self.encoding_channels] = variable_encodings.repeat( 277 | 278 | x.shape[0], 1, 1) 279 | 280 | else: 281 | # current support for only 2D 282 | 283 | x = torch.zeros((inp.shape[0], len( 284 | self.variable_channels) + len(self.encoding_channels) + len(self.static_channels), 285 | inp.shape[-2], inp.shape[-1]), device=inp.device, dtype=inp.dtype) 286 | x[:, self.variable_channels, :, :] = inp 287 | if self.n_static_channels != 0: 288 | x[:, self.static_channels, :, :] = static_channels.repeat( 289 | 1, self.n_variables, 1, 1) 290 | if self.n_encoding_channels != 0: 291 | x[:, self.encoding_channels, :, :] = variable_encodings.repeat( 292 | x.shape[0], 1, 1, 1) 293 | 294 | return x 295 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from YParams import YParams 2 | import wandb 3 | import argparse 4 | import torch 5 | import numpy as np 6 | from data_utils.data_loaders import * 7 | from data_utils.data_utils import MaskerNonuniformMesh, get_meshes 8 | from layers.variable_encoding import * 9 | from models.get_models import * 10 | from train.trainer import trainer 11 | from utils import * 12 | from models.model_helpers import count_parameters 13 | from test.evaluations import missing_variable_testing 14 | import random 15 | 16 | 17 | if __name__ == "__main__": 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--exp", nargs="?", default="FSI", type=str) 21 | parser.add_argument("--config", nargs="?", default="base_config", type=str) 22 | parser.add_argument("--ntrain", nargs="?", default=None, type=int) 23 | parser.add_argument("--epochs", nargs="?", default=None, type=int) 24 | parser.add_argument("--random_seed", nargs="?", default=42, type=int) 25 | parser.add_argument("--scheduler_step", nargs="?", default=None, type=int) 26 | parser.add_argument("--batch_size", nargs="?", default=None, type=int) 27 | parsed_args = parser.parse_args() 28 | 29 | if parsed_args.exp == "FSI": 30 | config_file = './config/ssl_ns_elastic.yaml' 31 | elif parsed_args.exp == "RB": 32 | config_file = './config/RB_config.yaml' 33 | else: 34 | raise ValueError("Unknown experiment type") 35 | 36 | config = parsed_args.config 37 | print("Loading config", config) 38 | params = YParams(config_file, config, print_params=True) 39 | 40 | if parsed_args.ntrain is not None: 41 | params.ntrain = parsed_args.ntrain 42 | print("Overriding ntrain to", params.ntrain) 43 | if parsed_args.random_seed is not None: 44 | params.random_seed = parsed_args.random_seed 45 | print("Overriding random seed to", params.random_seed) 46 | if parsed_args.epochs is not None: 47 | params.epochs = parsed_args.epochs 48 | print("Overriding epochs to", params.epochs) 49 | if parsed_args.scheduler_step is not None: 50 | params.scheduler_step = parsed_args.scheduler_step 51 | print("Overriding scheduler step to", params.scheduler_step) 52 | if parsed_args.batch_size is not None: 53 | params.batch_size = parsed_args.batch_size 54 | print("Overriding batch size to", params.batch_size) 55 | 56 | torch.manual_seed(params.random_seed) 57 | random.seed(params.random_seed) 58 | np.random.seed(params.random_seed) 59 | 60 | params.config = config 61 | 62 | # Set up WandB logging 63 | params.wandb_name = config 64 | params.wandb_group = params.nettype 65 | if params.wandb_log: 66 | wandb.login(key=get_wandb_api_key()) 67 | wandb.init( 68 | config=params, 69 | name=params.wandb_name, 70 | group=params.wandb_group, 71 | project=params.wandb_project, 72 | entity=params.wandb_entity) 73 | 74 | # stage of training: reconstructive (Self supervised traning) 75 | # or predictive (Supervised training either finetuning or from scratch) 76 | if params.pretrain_ssl: 77 | stage = StageEnum.RECONSTRUCTIVE 78 | else: 79 | stage = StageEnum.PREDICTIVE 80 | 81 | variable_encoder = None 82 | token_expander = None 83 | 84 | if params.nettype == 'transformer': 85 | if params.grid_type == 'uniform': 86 | encoder, decoder, contrastive, predictor = get_ssl_models_codano( 87 | params) 88 | input_mesh = None 89 | else: 90 | encoder, decoder, contrastive, predictor = get_ssl_models_codano_gino( 91 | params) 92 | 93 | if params.use_variable_encoding: 94 | variable_encoder = get_variable_encoder(params) 95 | token_expander = TokenExpansion(sum([params.equation_dict[i] for i in params.equation_dict.keys( 96 | )]), params.n_encoding_channels, params.n_static_channels, params.grid_type == 'uniform') 97 | 98 | print("Parameters Encoder", count_parameters(encoder), "x10^6") 99 | print("Parameters Decoder", count_parameters(decoder), "x10^6") 100 | print("Parameters Perdictor", count_parameters(predictor), "x10^6") 101 | if params.wandb_log: 102 | wandb.log( 103 | {'Encoder #parameters': count_parameters(encoder)}, commit=True) 104 | wandb.log( 105 | {'Decoder #parameters': count_parameters(decoder)}, commit=True) 106 | wandb.log( 107 | {'Predictor #parameters': count_parameters(predictor)}, commit=True) 108 | 109 | model = SSLWrapper( 110 | params, 111 | encoder, 112 | decoder, 113 | contrastive, 114 | predictor, 115 | stage=stage) 116 | 117 | if params.grid_type != 'uniform': 118 | print("Setting the Grid") 119 | mesh = get_mesh(params) 120 | input_mesh = torch.from_numpy(mesh).type(torch.float).cuda() 121 | model.set_initial_mesh(input_mesh) 122 | 123 | elif params.nettype in ['simple', 'gnn', 'deeponet', 'vit', 'unet', 'fno']: 124 | model = get_baseline_model(params) 125 | print("Parameters Model", count_parameters(model), "x10^6") 126 | wandb.log({'Model #parameters': count_parameters(model)}, commit=True) 127 | input_mesh = None 128 | 129 | print("PDE list", *list(params.equation_dict.keys())) 130 | 131 | if parsed_args.exp == 'FSI': 132 | # loading Fluid Stucture Interaction dataset 133 | dataset = NsElasticDataset( 134 | params.data_location, 135 | equation=list(params.equation_dict.keys()), 136 | mesh_location=params.input_mesh_location, 137 | params=params) 138 | train, test = dataset.get_dataloader(params.mu_list, params.dt, ntrain=params.get( 139 | 'ntrain'), ntest=params.get('ntest'), sample_per_inlet=params.sample_per_inlet) 140 | elif parsed_args.exp == 'RB': 141 | # loading Rayleigh-Benard dataset 142 | train, test = get_RB_dataloader(params) 143 | 144 | if getattr(params, 'evaluate_only', False): 145 | # setting satge to predictive for evaluation 146 | # load model weights 147 | stage = StageEnum.PREDICTIVE 148 | model.load_state_dict(torch.load(params.model_path), strict=False) 149 | if params.nettype == 'transformer' and params.use_variable_encoding: 150 | if "NS" in params.equation_dict.keys(): 151 | print("Loading NS variable encoder") 152 | variable_encoder.load_encoder( 153 | "NS", params.NS_variable_encoder_path) 154 | if "ES" in params.equation_dict.keys() and params.ES_variable_encoder_path is not None: 155 | print("Loading ES variable encoder") 156 | variable_encoder.load_encoder( 157 | "ES", params.ES_variable_encoder_path) 158 | if "T" in params.equation_dict.keys() and params.T_variable_encoder_path is not None: 159 | print("Loading T variable encoder") 160 | variable_encoder.load_encoder( 161 | "T", params.T_variable_encoder_path) 162 | if params.freeze_encoder: 163 | variable_encoder.freeze("T") 164 | 165 | elif params.training_stage == 'fine_tune': 166 | # load only encooder and vaariable encoder weights (VSPE) 167 | print(f"Loading Pretrained weights from {params.pretrain_weight}") 168 | model.encoder.load_state_dict(torch.load( 169 | params.pretrain_weight), strict=True) 170 | if params.use_variable_encoding: 171 | print( 172 | f"Loading Pretrained weights from {params.NS_variable_encoder_path}") 173 | 174 | if "NS" in params.equation_dict.keys(): 175 | print("Loading NS variable encoder") 176 | variable_encoder.load_encoder( 177 | "NS", params.NS_variable_encoder_path) 178 | if params.freeze_encoder: 179 | variable_encoder.freeze("NS") 180 | 181 | if "ES" in params.equation_dict.keys() and params.ES_variable_encoder_path is not None: 182 | print("Loading ES variable encoder") 183 | variable_encoder.load_encoder( 184 | "ES", params.ES_variable_encoder_path) 185 | if params.freeze_encoder: 186 | variable_encoder.freeze("ES") 187 | 188 | if "T" in params.equation_dict.keys() and params.T_variable_encoder_path is not None: 189 | print("Loading T variable encoder") 190 | variable_encoder.load_encoder( 191 | "T", params.T_variable_encoder_path) 192 | if params.freeze_encoder: 193 | variable_encoder.freeze("T") 194 | 195 | # Move model and encoders to GPU 196 | model = model.cuda() 197 | if variable_encoder is not None: 198 | variable_encoder.cuda() 199 | if token_expander is not None: 200 | token_expander.cuda() 201 | 202 | if not getattr(params, 'evaluate_only', False): 203 | # Train the model 204 | trainer( 205 | model, 206 | train, 207 | test, 208 | params, 209 | wandb_log=params.wandb_log, 210 | log_test_interval=params.wandb_log_test_interval, 211 | stage=stage, 212 | variable_encoder=variable_encoder, 213 | token_expander=token_expander, 214 | initial_mesh=input_mesh) 215 | 216 | if getattr(params, 'missing_var_test', False): 217 | # evaluate on missing variables and partially observed 218 | # variabless 219 | 220 | grid_non, grid_uni = get_meshes( 221 | params, params.grid_size) 222 | test_augmenter = MaskerNonuniformMesh( 223 | grid_non_uni=grid_non.clone().detach(), 224 | gird_uni=grid_uni.clone().detach(), 225 | radius=params.masking_radius, 226 | drop_type=params.drop_type, 227 | drop_pix=params.drop_pix_val, 228 | channel_aug_rate=params.channel_per_val, 229 | channel_drop_rate=params.channel_drop_per_val, 230 | verbose=True) 231 | missing_variable_testing( 232 | model, 233 | test, 234 | test_augmenter, 235 | 'sl', 236 | params, 237 | variable_encoder=variable_encoder, 238 | token_expander=token_expander, 239 | initial_mesh=input_mesh) 240 | else: 241 | # Evaluate the model on the 242 | # unaugmentted the test set 243 | 244 | missing_variable_testing( 245 | model, 246 | test, 247 | augmenter=None, 248 | normalizer=None, 249 | stage=stage, 250 | params=params, 251 | variable_encoder=variable_encoder, 252 | token_expander=token_expander, 253 | initial_mesh=input_mesh) 254 | 255 | if params.wandb_log: 256 | wandb.finish() 257 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/models/__init__.py -------------------------------------------------------------------------------- /models/codano.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import logging 3 | from typing import Literal, NamedTuple, Optional 4 | import numpy as np 5 | from einops import rearrange 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from neuralop.layers.padding import DomainPadding 10 | from layers.codano_block_nd import CodanoBlockND 11 | from layers.fino_nd import SpectralConvKernel2d 12 | from layers.variable_encoding import VariableEncoding2d 13 | 14 | 15 | # TODO replace with nerualop.MLP module 16 | class PermEqProjection(nn.Module): 17 | def __init__( 18 | self, 19 | in_channels, 20 | out_channels, 21 | hidden_channels=None, 22 | n_dim=2, 23 | non_linearity=F.gelu, 24 | permutation_invariant=False, 25 | ): 26 | """Permutation invariant projection layer. 27 | 28 | Performs linear projections on each channel separately. 29 | """ 30 | super().__init__() 31 | self.in_channels = in_channels 32 | self.out_channels = out_channels 33 | self.hidden_channels = (in_channels 34 | if hidden_channels is None else 35 | hidden_channels) 36 | self.non_linearity = non_linearity 37 | Conv = getattr(nn, f'Conv{n_dim}d') 38 | 39 | self.permutation_invariant = permutation_invariant 40 | 41 | self.fc1 = Conv(in_channels, hidden_channels, 1) 42 | self.norm = nn.InstanceNorm2d(hidden_channels, affine=True) 43 | self.fc2 = Conv(hidden_channels, out_channels, 1) 44 | 45 | def forward(self, x): 46 | batch = x.shape[0] 47 | if self.permutation_invariant: 48 | assert x.shape[1] % self.in_channels == 0, \ 49 | "Total Number of Channels is not divisible by number of tokens" 50 | x = rearrange(x, 'b (g c) h w -> (b g) c h w', c=self.in_channels) 51 | 52 | x = self.fc1(x) 53 | x = self.norm(x) 54 | x = self.non_linearity(x) 55 | x = self.fc2(x) 56 | if self.permutation_invariant: 57 | x = rearrange(x, '(b g) c h w -> b (g c) h w', b=batch) 58 | return x 59 | 60 | 61 | class VariableEncodingArgs(NamedTuple): 62 | basis: Literal["sht", "fft"] 63 | n_channels: int 64 | """Number of extra encoding channels per variable.""" 65 | modes_x: int 66 | modes_y: int 67 | modes_t: Optional[int] = None 68 | 69 | 70 | class CodANO(nn.Module): 71 | """ 72 | Parameters 73 | --- 74 | input_token_codimension : input token codim/number of channel per input token 75 | out_token_codim=None : output token codim/number of channel per output token 76 | hidden_token_codim=None : 77 | lifting_token_codim=None : 78 | var_encoding=False : boolean 79 | if true then it adds variable encoding with each channel 80 | var_num=None : denotes the number of variables 81 | var_enco_basis='sht' : specify the basis funtion for variable encodings 82 | var_enco_channels=1 : number of channels for each variable encoding 83 | var_enco_mode_x=50 : number of x modes for each variable encoding 84 | var_enco_mode_y=50 : number of y models for each variable encoding 85 | enable_cls_token=False : if true, learnable cls token will be added 86 | static_channels_num=0 : 87 | Number of static channels to be concatenated (xy grid, land/sea mask etc) 88 | static_features=None : 89 | The static feature (it will be taken from the Preprocessor while 90 | initializing the model) 91 | integral_operator_top : 92 | Required for the re-grid operation (for example: from equiangular to LG grid.) 93 | integral_operator_bottom : 94 | Required for the re-grid operation (for example: from LG grid to equiangular) 95 | """ 96 | 97 | def __init__( 98 | self, 99 | input_token_codimension, 100 | output_token_codimension=None, 101 | hidden_token_codimension=None, 102 | lifting_token_codimension=None, 103 | n_layers=4, 104 | n_modes=None, 105 | max_n_modes=None, 106 | scalings=None, 107 | n_heads=1, 108 | non_linearity=F.gelu, 109 | layer_kwargs={'use_mlp': False, 110 | 'mlp_dropout': 0, 111 | 'mlp_expansion': 1.0, 112 | 'non_linearity': F.gelu, 113 | 'norm': None, 114 | 'preactivation': False, 115 | 'fno_skip': 'linear', 116 | 'horizontal_skip': 'linear', 117 | 'mlp_skip': 'linear', 118 | 'separable': False, 119 | 'factorization': None, 120 | 'rank': 1.0, 121 | 'fft_norm': 'forward', 122 | 'normalizer': 'instance_norm', 123 | 'joint_factorization': False, 124 | 'fixed_rank_modes': False, 125 | 'implementation': 'factorized', 126 | 'decomposition_kwargs': dict(), 127 | 'normalizer': False}, 128 | per_channel_attention=True, 129 | operator_block=CodanoBlockND, 130 | integral_operator=SpectralConvKernel2d, 131 | integral_operator_top=partial( 132 | SpectralConvKernel2d, sht_grid="legendre-gauss"), 133 | integral_operator_bottom=partial( 134 | SpectralConvKernel2d, isht_grid="legendre-gauss"), 135 | projection=True, 136 | lifting=True, 137 | domain_padding=0.5, 138 | domain_padding_mode='one-sided', 139 | n_variables=None, 140 | variable_encoding_args: VariableEncodingArgs = None, 141 | enable_cls_token=False, 142 | logger=None, 143 | ): 144 | super().__init__() 145 | self.n_layers = n_layers 146 | assert len( 147 | n_modes) == n_layers, "number of modes for all layers are not given" 148 | assert len(n_heads) == n_layers, \ 149 | "number of Attention head for all layers are not given" 150 | if integral_operator_bottom is None: 151 | integral_operator_bottom = integral_operator 152 | if integral_operator_top is None: 153 | integral_operator_top = integral_operator 154 | self.n_dim = len(n_modes[0]) 155 | self.input_token_codimension = input_token_codimension 156 | # self.n_variables = n_variables 157 | if hidden_token_codimension is None: 158 | hidden_token_codimension = input_token_codimension 159 | if lifting_token_codimension is None: 160 | lifting_token_codimension = input_token_codimension 161 | if output_token_codimension is None: 162 | output_token_codimension = input_token_codimension 163 | 164 | self.hidden_token_codimension = hidden_token_codimension 165 | self.n_modes = n_modes 166 | self.max_n_modes = max_n_modes 167 | self.scalings = scalings 168 | self.non_linearity = non_linearity 169 | self.n_heads = n_heads 170 | self.integral_operator = integral_operator 171 | self.lifting = lifting 172 | self.projection = projection 173 | self.num_dims = len(n_modes[0]) 174 | self.enable_cls_token = enable_cls_token 175 | 176 | if logger is None: 177 | logger = logging.getLogger() 178 | self.logger = logger 179 | 180 | self.layer_kwargs = layer_kwargs 181 | if layer_kwargs is None: 182 | self.layer_kwargs = { 183 | 'incremental_n_modes': None, 184 | 'use_mlp': False, 185 | 'mlp_dropout': 0, 186 | 'mlp_expansion': 1.0, 187 | 'non_linearity': F.gelu, 188 | 'norm': None, 189 | 'preactivation': False, 190 | 'fno_skip': 'linear', 191 | 'horizontal_skip': 'linear', 192 | 'mlp_skip': 'linear', 193 | 'separable': False, 194 | 'factorization': None, 195 | 'rank': 1.0, 196 | 'fft_norm': 'forward', 197 | 'normalizer': 'instance_norm', 198 | 'joint_factorization': False, 199 | 'fixed_rank_modes': False, 200 | 'implementation': 'factorized', 201 | 'decomposition_kwargs': None, 202 | } 203 | 204 | # self.n_static_channels = n_static_channels 205 | """The number of static channels for all variable channels.""" 206 | 207 | # calculating scaling 208 | if self.scalings is not None: 209 | self.end_to_end_scaling = self.get_output_scaling_factor( 210 | np.ones_like(self.scalings[0]), 211 | self.scalings 212 | ) 213 | else: 214 | self.end_to_end_scaling = 1 215 | self.logger.debug(f"{self.end_to_end_scaling=}") 216 | if isinstance(self.end_to_end_scaling, (float, int)): 217 | self.end_to_end_scaling = [self.end_to_end_scaling] * self.n_dim 218 | 219 | # Setting up domain padding for encoder and reconstructor 220 | if domain_padding is not None and domain_padding > 0: 221 | self.domain_padding = DomainPadding( 222 | domain_padding=domain_padding, 223 | padding_mode=domain_padding_mode, 224 | output_scaling_factor=self.end_to_end_scaling, 225 | ) 226 | else: 227 | self.domain_padding = None 228 | self.domain_padding_mode = domain_padding_mode 229 | 230 | # A variable + it's variable encoding + the static channel(s) 231 | # together constitute a token 232 | # n_lifted_channels = self.input_token_codimension + \ 233 | # variable_encoding_args.n_channels + \ 234 | # self.n_static_channels 235 | if self.lifting: 236 | self.lifting = PermEqProjection( 237 | in_channels=input_token_codimension, 238 | out_channels=hidden_token_codimension, 239 | hidden_channels=lifting_token_codimension, 240 | n_dim=self.n_dim, 241 | non_linearity=self.non_linearity, 242 | permutation_invariant=True, # Permutation 243 | ) 244 | # elif self.use_variable_encoding: 245 | # hidden_token_codimension = n_lifted_channels 246 | 247 | cls_dimension = 1 if enable_cls_token else 0 248 | self.codimension_size = hidden_token_codimension * n_variables + cls_dimension 249 | 250 | self.logger.debug( 251 | f"Expected number of channels: {self.codimension_size=}") 252 | 253 | self.base = nn.ModuleList([]) 254 | for i in range(self.n_layers): 255 | if i == 0 and self.n_layers != 1: 256 | conv_op = integral_operator_top 257 | elif i == self.n_layers - 1 and self.n_layers != 1: 258 | conv_op = integral_operator_bottom 259 | else: 260 | conv_op = self.integral_operator 261 | 262 | self.base.append( 263 | operator_block( 264 | n_modes=self.n_modes[i], 265 | max_n_modes=self.max_n_modes[i], 266 | n_head=self.n_heads[i], 267 | token_codim=hidden_token_codimension, 268 | output_scaling_factor=[self.scalings[i]], 269 | SpectralConvolution=conv_op, 270 | codim_size=self.codimension_size, 271 | per_channel_attention=per_channel_attention, 272 | num_dims=self.num_dims, 273 | logger=self.logger.getChild(f"base[{i}]"), 274 | **self.layer_kwargs, 275 | ) 276 | ) 277 | 278 | if self.projection: 279 | self.projection = PermEqProjection( 280 | in_channels=hidden_token_codimension, 281 | out_channels=output_token_codimension, 282 | hidden_channels=lifting_token_codimension, 283 | n_dim=self.n_dim, 284 | non_linearity=self.non_linearity, 285 | permutation_invariant=True, # Permutation 286 | ) 287 | 288 | if enable_cls_token: 289 | self.cls_token = VariableEncoding2d( 290 | 1, 291 | hidden_token_codimension, 292 | (variable_encoding_args.modes_x, 293 | variable_encoding_args.modes_y), 294 | basis=variable_encoding_args.basis) 295 | 296 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer): 297 | for k in scalings_per_layer: 298 | initial_scale = np.multiply(initial_scale, k) 299 | initial_scale = initial_scale.tolist() 300 | if len(initial_scale) == 1: 301 | initial_scale = initial_scale[0] 302 | return initial_scale 303 | 304 | def get_device(self,): 305 | return self.cls_token.coefficients_r.device 306 | 307 | def forward(self, x: torch.Tensor): 308 | if self.lifting: 309 | x = self.lifting(x) 310 | 311 | if self.enable_cls_token: 312 | cls_token = self.cls_token(x).unsqueeze(0) 313 | repeat_shape = [1 for _ in x.shape] 314 | repeat_shape[0] = x.shape[0] 315 | x = torch.cat( 316 | [ 317 | cls_token.repeat(*repeat_shape), 318 | x, 319 | ], 320 | dim=1, 321 | ) 322 | 323 | if self.domain_padding is not None: 324 | x = self.domain_padding.pad(x) 325 | 326 | output_shape_en = [round(i * j) for (i, 327 | j) in zip(x.shape[-self.n_dim:], 328 | self.end_to_end_scaling)] 329 | 330 | cur_output_shape = None 331 | for layer_idx in range(self.n_layers): 332 | if layer_idx == self.n_layers - 1: 333 | cur_output_shape = output_shape_en 334 | x = self.base[layer_idx](x, output_shape=cur_output_shape) 335 | # self.logger.debug(f"{x.shape} (block[{layer_idx}])") 336 | 337 | if self.domain_padding is not None: 338 | x = self.domain_padding.unpad(x) 339 | 340 | if self.projection: 341 | x = self.projection(x) 342 | # self.logger.debug(f"{x.shape} (projection)") 343 | 344 | return x 345 | 346 | 347 | class CoDANOTemporal: 348 | def __call__(self, x): 349 | pass 350 | -------------------------------------------------------------------------------- /models/codano_gino.py: -------------------------------------------------------------------------------- 1 | from .codano import CodANO 2 | from layers.gno_layer import GnoPremEq 3 | from layers.codano_block_2D import CodanoBlocks2d 4 | from layers.fino_2D import SpectralConvKernel2d 5 | from functools import partial 6 | from einops import rearrange 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from layers.regrider import Regird 10 | from neuralop.layers.padding import DomainPadding 11 | import numpy as np 12 | import torch 13 | from layers.variable_encoding import VariableEncoding2d 14 | 15 | # TODO replace with nerualop.MLP module 16 | 17 | 18 | class Projection(nn.Module): 19 | def __init__( 20 | self, 21 | in_channels, 22 | out_channels, 23 | hidden_channels=None, 24 | n_dim=2, 25 | non_linearity=F.gelu, 26 | permutation_invariant=False, 27 | ): 28 | """Permutation invariant projection layer. 29 | 30 | Performs linear projections on each channel separately. 31 | """ 32 | super().__init__() 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.hidden_channels = (in_channels 36 | if hidden_channels is None else 37 | hidden_channels) 38 | self.non_linearity = non_linearity 39 | Conv = getattr(nn, f'Conv{n_dim}d') 40 | 41 | self.permutation_invariant = permutation_invariant 42 | 43 | self.fc1 = Conv(in_channels, hidden_channels, 1) 44 | self.norm = nn.InstanceNorm2d(hidden_channels, affine=True) 45 | self.fc2 = Conv(hidden_channels, out_channels, 1) 46 | 47 | def forward(self, x): 48 | batch = x.shape[0] 49 | if self.permutation_invariant: 50 | assert x.shape[1] % self.in_channels == 0, \ 51 | "Total Number of Channels is not divisible by number of tokens" 52 | x = rearrange(x, 'b (g c) h w -> (b g) c h w', c=self.in_channels) 53 | 54 | x = self.fc1(x) 55 | x = self.norm(x) 56 | x = self.non_linearity(x) 57 | x = self.fc2(x) 58 | if self.permutation_invariant: 59 | x = rearrange(x, '(b g) c h w -> b (g c) h w', b=batch) 60 | return x 61 | 62 | 63 | class CondnoGino(nn.Module): 64 | def __init__(self, 65 | in_token_codim, 66 | input_grid, 67 | output_grid=None, 68 | grid_size=None, 69 | radius=None, 70 | n_neigbor=10, 71 | fixed_neighbour=False, 72 | out_token_codim=None, 73 | hidden_token_codim=None, 74 | lifting_token_codim=None, 75 | kqv_non_linear=False, 76 | n_layers=4, 77 | n_modes=None, 78 | scalings=None, 79 | n_heads=1, 80 | layer_kwargs={'incremental_n_modes': None, 81 | 'use_mlp': False, 82 | 'mlp_dropout': 0, 83 | 'mlp_expansion': 1.0, 84 | 'non_linearity': torch.sin, 85 | 'norm': None, 'preactivation': False, 86 | 'fno_skip': 'linear', 87 | 'horizontal_skip': 'linear', 88 | 'mlp_skip': 'linear', 89 | 'separable': False, 90 | 'factorization': None, 91 | 'rank': 1.0, 92 | 'fft_norm': 'forward', 93 | 'normalizer': 'instance_norm', 94 | 'joint_factorization': False, 95 | 'fixed_rank_modes': False, 96 | 'implementation': 'factorized', 97 | 'decomposition_kwargs': dict(), 98 | 'normalizer': False}, 99 | operator_block=CodanoBlocks2d, 100 | per_channel_attention=False, 101 | integral_operator=SpectralConvKernel2d, 102 | integral_operator_top=None, 103 | integral_operator_bottom=None, 104 | re_grid_input=False, 105 | re_grid_output=False, 106 | projection=True, 107 | gno_mlp_layers=None, 108 | lifting=True, 109 | domain_padding=None, 110 | domain_padding_mode='one-sided', 111 | var_encoding=False, 112 | var_num=None, # denotes the number of varibales 113 | var_enco_basis='fft', 114 | var_enco_channels=1, 115 | var_enco_mode_x=20, 116 | var_enco_mode_y=40, 117 | enable_cls_token=False, 118 | static_channels_num=0, 119 | static_features=None, 120 | ): 121 | super().__init__() 122 | self.n_layers = n_layers 123 | assert len( 124 | n_modes) == n_layers, "number of modes for all layers are not given" 125 | assert len( 126 | n_heads) == n_layers, "number of Attention head for all layers are not given" 127 | if output_grid is None: 128 | output_grid = input_grid.clone() 129 | if integral_operator_bottom is None: 130 | integral_operator_bottom = integral_operator 131 | if integral_operator_top is None: 132 | integral_operator_top = integral_operator 133 | self.n_dim = len(n_modes[0]) 134 | self.in_token_codim = in_token_codim 135 | self.var_num = var_num 136 | if hidden_token_codim is None: 137 | hidden_token_codim = in_token_codim 138 | if lifting_token_codim is None: 139 | lifting_token_codim = in_token_codim 140 | if out_token_codim is None: 141 | out_token_codim = in_token_codim 142 | self.re_grid_input = re_grid_input 143 | self.re_grid_output = re_grid_output 144 | 145 | if self.re_grid_input: 146 | self.input_regrider = Regird("equiangular", "legendre-gauss") 147 | if self.re_grid_output: 148 | self.output_regrider = Regird("legendre-gauss", "equiangular") 149 | 150 | self.input_grid = input_grid 151 | self.output_grid = output_grid 152 | self.grid_size = grid_size 153 | 154 | self.hidden_token_codim = hidden_token_codim 155 | self.n_modes = n_modes 156 | self.scalings = scalings 157 | self.var_enco_channels = var_enco_channels 158 | self.n_heads = n_heads 159 | self.integral_operator = integral_operator 160 | self.layer_kwargs = layer_kwargs 161 | self.operator_block = operator_block 162 | self.lifting = lifting 163 | self.projection = projection 164 | 165 | self.radius = radius 166 | self.n_neigbor = n_neigbor 167 | self.fixed_neighbour = fixed_neighbour 168 | self.gno_mlp_layers = gno_mlp_layers 169 | self.per_channel_attention = per_channel_attention 170 | 171 | self.register_buffer("static_features", static_features) 172 | self.static_channels_num = static_channels_num 173 | # calculating scaling 174 | if self.scalings is not None: 175 | self.end_to_end_scaling = self.get_output_scaling_factor( 176 | np.ones_like(self.scalings[0]), self.scalings) 177 | print("End to End Scaling", self.end_to_end_scaling) 178 | else: 179 | self.end_to_end_scaling = 1 180 | if isinstance(self.end_to_end_scaling, (float, int)): 181 | self.end_to_end_scaling = [self.end_to_end_scaling] * self.n_dim 182 | 183 | # Setting up domain padding for encoder and reconstructor 184 | 185 | if domain_padding is not None and domain_padding > 0: 186 | self.domain_padding = DomainPadding( 187 | domain_padding=domain_padding, 188 | padding_mode=domain_padding_mode, 189 | output_scaling_factor=self.end_to_end_scaling) 190 | else: 191 | self.domain_padding = None 192 | self.domain_padding_mode = domain_padding_mode 193 | 194 | # Code for varibale encoding 195 | 196 | # initializing Components 197 | if self.lifting: 198 | print('Using lifing Layer') 199 | 200 | # a varibale + it's varibale encoding + the static channen together 201 | # constitute a token 202 | 203 | self.lifting = GnoPremEq( 204 | var_num=var_num, 205 | in_dim=self.in_token_codim, 206 | out_dim=hidden_token_codim, 207 | input_grid=self.input_grid, 208 | output_grid=self.output_grid, 209 | projection_hidden_dim=lifting_token_codim, 210 | mlp_layers=self.gno_mlp_layers, 211 | radius=self.radius, 212 | n_neigbor=n_neigbor, 213 | fixed_neighbour=fixed_neighbour, 214 | var_encoding=var_encoding, 215 | var_encoding_channels=var_enco_channels) 216 | 217 | elif var_encoding: 218 | hidden_token_codim = self.in_token_codim + \ 219 | var_enco_channels + self.static_channels_num 220 | 221 | if enable_cls_token: 222 | count = 1 223 | else: 224 | count = 0 225 | self.codim_size = hidden_token_codim * \ 226 | (var_num + count) # +1 is for the CLS token 227 | 228 | print("expected number of channels", self.codim_size) 229 | 230 | self.base = nn.ModuleList([]) 231 | for i in range(self.n_layers): 232 | if i == 0 and self.n_layers != 1: 233 | conv_op = integral_operator_top 234 | elif i == self.n_layers - 1 and self.n_layers != 1: 235 | conv_op = integral_operator_bottom 236 | else: 237 | conv_op = self.integral_operator 238 | 239 | self.base.append(self.operator_block( 240 | n_modes=self.n_modes[i], 241 | n_head=self.n_heads[i], 242 | token_codimension=hidden_token_codim, 243 | output_scaling_factor=[self.scalings[i]], 244 | SpectralConvolution=conv_op, 245 | codimension_size=self.codim_size, 246 | per_channel_attention=self.per_channel_attention, 247 | kqv_non_linear=kqv_non_linear, 248 | **self.layer_kwargs)) 249 | if self.projection: 250 | print("Using Projection Layer") 251 | self.projection = GnoPremEq( 252 | var_num=var_num, 253 | in_dim=self.hidden_token_codim, 254 | out_dim=self.hidden_token_codim, 255 | input_grid=self.output_grid, 256 | output_grid=self.input_grid, 257 | mlp_layers=self.gno_mlp_layers, 258 | radius=self.radius, 259 | n_neigbor=n_neigbor, 260 | fixed_neighbour=fixed_neighbour, 261 | var_encoding=False, 262 | projection_hidden_dim=lifting_token_codim, 263 | var_encoding_channels=0, 264 | end_projection=True, 265 | end_projection_outdim=out_token_codim) 266 | 267 | # Code for varibale encoding 268 | 269 | self.enable_cls_token = enable_cls_token 270 | if enable_cls_token: 271 | print("intializing CLS token") 272 | self.cls_token = VariableEncoding2d( 273 | 1, hidden_token_codim, (var_enco_mode_x, var_enco_mode_y), basis=var_enco_basis) 274 | 275 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer): 276 | for k in scalings_per_layer: 277 | initial_scale = np.multiply(initial_scale, k) 278 | initial_scale = initial_scale.tolist() 279 | if len(initial_scale) == 1: 280 | initial_scale = initial_scale[0] 281 | return initial_scale 282 | 283 | def get_device(self,): 284 | return self.cls_token.coefficients_r.device 285 | 286 | def forward(self, inp): 287 | ''' 288 | inp = (batch_size, n_points, in_dims/Channels) 289 | currenly only batch_size = 1 290 | ''' 291 | inp = inp[0, :, :] 292 | inp = inp[None, ...] 293 | if self.re_grid_input: 294 | inp = self.input_regrider(inp) 295 | if self.lifting: 296 | x = self.lifting(inp) 297 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0]) 298 | else: 299 | x = inp 300 | 301 | if self.enable_cls_token: 302 | cls_token = self.cls_token(x) 303 | x = torch.cat([cls_token[None, :, :, :].repeat( 304 | x.shape[0], 1, 1, 1), x], dim=1) 305 | 306 | if self.domain_padding is not None: 307 | x = self.domain_padding.pad(x) 308 | 309 | output_shape_en = [int(round(i * j)) for (i, 310 | j) in zip(x.shape[-self.n_dim:], 311 | self.end_to_end_scaling)] 312 | 313 | cur_output_shape = None 314 | for layer_idx in range(self.n_layers): 315 | if layer_idx == self.n_layers - 1: 316 | cur_output_shape = output_shape_en 317 | x = self.base[layer_idx](x, output_shape=cur_output_shape) 318 | 319 | if self.re_grid_output: 320 | x = self.output_regrider(x) 321 | if self.projection: 322 | x = rearrange(x, 'b c h w -> b (h w) c') 323 | x = self.projection(x) 324 | 325 | return x 326 | -------------------------------------------------------------------------------- /models/deeponet.py: -------------------------------------------------------------------------------- 1 | from layers.gnn_layer import GnnLayer 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch 5 | from layers.gnn_layer import GnnLayer 6 | from neuralop.layers.embeddings import PositionalEmbedding 7 | 8 | 9 | class DeepONet(nn.Module): 10 | def __init__(self, 11 | in_dim, 12 | out_dim, 13 | input_grid, 14 | output_grid=None, 15 | branch_layers=[128], 16 | trunk_layers=[128], 17 | initial_mesh=None, 18 | positional_encoding_dim=8, 19 | n_neigbor=10, 20 | gno_mlp_layers=None, 21 | ): 22 | super().__init__() 23 | if output_grid is None: 24 | output_grid = input_grid.clone() 25 | self.n_dim = input_grid.shape[-1] 26 | self.n_neigbor = n_neigbor 27 | self.gno_mlp_layers = gno_mlp_layers 28 | self.in_dim = in_dim 29 | print("in_dim", in_dim) 30 | if out_dim is None: 31 | out_dim = in_dim 32 | self.out_dim = out_dim 33 | self.positional_encoding_dim = positional_encoding_dim 34 | self.input_grid = input_grid 35 | self.output_grid = output_grid 36 | self.initial_mesh = initial_mesh 37 | self.branch_layers = branch_layers 38 | self.trunk_layers = trunk_layers 39 | self.gnn = None 40 | self.branch = self.get_branch() 41 | self.trunk = self.get_trunk() 42 | self.PE = PositionalEmbedding(positional_encoding_dim) 43 | self.bias = nn.Parameter(torch.zeros(out_dim)) 44 | 45 | # Code for varibale encoding 46 | 47 | def get_branch(self, ): 48 | dim1 = self.in_dim + self.n_dim * self.positional_encoding_dim 49 | # self.branch_layers = [dim1] + self.branch_layers 50 | self.gnn = GnnLayer( 51 | dim1, 52 | self.branch_layers[0], 53 | self.initial_mesh, 54 | self.initial_mesh, 55 | self.gno_mlp_layers, 56 | self.branch_layers[0], 57 | self.n_neigbor) 58 | self.layer_norm = nn.LayerNorm(self.branch_layers[0]) 59 | layers = [] 60 | 61 | self.branch_layers[0] = self.branch_layers[0] * \ 62 | self.input_grid.shape[0] 63 | for i in range(len(self.branch_layers) - 1): 64 | layers.append( 65 | nn.Linear(self.branch_layers[i], self.branch_layers[i + 1])) 66 | torch.nn.init.xavier_normal_(layers[-1].weight) 67 | if i != len(self.branch_layers) - 2: 68 | # layers.append(nn.LayerNorm(self.branch_layers[i+1])) 69 | layers.append(nn.ReLU()) 70 | return nn.Sequential(*layers) 71 | 72 | def get_trunk(self, ): 73 | dim1 = self.n_dim + self.positional_encoding_dim * self.n_dim 74 | self.trunk_layers = [dim1] + self.trunk_layers 75 | self.trunk_layers[-1] = self.trunk_layers[-1] * self.out_dim 76 | layers = [] 77 | for i in range(len(self.trunk_layers) - 1): 78 | layers.append( 79 | nn.Linear(self.trunk_layers[i], self.trunk_layers[i + 1])) 80 | torch.nn.init.xavier_normal_(layers[-1].weight) 81 | # if i != len(self.trunk_layers) - 2: 82 | # layers.append(nn.LayerNorm(self.trunk_layers[i+1])) 83 | layers.append(nn.ReLU()) 84 | return nn.Sequential(*layers) 85 | 86 | def get_pe(self, grid): 87 | pe = self.PE(grid.reshape(-1)) 88 | pe = pe.reshape(grid.shape[0], -1) 89 | return pe 90 | 91 | def forward( 92 | self, 93 | inp, 94 | out_grid_displacement=None, 95 | in_grid_displacement=None): 96 | ''' 97 | inp = (batch_size, n_points, in_dims/Channels) 98 | currenly only batch_size = 1 99 | ''' 100 | 101 | if out_grid_displacement is not None: 102 | with torch.no_grad(): 103 | in_grid = self.initial_mesh + in_grid_displacement 104 | out_grid = self.initial_mesh + out_grid_displacement 105 | self.gnn.update_grid(in_grid.clone(), in_grid.clone()) 106 | 107 | in_pe = self.get_pe(in_grid) 108 | in_data = torch.cat([inp, in_pe.unsqueeze(0)], dim=-1) 109 | 110 | bout = self.gnn(in_data[0])[None, ...] # (btach , dim) 111 | 112 | bout = self.layer_norm(bout) 113 | 114 | bout = self.branch(bout.reshape(inp.shape[0], -1)) 115 | 116 | bout = bout / np.sqrt(self.branch_layers[-1]) 117 | 118 | pe = self.get_pe(out_grid) # self.PE(out_grid.reshape(-1)) 119 | # pe = pe.reshape(out_grid.shape[0], -1) 120 | grid_pe = torch.cat([out_grid, pe], axis=1) 121 | 122 | tout = self.trunk(grid_pe) # (ngrid, dim * out_dim) 123 | # (ngrid, out_dim, dim) 124 | tout = tout.reshape(out_grid.shape[0], self.out_dim, -1) 125 | 126 | out = torch.einsum('bd,ncd->bnc', bout, tout) 127 | 128 | return out + self.bias 129 | -------------------------------------------------------------------------------- /models/fno_gino.py: -------------------------------------------------------------------------------- 1 | from layers.gno_layer import GNO 2 | from layers.fino_2D import SpectralConvKernel2d 3 | from einops import rearrange 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from layers.regrider import Regird 7 | from neuralop.layers.padding import DomainPadding 8 | from neuralop.layers.fno_block import FNOBlocks 9 | import numpy as np 10 | import torch 11 | from layers.variable_encoding import VariableEncoding2d 12 | 13 | 14 | class FnoGno(nn.Module): 15 | def __init__(self, 16 | in_dim, 17 | out_dim, 18 | input_grid, 19 | output_grid=None, 20 | grid_size=None, 21 | radius=None, 22 | fixed_neighbour=False, 23 | n_neigbor=10, 24 | hidden_dim=None, 25 | lifting_dim=None, 26 | n_layers=4, 27 | max_n_modes=None, 28 | n_modes=None, 29 | scalings=None, 30 | initial_mesh=None, 31 | non_linearity=F.gelu, 32 | layer_kwargs={'use_mlp': False, 33 | 'mlp_dropout': 0, 34 | 'mlp_expansion': 1.0, 35 | 'non_linearity': F.gelu, 36 | 'norm': None, 'preactivation': False, 37 | 'fno_skip': 'linear', 38 | 'horizontal_skip': 'linear', 39 | 'mlp_skip': 'linear', 40 | 'separable': False, 41 | 'factorization': None, 42 | 'rank': 1.0, 43 | 'fft_norm': 'forward', 44 | 'normalizer': 'instance_norm', 45 | 'joint_factorization': False, 46 | 'fixed_rank_modes': False, 47 | 'implementation': 'factorized', 48 | 'decomposition_kwargs': dict(), 49 | 'normalizer': False}, 50 | operator_block=FNOBlocks, 51 | integral_operator=SpectralConvKernel2d, 52 | integral_operator_top=None, 53 | integral_operator_bottom=None, 54 | re_grid_input=False, 55 | re_grid_output=False, 56 | projection=True, 57 | gno_mlp_layers=None, 58 | lifting=True, 59 | domain_padding=None, 60 | domain_padding_mode='one-sided', 61 | ): 62 | super().__init__() 63 | self.n_layers = n_layers 64 | assert len( 65 | n_modes) == n_layers, "number of modes for all layers are not given" 66 | if output_grid is None: 67 | output_grid = input_grid.clone() 68 | if integral_operator_bottom is None: 69 | integral_operator_bottom = integral_operator 70 | if integral_operator_top is None: 71 | integral_operator_top = integral_operator 72 | self.n_dim = len(max_n_modes[0]) 73 | self.in_dim = in_dim 74 | if hidden_dim is None: 75 | hidden_dim = in_dim 76 | if lifting_dim is None: 77 | lifting_dim = in_dim 78 | if out_dim is None: 79 | out_dim = in_dim 80 | self.re_grid_input = re_grid_input 81 | self.re_grid_output = re_grid_output 82 | 83 | if self.re_grid_input: 84 | self.input_regrider = Regird("equiangular", "legendre-gauss") 85 | if self.re_grid_output: 86 | self.output_regrider = Regird("legendre-gauss", "equiangular") 87 | 88 | self.input_grid = input_grid 89 | self.output_grid = output_grid 90 | self.grid_size = grid_size 91 | 92 | self.hidden_dim = hidden_dim 93 | self.n_modes = n_modes 94 | self.max_n_modes = max_n_modes 95 | self.scalings = scalings 96 | self.integral_operator = integral_operator 97 | self.layer_kwargs = layer_kwargs 98 | self.operator_block = operator_block 99 | self.lifting = lifting 100 | self.projection = projection 101 | self.radius = radius 102 | self.fixed_neighbour = fixed_neighbour 103 | self.n_neigbor = n_neigbor 104 | self.gno_mlp_layers = gno_mlp_layers 105 | 106 | # calculating scaling 107 | if self.scalings is not None: 108 | self.end_to_end_scaling = self.get_output_scaling_factor( 109 | np.ones_like(self.scalings[0]), self.scalings) 110 | print("End to End Scaling", self.end_to_end_scaling) 111 | else: 112 | self.end_to_end_scaling = 1 113 | if isinstance(self.end_to_end_scaling, (float, int)): 114 | self.end_to_end_scaling = [self.end_to_end_scaling] * self.n_dim 115 | 116 | # Setting up domain padding for encoder and reconstructor 117 | 118 | if domain_padding is not None and domain_padding > 0: 119 | self.domain_padding = DomainPadding( 120 | domain_padding=domain_padding, 121 | padding_mode=domain_padding_mode, 122 | output_scaling_factor=self.end_to_end_scaling) 123 | else: 124 | self.domain_padding = None 125 | self.domain_padding_mode = domain_padding_mode 126 | self.initial_mesh = initial_mesh 127 | # Code for varibale encoding 128 | 129 | # initializing Components 130 | if self.lifting: 131 | print('Using lifing Layer') 132 | 133 | # a varibale + it's varibale encoding + the static channen together 134 | # constitute a token 135 | 136 | self.lifting = GNO( 137 | in_dim=self.in_dim, 138 | out_dim=hidden_dim, 139 | input_grid=self.input_grid, 140 | output_grid=self.output_grid, 141 | projection_hidden_dim=lifting_dim, 142 | mlp_layers=self.gno_mlp_layers, 143 | radius=self.radius, 144 | fixed_neighbour=self.fixed_neighbour, 145 | n_neigbor=self.n_neigbor) 146 | 147 | self.base = nn.ModuleList([]) 148 | for i in range(self.n_layers): 149 | if i == 0 and self.n_layers != 1: 150 | conv_op = integral_operator_top 151 | elif i == self.n_layers - 1 and self.n_layers != 1: 152 | conv_op = integral_operator_bottom 153 | else: 154 | conv_op = self.integral_operator 155 | 156 | self.base.append(self.operator_block( 157 | hidden_dim, 158 | hidden_dim, 159 | max_n_modes=self.max_n_modes[i], 160 | n_modes=self.n_modes[i], 161 | output_scaling_factor=[self.scalings[i]], 162 | SpectralConv=conv_op, 163 | **self.layer_kwargs)) 164 | if self.projection: 165 | # input and output grid is swapped 166 | 167 | print("Using Projection Layer") 168 | self.projection = GNO( 169 | in_dim=self.hidden_dim, 170 | out_dim=out_dim, 171 | input_grid=self.output_grid, 172 | projection_hidden_dim=lifting_dim, 173 | output_grid=self.input_grid, 174 | mlp_layers=self.gno_mlp_layers, 175 | radius=self.radius, 176 | fixed_neighbour=self.fixed_neighbour, 177 | n_neigbor=self.n_neigbor) 178 | 179 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer): 180 | for k in scalings_per_layer: 181 | initial_scale = np.multiply(initial_scale, k) 182 | initial_scale = initial_scale.tolist() 183 | if len(initial_scale) == 1: 184 | initial_scale = initial_scale[0] 185 | return initial_scale 186 | 187 | def get_device(self,): 188 | return self.cls_token.coefficients_r.device 189 | 190 | def forward( 191 | self, 192 | inp, 193 | out_grid_displacement=None, 194 | in_grid_displacement=None): 195 | ''' 196 | inp = (batch_size, n_points, in_dims/Channels) 197 | currenly only batch_size = 1 198 | ''' 199 | if out_grid_displacement is not None: 200 | with torch.no_grad(): 201 | self.lifting.update_grid( 202 | self.initial_mesh + in_grid_displacement, None) 203 | self.projection.update_grid( 204 | None, self.initial_mesh + out_grid_displacement) 205 | 206 | if self.re_grid_input: 207 | inp = self.input_regrider(inp) 208 | if self.lifting: 209 | # print("In Lifting") 210 | x = self.lifting(inp) 211 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0]) 212 | else: 213 | x = inp 214 | 215 | if self.domain_padding is not None: 216 | x = self.domain_padding.pad(x) 217 | 218 | output_shape_en = [int(round(i * j)) for (i, 219 | j) in zip(x.shape[-self.n_dim:], 220 | self.end_to_end_scaling)] 221 | 222 | cur_output_shape = None 223 | for layer_idx in range(self.n_layers): 224 | if layer_idx == self.n_layers - 1: 225 | cur_output_shape = output_shape_en 226 | x = self.base[layer_idx](x, output_shape=cur_output_shape) 227 | 228 | if self.re_grid_output: 229 | x = self.output_regrider(x) 230 | if self.projection: 231 | # print("projection") 232 | x = rearrange(x, 'b c h w -> b (h w) c') 233 | x = self.projection(x) 234 | return x 235 | -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | from layers.gnn_layer import GnnLayer 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from neuralop.layers.mlp import MLPLinear 5 | import numpy as np 6 | import torch 7 | 8 | class GNN(nn.Module): 9 | def __init__(self, 10 | in_dim, 11 | out_dim, 12 | input_grid, 13 | output_grid=None, 14 | n_neigbor=None, 15 | hidden_dim=None, 16 | lifting_dim=None, 17 | n_layers=4, 18 | initial_mesh=None, 19 | non_linearity=F.gelu, 20 | projection=True, 21 | gno_mlp_layers=None, 22 | lifting=True, 23 | ): 24 | super().__init__() 25 | self.n_layers = n_layers 26 | 27 | if output_grid is None: 28 | output_grid = input_grid.clone() 29 | 30 | self.n_dim = input_grid.shape[-1] 31 | 32 | self.in_dim = in_dim 33 | 34 | if hidden_dim is None: 35 | hidden_dim = in_dim 36 | if lifting_dim is None: 37 | lifting_dim = in_dim 38 | if out_dim is None: 39 | out_dim = in_dim 40 | 41 | self.input_grid = input_grid 42 | self.output_grid = output_grid 43 | 44 | self.hidden_dim = hidden_dim 45 | 46 | self.lifting = lifting 47 | self.projection = projection 48 | self.n_neigbor = n_neigbor 49 | self.gno_mlp_layers = gno_mlp_layers 50 | 51 | self.initial_mesh = initial_mesh 52 | # Code for varibale encoding 53 | 54 | # initializing Components 55 | if self.lifting: 56 | print('Using lifing Layer') 57 | self.lifting = MLPLinear( 58 | layers=[self.in_dim, self.hidden_dim], 59 | ) 60 | 61 | self.base = nn.ModuleList([]) 62 | for i in range(self.n_layers): 63 | self.base.append(GnnLayer( 64 | hidden_dim, 65 | hidden_dim, 66 | self.initial_mesh, 67 | self.initial_mesh, 68 | gno_mlp_layers, 69 | lifting_dim, 70 | n_neigbor)) 71 | 72 | if self.projection: 73 | print("Using Projection Layer") 74 | self.projection = MLPLinear( 75 | layers=[self.hidden_dim, out_dim] 76 | ) 77 | 78 | def forward( 79 | self, 80 | inp, 81 | out_grid_displacement=None, 82 | in_grid_displacement=None): 83 | ''' 84 | inp = (batch_size, n_points, in_dims/Channels) 85 | currenly only batch_size = 1 86 | ''' 87 | 88 | if out_grid_displacement is not None: 89 | with torch.no_grad(): 90 | for i in range(self.n_layers): 91 | if i == self.n_layers - 1: 92 | # Doing different mesh for last layer 93 | in_grid = self.initial_mesh + in_grid_displacement 94 | 95 | out_grid = self.initial_mesh + out_grid_displacement 96 | else: 97 | in_grid = self.initial_mesh + in_grid_displacement 98 | out_grid = self.initial_mesh + in_grid_displacement 99 | self.base[i].update_grid( 100 | in_grid, 101 | out_grid) 102 | 103 | if self.lifting: 104 | x = self.lifting(inp) 105 | else: 106 | x = inp 107 | x = x[0, ...] 108 | for layer_idx in range(self.n_layers): 109 | x = self.base[layer_idx](x) 110 | 111 | if self.projection: 112 | x = self.projection(x) 113 | return x[None, ...] 114 | -------------------------------------------------------------------------------- /models/model_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def count_parameters(model): 5 | with torch.no_grad(): 6 | total_count = 0 7 | for p in model.parameters(): 8 | if not p.requires_grad: 9 | continue 10 | pcount = torch.tensor(p.numel()) 11 | total_count += int(pcount.item()) 12 | 13 | return total_count / 1e6 14 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | from layers.gno_layer import GNO 2 | from einops import rearrange 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from layers.regrider import Regird 6 | import numpy as np 7 | import torch 8 | from layers.unet_sublayer import UNet2d 9 | 10 | 11 | class UnetGno(nn.Module): 12 | def __init__(self, 13 | in_dim, 14 | out_dim, 15 | input_grid, 16 | output_grid=None, 17 | grid_size=None, 18 | radius=None, 19 | fixed_neighbour=False, 20 | n_neigbor=10, 21 | hidden_dim=None, 22 | lifting_dim=None, 23 | n_layers=4, 24 | pad_to_size=None, 25 | initial_mesh=None, 26 | non_linearity=F.gelu, 27 | re_grid_input=False, 28 | re_grid_output=False, 29 | projection=True, 30 | gno_mlp_layers=None, 31 | lifting=True, 32 | domain_padding=None, 33 | domain_padding_mode='one-sided', 34 | ): 35 | super().__init__() 36 | self.n_layers = n_layers 37 | if output_grid is None: 38 | output_grid = None # input_grid.clone() 39 | 40 | self.in_dim = in_dim 41 | if hidden_dim is None: 42 | hidden_dim = in_dim 43 | if lifting_dim is None: 44 | lifting_dim = in_dim 45 | if out_dim is None: 46 | out_dim = in_dim 47 | self.re_grid_input = re_grid_input 48 | self.re_grid_output = re_grid_output 49 | 50 | if self.re_grid_input: 51 | self.input_regrider = Regird("equiangular", "legendre-gauss") 52 | if self.re_grid_output: 53 | self.output_regrider = Regird("legendre-gauss", "equiangular") 54 | 55 | self.input_grid = input_grid 56 | self.output_grid = output_grid 57 | self.grid_size = grid_size 58 | 59 | self.hidden_dim = hidden_dim 60 | self.lifting = lifting 61 | self.projection = projection 62 | self.radius = radius 63 | self.fixed_neighbour = fixed_neighbour 64 | self.n_neigbor = n_neigbor 65 | self.gno_mlp_layers = gno_mlp_layers 66 | self.pad_to_size = pad_to_size 67 | 68 | self.initial_mesh = initial_mesh 69 | # Code for varibale encoding 70 | 71 | # initializing Components 72 | if self.lifting: 73 | print('Using lifing Layer') 74 | 75 | # a varibale + it's varibale encoding + the static channen together 76 | # constitute a token 77 | 78 | self.lifting = GNO( 79 | in_dim=self.in_dim, 80 | out_dim=hidden_dim, 81 | input_grid=self.input_grid, 82 | output_grid=self.output_grid, 83 | projection_hidden_dim=lifting_dim, 84 | mlp_layers=self.gno_mlp_layers, 85 | radius=self.radius, 86 | fixed_neighbour=self.fixed_neighbour, 87 | n_neigbor=self.n_neigbor) 88 | 89 | self.base = UNet2d(in_channels=hidden_dim, 90 | out_channels=hidden_dim, init_features=hidden_dim) 91 | 92 | if self.projection: 93 | # input and output grid is swapped 94 | 95 | print("Using Projection Layer") 96 | self.projection = GNO( 97 | in_dim=self.hidden_dim, 98 | out_dim=out_dim, 99 | input_grid=self.output_grid, 100 | projection_hidden_dim=lifting_dim, 101 | output_grid=self.input_grid, 102 | mlp_layers=self.gno_mlp_layers, 103 | radius=self.radius, 104 | fixed_neighbour=self.fixed_neighbour, 105 | n_neigbor=self.n_neigbor) 106 | 107 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer): 108 | for k in scalings_per_layer: 109 | initial_scale = np.multiply(initial_scale, k) 110 | initial_scale = initial_scale.tolist() 111 | if len(initial_scale) == 1: 112 | initial_scale = initial_scale[0] 113 | return initial_scale 114 | 115 | def get_device(self,): 116 | return self.cls_token.coefficients_r.device 117 | 118 | def forward( 119 | self, 120 | inp, 121 | out_grid_displacement=None, 122 | in_grid_displacement=None): 123 | ''' 124 | inp = (batch_size, n_points, in_dims/Channels) 125 | currenly only batch_size = 1 126 | ''' 127 | if out_grid_displacement is not None: 128 | with torch.no_grad(): 129 | self.lifting.update_grid( 130 | self.initial_mesh + in_grid_displacement, None) 131 | self.projection.update_grid( 132 | None, self.initial_mesh + out_grid_displacement) 133 | 134 | if self.re_grid_input: 135 | inp = self.input_regrider(inp) 136 | if self.lifting: 137 | # print("In Lifting") 138 | x = self.lifting(inp) 139 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0]) 140 | else: 141 | x = inp 142 | 143 | pad_size_1 = self.pad_to_size[-2] - x.shape[-2] 144 | pad_size_2 = self.pad_to_size[-1] - x.shape[-1] 145 | x = nn.functional.pad( 146 | x, 147 | (pad_size_2 // 2, 148 | pad_size_2 - pad_size_2 // 2, 149 | pad_size_1 // 2, 150 | pad_size_1 - pad_size_1 // 2)) 151 | x = self.base(x) 152 | 153 | 154 | 155 | x = x[:, :, pad_size_1 // 2: -(pad_size_1 - pad_size_1 // 2), 156 | pad_size_2 // 2: -(pad_size_2 - pad_size_2 // 2)] 157 | 158 | if self.re_grid_output: 159 | x = self.output_regrider(x) 160 | if self.projection: 161 | x = rearrange(x, 'b c h w -> b (h w) c') 162 | x = self.projection(x) 163 | return x 164 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | from layers.gno_layer import GNO 2 | from functools import partial 3 | from einops import rearrange 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from layers.regrider import Regird 7 | import numpy as np 8 | import torch 9 | from layers.regular_transformer import vision_transformer 10 | 11 | 12 | class VitGno(nn.Module): 13 | def __init__(self, 14 | in_dim, 15 | out_dim, 16 | input_grid, 17 | output_grid=None, 18 | grid_size=None, 19 | radius=None, 20 | fixed_neighbour=False, 21 | n_neigbor=10, 22 | hidden_dim=None, 23 | lifting_dim=None, 24 | n_layers=4, 25 | initial_mesh=None, 26 | non_linearity=F.gelu, 27 | patch_size=(10, 5), 28 | heads=10, 29 | contraction_factor=128, 30 | re_grid_input=False, 31 | re_grid_output=False, 32 | projection=True, 33 | gno_mlp_layers=None, 34 | lifting=True, 35 | domain_padding=None, 36 | domain_padding_mode='one-sided', 37 | ): 38 | super().__init__() 39 | self.n_layers = n_layers 40 | if output_grid is None: 41 | output_grid = input_grid.clone() 42 | 43 | self.in_dim = in_dim 44 | if hidden_dim is None: 45 | hidden_dim = in_dim 46 | if lifting_dim is None: 47 | lifting_dim = in_dim 48 | if out_dim is None: 49 | out_dim = in_dim 50 | self.re_grid_input = re_grid_input 51 | self.re_grid_output = re_grid_output 52 | 53 | if self.re_grid_input: 54 | self.input_regrider = Regird("equiangular", "legendre-gauss") 55 | if self.re_grid_output: 56 | self.output_regrider = Regird("legendre-gauss", "equiangular") 57 | 58 | self.input_grid = input_grid 59 | self.output_grid = output_grid 60 | self.grid_size = grid_size 61 | 62 | self.hidden_dim = hidden_dim 63 | self.lifting = lifting 64 | self.projection = projection 65 | self.radius = radius 66 | self.fixed_neighbour = fixed_neighbour 67 | self.n_neigbor = n_neigbor 68 | self.gno_mlp_layers = gno_mlp_layers 69 | 70 | # transformers parameters 71 | self.contraction_factor = contraction_factor 72 | self.grid_size = grid_size 73 | self.patch_size = patch_size 74 | self.heads = heads 75 | 76 | self.initial_mesh = initial_mesh 77 | # Code for varibale encoding 78 | 79 | # initializing Components 80 | if self.lifting: 81 | print('Using lifing Layer') 82 | 83 | # a varibale + it's varibale encoding + the static channen together 84 | # constitute a token 85 | 86 | self.lifting = GNO( 87 | in_dim=self.in_dim, 88 | out_dim=hidden_dim, 89 | input_grid=self.input_grid, 90 | output_grid=self.output_grid, 91 | projection_hidden_dim=lifting_dim, 92 | mlp_layers=self.gno_mlp_layers, 93 | radius=self.radius, 94 | fixed_neighbour=self.fixed_neighbour, 95 | n_neigbor=self.n_neigbor) 96 | 97 | self.base = vision_transformer( 98 | image_size=self.grid_size, 99 | patch_size=self.patch_size, 100 | num_classes=1, 101 | dim=self.patch_size[0] * self.patch_size[1] * 102 | self.hidden_dim // self.contraction_factor, 103 | depth=self.n_layers, 104 | channels=hidden_dim, 105 | heads=self.heads, 106 | mlp_dim=self.patch_size[0] * self.patch_size[1] * 107 | hidden_dim // self.contraction_factor, 108 | dropout=0.0, 109 | emb_dropout=0.0 110 | ) 111 | self.expander = nn.Linear( 112 | self.patch_size[0] * 113 | self.patch_size[1] * 114 | hidden_dim // 115 | self.contraction_factor, 116 | self.grid_size[0] * 117 | self.grid_size[1] * 118 | hidden_dim) 119 | if self.projection: 120 | # input and output grid is swapped 121 | 122 | print("Using Projection Layer") 123 | self.projection = GNO( 124 | in_dim=self.hidden_dim, 125 | out_dim=out_dim, 126 | input_grid=self.output_grid, 127 | projection_hidden_dim=lifting_dim, 128 | output_grid=self.input_grid, 129 | mlp_layers=self.gno_mlp_layers, 130 | radius=self.radius, 131 | fixed_neighbour=self.fixed_neighbour, 132 | n_neigbor=self.n_neigbor) 133 | 134 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer): 135 | for k in scalings_per_layer: 136 | initial_scale = np.multiply(initial_scale, k) 137 | initial_scale = initial_scale.tolist() 138 | if len(initial_scale) == 1: 139 | initial_scale = initial_scale[0] 140 | return initial_scale 141 | 142 | def get_device(self,): 143 | return self.cls_token.coefficients_r.device 144 | 145 | def forward( 146 | self, 147 | inp, 148 | out_grid_displacement=None, 149 | in_grid_displacement=None): 150 | ''' 151 | inp = (batch_size, n_points, in_dims/Channels) 152 | currenly only batch_size = 1 153 | ''' 154 | if out_grid_displacement is not None: 155 | with torch.no_grad(): 156 | self.lifting.update_grid( 157 | self.initial_mesh + in_grid_displacement, None) 158 | self.projection.update_grid( 159 | None, self.initial_mesh + out_grid_displacement) 160 | 161 | if self.re_grid_input: 162 | inp = self.input_regrider(inp) 163 | if self.lifting: 164 | # print("In Lifting") 165 | x = self.lifting(inp) 166 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0]) 167 | else: 168 | x = inp 169 | 170 | x = self.base(x) 171 | x = self.expander(x) 172 | x = x.reshape(-1, self.hidden_dim, 173 | self.grid_size[0], self.grid_size[1]) 174 | 175 | if self.re_grid_output: 176 | x = self.output_regrider(x) 177 | if self.projection: 178 | x = rearrange(x, 'b c h w -> b (h w) c') 179 | x = self.projection(x) 180 | return x 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.5 2 | einops==0.8.0 3 | filelock==3.15.4 4 | fsspec==2024.6.1 5 | GitPython==3.1.43 6 | h5py==3.11.0 7 | joblib==1.4.2 8 | Jinja2==3.1.4 9 | matplotlib==3.9.1 10 | networkx 11 | -e git+https://github.com/ashiq24/neuraloperator.git@codano_rep#egg=neuraloperator 12 | numpy==2.0.0 13 | opt-einsum==3.3.0 14 | protobuf==5.27.2 15 | PyYAML==6.0.1 16 | requests==2.32.3 17 | scikit-learn==1.5.1 18 | scipy==1.14.0 19 | sympy==1.13.1 20 | tensorly==0.8.1 21 | torch 22 | torchaudio 23 | torchvision 24 | tensorly-torch==0.5.0 25 | threadpoolctl==3.5.0 26 | tqdm==4.66.4 27 | typing_extensions 28 | urllib3==2.2.2 29 | vit-pytorch==1.7.4 30 | wandb==0.17.5 31 | yarl==1.9.4 32 | zarr==2.18.2 33 | torch_geometric==2.5.3 34 | torch-cluster==1.6.3 35 | torch-sparse==0.6.18 36 | torch_scatter==2.1.2 37 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/test/__init__.py -------------------------------------------------------------------------------- /test/evaluations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data_utils.data_utils import * 3 | import torch.nn as nn 4 | from timeit import default_timer 5 | from models.get_models import * 6 | from tqdm import tqdm 7 | import wandb 8 | from utils import * 9 | from train.trainer import * 10 | 11 | 12 | def missing_variable_testing( 13 | model, 14 | test_loader, 15 | augmenter=None, 16 | stage=StageEnum.PREDICTIVE, 17 | params=None, 18 | variable_encoder=None, 19 | token_expander=None, 20 | initial_mesh=None, 21 | wandb_log=False): 22 | print('Evaluating for Stage: ', stage) 23 | model.eval() 24 | with torch.no_grad(): 25 | ntest = 0 26 | test_l2 = 0 27 | test_l1 = 0 28 | loss_p = nn.MSELoss() 29 | loss_l1 = nn.L1Loss() 30 | t1 = default_timer() 31 | predictions = [] 32 | for data in test_loader: 33 | x, y = data['x'].cuda(), data['y'].cuda() 34 | static_features = data['static_features'] 35 | 36 | if augmenter is not None: 37 | x, _ = batched_masker(x, augmenter) 38 | 39 | inp = prepare_input( 40 | x, 41 | static_features, 42 | params, 43 | variable_encoder, 44 | token_expander, 45 | initial_mesh, 46 | data) 47 | out_grid_displacement, in_grid_displacement = get_grid_displacement( 48 | params, stage, data) 49 | 50 | batch_size = x.shape[0] 51 | out = model(inp, out_grid_displacement=out_grid_displacement, 52 | in_grid_displacement=in_grid_displacement) 53 | 54 | if getattr(params, 'horizontal_skip', False): 55 | out = out + x 56 | 57 | if isinstance(out, (list, tuple)): 58 | out = out[0] 59 | 60 | ntest += 1 61 | target = y.clone() 62 | 63 | predictions.append((out, target)) 64 | 65 | test_l2 += loss_p(target.reshape(batch_size, -1), 66 | out.reshape(batch_size, -1)).item() 67 | test_l1 += loss_l1(target.reshape(batch_size, -1), 68 | out.reshape(batch_size, -1)).item() 69 | 70 | test_l2 /= ntest 71 | test_l1 /= ntest 72 | t2 = default_timer() 73 | avg_time = (t2 - t1) / ntest 74 | 75 | wandb.log({'Augmented test_error_l2': test_l2}, commit=True) 76 | wandb.log({'Augmented test_error_l1': test_l1}, commit=True) 77 | wandb.log({'Avg test_time': avg_time}, commit=True) 78 | print(f"Augmented Test Error {stage}: ", test_l2) 79 | 80 | if hasattr(params, 'save_predictions') and params.save_predictions: 81 | torch.save(predictions[:50], f'../xy/predictions_{params.config}.pt') 82 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | from .new_adam import * 3 | -------------------------------------------------------------------------------- /train/new_adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from typing import List, Optional 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | def adam(params: List[Tensor], 9 | grads: List[Tensor], 10 | exp_avgs: List[Tensor], 11 | exp_avg_sqs: List[Tensor], 12 | max_exp_avg_sqs: List[Tensor], 13 | state_steps: List[int], 14 | *, 15 | amsgrad: bool, 16 | beta1: float, 17 | beta2: float, 18 | lr: float, 19 | weight_decay: float, 20 | eps: float): 21 | r"""Functional API that performs Adam algorithm computation. 22 | See :class:`~torch.optim.Adam` for details. 23 | """ 24 | 25 | for i, param in enumerate(params): 26 | 27 | grad = grads[i] 28 | exp_avg = exp_avgs[i] 29 | exp_avg_sq = exp_avg_sqs[i] 30 | step = state_steps[i] 31 | 32 | bias_correction1 = 1 - beta1 ** step 33 | bias_correction2 = 1 - beta2 ** step 34 | 35 | if weight_decay != 0: 36 | grad = grad.add(param, alpha=weight_decay) 37 | 38 | # Decay the first and second moment running average coefficient 39 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 40 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) 41 | if amsgrad: 42 | # Maintains the maximum of all 2nd moment running avg. till now 43 | torch.maximum( 44 | max_exp_avg_sqs[i], 45 | exp_avg_sq, 46 | out=max_exp_avg_sqs[i]) 47 | # Use the max. for normalizing running avg. of gradient 48 | denom = ( 49 | max_exp_avg_sqs[i].sqrt() / 50 | math.sqrt(bias_correction2)).add_(eps) 51 | else: 52 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 53 | 54 | step_size = lr / bias_correction1 55 | 56 | param.addcdiv_(exp_avg, denom, value=-step_size) 57 | 58 | 59 | class Adam(Optimizer): 60 | r"""Implements Adam algorithm. 61 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 62 | The implementation of the L2 penalty follows changes proposed in 63 | `Decoupled Weight Decay Regularization`_. 64 | Args: 65 | params (iterable): iterable of parameters to optimize or dicts defining 66 | parameter groups 67 | lr (float, optional): learning rate (default: 1e-3) 68 | betas (Tuple[float, float], optional): coefficients used for computing 69 | running averages of gradient and its square (default: (0.9, 0.999)) 70 | eps (float, optional): term added to the denominator to improve 71 | numerical stability (default: 1e-8) 72 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 73 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 74 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 75 | (default: False) 76 | .. _Adam\: A Method for Stochastic Optimization: 77 | https://arxiv.org/abs/1412.6980 78 | .. _Decoupled Weight Decay Regularization: 79 | https://arxiv.org/abs/1711.05101 80 | .. _On the Convergence of Adam and Beyond: 81 | https://openreview.net/forum?id=ryQu7f-RZ 82 | """ 83 | 84 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 85 | weight_decay=0, amsgrad=False): 86 | if not 0.0 <= lr: 87 | raise ValueError("Invalid learning rate: {}".format(lr)) 88 | if not 0.0 <= eps: 89 | raise ValueError("Invalid epsilon value: {}".format(eps)) 90 | if not 0.0 <= betas[0] < 1.0: 91 | raise ValueError( 92 | "Invalid beta parameter at index 0: {}".format( 93 | betas[0])) 94 | if not 0.0 <= betas[1] < 1.0: 95 | raise ValueError( 96 | "Invalid beta parameter at index 1: {}".format( 97 | betas[1])) 98 | if not 0.0 <= weight_decay: 99 | raise ValueError( 100 | "Invalid weight_decay value: {}".format(weight_decay)) 101 | defaults = dict(lr=lr, betas=betas, eps=eps, 102 | weight_decay=weight_decay, amsgrad=amsgrad) 103 | super(Adam, self).__init__(params, defaults) 104 | 105 | def __setstate__(self, state): 106 | super(Adam, self).__setstate__(state) 107 | for group in self.param_groups: 108 | group.setdefault('amsgrad', False) 109 | 110 | @torch.no_grad() 111 | def step(self, closure=None): 112 | """Performs a single optimization step. 113 | Args: 114 | closure (callable, optional): A closure that reevaluates the model 115 | and returns the loss. 116 | """ 117 | loss = None 118 | if closure is not None: 119 | with torch.enable_grad(): 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | params_with_grad = [] 124 | grads = [] 125 | exp_avgs = [] 126 | exp_avg_sqs = [] 127 | max_exp_avg_sqs = [] 128 | state_steps = [] 129 | beta1, beta2 = group['betas'] 130 | 131 | for p in group['params']: 132 | if p.grad is not None: 133 | params_with_grad.append(p) 134 | if p.grad.is_sparse: 135 | raise RuntimeError( 136 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 137 | grads.append(p.grad) 138 | 139 | state = self.state[p] 140 | # Lazy state initialization 141 | if len(state) == 0: 142 | state['step'] = 0 143 | # Exponential moving average of gradient values 144 | state['exp_avg'] = torch.zeros_like( 145 | p, memory_format=torch.preserve_format) 146 | # Exponential moving average of squared gradient values 147 | state['exp_avg_sq'] = torch.zeros_like( 148 | p, memory_format=torch.preserve_format) 149 | if group['amsgrad']: 150 | # Maintains max of all exp. moving avg. of sq. 151 | # grad. values 152 | state['max_exp_avg_sq'] = torch.zeros_like( 153 | p, memory_format=torch.preserve_format) 154 | 155 | exp_avgs.append(state['exp_avg']) 156 | exp_avg_sqs.append(state['exp_avg_sq']) 157 | 158 | if group['amsgrad']: 159 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 160 | 161 | # update the steps for each param group update 162 | state['step'] += 1 163 | # record the step after step update 164 | state_steps.append(state['step']) 165 | 166 | adam(params_with_grad, 167 | grads, 168 | exp_avgs, 169 | exp_avg_sqs, 170 | max_exp_avg_sqs, 171 | state_steps, 172 | amsgrad=group['amsgrad'], 173 | beta1=beta1, 174 | beta2=beta2, 175 | lr=group['lr'], 176 | weight_decay=group['weight_decay'], 177 | eps=group['eps']) 178 | return loss 179 | -------------------------------------------------------------------------------- /train/trainer.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import logging 3 | from timeit import default_timer 4 | import tqdm 5 | import wandb 6 | from data_utils.data_utils import * 7 | import torch 8 | from models.get_models import * 9 | from torch import nn 10 | from torch.utils import data 11 | from .new_adam import Adam 12 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau 13 | from utils import * 14 | 15 | def get_grid_displacement(params, stage, data): 16 | if params.grid_type == "non uniform": 17 | with torch.no_grad(): 18 | if stage == StageEnum.RECONSTRUCTIVE: 19 | out_grid_displacement = data['d_grid_x'].cuda()[0] 20 | in_grid_displacement = data['d_grid_x'].cuda()[0] 21 | else: 22 | out_grid_displacement = data['d_grid_y'].cuda()[0] 23 | in_grid_displacement = data['d_grid_x'].cuda()[0] 24 | else: 25 | out_grid_displacement = None 26 | in_grid_displacement = None 27 | return out_grid_displacement, in_grid_displacement 28 | 29 | def trainer( 30 | model, 31 | train_loader, 32 | test_loader, 33 | params, 34 | wandb_log=False, 35 | log_test_interval=1, 36 | stage=StageEnum.RECONSTRUCTIVE, 37 | variable_encoder=None, 38 | token_expander=None, 39 | initial_mesh=None, 40 | ): 41 | lr = params.lr 42 | weight_decay = params.weight_decay 43 | scheduler_step = params.scheduler_step 44 | scheduler_gamma = params.scheduler_gamma 45 | epochs = params.epochs 46 | weight_path = params.weight_path 47 | optimizer = Adam(model.parameters(), lr=lr, 48 | weight_decay=weight_decay, amsgrad=False) 49 | if params.scheduler_type == 'step': 50 | scheduler = StepLR( 51 | optimizer, step_size=scheduler_step, gamma=scheduler_gamma) 52 | else: 53 | scheduler = ReduceLROnPlateau( 54 | optimizer, patience=scheduler_step, factor=scheduler_gamma) 55 | 56 | loss_p = nn.MSELoss(reduction='sum') 57 | 58 | for ep in range(epochs): 59 | model.train() 60 | t1 = default_timer() 61 | train_l2 = 0 62 | train_count = 0 63 | train_loader_iter = tqdm.tqdm( 64 | train_loader, desc=f'Epoch {ep}/{epochs}', leave=False, ncols=100) 65 | 66 | for data in train_loader_iter: 67 | optimizer.zero_grad() 68 | x, y = data['x'].cuda(), data['y'].cuda() 69 | static_features = data['static_features'] 70 | if stage == StageEnum.RECONSTRUCTIVE and params.masking: 71 | x = model.do_mask(x) 72 | 73 | inp = prepare_input( 74 | x, 75 | static_features, 76 | params, 77 | variable_encoder, 78 | token_expander, 79 | initial_mesh, 80 | data) 81 | batch_size = x.shape[0] 82 | if params.grid_type == "non uniform": 83 | out_grid_displacement, in_grid_displacement = get_grid_displacement( 84 | params, stage, data) 85 | elif params.grid_type == "uniform": 86 | out_grid_displacement = None 87 | in_grid_displacement = None 88 | 89 | out = model(inp, out_grid_displacement=out_grid_displacement, 90 | in_grid_displacement=in_grid_displacement) 91 | 92 | if isinstance(out, (list, tuple)): 93 | out = out[0] 94 | if getattr(params, 'horizontal_skip', False): 95 | out = out + x 96 | 97 | train_count += 1 98 | target = x.clone() if stage == StageEnum.RECONSTRUCTIVE else y.clone() 99 | loss = loss_p(target.reshape( 100 | batch_size, -1), out.reshape(batch_size, -1)) / (x.shape[0] * x.shape[-1] * x.shape[-2]) 101 | loss.backward() 102 | 103 | if params.clip_gradient: 104 | nn.utils.clip_grad_value_( 105 | model.parameters(), params.gradient_clip_value) 106 | 107 | optimizer.step() 108 | train_l2 += loss.item() 109 | del x, y, out, loss 110 | gc.collect() 111 | 112 | torch.cuda.empty_cache() 113 | avg_train_l2 = train_l2 / train_count 114 | 115 | if params.scheduler_type != 'step': 116 | scheduler.step(avg_train_l2) 117 | else: 118 | scheduler.step() 119 | 120 | t2 = default_timer() 121 | epoch_train_time = t2 - t1 122 | 123 | if ep % log_test_interval == 0: 124 | values_to_log = dict(train_err=avg_train_l2, time=epoch_train_time) 125 | print( 126 | f"Epoch {ep}: Time: {epoch_train_time:.3f}s, Loss: {avg_train_l2:.7f}") 127 | if wandb_log: 128 | wandb.log(values_to_log, commit=True) 129 | 130 | if ep % params.weight_saving_interval == 0 or ep == epochs - 1: 131 | stage_string = 'ssl' if stage == StageEnum.RECONSTRUCTIVE else 'sl' 132 | if params.nettype != 'transformer': 133 | torch.save(model.state_dict(), weight_path + 134 | params.config + "_" + str(ep) + '.pt') 135 | else: 136 | weight_path_model_encoder = weight_path + params.config + \ 137 | "_" + stage_string + '_encoder_' + str(ep) + '.pt' 138 | weight_path_model_decoder = weight_path + params.config + \ 139 | "_" + stage_string + '_decoder_' + str(ep) + '.pt' 140 | weight_path_whole_model = weight_path + params.config + \ 141 | "_" + stage_string + '_whole_model_' + str(ep) + '.pt' 142 | torch.save(model.encoder.state_dict(), 143 | weight_path_model_encoder) 144 | torch.save(model.decoder.state_dict(), 145 | weight_path_model_decoder) 146 | torch.save(model.state_dict(), weight_path_whole_model) 147 | if variable_encoder is not None: 148 | variable_path = weight_path + params.config + \ 149 | "_variable_encoder_" + str(ep) 150 | variable_encoder.save_all_encoder(variable_path) 151 | 152 | model.eval() 153 | test_l2 = 0.0 154 | ntest = 0 155 | loss_p = nn.MSELoss(reduction='sum') 156 | with torch.no_grad(): 157 | for data in test_loader: 158 | x, y = data['x'].cuda(), data['y'].cuda() 159 | static_features = data['static_features'] 160 | inp = prepare_input( 161 | x, 162 | static_features, 163 | params, 164 | variable_encoder, 165 | token_expander, 166 | initial_mesh, 167 | data) 168 | out_grid_displacement, in_grid_displacement = get_grid_displacement( 169 | params, stage, data) 170 | batch_size = x.shape[0] 171 | out = model(inp, in_grid_displacement=in_grid_displacement, 172 | out_grid_displacement=out_grid_displacement) 173 | 174 | if isinstance(out, (list, tuple)): 175 | out = out[0] 176 | if getattr(params, 'horizontal_skip', False): 177 | out = out + x 178 | 179 | ntest += x.shape[0] 180 | target = x.clone() if stage == StageEnum.RECONSTRUCTIVE else y.clone() 181 | test_l2 += loss_p(target.reshape(batch_size, -1), 182 | out.reshape(batch_size, -1)).item() 183 | 184 | test_l2 /= (ntest * x.shape[-1] * x.shape[-2]) 185 | t2 = default_timer() 186 | 187 | if wandb_log: 188 | stage_string = 'ssl' if stage == StageEnum.RECONSTRUCTIVE else 'sl' 189 | wandb.log({'test_error_' + stage_string: test_l2}, commit=True) 190 | print("Test Error : " + stage_string, test_l2) 191 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import os 4 | import pathlib 5 | import psutil 6 | import re 7 | import signal 8 | 9 | from typing import List 10 | 11 | import h5py 12 | # from haikunator import Haikunator 13 | import numpy as np 14 | import psutil 15 | import torch 16 | import torch.nn as nn 17 | 18 | # HAIKU = haikunator.Haikunator() 19 | 20 | 21 | def prepare_input( 22 | x, 23 | static_features, 24 | params, 25 | variable_encoder, 26 | token_expander, 27 | initial_mesh, 28 | data): 29 | if variable_encoder is not None and token_expander is not None: 30 | if params.grid_type == 'uniform': 31 | inp = token_expander(x, variable_encoder(x), 32 | static_features.cuda()) 33 | elif params.grid_type == 'non uniform': 34 | initial_mesh = initial_mesh.cuda() 35 | equation = [i[0] for i in data['equation']] 36 | inp = token_expander( 37 | x, 38 | variable_encoder( 39 | initial_mesh + 40 | data['d_grid_x'].cuda()[0], 41 | equation), 42 | static_features.cuda()) 43 | elif params.n_static_channels > 0 and params.grid_type == 'non uniform': 44 | inp = torch.cat( 45 | [x, static_features[:, :, :params.n_static_channels].cuda()], dim=-1) 46 | else: 47 | inp = x 48 | return inp 49 | 50 | 51 | def get_wandb_api_key(api_key_file="config/wandb_api_key.txt"): 52 | try: 53 | return os.environ["WANDB_API_KEY"] 54 | except KeyError: 55 | with open(api_key_file, "r") as f: 56 | key = f.read() 57 | return key.strip() 58 | 59 | def get_mesh(params): 60 | """Get the mesh from a location.""" 61 | if hasattr(params, "text_mesh") and params.text_mesh: 62 | # load mesh_x and mesh_y from txt file as np array 63 | mesh_x = np.loadtxt(params.mesh_x) 64 | mesh_y = np.loadtxt(params.mesh_y) 65 | # create mesh from mesh_x and mesh_y 66 | mesh = np.zeros((mesh_x.shape[0], 2)) 67 | mesh[:, 0] = mesh_x 68 | mesh[:, 1] = mesh_y 69 | else: 70 | h5f = h5py.File(params.input_mesh_location, 'r') 71 | mesh = h5f['mesh/coordinates'] 72 | 73 | if params.super_resolution: 74 | # load mesh_x and mesh_y from txt file as np array 75 | mesh_x = np.loadtxt(params.super_resolution_mesh_x) 76 | mesh_y = np.loadtxt(params.super_resolution_mesh_y) 77 | # create mesh from mesh_x and mesh_y 78 | mesh_sup = np.zeros((mesh_x.shape[0], 2)) 79 | mesh_sup[:, 0] = mesh_x 80 | mesh_sup[:, 1] = mesh_y 81 | # merge it with the original mesh 82 | mesh = np.concatenate((mesh, mesh_sup), axis=0) 83 | 84 | print("Super Resolution Mesh Shape", mesh.shape) 85 | 86 | if hasattr( 87 | params, 88 | 'sub_sample_size') and params.sub_sample_size is not None: 89 | mesh_size = mesh.shape[0] 90 | indexs = [i for i in range(mesh_size)] 91 | np.random.seed(params.random_seed) 92 | sub_indexs = np.random.choice( 93 | indexs, params.sub_sample_size, replace=False) 94 | mesh = mesh[sub_indexs, :] 95 | 96 | return mesh[:] 97 | 98 | 99 | # TODO add collision checks 100 | # TODO add opts to toggle haiku and date fixes 101 | def save_model( 102 | model, 103 | directory: pathlib.Path, 104 | stage=None, 105 | sep='_', 106 | name=None, 107 | config=None): 108 | """Saves a model with a unique prefix/suffix 109 | 110 | The model is prefixed with is date (formatted like YYMMDDHHmm) 111 | and suffixed with a "Heroku-like" name (for verbal reference). 112 | 113 | Params: 114 | --- 115 | stage: None | StageEnum 116 | Controls the infix of the model name according to the following mapping: 117 | | None -> "model" 118 | | RECONSTRUCTIVE -> "reconstructive" 119 | | PREDICTIVE -> "predictive" 120 | """ 121 | prefix = datetime.datetime.utcnow().strftime("%y%m%d%H%M") 122 | infix = "model" 123 | if stage is not None: 124 | infix = stage.value.lower() 125 | # suffix = Haikunator.haikunate(token_length=0, delimiter=sep) 126 | 127 | torch.save(model.state_dict(), directory / f"{name}{sep}{config}{sep}.pth") 128 | 129 | 130 | def extract_pids(message) -> List[int]: 131 | # Assume `message` has a preamble followed by a sequence of tokens like 132 | # "Process \d+" with extra characters in between such tokens. 133 | 134 | pattern = re.compile("(Process \\d+)") 135 | # Contains "Process" tokens and extra characters, interleaved: 136 | tokens = pattern.split(message) 137 | # print('\n'.join(map(repr, zip(split[1::2], split[2::2])))) 138 | 139 | pattern2 = re.compile("(\\d+)") 140 | # print('\n'.join([repr((s, pattern2.search(t)[0])) for t in tokens[1::2]])) 141 | pids = [int(pattern2.search(t)[0]) for t in tokens[1::2]] 142 | 143 | return pids 144 | 145 | 146 | # https://psutil.readthedocs.io/en/latest/#kill-process-tree 147 | def signal_process_tree( 148 | pid, 149 | sig=signal.SIGTERM, 150 | include_parent=True, 151 | timeout=None, 152 | on_terminate=None, 153 | logger=None, 154 | ): 155 | """Kill a process tree (including grandchildren) with signal ``sig`` 156 | 157 | Return a (gone, still_alive) tuple. 158 | 159 | Parameters 160 | --- 161 | timeout: float 162 | Time, in seconds, to wait on each signaled process. 163 | on_terminate: Optional[Callable] 164 | A callback function which is called as soon as a child terminates. 165 | Optional. 166 | """ 167 | assert pid != os.getpid(), "won't kill myself" 168 | parent = psutil.Process(pid) 169 | children = parent.children(recursive=True) 170 | if include_parent: 171 | children.append(parent) 172 | if logger is None: 173 | logger = logging.getLogger() 174 | 175 | wait_children = [] 176 | for p in children: 177 | try: 178 | p.send_signal(sig) 179 | wait_children.append(p) 180 | except psutil.AccessDenied: 181 | logger.error(f"Unable to terminate Process {pid} (AccessDenied)") 182 | except psutil.NoSuchProcess: 183 | pass 184 | 185 | gone, alive = psutil.wait_procs( 186 | wait_children, 187 | timeout=timeout, 188 | callback=on_terminate, 189 | ) 190 | return (gone, alive) 191 | 192 | 193 | def count_model_params(model): 194 | """Returns the total number of parameters of a PyTorch model 195 | 196 | Notes 197 | ----- 198 | One complex number is counted as two parameters (we count real and imaginary parts)' 199 | """ 200 | return sum( 201 | [p.numel() * 2 if p.is_complex() else p.numel() 202 | for p in model.parameters()] 203 | ) 204 | 205 | 206 | def signal_my_processes( 207 | username, 208 | pids, 209 | sig=signal.SIGTERM, 210 | logger=None, 211 | ): 212 | if logger is None: 213 | logger = logging.getLogger() 214 | my_pids = [] 215 | for pid in pids: 216 | p = psutil.Process(pid) 217 | with p.oneshot(): 218 | p = p.as_dict(attrs=["username", "status"]) 219 | 220 | # TODO add other states to the filter 221 | if p["username"] == username and p["status"] == "sleeping": 222 | my_pids.append(pid) 223 | else: 224 | _p = {"pid": pid, **p} 225 | logger.warning(f"Cannot signal process: {_p}") 226 | 227 | for my_pid in my_pids: 228 | gone, alive = signal_process_tree(pid, sig, timeout=60, logger=logger) 229 | logger.info(f"{gone=}, {alive=}") 230 | --------------------------------------------------------------------------------