├── .gitignore
├── README.md
├── YParams.py
├── baseline_utlis
├── __init__.py
├── rigid_neighbor.py
└── test_rigid_neightbour.py
├── config
├── RB_config.yaml
└── ssl_ns_elastic.yaml
├── data_utils
├── __init__.py
├── data_loaders.py
└── data_utils.py
├── exps_FSI.sh
├── exps_RB.sh
├── fsi_animation_dx.gif
├── fsi_animation_pressue.gif
├── layers
├── __init__.py
├── codano_block_2D.py
├── codano_block_nd.py
├── fino_2D.py
├── fino_nd.py
├── gnn_layer.py
├── gno_layer.py
├── regrider.py
├── regular_transformer.py
├── unet3d.py
├── unet_sublayer.py
└── variable_encoding.py
├── main.py
├── models
├── __init__.py
├── codano.py
├── codano_gino.py
├── deeponet.py
├── fno_gino.py
├── get_models.py
├── gnn.py
├── model_helpers.py
├── unet.py
└── vit.py
├── requirements.txt
├── test
├── __init__.py
└── evaluations.py
├── train
├── __init__.py
├── new_adam.py
└── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ipynb
2 | wandb/*
3 | *__pycache__/*
4 | */__pycache__/*
5 | *.out
6 | .ipynb_cache/*
7 | .ipynb_checkpoints/*
8 | *key.txt
9 | wandb/*
10 | */.ipynb_cache/*
11 | */.ipynb_checkpoints/*
12 | */.ipynb_checkpoints/*
13 | .ipynb_checkpoints/*
14 |
15 | *__pycache__/*
16 | */__pycache__/*
17 | *.out
18 |
19 | # Virtual environment(s):
20 | .venv/*
21 | weights/*
22 | config/wandb_api_key.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Pretraining Codomain Attention Neural Operators for Solving Multiphysics PDEs
2 |
3 | > [Paper Link](https://arxiv.org/pdf/2403.12553.pdf)
4 |
5 | > **🚀🚀 HOW TO USE CoDA-NO MODEL USING `neuraloperator`** [](https://colab.research.google.com/drive/1W6Qy5Mk_vEjZgrA0tWMespXqKEYDOdc6?usp=sharing)
6 |
7 | ## Model Architecture
8 |
9 |
10 |
11 | Architecture of the Codomain Attention Neural Operator
12 |
13 | Each physical variable (or co-domain) of the input function is concatenated with variable-specific positional encoding (VSPE). Each variable, along with the VSPE, is passed through a GNO layer, which maps from the given non-uniform geometry to a latent regular grid. Then, the output on a uniform grid
14 | is passed through a series of CoDA-NO layers. Lastly, the output of the stacked CoDA-NO layers is mapped onto the domain of the
15 | output geometry for each query point using another GNO layer.
16 |
17 | At each CoDA-NO layer, the input function is tokenized codomain-wise to generate token functions. Each token function is passed through the K, Q, and V operators to get key, query, and value functions. The output function is calculated by extending the self-attention mechanism to the function space.
18 |
19 |
20 | ## Navier Stokes+Elastic Wave and Navier Stokes Dataset
21 |
22 | The fluid-solid interaction dataset is available at [HuggingFace](https://huggingface.co/datasets/ashiq24/FSI-pde-dataset). To download, please use the code
23 | ```python
24 | from huggingface_hub import snapshot_download
25 |
26 | folder_path = snapshot_download(
27 | repo_id="ashiq24/FSI-pde-dataset",
28 | repo_type="dataset",
29 | allow_patterns=["fsi-data/*"]
30 | )
31 | ```
32 | ### Data Set Structure
33 |
34 | **Displacement Field**
35 | 
36 |
37 | **Fluid Structure Interaction(NS +Elastic wave)**
38 | The `fsi-data` folder contains simulation data organized by various parameters (`mu`, `x1`, `x2`) where `mu` determines the viscosity and `x1` and `x2` are the parameters of the inlet condition. The dataset includes files for mesh, displacement, velocity, and pressure.
39 |
40 | This dataset structure is detailed below:
41 |
42 | ```plaintext
43 | fsi-data/
44 | ├── mesh.h5 # Initial mesh
45 | ├── mu=1.0/ # Simulation results for mu = 1.0
46 | │ ├── x1=-4/ # Inlet parameter x1 = -4
47 | │ │ ├── x2=-4/ # Inlet parameter for x2 = -4
48 | │ │ │ └── visualization/
49 | │ │ │ ├── displacement.h5 # Displacements for mu=1.0, x1=-4, x2=-4
50 | │ │ │ ├── velocity.h5 # Velocity field for mu=1.0, x1=-4, x2=-4
51 | │ │ │ └── pressure.h5 # Pressure field for mu=1.0, x1=-4, x2=-4
52 | │ │ ├── x2=-2/
53 | │ │ │ └── visualization/
54 | │ │ │ ├── displacement.h5
55 | │ │ │ ├── velocity.h5
56 | │ │ │ └── pressure.h5
57 | │ │ └── ... # Other x2 values for x1 = -4
58 | │ ├── x1=-2/
59 | │ │ ├── x2=-4/
60 | │ │ │ └── visualization/
61 | │ │ │ ├── displacement.h5
62 | │ │ │ ├── velocity.h5
63 | │ │ │ └── pressure.h5
64 | │ │ └── ... # Other x2 values for x1 = -2
65 | │ └── ... # Other x1 values for mu = 1.0
66 | ├── mu=5.0/ # Simulation results for mu = 5.0
67 | │ └── ... # Similar structure as mu=1.0
68 | └── mu=10.0/ # Simulation results for mu = 10.0
69 | └── ... # Similar structure as mu=1.0
70 | ```
71 | The dataset has a dataloader and visualization code. Also, the `NsElasticDataset` class in `data_utils/data_loaders.py` loads data automatically for all specified `mu`s and inlet conditions (`x1` and `x2`).
72 |
73 | **Fluid Motions with Non-deformable Solid(NS)** is stored in `cfd-data`
74 |
75 | ## Rayleigh–Bénard convection
76 | Huggingface dataset link: [Rayleigh_Benard_Convection](https://huggingface.co/datasets/ashiq24/Rayleigh_Benard_Convection)
77 |
78 |
79 |
80 |
81 | ## Experiments
82 |
83 | > ⚠️ **Note:** This repository uses an older version of the `neuralop` library. For a version compatible with the latest `neuralop` library, please refer to the following implementation:
84 |
85 | > **The codomain attention layer is now available through the `neuraloperator` library** ([implementation](https://github.com/neuraloperator/neuraloperator/blob/main/neuralop/layers/coda_layer.py)).
86 |
87 | > **Also, the model is available through the `neuraloperator` library, see** [](https://colab.research.google.com/drive/1W6Qy5Mk_vEjZgrA0tWMespXqKEYDOdc6?usp=sharing)
88 |
89 | ### Installations
90 | The configurations for all the experiments are at `config/ssl_ns_elastic.yaml` (for fluid-structure interaction) and `config/RB_config.yaml` (For the Releigh Bernard system).
91 |
92 | To set up the environments and install the dependencies, please run the following command:
93 | ```bash
94 | pip install -r requirements.txt
95 | ```
96 | It requires `python=3.11.9`, and the `torch` installations need to be tailored to your machine's specific Cuda version. Also, the installation of torch_geometric and torch_scatter should match the local machine's Cuda version. More at the [installation guide](https://pytorch-geometric.readthedocs.io/en/latest/).
97 |
98 | **Shortcut:** If you already use the `neuraloprator` package, we have installed most of the packages. Then, you just need to execute the following line to roll back to a compatible version.
99 |
100 | ```
101 | pip install -e git+https://github.com/ashiq24/neuraloperator.git@codano_rep#egg=neuraloperator
102 | ```
103 |
104 | We are going to release the CoDA-NO layers and models soon as part of the `neural operator` library.
105 |
106 | ### Running Experiments
107 | To run the experiments, download the datasets, update the "input_mesh_location" and "data_location" in the config file, update the Wandb credentials, and execute the following command
108 |
109 | ```
110 | python main.py --exp (FSI/RB) --config "config name" --ntrain N
111 | ```
112 |
113 | `--exp` : Determines which experiment we want to run, 'FSI' (fluid-structure interaction) or 'RB' (Releigh Bernard)
114 |
115 | `--config`: Determines which configuration to use from the config file 'config/ssl_ns_elastic.yaml/RB_config.yaml`.
116 |
117 | `--ntrain`: Determines Number of training data points.
118 |
119 | ## Scripts
120 | For training CoDA-NO architecture on NS/NS+EW (FSI) and Releigh Bernard convection datasets (both pre-training and fine-tuning), please execute the following scrips:
121 | ```
122 | exps_FSI.sh
123 | exps_RB.sh
124 | ```
125 |
126 |
127 | ## Reference
128 | If you find this paper and code useful in your research, please consider citing:
129 | ```bibtex
130 | @article{rahman2024pretraining,
131 | title={Pretraining Codomain Attention Neural Operators for Solving Multiphysics PDEs},
132 | author={Rahman, Md Ashiqur and George, Robert Joseph and Elleithy, Mogab and Leibovici, Daniel and Li, Zongyi and Bonev, Boris and White, Colin and Berner, Julius and Yeh, Raymond A and Kossaifi, Jean and Azizzadenesheli, Kamyar and Anandkumar, Anima},
133 | journal={Advances in Neural Information Processing Systems},
134 | volume={37}
135 | year={2024}
136 | }
137 | ```
138 |
--------------------------------------------------------------------------------
/YParams.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import pprint
4 |
5 | from ruamel.yaml import YAML
6 |
7 |
8 | class ParamsBase:
9 | """Convenience wrapper around a dictionary
10 |
11 | Allows referring to dictionary items as attributes, and tracking which
12 | attributes are modified.
13 | """
14 |
15 | def __init__(self):
16 | self._original_attrs = None
17 | self.params = {}
18 | self._original_attrs = list(self.__dict__)
19 |
20 | def __getitem__(self, key):
21 | return self.params[key]
22 |
23 | def __setitem__(self, key, val):
24 | self.params[key] = val
25 | self.__setattr__(key, val)
26 |
27 | def __contains__(self, key):
28 | return key in self.params
29 |
30 | def get(self, key, default=None):
31 | if hasattr(self, key):
32 | return getattr(self, key)
33 | else:
34 | return self.params.get(key, default)
35 |
36 | def to_dict(self):
37 | new_attrs = {
38 | key: val for key, val in vars(self).items()
39 | if key not in self._original_attrs
40 | }
41 | return {**self.params, **new_attrs}
42 |
43 | @staticmethod
44 | def from_json(path: str) -> "ParamsBase":
45 | with open(path) as f:
46 | c = json.load(f)
47 | params = ParamsBase()
48 | params.update_params(c)
49 | return params
50 |
51 | def update_params(self, config):
52 | for key, val in config.items():
53 | if val == 'None':
54 | val = None
55 |
56 | if type(val) == dict:
57 | child = ParamsBase()
58 | child.update_params(val)
59 | val = child
60 |
61 | self.params[key] = val
62 | self.__setattr__(key, val)
63 |
64 |
65 | class YParams(ParamsBase):
66 | def __init__(self, yaml_filename, config_name, print_params=False):
67 | """Open parameters stored with ``config_name`` in the yaml file ``yaml_filename``"""
68 | super().__init__()
69 | self._yaml_filename = yaml_filename
70 | self._config_name = config_name
71 |
72 | with open(yaml_filename) as _file:
73 | d = YAML().load(_file)[config_name]
74 |
75 | self.update_params(d)
76 |
77 | if print_params:
78 | print("------------------ Configuration ------------------")
79 | for k, v in d.items():
80 | print(k, end='=')
81 | pprint.pprint(v)
82 | print("---------------------------------------------------")
83 |
84 | def log(self):
85 | logging.info("------------------ Configuration ------------------")
86 | logging.info("Configuration file: " + str(self._yaml_filename))
87 | logging.info("Configuration name: " + str(self._config_name))
88 | for key, val in self.to_dict().items():
89 | logging.info(str(key) + ' ' + str(val))
90 | logging.info("---------------------------------------------------")
--------------------------------------------------------------------------------
/baseline_utlis/__init__.py:
--------------------------------------------------------------------------------
1 | from .rigid_neighbor import *
2 |
--------------------------------------------------------------------------------
/baseline_utlis/rigid_neighbor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def simple_neighbor_search(
6 | data: torch.Tensor,
7 | queries: torch.Tensor,
8 | n_neigbor: float):
9 | """
10 |
11 | Parameters
12 | ----------
13 | Density-Based Spatial Clustering of Applications with Noise
14 | data : torch.Tensor
15 | vector of data points from which to find neighbors
16 | queries : torch.Tensor
17 | centers of neighborhoods
18 |
19 | """
20 |
21 | # shaped num query points x num data points
22 | dists = torch.cdist(queries, data).to(queries.device)
23 | sorted_dist, _ = torch.sort(dists, dim=1)
24 | k = sorted_dist[:, n_neigbor]
25 | dists = dists - k[:, None]
26 | in_nbr = torch.where(dists < 0, 1., 0.) # i,j is one if j is i's neighbor
27 | # only keep the column indices
28 | nbr_indices = in_nbr.nonzero()[:, 1:].reshape(-1,)
29 | # num points in each neighborhood, summed cumulatively
30 | nbrhd_sizes = torch.cumsum(torch.sum(in_nbr, dim=1), dim=0)
31 | splits = torch.cat((torch.tensor([0.]).to(queries.device), nbrhd_sizes))
32 | nbr_dict = {}
33 | nbr_dict['neighbors_index'] = nbr_indices.long().to(queries.device)
34 | nbr_dict['neighbors_row_splits'] = splits.long()
35 | return nbr_dict
36 |
37 |
38 | class FixedNeighborSearch(nn.Module):
39 | """Neighbor search within a ball of a given radius
40 |
41 | Parameters
42 | ----------
43 | use_open3d : bool
44 | Whether to use open3d or torch_cluster
45 | NOTE: open3d implementation requires 3d data
46 | """
47 |
48 | def __init__(self, use_open3d=True, use_torch_cluster=False):
49 | super().__init__()
50 | self.search_fn = simple_neighbor_search
51 | self.use_open3d = False
52 |
53 | def forward(self, data, queries, n_neigbor):
54 | """Find the neighbors, in data, of each point in queries
55 | within a ball of radius. Returns in CRS format.
56 |
57 | Parameters
58 | ----------
59 | data : torch.Tensor of shape [n, d]
60 | Search space of possible neighbors
61 | NOTE: open3d requires d=3
62 | queries : torch.Tensor of shape [m, d]
63 | Point for which to find neighbors
64 | NOTE: open3d requires d=3
65 | radius : float
66 | Radius of each ball: B(queries[j], radius)
67 |
68 | Output
69 | ----------
70 | return_dict : dict
71 | Dictionary with keys: neighbors_index, neighbors_row_splits
72 | neighbors_index: torch.Tensor with dtype=torch.int64
73 | Index of each neighbor in data for every point
74 | in queries. Neighbors are ordered in the same orderings
75 | as the points in queries. Open3d and torch_cluster
76 | implementations can differ by a permutation of the
77 | neighbors for every point.
78 | neighbors_row_splits: torch.Tensor of shape [m+1] with dtype=torch.int64
79 | The value at index j is the sum of the number of
80 | neighbors up to query point j-1. First element is 0
81 | and last element is the total number of neighbors.
82 | """
83 | return_dict = {}
84 |
85 | return_dict = self.search_fn(data, queries, n_neigbor)
86 |
87 | return return_dict
88 |
--------------------------------------------------------------------------------
/baseline_utlis/test_rigid_neightbour.py:
--------------------------------------------------------------------------------
1 | from rigid_neighbor import *
2 | import torch
3 |
4 | in_ = torch.randn((10, 3))
5 | out = torch.randn((10, 3))
6 |
7 | NS = FixedNeighborSearch(use_open3d=False)
8 |
9 | neighbour = NS(in_, in_, 3)
10 | print(neighbour)
11 |
--------------------------------------------------------------------------------
/config/RB_config.yaml:
--------------------------------------------------------------------------------
1 | base_config: &BASE_CONFIG
2 | ## SSL parameters
3 | random_seed: 42
4 | config: " "
5 | nettype: 'transformer'
6 | evaluation_channel_drop: 1 # number of varibales dropped for prediction task with partially observed data
7 | drop_type : 'zeros' # dropped values are replaced by zeros
8 | grid_type: 'uniform'
9 | max_block : 0.3 #Block size for dropping pixels during data augmentation
10 | drop_pix: 0.5 #percentage of pixels dropped during data augmentation
11 | channel_per: 0.5 # percentage of channels affected by data augmentation
12 | channel_drop_per: 0.5 # percentage of affected channels to be dropped completly
13 | validation_aug: !!bool False
14 | max_block_val : 0.3
15 | drop_pix_val: 0.5
16 | channel_per_val: 0.2
17 | channel_drop_per_val: 1.0
18 |
19 |
20 |
21 | scheduler_type: 'step' # lr scheduler type
22 | batch_size: 4
23 | use_variable_encoding: !!bool True
24 | n_variables: 5
25 | masking: !!bool True # if true, it will perform data augmentation for SSL
26 | in_token_codim_en: 1
27 | kqv_non_linear: !!bool False
28 |
29 | hidden_token_codim_en: 6
30 | lifting_token_codim_en: 12
31 | lifting_token_codim_pred: 12
32 | out_token_codim_pred: 1
33 | n_layers_en: 2
34 | n_heads_en: [2,2]
35 | n_layers_dec: 2
36 | n_heads_dec: [2,2]
37 | n_layers_pred: 3
38 | n_heads_pred: [2,2,2]
39 |
40 | scalings_pred: [[1,1], [1,1], [1,1]]
41 | scalings_en: [[1,1], [1,1],[1,1]]
42 | scalings_dec: [[1,1],[1,1]]
43 |
44 | n_modes_en: [[100,100], [100,100]]
45 | n_modes_dec: [[100,100], [100,100]]
46 | n_modes_pred: [[100,100], [100,100],[100,100]]
47 |
48 | per_channel_attention: !!bool True
49 |
50 |
51 | #fft_type: 'fft' # Duplicate should be removed
52 | transform_type: 'fft' # might be also spherical harmonics transform or 'sht'
53 |
54 | tno_integral_op: 'fno'
55 |
56 | var_encoding: !!bool True
57 | n_encoding_channels: 3
58 | reconstruction: !!bool True
59 | enable_cls_token: !!bool True
60 |
61 | # add_static_feature: !!bool False
62 |
63 | pretrain_ssl : !!bool True #if true we pretrain the model by SSL
64 |
65 | # if True, it will fine tune the encoder during SL
66 | # otherwise it will freeze the weight of the encoder
67 | # which is trained by SSL
68 |
69 | ## training Hyeper parameters
70 | training_stage: 'regular' # can be regular or fine_tune
71 | freeze_encoder : !!bool False # if true, it will freeze the encoder during SL
72 | lr: 0.03
73 | weight_decay: 0.0000
74 | scheduler_step: 50
75 | scheduler_gamma: 0.5
76 | epochs: 50
77 | clip_gradient: !!bool True
78 | gradient_clip_value: 0.1
79 | ssl_only: !!bool False # if true, it will only train the model by SSL and will not be followed by SL
80 | weight_path: "./weights/"
81 | weight_saving_interval: 3
82 |
83 | # Weights and biases
84 | wandb_log: True
85 | wandb_name: 'codano-RB'
86 | wandb_group: 'neuraloperator'
87 | wandb_project: 'CoDA-NO_neurips'
88 | wandb_entity: 'ashiq24'
89 | wandb_log_test_interval: 1
90 |
91 | dataset: ""
92 |
93 | ## incremental learning
94 | incremental: False
95 | buffer_modes: 5
96 | grad_explained_ratio_threshold: 0.9999
97 | max_iter: 1
98 | grad_max_iter: 1
99 |
100 | incremental_loss_gap: False
101 | eps: 0.1
102 |
103 | incremental_resolution: False
104 | epoch_gap: 150
105 | horizontal_skip: !!bool False
106 |
107 | codano_NS2: &CODANO_NS2
108 | <<: *BASE_CONFIG
109 | # dataset hyper
110 |
111 | n_train: 40
112 | n_dim: 2
113 | equation_dict: { "NS": 2}
114 | n_test: 40
115 |
116 | subsampling_rate: 2
117 |
118 | ###
119 | dt: 2
120 | skip_start: 250
121 |
122 | data_location: ["../../../../../raid/ashiq/ns_vel/NS_data_re5000.pt", "../../../../../raid/ashiq/ns_vel/NS_data_re500.pt"]
123 | dataset: "NS"
124 |
125 | # training. hypers
126 | epochs: 30
127 | lr: 0.01
128 | weight_decay: 0.0000
129 | scheduler_type: 'rdp'
130 | scheduler_step: 3
131 | scheduler_gamma: 0.5
132 | gradient_clip_value: 5.0
133 |
134 | enable_cls_token: !!bool True
135 | n_encoding_channels: 16
136 |
137 | n_layers_en: 4
138 | n_heads_en: [2,2,2,2]
139 | n_layers_dec: 4
140 | n_heads_dec: [2,2,2,2]
141 | n_layers_pred: 3
142 | n_heads_pred: [2,2,2]
143 |
144 | scalings_en: [[1,1], [1,1], [1,1], [1,1]]
145 | scalings_dec: [[1,1],[1,1],[1,1], [1,1]]
146 | scalings_pred: [[1,1], [1,1], [1,1]]
147 |
148 | max_n_modes_en: [[32,32], [32,32], [32,32], [32,32]]
149 | max_n_modes_dec: [[32,32], [32,32], [32,32], [32,32]]
150 | max_n_modes_pred: [[32,32], [32,32], [32,32]]
151 |
152 | n_modes_en: [[32,32], [32,32], [32,32], [32,32]]
153 | n_modes_dec: [[32,32], [32,32], [32,32], [32,32]]
154 | n_modes_pred: [[32,32], [32,32], [32,32]]
155 |
156 | hidden_token_codim_en: 64
157 | lifting_token_codim_en: 128
158 | lifting_token_codim_pred: 128
159 |
160 | ## varibale encoder
161 | encoding_modes_x : 32
162 | encoding_modes_y : 32
163 | encoding_modes_t : 20
164 | basis : 'fft'
165 |
166 |
167 | n_static_channels: 2
168 | in_token_codim_en: 1
169 | out_token_codim_pred: 1
170 |
171 |
172 | pretrain_ssl : !!bool True
173 | freeze_encoder : !!bool False
174 | ssl_only: !!bool True
175 | positional_encoding_dim: 4
176 |
177 | #data augmentation
178 | channel_per: 1.0
179 | channel_drop_per: 0.0
180 | max_block : 0.5
181 | drop_pix: 0.5
182 | masking: !!bool True
183 |
184 | codano_big: &CODANO_BIG
185 | <<: *CODANO_NS2
186 | n_train: 100
187 | n_test: 100
188 | max_block : 0.5
189 | drop_pix: 0.6
190 | batch_size: 8
191 |
192 | n_encoding_channels: 4
193 |
194 | n_layers_en: 3
195 | n_heads_en: [32, 32, 32]
196 | n_layers_dec: 3
197 | n_heads_dec: [32,32,32]
198 | n_layers_pred: 1
199 | n_heads_pred: [16]
200 |
201 | scalings_en: [[1,1], [1,1], [1,1]]
202 | scalings_dec: [[1,1],[1,1],[1,1]]
203 | scalings_pred: [[1,1]]
204 |
205 | max_n_modes_en: [[64,64], [64,64], [64,64]]
206 | max_n_modes_dec: [[64,64], [64,64], [64,64]]
207 | max_n_modes_pred: [[64,64]]
208 |
209 | n_modes_en: [[64,64], [64,64], [64,64]]
210 | n_modes_dec: [[64,64], [64,64], [64,64]]
211 | n_modes_pred: [[64,64]]
212 |
213 | hidden_token_codim_en: 64
214 | lifting_token_codim_en: 128
215 | lifting_token_codim_pred: 128
216 |
217 | ## varibale encoder
218 | encoding_modes_x : 64
219 | encoding_modes_y : 64
220 | encoding_modes_t : 20
221 | weight_saving_interval: 1
222 |
223 | codano_RB: &CODANO_RB
224 | <<: *CODANO_NS2
225 |
226 | n_train: 40
227 | n_dim: 2
228 | batch_size: 4
229 | equation_dict: { "NS": 2, "T": 1}
230 | n_test: 150
231 | subsampling_rate: 2
232 | data_location: "../../../RB_Data/RB_data/data_test_2500.npz"
233 | dataset: "RB"
234 |
235 | pretrain_ssl : !!bool False
236 | masking: !!bool False
237 |
238 |
239 | n_encoding_channels: 16
240 |
241 | n_layers_en: 3
242 | n_heads_en: [16,16,16]
243 | n_layers_dec: 4
244 | n_heads_dec: [16,16,16,16]
245 | n_layers_pred: 3
246 | n_heads_pred: [16,16,16]
247 |
248 | scalings_en: [[1,1], [1,1], [1,1]]
249 | scalings_dec: [[1,1],[1,1],[1,1], [1,1]]
250 | scalings_pred: [[1,1], [1,1], [1,1]]
251 |
252 | max_n_modes_en: [[64,64], [64,64], [64,64]]
253 | max_n_modes_dec: [[64,64], [64,64], [64,64], [64,64]]
254 | max_n_modes_pred: [[64,64], [64,64], [64,64]]
255 |
256 | n_modes_en: [[64,64], [64,64], [64,64]]
257 | n_modes_dec: [[64,64], [64,64], [64,64], [64,64]]
258 | n_modes_pred: [[64,64], [64,64], [64,64]]
259 |
260 | hidden_token_codim_en: 64
261 | lifting_token_codim_en: 128
262 | lifting_token_codim_pred: 128
263 |
264 | ## varibale encoder
265 | encoding_modes_x : 32
266 | encoding_modes_y : 32
267 | encoding_modes_t : 20
268 |
269 | scheduler_step: 10
270 | scheduler_gamma: 0.5
271 | lr: 0.01
272 |
273 | ft_codano_RB: &FT_CODANO_RB
274 | <<: *CODANO_BIG
275 |
276 | batch_size: 5
277 | equation_dict: { "NS": 2, "T": 1}
278 | n_test: 150
279 | data_location: "../../../RB_Data/RB_data/data_test_2500.npz"
280 | dataset: "RB"
281 |
282 | training_stage: 'fine_tune'
283 | pretrain_ssl : !!bool False
284 | masking: !!bool False
285 |
286 | n_layers_pred: 1
287 | n_heads_pred: [1]
288 |
289 | scalings_pred: [[1,1]]
290 |
291 | n_modes_pred: [[64,64]]
292 | max_n_modes_pred: [[64,64]]
293 | pretrain_weight: "../../../RB_weights/pre_trained_weights/codano_big_ssl_encoder_7.pt"
294 | NS_variable_encoder_path: "../../../RB_weights/pre_trained_weights/codano_big_variable_encoder_7_NS.pt"
295 | T_variable_encoder_path: None
296 |
297 | scheduler_step: 2
298 | freeze_encoder : !!bool False
299 | clip_gradient: !!bool True
300 | scheduler_gamma: 0.5
301 | scheduler_type: 'rdp'
302 | lr: 0.05
303 | horizontal_skip: !!bool False
304 | weight_decay: 0.000000
305 |
306 | ft_codano_RB_small: &FT_CODANO_RB_SMALL
307 | <<: *CODANO_BIG
308 | batch_size: 5
309 | equation_dict: { "NS": 2, "T": 1}
310 | n_test: 150
311 | data_location: "../../RB_Data/RB_data/data_test_2500.npz"
312 | dataset: "RB"
313 |
314 | training_stage: 'fine_tune'
315 | pretrain_ssl : !!bool False
316 | masking: !!bool False
317 |
318 | n_layers_pred: 3
319 | n_heads_pred: [16 ,16,32] #,2]
320 |
321 | scalings_pred: [[1,1], [1,1], [1,1]] #, [1,1]]
322 |
323 | n_modes_pred: [[32,32], [32,32], [32,32]] #, [32,32]]
324 | max_n_modes_pred: [[32,32], [32,32], [32,32]] #,[32,32]]
325 | pretrain_weight: None
326 | NS_variable_encoder_path: None
327 | T_variable_encoder_path: None
328 |
329 | scheduler_step: 5
330 | freeze_encoder : !!bool False
331 | scheduler_gamma: 0.5
332 | lr: 0.005
333 |
334 | ft_codano_RB_test: &FT_CODANO_RB_TEST
335 | <<: *CODANO_RB
336 | n_train: 5
337 | batc_size: 2
338 | n_test: 5
339 | epochs: 10
340 |
341 |
342 | unet: &UNET
343 | <<: *CODANO_RB
344 | batch_size: 5
345 | nettype: 'unet'
346 | n_test: 150
347 | in_dim: 3
348 | out_dim: 3
349 | init_features: 64
350 |
351 | lr: 0.01
352 | scheduler_step: 5
353 | scheduler_gamma: 0.5
354 |
355 | fno: &FNO
356 | <<: *UNET
357 | nettype: 'fno'
358 | hidden_features: 32
359 | lifting_features: 64
360 |
361 | n_modes: [32,32]
362 | max_n_modes: [32,32]
363 | hidden_dim: 32
364 | in_dim: 3
365 | out_dim: 3
366 | lifting_dim: 64
367 | projection_dim: 64
368 | n_layers: 4
--------------------------------------------------------------------------------
/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_utils import *
2 | from .data_loaders import *
3 |
--------------------------------------------------------------------------------
/data_utils/data_loaders.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import random
3 | import h5py
4 | import os
5 | import torch
6 | import numpy as np
7 | from torchvision.transforms import Normalize
8 | from torch.utils.data import ConcatDataset, random_split, DataLoader
9 | import itertools
10 | from utils import *
11 | from neuralop.datasets.tensor_dataset import TensorDataset
12 | from data_utils import get_mesh_displacement
13 |
14 |
15 | # data loaders of Releigh-Berneard and Navier-Stokes
16 |
17 |
18 | class regularMeshTensorDataset(TensorDataset):
19 | def __init__(
20 | self,
21 | x,
22 | y,
23 | transform_x=None,
24 | transform_y=None,
25 | equation=None):
26 | '''
27 | data in format
28 | x: samples x channels x x_grid_size x y_grid_size
29 | '''
30 | super().__init__(x, y, transform_x, transform_y)
31 | self.equation = equation
32 | x_grid_size = x[0].shape[-2]
33 | y_grid_size = x[0].shape[-1]
34 |
35 | self.static_features = self.generate_static_grid(
36 | x_grid_size, y_grid_size)
37 |
38 | def generate_static_grid(self, x_grid_size, y_grid_size):
39 | '''
40 | creat grid of resolution x_grid_size x y_grid_size
41 | '''
42 | x = torch.linspace(-1, 1, x_grid_size)
43 | y = torch.linspace(-1, 1, y_grid_size)
44 | x, y = torch.meshgrid(x, y)
45 | return torch.permute(torch.stack([x, y], dim=-1), (2, 0, 1))
46 |
47 | def __getitem__(self, index):
48 | x = self.x[index]
49 | y = self.y[index]
50 |
51 | if self.transform_x is not None:
52 | x = self.transform_x(x)
53 |
54 | if self.transform_y is not None:
55 | y = self.transform_y(y)
56 |
57 | return {'x': x, 'y': y, 'equation': self.equation,
58 | 'static_features': self.static_features}
59 |
60 |
61 | def load_NS_dataset(path, n_samples, subsamplingrate=2):
62 | data_1 = torch.load(path[0])
63 | data_2 = torch.load(path[1])
64 |
65 | x1_t0 = data_1[:, :-1, :, :]
66 | x1_t1 = data_1[:, 1:, :, :]
67 | x2_t0 = data_2[:, :-1, :, :]
68 | x2_t1 = data_2[:, 1:, :, :]
69 |
70 | # flatten the data
71 | x1_t0 = x1_t0.reshape(-1, x1_t0.shape[-3],
72 | x1_t0.shape[-2], x1_t0.shape[-1])
73 | x1_t1 = x1_t1.reshape(-1, x1_t0.shape[-3],
74 | x1_t1.shape[-2], x1_t1.shape[-1])
75 | x2_t0 = x2_t0.reshape(-1, x1_t0.shape[-3],
76 | x2_t0.shape[-2], x2_t0.shape[-1])
77 | x2_t1 = x2_t1.reshape(-1, x1_t0.shape[-3],
78 | x2_t1.shape[-2], x2_t1.shape[-1])
79 |
80 | # shuffel the data
81 | indices = torch.randperm(x1_t0.shape[0])
82 | x1_t0 = x1_t0[indices]
83 | x1_t1 = x1_t1[indices]
84 | idx = torch.randperm(x2_t0.shape[0])
85 | x2_t0 = x2_t0[indices]
86 | x2_t1 = x2_t1[indices]
87 |
88 | ntain1 = int(n_samples / 2)
89 | ntain2 = n_samples - ntain1
90 |
91 | x1_train = x1_t0[:ntain1]
92 | y1_train = x1_t1[:ntain1]
93 | x2_train = x2_t0[:ntain2]
94 | y2_train = x2_t1[:ntain2]
95 |
96 | # concatenate the data
97 | x = torch.cat([x1_train, x2_train], dim=0)
98 | y = torch.cat([y1_train, y2_train], dim=0)
99 |
100 | return x.permute(0, 3, 1, 2), y.permute(0, 3, 1, 2)
101 |
102 |
103 | def load_NS_dataset_hdf5(path, n_samples, subsamplingrate=2):
104 | '''
105 | loading only velocity field data
106 | '''
107 | samples = 0
108 | i = 0
109 | data_x = []
110 | data_y = []
111 |
112 | while samples < n_samples + 100:
113 | if i > 999:
114 | break
115 | data_path = path[0] + 'data_' + str(i) + '.hdf5'
116 | x = h5py.File(data_path, "r")['data']
117 | data_x.append(x[:: subsamplingrate, :: subsamplingrate, :-1, :2])
118 | data_y.append(x[:: subsamplingrate, :: subsamplingrate, 1:, :2])
119 |
120 | data_path = path[1] + 'data_' + str(i) + '.hdf5'
121 | x = h5py.File(data_path, "r")['data']
122 | data_x.append(x[:: subsamplingrate, :: subsamplingrate, :-1, :2])
123 | data_y.append(x[:: subsamplingrate, :: subsamplingrate, 1:, :2])
124 |
125 | i += 1
126 | samples += (x.shape[2] - 1)
127 |
128 | x = torch.tensor(np.concatenate(data_x, axis=2))
129 | y = torch.tensor(np.concatenate(data_y, axis=2))
130 |
131 | print("Data Loaded: ", x.shape, y.shape)
132 |
133 | return torch.permute(x, (2, 3, 0, 1)), torch.permute(y, (2, 3, 0, 1))
134 |
135 |
136 | def get_NS_dataloader(params):
137 | n_samples = params.n_train + params.n_test
138 | x, y = load_NS_dataset(params.data_location,
139 | n_samples, params.subsampling_rate)
140 |
141 | # shuffel and split test train
142 | indices = torch.randperm(x.shape[0])
143 | x = x[indices]
144 | x_train = x[:params.n_train]
145 | y_train = y[:params.n_train]
146 |
147 | x_test = x[params.n_train:]
148 | y_test = y[params.n_train:]
149 |
150 | # get the mean and std of the data
151 | mean = torch.mean(x_train, dim=(0, 1, 2, 3))
152 | std = torch.std(x_train, dim=(0, 1, 2, 3))
153 | normalizer = Normalize(mean, std)
154 |
155 | dataset_train = regularMeshTensorDataset(
156 | x_train,
157 | y_train,
158 | transform_x=normalizer,
159 | transform_y=normalizer,
160 | equation=['NS'])
161 | dataset_test = regularMeshTensorDataset(
162 | x_test,
163 | y_test,
164 | transform_x=normalizer,
165 | transform_y=normalizer,
166 | equation=['NS'])
167 |
168 | dat_train = DataLoader(
169 | dataset_train, batch_size=params.batch_size, shuffle=True)
170 | dat_test = DataLoader(
171 | dataset_test, batch_size=params.batch_size, shuffle=False)
172 | return dat_train, dat_test
173 |
174 |
175 | def get_RB_dataloader(params):
176 | data = np.load(params.data_location)
177 | # files ['vx', 'vy', 'temp', 'time']
178 | vx = torch.tensor(data['vx']).type(torch.float)
179 | vy = torch.tensor(data['vy']).type(torch.float)
180 | temp = torch.tensor(data['temp']).type(torch.float)
181 | time = torch.tensor(data['time']).type(torch.float)
182 |
183 | # stack the data
184 | data = torch.stack([vx, vy, temp], dim=1)
185 | data = data[params.skip_start:]
186 | x = data[:int(-1 * params.dt)]
187 | y = data[int(params.dt):]
188 |
189 | # shuffel and split test train
190 | # fix the seed
191 | torch.manual_seed(params.random_seed)
192 | indices = torch.randperm(x.shape[0])
193 | x = x[indices]
194 | y = y[indices]
195 |
196 | x_train = x[:params.n_train, :,
197 | ::params.subsampling_rate, ::params.subsampling_rate]
198 | y_train = y[:params.n_train, :,
199 | ::params.subsampling_rate, ::params.subsampling_rate]
200 |
201 | x_test = x[-params.n_test:]
202 | y_test = y[-params.n_test:]
203 | print("len test data", len(x_test), len(y_test), params.n_test, x.shape)
204 |
205 | # get the mean and std of the data
206 | mean = torch.mean(x_train, dim=(0, 2, 3))
207 | std = torch.std(x_train, dim=(0, 2, 3))
208 | normalizer = Normalize(mean, std)
209 |
210 | dataset_train = regularMeshTensorDataset(
211 | x_train,
212 | y_train,
213 | transform_x=normalizer,
214 | transform_y=normalizer,
215 | equation=['ES'])
216 | dataset_test = regularMeshTensorDataset(
217 | x_test,
218 | y_test,
219 | transform_x=normalizer,
220 | transform_y=normalizer,
221 | equation=['ES'])
222 |
223 | dat_train = DataLoader(
224 | dataset_train, batch_size=params.batch_size, shuffle=True)
225 | dat_test = DataLoader(
226 | dataset_test, batch_size=params.batch_size, shuffle=False)
227 |
228 | return dat_train, dat_test
229 |
230 |
231 | # dataloader for Fluid Sturctur Interaction (FSI) problems
232 |
233 | class IrregularMeshTensorDataset(TensorDataset):
234 | def __init__(
235 | self,
236 | x,
237 | y,
238 | transform_x=None,
239 | transform_y=None,
240 | equation=None,
241 | x1=0,
242 | x2=0,
243 | mu=0.1,
244 | mesh=None):
245 | super().__init__(x, y, transform_x, transform_y)
246 | self.x1 = x1
247 | self.x2 = x2
248 |
249 | self.mu = mu
250 | self.mesh = mesh
251 | self.equation = equation
252 | print("Inside Dataset :", self.mesh.dtype, x.dtype, x.dtype)
253 | self._creat_static_features()
254 |
255 | def _creat_static_features(self, d_grid=None):
256 | '''
257 | creating static channels for inlet and reynolds number
258 | '''
259 | n_grid_points = self.x.shape[1]
260 | if len(self.equation) == 1:
261 | # equation can be either ['NS'] or ['NS', 'ES']
262 | # of 3 or 5 channels/varibales
263 | n_variables = 3
264 | else:
265 | n_variables = self.x.shape[-1]
266 | if d_grid is not None:
267 | positional_enco = self.mesh + d_grid
268 | else:
269 | positional_enco = self.mesh
270 |
271 | raynolds = torch.ones(n_grid_points, 1) * self.mu
272 | inlet = ((-self.x1 / 2 + positional_enco[:, 1]) *
273 | (-self.x2 / 2 + positional_enco[:, 1]))[:, None]**2
274 |
275 | self.static_features = torch.cat(
276 | [raynolds, inlet, positional_enco], dim=-1).repeat(1, n_variables)
277 |
278 | def __getitem__(self, index):
279 | x = self.x[index]
280 | y = self.y[index]
281 |
282 | d_grid_x = get_mesh_displacement(x)
283 | d_grid_y = get_mesh_displacement(y)
284 |
285 | self._creat_static_features(d_grid_x)
286 |
287 | if self.transform_x is not None:
288 | x = self.transform_x(x)
289 |
290 | if self.transform_y is not None:
291 | y = self.transform_y(y)
292 |
293 | if len(self.equation) == 1:
294 | x = x[:, :3]
295 | y = y[:, :3]
296 |
297 | return {'x': x, 'y': y, 'd_grid_x': d_grid_x,
298 | 'd_grid_y': d_grid_y, 'static_features': self.static_features,
299 | 'equation': self.equation}
300 |
301 |
302 | class Normalizer():
303 | def __init__(self, mean, std, eps=1e-6, persample=False):
304 | self.persample = persample
305 | self.mean = mean
306 | self.std = std
307 | self.eps = eps
308 |
309 | def __call__(self, data):
310 | if self.persample:
311 | self.mean = torch.mean(data, dim=(0))
312 | self.std = torch.var(data, dim=(0))**0.5
313 | return (data - self.mean) / (self.std + self.eps)
314 |
315 | def denormalize(self, data):
316 | return data * (self.std + self.eps) + self.mean
317 |
318 | def cuda(self,):
319 | if self.mean is not None and self.std is not None:
320 | self.mean = self.mean.cuda()
321 | self.std = self.std.cuda()
322 |
323 |
324 | class NsElasticDataset():
325 | def __init__(self, location, equation, mesh_location, params):
326 | self.location = location
327 |
328 | # _x1 and _x2 are the paraemters for the inlets condtions
329 | # _mu is the visocity
330 | self._x1 = [-4.0, -2.0, 0.0, 2.0, 4.0, 6.0]
331 | self._x2 = [-4.0, -2.0, 0, 2.0, 4.0, 6.0]
332 | self._mu = [0.1, 0.01, 0.5, 5, 1, 10]
333 | if params.data_partition == 'supervised':
334 | # held out 2 inlets for finetuning
335 | # there not introduced in the self-supevised
336 | # pretraining
337 | self._x1 = params.supervised_inlets_x1
338 | self._x2 = params.supervised_inlets_x2
339 | elif params.data_partition == 'self-supervised':
340 | self._x1 = list(set(self._x1) - set(params.supervised_inlets_x1))
341 | self._x2 = list(set(self._x2) - set(params.supervised_inlets_x2))
342 | else:
343 | raise ValueError(
344 | f"Data partition {params.data_partition} not supported")
345 |
346 | self.equation = equation
347 |
348 | mesh = get_mesh(params)
349 | self.input_mesh = torch.from_numpy(mesh).type(torch.float)
350 | print("Mesh Shape: ", self.input_mesh.shape)
351 | self.params = params
352 |
353 | self.normalizer = Normalizer(None, None, persample=True)
354 |
355 | def _readh5(self, h5f, dtype=torch.float32):
356 | a_dset_keys = list(h5f['VisualisationVector'].keys())
357 | size = len(a_dset_keys)
358 | readings = [None for i in range(size)]
359 | for dset in a_dset_keys:
360 | ds_data = (h5f['VisualisationVector'][dset])
361 | readings[int(dset)] = torch.tensor(np.array(ds_data), dtype=dtype)
362 |
363 | readings_tensor = torch.stack(readings, dim=0)
364 | print(f"Loaded tensor Size: {readings_tensor.shape}")
365 | return readings_tensor
366 |
367 | def get_data(self, mu, x1, x2):
368 | if mu not in self._mu:
369 | raise ValueError(f"Value of mu must be one of {self._mu}")
370 | if x1 not in self._x1 or x2 not in self._x2:
371 | raise ValueError(
372 | f"Value of is must be one of {self._ivals3} and {self._ivals12} ")
373 | if mu == 0.5:
374 | path = os.path.join(
375 | self.location,
376 | 'mu=' + str(mu),
377 | 'x1=' + str(-2.0),
378 | 'x2=' + str(x2),
379 | '1',
380 | 'Visualization')
381 | print(path)
382 | else:
383 | path = os.path.join(
384 | self.location,
385 | 'mu=' + str(mu),
386 | 'x1=' + str(x1),
387 | 'x2=' + str(x2),
388 | 'Visualization')
389 |
390 | filename = os.path.join(path, 'displacement.h5')
391 |
392 | h5f = h5py.File(filename, 'r')
393 | displacements_tensor = self._readh5(h5f)
394 |
395 | filename = os.path.join(path, 'pressure.h5')
396 | h5f = h5py.File(filename, 'r')
397 | pressure_tensor = self._readh5(h5f)
398 |
399 | filename = os.path.join(path, 'velocity.h5')
400 | h5f = h5py.File(filename, 'r')
401 | velocity_tensor = self._readh5(h5f)
402 |
403 | return velocity_tensor, pressure_tensor, displacements_tensor
404 |
405 | def get_data_txt(self, mu, x1, x2):
406 | if mu not in self._mu:
407 | raise ValueError(f"Value of mu must be one of {self._mu}")
408 | if x1 not in self._x1 or x2 not in self._x2:
409 | raise ValueError(
410 | f"Value of is must be one of {self._ivals3} and {self._ivals12} ")
411 | path = os.path.join(
412 | self.location,
413 | 'mu=' + str(mu),
414 | 'x1=' + str(x1),
415 | 'x2=' + str(x2),
416 | '1')
417 |
418 | velocity_x = torch.tensor(np.loadtxt(os.path.join(path, 'vel_x.txt')))
419 | velocity_y = torch.tensor(np.loadtxt(os.path.join(path, 'vel_y.txt')))
420 | if len(self.params.equation_dict) != 1:
421 | dis_x = torch.tensor(np.loadtxt(os.path.join(path, 'dis_x.txt')))
422 | dis_y = torch.tensor(np.loadtxt(os.path.join(path, 'dis_y.txt')))
423 | pressure = torch.tensor(np.loadtxt(os.path.join(path, 'pres.txt')))
424 | else:
425 | # just copying values as place holder when only NS equation is used
426 | dis_x = velocity_x
427 | dis_y = velocity_y
428 | pressure = velocity_x
429 |
430 | # reshape each tensor into 2d by keeping 876 entries in each row
431 | dis_x = dis_x.view(-1, 876, 1)
432 | dis_y = dis_y.view(-1, 876, 1)
433 | pressure = pressure.view(-1, 876, 1)
434 | velocity_x = velocity_x.view(-1, 876, 1)
435 | velocity_y = velocity_y.view(-1, 876, 1)
436 |
437 | velocity = torch.cat([velocity_x, velocity_y], dim=-1)
438 | displacement = torch.cat([dis_x, dis_y], dim=-1)
439 |
440 | return velocity.to(
441 | torch.float), pressure.to(
442 | torch.float), displacement.to(
443 | torch.float)
444 |
445 | def get_dataloader(
446 | self,
447 | mu_list,
448 | dt,
449 | normalize=True,
450 | batch_size=1,
451 | train_test_split=0.2,
452 | sample_per_inlet=200,
453 | ntrain=None,
454 | ntest=None,
455 | data_loader_kwargs={'num_workers': 2}):
456 |
457 | train_datasets = []
458 | test_datasets = []
459 |
460 | for mu in mu_list:
461 | train, test = self.get_tensor_dataset(
462 | mu, dt, normalize, train_test_split=train_test_split, sample_per_inlet=sample_per_inlet)
463 | train_datasets.append(train)
464 | test_datasets.append(test)
465 | train_dataset = ConcatDataset(train_datasets)
466 | test_dataset = ConcatDataset(test_datasets)
467 | print("****Train dataset size***: ", len(train_dataset))
468 | print("***Test dataset size***: ", len(test_dataset))
469 | if ntrain is not None:
470 | train_dataset = random_split(
471 | train_dataset, [ntrain, len(train_dataset) - ntrain])[0]
472 | if ntest is not None:
473 | test_dataset = random_split(
474 | test_dataset, [ntest, len(test_dataset) - ntest])[0]
475 |
476 | train_dataloader = DataLoader(
477 | train_dataset, batch_size=batch_size, **data_loader_kwargs)
478 | test_dataloader = DataLoader(
479 | test_dataset, batch_size=batch_size, **data_loader_kwargs)
480 |
481 | return train_dataloader, test_dataloader
482 |
483 | def get_tensor_dataset(
484 | self,
485 | mu,
486 | dt,
487 | normalize=True,
488 | min_max_normalize=False,
489 | train_test_split=0.2,
490 | sample_per_inlet=200,
491 | x1_list=None,
492 | x2_list=None):
493 |
494 | if x1_list is None:
495 | x1_list = self._x1
496 | if x2_list is None:
497 | x2_list = self._x2
498 | train_datasets = []
499 | test_datasets = []
500 | # for the given mu
501 | # loop over all given inlets
502 | for x1 in x1_list:
503 | for x2 in x2_list:
504 | try:
505 | if mu == 0.5:
506 | velocities, pressure, displacements = self.get_data_txt(
507 | mu, x1, x2)
508 | else:
509 | velocities, pressure, displacements = self.get_data(
510 | mu, x1, x2)
511 | except FileNotFoundError as e:
512 | print(e)
513 | print(
514 | f"Original file not found for mu={mu}, x1={x1}, x2={x2}")
515 | continue
516 |
517 | # keeping vx,xy, P, dx,dy
518 | varable_idices = [0, 1, 3, 4, 5]
519 | if mu == 0.5:
520 | combined = torch.cat(
521 | [velocities, pressure, displacements], dim=-1)[:sample_per_inlet, :, :]
522 | else:
523 | combined = torch.cat(
524 | [velocities, pressure, displacements], dim=-1)[:sample_per_inlet, :, varable_idices]
525 |
526 | if hasattr(
527 | self.params,
528 | 'sub_sample_size') and self.params.sub_sample_size is not None:
529 | mesh_size = combined.shape[1]
530 | indexs = [i for i in range(mesh_size)]
531 | np.random.seed(self.params.random_seed)
532 | sub_indexs = np.random.choice(
533 | indexs, self.params.sub_sample_size, replace=False)
534 | combined = combined[:, sub_indexs, :]
535 |
536 | if self.params.super_resolution:
537 | new_quieries = self.get_data_txt(
538 | mu, x1, x2).to(dtype=combined.dtype)
539 | new_quieries = new_quieries[:sample_per_inlet, :]
540 |
541 | print("shape of old data", combined.shape)
542 | print("shape of new data", new_quieries.shape)
543 |
544 | combined = torch.cat([combined, new_quieries], dim=-2)
545 | print("shape of combined data", combined.shape)
546 |
547 | step_t0 = combined[:-dt, ...]
548 | step_t1 = combined[dt:, ...]
549 |
550 | indexs = [i for i in range(step_t0.shape[0])]
551 |
552 | ntrain = int((1 - train_test_split) * len(indexs))
553 | ntest = len(indexs) - ntrain
554 |
555 | random.shuffle(indexs)
556 | train_t0, test_t0 = step_t0[indexs[:ntrain]
557 | ], step_t0[indexs[ntrain:ntrain + ntest]]
558 | train_t1, test_t1 = step_t1[indexs[:ntrain]
559 | ], step_t1[indexs[ntrain:ntrain + ntest]]
560 |
561 | if not normalize:
562 | normalizer = None
563 | else:
564 | if not min_max_normalize:
565 | mean, var = torch.mean(train_t0, dim=(
566 | 0, 1)), torch.var(train_t0, dim=(0, 1))**0.5
567 | else:
568 | mean = torch.min(
569 | train_t0.view(-1, train_t0.shape[-1]), dim=0)[0]
570 | var = torch.max(train_t0.view(-1,
571 | train_t0.shape[-1]),
572 | dim=0)[0] - torch.min(train_t0.view(-1,
573 | train_t0.shape[-1]),
574 | dim=0)[0]
575 |
576 | normalizer = Normalizer(mean, var)
577 |
578 | train_datasets.append(
579 | IrregularMeshTensorDataset(
580 | train_t0,
581 | train_t1,
582 | normalizer,
583 | normalizer,
584 | x1=x1,
585 | x2=x2,
586 | mu=mu,
587 | equation=self.equation,
588 | mesh=self.input_mesh))
589 | test_datasets.append(
590 | IrregularMeshTensorDataset(
591 | test_t0,
592 | test_t1,
593 | normalizer,
594 | normalizer,
595 | x1=x1,
596 | x2=x2,
597 | mu=mu,
598 | equation=self.equation,
599 | mesh=self.input_mesh))
600 |
601 | return ConcatDataset(train_datasets), ConcatDataset(test_datasets)
602 |
--------------------------------------------------------------------------------
/data_utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import sys
4 | from typing import Optional, Tuple
5 | from utils import *
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset
10 |
11 | sys.path.append('..')
12 |
13 |
14 | class ResizeDataset(Dataset):
15 | def __init__(self, parent_dataset, resolution):
16 | self.parent_dataset = parent_dataset
17 | self.resolution = resolution
18 |
19 | def __len__(self,):
20 | return self.parent_dataset.__len__()
21 |
22 | def __getitem__(self, global_idx):
23 | x, y = self.parent_dataset.__getitem__(global_idx)
24 | xp = F.interpolate(
25 | x[None, ...] if len(x.shape) < 4 else x,
26 | size=self.resolution,
27 | mode='bicubic',
28 | align_corners=True,
29 | )
30 | yp = F.interpolate(
31 | y[None, ...] if len(y.shape) < 4 else y,
32 | size=self.resolution,
33 | mode='bicubic',
34 | align_corners=True,
35 | )
36 | return torch.squeeze(xp), torch.squeeze(yp)
37 |
38 |
39 | class MaskerUniform:
40 | """Performs masking on datasets with regular meshes.
41 |
42 | For masking with data points from data sets with irregular meshes,
43 | use ``MaskerNonuniformMesh`` (below).
44 | """
45 |
46 | def __init__(
47 | self,
48 | drop_type='zeros',
49 | max_block=0.7,
50 | drop_pix=0.3,
51 | channel_per=0.5,
52 | channel_drop_per=0.2,
53 | device='cpu',
54 | min_block=30,
55 | ):
56 | self.drop_type = drop_type
57 | self.max_block = max_block
58 | self.drop_pix = drop_pix
59 | self.channel_per = channel_per
60 | self.channel_drop_per = channel_drop_per
61 | self.device = device
62 | self.min_block = min_block
63 |
64 | def __call__(self, size):
65 | """Returns a mask to be multiplied into a data tensor.
66 |
67 | Generates a binary mask of 0s and 1s to be point-wise multiplied into a
68 | data tensor to create a masked sample. By training on masked data, we
69 | expect the model to be resilient to missing data.
70 |
71 | TODO arg ``max_block_sz`` is outdated
72 | Parameters
73 | ---
74 | max_block_sz: percentage of the maximum block to be dropped
75 | """
76 |
77 | np.random.seed()
78 | C, H, W = size
79 | mask = torch.ones(size, device=self.device)
80 | drop_t = self.drop_type
81 | if self.channel_per == 1.0:
82 | augmented_channels = [i for i in range(C)]
83 | else:
84 | augmented_channels = np.random.choice(
85 | C, math.ceil(C * self.channel_per))
86 | drop_len = int(self.channel_drop_per * math.ceil(C * self.channel_per))
87 | mask[augmented_channels[:drop_len], :, :] = 0.0
88 | for i in augmented_channels[drop_len:]:
89 | n_drop_pix = self.drop_pix * H * W
90 | mx_blk_height = int(H * self.max_block)
91 | mx_blk_width = int(W * self.max_block)
92 |
93 | repetation = 3
94 | while n_drop_pix > 0 and repetation > 0:
95 | rnd_r = random.randint(0, H - 2)
96 | rnd_c = random.randint(0, W - 2)
97 |
98 | rnd_h = min(
99 | random.randint(self.min_block, mx_blk_height),
100 | H - rnd_r
101 | )
102 | rnd_w = min(
103 | random.randint(self.min_block, mx_blk_width),
104 | W - rnd_c
105 | )
106 | mask[i, rnd_r:rnd_r + rnd_h, rnd_c:rnd_c + rnd_w] = 0
107 | n_drop_pix -= rnd_h * rnd_c
108 | repetation -= 1
109 | return None, mask
110 |
111 |
112 | def batched_masker(
113 | data: torch.Tensor, # with a batch dimension
114 | aug, # instance of a Masker like above
115 | batched_channels: Optional[Tuple[Tuple[int]]] = None,
116 | ):
117 | mask = []
118 | aug.device = data.device
119 | for i in range(data.shape[0]):
120 | _, m = aug(data[i].shape)
121 | mask.append(m)
122 |
123 | masks = torch.stack(mask, dim=0)
124 | return data * masks, masks
125 |
126 |
127 | # Nearest neighbour type masking for non-uniform
128 | # mesh
129 |
130 | class MaskerNonuniformMesh(object):
131 | def __init__(
132 | self,
133 | grid_non_uni,
134 | gird_uni,
135 | radius,
136 | drop_type='zeros',
137 | drop_pix=0.3,
138 | channel_aug_rate=0.7,
139 | channel_drop_rate=0.2,
140 | device='cpu',
141 | max_block=10,
142 | verbose=None
143 | ):
144 | """
145 | Parameters
146 | ---
147 | drop_type : dropped pixels are filled with zeros
148 | drop_pix: Percentage of pixels to be dropped
149 | channel_aug_rate: Percentage of channels to be augmented
150 | channel_drop_rate: Percentage of (channel_aug_rate) channels to be masked completely (drop all pxiels)
151 | max_block: Maximum number of regions to be dropped.
152 | """
153 | self.grid_non_uni = grid_non_uni
154 | self.grid_uni = gird_uni
155 | dists = torch.cdist(gird_uni, grid_non_uni).to(
156 | gird_uni.device) # shaped num query points x num data points
157 | self.in_nbr = torch.where(dists <= radius, 1., 0.).long()
158 |
159 | self.drop_type = drop_type
160 | self.drop_pix = drop_pix
161 | self.channel_aug_rate = channel_aug_rate
162 | self.channel_drop_rate = channel_drop_rate
163 | self.device = device
164 | self.max_block = max_block
165 | self.verbose = verbose
166 |
167 | def __call__(self, size):
168 |
169 | L, C = size
170 | mask = torch.ones(size, device=self.device)
171 | drop_t = self.drop_type # no effect now
172 |
173 | augmented_channels = np.random.choice(
174 | C, math.ceil(C * self.channel_aug_rate))
175 |
176 | p = random.random()
177 | # with 50% probability, drop all pixels in very few channels
178 | # or doing masking to many channels
179 | if p < 0.5:
180 | drop_len = int(
181 | self.channel_drop_rate *
182 | math.ceil(
183 | C *
184 | self.channel_aug_rate))
185 | mask[:, augmented_channels[:drop_len]] = 0.0
186 | return None, mask
187 | else:
188 | drop_len = 0
189 | # print('masking', augmented_channels[drop_len:])
190 | for i in augmented_channels[drop_len:]:
191 | n_drop_pix = self.drop_pix * L
192 | max_location = self.max_block
193 | while n_drop_pix > 0:
194 | # python random is inclusive of low and high
195 | j = random.randint(0, self.in_nbr.shape[0] - 1)
196 | mask[self.in_nbr[j] == 1, i] = 0
197 | n_drop_pix -= sum(self.in_nbr[j]).float()
198 | max_location -= 1
199 | if max_location == 0:
200 | break
201 | return None, mask
202 |
203 |
204 | def get_meshes(params, grid_size):
205 | mesh = get_mesh(params)
206 | input_mesh = torch.from_numpy(mesh).type(torch.float).cuda()
207 |
208 | minx, maxx = np.min(mesh[:, 0]), np.max(mesh[:, 0])
209 | miny, maxy = np.min(mesh[:, 1]), np.max(mesh[:, 1])
210 |
211 | size_x, size_y = grid_size
212 | idx_x = torch.arange(start=minx,
213 | end=maxx + (maxx - minx) / size_x - 1e-5,
214 | step=(maxx - minx) / (size_x - 1))
215 | idx_y = torch.arange(start=miny,
216 | end=maxy + (maxy - miny) / size_y - 1e-5,
217 | step=(maxy - miny) / (size_y - 1))
218 | x, y = torch.meshgrid(idx_x, idx_y, indexing='ij')
219 | output_mesh = torch.transpose(torch.stack(
220 | [x.flatten(), y.flatten()]), 0, 1).type(torch.float).cuda()
221 |
222 | return input_mesh, output_mesh
223 |
224 | def get_mesh_displacement(x):
225 | """
226 | Only returns the displacement field.
227 | """
228 | return x[:, -2:].clone().detach()
229 |
--------------------------------------------------------------------------------
/exps_FSI.sh:
--------------------------------------------------------------------------------
1 | # SSL -
2 | python main.py --exp FSI --config codano_gno_NS_ES --ntrain 8000
3 | python main.py --exp FSI --config codano_gno_NS --ntrain 8000
4 |
5 |
6 | #Finetuning
7 | ## Re = 400
8 | python main.py --exp FSI --config ft_NSES_NSES_5 --ntrain 5 --epochs 50 --scheduler_step 10
9 | python main.py --exp FSI --config ft_NSES_NSES_5 --ntrain 25
10 | python main.py --exp FSI --config ft_NSES_NSES_5 --ntrain 100
11 |
12 | python main.py --exp FSI --config ft_NS_NSES_5 --ntrain 5 --epochs 50 --scheduler_step 10
13 | python main.py --exp FSI --config ft_NS_NSES_5 --ntrain 25
14 | python main.py --exp FSI --config ft_NS_NSES_5 --ntrain 100
15 |
16 | python main.py --exp FSI --config ft_NSES_NS_5 --ntrain 5 --epochs 50 --scheduler_step 10
17 | python main.py --exp FSI --config ft_NSES_NS_5 --ntrain 25
18 | python main.py --exp FSI --config ft_NSES_NS_5 --ntrain 100
19 |
20 | python main.py --exp FSI --config ft_NS_NS_5 --ntrain 5 --epochs 50 --scheduler_step 10
21 | python main.py --exp FSI --config ft_NS_NS_5 --ntrain 25
22 | python main.py --exp FSI --config ft_NS_NS_5 --ntrain 100
23 |
24 | # ## Re = 4000
25 | python main.py --exp FSI --config ft_NSES_NSES_0.5 --ntrain 5
26 | python main.py --exp FSI --config ft_NSES_NSES_0.5 --ntrain 25
27 | python main.py --exp FSI --config ft_NSES_NSES_0.5 --ntrain 100
28 |
29 | python main.py --exp FSI --config ft_NS_NSES_0.5 --ntrain 5
30 | python main.py --exp FSI --config ft_NS_NSES_0.5 --ntrain 25
31 | python main.py --exp FSI --config ft_NS_NSES_0.5 --ntrain 100
32 |
33 |
--------------------------------------------------------------------------------
/exps_RB.sh:
--------------------------------------------------------------------------------
1 | #SSL
2 | python main.py --exp RB --config codano_big --ntrain 40000
3 |
4 | # Finetuning
5 | python main.py --exp RB --config ft_codano_RB --ntrain 5 --epochs 50 --batch_size 1
6 | python main.py --exp RB --config ft_codano_RB --ntrain 10 --epochs 50 --batch_size 2
7 | python main.py --exp RB --config ft_codano_RB --ntrain 25 --epochs 50 --batch_size 5
8 |
9 | python main.py --exp RB --config unet --ntrain 5 --epochs 50 --batch_size 1
10 | python main.py --exp RB --config unet --ntrain 10 --epochs 50 --batch_size 2
11 | python main.py --exp RB --config unet --ntrain 25 --epochs 50 --batch_size 5
12 |
13 | python main.py --exp RB --config fno --ntrain 5 --epochs 50 --batch_size 1
14 | python main.py --exp RB --config fno --ntrain 10 --epochs 50 --batch_size 2
15 | python main.py --exp RB --config fno --ntrain 25 --epochs 50 --batch_size 5
16 |
17 | python main.py --exp RB --config codano_RB --ntrain 5 --epochs 50 --batch_size 1
18 | python main.py --exp RB --config codano_RB --ntrain 10 --epochs 50 --batch_size 2
19 | python main.py --exp RB --config codano_RB --ntrain 25 --epochs 50 --batch_size 5
--------------------------------------------------------------------------------
/fsi_animation_dx.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/fsi_animation_dx.gif
--------------------------------------------------------------------------------
/fsi_animation_pressue.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/fsi_animation_pressue.gif
--------------------------------------------------------------------------------
/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/layers/__init__.py
--------------------------------------------------------------------------------
/layers/codano_block_2D.py:
--------------------------------------------------------------------------------
1 |
2 | from functools import partial
3 | import logging
4 | import numpy as np
5 | from einops import rearrange
6 | import torch
7 | from torch import nn
8 | import torch.nn.functional as F
9 | from neuralop.layers.fno_block import FNOBlocks
10 | from .fino_2D import SpectralConvKernel2d
11 |
12 | # Implementation of 2d Codano Blocks
13 |
14 |
15 | def NO_OP(x, *_args, **_kwargs):
16 | return x
17 |
18 |
19 | AffineNormalizer2D = partial(nn.InstanceNorm2d, affine=True)
20 |
21 |
22 | class CodanoBlocks(nn.Module):
23 | def __init__(
24 | self,
25 | n_modes,
26 | n_head=1,
27 | token_codimension=1,
28 | output_scaling_factor=None,
29 | incremental_n_modes=None,
30 | head_codimension=None,
31 | use_mlp=False,
32 | mlp=None,
33 | non_linearity=F.gelu,
34 | preactivation=False,
35 | fno_skip='linear',
36 | mlp_skip='soft-gating',
37 | mlp_expansion=1.0,
38 | separable=False,
39 | factorization='tucker',
40 | rank=1.0,
41 | SpectralConvolution=None,
42 | Normalizer=None,
43 | joint_factorization=False,
44 | fixed_rank_modes=False,
45 | implementation='factorized',
46 | decomposition_kwargs=None,
47 | fft_norm='forward',
48 | codimension_size=None,
49 | per_channel_attention=True,
50 | permutation_eq=True,
51 | temperature=1.0,
52 | kqv_non_linear=False,
53 | **_kwargs,
54 | ):
55 | super().__init__()
56 |
57 | # Co-dimension of each variable/token. The token embedding space is
58 | # identical to the variable space, so their dimensionalities are equal.
59 | self.variable_codimension = token_codimension
60 | self.token_codimension = token_codimension
61 |
62 | # (maybe) codim of attention from each head
63 | self.head_codimension = (head_codimension
64 | if head_codimension is not None
65 | else token_codimension)
66 | self.n_head = n_head # number of heads
67 | self.output_scaling_factor = output_scaling_factor
68 | self.temperature = temperature
69 |
70 | # K,Q,V operator with or without non_lin
71 | if kqv_non_linear:
72 | kqv_activation = non_linearity
73 | else:
74 | kqv_activation = NO_OP
75 |
76 | self.permutation_eq = permutation_eq
77 |
78 | if self.n_head is not None:
79 | # recalculating the value of `head_codim`
80 | self.head_codimension = max(token_codimension // self.n_head, 1)
81 |
82 | self.codimension_size = codimension_size
83 | self.mixer_token_codimension = token_codimension
84 |
85 | if per_channel_attention:
86 | # for per channel attention, forcing the values of token dims
87 | self.token_codimension = 1
88 | self.head_codimension = 1
89 |
90 | # this scale used for downsampling Q,K functions
91 | scale = 2 if per_channel_attention else 1
92 | scale = min(self.n_head, scale)
93 |
94 | mixer_modes = [i // scale for i in n_modes]
95 |
96 |
97 | if decomposition_kwargs is None:
98 | decomposition_kwargs = {}
99 | common_args = dict(
100 | use_mlp=use_mlp,
101 | mlp=mlp,
102 | preactivation=preactivation,
103 | mlp_skip=mlp_skip,
104 | mlp_dropout=0,
105 | incremental_n_modes=incremental_n_modes,
106 | rank=rank,
107 | fft_norm=fft_norm,
108 | mlp_expansion=mlp_expansion,
109 | fixed_rank_modes=fixed_rank_modes,
110 | implementation=implementation,
111 | separable=separable,
112 | factorization=factorization,
113 | decomposition_kwargs=decomposition_kwargs,
114 | joint_factorization=joint_factorization,
115 | )
116 |
117 | kqv_args = dict(
118 | in_channels=self.token_codimension,
119 | out_channels=self.n_head * self.head_codimension,
120 | n_modes=mixer_modes,
121 | # args below are shared with Projection block
122 | non_linearity=kqv_activation,
123 | fno_skip='linear',
124 | norm=None,
125 | apply_skip=True,
126 | n_layers=1,
127 | )
128 | self.K = FNOBlocks(
129 | output_scaling_factor=1 / scale,
130 | SpectralConv=partial(
131 | SpectralConvolution,
132 | rank=0.5,
133 | factorization=None,
134 | ),
135 | **kqv_args,
136 | **common_args,
137 | )
138 | self.Q = FNOBlocks(
139 | output_scaling_factor=1 / scale,
140 | SpectralConv=partial(
141 | SpectralConvolution,
142 | rank=0.5,
143 | factorization=None,
144 | ),
145 | **kqv_args,
146 | **common_args,
147 | )
148 | self.V = FNOBlocks(
149 | output_scaling_factor=1,
150 | SpectralConv=partial(
151 | SpectralConvolution,
152 | rank=0.5,
153 | factorization=None,
154 | ),
155 | **kqv_args,
156 | **common_args,
157 | )
158 |
159 | if self.n_head * self.head_codimension != self.token_codimension:
160 | self.proj = FNOBlocks(
161 | in_channels=self.n_head * self.head_codimension,
162 | out_channels=self.token_codimension,
163 | n_modes=n_modes,
164 | output_scaling_factor=1,
165 | # args below are shared with KQV blocks
166 | apply_skip=True,
167 | non_linearity=NO_OP,
168 | fno_skip='linear',
169 | norm=None,
170 | SpectralConv=partial(
171 | SpectralConvolution,
172 | rank=0.5,
173 | factorization=None,
174 | ),
175 | n_layers=1,
176 | **common_args,
177 | )
178 | else:
179 | self.proj = None
180 |
181 | self.attention_normalizer = Normalizer(self.token_codimension)
182 |
183 | mixer_args = dict(
184 | n_modes=n_modes,
185 | output_scaling_factor=1,
186 | non_linearity=non_linearity,
187 | norm='instance_norm',
188 | fno_skip=fno_skip,
189 | SpectralConv=partial(
190 | SpectralConvolution,
191 | rank=0.5,
192 | factorization=None,
193 | bias=True,
194 | ),
195 | )
196 | # We have an option to make the last operator (MLP in regular
197 | # Transformer block) permutation equivariant. i.e., applying the
198 | # operator per variable or applying the operator on the whole channel
199 | # (like regular FNO).
200 | if permutation_eq:
201 | self.mixer = FNOBlocks(
202 | in_channels=self.mixer_token_codimension,
203 | out_channels=self.mixer_token_codimension,
204 | apply_skip=True,
205 | n_layers=2,
206 | **mixer_args,
207 | **common_args,
208 | )
209 | self.norm1 = Normalizer(self.token_codimension)
210 | self.norm2 = Normalizer(self.mixer_token_codimension)
211 | self.mixer_out_normalizer = Normalizer(
212 | self.mixer_token_codimension)
213 |
214 | else:
215 | self.mixer = FNOBlocks(
216 | in_channels=codimension_size,
217 | out_channels=codimension_size,
218 | n_layers=2,
219 | **mixer_args,
220 | **common_args,
221 | )
222 | self.norm1 = Normalizer(codimension_size)
223 | self.norm2 = Normalizer(codimension_size)
224 | self.mixer_out_normalizer = Normalizer(codimension_size)
225 |
226 | def forward(self, *args):
227 | raise NotImplementedError(
228 | "Use a proper subclass of CodanoBlock (i.e. CodanoBlock2d or CodanoBlock3D).")
229 |
230 |
231 | class CodanoBlocks2d(CodanoBlocks):
232 | def __init__(self, *args, **kwargs):
233 | Normalizer = kwargs.get("Normalizer")
234 | if Normalizer is None:
235 | Normalizer = AffineNormalizer2D
236 | kwargs["Normalizer"] = Normalizer
237 |
238 | Convolution = kwargs.get("SpectralConvolution")
239 | if Convolution is None:
240 | Convolution = SpectralConvKernel2d
241 | kwargs["SpectralConvolution"] = Convolution
242 |
243 | super().__init__(*args, **kwargs)
244 |
245 | # XXX rewrite comments on TNO*3D
246 | def compute_attention(self, xa, batch_size):
247 | """Compute the key-query-value variant of the attention matrix.
248 |
249 | Assumes input ``xa`` has been normalized.
250 | """
251 | k = self.K(xa)
252 | q = self.Q(xa)
253 | v = self.V(xa)
254 |
255 | value_x, value_y = v.shape[-2], v.shape[-1]
256 |
257 | rearrangement = dict(
258 | pattern='(b t) (a d) h w -> b a t (d h w)',
259 | b=batch_size,
260 | a=self.n_head,
261 | )
262 | k = rearrange(k, **rearrangement)
263 | q = rearrange(q, **rearrangement)
264 | v = rearrange(v, **rearrangement)
265 |
266 | dprod = (torch.matmul(q, k.transpose(-1, -2)) /
267 | (np.sqrt(k.shape[-1]) * self.temperature))
268 | dprod = F.softmax(dprod, dim=-1)
269 |
270 | attention = torch.matmul(dprod, v)
271 | attention = rearrange(
272 | attention,
273 | 'b a t (d h w) -> b t a d h w',
274 | d=self.head_codimension,
275 | h=value_x,
276 | w=value_y,
277 | )
278 | attention = rearrange(attention, 'b t a d h w -> (b t) (a d) h w')
279 | return attention
280 |
281 | def forward(self, x, output_shape=None):
282 | if self.permutation_eq:
283 | return self._forward_equivariant(x)
284 | else:
285 | return self._forward_non_equivariant(x)
286 |
287 | def _forward_equivariant(self, x):
288 | batch_size = x.shape[0]
289 | output_shape = x.shape[-2:]
290 |
291 | assert x.shape[1] % self.token_codimension == 0
292 |
293 | xa = rearrange(x, 'b (t d) h w -> (b t) d h w',
294 | d=self.token_codimension)
295 | xa_norm = self.norm1(xa)
296 |
297 | attention = self.compute_attention(xa_norm, batch_size)
298 | if self.proj is not None:
299 | attention = self.proj(attention)
300 |
301 | attention = self.attention_normalizer(attention + xa)
302 | attention = rearrange(
303 | attention, '(b t) d h w -> b (t d) h w', b=batch_size)
304 | # print("{attention.shape=}")
305 | attention = rearrange(
306 | attention,
307 | 'b (t d) h w -> (b t) d h w',
308 | d=self.mixer_token_codimension)
309 | # print("{attention.shape=}")
310 |
311 | attention_normalized = self.norm2(attention)
312 | output = self.mixer(attention_normalized, output_shape=output_shape)
313 |
314 | output = self.mixer_out_normalizer(output) + attention
315 | # print(f"{output.shape=}")
316 | output = rearrange(output, '(b t) d h w -> b (t d) h w', b=batch_size)
317 |
318 | return output
319 |
320 | def _forward_non_equivariant(self, x):
321 | batch_size = x.shape[0]
322 | output_shape = x.shape[-2:]
323 |
324 | assert x.shape[1] % self.token_codimension == 0
325 |
326 | x_norm = self.norm1(x)
327 | xa = rearrange(x_norm, 'b (t d) h w -> (b t) d h w',
328 | d=self.token_codimension)
329 |
330 | attention = self.compute_attention(xa, batch_size)
331 | if self.proj is not None:
332 | attention = self.proj(attention)
333 |
334 | attention = rearrange(
335 | attention, '(b t) d h w -> b (t d) h w', b=batch_size)
336 | attention_normalized = self.norm2(attention)
337 | output = self.mixer(attention_normalized, output_shape=output_shape)
338 |
339 | return output
340 |
--------------------------------------------------------------------------------
/layers/codano_block_nd.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import logging
3 | from typing import Optional, Callable, Union, Dict
4 |
5 | import numpy as np
6 | from einops import rearrange
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from neuralop.layers.fno_block import FNOBlocks
11 | from .fino_nd import SpectralConvKernel2d, SpectralConvKernel1d, SpectralConvKernel3d
12 |
13 |
14 | # Implementation of generic N dimentional Codano Block
15 | AffineNormalizer1D = partial(nn.InstanceNorm1d, affine=True)
16 | AffineNormalizer2D = partial(nn.InstanceNorm2d, affine=True)
17 | AffineNormalizer3D = partial(nn.InstanceNorm3d, affine=True)
18 | # For higher demnsion (>=4), need to implement custom
19 |
20 | def Identity(x, *args, **kwargs):
21 | return x
22 |
23 |
24 | class CodanoBlock(nn.Module):
25 | def __init__(
26 | self,
27 | n_modes: Union[int, tuple],
28 | n_head: int = 1,
29 | token_codimension: int = 1,
30 | output_scaling_factor: Optional[float] = None,
31 | max_n_modes: Optional[Union[int, tuple]] = None,
32 | head_codimension: Optional[int] = None,
33 | use_mlp: bool = False,
34 | mlp: Optional[nn.Module] = None,
35 | non_linearity: Callable = F.gelu,
36 | preactivation: bool = False,
37 | fno_skip: str = 'linear',
38 | mlp_skip: str = 'linear',
39 | mlp_expansion: float = 1.0,
40 | separable: bool = False,
41 | factorization: str = None,
42 | rank: float = 1.0,
43 | SpectralConvolution: Optional[Callable] = None,
44 | Normalizer: Optional[Callable] = None,
45 | joint_factorization: bool = False,
46 | fixed_rank_modes: bool = False,
47 | implementation: str = 'reconstructed',
48 | decomposition_kwargs: Optional[Dict] = None,
49 | fft_norm: str = 'forward',
50 | codimension_size: Optional[int] = None,
51 | per_channel_attention: bool = True,
52 | permutation_eq: bool = True,
53 | temperature: float = 1.0,
54 | kqv_non_linear: bool = False,
55 | num_dims: int = 2,
56 | **_kwargs,
57 | ):
58 | super().__init__()
59 |
60 | self.variable_codimension = token_codimension
61 | self.token_codimension = token_codimension
62 | self.head_codimension = head_codimension or token_codimension
63 | self.n_head = n_head
64 | self.output_scaling_factor = output_scaling_factor
65 | self.temperature = temperature
66 | self.num_dims = num_dims
67 |
68 | if kqv_non_linear:
69 | kqv_activation = non_linearity
70 | else:
71 | kqv_activation = Identity
72 |
73 | self.permutation_eq = permutation_eq
74 |
75 | if self.n_head is not None:
76 | self.head_codimension = max(token_codimension // self.n_head, 1)
77 |
78 | self.codimension_size = codimension_size
79 | self.mixer_token_codimension = token_codimension
80 |
81 | if per_channel_attention:
82 | self.token_codimension = 1
83 | self.head_codimension = 1
84 |
85 | scale = min(self.n_head, 1 if per_channel_attention else 1)
86 |
87 | mixer_modes = [i // scale for i in n_modes]
88 | mixer_n_modes = [i // scale for i in max_n_modes]
89 |
90 | decomposition_kwargs = decomposition_kwargs or {}
91 | common_args = dict(
92 | use_mlp=use_mlp,
93 | mlp=mlp,
94 | preactivation=preactivation,
95 | mlp_skip=mlp_skip,
96 | mlp_dropout=0,
97 | rank=rank,
98 | fft_norm=fft_norm,
99 | mlp_expansion=mlp_expansion,
100 | fixed_rank_modes=fixed_rank_modes,
101 | implementation=implementation,
102 | separable=separable,
103 | factorization=factorization,
104 | decomposition_kwargs=decomposition_kwargs,
105 | joint_factorization=joint_factorization,
106 | )
107 |
108 | kqv_args = dict(
109 | in_channels=self.token_codimension,
110 | out_channels=self.n_head * self.head_codimension,
111 | n_modes=mixer_modes,
112 | max_n_modes=mixer_n_modes,
113 | non_linearity=kqv_activation,
114 | fno_skip='linear',
115 | norm=None,
116 | apply_skip=True,
117 | n_layers=1,
118 | )
119 |
120 | rank = 1.0
121 | conv_kwargs = dict(rank=rank, factorization=None)
122 | self.K = FNOBlocks(
123 | output_scaling_factor=1 / scale,
124 | SpectralConv=partial(
125 | SpectralConvolution,
126 | **conv_kwargs
127 | ),
128 | **kqv_args,
129 | **common_args,
130 | )
131 | self.Q = FNOBlocks(
132 | output_scaling_factor=1 / scale,
133 | SpectralConv=partial(
134 | SpectralConvolution,
135 | **conv_kwargs,
136 | ),
137 | **kqv_args,
138 | **common_args,
139 | )
140 | self.V = FNOBlocks(
141 | output_scaling_factor=1,
142 | SpectralConv=partial(
143 | SpectralConvolution,
144 | **conv_kwargs,
145 | ),
146 | **kqv_args,
147 | **common_args,
148 | )
149 |
150 | if self.n_head * self.head_codimension != self.token_codimension:
151 | self.proj = FNOBlocks(
152 | in_channels=self.n_head * self.head_codimension,
153 | out_channels=self.token_codimension,
154 | n_modes=n_modes,
155 | max_n_modes=max_n_modes,
156 | output_scaling_factor=1,
157 | apply_skip=True,
158 | non_linearity=Identity,
159 | fno_skip='linear',
160 | norm=None,
161 | SpectralConv=partial(
162 | SpectralConvolution,
163 | rank=1.0,
164 | factorization=None,
165 | ),
166 | n_layers=1,
167 | **common_args,
168 | )
169 | else:
170 | self.proj = None
171 |
172 | self.attention_normalizer = Normalizer(self.token_codimension)
173 |
174 | mixer_args = dict(
175 | n_modes=n_modes,
176 | max_n_modes=max_n_modes,
177 | output_scaling_factor=1,
178 | non_linearity=non_linearity,
179 | norm='instance_norm',
180 | fno_skip=fno_skip,
181 | SpectralConv=partial(
182 | SpectralConvolution,
183 | rank=rank,
184 | factorization=None,
185 | bias=True,
186 | ),
187 | )
188 |
189 | if self.permutation_eq:
190 | self.mixer = FNOBlocks(
191 | in_channels=self.mixer_token_codimension,
192 | out_channels=self.mixer_token_codimension,
193 | apply_skip=True,
194 | n_layers=2,
195 | **mixer_args,
196 | **common_args,
197 | )
198 | self.norm1 = Normalizer(self.token_codimension)
199 | self.norm2 = Normalizer(self.mixer_token_codimension)
200 | self.mixer_out_normalizer = Normalizer(
201 | self.mixer_token_codimension)
202 | else:
203 | self.mixer = FNOBlocks(
204 | in_channels=codimension_size,
205 | out_channels=codimension_size,
206 | n_layers=2,
207 | **mixer_args,
208 | **common_args,
209 | )
210 | self.norm1 = Normalizer(codimension_size)
211 | self.norm2 = Normalizer(codimension_size)
212 | self.mixer_out_normalizer = Normalizer(codimension_size)
213 |
214 | def forward(self, *args):
215 | raise NotImplementedError(
216 | "Use a proper subclass of CodanoBlock (i.e. CodanoBlockND or CodanoBlock3D).")
217 |
218 |
219 | class CodanoBlockND(CodanoBlock):
220 | def __init__(self, *args, **kwargs):
221 | if kwargs["num_dims"] == 1:
222 | Normalizer = kwargs.get("Normalizer", AffineNormalizer1D)
223 | kwargs["Normalizer"] = Normalizer
224 |
225 | Convolution = kwargs.get(
226 | "SpectralConvolution", SpectralConvKernel1d)
227 | kwargs["SpectralConvolution"] = Convolution
228 | elif kwargs["num_dims"] == 2:
229 | Normalizer = kwargs.get("Normalizer", AffineNormalizer2D)
230 | kwargs["Normalizer"] = Normalizer
231 |
232 | Convolution = kwargs.get(
233 | "SpectralConvolution", SpectralConvKernel2d)
234 | kwargs["SpectralConvolution"] = Convolution
235 | elif kwargs["num_dims"] == 3:
236 | Normalizer = kwargs.get("Normalizer", AffineNormalizer3D)
237 | kwargs["Normalizer"] = Normalizer
238 |
239 | Convolution = kwargs.get(
240 | "SpectralConvolution", SpectralConvKernel3d)
241 | kwargs["SpectralConvolution"] = Convolution
242 |
243 | super().__init__(*args, **kwargs)
244 |
245 | def compute_attention(self, xa, batch_size):
246 | k = self.K(xa)
247 | q = self.Q(xa)
248 | v = self.V(xa)
249 |
250 | v_shape = v.shape[-self.num_dims:]
251 |
252 | rearrangement = dict(
253 | pattern=f'(b k) (a d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> b a k (d {" ".join(f"d{i}" for i in range(self.num_dims))})',
254 | b=batch_size,
255 | a=self.n_head,
256 | )
257 | k = rearrange(k, **rearrangement)
258 | q = rearrange(q, **rearrangement)
259 | v = rearrange(v, **rearrangement)
260 |
261 | dprod = torch.matmul(q, k.transpose(-1, -2))
262 | dprod = dprod / (self.temperature * np.sqrt(k.shape[-1]))
263 | dprod = F.softmax(dprod, dim=-1)
264 |
265 | attention = torch.matmul(dprod, v)
266 | rearrange_args = dict(
267 | pattern=f'b a k (d {" ".join(f"d{i}" for i in range(self.num_dims))}) -> b k a d {" ".join(f"d{i}" for i in range(self.num_dims))}',
268 | d=self.head_codimension,
269 | )
270 | rearrange_args.update(
271 | {f'd{i}': v_shape[i] for i in range(self.num_dims)})
272 | attention = rearrange(attention, **rearrange_args)
273 | attention = rearrange(
274 | attention,
275 | f'b k a d {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) (a d) {" ".join(f"d{i}" for i in range(self.num_dims))}')
276 | return attention
277 |
278 | def forward(self, x, output_shape=None):
279 | if self.permutation_eq:
280 | return self._forward_equivariant(x)
281 | else:
282 | return self._forward_non_equivariant(x)
283 |
284 | def _forward_equivariant(self, x):
285 | batch_size = x.shape[0]
286 | output_shape = x.shape[-self.num_dims:]
287 |
288 | assert x.shape[1] % self.token_codimension == 0
289 |
290 | xa = rearrange(
291 | x,
292 | f'b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) d {" ".join(f"d{i}" for i in range(self.num_dims))}',
293 | d=self.token_codimension)
294 | xa_norm = self.norm1(xa)
295 |
296 | attention = self.compute_attention(xa_norm, batch_size)
297 | if self.proj is not None:
298 | attention = self.proj(attention)
299 |
300 | attention = self.attention_normalizer(attention + xa)
301 | attention = rearrange(
302 | attention,
303 | f'(b k) d {" ".join(f"d{i}" for i in range(self.num_dims))} -> b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))}',
304 | b=batch_size)
305 | attention = rearrange(
306 | attention,
307 | f'b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) d {" ".join(f"d{i}" for i in range(self.num_dims))}',
308 | d=self.mixer_token_codimension)
309 |
310 | attention_normalized = self.norm2(attention)
311 | output = self.mixer(attention_normalized, output_shape=output_shape)
312 |
313 | output = self.mixer_out_normalizer(output) + attention
314 | output = rearrange(
315 | output,
316 | f'(b k) d {" ".join(f"d{i}" for i in range(self.num_dims))} -> b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))}',
317 | b=batch_size)
318 |
319 | return output
320 |
321 | def _forward_non_equivariant(self, x):
322 | batch_size = x.shape[0]
323 | output_shape = x.shape[-self.num_dims:]
324 |
325 | assert x.shape[1] % self.token_codimension == 0
326 |
327 | x_norm = self.norm1(x)
328 | xa = rearrange(
329 | x_norm,
330 | f'b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))} -> (b k) d {" ".join(f"d{i}" for i in range(self.num_dims))}',
331 | d=self.token_codimension)
332 |
333 | attention = self.compute_attention(xa, batch_size)
334 | if self.proj is not None:
335 | attention = self.proj(attention)
336 |
337 | attention = rearrange(
338 | attention,
339 | f'(b k) d {" ".join(f"d{i}" for i in range(self.num_dims))} -> b (k d) {" ".join(f"d{i}" for i in range(self.num_dims))}',
340 | b=batch_size)
341 | attention_normalized = self.norm2(attention)
342 | output = self.mixer(attention_normalized, output_shape=output_shape)
343 |
344 | return output
345 |
--------------------------------------------------------------------------------
/layers/fino_2D.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | import torch_harmonics as th
8 | from neuralop.layers.spectral_convolution import SpectralConv
9 |
10 |
11 | class SpectralConvKernel2d(SpectralConv):
12 | """
13 | Parameters
14 | ---
15 | transform_type : {'sht', 'fft'}
16 | * If "sht" it uses the Spherical Fourier Transform.
17 | * If "fft" it uses the Fast Fourier Transform.
18 | """
19 |
20 | def __init__(
21 | self,
22 | in_channels,
23 | out_channels,
24 | n_modes,
25 | incremental_n_modes=None,
26 | bias=True,
27 | n_layers=1,
28 | separable=False,
29 | output_scaling_factor=None,
30 | rank=0.5,
31 | factorization='dense',
32 | implementation='reconstructed',
33 | fno_block_precision='full',
34 | fixed_rank_modes=False,
35 | joint_factorization=False,
36 | decomposition_kwargs=None,
37 | init_std='auto',
38 | fft_norm='forward',
39 | transform_type="sht",
40 | sht_nlat=180,
41 | sht_nlon=360,
42 | sht_grid="legendre-gauss",
43 | isht_grid="legendre-gauss",
44 | sht_norm="backward",
45 | frequency_mixer=False,
46 | verbose=True,
47 | logger=None
48 | ):
49 | self.verbose = verbose
50 |
51 | if decomposition_kwargs is None:
52 | decomposition_kwargs = {}
53 | super().__init__(
54 | in_channels,
55 | out_channels,
56 | n_modes,
57 | incremental_n_modes,
58 | bias=bias,
59 | n_layers=n_layers,
60 | separable=separable,
61 | output_scaling_factor=output_scaling_factor,
62 | fno_block_precision=fno_block_precision,
63 | rank=rank,
64 | factorization=factorization,
65 | implementation=implementation,
66 | fixed_rank_modes=fixed_rank_modes,
67 | joint_factorization=joint_factorization,
68 | decomposition_kwargs=decomposition_kwargs,
69 | init_std=init_std,
70 | fft_norm=fft_norm,
71 | )
72 |
73 | # self.shared = shared
74 |
75 | # readjusting initialization
76 | if init_std == "auto":
77 | init_std = 1 / np.sqrt(in_channels * n_modes[-1] * n_modes[-2])
78 | else:
79 | init_std = init_std
80 |
81 | for w in self.weight:
82 | w.normal_(0, init_std)
83 |
84 | # weights for frequency mixers
85 |
86 | hm1, hm2 = self.half_n_modes[0], self.half_n_modes[1]
87 | if frequency_mixer:
88 | # if frequency mixer is true
89 | # then initializing weights for frequncy mixing
90 | # otherwise it is just a regular FNO or SFNO layer
91 | print("using Mixer")
92 | self.W1r = nn.Parameter(
93 | torch.empty(
94 | hm1,
95 | hm2,
96 | hm1,
97 | hm2,
98 | dtype=torch.float))
99 | self.W2r = nn.Parameter(
100 | torch.empty(
101 | hm1,
102 | hm2,
103 | hm1,
104 | hm2,
105 | dtype=torch.float))
106 | self.W1i = nn.Parameter(
107 | torch.empty(
108 | hm1,
109 | hm2,
110 | hm1,
111 | hm2,
112 | dtype=torch.float))
113 | self.W2i = nn.Parameter(
114 | torch.empty(
115 | hm1,
116 | hm2,
117 | hm1,
118 | hm2,
119 | dtype=torch.float))
120 | self.reset_parameter()
121 |
122 | self.sht_grid = sht_grid
123 | self.isht_grid = isht_grid
124 | self.sht_norm = sht_norm
125 | self.transform_type = transform_type
126 | self.frequency_mixer = frequency_mixer
127 |
128 |
129 | if self.output_scaling_factor is not None:
130 | out_nlat = round(sht_nlat * self.output_scaling_factor[0][0])
131 | out_nlon = round(sht_nlon * self.output_scaling_factor[0][1])
132 | else:
133 | out_nlat = sht_nlat
134 | out_nlon = sht_nlon
135 |
136 | if self.transform_type == "sht":
137 | self.forward_sht = th.RealSHT(
138 | sht_nlat,
139 | sht_nlon,
140 | grid=self.sht_grid,
141 | norm=self.sht_norm,
142 | )
143 | self.inverse_sht = th.InverseRealSHT(
144 | out_nlat,
145 | out_nlon,
146 | grid=self.isht_grid,
147 | norm=self.sht_norm,
148 | )
149 |
150 | def reset_parameter(self):
151 | # Initial model parameters.
152 | scaling_factor = ((1 / self.in_channels)**0.5) / \
153 | (self.half_n_modes[0] * self.half_n_modes[1])
154 | torch.nn.init.normal_(self.W1r, mean=0.0, std=scaling_factor)
155 | torch.nn.init.normal_(self.W2r, mean=0.0, std=scaling_factor)
156 | torch.nn.init.normal_(self.W1i, mean=0.0, std=scaling_factor)
157 | torch.nn.init.normal_(self.W2i, mean=0.0, std=scaling_factor)
158 |
159 | @staticmethod
160 | def mode_mixer(x, weights):
161 | return torch.einsum("bimn,mnop->biop", x, weights)
162 |
163 | def forward_transform(self, x):
164 | height, width = x.shape[-2:]
165 | if self.transform_type == "fft":
166 | return torch.fft.rfft2(x.float(), norm=self.fft_norm)
167 |
168 | if self.transform_type == "sht":
169 | # The SHT is expensive to initialize, and during training we expect
170 | # the data to all be of the same shape. If we have a correct SHT,
171 | # let's use it:
172 | if (
173 | self.forward_sht.nlat == height and
174 | self.forward_sht.nlon == width
175 | ):
176 | return self.forward_sht(x.double()).to(dtype=torch.cfloat)
177 |
178 | # Otherwise, initialize a new SHT:
179 | self.forward_sht = th.RealSHT(
180 | height,
181 | width,
182 | grid=self.sht_grid,
183 | norm=self.sht_norm,
184 | ).to(x.device)
185 | return self.forward_sht(x.double()).to(dtype=torch.cfloat)
186 |
187 | raise ValueError(
188 | 'Expected `transform_type` to be one of "fft" or "sht"; '
189 | f'Got {self.transform_type=}'
190 | )
191 |
192 | # Although a previous implementation kept an initialized
193 | # ``th.InverseRealSHT`` in its state, it always checked if its lat/lon grid
194 | # size matched the input's
195 | # resolution. Thus, it never really mattered that an object was in state.
196 | def inverse_transform(
197 | self,
198 | x: torch.Tensor,
199 | target_height: int,
200 | target_width: int,
201 | device,
202 | ):
203 | source_height, source_width = x.shape[-2:]
204 | if self.transform_type == "fft":
205 | return torch.fft.irfft2(
206 | x,
207 | s=(target_height, target_width),
208 | dim=(-2, -1),
209 | norm=self.fft_norm,
210 | )
211 |
212 | if self.transform_type == "sht":
213 | # The SHT is expensive to initialize, and during training we expect
214 | # the data to all be of the same shape. If we have a correct SHT,
215 | # let's use it:
216 | if (
217 | self.inverse_sht.lmax == source_height and
218 | self.inverse_sht.mmax == source_width and
219 | self.inverse_sht.nlat == target_height and
220 | self.inverse_sht.nlon == target_width
221 | ):
222 | return self.inverse_sht(x.to(dtype=torch.cdouble)).float()
223 |
224 | # Otherwise, initialize a new SHT:
225 | self.inverse_sht = th.InverseRealSHT(
226 | target_height,
227 | target_width,
228 | lmax=source_height,
229 | mmax=source_width,
230 | grid=self.sht_grid,
231 | norm=self.sht_norm,
232 | ).to(device)
233 | return self.inverse_sht(x.to(dtype=torch.cdouble)).float()
234 |
235 | raise ValueError(
236 | 'Expected `transform_type` to be one of "fft" or "sht"; '
237 | f'Got {self.transform_type=}'
238 | )
239 |
240 | def forward(self, x, indices=0, output_shape=None):
241 | batch_size, channels, height, width = x.shape
242 |
243 | x = self.forward_transform(x)
244 |
245 | upper_modes = [
246 | slice(None),
247 | slice(None),
248 | slice(None, self.half_n_modes[0]),
249 | slice(None, self.half_n_modes[1]),
250 | ]
251 | """Slice for upper frequency modes.
252 |
253 | Equivalent to: ``x[:, :, :self.half_n_modes[0], :self.half_n_modes[1]]``
254 | """
255 |
256 | lower_modes = [
257 | slice(None),
258 | slice(None),
259 | slice(-self.half_n_modes[0], None),
260 | slice(None, self.half_n_modes[1]),
261 | ]
262 | """Slice for lower frequency modes.
263 |
264 | Equivalent to: ``x[:, :, -self.half_n_modes[0]:, :self.half_n_modes[1]]``
265 | """
266 |
267 | # mode mixer
268 | # uses separate MLP to mix mode along each co-dim/channels
269 | if self.frequency_mixer:
270 | W1 = self.W1r + 1.0j * self.W1i
271 | W2 = self.W2r + 1.0j * self.W2i
272 |
273 | x[upper_modes] = self.mode_mixer(x[upper_modes].clone(), W1)
274 | x[lower_modes] = self.mode_mixer(x[lower_modes].clone(), W2)
275 |
276 | # spectral conv / channel mixer
277 |
278 | # The output will be of size:
279 | # (batch_size, self.out_channels, x.size(-2), x.size(-1)//2 + 1)
280 | out_fft = torch.zeros(
281 | [batch_size, self.out_channels, height, width // 2 + 1],
282 | dtype=x.dtype,
283 | device=x.device,
284 | )
285 |
286 | # Upper block (truncate high frequencies):
287 | out_fft[upper_modes] = self._contract(
288 | x[upper_modes],
289 | self._get_weight(2 * indices),
290 | separable=self.separable,
291 | )
292 | # Lower block (truncate low frequencies):
293 | out_fft[lower_modes] = self._contract(
294 | x[lower_modes],
295 | self._get_weight(2 * indices + 1),
296 | separable=self.separable,
297 | )
298 |
299 | if self.output_scaling_factor is not None and output_shape is None:
300 | height = round(height * self.output_scaling_factor[indices][0])
301 | width = round(width * self.output_scaling_factor[indices][1])
302 |
303 | if output_shape is not None:
304 | height = output_shape[0]
305 | width = output_shape[1]
306 |
307 | x = self.inverse_transform(out_fft, height, width, x.device)
308 |
309 | if self.bias is not None:
310 | x = x + self.bias[indices, ...]
311 |
312 | return x
313 |
--------------------------------------------------------------------------------
/layers/gnn_layer.py:
--------------------------------------------------------------------------------
1 | from baseline_utlis import FixedNeighborSearch
2 | from neuralop.layers.integral_transform import IntegralTransform
3 | from neuralop.layers.mlp import MLPLinear
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class GnnLayer(nn.Module):
9 | def __init__(self, in_dim, out_dim,
10 | input_grid, output_grid, mlp_layers, projection_hidden_dim,
11 | n_neigbor):
12 | '''
13 | var_num: number of variables
14 | in_dim: Input Condim/channel per variables
15 | out_dim: Output Condim/channel per variables
16 | input_grid: Input grid (points)
17 | output_grid: Output grid (points)
18 | mlp_layers: MLP layers (for integral operator)
19 | projection_hidden_dim: Before applying integral operator we have pointwise MLP. This parameter
20 | determines the width of the multi-layered MLP
21 | n_neigbor: number of neighbours to consider
22 | '''
23 | super().__init__()
24 |
25 | n_dim = input_grid.shape[-1]
26 | self.in_dim = in_dim
27 | self.out_dim = out_dim
28 | self.input_grid = input_grid
29 | self.output_grid = output_grid
30 | self.mlp_layers = [2 * n_dim + out_dim] + mlp_layers + [out_dim]
31 | self.n_neigbor = n_neigbor
32 | # project to higher dim
33 | self.projection = MLPLinear([self.in_dim,
34 | projection_hidden_dim, out_dim])
35 | print([self.in_dim, projection_hidden_dim, out_dim])
36 | print(self.mlp_layers)
37 | # apply GNO to get uniform grid
38 | NS = FixedNeighborSearch(use_open3d=False)
39 |
40 | self.neighbour = NS(
41 | input_grid.clone().cpu(),
42 | output_grid.clone().cpu(),
43 | n_neigbor=n_neigbor)
44 |
45 | for key, value in self.neighbour.items():
46 | self.neighbour[key] = self.neighbour[key].cuda()
47 |
48 | self.it = IntegralTransform(
49 | mlp_layers=self.mlp_layers, transform_type='nonlinear')
50 |
51 | self.normalize = nn.LayerNorm(out_dim)
52 |
53 | def update_grid(
54 | self,
55 | input_grid=None,
56 | output_grid=None
57 | ):
58 |
59 | if input_grid is None:
60 | input_grid = self.input_grid
61 | if output_grid is None:
62 | output_grid = self.output_grid
63 |
64 | input_grid = input_grid.clone()
65 | self.input_grid = self.input_grid[:input_grid.shape[0], :]
66 | self.output_grid = self.output_grid[:output_grid.shape[0], :]
67 |
68 | NS = FixedNeighborSearch(use_open3d=False)
69 |
70 | self.neighbour = NS(
71 | input_grid.clone(),
72 | output_grid.clone(),
73 | n_neigbor=self.n_neigbor)
74 | for key, value in self.neighbour.items():
75 | self.neighbour[key] = self.neighbour[key].cuda()
76 |
77 | def forward(self, inp):
78 | '''
79 | inp : (batch_size, n_points, in_dims/Channels)
80 | '''
81 | x = inp
82 | x = self.projection(x)
83 | out = self.it(self.input_grid, self.neighbour,
84 | self.output_grid, x)
85 |
86 |
87 | if out.shape == x.shape:
88 | out = out + x
89 | out = self.normalize(out)
90 | return out
91 |
--------------------------------------------------------------------------------
/layers/gno_layer.py:
--------------------------------------------------------------------------------
1 | from neuralop.layers.neighbor_search import NeighborSearch
2 | from neuralop.layers.integral_transform import IntegralTransform
3 | from neuralop.layers.mlp import MLPLinear
4 | from baseline_utlis import FixedNeighborSearch
5 | from einops import rearrange
6 | from neuralop.layers.embeddings import PositionalEmbedding
7 | import torch.nn as nn
8 | import torch
9 | import torch.nn.functional as F
10 |
11 |
12 | class GnoPremEq(nn.Module):
13 | def __init__(
14 | self,
15 | var_num,
16 | in_dim, out_dim,
17 | input_grid,
18 | output_grid,
19 | mlp_layers,
20 | projection_hidden_dim,
21 | radius,
22 | var_encoding=False,
23 | n_neigbor=10,
24 | fixed_neighbour=False,
25 | var_encoding_channels=1,
26 | n_layers=2,
27 | postional_em_dim=4, # always even
28 | end_projection=False,
29 | end_projection_outdim=None,
30 | ):
31 | '''
32 | var_num: number of variables
33 | in_dim: Input Condim/channel per variables
34 | out_dim: Output Condim/channel per variables
35 | input_grid: Input grid (points)
36 | output_grid: Output grid (points)
37 | mlp_layers: MLP layers (for integral operator)
38 | projection_hidden_dim: Before applying integral operator we have pointwise MLP. This parameter
39 | determines the width of the multi-layered MLP
40 | radius: radius of the neighbourhood
41 | var_encoding: whether to use variable encoding
42 | var_encoding_channels: number of channels for variable encoding
43 | '''
44 | super().__init__()
45 | assert postional_em_dim % 2 == 0
46 | n_dim = input_grid.shape[-1]
47 | self.radius = radius
48 | self.fixed_neighbour = fixed_neighbour
49 | self.n_neigbor = n_neigbor
50 | self.var_num = var_num
51 | self.in_dim = in_dim
52 | self.out_dim = out_dim
53 | self.input_grid = input_grid
54 | self.output_grid = output_grid
55 | self.mlp_layers = [2 * n_dim + self.out_dim] + mlp_layers + [out_dim]
56 | self.var_encoding = var_encoding
57 | self.postional_em_dim = postional_em_dim
58 | self.var_encoding_channels = var_encoding_channels
59 | self.n_layers = n_layers
60 | self.end_projection = end_projection
61 | self.end_projection_outdim = end_projection_outdim
62 |
63 | # get varibale encoding
64 | if self.var_encoding:
65 | self.var_encoder = MLPLinear(
66 | [n_dim + 2 * postional_em_dim, self.var_encoding_channels * var_num])
67 | self.PE = PositionalEmbedding(postional_em_dim)
68 | self.variable_channels = [
69 | i * (var_encoding_channels + self.in_dim) for i in range(var_num)]
70 | self.encoding_channels = list(set([i for i in range(
71 | (var_encoding_channels + 1) * var_num)]) - set(self.variable_channels))
72 | else:
73 | self.var_encoding_channels = 0
74 |
75 | # project to higher dim
76 | self.projection = MLPLinear([self.var_encoding_channels + self.in_dim,
77 | projection_hidden_dim, out_dim],
78 | non_linearity=F.gelu)
79 |
80 | # apply GNO to get uniform grid
81 |
82 | self.neighbour = None
83 | self.neighbour_last = None
84 | self.update_grid()
85 |
86 | self.it = torch.nn.ModuleList()
87 | for i in range(n_layers):
88 | self.it.append(IntegralTransform(
89 | mlp_layers=self.mlp_layers,
90 | transform_type='nonlinear',
91 | mlp_non_linearity=F.gelu))
92 |
93 | if self.end_projection:
94 | self.end_projector = MLPLinear([self.out_dim,
95 | projection_hidden_dim, self.end_projection_outdim],
96 | non_linearity=F.gelu)
97 |
98 | def update_grid(
99 | self,
100 | input_grid=None,
101 | output_grid=None
102 | ):
103 | if input_grid is None:
104 | input_grid = self.input_grid
105 | if output_grid is None:
106 | output_grid = self.output_grid
107 |
108 | input_grid = input_grid.clone()
109 | self.input_grid = self.input_grid[:input_grid.shape[0], :]
110 | self.output_grid = self.output_grid[:output_grid.shape[0], :]
111 | if self.fixed_neighbour:
112 | NS = FixedNeighborSearch(use_open3d=False)
113 | self.neighbour = NS(
114 | input_grid.clone().cpu(),
115 | input_grid.clone().cpu(),
116 | n_neigbor=self.n_neigbor)
117 | else:
118 | NS = NeighborSearch(use_open3d=False)
119 | self.neighbour = NS(
120 | input_grid.clone().cpu(),
121 | input_grid.clone().cpu(),
122 | radius=self.radius)
123 |
124 | for key, value in self.neighbour.items():
125 | self.neighbour[key] = self.neighbour[key].cuda()
126 |
127 | NS_last = FixedNeighborSearch(use_open3d=False)
128 | self.neighbour_last = NS_last(
129 | input_grid.clone().cpu(),
130 | output_grid.clone().cpu(),
131 | n_neigbor=self.n_neigbor)
132 |
133 | for key, value in self.neighbour_last.items():
134 | self.neighbour_last[key] = self.neighbour_last[key].cuda()
135 |
136 | def _intergral_transform(self, x):
137 | for i in range(self.n_layers):
138 | if i == self.n_layers - 1:
139 | x = self.it[i](self.input_grid, self.neighbour_last,
140 | self.output_grid, x)
141 | if self.end_projection:
142 | x = self.end_projector(x)
143 | else:
144 | x = self.it[i](self.input_grid, self.neighbour,
145 | self.input_grid, x) + x
146 | return x
147 |
148 | def forward(self, inp):
149 | '''
150 | inp : (batch_size, n_points, in_dims/Channels)
151 | '''
152 |
153 | if self.var_encoding:
154 | x = torch.zeros((inp.shape[0], inp.shape[1], len(
155 | self.variable_channels) + len(self.encoding_channels)), device=inp.device, dtype=inp.dtype)
156 |
157 | pe = self.PE(self.input_grid.reshape(-1))
158 | pe = pe.reshape(self.input_grid.shape[0], -1)
159 | grid_pe = torch.cat([self.input_grid, pe], axis=1)
160 | var_encoding = self.var_encoder(grid_pe).to(x.device)
161 | x[:, :, self.variable_channels] = inp
162 | x[:, :, self.encoding_channels] = var_encoding[None,
163 | :, :].repeat(x.shape[0], 1, 1)
164 | else:
165 | x = inp
166 |
167 |
168 | x = rearrange(
169 | x,
170 | 'b n (v c) -> (b n) v c',
171 | c=self.in_dim +
172 | self.var_encoding_channels)
173 | x = self.projection(x)
174 |
175 | out = None
176 |
177 | for i in range(x.shape[-2]):
178 | # print(i)
179 |
180 | temp = self._intergral_transform(x[:, i, :])
181 | if out is None:
182 | out = temp[None, ...]
183 | else:
184 | out = torch.cat([out, temp[None, ...]], dim=2)
185 |
186 |
187 | return out
188 |
189 |
190 | class GNO(nn.Module):
191 | def __init__(self, in_dim, out_dim,
192 | input_grid, output_grid, mlp_layers, projection_hidden_dim,
193 | radius, fixed_neighbour=False, n_neigbor=10):
194 | '''
195 | var_num: number of variables
196 | in_dim: Input Condim/channel per variables
197 | out_dim: Output Condim/channel per variables
198 | input_grid: Input grid (points)
199 | output_grid: Output grid (points)
200 | mlp_layers: MLP layers (for integral operator)
201 | projection_hidden_dim: Before applying integral operator we have pointwise MLP. This parameter
202 | determines the width of the multi-layered MLP
203 | radius: radius of the neighbourhood
204 | '''
205 | super().__init__()
206 |
207 | n_dim = input_grid.shape[-1]
208 | self.in_dim = in_dim
209 | self.out_dim = out_dim
210 | self.input_grid = input_grid
211 | self.output_grid = output_grid
212 | self.mlp_layers = [2 * n_dim] + mlp_layers + [out_dim]
213 | self.fixed_neighbour = fixed_neighbour
214 | self.n_neigbor = n_neigbor
215 | self.radius = radius
216 | # project to higher dim
217 | self.projection = MLPLinear([self.in_dim,
218 | projection_hidden_dim, out_dim])
219 |
220 | self.neighbour = None
221 | self.update_grid()
222 |
223 | for key, value in self.neighbour.items():
224 | self.neighbour[key] = self.neighbour[key].cuda()
225 |
226 | self.it = IntegralTransform(mlp_layers=self.mlp_layers)
227 |
228 | def update_grid(
229 | self,
230 | input_grid=None,
231 | output_grid=None
232 | ):
233 | if input_grid is None:
234 | input_grid = self.input_grid
235 | if output_grid is None:
236 | output_grid = self.output_grid
237 | input_grid = input_grid.clone()
238 | self.input_grid = self.input_grid[:input_grid.shape[0], :]
239 | self.output_grid = self.output_grid[:output_grid.shape[0], :]
240 |
241 | if self.fixed_neighbour:
242 | NS = FixedNeighborSearch(use_open3d=False)
243 | self.neighbour = NS(
244 | input_grid.clone(),
245 | output_grid.clone(),
246 | n_neigbor=self.n_neigbor)
247 | else:
248 | NS = NeighborSearch(use_open3d=False)
249 | self.neighbour = NS(
250 | input_grid.clone(),
251 | output_grid.clone(),
252 | radius=self.radius)
253 |
254 | def forward(self, inp):
255 | '''
256 | inp : (batch_size, n_points, in_dims/Channels)
257 | '''
258 |
259 | x = inp
260 | x = self.projection(x)
261 |
262 | out = self.it(self.input_grid, self.neighbour,
263 | self.output_grid, x[0, ...])
264 |
265 | return out[None, ...]
266 |
--------------------------------------------------------------------------------
/layers/regrider.py:
--------------------------------------------------------------------------------
1 | import torch_harmonics as th
2 | import torch.nn as nn
3 |
4 |
5 | class Regird(nn.Module):
6 | def __init__(
7 | self,
8 | input_grid,
9 | output_grid,
10 | sht_nlat=128,
11 | sht_nlon=256,
12 | output_scaling_factor=None):
13 | super().__init__()
14 | self.input_transform = th.RealSHT(
15 | sht_nlat, sht_nlon, grid=input_grid, norm='backward').float()
16 | self.input_grid = input_grid
17 | self.output_grid = output_grid
18 | self.output_transform = th.InverseRealSHT(
19 | sht_nlat, sht_nlon, grid=output_grid, norm='backward').float()
20 |
21 | def forward(self, x):
22 | if self.input_transform.nlat != x.shape[-2] or self.input_transform.nlon != x.shape[-1]:
23 | self.input_transform = th.RealSHT(x.shape[-2], x.shape[-1], grid=self.input_grid,
24 | norm='backward').to(x.device, dtype=x.dtype)
25 | self.output_transform = th.InverseRealSHT(x.shape[-2], x.shape[-1], grid=self.output_grid,
26 | norm='backward').to(x.device, dtype=x.dtype)
27 |
28 | return self.output_transform(self.input_transform(x))
29 |
--------------------------------------------------------------------------------
/layers/regular_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from einops import rearrange, repeat
4 | from vit_pytorch import ViT
5 |
6 |
7 | class vision_transformer(ViT):
8 | def forward(self, img):
9 | x = self.to_patch_embedding(img)
10 | b, n, _ = x.shape
11 |
12 | cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
13 | x = torch.cat((cls_tokens, x), dim=1)
14 | x += self.pos_embedding[:, :(n + 1)]
15 | x = self.dropout(x)
16 |
17 | x = self.transformer(x)
18 | return x[:, 1:, :]
--------------------------------------------------------------------------------
/layers/unet3d.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class UNet3d(nn.Module):
8 | def __init__(self, in_channels=3, out_channels=1, init_features=32):
9 | super(UNet3d, self).__init__()
10 |
11 | features = init_features
12 | self.encoder1 = UNet3d._block(in_channels, features, name="enc1")
13 | self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
14 | self.encoder2 = UNet3d._block(features, features * 2, name="enc2")
15 | self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
16 | self.bottleneck = UNet3d._block(
17 | features * 2, features * 4, name="bottleneck")
18 | self.upconv2 = nn.ConvTranspose3d(
19 | features * 4, features * 2, kernel_size=(1, 2, 2), stride=(1, 2, 2)
20 | )
21 | self.decoder2 = UNet3d._block(
22 | (features * 2) * 2, features * 2, name="dec2")
23 | self.upconv1 = nn.ConvTranspose3d(
24 | features * 2, features, kernel_size=(1, 2, 2), stride=(1, 2, 2)
25 | )
26 | self.decoder1 = UNet3d._block(features * 2, features, name="dec1")
27 |
28 | self.conv = nn.Conv3d(
29 | in_channels=features, out_channels=out_channels, kernel_size=1
30 | )
31 | def forward(self, x):
32 | enc1 = self.encoder1(x)
33 | enc2 = self.encoder2(self.pool1(enc1))
34 | bottleneck = self.bottleneck(self.pool2(enc2))
35 |
36 | dec2 = self.upconv2(bottleneck)
37 | dec2 = torch.cat((dec2, enc2), dim=1)
38 | dec2 = self.decoder2(dec2)
39 | dec1 = self.upconv1(dec2)
40 | dec1 = torch.cat((dec1, enc1), dim=1)
41 | dec1 = self.decoder1(dec1)
42 | return self.conv(dec1)
43 |
44 | @staticmethod
45 | def _block(in_channels, features, name):
46 | return nn.Sequential(
47 | OrderedDict(
48 | [
49 | (
50 | name + "conv1",
51 | nn.Conv3d(
52 | in_channels=in_channels,
53 | out_channels=features,
54 | kernel_size=3,
55 | padding=1,
56 | bias=False,
57 | ),
58 | ),
59 | (name + "norm1", nn.BatchNorm3d(num_features=features)),
60 | (name + "relu1", nn.ReLU(inplace=True)),
61 | (
62 | name + "conv2",
63 | nn.Conv3d(
64 | in_channels=features,
65 | out_channels=features,
66 | kernel_size=3,
67 | padding=1,
68 | bias=False,
69 | ),
70 | ),
71 | (name + "norm2", nn.BatchNorm3d(num_features=features)),
72 | (name + "relu2", nn.ReLU(inplace=True)),
73 | ]
74 | )
75 | )
76 |
--------------------------------------------------------------------------------
/layers/unet_sublayer.py:
--------------------------------------------------------------------------------
1 |
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class UNet2d(nn.Module):
14 |
15 | def __init__(self, in_channels=3, out_channels=1, init_features=32):
16 | super(UNet2d, self).__init__()
17 |
18 | features = init_features
19 | self.encoder1 = UNet2d._block(in_channels, features, name="enc1")
20 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
21 | self.encoder2 = UNet2d._block(features, features * 2, name="enc2")
22 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
23 | self.bottleneck = UNet2d._block(
24 | features * 2, features * 4, name="bottleneck")
25 | self.upconv2 = nn.ConvTranspose2d(
26 | features * 4, features * 2, kernel_size=2, stride=2
27 | )
28 | self.decoder2 = UNet2d._block(
29 | (features * 2) * 2, features * 2, name="dec2")
30 | self.upconv1 = nn.ConvTranspose2d(
31 | features * 2, features, kernel_size=2, stride=2
32 | )
33 | self.decoder1 = UNet2d._block(features * 2, features, name="dec1")
34 |
35 | self.conv = nn.Conv2d(
36 | in_channels=features, out_channels=out_channels, kernel_size=1
37 | )
38 |
39 | def forward(self, x, **kwargs):
40 | enc1 = self.encoder1(x)
41 | enc2 = self.encoder2(self.pool1(enc1))
42 |
43 | bottleneck = self.bottleneck(self.pool2(enc2))
44 | dec2 = self.upconv2(bottleneck)
45 | dec2 = torch.cat((dec2, enc2), dim=1)
46 | dec2 = self.decoder2(dec2)
47 | dec1 = self.upconv1(dec2)
48 | dec1 = torch.cat((dec1, enc1), dim=1)
49 | dec1 = self.decoder1(dec1)
50 | return self.conv(dec1)
51 |
52 | @staticmethod
53 | def _block(in_channels, features, name):
54 | return nn.Sequential(
55 | OrderedDict(
56 | [
57 | (
58 | name + "conv1",
59 | nn.Conv2d(
60 | in_channels=in_channels,
61 | out_channels=features,
62 | kernel_size=3,
63 | padding=1,
64 | bias=False,
65 | ),
66 | ),
67 | (name + "norm1", nn.BatchNorm2d(num_features=features)),
68 | (name + "tanh1", nn.ReLU()),
69 | (
70 | name + "conv2",
71 | nn.Conv2d(
72 | in_channels=features,
73 | out_channels=features,
74 | kernel_size=3,
75 | padding=1,
76 | bias=False,
77 | ),
78 | ),
79 | (name + "norm2", nn.BatchNorm2d(num_features=features)),
80 | (name + "tanh2", nn.ReLU()),
81 | ]
82 | )
83 | )
84 |
--------------------------------------------------------------------------------
/layers/variable_encoding.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | from typing import Tuple
3 | from neuralop.layers.mlp import MLPLinear
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | import torch_harmonics as th
8 | from neuralop.layers.embeddings import PositionalEmbedding
9 |
10 |
11 | class VariableEncoding2d(nn.Module):
12 | def __init__(self,
13 | n_variables: int,
14 | variable_encoding_size: int,
15 | modes: Tuple[int, ...],
16 | basis='fft') -> None:
17 | super().__init__()
18 | self.modes = modes
19 | channel = n_variables * variable_encoding_size
20 | self.coefficients_r = nn.Parameter(
21 | torch.empty(channel, *modes))
22 | self.coefficients_i = nn.Parameter(
23 | torch.empty(channel, *modes))
24 | self.reset_parameters()
25 | self.basis = basis
26 | if basis == 'fft':
27 | self.transform = torch.fft.ifft2
28 | elif basis == 'sht':
29 | self.transform = th.InverseRealSHT(
30 | *modes,
31 | lmax=modes[-2],
32 | mmax=modes[-1],
33 | grid="legendre-gauss",
34 | norm="backward",
35 | )
36 |
37 | def reset_parameters(self):
38 | std = (1 / (self.modes[-1] * self.modes[-2]))**0.5
39 | torch.nn.init.normal_(self.coefficients_r, mean=0.0, std=std)
40 | torch.nn.init.normal_(self.coefficients_i, mean=0.0, std=std)
41 |
42 | def forward(self, x):
43 | """Take a resolution and outputs the positional encodings"""
44 | size_x, size_y = x.shape[-2], x.shape[-1]
45 | if self.basis == 'sht':
46 | if self.transform.nlat == size_x and self.transform.nlon == size_y:
47 | return self.transform(
48 | self.coefficients_r + 1.0j * self.coefficients_i)
49 |
50 | self.transform = th.InverseRealSHT(
51 | size_x,
52 | size_y,
53 | lmax=self.modes[-2],
54 | mmax=self.modes[-1],
55 | grid="legendre-gauss",
56 | norm='backward'
57 | ).to(
58 | device=self.coefficients_i.device,
59 | dtype=self.coefficients_i.dtype
60 | )
61 | return self.transform(
62 | self.coefficients_r + 1.0j * self.coefficients_i,
63 | s=(size_x, size_y)
64 | ).real
65 |
66 | else:
67 | return self.transform(
68 | self.coefficients_r + 1.0j * self.coefficients_i,
69 | s=(size_x, size_y)
70 | ).real
71 |
72 |
73 | # SHT doesn't make sense for 3 dimenstional data
74 | class FourierVariableEncoding3D(nn.Module):
75 | def __init__(self, n_features: int, modes: Tuple[int, ...]) -> None:
76 | super().__init__()
77 | if len(modes) != 3:
78 | raise ValueError(
79 | f"Expected 3 frequency modes, but got {len(modes)} modes:\n{modes=}")
80 |
81 | self.modes = modes
82 | self.weights_re = nn.Parameter(torch.empty(n_features, *modes))
83 | self.weights_im = nn.Parameter(torch.empty(n_features, *modes))
84 | self.reset_parameters()
85 | # self.transform = torch.fft.ifftn
86 |
87 | def reset_parameters(self):
88 | std = 1 / np.sqrt(reduce(lambda a, b: a * b, self.modes))
89 | torch.nn.init.normal_(self.weights_re, mean=0.0, std=std)
90 | torch.nn.init.normal_(self.weights_im, mean=0.0, std=std)
91 |
92 | def forward(self, x_shape):
93 | """Take a resolution and outputs the positional encodings"""
94 | _, size_t, size_x, size_y = x_shape[0], x_shape[1], x_shape[2], x_shape[3]
95 | s = (size_t, size_x, size_y)
96 | # now check if s is a tuple of ints
97 | if not all(isinstance(i, int) for i in s):
98 | raise ValueError(
99 | f"Expected a tuple of integers, but got {s=}")
100 | else:
101 | s = tuple(s)
102 | return torch.fft.ifftn(
103 | self.weights_re + 1.0j * self.weights_im,
104 | s=s,
105 | norm="forward", # don't multiply by any normalization factor
106 | ).real
107 |
108 |
109 | class VariableEncodingIrregularMesh(nn.Module):
110 | def __init__(
111 | self,
112 | n_variables: int,
113 | variable_encoding_size: int,
114 | n_dim: int = 2,
115 | positional_encoding_dim: int = 8
116 | ) -> None:
117 | super().__init__()
118 | self.n_variables = n_variables
119 | self.variable_encoding_size = variable_encoding_size
120 | self.n_dim = n_dim
121 | self.positional_encoding_dim = positional_encoding_dim
122 | self.var_encoder = MLPLinear(
123 | [n_dim + self.n_dim * positional_encoding_dim, self.variable_encoding_size * n_variables])
124 | self.PE = PositionalEmbedding(positional_encoding_dim)
125 |
126 | def forward(self, grid_poits):
127 | pe = self.PE(grid_poits.reshape(-1))
128 | pe = pe.reshape(grid_poits.shape[0], -1)
129 | grid_pe = torch.cat([grid_poits, pe], axis=1)
130 | var_encoding = self.var_encoder(grid_pe)
131 | return var_encoding
132 |
133 |
134 | class VariableEncodingWrapper(nn.Module):
135 | def __init__(
136 | self,
137 | equation_dict: dict,
138 | variable_encoding_size: int,
139 | n_dim: int = 2,
140 | positional_encoding_dim: int = 8,
141 | varibale_encoding_modes: Tuple[int, ...] = (32, 32),
142 | basis='fft',
143 | uniform=False) -> None:
144 | '''
145 | For each equation in the equation_dict, we create a VariableEncodingIrregularMesh
146 | dic is of form {"Equation": n_variables, ...}
147 | '''
148 | super().__init__()
149 | self.n_dim = n_dim
150 | self.equation_dict = equation_dict
151 | self.variable_encoding_size = variable_encoding_size
152 | self.model_dict = nn.ModuleDict()
153 | self.uniform = uniform
154 | for i in equation_dict.keys():
155 | if not uniform:
156 | self.model_dict[i] = VariableEncodingIrregularMesh(
157 | n_variables=equation_dict[i],
158 | variable_encoding_size=self.variable_encoding_size,
159 | n_dim=n_dim,
160 | positional_encoding_dim=positional_encoding_dim
161 | )
162 | else:
163 | self.model_dict[i] = VariableEncoding2d(
164 | n_variables=equation_dict[i],
165 | variable_encoding_size=self.variable_encoding_size,
166 | modes=varibale_encoding_modes,
167 | basis=basis
168 | )
169 |
170 | def load_encoder(self, equation: str, path: str):
171 | self.model_dict[equation].load_state_dict(
172 | torch.load(path, map_location=torch.device('cpu')))
173 |
174 | def save_encoder(self, equation: str, path: str):
175 | torch.save(self.model_dict[equation].state_dict(), path)
176 |
177 | def save_all_encoder(self, path: str):
178 | for i in self.equation_dict.keys():
179 | torch.save(self.model_dict[i].state_dict(), path + f"_{i}" + ".pt")
180 |
181 | def freeze(self, equation: str):
182 | for param in self.model_dict[equation].parameters():
183 | param.requires_grad = False
184 |
185 | def forward(self, grid_poits, equation: str = None):
186 | '''
187 | grid_poits: (n_points, n_dim) or for uniform mesh input tensor of shape (D, channels, H, W)
188 | '''
189 | encoding_list = []
190 | if equation is None:
191 | equation = list(self.equation_dict.keys())
192 | for i in equation:
193 | encoding_list.append(self.model_dict[i](grid_poits))
194 |
195 | if self.uniform:
196 | return torch.cat(encoding_list, axis=0).unsqueeze(0)
197 | else:
198 | return torch.cat(encoding_list, axis=1).unsqueeze(0)
199 |
200 |
201 | def get_variable_encoder(params):
202 | return VariableEncodingWrapper(
203 | params.equation_dict,
204 | variable_encoding_size=params.n_encoding_channels,
205 | n_dim=params.n_dim,
206 | positional_encoding_dim=params.positional_encoding_dim,
207 | uniform=params.grid_type == 'uniform',
208 | varibale_encoding_modes=(
209 | params.encoding_modes_x,
210 | params.encoding_modes_y) if hasattr(
211 | params,
212 | 'encoding_modes_y') else None)
213 |
214 |
215 |
216 | class TokenExpansion(nn.Module):
217 | def __init__(
218 | self,
219 | n_variables: int,
220 | n_encoding_channels,
221 | n_static_channels: int,
222 | uniform_grid=False) -> None:
223 | """
224 | stack the variables and the corresponsing encodings together
225 |
226 | Args:
227 | n_variables (int): number of variables
228 | n_encoding_channels (int): number of encoding channels
229 | n_static_channels (int): number of static channels
230 | unifor_grid (bool): if the grid is uniform
231 |
232 | """
233 | super().__init__()
234 | self.n_variables = n_variables
235 | self.n_encoding_channels = n_encoding_channels
236 | self.n_static_channels = n_static_channels
237 | self.uniform_grid = uniform_grid
238 |
239 | expansion_factor = 1 + self.n_static_channels + self.n_encoding_channels
240 |
241 | self.variable_channels = [
242 | i * expansion_factor for i in range(n_variables)]
243 | self.static_channels = []
244 | if self.n_static_channels != 0:
245 | for v in self.variable_channels:
246 | self.static_channels.extend(
247 | range(v + 1, v + self.n_static_channels + 1))
248 | self.encoding_channels = []
249 | if self.n_encoding_channels != 0:
250 | self.encoding_channels = sorted(list(
251 | set(range(n_variables * expansion_factor))
252 | - set(self.variable_channels)
253 | - set(self.static_channels)
254 | ))
255 |
256 | print(self.variable_channels)
257 | print(self.static_channels)
258 | print(self.encoding_channels)
259 |
260 | def forward(
261 | self,
262 | inp: torch.Tensor,
263 | variable_encodings: torch.tensor,
264 | static_channels: torch.tensor) -> torch.Tensor:
265 | """
266 | x: (batch_size, points, n_variables)
267 | """
268 | if not self.uniform_grid:
269 | x = torch.zeros((inp.shape[0], inp.shape[1], len(self.variable_channels) + len(
270 | self.encoding_channels) + len(self.static_channels)), device=inp.device, dtype=inp.dtype)
271 | x[:, :, self.variable_channels] = inp
272 | if self.n_static_channels != 0:
273 | x[:, :, self.static_channels] = static_channels.repeat(
274 | x.shape[0], 1, 1)
275 | if self.n_encoding_channels != 0:
276 | x[:, :, self.encoding_channels] = variable_encodings.repeat(
277 |
278 | x.shape[0], 1, 1)
279 |
280 | else:
281 | # current support for only 2D
282 |
283 | x = torch.zeros((inp.shape[0], len(
284 | self.variable_channels) + len(self.encoding_channels) + len(self.static_channels),
285 | inp.shape[-2], inp.shape[-1]), device=inp.device, dtype=inp.dtype)
286 | x[:, self.variable_channels, :, :] = inp
287 | if self.n_static_channels != 0:
288 | x[:, self.static_channels, :, :] = static_channels.repeat(
289 | 1, self.n_variables, 1, 1)
290 | if self.n_encoding_channels != 0:
291 | x[:, self.encoding_channels, :, :] = variable_encodings.repeat(
292 | x.shape[0], 1, 1, 1)
293 |
294 | return x
295 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from YParams import YParams
2 | import wandb
3 | import argparse
4 | import torch
5 | import numpy as np
6 | from data_utils.data_loaders import *
7 | from data_utils.data_utils import MaskerNonuniformMesh, get_meshes
8 | from layers.variable_encoding import *
9 | from models.get_models import *
10 | from train.trainer import trainer
11 | from utils import *
12 | from models.model_helpers import count_parameters
13 | from test.evaluations import missing_variable_testing
14 | import random
15 |
16 |
17 | if __name__ == "__main__":
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--exp", nargs="?", default="FSI", type=str)
21 | parser.add_argument("--config", nargs="?", default="base_config", type=str)
22 | parser.add_argument("--ntrain", nargs="?", default=None, type=int)
23 | parser.add_argument("--epochs", nargs="?", default=None, type=int)
24 | parser.add_argument("--random_seed", nargs="?", default=42, type=int)
25 | parser.add_argument("--scheduler_step", nargs="?", default=None, type=int)
26 | parser.add_argument("--batch_size", nargs="?", default=None, type=int)
27 | parsed_args = parser.parse_args()
28 |
29 | if parsed_args.exp == "FSI":
30 | config_file = './config/ssl_ns_elastic.yaml'
31 | elif parsed_args.exp == "RB":
32 | config_file = './config/RB_config.yaml'
33 | else:
34 | raise ValueError("Unknown experiment type")
35 |
36 | config = parsed_args.config
37 | print("Loading config", config)
38 | params = YParams(config_file, config, print_params=True)
39 |
40 | if parsed_args.ntrain is not None:
41 | params.ntrain = parsed_args.ntrain
42 | print("Overriding ntrain to", params.ntrain)
43 | if parsed_args.random_seed is not None:
44 | params.random_seed = parsed_args.random_seed
45 | print("Overriding random seed to", params.random_seed)
46 | if parsed_args.epochs is not None:
47 | params.epochs = parsed_args.epochs
48 | print("Overriding epochs to", params.epochs)
49 | if parsed_args.scheduler_step is not None:
50 | params.scheduler_step = parsed_args.scheduler_step
51 | print("Overriding scheduler step to", params.scheduler_step)
52 | if parsed_args.batch_size is not None:
53 | params.batch_size = parsed_args.batch_size
54 | print("Overriding batch size to", params.batch_size)
55 |
56 | torch.manual_seed(params.random_seed)
57 | random.seed(params.random_seed)
58 | np.random.seed(params.random_seed)
59 |
60 | params.config = config
61 |
62 | # Set up WandB logging
63 | params.wandb_name = config
64 | params.wandb_group = params.nettype
65 | if params.wandb_log:
66 | wandb.login(key=get_wandb_api_key())
67 | wandb.init(
68 | config=params,
69 | name=params.wandb_name,
70 | group=params.wandb_group,
71 | project=params.wandb_project,
72 | entity=params.wandb_entity)
73 |
74 | # stage of training: reconstructive (Self supervised traning)
75 | # or predictive (Supervised training either finetuning or from scratch)
76 | if params.pretrain_ssl:
77 | stage = StageEnum.RECONSTRUCTIVE
78 | else:
79 | stage = StageEnum.PREDICTIVE
80 |
81 | variable_encoder = None
82 | token_expander = None
83 |
84 | if params.nettype == 'transformer':
85 | if params.grid_type == 'uniform':
86 | encoder, decoder, contrastive, predictor = get_ssl_models_codano(
87 | params)
88 | input_mesh = None
89 | else:
90 | encoder, decoder, contrastive, predictor = get_ssl_models_codano_gino(
91 | params)
92 |
93 | if params.use_variable_encoding:
94 | variable_encoder = get_variable_encoder(params)
95 | token_expander = TokenExpansion(sum([params.equation_dict[i] for i in params.equation_dict.keys(
96 | )]), params.n_encoding_channels, params.n_static_channels, params.grid_type == 'uniform')
97 |
98 | print("Parameters Encoder", count_parameters(encoder), "x10^6")
99 | print("Parameters Decoder", count_parameters(decoder), "x10^6")
100 | print("Parameters Perdictor", count_parameters(predictor), "x10^6")
101 | if params.wandb_log:
102 | wandb.log(
103 | {'Encoder #parameters': count_parameters(encoder)}, commit=True)
104 | wandb.log(
105 | {'Decoder #parameters': count_parameters(decoder)}, commit=True)
106 | wandb.log(
107 | {'Predictor #parameters': count_parameters(predictor)}, commit=True)
108 |
109 | model = SSLWrapper(
110 | params,
111 | encoder,
112 | decoder,
113 | contrastive,
114 | predictor,
115 | stage=stage)
116 |
117 | if params.grid_type != 'uniform':
118 | print("Setting the Grid")
119 | mesh = get_mesh(params)
120 | input_mesh = torch.from_numpy(mesh).type(torch.float).cuda()
121 | model.set_initial_mesh(input_mesh)
122 |
123 | elif params.nettype in ['simple', 'gnn', 'deeponet', 'vit', 'unet', 'fno']:
124 | model = get_baseline_model(params)
125 | print("Parameters Model", count_parameters(model), "x10^6")
126 | wandb.log({'Model #parameters': count_parameters(model)}, commit=True)
127 | input_mesh = None
128 |
129 | print("PDE list", *list(params.equation_dict.keys()))
130 |
131 | if parsed_args.exp == 'FSI':
132 | # loading Fluid Stucture Interaction dataset
133 | dataset = NsElasticDataset(
134 | params.data_location,
135 | equation=list(params.equation_dict.keys()),
136 | mesh_location=params.input_mesh_location,
137 | params=params)
138 | train, test = dataset.get_dataloader(params.mu_list, params.dt, ntrain=params.get(
139 | 'ntrain'), ntest=params.get('ntest'), sample_per_inlet=params.sample_per_inlet)
140 | elif parsed_args.exp == 'RB':
141 | # loading Rayleigh-Benard dataset
142 | train, test = get_RB_dataloader(params)
143 |
144 | if getattr(params, 'evaluate_only', False):
145 | # setting satge to predictive for evaluation
146 | # load model weights
147 | stage = StageEnum.PREDICTIVE
148 | model.load_state_dict(torch.load(params.model_path), strict=False)
149 | if params.nettype == 'transformer' and params.use_variable_encoding:
150 | if "NS" in params.equation_dict.keys():
151 | print("Loading NS variable encoder")
152 | variable_encoder.load_encoder(
153 | "NS", params.NS_variable_encoder_path)
154 | if "ES" in params.equation_dict.keys() and params.ES_variable_encoder_path is not None:
155 | print("Loading ES variable encoder")
156 | variable_encoder.load_encoder(
157 | "ES", params.ES_variable_encoder_path)
158 | if "T" in params.equation_dict.keys() and params.T_variable_encoder_path is not None:
159 | print("Loading T variable encoder")
160 | variable_encoder.load_encoder(
161 | "T", params.T_variable_encoder_path)
162 | if params.freeze_encoder:
163 | variable_encoder.freeze("T")
164 |
165 | elif params.training_stage == 'fine_tune':
166 | # load only encooder and vaariable encoder weights (VSPE)
167 | print(f"Loading Pretrained weights from {params.pretrain_weight}")
168 | model.encoder.load_state_dict(torch.load(
169 | params.pretrain_weight), strict=True)
170 | if params.use_variable_encoding:
171 | print(
172 | f"Loading Pretrained weights from {params.NS_variable_encoder_path}")
173 |
174 | if "NS" in params.equation_dict.keys():
175 | print("Loading NS variable encoder")
176 | variable_encoder.load_encoder(
177 | "NS", params.NS_variable_encoder_path)
178 | if params.freeze_encoder:
179 | variable_encoder.freeze("NS")
180 |
181 | if "ES" in params.equation_dict.keys() and params.ES_variable_encoder_path is not None:
182 | print("Loading ES variable encoder")
183 | variable_encoder.load_encoder(
184 | "ES", params.ES_variable_encoder_path)
185 | if params.freeze_encoder:
186 | variable_encoder.freeze("ES")
187 |
188 | if "T" in params.equation_dict.keys() and params.T_variable_encoder_path is not None:
189 | print("Loading T variable encoder")
190 | variable_encoder.load_encoder(
191 | "T", params.T_variable_encoder_path)
192 | if params.freeze_encoder:
193 | variable_encoder.freeze("T")
194 |
195 | # Move model and encoders to GPU
196 | model = model.cuda()
197 | if variable_encoder is not None:
198 | variable_encoder.cuda()
199 | if token_expander is not None:
200 | token_expander.cuda()
201 |
202 | if not getattr(params, 'evaluate_only', False):
203 | # Train the model
204 | trainer(
205 | model,
206 | train,
207 | test,
208 | params,
209 | wandb_log=params.wandb_log,
210 | log_test_interval=params.wandb_log_test_interval,
211 | stage=stage,
212 | variable_encoder=variable_encoder,
213 | token_expander=token_expander,
214 | initial_mesh=input_mesh)
215 |
216 | if getattr(params, 'missing_var_test', False):
217 | # evaluate on missing variables and partially observed
218 | # variabless
219 |
220 | grid_non, grid_uni = get_meshes(
221 | params, params.grid_size)
222 | test_augmenter = MaskerNonuniformMesh(
223 | grid_non_uni=grid_non.clone().detach(),
224 | gird_uni=grid_uni.clone().detach(),
225 | radius=params.masking_radius,
226 | drop_type=params.drop_type,
227 | drop_pix=params.drop_pix_val,
228 | channel_aug_rate=params.channel_per_val,
229 | channel_drop_rate=params.channel_drop_per_val,
230 | verbose=True)
231 | missing_variable_testing(
232 | model,
233 | test,
234 | test_augmenter,
235 | 'sl',
236 | params,
237 | variable_encoder=variable_encoder,
238 | token_expander=token_expander,
239 | initial_mesh=input_mesh)
240 | else:
241 | # Evaluate the model on the
242 | # unaugmentted the test set
243 |
244 | missing_variable_testing(
245 | model,
246 | test,
247 | augmenter=None,
248 | normalizer=None,
249 | stage=stage,
250 | params=params,
251 | variable_encoder=variable_encoder,
252 | token_expander=token_expander,
253 | initial_mesh=input_mesh)
254 |
255 | if params.wandb_log:
256 | wandb.finish()
257 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/models/__init__.py
--------------------------------------------------------------------------------
/models/codano.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import logging
3 | from typing import Literal, NamedTuple, Optional
4 | import numpy as np
5 | from einops import rearrange
6 | import torch
7 | from torch import nn
8 | import torch.nn.functional as F
9 | from neuralop.layers.padding import DomainPadding
10 | from layers.codano_block_nd import CodanoBlockND
11 | from layers.fino_nd import SpectralConvKernel2d
12 | from layers.variable_encoding import VariableEncoding2d
13 |
14 |
15 | # TODO replace with nerualop.MLP module
16 | class PermEqProjection(nn.Module):
17 | def __init__(
18 | self,
19 | in_channels,
20 | out_channels,
21 | hidden_channels=None,
22 | n_dim=2,
23 | non_linearity=F.gelu,
24 | permutation_invariant=False,
25 | ):
26 | """Permutation invariant projection layer.
27 |
28 | Performs linear projections on each channel separately.
29 | """
30 | super().__init__()
31 | self.in_channels = in_channels
32 | self.out_channels = out_channels
33 | self.hidden_channels = (in_channels
34 | if hidden_channels is None else
35 | hidden_channels)
36 | self.non_linearity = non_linearity
37 | Conv = getattr(nn, f'Conv{n_dim}d')
38 |
39 | self.permutation_invariant = permutation_invariant
40 |
41 | self.fc1 = Conv(in_channels, hidden_channels, 1)
42 | self.norm = nn.InstanceNorm2d(hidden_channels, affine=True)
43 | self.fc2 = Conv(hidden_channels, out_channels, 1)
44 |
45 | def forward(self, x):
46 | batch = x.shape[0]
47 | if self.permutation_invariant:
48 | assert x.shape[1] % self.in_channels == 0, \
49 | "Total Number of Channels is not divisible by number of tokens"
50 | x = rearrange(x, 'b (g c) h w -> (b g) c h w', c=self.in_channels)
51 |
52 | x = self.fc1(x)
53 | x = self.norm(x)
54 | x = self.non_linearity(x)
55 | x = self.fc2(x)
56 | if self.permutation_invariant:
57 | x = rearrange(x, '(b g) c h w -> b (g c) h w', b=batch)
58 | return x
59 |
60 |
61 | class VariableEncodingArgs(NamedTuple):
62 | basis: Literal["sht", "fft"]
63 | n_channels: int
64 | """Number of extra encoding channels per variable."""
65 | modes_x: int
66 | modes_y: int
67 | modes_t: Optional[int] = None
68 |
69 |
70 | class CodANO(nn.Module):
71 | """
72 | Parameters
73 | ---
74 | input_token_codimension : input token codim/number of channel per input token
75 | out_token_codim=None : output token codim/number of channel per output token
76 | hidden_token_codim=None :
77 | lifting_token_codim=None :
78 | var_encoding=False : boolean
79 | if true then it adds variable encoding with each channel
80 | var_num=None : denotes the number of variables
81 | var_enco_basis='sht' : specify the basis funtion for variable encodings
82 | var_enco_channels=1 : number of channels for each variable encoding
83 | var_enco_mode_x=50 : number of x modes for each variable encoding
84 | var_enco_mode_y=50 : number of y models for each variable encoding
85 | enable_cls_token=False : if true, learnable cls token will be added
86 | static_channels_num=0 :
87 | Number of static channels to be concatenated (xy grid, land/sea mask etc)
88 | static_features=None :
89 | The static feature (it will be taken from the Preprocessor while
90 | initializing the model)
91 | integral_operator_top :
92 | Required for the re-grid operation (for example: from equiangular to LG grid.)
93 | integral_operator_bottom :
94 | Required for the re-grid operation (for example: from LG grid to equiangular)
95 | """
96 |
97 | def __init__(
98 | self,
99 | input_token_codimension,
100 | output_token_codimension=None,
101 | hidden_token_codimension=None,
102 | lifting_token_codimension=None,
103 | n_layers=4,
104 | n_modes=None,
105 | max_n_modes=None,
106 | scalings=None,
107 | n_heads=1,
108 | non_linearity=F.gelu,
109 | layer_kwargs={'use_mlp': False,
110 | 'mlp_dropout': 0,
111 | 'mlp_expansion': 1.0,
112 | 'non_linearity': F.gelu,
113 | 'norm': None,
114 | 'preactivation': False,
115 | 'fno_skip': 'linear',
116 | 'horizontal_skip': 'linear',
117 | 'mlp_skip': 'linear',
118 | 'separable': False,
119 | 'factorization': None,
120 | 'rank': 1.0,
121 | 'fft_norm': 'forward',
122 | 'normalizer': 'instance_norm',
123 | 'joint_factorization': False,
124 | 'fixed_rank_modes': False,
125 | 'implementation': 'factorized',
126 | 'decomposition_kwargs': dict(),
127 | 'normalizer': False},
128 | per_channel_attention=True,
129 | operator_block=CodanoBlockND,
130 | integral_operator=SpectralConvKernel2d,
131 | integral_operator_top=partial(
132 | SpectralConvKernel2d, sht_grid="legendre-gauss"),
133 | integral_operator_bottom=partial(
134 | SpectralConvKernel2d, isht_grid="legendre-gauss"),
135 | projection=True,
136 | lifting=True,
137 | domain_padding=0.5,
138 | domain_padding_mode='one-sided',
139 | n_variables=None,
140 | variable_encoding_args: VariableEncodingArgs = None,
141 | enable_cls_token=False,
142 | logger=None,
143 | ):
144 | super().__init__()
145 | self.n_layers = n_layers
146 | assert len(
147 | n_modes) == n_layers, "number of modes for all layers are not given"
148 | assert len(n_heads) == n_layers, \
149 | "number of Attention head for all layers are not given"
150 | if integral_operator_bottom is None:
151 | integral_operator_bottom = integral_operator
152 | if integral_operator_top is None:
153 | integral_operator_top = integral_operator
154 | self.n_dim = len(n_modes[0])
155 | self.input_token_codimension = input_token_codimension
156 | # self.n_variables = n_variables
157 | if hidden_token_codimension is None:
158 | hidden_token_codimension = input_token_codimension
159 | if lifting_token_codimension is None:
160 | lifting_token_codimension = input_token_codimension
161 | if output_token_codimension is None:
162 | output_token_codimension = input_token_codimension
163 |
164 | self.hidden_token_codimension = hidden_token_codimension
165 | self.n_modes = n_modes
166 | self.max_n_modes = max_n_modes
167 | self.scalings = scalings
168 | self.non_linearity = non_linearity
169 | self.n_heads = n_heads
170 | self.integral_operator = integral_operator
171 | self.lifting = lifting
172 | self.projection = projection
173 | self.num_dims = len(n_modes[0])
174 | self.enable_cls_token = enable_cls_token
175 |
176 | if logger is None:
177 | logger = logging.getLogger()
178 | self.logger = logger
179 |
180 | self.layer_kwargs = layer_kwargs
181 | if layer_kwargs is None:
182 | self.layer_kwargs = {
183 | 'incremental_n_modes': None,
184 | 'use_mlp': False,
185 | 'mlp_dropout': 0,
186 | 'mlp_expansion': 1.0,
187 | 'non_linearity': F.gelu,
188 | 'norm': None,
189 | 'preactivation': False,
190 | 'fno_skip': 'linear',
191 | 'horizontal_skip': 'linear',
192 | 'mlp_skip': 'linear',
193 | 'separable': False,
194 | 'factorization': None,
195 | 'rank': 1.0,
196 | 'fft_norm': 'forward',
197 | 'normalizer': 'instance_norm',
198 | 'joint_factorization': False,
199 | 'fixed_rank_modes': False,
200 | 'implementation': 'factorized',
201 | 'decomposition_kwargs': None,
202 | }
203 |
204 | # self.n_static_channels = n_static_channels
205 | """The number of static channels for all variable channels."""
206 |
207 | # calculating scaling
208 | if self.scalings is not None:
209 | self.end_to_end_scaling = self.get_output_scaling_factor(
210 | np.ones_like(self.scalings[0]),
211 | self.scalings
212 | )
213 | else:
214 | self.end_to_end_scaling = 1
215 | self.logger.debug(f"{self.end_to_end_scaling=}")
216 | if isinstance(self.end_to_end_scaling, (float, int)):
217 | self.end_to_end_scaling = [self.end_to_end_scaling] * self.n_dim
218 |
219 | # Setting up domain padding for encoder and reconstructor
220 | if domain_padding is not None and domain_padding > 0:
221 | self.domain_padding = DomainPadding(
222 | domain_padding=domain_padding,
223 | padding_mode=domain_padding_mode,
224 | output_scaling_factor=self.end_to_end_scaling,
225 | )
226 | else:
227 | self.domain_padding = None
228 | self.domain_padding_mode = domain_padding_mode
229 |
230 | # A variable + it's variable encoding + the static channel(s)
231 | # together constitute a token
232 | # n_lifted_channels = self.input_token_codimension + \
233 | # variable_encoding_args.n_channels + \
234 | # self.n_static_channels
235 | if self.lifting:
236 | self.lifting = PermEqProjection(
237 | in_channels=input_token_codimension,
238 | out_channels=hidden_token_codimension,
239 | hidden_channels=lifting_token_codimension,
240 | n_dim=self.n_dim,
241 | non_linearity=self.non_linearity,
242 | permutation_invariant=True, # Permutation
243 | )
244 | # elif self.use_variable_encoding:
245 | # hidden_token_codimension = n_lifted_channels
246 |
247 | cls_dimension = 1 if enable_cls_token else 0
248 | self.codimension_size = hidden_token_codimension * n_variables + cls_dimension
249 |
250 | self.logger.debug(
251 | f"Expected number of channels: {self.codimension_size=}")
252 |
253 | self.base = nn.ModuleList([])
254 | for i in range(self.n_layers):
255 | if i == 0 and self.n_layers != 1:
256 | conv_op = integral_operator_top
257 | elif i == self.n_layers - 1 and self.n_layers != 1:
258 | conv_op = integral_operator_bottom
259 | else:
260 | conv_op = self.integral_operator
261 |
262 | self.base.append(
263 | operator_block(
264 | n_modes=self.n_modes[i],
265 | max_n_modes=self.max_n_modes[i],
266 | n_head=self.n_heads[i],
267 | token_codim=hidden_token_codimension,
268 | output_scaling_factor=[self.scalings[i]],
269 | SpectralConvolution=conv_op,
270 | codim_size=self.codimension_size,
271 | per_channel_attention=per_channel_attention,
272 | num_dims=self.num_dims,
273 | logger=self.logger.getChild(f"base[{i}]"),
274 | **self.layer_kwargs,
275 | )
276 | )
277 |
278 | if self.projection:
279 | self.projection = PermEqProjection(
280 | in_channels=hidden_token_codimension,
281 | out_channels=output_token_codimension,
282 | hidden_channels=lifting_token_codimension,
283 | n_dim=self.n_dim,
284 | non_linearity=self.non_linearity,
285 | permutation_invariant=True, # Permutation
286 | )
287 |
288 | if enable_cls_token:
289 | self.cls_token = VariableEncoding2d(
290 | 1,
291 | hidden_token_codimension,
292 | (variable_encoding_args.modes_x,
293 | variable_encoding_args.modes_y),
294 | basis=variable_encoding_args.basis)
295 |
296 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer):
297 | for k in scalings_per_layer:
298 | initial_scale = np.multiply(initial_scale, k)
299 | initial_scale = initial_scale.tolist()
300 | if len(initial_scale) == 1:
301 | initial_scale = initial_scale[0]
302 | return initial_scale
303 |
304 | def get_device(self,):
305 | return self.cls_token.coefficients_r.device
306 |
307 | def forward(self, x: torch.Tensor):
308 | if self.lifting:
309 | x = self.lifting(x)
310 |
311 | if self.enable_cls_token:
312 | cls_token = self.cls_token(x).unsqueeze(0)
313 | repeat_shape = [1 for _ in x.shape]
314 | repeat_shape[0] = x.shape[0]
315 | x = torch.cat(
316 | [
317 | cls_token.repeat(*repeat_shape),
318 | x,
319 | ],
320 | dim=1,
321 | )
322 |
323 | if self.domain_padding is not None:
324 | x = self.domain_padding.pad(x)
325 |
326 | output_shape_en = [round(i * j) for (i,
327 | j) in zip(x.shape[-self.n_dim:],
328 | self.end_to_end_scaling)]
329 |
330 | cur_output_shape = None
331 | for layer_idx in range(self.n_layers):
332 | if layer_idx == self.n_layers - 1:
333 | cur_output_shape = output_shape_en
334 | x = self.base[layer_idx](x, output_shape=cur_output_shape)
335 | # self.logger.debug(f"{x.shape} (block[{layer_idx}])")
336 |
337 | if self.domain_padding is not None:
338 | x = self.domain_padding.unpad(x)
339 |
340 | if self.projection:
341 | x = self.projection(x)
342 | # self.logger.debug(f"{x.shape} (projection)")
343 |
344 | return x
345 |
346 |
347 | class CoDANOTemporal:
348 | def __call__(self, x):
349 | pass
350 |
--------------------------------------------------------------------------------
/models/codano_gino.py:
--------------------------------------------------------------------------------
1 | from .codano import CodANO
2 | from layers.gno_layer import GnoPremEq
3 | from layers.codano_block_2D import CodanoBlocks2d
4 | from layers.fino_2D import SpectralConvKernel2d
5 | from functools import partial
6 | from einops import rearrange
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from layers.regrider import Regird
10 | from neuralop.layers.padding import DomainPadding
11 | import numpy as np
12 | import torch
13 | from layers.variable_encoding import VariableEncoding2d
14 |
15 | # TODO replace with nerualop.MLP module
16 |
17 |
18 | class Projection(nn.Module):
19 | def __init__(
20 | self,
21 | in_channels,
22 | out_channels,
23 | hidden_channels=None,
24 | n_dim=2,
25 | non_linearity=F.gelu,
26 | permutation_invariant=False,
27 | ):
28 | """Permutation invariant projection layer.
29 |
30 | Performs linear projections on each channel separately.
31 | """
32 | super().__init__()
33 | self.in_channels = in_channels
34 | self.out_channels = out_channels
35 | self.hidden_channels = (in_channels
36 | if hidden_channels is None else
37 | hidden_channels)
38 | self.non_linearity = non_linearity
39 | Conv = getattr(nn, f'Conv{n_dim}d')
40 |
41 | self.permutation_invariant = permutation_invariant
42 |
43 | self.fc1 = Conv(in_channels, hidden_channels, 1)
44 | self.norm = nn.InstanceNorm2d(hidden_channels, affine=True)
45 | self.fc2 = Conv(hidden_channels, out_channels, 1)
46 |
47 | def forward(self, x):
48 | batch = x.shape[0]
49 | if self.permutation_invariant:
50 | assert x.shape[1] % self.in_channels == 0, \
51 | "Total Number of Channels is not divisible by number of tokens"
52 | x = rearrange(x, 'b (g c) h w -> (b g) c h w', c=self.in_channels)
53 |
54 | x = self.fc1(x)
55 | x = self.norm(x)
56 | x = self.non_linearity(x)
57 | x = self.fc2(x)
58 | if self.permutation_invariant:
59 | x = rearrange(x, '(b g) c h w -> b (g c) h w', b=batch)
60 | return x
61 |
62 |
63 | class CondnoGino(nn.Module):
64 | def __init__(self,
65 | in_token_codim,
66 | input_grid,
67 | output_grid=None,
68 | grid_size=None,
69 | radius=None,
70 | n_neigbor=10,
71 | fixed_neighbour=False,
72 | out_token_codim=None,
73 | hidden_token_codim=None,
74 | lifting_token_codim=None,
75 | kqv_non_linear=False,
76 | n_layers=4,
77 | n_modes=None,
78 | scalings=None,
79 | n_heads=1,
80 | layer_kwargs={'incremental_n_modes': None,
81 | 'use_mlp': False,
82 | 'mlp_dropout': 0,
83 | 'mlp_expansion': 1.0,
84 | 'non_linearity': torch.sin,
85 | 'norm': None, 'preactivation': False,
86 | 'fno_skip': 'linear',
87 | 'horizontal_skip': 'linear',
88 | 'mlp_skip': 'linear',
89 | 'separable': False,
90 | 'factorization': None,
91 | 'rank': 1.0,
92 | 'fft_norm': 'forward',
93 | 'normalizer': 'instance_norm',
94 | 'joint_factorization': False,
95 | 'fixed_rank_modes': False,
96 | 'implementation': 'factorized',
97 | 'decomposition_kwargs': dict(),
98 | 'normalizer': False},
99 | operator_block=CodanoBlocks2d,
100 | per_channel_attention=False,
101 | integral_operator=SpectralConvKernel2d,
102 | integral_operator_top=None,
103 | integral_operator_bottom=None,
104 | re_grid_input=False,
105 | re_grid_output=False,
106 | projection=True,
107 | gno_mlp_layers=None,
108 | lifting=True,
109 | domain_padding=None,
110 | domain_padding_mode='one-sided',
111 | var_encoding=False,
112 | var_num=None, # denotes the number of varibales
113 | var_enco_basis='fft',
114 | var_enco_channels=1,
115 | var_enco_mode_x=20,
116 | var_enco_mode_y=40,
117 | enable_cls_token=False,
118 | static_channels_num=0,
119 | static_features=None,
120 | ):
121 | super().__init__()
122 | self.n_layers = n_layers
123 | assert len(
124 | n_modes) == n_layers, "number of modes for all layers are not given"
125 | assert len(
126 | n_heads) == n_layers, "number of Attention head for all layers are not given"
127 | if output_grid is None:
128 | output_grid = input_grid.clone()
129 | if integral_operator_bottom is None:
130 | integral_operator_bottom = integral_operator
131 | if integral_operator_top is None:
132 | integral_operator_top = integral_operator
133 | self.n_dim = len(n_modes[0])
134 | self.in_token_codim = in_token_codim
135 | self.var_num = var_num
136 | if hidden_token_codim is None:
137 | hidden_token_codim = in_token_codim
138 | if lifting_token_codim is None:
139 | lifting_token_codim = in_token_codim
140 | if out_token_codim is None:
141 | out_token_codim = in_token_codim
142 | self.re_grid_input = re_grid_input
143 | self.re_grid_output = re_grid_output
144 |
145 | if self.re_grid_input:
146 | self.input_regrider = Regird("equiangular", "legendre-gauss")
147 | if self.re_grid_output:
148 | self.output_regrider = Regird("legendre-gauss", "equiangular")
149 |
150 | self.input_grid = input_grid
151 | self.output_grid = output_grid
152 | self.grid_size = grid_size
153 |
154 | self.hidden_token_codim = hidden_token_codim
155 | self.n_modes = n_modes
156 | self.scalings = scalings
157 | self.var_enco_channels = var_enco_channels
158 | self.n_heads = n_heads
159 | self.integral_operator = integral_operator
160 | self.layer_kwargs = layer_kwargs
161 | self.operator_block = operator_block
162 | self.lifting = lifting
163 | self.projection = projection
164 |
165 | self.radius = radius
166 | self.n_neigbor = n_neigbor
167 | self.fixed_neighbour = fixed_neighbour
168 | self.gno_mlp_layers = gno_mlp_layers
169 | self.per_channel_attention = per_channel_attention
170 |
171 | self.register_buffer("static_features", static_features)
172 | self.static_channels_num = static_channels_num
173 | # calculating scaling
174 | if self.scalings is not None:
175 | self.end_to_end_scaling = self.get_output_scaling_factor(
176 | np.ones_like(self.scalings[0]), self.scalings)
177 | print("End to End Scaling", self.end_to_end_scaling)
178 | else:
179 | self.end_to_end_scaling = 1
180 | if isinstance(self.end_to_end_scaling, (float, int)):
181 | self.end_to_end_scaling = [self.end_to_end_scaling] * self.n_dim
182 |
183 | # Setting up domain padding for encoder and reconstructor
184 |
185 | if domain_padding is not None and domain_padding > 0:
186 | self.domain_padding = DomainPadding(
187 | domain_padding=domain_padding,
188 | padding_mode=domain_padding_mode,
189 | output_scaling_factor=self.end_to_end_scaling)
190 | else:
191 | self.domain_padding = None
192 | self.domain_padding_mode = domain_padding_mode
193 |
194 | # Code for varibale encoding
195 |
196 | # initializing Components
197 | if self.lifting:
198 | print('Using lifing Layer')
199 |
200 | # a varibale + it's varibale encoding + the static channen together
201 | # constitute a token
202 |
203 | self.lifting = GnoPremEq(
204 | var_num=var_num,
205 | in_dim=self.in_token_codim,
206 | out_dim=hidden_token_codim,
207 | input_grid=self.input_grid,
208 | output_grid=self.output_grid,
209 | projection_hidden_dim=lifting_token_codim,
210 | mlp_layers=self.gno_mlp_layers,
211 | radius=self.radius,
212 | n_neigbor=n_neigbor,
213 | fixed_neighbour=fixed_neighbour,
214 | var_encoding=var_encoding,
215 | var_encoding_channels=var_enco_channels)
216 |
217 | elif var_encoding:
218 | hidden_token_codim = self.in_token_codim + \
219 | var_enco_channels + self.static_channels_num
220 |
221 | if enable_cls_token:
222 | count = 1
223 | else:
224 | count = 0
225 | self.codim_size = hidden_token_codim * \
226 | (var_num + count) # +1 is for the CLS token
227 |
228 | print("expected number of channels", self.codim_size)
229 |
230 | self.base = nn.ModuleList([])
231 | for i in range(self.n_layers):
232 | if i == 0 and self.n_layers != 1:
233 | conv_op = integral_operator_top
234 | elif i == self.n_layers - 1 and self.n_layers != 1:
235 | conv_op = integral_operator_bottom
236 | else:
237 | conv_op = self.integral_operator
238 |
239 | self.base.append(self.operator_block(
240 | n_modes=self.n_modes[i],
241 | n_head=self.n_heads[i],
242 | token_codimension=hidden_token_codim,
243 | output_scaling_factor=[self.scalings[i]],
244 | SpectralConvolution=conv_op,
245 | codimension_size=self.codim_size,
246 | per_channel_attention=self.per_channel_attention,
247 | kqv_non_linear=kqv_non_linear,
248 | **self.layer_kwargs))
249 | if self.projection:
250 | print("Using Projection Layer")
251 | self.projection = GnoPremEq(
252 | var_num=var_num,
253 | in_dim=self.hidden_token_codim,
254 | out_dim=self.hidden_token_codim,
255 | input_grid=self.output_grid,
256 | output_grid=self.input_grid,
257 | mlp_layers=self.gno_mlp_layers,
258 | radius=self.radius,
259 | n_neigbor=n_neigbor,
260 | fixed_neighbour=fixed_neighbour,
261 | var_encoding=False,
262 | projection_hidden_dim=lifting_token_codim,
263 | var_encoding_channels=0,
264 | end_projection=True,
265 | end_projection_outdim=out_token_codim)
266 |
267 | # Code for varibale encoding
268 |
269 | self.enable_cls_token = enable_cls_token
270 | if enable_cls_token:
271 | print("intializing CLS token")
272 | self.cls_token = VariableEncoding2d(
273 | 1, hidden_token_codim, (var_enco_mode_x, var_enco_mode_y), basis=var_enco_basis)
274 |
275 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer):
276 | for k in scalings_per_layer:
277 | initial_scale = np.multiply(initial_scale, k)
278 | initial_scale = initial_scale.tolist()
279 | if len(initial_scale) == 1:
280 | initial_scale = initial_scale[0]
281 | return initial_scale
282 |
283 | def get_device(self,):
284 | return self.cls_token.coefficients_r.device
285 |
286 | def forward(self, inp):
287 | '''
288 | inp = (batch_size, n_points, in_dims/Channels)
289 | currenly only batch_size = 1
290 | '''
291 | inp = inp[0, :, :]
292 | inp = inp[None, ...]
293 | if self.re_grid_input:
294 | inp = self.input_regrider(inp)
295 | if self.lifting:
296 | x = self.lifting(inp)
297 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0])
298 | else:
299 | x = inp
300 |
301 | if self.enable_cls_token:
302 | cls_token = self.cls_token(x)
303 | x = torch.cat([cls_token[None, :, :, :].repeat(
304 | x.shape[0], 1, 1, 1), x], dim=1)
305 |
306 | if self.domain_padding is not None:
307 | x = self.domain_padding.pad(x)
308 |
309 | output_shape_en = [int(round(i * j)) for (i,
310 | j) in zip(x.shape[-self.n_dim:],
311 | self.end_to_end_scaling)]
312 |
313 | cur_output_shape = None
314 | for layer_idx in range(self.n_layers):
315 | if layer_idx == self.n_layers - 1:
316 | cur_output_shape = output_shape_en
317 | x = self.base[layer_idx](x, output_shape=cur_output_shape)
318 |
319 | if self.re_grid_output:
320 | x = self.output_regrider(x)
321 | if self.projection:
322 | x = rearrange(x, 'b c h w -> b (h w) c')
323 | x = self.projection(x)
324 |
325 | return x
326 |
--------------------------------------------------------------------------------
/models/deeponet.py:
--------------------------------------------------------------------------------
1 | from layers.gnn_layer import GnnLayer
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch
5 | from layers.gnn_layer import GnnLayer
6 | from neuralop.layers.embeddings import PositionalEmbedding
7 |
8 |
9 | class DeepONet(nn.Module):
10 | def __init__(self,
11 | in_dim,
12 | out_dim,
13 | input_grid,
14 | output_grid=None,
15 | branch_layers=[128],
16 | trunk_layers=[128],
17 | initial_mesh=None,
18 | positional_encoding_dim=8,
19 | n_neigbor=10,
20 | gno_mlp_layers=None,
21 | ):
22 | super().__init__()
23 | if output_grid is None:
24 | output_grid = input_grid.clone()
25 | self.n_dim = input_grid.shape[-1]
26 | self.n_neigbor = n_neigbor
27 | self.gno_mlp_layers = gno_mlp_layers
28 | self.in_dim = in_dim
29 | print("in_dim", in_dim)
30 | if out_dim is None:
31 | out_dim = in_dim
32 | self.out_dim = out_dim
33 | self.positional_encoding_dim = positional_encoding_dim
34 | self.input_grid = input_grid
35 | self.output_grid = output_grid
36 | self.initial_mesh = initial_mesh
37 | self.branch_layers = branch_layers
38 | self.trunk_layers = trunk_layers
39 | self.gnn = None
40 | self.branch = self.get_branch()
41 | self.trunk = self.get_trunk()
42 | self.PE = PositionalEmbedding(positional_encoding_dim)
43 | self.bias = nn.Parameter(torch.zeros(out_dim))
44 |
45 | # Code for varibale encoding
46 |
47 | def get_branch(self, ):
48 | dim1 = self.in_dim + self.n_dim * self.positional_encoding_dim
49 | # self.branch_layers = [dim1] + self.branch_layers
50 | self.gnn = GnnLayer(
51 | dim1,
52 | self.branch_layers[0],
53 | self.initial_mesh,
54 | self.initial_mesh,
55 | self.gno_mlp_layers,
56 | self.branch_layers[0],
57 | self.n_neigbor)
58 | self.layer_norm = nn.LayerNorm(self.branch_layers[0])
59 | layers = []
60 |
61 | self.branch_layers[0] = self.branch_layers[0] * \
62 | self.input_grid.shape[0]
63 | for i in range(len(self.branch_layers) - 1):
64 | layers.append(
65 | nn.Linear(self.branch_layers[i], self.branch_layers[i + 1]))
66 | torch.nn.init.xavier_normal_(layers[-1].weight)
67 | if i != len(self.branch_layers) - 2:
68 | # layers.append(nn.LayerNorm(self.branch_layers[i+1]))
69 | layers.append(nn.ReLU())
70 | return nn.Sequential(*layers)
71 |
72 | def get_trunk(self, ):
73 | dim1 = self.n_dim + self.positional_encoding_dim * self.n_dim
74 | self.trunk_layers = [dim1] + self.trunk_layers
75 | self.trunk_layers[-1] = self.trunk_layers[-1] * self.out_dim
76 | layers = []
77 | for i in range(len(self.trunk_layers) - 1):
78 | layers.append(
79 | nn.Linear(self.trunk_layers[i], self.trunk_layers[i + 1]))
80 | torch.nn.init.xavier_normal_(layers[-1].weight)
81 | # if i != len(self.trunk_layers) - 2:
82 | # layers.append(nn.LayerNorm(self.trunk_layers[i+1]))
83 | layers.append(nn.ReLU())
84 | return nn.Sequential(*layers)
85 |
86 | def get_pe(self, grid):
87 | pe = self.PE(grid.reshape(-1))
88 | pe = pe.reshape(grid.shape[0], -1)
89 | return pe
90 |
91 | def forward(
92 | self,
93 | inp,
94 | out_grid_displacement=None,
95 | in_grid_displacement=None):
96 | '''
97 | inp = (batch_size, n_points, in_dims/Channels)
98 | currenly only batch_size = 1
99 | '''
100 |
101 | if out_grid_displacement is not None:
102 | with torch.no_grad():
103 | in_grid = self.initial_mesh + in_grid_displacement
104 | out_grid = self.initial_mesh + out_grid_displacement
105 | self.gnn.update_grid(in_grid.clone(), in_grid.clone())
106 |
107 | in_pe = self.get_pe(in_grid)
108 | in_data = torch.cat([inp, in_pe.unsqueeze(0)], dim=-1)
109 |
110 | bout = self.gnn(in_data[0])[None, ...] # (btach , dim)
111 |
112 | bout = self.layer_norm(bout)
113 |
114 | bout = self.branch(bout.reshape(inp.shape[0], -1))
115 |
116 | bout = bout / np.sqrt(self.branch_layers[-1])
117 |
118 | pe = self.get_pe(out_grid) # self.PE(out_grid.reshape(-1))
119 | # pe = pe.reshape(out_grid.shape[0], -1)
120 | grid_pe = torch.cat([out_grid, pe], axis=1)
121 |
122 | tout = self.trunk(grid_pe) # (ngrid, dim * out_dim)
123 | # (ngrid, out_dim, dim)
124 | tout = tout.reshape(out_grid.shape[0], self.out_dim, -1)
125 |
126 | out = torch.einsum('bd,ncd->bnc', bout, tout)
127 |
128 | return out + self.bias
129 |
--------------------------------------------------------------------------------
/models/fno_gino.py:
--------------------------------------------------------------------------------
1 | from layers.gno_layer import GNO
2 | from layers.fino_2D import SpectralConvKernel2d
3 | from einops import rearrange
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from layers.regrider import Regird
7 | from neuralop.layers.padding import DomainPadding
8 | from neuralop.layers.fno_block import FNOBlocks
9 | import numpy as np
10 | import torch
11 | from layers.variable_encoding import VariableEncoding2d
12 |
13 |
14 | class FnoGno(nn.Module):
15 | def __init__(self,
16 | in_dim,
17 | out_dim,
18 | input_grid,
19 | output_grid=None,
20 | grid_size=None,
21 | radius=None,
22 | fixed_neighbour=False,
23 | n_neigbor=10,
24 | hidden_dim=None,
25 | lifting_dim=None,
26 | n_layers=4,
27 | max_n_modes=None,
28 | n_modes=None,
29 | scalings=None,
30 | initial_mesh=None,
31 | non_linearity=F.gelu,
32 | layer_kwargs={'use_mlp': False,
33 | 'mlp_dropout': 0,
34 | 'mlp_expansion': 1.0,
35 | 'non_linearity': F.gelu,
36 | 'norm': None, 'preactivation': False,
37 | 'fno_skip': 'linear',
38 | 'horizontal_skip': 'linear',
39 | 'mlp_skip': 'linear',
40 | 'separable': False,
41 | 'factorization': None,
42 | 'rank': 1.0,
43 | 'fft_norm': 'forward',
44 | 'normalizer': 'instance_norm',
45 | 'joint_factorization': False,
46 | 'fixed_rank_modes': False,
47 | 'implementation': 'factorized',
48 | 'decomposition_kwargs': dict(),
49 | 'normalizer': False},
50 | operator_block=FNOBlocks,
51 | integral_operator=SpectralConvKernel2d,
52 | integral_operator_top=None,
53 | integral_operator_bottom=None,
54 | re_grid_input=False,
55 | re_grid_output=False,
56 | projection=True,
57 | gno_mlp_layers=None,
58 | lifting=True,
59 | domain_padding=None,
60 | domain_padding_mode='one-sided',
61 | ):
62 | super().__init__()
63 | self.n_layers = n_layers
64 | assert len(
65 | n_modes) == n_layers, "number of modes for all layers are not given"
66 | if output_grid is None:
67 | output_grid = input_grid.clone()
68 | if integral_operator_bottom is None:
69 | integral_operator_bottom = integral_operator
70 | if integral_operator_top is None:
71 | integral_operator_top = integral_operator
72 | self.n_dim = len(max_n_modes[0])
73 | self.in_dim = in_dim
74 | if hidden_dim is None:
75 | hidden_dim = in_dim
76 | if lifting_dim is None:
77 | lifting_dim = in_dim
78 | if out_dim is None:
79 | out_dim = in_dim
80 | self.re_grid_input = re_grid_input
81 | self.re_grid_output = re_grid_output
82 |
83 | if self.re_grid_input:
84 | self.input_regrider = Regird("equiangular", "legendre-gauss")
85 | if self.re_grid_output:
86 | self.output_regrider = Regird("legendre-gauss", "equiangular")
87 |
88 | self.input_grid = input_grid
89 | self.output_grid = output_grid
90 | self.grid_size = grid_size
91 |
92 | self.hidden_dim = hidden_dim
93 | self.n_modes = n_modes
94 | self.max_n_modes = max_n_modes
95 | self.scalings = scalings
96 | self.integral_operator = integral_operator
97 | self.layer_kwargs = layer_kwargs
98 | self.operator_block = operator_block
99 | self.lifting = lifting
100 | self.projection = projection
101 | self.radius = radius
102 | self.fixed_neighbour = fixed_neighbour
103 | self.n_neigbor = n_neigbor
104 | self.gno_mlp_layers = gno_mlp_layers
105 |
106 | # calculating scaling
107 | if self.scalings is not None:
108 | self.end_to_end_scaling = self.get_output_scaling_factor(
109 | np.ones_like(self.scalings[0]), self.scalings)
110 | print("End to End Scaling", self.end_to_end_scaling)
111 | else:
112 | self.end_to_end_scaling = 1
113 | if isinstance(self.end_to_end_scaling, (float, int)):
114 | self.end_to_end_scaling = [self.end_to_end_scaling] * self.n_dim
115 |
116 | # Setting up domain padding for encoder and reconstructor
117 |
118 | if domain_padding is not None and domain_padding > 0:
119 | self.domain_padding = DomainPadding(
120 | domain_padding=domain_padding,
121 | padding_mode=domain_padding_mode,
122 | output_scaling_factor=self.end_to_end_scaling)
123 | else:
124 | self.domain_padding = None
125 | self.domain_padding_mode = domain_padding_mode
126 | self.initial_mesh = initial_mesh
127 | # Code for varibale encoding
128 |
129 | # initializing Components
130 | if self.lifting:
131 | print('Using lifing Layer')
132 |
133 | # a varibale + it's varibale encoding + the static channen together
134 | # constitute a token
135 |
136 | self.lifting = GNO(
137 | in_dim=self.in_dim,
138 | out_dim=hidden_dim,
139 | input_grid=self.input_grid,
140 | output_grid=self.output_grid,
141 | projection_hidden_dim=lifting_dim,
142 | mlp_layers=self.gno_mlp_layers,
143 | radius=self.radius,
144 | fixed_neighbour=self.fixed_neighbour,
145 | n_neigbor=self.n_neigbor)
146 |
147 | self.base = nn.ModuleList([])
148 | for i in range(self.n_layers):
149 | if i == 0 and self.n_layers != 1:
150 | conv_op = integral_operator_top
151 | elif i == self.n_layers - 1 and self.n_layers != 1:
152 | conv_op = integral_operator_bottom
153 | else:
154 | conv_op = self.integral_operator
155 |
156 | self.base.append(self.operator_block(
157 | hidden_dim,
158 | hidden_dim,
159 | max_n_modes=self.max_n_modes[i],
160 | n_modes=self.n_modes[i],
161 | output_scaling_factor=[self.scalings[i]],
162 | SpectralConv=conv_op,
163 | **self.layer_kwargs))
164 | if self.projection:
165 | # input and output grid is swapped
166 |
167 | print("Using Projection Layer")
168 | self.projection = GNO(
169 | in_dim=self.hidden_dim,
170 | out_dim=out_dim,
171 | input_grid=self.output_grid,
172 | projection_hidden_dim=lifting_dim,
173 | output_grid=self.input_grid,
174 | mlp_layers=self.gno_mlp_layers,
175 | radius=self.radius,
176 | fixed_neighbour=self.fixed_neighbour,
177 | n_neigbor=self.n_neigbor)
178 |
179 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer):
180 | for k in scalings_per_layer:
181 | initial_scale = np.multiply(initial_scale, k)
182 | initial_scale = initial_scale.tolist()
183 | if len(initial_scale) == 1:
184 | initial_scale = initial_scale[0]
185 | return initial_scale
186 |
187 | def get_device(self,):
188 | return self.cls_token.coefficients_r.device
189 |
190 | def forward(
191 | self,
192 | inp,
193 | out_grid_displacement=None,
194 | in_grid_displacement=None):
195 | '''
196 | inp = (batch_size, n_points, in_dims/Channels)
197 | currenly only batch_size = 1
198 | '''
199 | if out_grid_displacement is not None:
200 | with torch.no_grad():
201 | self.lifting.update_grid(
202 | self.initial_mesh + in_grid_displacement, None)
203 | self.projection.update_grid(
204 | None, self.initial_mesh + out_grid_displacement)
205 |
206 | if self.re_grid_input:
207 | inp = self.input_regrider(inp)
208 | if self.lifting:
209 | # print("In Lifting")
210 | x = self.lifting(inp)
211 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0])
212 | else:
213 | x = inp
214 |
215 | if self.domain_padding is not None:
216 | x = self.domain_padding.pad(x)
217 |
218 | output_shape_en = [int(round(i * j)) for (i,
219 | j) in zip(x.shape[-self.n_dim:],
220 | self.end_to_end_scaling)]
221 |
222 | cur_output_shape = None
223 | for layer_idx in range(self.n_layers):
224 | if layer_idx == self.n_layers - 1:
225 | cur_output_shape = output_shape_en
226 | x = self.base[layer_idx](x, output_shape=cur_output_shape)
227 |
228 | if self.re_grid_output:
229 | x = self.output_regrider(x)
230 | if self.projection:
231 | # print("projection")
232 | x = rearrange(x, 'b c h w -> b (h w) c')
233 | x = self.projection(x)
234 | return x
235 |
--------------------------------------------------------------------------------
/models/gnn.py:
--------------------------------------------------------------------------------
1 | from layers.gnn_layer import GnnLayer
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from neuralop.layers.mlp import MLPLinear
5 | import numpy as np
6 | import torch
7 |
8 | class GNN(nn.Module):
9 | def __init__(self,
10 | in_dim,
11 | out_dim,
12 | input_grid,
13 | output_grid=None,
14 | n_neigbor=None,
15 | hidden_dim=None,
16 | lifting_dim=None,
17 | n_layers=4,
18 | initial_mesh=None,
19 | non_linearity=F.gelu,
20 | projection=True,
21 | gno_mlp_layers=None,
22 | lifting=True,
23 | ):
24 | super().__init__()
25 | self.n_layers = n_layers
26 |
27 | if output_grid is None:
28 | output_grid = input_grid.clone()
29 |
30 | self.n_dim = input_grid.shape[-1]
31 |
32 | self.in_dim = in_dim
33 |
34 | if hidden_dim is None:
35 | hidden_dim = in_dim
36 | if lifting_dim is None:
37 | lifting_dim = in_dim
38 | if out_dim is None:
39 | out_dim = in_dim
40 |
41 | self.input_grid = input_grid
42 | self.output_grid = output_grid
43 |
44 | self.hidden_dim = hidden_dim
45 |
46 | self.lifting = lifting
47 | self.projection = projection
48 | self.n_neigbor = n_neigbor
49 | self.gno_mlp_layers = gno_mlp_layers
50 |
51 | self.initial_mesh = initial_mesh
52 | # Code for varibale encoding
53 |
54 | # initializing Components
55 | if self.lifting:
56 | print('Using lifing Layer')
57 | self.lifting = MLPLinear(
58 | layers=[self.in_dim, self.hidden_dim],
59 | )
60 |
61 | self.base = nn.ModuleList([])
62 | for i in range(self.n_layers):
63 | self.base.append(GnnLayer(
64 | hidden_dim,
65 | hidden_dim,
66 | self.initial_mesh,
67 | self.initial_mesh,
68 | gno_mlp_layers,
69 | lifting_dim,
70 | n_neigbor))
71 |
72 | if self.projection:
73 | print("Using Projection Layer")
74 | self.projection = MLPLinear(
75 | layers=[self.hidden_dim, out_dim]
76 | )
77 |
78 | def forward(
79 | self,
80 | inp,
81 | out_grid_displacement=None,
82 | in_grid_displacement=None):
83 | '''
84 | inp = (batch_size, n_points, in_dims/Channels)
85 | currenly only batch_size = 1
86 | '''
87 |
88 | if out_grid_displacement is not None:
89 | with torch.no_grad():
90 | for i in range(self.n_layers):
91 | if i == self.n_layers - 1:
92 | # Doing different mesh for last layer
93 | in_grid = self.initial_mesh + in_grid_displacement
94 |
95 | out_grid = self.initial_mesh + out_grid_displacement
96 | else:
97 | in_grid = self.initial_mesh + in_grid_displacement
98 | out_grid = self.initial_mesh + in_grid_displacement
99 | self.base[i].update_grid(
100 | in_grid,
101 | out_grid)
102 |
103 | if self.lifting:
104 | x = self.lifting(inp)
105 | else:
106 | x = inp
107 | x = x[0, ...]
108 | for layer_idx in range(self.n_layers):
109 | x = self.base[layer_idx](x)
110 |
111 | if self.projection:
112 | x = self.projection(x)
113 | return x[None, ...]
114 |
--------------------------------------------------------------------------------
/models/model_helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def count_parameters(model):
5 | with torch.no_grad():
6 | total_count = 0
7 | for p in model.parameters():
8 | if not p.requires_grad:
9 | continue
10 | pcount = torch.tensor(p.numel())
11 | total_count += int(pcount.item())
12 |
13 | return total_count / 1e6
14 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | from layers.gno_layer import GNO
2 | from einops import rearrange
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from layers.regrider import Regird
6 | import numpy as np
7 | import torch
8 | from layers.unet_sublayer import UNet2d
9 |
10 |
11 | class UnetGno(nn.Module):
12 | def __init__(self,
13 | in_dim,
14 | out_dim,
15 | input_grid,
16 | output_grid=None,
17 | grid_size=None,
18 | radius=None,
19 | fixed_neighbour=False,
20 | n_neigbor=10,
21 | hidden_dim=None,
22 | lifting_dim=None,
23 | n_layers=4,
24 | pad_to_size=None,
25 | initial_mesh=None,
26 | non_linearity=F.gelu,
27 | re_grid_input=False,
28 | re_grid_output=False,
29 | projection=True,
30 | gno_mlp_layers=None,
31 | lifting=True,
32 | domain_padding=None,
33 | domain_padding_mode='one-sided',
34 | ):
35 | super().__init__()
36 | self.n_layers = n_layers
37 | if output_grid is None:
38 | output_grid = None # input_grid.clone()
39 |
40 | self.in_dim = in_dim
41 | if hidden_dim is None:
42 | hidden_dim = in_dim
43 | if lifting_dim is None:
44 | lifting_dim = in_dim
45 | if out_dim is None:
46 | out_dim = in_dim
47 | self.re_grid_input = re_grid_input
48 | self.re_grid_output = re_grid_output
49 |
50 | if self.re_grid_input:
51 | self.input_regrider = Regird("equiangular", "legendre-gauss")
52 | if self.re_grid_output:
53 | self.output_regrider = Regird("legendre-gauss", "equiangular")
54 |
55 | self.input_grid = input_grid
56 | self.output_grid = output_grid
57 | self.grid_size = grid_size
58 |
59 | self.hidden_dim = hidden_dim
60 | self.lifting = lifting
61 | self.projection = projection
62 | self.radius = radius
63 | self.fixed_neighbour = fixed_neighbour
64 | self.n_neigbor = n_neigbor
65 | self.gno_mlp_layers = gno_mlp_layers
66 | self.pad_to_size = pad_to_size
67 |
68 | self.initial_mesh = initial_mesh
69 | # Code for varibale encoding
70 |
71 | # initializing Components
72 | if self.lifting:
73 | print('Using lifing Layer')
74 |
75 | # a varibale + it's varibale encoding + the static channen together
76 | # constitute a token
77 |
78 | self.lifting = GNO(
79 | in_dim=self.in_dim,
80 | out_dim=hidden_dim,
81 | input_grid=self.input_grid,
82 | output_grid=self.output_grid,
83 | projection_hidden_dim=lifting_dim,
84 | mlp_layers=self.gno_mlp_layers,
85 | radius=self.radius,
86 | fixed_neighbour=self.fixed_neighbour,
87 | n_neigbor=self.n_neigbor)
88 |
89 | self.base = UNet2d(in_channels=hidden_dim,
90 | out_channels=hidden_dim, init_features=hidden_dim)
91 |
92 | if self.projection:
93 | # input and output grid is swapped
94 |
95 | print("Using Projection Layer")
96 | self.projection = GNO(
97 | in_dim=self.hidden_dim,
98 | out_dim=out_dim,
99 | input_grid=self.output_grid,
100 | projection_hidden_dim=lifting_dim,
101 | output_grid=self.input_grid,
102 | mlp_layers=self.gno_mlp_layers,
103 | radius=self.radius,
104 | fixed_neighbour=self.fixed_neighbour,
105 | n_neigbor=self.n_neigbor)
106 |
107 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer):
108 | for k in scalings_per_layer:
109 | initial_scale = np.multiply(initial_scale, k)
110 | initial_scale = initial_scale.tolist()
111 | if len(initial_scale) == 1:
112 | initial_scale = initial_scale[0]
113 | return initial_scale
114 |
115 | def get_device(self,):
116 | return self.cls_token.coefficients_r.device
117 |
118 | def forward(
119 | self,
120 | inp,
121 | out_grid_displacement=None,
122 | in_grid_displacement=None):
123 | '''
124 | inp = (batch_size, n_points, in_dims/Channels)
125 | currenly only batch_size = 1
126 | '''
127 | if out_grid_displacement is not None:
128 | with torch.no_grad():
129 | self.lifting.update_grid(
130 | self.initial_mesh + in_grid_displacement, None)
131 | self.projection.update_grid(
132 | None, self.initial_mesh + out_grid_displacement)
133 |
134 | if self.re_grid_input:
135 | inp = self.input_regrider(inp)
136 | if self.lifting:
137 | # print("In Lifting")
138 | x = self.lifting(inp)
139 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0])
140 | else:
141 | x = inp
142 |
143 | pad_size_1 = self.pad_to_size[-2] - x.shape[-2]
144 | pad_size_2 = self.pad_to_size[-1] - x.shape[-1]
145 | x = nn.functional.pad(
146 | x,
147 | (pad_size_2 // 2,
148 | pad_size_2 - pad_size_2 // 2,
149 | pad_size_1 // 2,
150 | pad_size_1 - pad_size_1 // 2))
151 | x = self.base(x)
152 |
153 |
154 |
155 | x = x[:, :, pad_size_1 // 2: -(pad_size_1 - pad_size_1 // 2),
156 | pad_size_2 // 2: -(pad_size_2 - pad_size_2 // 2)]
157 |
158 | if self.re_grid_output:
159 | x = self.output_regrider(x)
160 | if self.projection:
161 | x = rearrange(x, 'b c h w -> b (h w) c')
162 | x = self.projection(x)
163 | return x
164 |
--------------------------------------------------------------------------------
/models/vit.py:
--------------------------------------------------------------------------------
1 | from layers.gno_layer import GNO
2 | from functools import partial
3 | from einops import rearrange
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from layers.regrider import Regird
7 | import numpy as np
8 | import torch
9 | from layers.regular_transformer import vision_transformer
10 |
11 |
12 | class VitGno(nn.Module):
13 | def __init__(self,
14 | in_dim,
15 | out_dim,
16 | input_grid,
17 | output_grid=None,
18 | grid_size=None,
19 | radius=None,
20 | fixed_neighbour=False,
21 | n_neigbor=10,
22 | hidden_dim=None,
23 | lifting_dim=None,
24 | n_layers=4,
25 | initial_mesh=None,
26 | non_linearity=F.gelu,
27 | patch_size=(10, 5),
28 | heads=10,
29 | contraction_factor=128,
30 | re_grid_input=False,
31 | re_grid_output=False,
32 | projection=True,
33 | gno_mlp_layers=None,
34 | lifting=True,
35 | domain_padding=None,
36 | domain_padding_mode='one-sided',
37 | ):
38 | super().__init__()
39 | self.n_layers = n_layers
40 | if output_grid is None:
41 | output_grid = input_grid.clone()
42 |
43 | self.in_dim = in_dim
44 | if hidden_dim is None:
45 | hidden_dim = in_dim
46 | if lifting_dim is None:
47 | lifting_dim = in_dim
48 | if out_dim is None:
49 | out_dim = in_dim
50 | self.re_grid_input = re_grid_input
51 | self.re_grid_output = re_grid_output
52 |
53 | if self.re_grid_input:
54 | self.input_regrider = Regird("equiangular", "legendre-gauss")
55 | if self.re_grid_output:
56 | self.output_regrider = Regird("legendre-gauss", "equiangular")
57 |
58 | self.input_grid = input_grid
59 | self.output_grid = output_grid
60 | self.grid_size = grid_size
61 |
62 | self.hidden_dim = hidden_dim
63 | self.lifting = lifting
64 | self.projection = projection
65 | self.radius = radius
66 | self.fixed_neighbour = fixed_neighbour
67 | self.n_neigbor = n_neigbor
68 | self.gno_mlp_layers = gno_mlp_layers
69 |
70 | # transformers parameters
71 | self.contraction_factor = contraction_factor
72 | self.grid_size = grid_size
73 | self.patch_size = patch_size
74 | self.heads = heads
75 |
76 | self.initial_mesh = initial_mesh
77 | # Code for varibale encoding
78 |
79 | # initializing Components
80 | if self.lifting:
81 | print('Using lifing Layer')
82 |
83 | # a varibale + it's varibale encoding + the static channen together
84 | # constitute a token
85 |
86 | self.lifting = GNO(
87 | in_dim=self.in_dim,
88 | out_dim=hidden_dim,
89 | input_grid=self.input_grid,
90 | output_grid=self.output_grid,
91 | projection_hidden_dim=lifting_dim,
92 | mlp_layers=self.gno_mlp_layers,
93 | radius=self.radius,
94 | fixed_neighbour=self.fixed_neighbour,
95 | n_neigbor=self.n_neigbor)
96 |
97 | self.base = vision_transformer(
98 | image_size=self.grid_size,
99 | patch_size=self.patch_size,
100 | num_classes=1,
101 | dim=self.patch_size[0] * self.patch_size[1] *
102 | self.hidden_dim // self.contraction_factor,
103 | depth=self.n_layers,
104 | channels=hidden_dim,
105 | heads=self.heads,
106 | mlp_dim=self.patch_size[0] * self.patch_size[1] *
107 | hidden_dim // self.contraction_factor,
108 | dropout=0.0,
109 | emb_dropout=0.0
110 | )
111 | self.expander = nn.Linear(
112 | self.patch_size[0] *
113 | self.patch_size[1] *
114 | hidden_dim //
115 | self.contraction_factor,
116 | self.grid_size[0] *
117 | self.grid_size[1] *
118 | hidden_dim)
119 | if self.projection:
120 | # input and output grid is swapped
121 |
122 | print("Using Projection Layer")
123 | self.projection = GNO(
124 | in_dim=self.hidden_dim,
125 | out_dim=out_dim,
126 | input_grid=self.output_grid,
127 | projection_hidden_dim=lifting_dim,
128 | output_grid=self.input_grid,
129 | mlp_layers=self.gno_mlp_layers,
130 | radius=self.radius,
131 | fixed_neighbour=self.fixed_neighbour,
132 | n_neigbor=self.n_neigbor)
133 |
134 | def get_output_scaling_factor(self, initial_scale, scalings_per_layer):
135 | for k in scalings_per_layer:
136 | initial_scale = np.multiply(initial_scale, k)
137 | initial_scale = initial_scale.tolist()
138 | if len(initial_scale) == 1:
139 | initial_scale = initial_scale[0]
140 | return initial_scale
141 |
142 | def get_device(self,):
143 | return self.cls_token.coefficients_r.device
144 |
145 | def forward(
146 | self,
147 | inp,
148 | out_grid_displacement=None,
149 | in_grid_displacement=None):
150 | '''
151 | inp = (batch_size, n_points, in_dims/Channels)
152 | currenly only batch_size = 1
153 | '''
154 | if out_grid_displacement is not None:
155 | with torch.no_grad():
156 | self.lifting.update_grid(
157 | self.initial_mesh + in_grid_displacement, None)
158 | self.projection.update_grid(
159 | None, self.initial_mesh + out_grid_displacement)
160 |
161 | if self.re_grid_input:
162 | inp = self.input_regrider(inp)
163 | if self.lifting:
164 | # print("In Lifting")
165 | x = self.lifting(inp)
166 | x = rearrange(x, 'b (h w) c -> b c h w', h=self.grid_size[0])
167 | else:
168 | x = inp
169 |
170 | x = self.base(x)
171 | x = self.expander(x)
172 | x = x.reshape(-1, self.hidden_dim,
173 | self.grid_size[0], self.grid_size[1])
174 |
175 | if self.re_grid_output:
176 | x = self.output_regrider(x)
177 | if self.projection:
178 | x = rearrange(x, 'b c h w -> b (h w) c')
179 | x = self.projection(x)
180 | return x
181 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.9.5
2 | einops==0.8.0
3 | filelock==3.15.4
4 | fsspec==2024.6.1
5 | GitPython==3.1.43
6 | h5py==3.11.0
7 | joblib==1.4.2
8 | Jinja2==3.1.4
9 | matplotlib==3.9.1
10 | networkx
11 | -e git+https://github.com/ashiq24/neuraloperator.git@codano_rep#egg=neuraloperator
12 | numpy==2.0.0
13 | opt-einsum==3.3.0
14 | protobuf==5.27.2
15 | PyYAML==6.0.1
16 | requests==2.32.3
17 | scikit-learn==1.5.1
18 | scipy==1.14.0
19 | sympy==1.13.1
20 | tensorly==0.8.1
21 | torch
22 | torchaudio
23 | torchvision
24 | tensorly-torch==0.5.0
25 | threadpoolctl==3.5.0
26 | tqdm==4.66.4
27 | typing_extensions
28 | urllib3==2.2.2
29 | vit-pytorch==1.7.4
30 | wandb==0.17.5
31 | yarl==1.9.4
32 | zarr==2.18.2
33 | torch_geometric==2.5.3
34 | torch-cluster==1.6.3
35 | torch-sparse==0.6.18
36 | torch_scatter==2.1.2
37 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neuraloperator/CoDA-NO/21ea8ae9ec70f49a05b7902dc2d51613a1b05618/test/__init__.py
--------------------------------------------------------------------------------
/test/evaluations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from data_utils.data_utils import *
3 | import torch.nn as nn
4 | from timeit import default_timer
5 | from models.get_models import *
6 | from tqdm import tqdm
7 | import wandb
8 | from utils import *
9 | from train.trainer import *
10 |
11 |
12 | def missing_variable_testing(
13 | model,
14 | test_loader,
15 | augmenter=None,
16 | stage=StageEnum.PREDICTIVE,
17 | params=None,
18 | variable_encoder=None,
19 | token_expander=None,
20 | initial_mesh=None,
21 | wandb_log=False):
22 | print('Evaluating for Stage: ', stage)
23 | model.eval()
24 | with torch.no_grad():
25 | ntest = 0
26 | test_l2 = 0
27 | test_l1 = 0
28 | loss_p = nn.MSELoss()
29 | loss_l1 = nn.L1Loss()
30 | t1 = default_timer()
31 | predictions = []
32 | for data in test_loader:
33 | x, y = data['x'].cuda(), data['y'].cuda()
34 | static_features = data['static_features']
35 |
36 | if augmenter is not None:
37 | x, _ = batched_masker(x, augmenter)
38 |
39 | inp = prepare_input(
40 | x,
41 | static_features,
42 | params,
43 | variable_encoder,
44 | token_expander,
45 | initial_mesh,
46 | data)
47 | out_grid_displacement, in_grid_displacement = get_grid_displacement(
48 | params, stage, data)
49 |
50 | batch_size = x.shape[0]
51 | out = model(inp, out_grid_displacement=out_grid_displacement,
52 | in_grid_displacement=in_grid_displacement)
53 |
54 | if getattr(params, 'horizontal_skip', False):
55 | out = out + x
56 |
57 | if isinstance(out, (list, tuple)):
58 | out = out[0]
59 |
60 | ntest += 1
61 | target = y.clone()
62 |
63 | predictions.append((out, target))
64 |
65 | test_l2 += loss_p(target.reshape(batch_size, -1),
66 | out.reshape(batch_size, -1)).item()
67 | test_l1 += loss_l1(target.reshape(batch_size, -1),
68 | out.reshape(batch_size, -1)).item()
69 |
70 | test_l2 /= ntest
71 | test_l1 /= ntest
72 | t2 = default_timer()
73 | avg_time = (t2 - t1) / ntest
74 |
75 | wandb.log({'Augmented test_error_l2': test_l2}, commit=True)
76 | wandb.log({'Augmented test_error_l1': test_l1}, commit=True)
77 | wandb.log({'Avg test_time': avg_time}, commit=True)
78 | print(f"Augmented Test Error {stage}: ", test_l2)
79 |
80 | if hasattr(params, 'save_predictions') and params.save_predictions:
81 | torch.save(predictions[:50], f'../xy/predictions_{params.config}.pt')
82 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import *
2 | from .new_adam import *
3 |
--------------------------------------------------------------------------------
/train/new_adam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import Tensor
4 | from typing import List, Optional
5 | from torch.optim.optimizer import Optimizer
6 |
7 |
8 | def adam(params: List[Tensor],
9 | grads: List[Tensor],
10 | exp_avgs: List[Tensor],
11 | exp_avg_sqs: List[Tensor],
12 | max_exp_avg_sqs: List[Tensor],
13 | state_steps: List[int],
14 | *,
15 | amsgrad: bool,
16 | beta1: float,
17 | beta2: float,
18 | lr: float,
19 | weight_decay: float,
20 | eps: float):
21 | r"""Functional API that performs Adam algorithm computation.
22 | See :class:`~torch.optim.Adam` for details.
23 | """
24 |
25 | for i, param in enumerate(params):
26 |
27 | grad = grads[i]
28 | exp_avg = exp_avgs[i]
29 | exp_avg_sq = exp_avg_sqs[i]
30 | step = state_steps[i]
31 |
32 | bias_correction1 = 1 - beta1 ** step
33 | bias_correction2 = 1 - beta2 ** step
34 |
35 | if weight_decay != 0:
36 | grad = grad.add(param, alpha=weight_decay)
37 |
38 | # Decay the first and second moment running average coefficient
39 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
40 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
41 | if amsgrad:
42 | # Maintains the maximum of all 2nd moment running avg. till now
43 | torch.maximum(
44 | max_exp_avg_sqs[i],
45 | exp_avg_sq,
46 | out=max_exp_avg_sqs[i])
47 | # Use the max. for normalizing running avg. of gradient
48 | denom = (
49 | max_exp_avg_sqs[i].sqrt() /
50 | math.sqrt(bias_correction2)).add_(eps)
51 | else:
52 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
53 |
54 | step_size = lr / bias_correction1
55 |
56 | param.addcdiv_(exp_avg, denom, value=-step_size)
57 |
58 |
59 | class Adam(Optimizer):
60 | r"""Implements Adam algorithm.
61 | It has been proposed in `Adam: A Method for Stochastic Optimization`_.
62 | The implementation of the L2 penalty follows changes proposed in
63 | `Decoupled Weight Decay Regularization`_.
64 | Args:
65 | params (iterable): iterable of parameters to optimize or dicts defining
66 | parameter groups
67 | lr (float, optional): learning rate (default: 1e-3)
68 | betas (Tuple[float, float], optional): coefficients used for computing
69 | running averages of gradient and its square (default: (0.9, 0.999))
70 | eps (float, optional): term added to the denominator to improve
71 | numerical stability (default: 1e-8)
72 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
73 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
74 | algorithm from the paper `On the Convergence of Adam and Beyond`_
75 | (default: False)
76 | .. _Adam\: A Method for Stochastic Optimization:
77 | https://arxiv.org/abs/1412.6980
78 | .. _Decoupled Weight Decay Regularization:
79 | https://arxiv.org/abs/1711.05101
80 | .. _On the Convergence of Adam and Beyond:
81 | https://openreview.net/forum?id=ryQu7f-RZ
82 | """
83 |
84 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
85 | weight_decay=0, amsgrad=False):
86 | if not 0.0 <= lr:
87 | raise ValueError("Invalid learning rate: {}".format(lr))
88 | if not 0.0 <= eps:
89 | raise ValueError("Invalid epsilon value: {}".format(eps))
90 | if not 0.0 <= betas[0] < 1.0:
91 | raise ValueError(
92 | "Invalid beta parameter at index 0: {}".format(
93 | betas[0]))
94 | if not 0.0 <= betas[1] < 1.0:
95 | raise ValueError(
96 | "Invalid beta parameter at index 1: {}".format(
97 | betas[1]))
98 | if not 0.0 <= weight_decay:
99 | raise ValueError(
100 | "Invalid weight_decay value: {}".format(weight_decay))
101 | defaults = dict(lr=lr, betas=betas, eps=eps,
102 | weight_decay=weight_decay, amsgrad=amsgrad)
103 | super(Adam, self).__init__(params, defaults)
104 |
105 | def __setstate__(self, state):
106 | super(Adam, self).__setstate__(state)
107 | for group in self.param_groups:
108 | group.setdefault('amsgrad', False)
109 |
110 | @torch.no_grad()
111 | def step(self, closure=None):
112 | """Performs a single optimization step.
113 | Args:
114 | closure (callable, optional): A closure that reevaluates the model
115 | and returns the loss.
116 | """
117 | loss = None
118 | if closure is not None:
119 | with torch.enable_grad():
120 | loss = closure()
121 |
122 | for group in self.param_groups:
123 | params_with_grad = []
124 | grads = []
125 | exp_avgs = []
126 | exp_avg_sqs = []
127 | max_exp_avg_sqs = []
128 | state_steps = []
129 | beta1, beta2 = group['betas']
130 |
131 | for p in group['params']:
132 | if p.grad is not None:
133 | params_with_grad.append(p)
134 | if p.grad.is_sparse:
135 | raise RuntimeError(
136 | 'Adam does not support sparse gradients, please consider SparseAdam instead')
137 | grads.append(p.grad)
138 |
139 | state = self.state[p]
140 | # Lazy state initialization
141 | if len(state) == 0:
142 | state['step'] = 0
143 | # Exponential moving average of gradient values
144 | state['exp_avg'] = torch.zeros_like(
145 | p, memory_format=torch.preserve_format)
146 | # Exponential moving average of squared gradient values
147 | state['exp_avg_sq'] = torch.zeros_like(
148 | p, memory_format=torch.preserve_format)
149 | if group['amsgrad']:
150 | # Maintains max of all exp. moving avg. of sq.
151 | # grad. values
152 | state['max_exp_avg_sq'] = torch.zeros_like(
153 | p, memory_format=torch.preserve_format)
154 |
155 | exp_avgs.append(state['exp_avg'])
156 | exp_avg_sqs.append(state['exp_avg_sq'])
157 |
158 | if group['amsgrad']:
159 | max_exp_avg_sqs.append(state['max_exp_avg_sq'])
160 |
161 | # update the steps for each param group update
162 | state['step'] += 1
163 | # record the step after step update
164 | state_steps.append(state['step'])
165 |
166 | adam(params_with_grad,
167 | grads,
168 | exp_avgs,
169 | exp_avg_sqs,
170 | max_exp_avg_sqs,
171 | state_steps,
172 | amsgrad=group['amsgrad'],
173 | beta1=beta1,
174 | beta2=beta2,
175 | lr=group['lr'],
176 | weight_decay=group['weight_decay'],
177 | eps=group['eps'])
178 | return loss
179 |
--------------------------------------------------------------------------------
/train/trainer.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import logging
3 | from timeit import default_timer
4 | import tqdm
5 | import wandb
6 | from data_utils.data_utils import *
7 | import torch
8 | from models.get_models import *
9 | from torch import nn
10 | from torch.utils import data
11 | from .new_adam import Adam
12 | from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
13 | from utils import *
14 |
15 | def get_grid_displacement(params, stage, data):
16 | if params.grid_type == "non uniform":
17 | with torch.no_grad():
18 | if stage == StageEnum.RECONSTRUCTIVE:
19 | out_grid_displacement = data['d_grid_x'].cuda()[0]
20 | in_grid_displacement = data['d_grid_x'].cuda()[0]
21 | else:
22 | out_grid_displacement = data['d_grid_y'].cuda()[0]
23 | in_grid_displacement = data['d_grid_x'].cuda()[0]
24 | else:
25 | out_grid_displacement = None
26 | in_grid_displacement = None
27 | return out_grid_displacement, in_grid_displacement
28 |
29 | def trainer(
30 | model,
31 | train_loader,
32 | test_loader,
33 | params,
34 | wandb_log=False,
35 | log_test_interval=1,
36 | stage=StageEnum.RECONSTRUCTIVE,
37 | variable_encoder=None,
38 | token_expander=None,
39 | initial_mesh=None,
40 | ):
41 | lr = params.lr
42 | weight_decay = params.weight_decay
43 | scheduler_step = params.scheduler_step
44 | scheduler_gamma = params.scheduler_gamma
45 | epochs = params.epochs
46 | weight_path = params.weight_path
47 | optimizer = Adam(model.parameters(), lr=lr,
48 | weight_decay=weight_decay, amsgrad=False)
49 | if params.scheduler_type == 'step':
50 | scheduler = StepLR(
51 | optimizer, step_size=scheduler_step, gamma=scheduler_gamma)
52 | else:
53 | scheduler = ReduceLROnPlateau(
54 | optimizer, patience=scheduler_step, factor=scheduler_gamma)
55 |
56 | loss_p = nn.MSELoss(reduction='sum')
57 |
58 | for ep in range(epochs):
59 | model.train()
60 | t1 = default_timer()
61 | train_l2 = 0
62 | train_count = 0
63 | train_loader_iter = tqdm.tqdm(
64 | train_loader, desc=f'Epoch {ep}/{epochs}', leave=False, ncols=100)
65 |
66 | for data in train_loader_iter:
67 | optimizer.zero_grad()
68 | x, y = data['x'].cuda(), data['y'].cuda()
69 | static_features = data['static_features']
70 | if stage == StageEnum.RECONSTRUCTIVE and params.masking:
71 | x = model.do_mask(x)
72 |
73 | inp = prepare_input(
74 | x,
75 | static_features,
76 | params,
77 | variable_encoder,
78 | token_expander,
79 | initial_mesh,
80 | data)
81 | batch_size = x.shape[0]
82 | if params.grid_type == "non uniform":
83 | out_grid_displacement, in_grid_displacement = get_grid_displacement(
84 | params, stage, data)
85 | elif params.grid_type == "uniform":
86 | out_grid_displacement = None
87 | in_grid_displacement = None
88 |
89 | out = model(inp, out_grid_displacement=out_grid_displacement,
90 | in_grid_displacement=in_grid_displacement)
91 |
92 | if isinstance(out, (list, tuple)):
93 | out = out[0]
94 | if getattr(params, 'horizontal_skip', False):
95 | out = out + x
96 |
97 | train_count += 1
98 | target = x.clone() if stage == StageEnum.RECONSTRUCTIVE else y.clone()
99 | loss = loss_p(target.reshape(
100 | batch_size, -1), out.reshape(batch_size, -1)) / (x.shape[0] * x.shape[-1] * x.shape[-2])
101 | loss.backward()
102 |
103 | if params.clip_gradient:
104 | nn.utils.clip_grad_value_(
105 | model.parameters(), params.gradient_clip_value)
106 |
107 | optimizer.step()
108 | train_l2 += loss.item()
109 | del x, y, out, loss
110 | gc.collect()
111 |
112 | torch.cuda.empty_cache()
113 | avg_train_l2 = train_l2 / train_count
114 |
115 | if params.scheduler_type != 'step':
116 | scheduler.step(avg_train_l2)
117 | else:
118 | scheduler.step()
119 |
120 | t2 = default_timer()
121 | epoch_train_time = t2 - t1
122 |
123 | if ep % log_test_interval == 0:
124 | values_to_log = dict(train_err=avg_train_l2, time=epoch_train_time)
125 | print(
126 | f"Epoch {ep}: Time: {epoch_train_time:.3f}s, Loss: {avg_train_l2:.7f}")
127 | if wandb_log:
128 | wandb.log(values_to_log, commit=True)
129 |
130 | if ep % params.weight_saving_interval == 0 or ep == epochs - 1:
131 | stage_string = 'ssl' if stage == StageEnum.RECONSTRUCTIVE else 'sl'
132 | if params.nettype != 'transformer':
133 | torch.save(model.state_dict(), weight_path +
134 | params.config + "_" + str(ep) + '.pt')
135 | else:
136 | weight_path_model_encoder = weight_path + params.config + \
137 | "_" + stage_string + '_encoder_' + str(ep) + '.pt'
138 | weight_path_model_decoder = weight_path + params.config + \
139 | "_" + stage_string + '_decoder_' + str(ep) + '.pt'
140 | weight_path_whole_model = weight_path + params.config + \
141 | "_" + stage_string + '_whole_model_' + str(ep) + '.pt'
142 | torch.save(model.encoder.state_dict(),
143 | weight_path_model_encoder)
144 | torch.save(model.decoder.state_dict(),
145 | weight_path_model_decoder)
146 | torch.save(model.state_dict(), weight_path_whole_model)
147 | if variable_encoder is not None:
148 | variable_path = weight_path + params.config + \
149 | "_variable_encoder_" + str(ep)
150 | variable_encoder.save_all_encoder(variable_path)
151 |
152 | model.eval()
153 | test_l2 = 0.0
154 | ntest = 0
155 | loss_p = nn.MSELoss(reduction='sum')
156 | with torch.no_grad():
157 | for data in test_loader:
158 | x, y = data['x'].cuda(), data['y'].cuda()
159 | static_features = data['static_features']
160 | inp = prepare_input(
161 | x,
162 | static_features,
163 | params,
164 | variable_encoder,
165 | token_expander,
166 | initial_mesh,
167 | data)
168 | out_grid_displacement, in_grid_displacement = get_grid_displacement(
169 | params, stage, data)
170 | batch_size = x.shape[0]
171 | out = model(inp, in_grid_displacement=in_grid_displacement,
172 | out_grid_displacement=out_grid_displacement)
173 |
174 | if isinstance(out, (list, tuple)):
175 | out = out[0]
176 | if getattr(params, 'horizontal_skip', False):
177 | out = out + x
178 |
179 | ntest += x.shape[0]
180 | target = x.clone() if stage == StageEnum.RECONSTRUCTIVE else y.clone()
181 | test_l2 += loss_p(target.reshape(batch_size, -1),
182 | out.reshape(batch_size, -1)).item()
183 |
184 | test_l2 /= (ntest * x.shape[-1] * x.shape[-2])
185 | t2 = default_timer()
186 |
187 | if wandb_log:
188 | stage_string = 'ssl' if stage == StageEnum.RECONSTRUCTIVE else 'sl'
189 | wandb.log({'test_error_' + stage_string: test_l2}, commit=True)
190 | print("Test Error : " + stage_string, test_l2)
191 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import os
4 | import pathlib
5 | import psutil
6 | import re
7 | import signal
8 |
9 | from typing import List
10 |
11 | import h5py
12 | # from haikunator import Haikunator
13 | import numpy as np
14 | import psutil
15 | import torch
16 | import torch.nn as nn
17 |
18 | # HAIKU = haikunator.Haikunator()
19 |
20 |
21 | def prepare_input(
22 | x,
23 | static_features,
24 | params,
25 | variable_encoder,
26 | token_expander,
27 | initial_mesh,
28 | data):
29 | if variable_encoder is not None and token_expander is not None:
30 | if params.grid_type == 'uniform':
31 | inp = token_expander(x, variable_encoder(x),
32 | static_features.cuda())
33 | elif params.grid_type == 'non uniform':
34 | initial_mesh = initial_mesh.cuda()
35 | equation = [i[0] for i in data['equation']]
36 | inp = token_expander(
37 | x,
38 | variable_encoder(
39 | initial_mesh +
40 | data['d_grid_x'].cuda()[0],
41 | equation),
42 | static_features.cuda())
43 | elif params.n_static_channels > 0 and params.grid_type == 'non uniform':
44 | inp = torch.cat(
45 | [x, static_features[:, :, :params.n_static_channels].cuda()], dim=-1)
46 | else:
47 | inp = x
48 | return inp
49 |
50 |
51 | def get_wandb_api_key(api_key_file="config/wandb_api_key.txt"):
52 | try:
53 | return os.environ["WANDB_API_KEY"]
54 | except KeyError:
55 | with open(api_key_file, "r") as f:
56 | key = f.read()
57 | return key.strip()
58 |
59 | def get_mesh(params):
60 | """Get the mesh from a location."""
61 | if hasattr(params, "text_mesh") and params.text_mesh:
62 | # load mesh_x and mesh_y from txt file as np array
63 | mesh_x = np.loadtxt(params.mesh_x)
64 | mesh_y = np.loadtxt(params.mesh_y)
65 | # create mesh from mesh_x and mesh_y
66 | mesh = np.zeros((mesh_x.shape[0], 2))
67 | mesh[:, 0] = mesh_x
68 | mesh[:, 1] = mesh_y
69 | else:
70 | h5f = h5py.File(params.input_mesh_location, 'r')
71 | mesh = h5f['mesh/coordinates']
72 |
73 | if params.super_resolution:
74 | # load mesh_x and mesh_y from txt file as np array
75 | mesh_x = np.loadtxt(params.super_resolution_mesh_x)
76 | mesh_y = np.loadtxt(params.super_resolution_mesh_y)
77 | # create mesh from mesh_x and mesh_y
78 | mesh_sup = np.zeros((mesh_x.shape[0], 2))
79 | mesh_sup[:, 0] = mesh_x
80 | mesh_sup[:, 1] = mesh_y
81 | # merge it with the original mesh
82 | mesh = np.concatenate((mesh, mesh_sup), axis=0)
83 |
84 | print("Super Resolution Mesh Shape", mesh.shape)
85 |
86 | if hasattr(
87 | params,
88 | 'sub_sample_size') and params.sub_sample_size is not None:
89 | mesh_size = mesh.shape[0]
90 | indexs = [i for i in range(mesh_size)]
91 | np.random.seed(params.random_seed)
92 | sub_indexs = np.random.choice(
93 | indexs, params.sub_sample_size, replace=False)
94 | mesh = mesh[sub_indexs, :]
95 |
96 | return mesh[:]
97 |
98 |
99 | # TODO add collision checks
100 | # TODO add opts to toggle haiku and date fixes
101 | def save_model(
102 | model,
103 | directory: pathlib.Path,
104 | stage=None,
105 | sep='_',
106 | name=None,
107 | config=None):
108 | """Saves a model with a unique prefix/suffix
109 |
110 | The model is prefixed with is date (formatted like YYMMDDHHmm)
111 | and suffixed with a "Heroku-like" name (for verbal reference).
112 |
113 | Params:
114 | ---
115 | stage: None | StageEnum
116 | Controls the infix of the model name according to the following mapping:
117 | | None -> "model"
118 | | RECONSTRUCTIVE -> "reconstructive"
119 | | PREDICTIVE -> "predictive"
120 | """
121 | prefix = datetime.datetime.utcnow().strftime("%y%m%d%H%M")
122 | infix = "model"
123 | if stage is not None:
124 | infix = stage.value.lower()
125 | # suffix = Haikunator.haikunate(token_length=0, delimiter=sep)
126 |
127 | torch.save(model.state_dict(), directory / f"{name}{sep}{config}{sep}.pth")
128 |
129 |
130 | def extract_pids(message) -> List[int]:
131 | # Assume `message` has a preamble followed by a sequence of tokens like
132 | # "Process \d+" with extra characters in between such tokens.
133 |
134 | pattern = re.compile("(Process \\d+)")
135 | # Contains "Process" tokens and extra characters, interleaved:
136 | tokens = pattern.split(message)
137 | # print('\n'.join(map(repr, zip(split[1::2], split[2::2]))))
138 |
139 | pattern2 = re.compile("(\\d+)")
140 | # print('\n'.join([repr((s, pattern2.search(t)[0])) for t in tokens[1::2]]))
141 | pids = [int(pattern2.search(t)[0]) for t in tokens[1::2]]
142 |
143 | return pids
144 |
145 |
146 | # https://psutil.readthedocs.io/en/latest/#kill-process-tree
147 | def signal_process_tree(
148 | pid,
149 | sig=signal.SIGTERM,
150 | include_parent=True,
151 | timeout=None,
152 | on_terminate=None,
153 | logger=None,
154 | ):
155 | """Kill a process tree (including grandchildren) with signal ``sig``
156 |
157 | Return a (gone, still_alive) tuple.
158 |
159 | Parameters
160 | ---
161 | timeout: float
162 | Time, in seconds, to wait on each signaled process.
163 | on_terminate: Optional[Callable]
164 | A callback function which is called as soon as a child terminates.
165 | Optional.
166 | """
167 | assert pid != os.getpid(), "won't kill myself"
168 | parent = psutil.Process(pid)
169 | children = parent.children(recursive=True)
170 | if include_parent:
171 | children.append(parent)
172 | if logger is None:
173 | logger = logging.getLogger()
174 |
175 | wait_children = []
176 | for p in children:
177 | try:
178 | p.send_signal(sig)
179 | wait_children.append(p)
180 | except psutil.AccessDenied:
181 | logger.error(f"Unable to terminate Process {pid} (AccessDenied)")
182 | except psutil.NoSuchProcess:
183 | pass
184 |
185 | gone, alive = psutil.wait_procs(
186 | wait_children,
187 | timeout=timeout,
188 | callback=on_terminate,
189 | )
190 | return (gone, alive)
191 |
192 |
193 | def count_model_params(model):
194 | """Returns the total number of parameters of a PyTorch model
195 |
196 | Notes
197 | -----
198 | One complex number is counted as two parameters (we count real and imaginary parts)'
199 | """
200 | return sum(
201 | [p.numel() * 2 if p.is_complex() else p.numel()
202 | for p in model.parameters()]
203 | )
204 |
205 |
206 | def signal_my_processes(
207 | username,
208 | pids,
209 | sig=signal.SIGTERM,
210 | logger=None,
211 | ):
212 | if logger is None:
213 | logger = logging.getLogger()
214 | my_pids = []
215 | for pid in pids:
216 | p = psutil.Process(pid)
217 | with p.oneshot():
218 | p = p.as_dict(attrs=["username", "status"])
219 |
220 | # TODO add other states to the filter
221 | if p["username"] == username and p["status"] == "sleeping":
222 | my_pids.append(pid)
223 | else:
224 | _p = {"pid": pid, **p}
225 | logger.warning(f"Cannot signal process: {_p}")
226 |
227 | for my_pid in my_pids:
228 | gone, alive = signal_process_tree(pid, sig, timeout=60, logger=logger)
229 | logger.info(f"{gone=}, {alive=}")
230 |
--------------------------------------------------------------------------------