├── .gitignore ├── gnn_bvp_solver ├── fem_dataset │ ├── __init__.py │ ├── modules │ │ ├── __init__.py │ │ ├── circle.py │ │ ├── fem_problem.py │ │ ├── plate.py │ │ ├── boundary_conditions.py │ │ ├── magnetostatics_problem.py │ │ ├── electrostatics_problem.py │ │ └── linear_elasticity_problem.py │ ├── data_generators │ │ ├── __init__.py │ │ ├── generator_registry.py │ │ ├── extend_solution.py │ │ ├── generator_base.py │ │ ├── elasticity_fixed_line.py │ │ ├── rotating_charges.py │ │ ├── magnetics_random_current.py │ │ ├── electrics_random_charge.py │ │ ├── plates_and_charges.py │ │ └── condenser_plates.py │ ├── mesh_generators │ │ ├── mesh_base.py │ │ ├── convert_mesh.py │ │ ├── square_mesh.py │ │ ├── disk_precomputed.py │ │ ├── u_mesh_precomputed.py │ │ ├── l_mesh_precomputed.py │ │ ├── cylinder_precomputed.py │ │ ├── mesh_registry.py │ │ ├── disk_mesh.py │ │ ├── u_mesh.py │ │ ├── cylinder_mesh.py │ │ └── l_mesh.py │ ├── utils.py │ ├── msg_dataset.py │ ├── fem_driver.py │ ├── recursive_user_expression.py │ └── lightning_datamodule.py ├── fem_trainer │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── gnn_identity.py │ │ ├── two_layer_mlp.py │ │ ├── gcn_processor_weights.py │ │ ├── gcn_processor.py │ │ └── main_model.py │ ├── graph_training_module.py │ └── trainer.py ├── preprocessing │ ├── __init__.py │ ├── gather_data.py │ └── split_and_normalize.py ├── visualization │ ├── __init__.py │ ├── vis_inputs.py │ ├── vis_quantities.py │ └── plot_graph.py ├── __init__.py ├── tricks │ └── dropouts.py └── app.py ├── requirements.txt ├── configs ├── task_shape │ ├── es_ma.json │ ├── es_no_ma.json │ ├── ms_ma.json │ └── ms_no_ma.json ├── task_sup │ ├── es_ma.json │ ├── es_no_ma.json │ ├── ms_ma.json │ └── ms_no_ma.json └── ablation │ ├── drop_edges.json │ ├── drop_nodes.json │ └── drop_features.json ├── LICENSE └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | wandb 3 | *.ckpt -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnn_bvp_solver/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnn_bvp_solver/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gnn_bvp_solver/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | squirrel-datasets-core 3 | matplotlib 4 | numpy 5 | torch==1.11.0 6 | pytorch-lightning 7 | torchmetrics 8 | 9 | # replace this line with your pytorch version 10 | --find-links https://data.pyg.org/whl/torch-1.11.0+cpu.html 11 | torch-geometric 12 | torch-sparse 13 | torch-scatter -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/models/gnn_identity.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | import torch 3 | 4 | 5 | class GNNIdentity(Module): 6 | def __init__(self, *args, **kwargs): 7 | """Simple identity function that can be used instead of a GNN processor""" 8 | super(GNNIdentity, self).__init__() 9 | 10 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 11 | """Model forward""" 12 | return x 13 | -------------------------------------------------------------------------------- /configs/task_shape/es_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_train_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_val_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_test_shape", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["shape_gen", "electrostatics"], 9 | "batch_size": 32, 10 | "augmentation": ["mesh_aug"], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/task_shape/es_no_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_train_no_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_val_no_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_test_shape", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["shape_gen", "electrostatics"], 9 | "batch_size": 32, 10 | "augmentation": [], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/task_sup/es_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_train_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_val_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_test_sup", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["superposition", "electrostatics"], 9 | "batch_size": 32, 10 | "augmentation": ["mesh_aug"], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/task_sup/es_no_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_train_no_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_val_no_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_test_sup", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["superposition", "electrostatics"], 9 | "batch_size": 32, 10 | "augmentation": [], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/task_shape/ms_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_train_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_val_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_test_shape", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["shape_gen", "magnetostatics"], 9 | "batch_size": 32, 10 | "augmentation": ["mesh_aug"], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/task_shape/ms_no_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_train_no_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_val_no_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_test_shape", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["shape_gen", "magnetostatics"], 9 | "batch_size": 32, 10 | "augmentation": [], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/task_sup/ms_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_train_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_val_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_test_sup", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["superposition", "magnetostatics"], 9 | "batch_size": 32, 10 | "augmentation": ["mesh_aug"], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/task_sup/ms_no_ma.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_train_no_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_val_no_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/MagneticsRandomCurrentGenerator/norm_test_sup", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["superposition", "magnetostatics"], 9 | "batch_size": 32, 10 | "augmentation": [], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/ablation/drop_edges.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_train_no_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_val_no_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_test_shape", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["shape_gen", "electrostatics"], 9 | "batch_size": 32, 10 | "augmentation": ["drop_edges"], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/ablation/drop_nodes.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_train_no_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_val_no_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_test_shape", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["shape_gen", "electrostatics"], 9 | "batch_size": 32, 10 | "augmentation": ["drop_nodes"], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /configs/ablation/drop_features.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_train": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_train_no_ma", 3 | "data_val": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_val_no_ma", 4 | "data_test": "gs://squirrel-core-public-data/gnn_bvp_solver/ElectricsRandomChargeGenerator/norm_test_shape", 5 | "processor": "gcnch3w", 6 | "epochs": 150, 7 | "dim": [11, 3], 8 | "tags": ["shape_gen", "electrostatics"], 9 | "batch_size": 32, 10 | "augmentation": ["embedding_dropout"], 11 | "remove_pos": true 12 | } -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/models/two_layer_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import ReLU, Sequential, Linear, Identity 3 | 4 | 5 | class TwoLayerMLP(torch.nn.Module): 6 | def __init__(self, n_input: int = 128, n_hidden: int = 128, n_output: int = 128, is_output: bool = False): 7 | """Simple MLP with two linear layers and relu nonlinearity""" 8 | super().__init__() 9 | 10 | self.layers = Sequential( 11 | Linear(n_input, n_hidden), 12 | ReLU(), 13 | Linear(n_hidden, n_output), 14 | Identity() if is_output else ReLU(), 15 | ) 16 | 17 | def forward(self, x: torch.Tensor) -> torch.Tensor: 18 | """Model forward""" 19 | return self.layers(x) 20 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/mesh_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Any, Callable 4 | 5 | 6 | @dataclass 7 | class BaseMeshConfig: 8 | resolution: int 9 | mesh_name: str 10 | 11 | 12 | @dataclass 13 | class ParametrizedMesh: 14 | mesh: Any 15 | default_bc: Callable 16 | visible_area: Callable[[float, float], bool] 17 | 18 | 19 | class MeshGeneratorBase(ABC): 20 | @abstractmethod 21 | def __call__(self) -> BaseMeshConfig: 22 | """Generate a mesh config""" 23 | pass 24 | 25 | @staticmethod 26 | @abstractmethod 27 | def solve_config(config: BaseMeshConfig) -> ParametrizedMesh: 28 | """Generate fenics mesh from mesh config""" 29 | pass 30 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/circle.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | 4 | class CircleModule: 5 | def __init__(self, center: Tuple[float, float], radius: float, value: float): 6 | """Create a circular change for one the physical quantities 7 | 8 | Args: 9 | center (Tuple[float, float]): center of the circle 10 | radius (float): radius of the circle 11 | value (float): value to change to in this range 12 | """ 13 | self.center = center 14 | self.radius = radius 15 | self.value = value 16 | 17 | def __call__(self, x: float, y: float) -> float: 18 | """Call the module to evaluate at a given position. 19 | 20 | Args: 21 | x (float): x coordinate 22 | y (float): y coordinate 23 | 24 | Returns: 25 | float: value of ph. quantity 26 | """ 27 | x_c, y_c = self.center 28 | 29 | if (x - x_c) ** 2 + (y - y_c) ** 2 < self.radius: 30 | return self.value 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Merantix Momentum (Merantix Labs GmbH) 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/convert_mesh.py: -------------------------------------------------------------------------------- 1 | import fenics as fn 2 | from pathlib import Path 3 | import random 4 | import string 5 | import os 6 | import subprocess 7 | 8 | import meshio 9 | 10 | 11 | def save_msh_to_file(mesh: meshio.Mesh) -> str: 12 | """Convert a mesh in (py)gmsh format and save it as fenics mesh""" 13 | N = 10 14 | name = "".join(random.choices(string.ascii_uppercase + string.digits, k=N)) 15 | 16 | Path("temp").mkdir(exist_ok=True) 17 | mesh.write(f"temp/{name}.msh", file_format="gmsh22") 18 | subprocess.run( 19 | ["gmsh", "-2", "-format", "msh2", f"temp/{name}.msh"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL 20 | ) 21 | subprocess.run( 22 | ["dolfin-convert", f"temp/{name}.msh", f"temp/{name}.xml"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL 23 | ) 24 | 25 | os.remove(f"temp/{name}.msh") 26 | return f"temp/{name}.xml" 27 | 28 | 29 | def convert_msh_to_fenics(mesh: meshio.Mesh) -> fn.Mesh: 30 | """Convert a mesh in (py)gmsh format to a fenics mesh""" 31 | filename = save_msh_to_file(mesh) 32 | res = fn.Mesh(filename) 33 | os.remove(filename) 34 | 35 | return res 36 | -------------------------------------------------------------------------------- /gnn_bvp_solver/tricks/dropouts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.utils import subgraph 4 | from torch_geometric.utils.num_nodes import maybe_num_nodes 5 | 6 | 7 | class DropNode(nn.Module): 8 | """ 9 | DropNode: Sampling node using a uniform distribution. 10 | Based on https://github.com/VITA-Group/Deep_GCN_Benchmarking 11 | """ 12 | 13 | def __init__(self, drop_rate: float): 14 | """Set dropout rate rate""" 15 | super(DropNode, self).__init__() 16 | self.drop_rate = drop_rate 17 | 18 | def forward( 19 | self, 20 | edge_index: torch.Tensor, 21 | edge_attr: torch.Tensor = None, 22 | num_nodes: int = None, 23 | ) -> torch.Tensor: 24 | """Randomly drop nodes at specified rate""" 25 | if not self.training: 26 | return edge_index, edge_attr 27 | 28 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 29 | nodes = torch.arange(num_nodes, dtype=torch.int64) 30 | mask = torch.full_like(nodes, 1 - self.drop_rate, dtype=torch.float32) 31 | mask = torch.bernoulli(mask).to(torch.bool) 32 | subnodes = nodes[mask] 33 | 34 | return subgraph(subnodes, edge_index, edge_attr=edge_attr, num_nodes=num_nodes) 35 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import torch 3 | import numpy as np 4 | import fenics 5 | 6 | 7 | def stack_data(data: Iterable[np.array]) -> torch.tensor: 8 | """Helper function for converting data to pytorch and stacking it 9 | 10 | Args: 11 | data (Iterable[np.array]): List of numpy arrays 12 | 13 | Returns: 14 | torch.Tensor: Aggregated data 15 | """ 16 | data = np.stack([d.astype(np.float32).flatten() for d in data], axis=1) 17 | return torch.tensor(data, dtype=torch.float) 18 | 19 | 20 | def extract_edges_from_triangle_mesh(mesh: fenics.Mesh) -> np.array: 21 | """Extract an unfiltered list of edges from a triangle mesh. 22 | 23 | Args: 24 | mesh (fenics.Mesh): The mesh to iterate over. 25 | 26 | Returns: 27 | np.array: Array containing list of edges indicated by indexes. 28 | """ 29 | edges_source = [] 30 | edges_sink = [] 31 | for c in mesh.cells(): 32 | # note: we assume a triangle mesh 33 | edges_source += [c[0], c[1], c[2], c[1], c[2], c[0]] 34 | edges_sink += [c[1], c[2], c[0], c[0], c[1], c[2]] 35 | 36 | # note that this array is still unfiltered and might contain duplicates 37 | # we have a convenient way to filter with pt geometric 38 | return np.array([edges_source, edges_sink], dtype=np.int64) 39 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/models/gcn_processor_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import GCNConv, Sequential 3 | from typing import Type, List, Dict, Tuple 4 | from gnn_bvp_solver.fem_trainer.models.gcn_processor import GraphNetMP 5 | 6 | 7 | class WeightedGraphNetMP(GraphNetMP): 8 | def __init__( 9 | self, 10 | n_gcn_layers: int, 11 | n_input: int = 128, 12 | n_hidden: int = 128, 13 | n_output: int = 128, 14 | gcn_type: Type = GCNConv, 15 | gcn_kwargs: Dict = None, 16 | processor_dropout: bool = False, 17 | ): 18 | """Simple graph net with n message passing layers and weighted edges""" 19 | super().__init__(n_gcn_layers, n_input, n_hidden, n_output, gcn_type, gcn_kwargs, processor_dropout) 20 | 21 | def _get_gcn_layer(self, gcn_type: Type, n_in: int, n_out: int, gcn_kwargs: Dict) -> Tuple: 22 | return gcn_type(n_in, n_out, **gcn_kwargs), "x, edge_index, edge_weights -> x" 23 | 24 | def _init_processor_layers(self, layers: List) -> torch.nn.Module: 25 | self.conv_layers = Sequential( 26 | "x, edge_index, edge_weights", 27 | layers, 28 | ) 29 | 30 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_weights: torch.Tensor) -> torch.Tensor: 31 | """Model forward""" 32 | return self.conv_layers(x, edge_index, edge_weights) 33 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/fem_problem.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Iterable, Tuple 3 | from gnn_bvp_solver.fem_dataset.recursive_user_expression import RecursiveUserExpression 4 | 5 | import fenics as fn 6 | import numpy as np 7 | 8 | 9 | class FemProblem(ABC): 10 | def __init__(self, mesh: fn.Mesh): 11 | """Init FEM Problem with given mesh. 12 | 13 | Args: 14 | mesh: FEM mesh to evaluate on. 15 | """ 16 | self.boundary_conditions = RecursiveUserExpression(False) 17 | self.mesh = mesh 18 | 19 | @abstractmethod 20 | def solve(self) -> Dict[str, np.array]: 21 | """Using fenics to find the solution for an fem problem 22 | 23 | Returns: 24 | Dict[str, np.array]: solution 25 | """ 26 | pass 27 | 28 | @staticmethod 29 | @abstractmethod 30 | def input_output_mapping() -> Dict[str, Iterable[str]]: 31 | """Specify which quantities are input or output. 32 | 33 | Returns: 34 | Dict[str, Iterable[str]]: input / output values of this problem 35 | """ 36 | pass 37 | 38 | @staticmethod 39 | @abstractmethod 40 | def physical_quantities() -> Iterable[Tuple]: 41 | """Return physical quantities present in this simulation. 42 | 43 | Returns: 44 | Iterable: all quantities (can be 1d or nd) 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/msg_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from torch_geometric.data import Data 4 | from squirrel.driver.msgpack import MessagepackDriver 5 | 6 | 7 | class MsgIterableDataset(torch.utils.data.IterableDataset): 8 | def __init__(self, path: str, split: str = None, dry_run: bool = False, shuffle: bool = False) -> None: 9 | """Creates a PyTorch iterable dataset from a squirrel messagepack driver""" 10 | if split is None: 11 | self.driver = MessagepackDriver(path) 12 | else: 13 | self.driver = MessagepackDriver(f"{path}/norm_{split}") 14 | self.dry_run = dry_run 15 | self.shuffle = shuffle 16 | 17 | def _mapping_f(self, item: Dict) -> Data: 18 | """Map numpy arrays from squirrel to pt geometric Data""" 19 | edge_index = torch.tensor(item["edge_index"]) 20 | return Data(x=torch.tensor(item["data_x"]), y=torch.tensor(item["data_y"]), edge_index=edge_index) 21 | 22 | def __iter__(self): 23 | """Iterate dataset""" 24 | if self.dry_run: 25 | it = self.driver.get_iter(max_workers=1, prefetch_buffer=1, shuffle_key_buffer=1, shuffle_item_buffer=1) 26 | elif self.shuffle: 27 | it = self.driver.get_iter(max_workers=4, prefetch_buffer=5, shuffle_key_buffer=100, shuffle_item_buffer=100) 28 | else: 29 | it = self.driver.get_iter(max_workers=4, prefetch_buffer=5, shuffle_key_buffer=1, shuffle_item_buffer=1) 30 | 31 | for i in it.map(self._mapping_f): 32 | yield i 33 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/plate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PlateModule: 5 | def __init__(self, distance: float, height: float, value: float, rotation: float, resolution: int): 6 | """Create a change in form of a thin plate for one the physical quantities 7 | 8 | Args: 9 | distance (float): distance from center 10 | height (float): height of plat 11 | value (float): value of ph. quantity to set 12 | rotation (float): rotation of plate 13 | resolution (int): resolution of mesh, needed to avoid rounding errors 14 | """ 15 | self.distance = distance 16 | self.height = height 17 | self.value = value 18 | self.resolution = resolution 19 | self.rotation = rotation 20 | 21 | def __call__(self, x: float, y: float): 22 | """Call the module to evaluate at a given position. 23 | 24 | Args: 25 | x (float): x coordinate 26 | y (float): y coordinate 27 | 28 | Returns: 29 | float: value of ph. quantity 30 | """ 31 | dx = (x - 0.5) * np.cos(self.rotation) - (y - 0.5) * np.sin(self.rotation) + 0.5 32 | dy = (x - 0.5) * np.sin(self.rotation) + (y - 0.5) * np.cos(self.rotation) + 0.5 33 | 34 | if dy > (0.5 + (self.height / 2.0)) or dy < (0.5 - self.height / 2.0): 35 | return 36 | 37 | # divide by resolution to avoid disappearance if no nodes are hit 38 | if np.abs((dx - 0.5) - self.distance) < 1.0 / self.resolution: 39 | return self.value 40 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/square_mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import partial 3 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, MeshGeneratorBase, ParametrizedMesh 4 | from gnn_bvp_solver.fem_dataset.modules.boundary_conditions import BoundaryConditions 5 | 6 | import fenics as fn 7 | 8 | RES_NOISE_RANGE = [-8.0, 5.0] 9 | 10 | 11 | class UnitSquareGenerator(MeshGeneratorBase): 12 | def __init__(self, resolution: int, randomize: bool = False): 13 | """Create mesh generator for a square mesh""" 14 | self.resolution = resolution 15 | self.randomize = randomize 16 | 17 | def __call__(self) -> BaseMeshConfig: 18 | """Generate a mesh config""" 19 | if self.randomize: 20 | res_noise = int(np.random.uniform(RES_NOISE_RANGE[0], RES_NOISE_RANGE[1])) 21 | return BaseMeshConfig(self.resolution + res_noise, "UnitSquareGenerator") 22 | else: 23 | return BaseMeshConfig(self.resolution, "UnitSquareGenerator") 24 | 25 | @staticmethod 26 | def vis_area(x: float, y: float) -> bool: 27 | """Query visible area for this mesh""" 28 | if x < 0 or y < 0 or x > 1.0 or y > 1.0: 29 | return False 30 | return True 31 | 32 | @staticmethod 33 | def solve_config(config: BaseMeshConfig) -> ParametrizedMesh: 34 | """Generate fenics mesh from mesh config""" 35 | return ParametrizedMesh( 36 | fn.UnitSquareMesh(config.resolution, config.resolution), 37 | BoundaryConditions("all").get_f(), 38 | partial(UnitSquareGenerator.vis_area), 39 | ) 40 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/generator_registry.py: -------------------------------------------------------------------------------- 1 | from gnn_bvp_solver.fem_dataset.data_generators.condenser_plates import CondenserPlatesGenerator 2 | from gnn_bvp_solver.fem_dataset.data_generators.elasticity_fixed_line import ElasticityFixedLineGenerator 3 | from gnn_bvp_solver.fem_dataset.data_generators.electrics_random_charge import ElectricsRandomChargeGenerator 4 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import BaseGeneratorConfig 5 | from gnn_bvp_solver.fem_dataset.data_generators.magnetics_random_current import MagneticsRandomCurrentGenerator 6 | from gnn_bvp_solver.fem_dataset.data_generators.plates_and_charges import PlatesAndChargesGenerator 7 | from gnn_bvp_solver.fem_dataset.data_generators.rotating_charges import RotatingChargesGenerator 8 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import GeneratorBase 9 | 10 | 11 | def get_gen_and_solve(config: BaseGeneratorConfig) -> GeneratorBase: 12 | """This function stores a list of generators accessibly via names""" 13 | if config.generator_name == "CondenserPlatesGenerator": 14 | return CondenserPlatesGenerator.solve_config(config) 15 | elif config.generator_name == "ElasticityFixedLineGenerator": 16 | return ElasticityFixedLineGenerator.solve_config(config) 17 | elif config.generator_name == "ElectricsRandomChargeGenerator": 18 | return ElectricsRandomChargeGenerator.solve_config(config) 19 | elif config.generator_name == "MagneticsRandomCurrentGenerator": 20 | return MagneticsRandomCurrentGenerator.solve_config(config) 21 | elif config.generator_name == "PlatesAndChargesGenerator": 22 | return PlatesAndChargesGenerator.solve_config(config) 23 | elif config.generator_name == "RotatingChargesGenerator": 24 | return RotatingChargesGenerator.solve_config(config) 25 | else: 26 | ValueError(f"generator {config.generator_name} not found") 27 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/disk_precomputed.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import fenics as fn 4 | import numpy as np 5 | from pyspark import SparkContext, SparkFiles 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import save_msh_to_file 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.disk_mesh import UnitDiskGenerator 8 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, ParametrizedMesh 9 | 10 | 11 | @dataclass 12 | class UnitDiskPrecomputedConfig(BaseMeshConfig): 13 | precomputed_path: str 14 | 15 | 16 | class UnitDiskPrecomputedGenerator(UnitDiskGenerator): 17 | def __init__(self, resolution: int, spark_context: SparkContext, randomize: bool): 18 | """Create mesh generator for a disk and precompute meshes as files""" 19 | super().__init__(resolution, randomize) 20 | self.spark_context = spark_context 21 | self.last_config = None 22 | 23 | def __call__(self) -> UnitDiskPrecomputedConfig: 24 | """Generate a mesh config""" 25 | if self.last_config is not None and (np.random.randint(16) < 15 or not self.randomize): 26 | return self.last_config 27 | 28 | config = super().__call__() # generate config 29 | 30 | path = save_msh_to_file(UnitDiskGenerator.config2msh(config)) 31 | self.spark_context.addFile(path) 32 | 33 | self.last_config = UnitDiskPrecomputedConfig(self.resolution, "UnitDiskPrecomputedGenerator", path) 34 | return self.last_config 35 | 36 | @staticmethod 37 | def solve_config(config: UnitDiskPrecomputedConfig) -> fn.Mesh: 38 | """Generate fenics mesh from mesh config""" 39 | return ParametrizedMesh( 40 | fn.Mesh(SparkFiles.get(config.precomputed_path[5:])), 41 | partial(UnitDiskGenerator.vis_area, config=config, is_bc=True), 42 | partial(UnitDiskGenerator.vis_area, config=config, is_bc=False), 43 | ) 44 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/u_mesh_precomputed.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import fenics as fn 4 | from pyspark import SparkContext, SparkFiles 5 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import save_msh_to_file 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.u_mesh import UMeshConfig, UMeshGenerator 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import ParametrizedMesh 8 | import numpy as np 9 | 10 | 11 | @dataclass 12 | class UMeshPrecomputedConfig(UMeshConfig): 13 | precomputed_path: str 14 | 15 | 16 | class UMeshPrecomputedGenerator(UMeshGenerator): 17 | def __init__(self, spark_context: SparkContext, randomize: bool): 18 | """Create mesh generator for a U mesh and precompute meshes as files""" 19 | super().__init__(randomize) 20 | self.spark_context = spark_context 21 | self.last_config = None 22 | 23 | def __call__(self) -> UMeshConfig: 24 | """Generate a mesh config""" 25 | if self.last_config is not None and (np.random.randint(16) < 15 or not self.randomize): 26 | return self.last_config 27 | 28 | config = super().__call__() # generate config 29 | 30 | path = save_msh_to_file(UMeshGenerator.config2msh(config)) 31 | self.spark_context.addFile(path) 32 | 33 | self.last_config = UMeshPrecomputedConfig( 34 | self.resolution, "UMeshPrecomputedGenerator", config.cutout_size_x, config.cutout_size_y, path 35 | ) 36 | return self.last_config 37 | 38 | @staticmethod 39 | def solve_config(config: UMeshPrecomputedConfig) -> fn.Mesh: 40 | """Generate fenics mesh from mesh config""" 41 | return ParametrizedMesh( 42 | fn.Mesh(SparkFiles.get(config.precomputed_path[5:])), 43 | partial(UMeshGenerator.vis_area, config=config, is_bc=True), 44 | partial(UMeshGenerator.vis_area, config=config, is_bc=False), 45 | ) 46 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/extend_solution.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import GeneratorSolution 3 | import numpy as np 4 | import scipy.spatial 5 | 6 | 7 | def _min_dist_xy(p: np.array, ref_points: np.array) -> Tuple[np.array, np.array]: 8 | dist = scipy.spatial.distance.cdist(p, ref_points) 9 | closest_border = np.argmin(dist, axis=1) 10 | 11 | dx = ref_points[closest_border][:, 0] - p[:, 0] 12 | dy = ref_points[closest_border][:, 1] - p[:, 1] 13 | 14 | return dx, dy 15 | 16 | 17 | def map_extend(sol: GeneratorSolution) -> GeneratorSolution: 18 | """Extend generator solution by adding features for distance to border and next boundary condition. 19 | 20 | Args: 21 | sol (GeneratorSolution): solution to extend 22 | 23 | Returns: 24 | GeneratorSolution: extended solution 25 | """ 26 | p = np.stack([sol.quantities["x"], sol.quantities["y"]], axis=1) 27 | p_bc_active = sol.quantities["bdr_v0"] > 0.01 28 | 29 | def _apply_bc(a: np.array) -> np.array: 30 | return sol.mesh.default_bc(a[0], a[1]) 31 | 32 | p_border = np.apply_along_axis(_apply_bc, axis=1, arr=p) 33 | bc_points = p[p_bc_active] 34 | border_points = p[p_border] 35 | 36 | sol.io_mapping["x"].append("dist_border_x") 37 | sol.io_mapping["x"].append("dist_border_y") 38 | sol.io_mapping["x"].append("dist_bc_x") 39 | sol.io_mapping["x"].append("dist_bc_y") 40 | sol.io_mapping["x"].append("dist_bc") 41 | sol.io_mapping["x"].append("dist_border") 42 | 43 | dx_border, dy_border = _min_dist_xy(p, border_points) 44 | sol.quantities["dist_border_x"] = dx_border 45 | sol.quantities["dist_border_y"] = dy_border 46 | sol.quantities["dist_border"] = np.sqrt(dx_border**2 + dy_border**2) 47 | 48 | dx_bc, dy_bc = _min_dist_xy(p, bc_points) 49 | sol.quantities["dist_bc_x"] = dx_bc 50 | sol.quantities["dist_bc_y"] = dy_bc 51 | sol.quantities["dist_bc"] = np.sqrt(dx_bc**2 + dy_bc**2) 52 | 53 | return sol 54 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/models/gcn_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type, List, Tuple 2 | import torch 3 | from torch_geometric.nn import GCNConv, Sequential 4 | from torch.nn import ReLU, Dropout 5 | 6 | 7 | class GraphNetMP(torch.nn.Module): 8 | def __init__( 9 | self, 10 | n_gcn_layers: int, 11 | n_input: int = 128, 12 | n_hidden: int = 128, 13 | n_output: int = 128, 14 | gcn_type: Type = GCNConv, 15 | gcn_kwargs: Dict = None, 16 | processor_dropout: bool = False, 17 | ): 18 | """Simple graph net with n message passing layers.""" 19 | super().__init__() 20 | 21 | if gcn_kwargs is None: 22 | gcn_kwargs = {} 23 | 24 | if processor_dropout: 25 | layers = [Dropout(0.1), self._get_gcn_layer(gcn_type, n_input, n_hidden, gcn_kwargs), ReLU()] 26 | else: 27 | layers = [self._get_gcn_layer(gcn_type, n_input, n_hidden, gcn_kwargs), ReLU()] 28 | 29 | for _i in range(n_gcn_layers - 2): 30 | if processor_dropout: 31 | layers.append(Dropout(0.1)) 32 | 33 | layers.append(self._get_gcn_layer(gcn_type, n_hidden, n_hidden, gcn_kwargs)) 34 | layers.append(ReLU()) 35 | 36 | if processor_dropout: 37 | layers.append(Dropout(0.1)) 38 | 39 | layers.append(self._get_gcn_layer(gcn_type, n_hidden, n_output, gcn_kwargs)) 40 | layers.append(ReLU()) 41 | self._init_processor_layers(layers) 42 | 43 | def _get_gcn_layer(self, gcn_type: Type, n_in: int, n_out: int, gcn_kwargs: Dict) -> Tuple: 44 | return gcn_type(n_in, n_out, **gcn_kwargs), "x, edge_index -> x" 45 | 46 | def _init_processor_layers(self, layers: List) -> torch.nn.Module: 47 | self.conv_layers = Sequential( 48 | "x, edge_index", 49 | layers, 50 | ) 51 | 52 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 53 | """Model forward""" 54 | return self.conv_layers(x, edge_index) 55 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/l_mesh_precomputed.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import fenics as fn 4 | from pyspark import SparkContext, SparkFiles 5 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import save_msh_to_file 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.l_mesh import LMeshConfig, LMeshGenerator 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import ParametrizedMesh 8 | import numpy as np 9 | 10 | 11 | @dataclass 12 | class LMeshPrecomputedConfig(LMeshConfig): 13 | precomputed_path: str 14 | 15 | 16 | class LMeshPrecomputedGenerator(LMeshGenerator): 17 | def __init__(self, spark_context: SparkContext, randomize: bool): 18 | """Create mesh generator for an L mesh""" 19 | super().__init__(randomize) 20 | self.spark_context = spark_context 21 | self.last_config = None 22 | 23 | def __call__(self) -> LMeshPrecomputedConfig: 24 | """Generate a mesh config""" 25 | if self.last_config is not None and (np.random.randint(16) < 15 or not self.randomize): 26 | return self.last_config 27 | 28 | config = super().__call__() # generate config 29 | 30 | path = save_msh_to_file(LMeshGenerator.config2msh(config)) 31 | self.spark_context.addFile(path) 32 | 33 | self.last_config = LMeshPrecomputedConfig( 34 | self.resolution, 35 | "LMeshPrecomputedGenerator", 36 | config.cutout_location, 37 | config.cutout_size_x, 38 | config.cutout_size_y, 39 | path, 40 | ) 41 | return self.last_config 42 | 43 | @staticmethod 44 | def solve_config(config: LMeshPrecomputedConfig) -> fn.Mesh: 45 | """Generate fenics mesh from mesh config""" 46 | return ParametrizedMesh( 47 | fn.Mesh(SparkFiles.get(config.precomputed_path[5:])), 48 | partial(LMeshGenerator.vis_area, config=config, is_bc=True), 49 | partial(LMeshGenerator.vis_area, config=config, is_bc=False), 50 | ) 51 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/cylinder_precomputed.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import fenics as fn 4 | import numpy as np 5 | from pyspark import SparkContext, SparkFiles 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import save_msh_to_file 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.cylinder_mesh import CylinderGenerator, CylinderMeshConfig 8 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import ParametrizedMesh 9 | 10 | 11 | @dataclass 12 | class CylinderMeshPrecomputedConfig(CylinderMeshConfig): 13 | precomputed_path: str 14 | 15 | 16 | class CylinderMeshPrecomputedGenerator(CylinderGenerator): 17 | def __init__(self, spark_context: SparkContext, randomize: bool): 18 | """Create mesh generator for a disk with hole and precompute meshes as files""" 19 | super().__init__(randomize) 20 | self.spark_context = spark_context 21 | self.last_config = None 22 | 23 | def __call__(self) -> CylinderMeshPrecomputedConfig: 24 | """Generate a mesh config""" 25 | if self.last_config is not None and (np.random.randint(16) < 15 or not self.randomize): 26 | return self.last_config 27 | 28 | config = super().__call__() # generate config 29 | path = save_msh_to_file(CylinderGenerator.config2msh(config)) 30 | self.spark_context.addFile(path) 31 | 32 | self.last_config = CylinderMeshPrecomputedConfig( 33 | self.resolution, 34 | "CylinderMeshPrecomputedGenerator", 35 | config.center_x, 36 | config.center_y, 37 | config.inner_radius, 38 | path, 39 | ) 40 | return self.last_config 41 | 42 | @staticmethod 43 | def solve_config(config: CylinderMeshPrecomputedConfig) -> ParametrizedMesh: 44 | """Generate fenics mesh from mesh config""" 45 | return ParametrizedMesh( 46 | fn.Mesh(SparkFiles.get(config.precomputed_path[5:])), 47 | partial(CylinderGenerator.vis_area, config=config, is_bc=True), 48 | partial(CylinderGenerator.vis_area, config=config), 49 | ) 50 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/mesh_registry.py: -------------------------------------------------------------------------------- 1 | from gnn_bvp_solver.fem_dataset.mesh_generators.cylinder_mesh import CylinderGenerator 2 | from gnn_bvp_solver.fem_dataset.mesh_generators.cylinder_precomputed import CylinderMeshPrecomputedGenerator 3 | from gnn_bvp_solver.fem_dataset.mesh_generators.disk_mesh import UnitDiskGenerator 4 | from gnn_bvp_solver.fem_dataset.mesh_generators.disk_precomputed import UnitDiskPrecomputedGenerator 5 | from gnn_bvp_solver.fem_dataset.mesh_generators.l_mesh import LMeshGenerator 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.l_mesh_precomputed import LMeshPrecomputedGenerator 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, MeshGeneratorBase 8 | from gnn_bvp_solver.fem_dataset.mesh_generators.square_mesh import UnitSquareGenerator 9 | from gnn_bvp_solver.fem_dataset.mesh_generators.u_mesh import UMeshGenerator 10 | from gnn_bvp_solver.fem_dataset.mesh_generators.u_mesh_precomputed import UMeshPrecomputedGenerator 11 | 12 | 13 | def get_mesh(config: BaseMeshConfig) -> MeshGeneratorBase: 14 | """This function stores a list of mesh generators accessibly via names""" 15 | if config.mesh_name == "UnitSquareGenerator": 16 | return UnitSquareGenerator.solve_config(config) 17 | elif config.mesh_name == "UnitDiskGenerator": 18 | return UnitDiskGenerator.solve_config(config) 19 | elif config.mesh_name == "CylinderGenerator": 20 | return CylinderGenerator.solve_config(config) 21 | elif config.mesh_name == "LMeshGenerator": 22 | return LMeshGenerator.solve_config(config) 23 | elif config.mesh_name == "UMeshGenerator": 24 | return UMeshGenerator.solve_config(config) 25 | elif config.mesh_name == "LMeshPrecomputedGenerator": 26 | return LMeshPrecomputedGenerator.solve_config(config) 27 | elif config.mesh_name == "UMeshPrecomputedGenerator": 28 | return UMeshPrecomputedGenerator.solve_config(config) 29 | elif config.mesh_name == "UnitDiskPrecomputedGenerator": 30 | return UnitDiskPrecomputedGenerator.solve_config(config) 31 | elif config.mesh_name == "CylinderMeshPrecomputedGenerator": 32 | return CylinderMeshPrecomputedGenerator.solve_config(config) 33 | else: 34 | raise ValueError(f"Mesh generator {config.mesh_name} not found") 35 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/disk_mesh.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import meshio 3 | import pygmsh 4 | import numpy as np 5 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import convert_msh_to_fenics 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, MeshGeneratorBase, ParametrizedMesh 7 | 8 | RES_NOISE_RANGE = [-8.0, 5.0] 9 | 10 | 11 | class UnitDiskGenerator(MeshGeneratorBase): 12 | def __init__(self, resolution: int, randomize: bool = False): 13 | """Create mesh generator for a disk""" 14 | self.resolution = resolution 15 | self.randomize = randomize 16 | 17 | def __call__(self) -> BaseMeshConfig: 18 | """Generate a mesh config""" 19 | if self.randomize: 20 | res_noise = int(np.random.uniform(RES_NOISE_RANGE[0], RES_NOISE_RANGE[1])) 21 | return BaseMeshConfig(self.resolution + res_noise, "UnitDiskGenerator") 22 | else: 23 | return BaseMeshConfig(self.resolution, "UnitDiskGenerator") 24 | 25 | @staticmethod 26 | def vis_area(x: float, y: float, config: BaseMeshConfig, is_bc: bool = False) -> bool: 27 | """Query visible area for this mesh""" 28 | resolution = config.resolution 29 | 30 | if resolution is None: 31 | epsilon = 0.0 32 | else: 33 | epsilon = 2.0 / resolution 34 | 35 | x_0 = x - 0.5 36 | y_0 = y - 0.5 37 | 38 | r_sq = np.square(x_0) + np.square(y_0) 39 | if r_sq > 0.25 - np.square(epsilon): 40 | return is_bc 41 | 42 | return not is_bc 43 | 44 | @staticmethod 45 | def config2msh(config: BaseMeshConfig) -> meshio.Mesh: 46 | """Convert mesh config to pygmsh mesh""" 47 | with pygmsh.occ.Geometry() as geom: 48 | geom.add_disk([0.5, 0.5], 0.5, mesh_size=1.0 / config.resolution) 49 | return geom.generate_mesh() 50 | 51 | @staticmethod 52 | def solve_config(config: BaseMeshConfig) -> ParametrizedMesh: 53 | """Generate fenics mesh from mesh config""" 54 | return ParametrizedMesh( 55 | convert_msh_to_fenics(UnitDiskGenerator.config2msh(config)), 56 | partial(UnitDiskGenerator.vis_area, config=config, is_bc=True), 57 | partial(UnitDiskGenerator.vis_area, config=config, is_bc=False), 58 | ) 59 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/fem_driver.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from squirrel.driver import IterDriver 3 | from gnn_bvp_solver.fem_dataset.data_generators.extend_solution import map_extend 4 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import BaseGeneratorConfig, GeneratorBase 5 | from squirrel.iterstream import Composable, IterableSource 6 | from gnn_bvp_solver.fem_dataset.data_generators.generator_registry import get_gen_and_solve 7 | 8 | import numpy as np 9 | import fenics 10 | 11 | 12 | class FemDriver(IterDriver): 13 | name = "fem_driver" 14 | 15 | def __init__( 16 | self, 17 | train_generator: GeneratorBase, 18 | val_generator: GeneratorBase = None, 19 | test_generator: GeneratorBase = None, 20 | **kwargs 21 | ) -> None: 22 | """Initialize the FemDriver. 23 | 24 | Args: 25 | train_generator (GeneratorBase): Generator for training data. 26 | val_generator (GeneratorBase): Generator for validation data. 27 | test_generator (GeneratorBase): Generator for testing data. 28 | **kwargs: Other keyword arguments passes to super class initializer. 29 | """ 30 | super().__init__(**kwargs) 31 | self.generators = {"train": train_generator, "val": val_generator, "test": test_generator} 32 | 33 | @staticmethod 34 | def map_sample(config: BaseGeneratorConfig) -> Dict: 35 | """Simplified method that expects all generators to output the same quantities""" 36 | fenics.set_log_active(False) 37 | 38 | solution = get_gen_and_solve(config) 39 | solution = map_extend(solution) 40 | 41 | return { 42 | "data_x": np.stack([solution.quantities[k] for k in solution.io_mapping["x"]], axis=1), 43 | "data_y": np.stack([solution.quantities[k] for k in solution.io_mapping["y"]], axis=1), 44 | "edge_index": solution.quantities["edge_index"], 45 | } 46 | 47 | def get_iter(self, split: str, solve_pde: bool = True, **kwargs) -> Composable: 48 | """Create iterstream based on dataset split (train, val, test). Applies hooks before loading samples.""" 49 | assert split in ["train", "val", "test"] 50 | 51 | it = IterableSource(iter(self.generators[split])) 52 | 53 | if solve_pde: 54 | return it.map(self.map_sample) 55 | else: 56 | return it 57 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/recursive_user_expression.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Iterable, Tuple, Union 2 | import fenics as fn 3 | 4 | 5 | class RecursiveUserExpression(fn.UserExpression): 6 | def __init__(self, default: int = 0, **kwargs): 7 | """Recursive fenics user expression. Defines a function that can be evaluated 8 | and states physical quantities at all locations. 9 | 10 | Args: 11 | default (int, optional): default value for this expression. Defaults to 0. 12 | """ 13 | self.default = default 14 | self.sub_expressions = [] 15 | super().__init__(degree=1, **kwargs) 16 | 17 | def add_subexpression(self, e: Callable[[float, float], Union[float, None]]) -> None: 18 | """Add a subexpression that can change a local part of the quantity. 19 | 20 | Args: 21 | e (Callable[[float, float], Union[float, None]]): Callable expression. 22 | 23 | Returns: 24 | None 25 | """ 26 | self.sub_expressions.append(e) 27 | 28 | def eval_cell(self, values: Any, point: Iterable[float], cell: Any) -> float: 29 | """Override base method to tell fenics which values the quantity has at which point. 30 | 31 | Args: 32 | values: value array to store values into 33 | point (Iterable[float]): coordinates of the current point where we evaluate. 34 | cell: not used 35 | 36 | Returns: 37 | float: value of physical quantity 38 | """ 39 | values[0] = self(point) 40 | return values[0] 41 | 42 | def __call__(self, point: Iterable[float], *args) -> float: 43 | """Expression is callable and outputs a value for each point. 44 | 45 | Args: 46 | point (Iterable[float]): coordinates of the current point. 47 | 48 | Returns: 49 | float: value of physical quantity 50 | """ 51 | result = self.default 52 | x, y = point[0], point[1] 53 | 54 | # check all candidates for that node 55 | # the last one wins if there is a conflict 56 | for e in self.sub_expressions: 57 | candidate = e(x, y, *args) 58 | if candidate is not None: 59 | result = candidate 60 | 61 | return result 62 | 63 | def value_shape(self) -> Tuple: 64 | """Shape of the ph. quantity 65 | 66 | Returns: 67 | Tuple: empty tuple to indicate scalar quantity 68 | """ 69 | return () 70 | -------------------------------------------------------------------------------- /gnn_bvp_solver/app.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from gnn_bvp_solver.fem_trainer.trainer import FEMTraining 3 | import argparse 4 | import json 5 | 6 | 7 | def create_training(config: Dict, dry_run: bool, profiling: bool, download_data: bool) -> FEMTraining: 8 | """Create fem trainer from config and other options""" 9 | return FEMTraining( 10 | config["data_train"], 11 | config["data_val"], 12 | config["data_test"], 13 | project="gnn_bvp_solver", 14 | tags=[config["processor"], *config["tags"]], 15 | dry_run=dry_run, 16 | batch_size_train=config["batch_size"], 17 | profiling=profiling, 18 | download_data=download_data, 19 | ) 20 | 21 | 22 | def train(training: FEMTraining, config: Dict, cuda: bool) -> None: 23 | """Call train on trainer using options from config""" 24 | training.train( 25 | config["dim"][0], 26 | config["dim"][1], 27 | config["processor"], 28 | augmentation=config["augmentation"], 29 | epochs=config["epochs"], 30 | cuda=cuda, 31 | remove_pos=config["remove_pos"], 32 | ) 33 | 34 | 35 | def test(training: FEMTraining, artifact: str, cuda: bool) -> None: 36 | """Call test on trainer""" 37 | training.test(artifact, cuda=cuda, project="gnn_bvp_solver") 38 | 39 | 40 | def test_vis(training: FEMTraining, artifact: str, cuda: bool, failure: bool = False) -> None: 41 | """Start visualization (either visualize largest loss or first few samples)""" 42 | training.vis_cases(artifact, cuda=cuda, failure=failure) 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--config", default="") 48 | parser.add_argument("--dry_run", default="") 49 | parser.add_argument("--task", default="train") 50 | parser.add_argument("--artifact", default="") 51 | parser.add_argument("--no-gpu", dest="gpu", action="store_false") 52 | parser.add_argument("--profiling", dest="profiling", action="store_true") 53 | parser.add_argument("--download-data", dest="download_data", action="store_true") 54 | parser.set_defaults(gpu=True, profiling=False, download_data=False) 55 | 56 | args = parser.parse_args() 57 | print(args) 58 | 59 | with open(args.config) as f: 60 | config_dict = json.loads(f.read()) 61 | 62 | training = create_training(config_dict, args.dry_run == "dry_run", args.profiling, args.download_data) 63 | if args.task == "train": 64 | train(training, config_dict, args.gpu) 65 | elif args.task == "test": 66 | test(training, args.artifact, args.gpu) 67 | elif args.task == "vis": 68 | test_vis(training, args.artifact, args.gpu) 69 | elif args.task == "vis_failure": 70 | test_vis(training, args.artifact, args.gpu, True) 71 | else: 72 | print("nothing to do") 73 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Source code for paper "Learning the Solution Operator of Boundary Value Problems using Graph Neural Networks". 2 | 3 | > **_NOTE on reproducibility (17 Aug 2023):_** The appendix of the [arxiv version](https://arxiv.org/abs/2206.14092) of our paper has been updated with more extensive results (mean and standard deviation over 5 runs) to increase reproducibility of our results. 4 | 5 | ## Using the data 6 | The data is stored publicly in google buckets in `requester-pays` mode. To access the data and train your models, you need to [include a billing project](https://cloud.google.com/storage/docs/using-requester-pays#using). The raw data can also be found on [huggingface](https://huggingface.co/datasets/winfried/gnn_bvp_solver/tree/main). You would need to adapt the dataset paths in the config files after downloading and afterwards you would be able to run the experiments without accessing the cloud. 7 | 8 | ## Use weights & biases: 9 | We use weights & biases for logging and experiment tracking. You can create your free account [here](https://wandb.ai/). 10 | Use `wandb login` to login from your python environment as described [here](https://docs.wandb.ai/quickstart). 11 | 12 | ## Run training: 13 | Make sure all requirements are installed via `pip install -r requirements.txt`. 14 | If you have trouble installing PyTorch Geometric on your machine make sure to follow the [official instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html). 15 | In the config folder, training scripts are provided for all experiments in the paper. 16 | An example command would be: 17 | 18 | ``` 19 | python -m gnn_bvp_solver.app --config configs/task_shape/es_ma.json --no-gpu 20 | ``` 21 | 22 | ## Test a model: 23 | Look for the model you want to test in the weight & biases [artifact store](https://docs.wandb.ai/guides/artifacts). 24 | The best and latest model will be automatically tagged for each run. Pass `test` as task and the model artifact. 25 | It is important to use the same config like for training. 26 | An example command would be: 27 | 28 | ``` 29 | python -m gnn_bvp_solver.app --task test --artifact model-aer8oj02:v1 --config configs/task_shape/es_ma.json --no-gpu 30 | ``` 31 | 32 | ## Paper 33 | Please find the full details for our experiments in the [paper](https://arxiv.org/abs/2206.14092). 34 | We published our work at the [AI for Science workshop at ICML 2022](https://openreview.net/forum?id=4vx9FQA7wiC). 35 | 36 | ## License 37 | This software is licensed under the [MIT License](LICENSE). 38 | 39 | ## Cite our work 40 | Please cite the [paper](https://arxiv.org/abs/2206.14092) if you use this code in your own work. 41 | 42 | ``` 43 | @article{lotzsch2022learning, 44 | title={Learning the Solution Operator of Boundary Value Problems using Graph Neural Networks}, 45 | author={L{\"o}tzsch, Winfried and Ohler, Simon and Otterbach, Johannes S}, 46 | journal={ICML 2022 2nd AI for Science Workshop}, 47 | year={2022} 48 | } 49 | ``` 50 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/lightning_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | from torch.utils.data import IterableDataset 3 | from torch_geometric.loader.dataloader import DataLoader 4 | import pytorch_lightning as pl 5 | 6 | from torch_geometric.data import Data 7 | from squirrel.iterstream.torch_composables import TorchIterable 8 | 9 | 10 | data_type = Union[TorchIterable, Iterable[Data], IterableDataset] 11 | 12 | 13 | class FEMDataModule(pl.LightningDataModule): 14 | def __init__( 15 | self, 16 | train_data: data_type, 17 | val_data: data_type = None, 18 | test_data: data_type = None, 19 | batch_size_train: int = 1, 20 | num_workers: int = 0, 21 | ): 22 | """Load fem graph data. 23 | 24 | Args: 25 | train_data (data_type): Training data. 26 | val_data (data_type, optional): Validation data. Defaults to None. 27 | test_data (data_type, optional): Test data. Defaults to None. 28 | batch_size_train (int, optional): Batch size for training. Defaults to 1. 29 | num_workers (int, optional): Number of workers. Defaults to 0. 30 | """ 31 | super().__init__() 32 | 33 | self.batch_size = 1 # need bs 1 for val and test to visualize results 34 | self.batch_size_train = batch_size_train 35 | self.num_workers = num_workers 36 | 37 | self.train_dataset = train_data 38 | self.val_dataset = val_data 39 | self.test_dataset = test_data 40 | 41 | def _dataloader(self, data: data_type, batch_size: int = None, num_workers: int = None) -> DataLoader: 42 | """Internal function to create dataloaders. 43 | 44 | Args: 45 | data (data_type): data source 46 | batch_size (int): batch size 47 | num_workers (int): number of workers size 48 | 49 | Returns: 50 | DataLoader: _description_ 51 | """ 52 | 53 | if num_workers is None: 54 | num_workers = self.num_workers 55 | if batch_size is None: 56 | batch_size = self.batch_size 57 | 58 | # shuffle is handled outside 59 | return DataLoader(data, batch_size=batch_size, num_workers=num_workers, pin_memory=True) 60 | 61 | def train_dataloader(self) -> DataLoader: 62 | """Get Train Dataloader. 63 | 64 | Returns: 65 | DataLoader: Train Dataloader 66 | """ 67 | return self._dataloader(self.train_dataset, batch_size=self.batch_size_train) 68 | 69 | def val_dataloader(self) -> DataLoader: 70 | """Get Val Dataloader. 71 | 72 | Returns: 73 | DataLoader: Val Dataloader 74 | """ 75 | return self._dataloader(self.val_dataset) 76 | 77 | def test_dataloader(self) -> DataLoader: 78 | """Get Test Dataloader. 79 | 80 | Returns: 81 | DataLoader: Test Dataloader 82 | """ 83 | return self._dataloader(self.test_dataset) 84 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/generator_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 4 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, ParametrizedMesh 5 | from gnn_bvp_solver.fem_dataset.modules.fem_problem import FemProblem 6 | from typing import Iterable, Iterator, Dict 7 | 8 | import numpy as np 9 | 10 | 11 | i_o_mapping_type = Dict[str, Iterable[str]] 12 | solution_type = Dict[str, np.array] 13 | 14 | 15 | @dataclass 16 | class BaseGeneratorConfig: 17 | mesh_config: BaseMeshConfig 18 | generator_name: str 19 | 20 | 21 | @dataclass 22 | class GeneratorSolution: 23 | quantities: Dict 24 | io_mapping: Dict 25 | visible_area: np.array 26 | mesh: ParametrizedMesh 27 | 28 | 29 | def build_solution(problem: FemProblem, parametrized_mesh: ParametrizedMesh) -> GeneratorSolution: 30 | """Construct a generator solution by solving an fem problem on a mesh. 31 | 32 | Args: 33 | problem (FemProblem): fem problem to solve 34 | parametrized_mesh (ParametrizedMesh): mesh info to be stored for future reference 35 | 36 | Returns: 37 | GeneratorSolution: generated solution 38 | """ 39 | return GeneratorSolution( 40 | problem.solve(), problem.input_output_mapping(), parametrized_mesh.visible_area, parametrized_mesh 41 | ) 42 | 43 | 44 | class GeneratorBase(ABC): 45 | def __init__(self, mesh_generator: MeshGeneratorBase): 46 | """Creates the base class for all fem generators. 47 | 48 | Args: 49 | mesh_generator (MeshGeneratorBase): each fem generator needs access to a mesh generator 50 | """ 51 | self.mesh_generator = mesh_generator 52 | 53 | @abstractmethod 54 | def __iter__(self) -> Iterator[BaseGeneratorConfig]: 55 | """Generate fem configs that can be turned into problems and solved. 56 | 57 | Returns: 58 | Iterator[BaseGeneratorConfig]: Generated configs 59 | """ 60 | pass 61 | 62 | @staticmethod 63 | def distribute_on_mesh(coords: np.array, n_samples: int) -> np.array: 64 | """Randomly select a range of coordinates without replacement. 65 | 66 | Args: 67 | coords (np.array): coordinates to sample from 68 | n_samples (int): number of coordinates to sample 69 | 70 | Returns: 71 | np.array: sampled coordinates 72 | """ 73 | return coords[np.random.choice(coords.shape[0], n_samples, replace=False), :] 74 | 75 | @staticmethod 76 | @abstractmethod 77 | def solve_config(config: BaseGeneratorConfig) -> GeneratorSolution: 78 | """Generate an fem problem for a given config and solve it. 79 | 80 | Args: 81 | config (BaseGeneratorConfig): config to solve 82 | 83 | Yields: 84 | GeneratorSolution: fem solution containing the mesh and visible area 85 | """ 86 | pass 87 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/u_mesh.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import fenics as fn 4 | import meshio 5 | import pygmsh 6 | import numpy as np 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import convert_msh_to_fenics 8 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, MeshGeneratorBase, ParametrizedMesh 9 | 10 | 11 | @dataclass 12 | class UMeshConfig(BaseMeshConfig): 13 | cutout_size_x: float 14 | cutout_size_y: float 15 | 16 | 17 | UM_CONST_RES = 9 18 | CUTOUT_RANGE_X = [0.2, 0.6] 19 | CUTOUT_RANGE_Y = [0.2, 0.8] 20 | 21 | 22 | class UMeshGenerator(MeshGeneratorBase): 23 | def __init__(self, randomize: bool = False): 24 | """Create mesh generator for a U mesh""" 25 | # fixed resolution 26 | self.resolution = UM_CONST_RES 27 | self.randomize = randomize 28 | 29 | def __call__(self) -> UMeshConfig: 30 | """Generate a mesh config""" 31 | if self.randomize: 32 | cutout_size_x = np.random.uniform(CUTOUT_RANGE_X[0], CUTOUT_RANGE_X[1]) 33 | cutout_size_y = np.random.uniform(CUTOUT_RANGE_Y[0], CUTOUT_RANGE_Y[1]) 34 | else: 35 | cutout_size_x = 0.5 36 | cutout_size_y = 0.75 37 | 38 | return UMeshConfig(self.resolution, "UMeshGenerator", cutout_size_x, cutout_size_y) 39 | 40 | @staticmethod 41 | def vis_area(x: float, y: float, config: UMeshConfig, is_bc: bool = False) -> bool: 42 | """Query visible area for this mesh""" 43 | cutout_size_x = config.cutout_size_x 44 | cutout_size_y = config.cutout_size_y 45 | 46 | epsilon = 0.001 if is_bc else 0.0 47 | 48 | # outside bounding box 49 | if x < 0 + epsilon or y < 0 + epsilon or x > 1.0 - epsilon or y > 1.0 - epsilon: 50 | return is_bc 51 | 52 | # inside cutout 53 | cond_x = (x > (0.5 - cutout_size_x / 2.0) - epsilon) and (x < (0.5 + cutout_size_x / 2.0) + epsilon) 54 | cond_y = y > 1.0 - cutout_size_y - epsilon 55 | if cond_x and cond_y: 56 | return is_bc 57 | 58 | return not is_bc 59 | 60 | @staticmethod 61 | def config2msh(config: UMeshConfig) -> meshio.Mesh: 62 | """Convert mesh config to pygmsh mesh""" 63 | with pygmsh.occ.Geometry() as geom: 64 | cutout_size_x = config.cutout_size_x 65 | cutout_size_y = config.cutout_size_y 66 | 67 | g1 = geom.add_rectangle([0.0, 0.0, 0.0], 1.0, 1.0) 68 | g2 = geom.add_rectangle([0.5 - cutout_size_x / 2.0, 1.0 - cutout_size_y, 0.0], cutout_size_x, cutout_size_y) 69 | 70 | geom.boolean_difference(g1, g2) 71 | return geom.generate_mesh() 72 | 73 | @staticmethod 74 | def solve_config(config: UMeshConfig) -> fn.Mesh: 75 | """Generate fenics mesh from mesh config""" 76 | mesh = UMeshGenerator.config2msh(config) 77 | return ParametrizedMesh( 78 | convert_msh_to_fenics(mesh), 79 | partial(UMeshGenerator.vis_area, config=config, is_bc=True), 80 | partial(UMeshGenerator.vis_area, config=config, is_bc=False), 81 | ) 82 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/elasticity_fixed_line.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from typing import Iterator 4 | from dataclasses import dataclass 5 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import ( 6 | BaseGeneratorConfig, 7 | GeneratorBase, 8 | GeneratorSolution, 9 | build_solution, 10 | ) 11 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_registry import get_mesh 12 | from gnn_bvp_solver.fem_dataset.modules.boundary_conditions import BoundaryConditions 13 | from gnn_bvp_solver.fem_dataset.modules.linear_elasticity_problem import ElasticityProblem 14 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 15 | 16 | 17 | @dataclass 18 | class ElasticityFixedLineConfig(BaseGeneratorConfig): 19 | n_fix_lines_min: int 20 | n_fix_lines_max: int 21 | 22 | 23 | class ElasticityFixedLineGenerator(GeneratorBase): 24 | def __init__( 25 | self, n_samples: int, mesh_generator: MeshGeneratorBase, n_fix_lines_max: int = 1, n_fix_lines_min: int = 1 26 | ): 27 | """Creates a set of linear elasticity problems. 28 | 29 | Args: 30 | n_samples (int): number of samples in dataset to generate 31 | mesh_generator (MeshGeneratorBase): underlying mesh generator 32 | n_fix_lines_max (int, optional): Max number of fixed vertical lines in the BCs. Defaults to 1. 33 | n_fix_lines_min (int, optional): Min number of fixed vertical lines in the BCs. Defaults to 1. 34 | """ 35 | self.n_samples = n_samples 36 | self.n_fix_lines_min = n_fix_lines_min 37 | self.n_fix_lines_max = n_fix_lines_max 38 | 39 | super().__init__(mesh_generator) 40 | 41 | @staticmethod 42 | def solve_config(config: ElasticityFixedLineConfig, debug: bool = False) -> GeneratorSolution: 43 | """Solve elasticity config. 44 | 45 | Args: 46 | config (ElasticityFixedLineConfig): config to solve 47 | debug (bool, optional): Disable randomization for debugging. Defaults to False. 48 | 49 | Returns: 50 | GeneratorSolution: fem solution containing the mesh and visible area 51 | """ 52 | parametrized_mesh = get_mesh(config.mesh_config) 53 | vsize = ( 54 | config.n_fix_lines_max if debug else np.random.randint(config.n_fix_lines_min, config.n_fix_lines_max + 1) 55 | ) 56 | v_lines = np.random.uniform(size=(vsize,)) 57 | 58 | problem = ElasticityProblem(parametrized_mesh.mesh) 59 | problem.boundary_conditions.add_subexpression( 60 | BoundaryConditions("vertical_lines", x_line=v_lines, resolution=config.mesh_config.resolution).get_f() 61 | ) 62 | 63 | return build_solution(problem, parametrized_mesh) 64 | 65 | def __iter__(self) -> Iterator[ElasticityFixedLineConfig]: 66 | """Generate and solve the problems. 67 | 68 | Returns: 69 | Iterator[ElasticityFixedLineConfig]: Generated solutions 70 | """ 71 | for _i in range(self.n_samples): 72 | yield ElasticityFixedLineConfig( 73 | mesh_config=self.mesh_generator(), 74 | n_fix_lines_min=self.n_fix_lines_min, 75 | n_fix_lines_max=self.n_fix_lines_max, 76 | generator_name="ElasticityFixedLineGenerator", 77 | ) 78 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/rotating_charges.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterator, Tuple 3 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import BaseGeneratorConfig, GeneratorBase 4 | from gnn_bvp_solver.fem_dataset.modules.electrostatics_problem import ElectrostaticsProblem 5 | from gnn_bvp_solver.fem_dataset.modules.circle import CircleModule 6 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import solution_type, i_o_mapping_type 7 | 8 | import numpy as np 9 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import build_solution 10 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 11 | 12 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_registry import get_mesh 13 | 14 | 15 | @dataclass 16 | class RotatingChargesConfig(BaseGeneratorConfig): 17 | theta: float 18 | distance: float 19 | n_charges: int 20 | 21 | 22 | class RotatingChargesGenerator(GeneratorBase): 23 | def __init__( 24 | self, 25 | n_samples: int, 26 | mesh_generator: MeshGeneratorBase, 27 | n_charges: int = 1, 28 | distance_r: float = 0.1, 29 | ): 30 | """Creates a set of rotating charges with equal charge 31 | 32 | Args: 33 | n_samples (int, optional): number of rotating steps. 34 | mesh_generator (MeshGeneratorBase): underlying mesh generator 35 | n_charges (int, optional): number of rotating charges. Defaults to 1. 36 | distance_r (float, optional): radial distance to center. Defaults to 0.1. 37 | """ 38 | self.n_charges = n_charges 39 | self.n_samples = n_samples 40 | self.distance = distance_r / np.sqrt(2.0) 41 | 42 | super().__init__(mesh_generator) 43 | 44 | @staticmethod 45 | def solve_config(config: RotatingChargesConfig) -> Tuple[solution_type, i_o_mapping_type]: 46 | """Solve config combining plates and charges.""" 47 | parametrized_mesh = get_mesh(config.mesh_config) 48 | 49 | problem = ElectrostaticsProblem(parametrized_mesh.mesh) 50 | problem.boundary_conditions.add_subexpression(parametrized_mesh.default_bc) 51 | 52 | for i in range(config.n_charges): 53 | theta_charge = config.theta + (i * 2.0 * np.pi) / config.n_charges 54 | 55 | dx = config.distance * np.cos(theta_charge) - config.distance * np.sin(theta_charge) 56 | dy = config.distance * np.sin(theta_charge) + config.distance * np.cos(theta_charge) 57 | 58 | problem.charge_density.add_subexpression(CircleModule((0.5 + dx, 0.5 + dy), 0.005, 1.0)) 59 | 60 | return build_solution(problem, parametrized_mesh) 61 | 62 | def __iter__(self) -> Iterator[float]: 63 | """Generate and solve the problems. Align n charges on a circle and rotate. 64 | 65 | Returns: 66 | Iterator[Tuple[solution_type, i_o_mapping_type]]: Generated solutions 67 | """ 68 | for theta in np.linspace(0, 2 * np.pi, num=self.n_samples, endpoint=False): 69 | yield RotatingChargesConfig( 70 | theta=theta, 71 | distance=self.distance, 72 | n_charges=self.n_charges, 73 | mesh_config=self.mesh_generator(), 74 | generator_name="RotatingChargesGenerator", 75 | ) 76 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/magnetics_random_current.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterator 3 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import BaseGeneratorConfig, GeneratorBase 4 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_registry import get_mesh 5 | from gnn_bvp_solver.fem_dataset.modules.circle import CircleModule 6 | from gnn_bvp_solver.fem_dataset.modules.magnetostatics_problem import MagnetostaticsProblem 7 | 8 | import numpy as np 9 | 10 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import GeneratorSolution, build_solution 11 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 12 | 13 | 14 | @dataclass 15 | class MagneticsRandomCurrentConfig(BaseGeneratorConfig): 16 | n_currents_min: int 17 | n_currents_max: int 18 | 19 | 20 | class MagneticsRandomCurrentGenerator(GeneratorBase): 21 | def __init__( 22 | self, 23 | n_samples: int, 24 | mesh_generator: MeshGeneratorBase, 25 | n_currents_max: int = 1, 26 | n_currents_min: int = 1, 27 | ): 28 | """Create a set of radomly distributed flowing currents of equal magitude. 29 | 30 | Args: 31 | n_samples (int): number of samples in dataset to generate 32 | mesh_generator (MeshGeneratorBase): underlying mesh generator 33 | n_currents_max (int, optional): max number of charges that will be randomly distributed. Defaults to 1. 34 | n_currents_min (int, optional): min number of charges that will be randomly distributed. Defaults to 1. 35 | resolution (int, optional): resolution of the underlying mesh. Defaults to 35. 36 | """ 37 | self.n_samples = n_samples 38 | self.n_currents_max = n_currents_max 39 | self.n_currents_min = n_currents_min 40 | 41 | super().__init__(mesh_generator) 42 | 43 | @staticmethod 44 | def solve_config(config: MagneticsRandomCurrentConfig, debug: bool = False) -> GeneratorSolution: 45 | """Solve random current config.""" 46 | parametrized_mesh = get_mesh(config.mesh_config) 47 | coords = parametrized_mesh.mesh.coordinates() 48 | 49 | problem = MagnetostaticsProblem(parametrized_mesh.mesh) 50 | problem.boundary_conditions.add_subexpression(parametrized_mesh.default_bc) 51 | 52 | range_mx = ( 53 | config.n_currents_max if debug else np.random.randint(config.n_currents_min, config.n_currents_max + 1) 54 | ) 55 | pos = GeneratorBase.distribute_on_mesh(coords, range_mx) 56 | 57 | for j in range(range_mx): 58 | problem.current_density.add_subexpression(CircleModule((pos[j, 0], pos[j, 1]), 0.005, 1.0)) 59 | 60 | return build_solution(problem, parametrized_mesh) 61 | 62 | def __iter__(self) -> Iterator[MagneticsRandomCurrentConfig]: 63 | """Generate and solve the problems. Align n currents randomly. 64 | 65 | Returns: 66 | Iterator[MagneticsRandomCurrentConfig]: Generated solutions 67 | """ 68 | for _i in range(self.n_samples): 69 | yield MagneticsRandomCurrentConfig( 70 | mesh_config=self.mesh_generator(), 71 | n_currents_max=self.n_currents_max, 72 | n_currents_min=self.n_currents_min, 73 | generator_name="MagneticsRandomCurrentGenerator", 74 | ) 75 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/cylinder_mesh.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import meshio 4 | import pygmsh 5 | import numpy as np 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import convert_msh_to_fenics 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, MeshGeneratorBase, ParametrizedMesh 8 | 9 | 10 | CM_CONST_RES = 8 11 | CM_CONST_INNER_RADIUS_RANGE = [0.05, 0.25] 12 | CM_CONST_CENTER_X_RANGE = [0.35, 0.65] 13 | CM_CONST_CENTER_Y_RANGE = [0.35, 0.65] 14 | 15 | 16 | @dataclass 17 | class CylinderMeshConfig(BaseMeshConfig): 18 | center_x: float 19 | center_y: float 20 | inner_radius: float 21 | 22 | 23 | class CylinderGenerator(MeshGeneratorBase): 24 | def __init__(self, randomize: bool = False): 25 | """Create mesh generator for a disk with hole""" 26 | # fixed resolution 27 | self.resolution = CM_CONST_RES 28 | self.randomize = randomize 29 | 30 | def __call__(self) -> CylinderMeshConfig: 31 | """Generate a mesh config""" 32 | if self.randomize: 33 | inner_radius = np.random.uniform(CM_CONST_INNER_RADIUS_RANGE[0], CM_CONST_INNER_RADIUS_RANGE[1]) 34 | center_x = np.random.uniform(CM_CONST_CENTER_X_RANGE[0], CM_CONST_CENTER_X_RANGE[1]) 35 | center_y = np.random.uniform(CM_CONST_CENTER_Y_RANGE[0], CM_CONST_CENTER_Y_RANGE[1]) 36 | else: 37 | inner_radius = 0.12 38 | center_x = 0.5 39 | center_y = 0.5 40 | 41 | return CylinderMeshConfig(self.resolution, "CylinderGenerator", center_x, center_y, inner_radius) 42 | 43 | @staticmethod 44 | def vis_area(x: float, y: float, config: CylinderMeshConfig, is_bc: bool = False) -> bool: 45 | """Query visible area for this mesh""" 46 | resolution = config.resolution 47 | center_x = config.center_x 48 | center_y = config.center_y 49 | inner_radius = config.inner_radius 50 | 51 | epsilon = 1.0 / resolution if is_bc else 0.0 52 | 53 | x_0 = x - center_x 54 | y_0 = y - center_y 55 | 56 | r_sq = np.square(x_0) + np.square(y_0) 57 | r_sq_outer = np.square(x - 0.5) + np.square(y - 0.5) 58 | 59 | if r_sq_outer > 0.25 - np.square(epsilon): 60 | return is_bc 61 | if r_sq < np.square(inner_radius) + np.square(epsilon) * 0.5: 62 | return is_bc 63 | 64 | return not is_bc 65 | 66 | @staticmethod 67 | def config2msh(config: CylinderMeshConfig) -> meshio.Mesh: 68 | """Convert mesh config to pygmsh mesh""" 69 | center_x = config.center_x 70 | center_y = config.center_y 71 | inner_radius = config.inner_radius 72 | 73 | with pygmsh.occ.Geometry() as geom: 74 | g1 = geom.add_disk([0.5, 0.5], 0.5) 75 | g2 = geom.add_disk([center_x, center_y], inner_radius) 76 | geom.boolean_difference(g1, g2) 77 | return geom.generate_mesh() 78 | 79 | @staticmethod 80 | def solve_config(config: CylinderMeshConfig) -> ParametrizedMesh: 81 | """Generate fenics mesh from mesh config""" 82 | return ParametrizedMesh( 83 | convert_msh_to_fenics(CylinderGenerator.config2msh(config)), 84 | partial(CylinderGenerator.vis_area, config=config, is_bc=True), 85 | partial(CylinderGenerator.vis_area, config=config), 86 | ) 87 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/electrics_random_charge.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterator 3 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import BaseGeneratorConfig, GeneratorBase, GeneratorSolution 4 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 5 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_registry import get_mesh 6 | from gnn_bvp_solver.fem_dataset.modules.electrostatics_problem import ElectrostaticsProblem 7 | from gnn_bvp_solver.fem_dataset.modules.circle import CircleModule 8 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import build_solution 9 | 10 | import numpy as np 11 | 12 | 13 | @dataclass 14 | class ElectricsRandomChargeConfig(BaseGeneratorConfig): 15 | n_charges_min: int 16 | n_charges_max: int 17 | 18 | 19 | class ElectricsRandomChargeGenerator(GeneratorBase): 20 | def __init__( 21 | self, 22 | n_samples: int, 23 | mesh_generator: MeshGeneratorBase, 24 | n_charges_max: int = 1, 25 | n_charges_min: int = 1, 26 | ): 27 | """Create a set of randomly distributed charges with equal density. 28 | 29 | Args: 30 | n_samples (int): number of samples in dataset to generate 31 | mesh_generator (MeshGeneratorBase): underlying mesh generator 32 | n_charges_max (int, optional): max number of charges that will be randomly distributed. Defaults to 1. 33 | n_charges_min (int, optional): min number of charges that will be randomly distributed. Defaults to 1. 34 | """ 35 | self.n_samples = n_samples 36 | self.n_charges_max = n_charges_max 37 | self.n_charges_min = n_charges_min 38 | 39 | super().__init__(mesh_generator) 40 | 41 | @staticmethod 42 | def solve_config(config: ElectricsRandomChargeConfig, debug: bool = False) -> GeneratorSolution: 43 | """Solve random charge config. 44 | 45 | Args: 46 | config (ElectricsRandomChargeConfig): config to solve 47 | debug (bool, optional): Disable randomization for debugging. Defaults to False. 48 | 49 | Returns: 50 | GeneratorSolution: fem solution containing the mesh and visible area 51 | """ 52 | parametrized_mesh = get_mesh(config.mesh_config) 53 | coords = parametrized_mesh.mesh.coordinates() 54 | 55 | problem = ElectrostaticsProblem(parametrized_mesh.mesh) 56 | problem.boundary_conditions.add_subexpression(parametrized_mesh.default_bc) 57 | range_mx = config.n_charges_max if debug else np.random.randint(config.n_charges_min, config.n_charges_max + 1) 58 | pos = GeneratorBase.distribute_on_mesh(coords, range_mx) 59 | 60 | for j in range(range_mx): 61 | problem.charge_density.add_subexpression(CircleModule((pos[j, 0], pos[j, 1]), 0.005, 1.0)) 62 | 63 | return build_solution(problem, parametrized_mesh) 64 | 65 | def __iter__(self) -> Iterator[ElectricsRandomChargeConfig]: 66 | """Generate and solve the problems. Align n currents randomly. 67 | 68 | Returns: 69 | Iterator[ElectricsRandomChargeConfig]: Generated solutions 70 | """ 71 | for _i in range(self.n_samples): 72 | yield ElectricsRandomChargeConfig( 73 | mesh_config=self.mesh_generator(), 74 | n_charges_min=self.n_charges_min, 75 | n_charges_max=self.n_charges_max, 76 | generator_name="ElectricsRandomChargeGenerator", 77 | ) 78 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/boundary_conditions.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable 2 | import numpy as np 3 | 4 | 5 | class BoundaryConditions: 6 | def __init__(self, name: str, **kwargs): 7 | """Define different types of boundary conditions""" 8 | self.name = name 9 | self.kwargs = kwargs 10 | 11 | def get_f(self) -> Callable: 12 | """Get predefined boundary conditions (mostly for square mesh)""" 13 | if self.name == "all": 14 | return self.all_boundaries 15 | elif self.name == "left_and_right": 16 | return self.left_and_right 17 | elif self.name == "left": 18 | return self.left 19 | elif self.name == "right": 20 | return self.right 21 | elif self.name == "top_and_bottom": 22 | return self.top_and_bottom 23 | elif self.name == "vertical_lines": 24 | return self.f_vertical_lines(**self.kwargs) 25 | else: 26 | raise ValueError(f"name {self.name} not found") 27 | 28 | @staticmethod 29 | def all_boundaries(x: float, y: float) -> bool: 30 | """Set all boundaries of a square mesh to true. 31 | 32 | Args: 33 | x (float): x coordinate 34 | y (float): y coordinate 35 | 36 | Returns: 37 | bool: whether boundary is active at this point 38 | """ 39 | return x < 0.00001 or x > 0.99999 or y < 0.00001 or y > 0.99999 40 | 41 | @staticmethod 42 | def left_and_right(x: float, y: float) -> bool: 43 | """Set left and right boundaries of a square mesh to true. 44 | 45 | Args: 46 | x (float): x coordinate 47 | y (float): y coordinate 48 | 49 | Returns: 50 | bool: whether boundary is active at this point 51 | """ 52 | return x < 0.00001 or x > 0.99999 53 | 54 | @staticmethod 55 | def left(x: float, y: float) -> bool: 56 | """Set left boundaries of a square mesh to true. 57 | 58 | Args: 59 | x (float): x coordinate 60 | y (float): y coordinate 61 | 62 | Returns: 63 | bool: whether boundary is active at this point 64 | """ 65 | return x < 0.00001 66 | 67 | @staticmethod 68 | def right(x: float, y: float, bdr: bool) -> bool: 69 | """Set right boundaries of a square mesh to true. 70 | 71 | Args: 72 | x (float): x coordinate 73 | y (float): y coordinate 74 | 75 | Returns: 76 | bool: whether boundary is active at this point 77 | """ 78 | return x > 0.99999 79 | 80 | @staticmethod 81 | def top_and_bottom(x: float, y: float) -> bool: 82 | """Set top and bottom boundaries of a square mesh to true. 83 | 84 | Args: 85 | x (float): x coordinate 86 | y (float): y coordinate 87 | 88 | Returns: 89 | bool: whether boundary is active at this point 90 | """ 91 | return y < 0.00001 or y > 0.99999 92 | 93 | @staticmethod 94 | def f_vertical_lines(x_line: Iterable[float], resolution: int) -> Callable[[float, float], bool]: 95 | """Create boundary condition as a vertical line. 96 | 97 | Args: 98 | x_line (float): x_line coordinate of line 99 | resolution (int): resolution of mesh, needed to avoid rounding errors 100 | 101 | Returns: 102 | bool: whether boundary is active at this point 103 | """ 104 | x_line = np.array(x_line) 105 | return lambda x, y: np.min(np.abs(x - x_line)) < 0.5 / resolution 106 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/magnetostatics_problem.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, Tuple 2 | import fenics as fn 3 | import numpy as np 4 | from gnn_bvp_solver.fem_dataset.modules.fem_problem import FemProblem 5 | from gnn_bvp_solver.fem_dataset.recursive_user_expression import RecursiveUserExpression 6 | from gnn_bvp_solver.fem_dataset.utils import extract_edges_from_triangle_mesh 7 | 8 | 9 | class MagnetostaticsProblem(FemProblem): 10 | def __init__(self, mesh: fn.Mesh): 11 | """Init Magnetostatics Problem. Generate recursive expressions for all quantities to 12 | enable adding / removing components flexibly. Fem mesh is fixed to unit square for now. 13 | 14 | Args: 15 | mesh: FEM mesh to evaluate on. 16 | """ 17 | self.current_density = RecursiveUserExpression(0.0) 18 | self.permeability = RecursiveUserExpression(1.0) 19 | super().__init__(mesh) 20 | 21 | def solve(self) -> Dict[str, np.array]: 22 | """Using fenics to find the solution for an magnetostatics problem 23 | 24 | Returns: 25 | FEMSolution: solution for magnetic potential and field 26 | """ 27 | V = fn.FunctionSpace(self.mesh, "P", 2) 28 | 29 | class TempClass: 30 | @staticmethod 31 | def _bdr(x: Tuple[float, float]) -> bool: 32 | """Recursively define the boundary condition (for now only enforce V=0) 33 | It has to be a static method so we use this hack to make it work. 34 | Args: 35 | x (Tuple[float, float]): point to evaluate 36 | 37 | Returns: 38 | bool: True if BC is active otherwise False 39 | """ 40 | return self.boundary_conditions(x) 41 | 42 | bdr_bc = fn.DirichletBC(V, fn.Constant(0.0), TempClass._bdr) 43 | bcs = [bdr_bc] 44 | 45 | A = fn.TrialFunction(V) 46 | v = fn.TestFunction(V) 47 | a = fn.dot(fn.grad(A), fn.grad(v)) * fn.dx 48 | L = self.permeability * self.current_density * v * fn.dx 49 | A_res = fn.Function(V) 50 | fn.solve(a == L, A_res, bcs) 51 | 52 | B_x = fn.project(A_res.dx(1)).compute_vertex_values() 53 | B_y = fn.project(-A_res.dx(0)).compute_vertex_values() 54 | 55 | return { 56 | "x": np.array(self.mesh.coordinates()[:, 0], dtype=np.float32), 57 | "y": np.array(self.mesh.coordinates()[:, 1], dtype=np.float32), 58 | "I": np.array(fn.project(self.current_density, V).compute_vertex_values(), dtype=np.float32), 59 | "A": np.array(A_res.compute_vertex_values(), dtype=np.float32), 60 | "B_x": np.array(B_x, dtype=np.float32), 61 | "B_y": np.array(B_y, dtype=np.float32), 62 | "mu": np.array(fn.project(self.permeability, V).compute_vertex_values(), dtype=np.float32), 63 | "bdr_v0": np.array(fn.project(self.boundary_conditions, V).compute_vertex_values(), dtype=np.float32), 64 | "edge_index": np.array(extract_edges_from_triangle_mesh(self.mesh), dtype=np.int64), 65 | } 66 | 67 | @staticmethod 68 | def input_output_mapping() -> Dict[str, Iterable[str]]: 69 | """Specify which quantities are input or output. 70 | 71 | Returns: 72 | Dict[str, Iterable[str]]: input / output values of this problem 73 | """ 74 | # x and y must be first by convention 75 | return {"x": ["x", "y", "I", "bdr_v0", "mu"], "y": ["A", "B_x", "B_y"]} 76 | 77 | @staticmethod 78 | def physical_quantities() -> Iterable[Tuple]: 79 | """Return physical quantities present in this simulation. 80 | 81 | Returns: 82 | Iterable: all quantities (can be 1d or nd) 83 | """ 84 | return [("I",), ("mu",), ("A",), ("B_x", "B_y")] 85 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/electrostatics_problem.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, Tuple 2 | import fenics as fn 3 | import numpy as np 4 | from gnn_bvp_solver.fem_dataset.modules.fem_problem import FemProblem 5 | from gnn_bvp_solver.fem_dataset.recursive_user_expression import RecursiveUserExpression 6 | from gnn_bvp_solver.fem_dataset.utils import extract_edges_from_triangle_mesh 7 | 8 | 9 | class ElectrostaticsProblem(FemProblem): 10 | def __init__(self, mesh: fn.Mesh): 11 | """Init Electrostatics Problem. Generate recursive expressions for all quantities to 12 | enable adding / removing components flexibly. Fem mesh is fixed to unit square for now. 13 | 14 | Args: 15 | mesh: FEM mesh to evaluate on. 16 | """ 17 | self.charge_density = RecursiveUserExpression(0.0) 18 | self.permittivity = RecursiveUserExpression(1.0) 19 | super().__init__(mesh) 20 | 21 | def solve(self) -> Dict[str, np.array]: 22 | """Using fenics to find the solution for an electrostatics problem 23 | 24 | Returns: 25 | FEMSolution: solution for potential and electric field 26 | """ 27 | V = fn.FunctionSpace(self.mesh, "P", 2) 28 | 29 | class TempClass: 30 | @staticmethod 31 | def _bdr(x: Tuple[float, float]) -> bool: 32 | """Recursively define the boundary condition (for now only enforce V=0) 33 | It has to be a static method so we use this hack to make it work. 34 | Args: 35 | x (Tuple[float, float]): point to evaluate 36 | 37 | Returns: 38 | bool: True if BC is active otherwise False 39 | """ 40 | return self.boundary_conditions(x) 41 | 42 | bdr_bc = fn.DirichletBC(V, fn.Constant(0.0), TempClass._bdr) 43 | bcs = [bdr_bc] 44 | 45 | u = fn.TrialFunction(V) 46 | v = fn.TestFunction(V) 47 | a = fn.dot(fn.grad(u), fn.grad(v)) * self.permittivity * fn.dx 48 | L = self.charge_density * v * fn.dx 49 | u = fn.Function(V) 50 | fn.solve(a == L, u, bcs) 51 | 52 | electric_field = fn.project(-fn.grad(u)) 53 | e_result = electric_field.compute_vertex_values().reshape((2, -1)) 54 | 55 | return { 56 | "x": np.array(self.mesh.coordinates()[:, 0], dtype=np.float32), 57 | "y": np.array(self.mesh.coordinates()[:, 1], dtype=np.float32), 58 | "rho": np.array(fn.project(self.charge_density, V).compute_vertex_values(), dtype=np.float32), 59 | "u": np.array(u.compute_vertex_values(), dtype=np.float32), 60 | "E_x": np.array(e_result[0], dtype=np.float32), 61 | "E_y": np.array(e_result[1], dtype=np.float32), 62 | "epsilon": np.array(fn.project(self.permittivity, V).compute_vertex_values(), dtype=np.float32), 63 | "bdr_v0": np.array(fn.project(self.boundary_conditions, V).compute_vertex_values(), dtype=np.float32), 64 | "edge_index": np.array(extract_edges_from_triangle_mesh(self.mesh), dtype=np.int64), 65 | } 66 | 67 | @staticmethod 68 | def input_output_mapping() -> Dict[str, Iterable[str]]: 69 | """Specify which quantities are input or output. 70 | 71 | Returns: 72 | Dict[str, Iterable[str]]: input / output values of this problem 73 | """ 74 | # x and y must be first by convention 75 | return {"x": ["x", "y", "rho", "bdr_v0", "epsilon"], "y": ["u", "E_x", "E_y"]} 76 | 77 | @staticmethod 78 | def physical_quantities() -> Iterable[Tuple]: 79 | """Return physical quantities present in this simulation. 80 | 81 | Returns: 82 | Iterable: all quantities (can be 1d or nd) 83 | """ 84 | return [("rho",), ("epsilon",), ("u",), ("E_x", "E_y")] 85 | -------------------------------------------------------------------------------- /gnn_bvp_solver/preprocessing/gather_data.py: -------------------------------------------------------------------------------- 1 | from gnn_bvp_solver.fem_dataset.data_generators.elasticity_fixed_line import ElasticityFixedLineGenerator 2 | from gnn_bvp_solver.fem_dataset.data_generators.electrics_random_charge import ElectricsRandomChargeGenerator 3 | from gnn_bvp_solver.fem_dataset.data_generators.magnetics_random_current import MagneticsRandomCurrentGenerator 4 | from gnn_bvp_solver.fem_dataset.fem_driver import FemDriver 5 | from gnn_bvp_solver.fem_dataset.fem_driver import GeneratorBase 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.cylinder_precomputed import CylinderMeshPrecomputedGenerator 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.disk_precomputed import UnitDiskPrecomputedGenerator 8 | from gnn_bvp_solver.fem_dataset.mesh_generators.l_mesh_precomputed import LMeshPrecomputedGenerator 9 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 10 | from gnn_bvp_solver.fem_dataset.mesh_generators.square_mesh import UnitSquareGenerator 11 | from gnn_bvp_solver.fem_dataset.mesh_generators.u_mesh_precomputed import UMeshPrecomputedGenerator 12 | from squirrel_datasets_core.spark.setup_spark import get_spark 13 | from squirrel_datasets_core.preprocessing.save_shards import save_composable_to_shards 14 | from pyspark.sql import SparkSession 15 | 16 | 17 | SAMPLES = 2500 18 | SHARDS = 25 19 | 20 | fem_generators = [ 21 | lambda mg: ElectricsRandomChargeGenerator(SAMPLES, mg, 3), 22 | lambda mg: MagneticsRandomCurrentGenerator(SAMPLES, mg, 3), 23 | lambda mg: ElasticityFixedLineGenerator(SAMPLES, mg, 3), 24 | ] 25 | 26 | 27 | extra_generators = [ 28 | lambda mg: ElectricsRandomChargeGenerator(SAMPLES, mg, 5, 4), 29 | lambda mg: MagneticsRandomCurrentGenerator(SAMPLES, mg, 5, 4), 30 | lambda mg: ElasticityFixedLineGenerator(SAMPLES, mg, 5, 4), 31 | ] 32 | 33 | 34 | def generate_config(session: SparkSession, generator: GeneratorBase, mesh_g: MeshGeneratorBase) -> None: 35 | """Save shards using squirrel for one combination of fem and mesh generator""" 36 | key = f"{type(generator).__name__}_{mesh_g}" 37 | 38 | fem_driver = FemDriver(generator) 39 | iter = fem_driver.get_iter("train", solve_pde=False) 40 | 41 | path = f"gs://squirrel-core-public-data/gnn_bvp_solver/{key}" 42 | # path = f"local/{key}" 43 | 44 | save_composable_to_shards( 45 | src_it=iter, num_shards=SHARDS, out_url=path, session=session, hooks=[fem_driver.map_sample] 46 | ) 47 | 48 | 49 | def generate_spark() -> None: 50 | """Use spark to generate fem simulations on multiple meshes""" 51 | session = get_spark("gnn-bvp-preprocessing") 52 | 53 | mesh_generators = { 54 | "square": UnitSquareGenerator(15, False), 55 | "disk": UnitDiskPrecomputedGenerator(15, session.sparkContext, False), 56 | "cylinder": CylinderMeshPrecomputedGenerator(session.sparkContext, False), 57 | "l_mesh": LMeshPrecomputedGenerator(session.sparkContext, False), 58 | "u_mesh": UMeshPrecomputedGenerator(session.sparkContext, False), 59 | } 60 | 61 | mesh_generators_rand = { 62 | "square_rand": UnitSquareGenerator(15, True), 63 | "disk_rand": UnitDiskPrecomputedGenerator(15, session.sparkContext, True), 64 | "cylinder_rand": CylinderMeshPrecomputedGenerator(session.sparkContext, True), 65 | "l_mesh_rand": LMeshPrecomputedGenerator(session.sparkContext, True), 66 | "u_mesh_rand": UMeshPrecomputedGenerator(session.sparkContext, True), 67 | } 68 | 69 | for fem_g in fem_generators: 70 | for mesh_g in mesh_generators: 71 | generate_config(session, fem_g(mesh_generators[mesh_g]), mesh_g) 72 | for mesh_g in mesh_generators_rand: 73 | generate_config(session, fem_g(mesh_generators_rand[mesh_g]), mesh_g) 74 | 75 | # create extra 76 | for fem_g in extra_generators: 77 | for mesh_g in mesh_generators: 78 | generate_config(session, fem_g(mesh_generators[mesh_g]), mesh_g + "_extra") 79 | 80 | 81 | if __name__ == "__main__": 82 | generate_spark() 83 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/modules/linear_elasticity_problem.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, Tuple 2 | import fenics as fn 3 | import numpy as np 4 | from gnn_bvp_solver.fem_dataset.modules.fem_problem import FemProblem 5 | from gnn_bvp_solver.fem_dataset.utils import extract_edges_from_triangle_mesh 6 | 7 | 8 | class ElasticityProblem(FemProblem): 9 | def __init__(self, mesh: fn.Mesh): 10 | """Init Linear Elasticity Problem. 11 | 12 | Args: 13 | mesh: FEM mesh to evaluate on. 14 | """ 15 | super().__init__(mesh) 16 | 17 | def solve(self) -> Dict[str, np.array]: 18 | """Using fenics to find the solution for an elasticity problem 19 | 20 | Returns: 21 | FEMSolution: solution 22 | """ 23 | # V0 = fn.FunctionSpace(self.mesh, "DG", 0) 24 | V = fn.VectorFunctionSpace(self.mesh, "Lagrange", 2) 25 | 26 | class TempClass: 27 | @staticmethod 28 | def _bdr(x: Tuple[float, float]) -> bool: 29 | """Recursively define the boundary condition (for now only enforce V=0) 30 | It has to be a static method so we use this hack to make it work. 31 | Args: 32 | x (Tuple[float, float]): point to evaluate 33 | 34 | Returns: 35 | bool: True if BC is active otherwise False 36 | """ 37 | return self.boundary_conditions(x) 38 | 39 | bdr_bc = fn.DirichletBC(V, fn.Constant((0.0, 0.0)), TempClass._bdr) 40 | bcs = [bdr_bc] 41 | 42 | def _eps(v: Any) -> Any: 43 | return fn.sym(fn.grad(v)) 44 | 45 | def _sigma(u: Any) -> Any: 46 | return 10e1 * fn.tr(_eps(u)) * fn.Identity(2) + 2 * _eps(u) 47 | 48 | f = fn.Constant((0, -1.0)) 49 | u = fn.TrialFunction(V) 50 | v = fn.TestFunction(V) 51 | a = fn.inner(_sigma(u), _eps(v)) * fn.dx 52 | L = fn.inner(f, v) * fn.dx 53 | u_res = fn.Function(V) 54 | fn.solve(a == L, u_res, bcs) 55 | 56 | displacement = u_res.compute_vertex_values().reshape(2, -1) 57 | # fn.plot(u_res, title='Displacement', mode='displacement') 58 | 59 | s = _sigma(u_res) - (1.0 / 3) * fn.tr(_sigma(u_res)) * fn.Identity(2) # deviatoric stress 60 | von_Mises = fn.sqrt(3.0 / 2 * fn.inner(s, s)) 61 | V = fn.FunctionSpace(self.mesh, "P", 1) 62 | von_Mises = fn.project(von_Mises, V) 63 | # fn.plot(von_Mises, title='Stress intensity') 64 | 65 | u_magnitude = fn.sqrt(fn.dot(u_res, u_res)) 66 | u_magnitude = fn.project(u_magnitude, V) 67 | # fn.plot(u_magnitude, 'Displacement magnitude') 68 | 69 | return { 70 | "x": np.array(self.mesh.coordinates()[:, 0], dtype=np.float32), 71 | "y": np.array(self.mesh.coordinates()[:, 1], dtype=np.float32), 72 | "u_x": np.array(displacement[0], dtype=np.float32), 73 | "u_y": np.array(displacement[1], dtype=np.float32), 74 | "m": np.array(u_magnitude.compute_vertex_values(), dtype=np.float32), 75 | "stress": np.array(von_Mises.compute_vertex_values(), dtype=np.float32), 76 | "bdr_v0": np.array(fn.project(self.boundary_conditions, V).compute_vertex_values(), dtype=np.float32), 77 | "edge_index": np.array(extract_edges_from_triangle_mesh(self.mesh), dtype=np.int64), 78 | } 79 | 80 | @staticmethod 81 | def input_output_mapping() -> Dict[str, Iterable[str]]: 82 | """Specify which quantities are input or output. 83 | 84 | Returns: 85 | Dict[str, Iterable[str]]: input / output values of this problem 86 | """ 87 | # x and y must be first by convention 88 | return {"x": ["x", "y", "bdr_v0"], "y": ["m", "u_x", "u_y"]} 89 | 90 | @staticmethod 91 | def physical_quantities() -> Iterable[Tuple]: 92 | """Return physical quantities present in this simulation. 93 | 94 | Returns: 95 | Iterable: all quantities (can be 1d or nd) 96 | """ 97 | return [("m",), ("stress",), ("u_x", "u_y")] 98 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/plates_and_charges.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterator, Tuple 3 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import BaseGeneratorConfig, GeneratorBase 4 | from gnn_bvp_solver.fem_dataset.modules import boundary_conditions 5 | from gnn_bvp_solver.fem_dataset.modules.electrostatics_problem import ElectrostaticsProblem 6 | from gnn_bvp_solver.fem_dataset.modules.plate import PlateModule 7 | from gnn_bvp_solver.fem_dataset.modules.circle import CircleModule 8 | 9 | import numpy as np 10 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import GeneratorSolution, build_solution 11 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 12 | 13 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_registry import get_mesh 14 | 15 | 16 | @dataclass 17 | class PlatesAndChargesConfig(BaseGeneratorConfig): 18 | distance: float 19 | height: float 20 | theta: float 21 | 22 | 23 | class PlatesAndChargesGenerator(GeneratorBase): 24 | def __init__( 25 | self, 26 | n_samples: int, 27 | mesh_generator: MeshGeneratorBase, 28 | height_range: Tuple[float, float], 29 | distance_range: Tuple[float, float], 30 | n_rotations: int = 1, 31 | ): 32 | """Creates a set of condenser plates and additional point charges. 33 | 34 | Args: 35 | n_samples (int): number of samples in dataset to generate 36 | mesh_generator (MeshGeneratorBase): underlying mesh generator 37 | height_range (Tuple): min and max height 38 | distance_range (Tuple): min and max distance 39 | n_rotations (int, optional): Rotations of the condenser plates. Defaults to 1. 40 | """ 41 | self.n_rotations = n_rotations 42 | self.height_range = height_range 43 | self.distance_range = distance_range 44 | self.n_samples = n_samples // n_rotations 45 | self.boundary_conditions = boundary_conditions 46 | 47 | super().__init__(mesh_generator) 48 | 49 | @staticmethod 50 | def solve_config(config: PlatesAndChargesConfig) -> GeneratorSolution: 51 | """Solve config combining plates and charges.""" 52 | parametrized_mesh = get_mesh(config.mesh_config) 53 | 54 | problem = ElectrostaticsProblem(parametrized_mesh.mesh) 55 | problem.boundary_conditions.add_subexpression(parametrized_mesh.default_bc) 56 | problem.charge_density.add_subexpression( 57 | PlateModule( 58 | config.distance, config.height, 1.0, rotation=config.theta, resolution=config.mesh_config.resolution 59 | ) 60 | ) 61 | problem.charge_density.add_subexpression( 62 | PlateModule( 63 | -config.distance, config.height, -1.0, rotation=config.theta, resolution=config.mesh_config.resolution 64 | ) 65 | ) 66 | 67 | problem.charge_density.add_subexpression( 68 | CircleModule((0.5, 0.5 + max(config.height, config.distance) + 0.1), 0.005, -1.0) 69 | ) 70 | problem.charge_density.add_subexpression( 71 | CircleModule((0.5, 0.5 - max(config.height, config.distance) - 0.1), 0.005, 1.0) 72 | ) 73 | 74 | return build_solution(problem, parametrized_mesh) 75 | 76 | def __iter__(self) -> Iterator[PlatesAndChargesConfig]: 77 | """Generate and solve the problems. 78 | 79 | Returns: 80 | Iterator[PlatesAndChargesConfig]: Generated solutions 81 | """ 82 | samples_height = int(np.sqrt(self.n_samples)) 83 | samples_distance = self.n_samples // samples_height 84 | 85 | for height in np.linspace(self.height_range[0], self.height_range[1], num=samples_height): 86 | for distance in np.linspace(self.distance_range[0], self.distance_range[1], num=samples_distance): 87 | for theta in np.linspace(0, 2 * np.pi, num=self.n_rotations, endpoint=False): 88 | yield PlatesAndChargesConfig( 89 | distance=distance, 90 | height=height, 91 | theta=theta, 92 | generator_name="PlatesAndChargesGenerator", 93 | mesh_config=self.mesh_generator(), 94 | ) 95 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/mesh_generators/l_mesh.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import partial 3 | import meshio 4 | import pygmsh 5 | import numpy as np 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.convert_mesh import convert_msh_to_fenics 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import BaseMeshConfig, MeshGeneratorBase, ParametrizedMesh 8 | 9 | 10 | @dataclass 11 | class LMeshConfig(BaseMeshConfig): 12 | cutout_location: int 13 | cutout_size_x: float 14 | cutout_size_y: float 15 | 16 | 17 | LM_CONST_RES = 9 18 | CUTOUT_RANGE = [0.2, 0.8] 19 | 20 | 21 | class LMeshGenerator(MeshGeneratorBase): 22 | def __init__(self, randomize: bool = False): 23 | """Create mesh generator for an L mesh and precompute meshes as files""" 24 | # fixed resolution 25 | self.resolution = LM_CONST_RES 26 | self.randomize = randomize 27 | 28 | def __call__(self) -> LMeshConfig: 29 | """Generate a mesh config""" 30 | if self.randomize: 31 | cutout_location = np.random.randint(4) 32 | cutout_size_x = np.random.uniform(CUTOUT_RANGE[0], CUTOUT_RANGE[1]) 33 | cutout_size_y = np.random.uniform(CUTOUT_RANGE[0], CUTOUT_RANGE[1]) 34 | else: 35 | cutout_location = 0 36 | cutout_size_x = 0.5 37 | cutout_size_y = 0.5 38 | 39 | return LMeshConfig(self.resolution, "LMeshGenerator", cutout_location, cutout_size_x, cutout_size_y) 40 | 41 | @staticmethod 42 | def vis_area(x: float, y: float, config: LMeshConfig, is_bc: bool = False) -> bool: 43 | """Query visible area for this mesh""" 44 | cutout_size_x = config.cutout_size_x 45 | cutout_size_y = config.cutout_size_y 46 | cutout_location = config.cutout_location 47 | 48 | epsilon = 0.001 if is_bc else 0.0 49 | 50 | # outside bounding box 51 | if x < 0 + epsilon or y < 0 + epsilon or x > 1.0 - epsilon or y > 1.0 - epsilon: 52 | return is_bc 53 | 54 | if cutout_location == 0: 55 | if x < cutout_size_x + epsilon and y < cutout_size_y + epsilon: 56 | return is_bc 57 | elif cutout_location == 1: 58 | if x > (1.0 - cutout_size_x) - epsilon and y < cutout_size_y + epsilon: 59 | return is_bc 60 | elif cutout_location == 2: 61 | if x < cutout_size_x + epsilon and y > (1 - cutout_size_y) - epsilon: 62 | return is_bc 63 | else: 64 | if x > (1.0 - cutout_size_x) - epsilon and y > (1 - cutout_size_y) - epsilon: 65 | return is_bc 66 | 67 | return not is_bc 68 | 69 | @staticmethod 70 | def config2msh(config: LMeshConfig) -> meshio.Mesh: 71 | """Convert mesh config to pygmsh mesh""" 72 | with pygmsh.occ.Geometry() as geom: 73 | cutout_size_x = config.cutout_size_x 74 | cutout_size_y = config.cutout_size_y 75 | cutout_location = config.cutout_location 76 | 77 | # mesh_size = 1.25 / config.resolution 78 | # ignore resolution 79 | 80 | g1 = geom.add_rectangle([0.0, 0.0, 0.0], 1.0, 1.0) 81 | 82 | diff_x = 0.5 - cutout_size_x 83 | diff_y = 0.5 - cutout_size_y 84 | 85 | if cutout_location == 0: 86 | g2 = geom.add_rectangle([0.0, 0.0, 0.0], cutout_size_x, cutout_size_y) 87 | elif cutout_location == 1: 88 | g2 = geom.add_rectangle([0.5 + diff_x, 0.0, 0.0], cutout_size_x, cutout_size_y) 89 | elif cutout_location == 2: 90 | g2 = geom.add_rectangle([0.0, 0.5 + diff_y, 0.0], cutout_size_x, cutout_size_y) 91 | else: 92 | g2 = geom.add_rectangle([0.5 + diff_x, 0.5 + diff_y, 0.0], cutout_size_x, cutout_size_y) 93 | 94 | geom.boolean_difference(g1, g2) 95 | return geom.generate_mesh() 96 | 97 | @staticmethod 98 | def solve_config(config: LMeshConfig) -> ParametrizedMesh: 99 | """Generate fenics mesh from mesh config""" 100 | mesh = LMeshGenerator.config2msh(config) 101 | return ParametrizedMesh( 102 | convert_msh_to_fenics(mesh), 103 | partial(LMeshGenerator.vis_area, config=config, is_bc=True), 104 | partial(LMeshGenerator.vis_area, config=config, is_bc=False), 105 | ) 106 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_dataset/data_generators/condenser_plates.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterator, Tuple 3 | from gnn_bvp_solver.fem_dataset.data_generators.generator_base import ( 4 | BaseGeneratorConfig, 5 | GeneratorBase, 6 | GeneratorSolution, 7 | build_solution, 8 | ) 9 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_base import MeshGeneratorBase 10 | from gnn_bvp_solver.fem_dataset.mesh_generators.mesh_registry import get_mesh 11 | from gnn_bvp_solver.fem_dataset.modules.circle import CircleModule 12 | from gnn_bvp_solver.fem_dataset.modules.plate import PlateModule 13 | from gnn_bvp_solver.fem_dataset.modules.electrostatics_problem import ElectrostaticsProblem 14 | 15 | import numpy as np 16 | 17 | 18 | @dataclass 19 | class CondenserPlatesConfig(BaseGeneratorConfig): 20 | height: float 21 | distance: float 22 | theta: float 23 | conductor: bool 24 | 25 | 26 | class CondenserPlatesGenerator(GeneratorBase): 27 | def __init__( 28 | self, 29 | n_samples: int, 30 | mesh_generator: MeshGeneratorBase, 31 | height_range: Tuple[float, float], 32 | distance_range: Tuple[float, float], 33 | n_rotations: int = 1, 34 | conductor: bool = False, 35 | ): 36 | """Creates a set of condenser plates and optionally a conductor in between. 37 | 38 | Args: 39 | n_samples (int): number of samples in dataset to generate 40 | mesh_generator (MeshGeneratorBase): underlying mesh generator 41 | height_range (Tuple): min and max height 42 | distance_range (Tuple): min and max distance 43 | n_rotations (int): number of rotations 44 | conductor (bool, optional): Include a conductor between the plates. Defaults to False. 45 | """ 46 | self.n_rotations = n_rotations 47 | self.height_range = height_range 48 | self.distance_range = distance_range 49 | self.conductor = conductor 50 | self.n_samples = n_samples // n_rotations 51 | 52 | super().__init__(mesh_generator) 53 | 54 | @staticmethod 55 | def solve_config(config: CondenserPlatesConfig) -> GeneratorSolution: 56 | """Solve condenser config. 57 | 58 | Args: 59 | config (CondenserPlatesConfig): config to solve 60 | 61 | Returns: 62 | GeneratorSolution: fem solution containing the mesh and visible area 63 | """ 64 | parametrized_mesh = get_mesh(config.mesh_config) 65 | problem = ElectrostaticsProblem(parametrized_mesh.mesh) 66 | 67 | problem.boundary_conditions.add_subexpression(parametrized_mesh.default_bc) 68 | problem.charge_density.add_subexpression( 69 | PlateModule( 70 | config.distance, config.height, 1.0, rotation=config.theta, resolution=config.mesh_config.resolution 71 | ) 72 | ) 73 | problem.charge_density.add_subexpression( 74 | PlateModule( 75 | -config.distance, config.height, -1.0, rotation=config.theta, resolution=config.mesh_config.resolution 76 | ) 77 | ) 78 | 79 | if config.conductor: 80 | problem.permittivity.add_subexpression(CircleModule((0.5, 0.5), 0.005, 1000000.0)) 81 | 82 | return build_solution(problem, parametrized_mesh) 83 | 84 | def __iter__(self) -> Iterator[CondenserPlatesConfig]: 85 | """Generate and solve the problems rotating condenser plates and varying height and distance. 86 | 87 | Returns: 88 | Iterator[CondenserPlatesConfig]: Generated solutions 89 | """ 90 | samples_height = int(np.sqrt(self.n_samples)) 91 | samples_distance = self.n_samples // samples_height 92 | 93 | for height in np.linspace(self.height_range[0], self.height_range[1], num=samples_height): 94 | for distance in np.linspace(self.distance_range[0], self.distance_range[1], num=samples_distance): 95 | for theta in np.linspace(0, 2 * np.pi, num=self.n_rotations, endpoint=False): 96 | yield CondenserPlatesConfig( 97 | mesh_config=self.mesh_generator(), 98 | height=height, 99 | distance=distance, 100 | theta=theta, 101 | conductor=self.conductor, 102 | generator_name="CondenserPlatesGenerator", 103 | ) 104 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/models/main_model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from gnn_bvp_solver.fem_trainer.models.gcn_processor import GraphNetMP 3 | from gnn_bvp_solver.fem_trainer.models.gcn_processor_weights import WeightedGraphNetMP 4 | from gnn_bvp_solver.fem_trainer.models.gnn_identity import GNNIdentity 5 | from gnn_bvp_solver.fem_trainer.models.two_layer_mlp import TwoLayerMLP 6 | from torch_geometric.nn import GraphUNet, ChebConv, GCN2Conv 7 | from torch_geometric.utils import dropout_adj 8 | from torch_geometric.utils import coalesce, to_undirected 9 | from gnn_bvp_solver.tricks.dropouts import DropNode 10 | 11 | import torch.nn.functional as F 12 | import torch 13 | 14 | 15 | class MainModel(torch.nn.Module): 16 | def __init__( 17 | self, 18 | dim_in: int, 19 | mlp_hidden: int, 20 | processor_hidden: int, 21 | dim_out: int, 22 | processor: str, 23 | augmentation: List, 24 | remove_pos: bool, 25 | ): 26 | """Main model consisting of encoder - processor - decoder""" 27 | super().__init__() 28 | 29 | self.remove_pos = remove_pos 30 | if self.remove_pos: 31 | dim_in -= 2 32 | 33 | self.pass_edge_weights = False 34 | self.augmentation = [] if augmentation is None else augmentation 35 | processor_dropout = "processor_dropout" in self.augmentation 36 | 37 | if processor == "unet3": 38 | self.processor = GraphUNet(mlp_hidden, processor_hidden, mlp_hidden, depth=3) 39 | elif processor == "unet6": 40 | self.processor = GraphUNet(mlp_hidden, processor_hidden, mlp_hidden, depth=6) 41 | elif processor == "gcn9": 42 | self.processor = GraphNetMP( 43 | 9, mlp_hidden, processor_hidden, mlp_hidden, processor_dropout=processor_dropout 44 | ) 45 | elif processor == "gcn18": 46 | self.processor = GraphNetMP( 47 | 18, mlp_hidden, processor_hidden, mlp_hidden, processor_dropout=processor_dropout 48 | ) 49 | elif processor == "gcn9w": 50 | self.processor = WeightedGraphNetMP( 51 | 9, mlp_hidden, processor_hidden, mlp_hidden, processor_dropout=processor_dropout 52 | ) 53 | self.pass_edge_weights = True 54 | elif processor == "gcnii9w": 55 | self.processor = WeightedGraphNetMP( 56 | 9, mlp_hidden, processor_hidden, mlp_hidden, gcn_type=GCN2Conv, processor_dropout=processor_dropout 57 | ) 58 | self.pass_edge_weights = True 59 | elif processor == "gcnch9w": 60 | self.processor = WeightedGraphNetMP( 61 | 9, 62 | mlp_hidden, 63 | processor_hidden, 64 | mlp_hidden, 65 | gcn_type=ChebConv, 66 | gcn_kwargs={"K": 5}, 67 | processor_dropout=processor_dropout, 68 | ) 69 | self.pass_edge_weights = True 70 | elif processor == "gcnch3w": 71 | self.processor = WeightedGraphNetMP( 72 | 3, 73 | mlp_hidden, 74 | processor_hidden, 75 | mlp_hidden, 76 | gcn_type=ChebConv, 77 | gcn_kwargs={"K": 5}, 78 | processor_dropout=processor_dropout, 79 | ) 80 | self.pass_edge_weights = True 81 | elif processor == "gcnch3w_noew": 82 | self.processor = GraphNetMP( 83 | 3, 84 | mlp_hidden, 85 | processor_hidden, 86 | mlp_hidden, 87 | gcn_type=ChebConv, 88 | gcn_kwargs={"K": 5}, 89 | processor_dropout=processor_dropout, 90 | ) 91 | elif processor == "none": 92 | self.processor = GNNIdentity() 93 | else: 94 | raise ValueError("processor") 95 | 96 | self.encoder = TwoLayerMLP(dim_in, mlp_hidden, mlp_hidden) 97 | self.decoder = TwoLayerMLP(mlp_hidden, mlp_hidden, dim_out, is_output=True) 98 | 99 | if "drop_nodes" in self.augmentation: 100 | self.drop_nodes = DropNode(0.1) 101 | 102 | def compute_edge_weights(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 103 | """Compute edge weights as relative distances between nodes.""" 104 | nodes_s = x[edge_index[0]] 105 | nodes_t = x[edge_index[1]] 106 | 107 | x_dif = nodes_s[:, 0] - nodes_t[:, 0] 108 | y_dif = nodes_s[:, 1] - nodes_t[:, 1] 109 | 110 | return torch.sqrt(torch.square(x_dif) + torch.square(y_dif)) 111 | 112 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 113 | """Forward pass""" 114 | # should be fast here as is on gpu 115 | if self.training: 116 | edge_index = coalesce(to_undirected(edge_index)) 117 | 118 | if "drop_edges" in self.augmentation and self.training: 119 | # dropout from adj matrix 120 | edge_index, _ = dropout_adj(edge_index, p=0.2, force_undirected=True, training=self.training) 121 | 122 | if "drop_nodes" in self.augmentation and self.training: 123 | edge_index, _ = self.drop_nodes(edge_index) 124 | 125 | # compute edge weights if needed 126 | if self.pass_edge_weights: 127 | edge_weights = self.compute_edge_weights(x, edge_index) 128 | 129 | if self.remove_pos: 130 | x = x[:, 2:] # remove positional information 131 | 132 | if "embedding_dropout" in self.augmentation and self.training: 133 | x = F.dropout(x, p=0.2, training=self.training) # perform less aggressive dropout 134 | # like here: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_unet.py 135 | 136 | x = self.encoder(x) 137 | 138 | if self.pass_edge_weights: 139 | x = self.processor(x, edge_index, edge_weights=edge_weights) 140 | else: 141 | x = self.processor(x, edge_index) 142 | 143 | return self.decoder(x) 144 | -------------------------------------------------------------------------------- /gnn_bvp_solver/visualization/vis_inputs.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from matplotlib.colors import ListedColormap 3 | from gnn_bvp_solver.fem_dataset.data_generators.electrics_random_charge import ElectricsRandomChargeGenerator 4 | from gnn_bvp_solver.fem_dataset.data_generators.magnetics_random_current import MagneticsRandomCurrentGenerator 5 | from gnn_bvp_solver.fem_dataset.data_generators.elasticity_fixed_line import ElasticityFixedLineGenerator 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.square_mesh import UnitSquareGenerator 7 | from gnn_bvp_solver.visualization.plot_graph import Visualization 8 | from gnn_bvp_solver.fem_dataset.data_generators.extend_solution import map_extend 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def plot_in_and_outputs( 14 | plot: List = None, 15 | interpolate: bool = False, 16 | plt_mesh: bool = False, 17 | plot_el: bool = True, 18 | savefig: Union[str, None] = None, 19 | ) -> None: 20 | """Plot input and output quantities on a square mesh.""" 21 | 22 | if plot is None: 23 | # plot = ["bdr", "g_field", "q", "bdr_dist", "qin1"] 24 | plot = ["g_field", "q", "qin1"] 25 | 26 | g = UnitSquareGenerator(15) 27 | 28 | mesh = g.solve_config(g()) 29 | print("nodes: ", mesh.mesh.coordinates().shape) 30 | 31 | es = ElectricsRandomChargeGenerator(1, g, 3) 32 | em = MagneticsRandomCurrentGenerator(1, g, 3) 33 | el = ElasticityFixedLineGenerator(1, g, 3) 34 | fem_g = [es, em] 35 | 36 | if plot_el: 37 | fem_g = [el] 38 | 39 | skip = 1 if el in fem_g else 0 40 | _, ax = plt.subplots(1, 3 * len(fem_g) - skip, figsize=(7 * len(fem_g) - skip * 2, 3), squeeze=False) 41 | 42 | i = -1 43 | for _, e in enumerate(fem_g): 44 | # i = -1 45 | jj = 0 46 | cmap = ListedColormap(["black", "#009E73"]) 47 | 48 | for sol in e: 49 | sol = e.solve_config(sol, debug=True) 50 | 51 | if isinstance(e, ElectricsRandomChargeGenerator): 52 | q1 = "u" 53 | tq1 = "(b) El. Potential" 54 | q2, q3 = "E_x", "E_y" 55 | tq23 = "(c) El. Field" 56 | 57 | qin1 = "rho" 58 | tqin1 = "(a) Charge" 59 | elif isinstance(e, MagneticsRandomCurrentGenerator): 60 | q1 = "A" 61 | tq1 = "(e) Magn. Potential" 62 | q2, q3 = "B_x", "B_y" 63 | tq23 = "(f) Magn. Field" 64 | 65 | qin1 = "I" 66 | tqin1 = "(d) Current" 67 | elif isinstance(e, ElasticityFixedLineGenerator): 68 | q1 = None 69 | tq1 = None 70 | q2, q3 = "u_x", "u_y" 71 | tq23 = "Displacement Field" 72 | 73 | qin1 = "bdr_v0" 74 | tqin1 = "Boundary condition" 75 | 76 | if "qin1" in plot: 77 | i += 1 78 | 79 | if qin1 is not None: 80 | vis = Visualization( 81 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 82 | ) 83 | 84 | # if interpolate: 85 | # vis.plot_on_grid(sol.quantities[qin1]) 86 | # else: 87 | vis.plot_on_grid(sol.quantities[qin1], scatter_only=True, cmap=cmap, s=1000.0) 88 | 89 | if plt_mesh: 90 | vis.plot_mesh(mesh.mesh) 91 | 92 | ax[jj, i].set_title(tqin1) 93 | ax[jj, i].imshow(vis.to_numpy()) 94 | 95 | ax[jj, i].axis("off") 96 | 97 | if "bdr" in plot: 98 | vis = Visualization( 99 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 100 | ) 101 | vis.plot_on_grid(sol.quantities["bdr_v0"]) 102 | 103 | i += 1 104 | ax[jj, i].set_title("Boundary condition") 105 | ax[jj, i].imshow(vis.to_numpy()) 106 | ax[jj, i].axis("off") 107 | 108 | if "bdr_dist" in plot: 109 | sol = map_extend(sol) 110 | 111 | vis = Visualization( 112 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 113 | ) 114 | vis.plot_on_grid(sol.quantities["dist_border"]) 115 | vis.quiver_on_grid(sol.quantities["dist_border_x"], sol.quantities["dist_border_y"], interpolate=False) 116 | 117 | # vis2 = Visualization(sol.quantities["x"], sol.quantities["y"], axis_off=True, 118 | # visible_area=sol.visible_area) 119 | # vis2.plot_on_grid(sol.quantities["dist_bc"]) 120 | # vis2.quiver_on_grid(sol.quantities["dist_bc_x"], sol.quantities["dist_bc_y"], interpolate=False) 121 | 122 | i += 1 123 | ax[jj, i].set_title("Distance to border") 124 | ax[jj, i].imshow(vis.to_numpy()) 125 | ax[jj, i].axis("off") 126 | 127 | # i += 1 128 | # ax[j, i].set_title("Distance to boundary") 129 | # ax[j, i].imshow(vis2.to_numpy()) 130 | # ax[j, i].axis("off") 131 | 132 | if "q" in plot: 133 | if q1 is not None: 134 | i += 1 135 | 136 | vis = Visualization( 137 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 138 | ) 139 | 140 | if interpolate: 141 | vis.plot_on_grid(sol.quantities[q1]) 142 | else: 143 | vis.plot_on_grid(sol.quantities[q1], scatter_only=True, s=1000.0) 144 | 145 | if plt_mesh: 146 | vis.plot_mesh(mesh.mesh) 147 | 148 | ax[jj, i].set_title(tq1) 149 | ax[jj, i].imshow(vis.to_numpy()) 150 | 151 | ax[jj, i].axis("off") 152 | 153 | if "g_field" in plot: 154 | vis = Visualization( 155 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 156 | ) 157 | vis.quiver_on_grid(sol.quantities[q2], sol.quantities[q3], interpolate=False) 158 | 159 | if interpolate: 160 | vis.grad_field(sol.quantities[q2], sol.quantities[q3]) 161 | else: 162 | vis.grad_field(sol.quantities[q2], sol.quantities[q3], scatter_only=True, s=1000.0) 163 | if plt_mesh: 164 | vis.plot_mesh(mesh.mesh) 165 | 166 | i += 1 167 | ax[jj, i].set_title(tq23) 168 | ax[jj, i].imshow(vis.to_numpy()) 169 | ax[jj, i].axis("off") 170 | 171 | break 172 | 173 | plt.tight_layout() 174 | if savefig is not None: 175 | plt.savefig(savefig) 176 | plt.show() 177 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/graph_training_module.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | import torch 5 | import wandb 6 | import torchmetrics 7 | 8 | from torch_geometric.data.batch import Batch 9 | from gnn_bvp_solver.fem_trainer.models.main_model import MainModel 10 | from gnn_bvp_solver.visualization.plot_graph import Visualization 11 | 12 | 13 | class GNNModule(pl.LightningModule): 14 | def __init__( 15 | self, 16 | dim_in: int, 17 | dim_out: int, 18 | processor: str = "unet3", 19 | processor_hidden: int = 128, 20 | mlp_hidden: int = 128, 21 | augmentation: List = None, 22 | remove_pos: bool = True, 23 | ): 24 | """PT Lightning module for simple graph net training with MSE loss & Adam.""" 25 | super().__init__() 26 | 27 | self.save_hyperparameters() 28 | self.loss_f = torch.nn.MSELoss() 29 | self.model = MainModel(dim_in, mlp_hidden, processor_hidden, dim_out, processor, augmentation, remove_pos) 30 | 31 | self.mape = torchmetrics.MeanAbsolutePercentageError() 32 | 33 | def relative_absolute_error(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 34 | """Calculate relative absolute error""" 35 | return torch.sum(torch.abs(input - target)) / torch.sum(torch.abs(target)) 36 | 37 | def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: 38 | """Execute model 39 | 40 | Args: 41 | x (torch.Tensor): input 42 | 43 | Returns: 44 | torch.Tensor: model output 45 | """ 46 | return self.model(x, edge_index) 47 | 48 | def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor: 49 | """Train model 50 | 51 | Args: 52 | batch (Batch): batch containing graph structure and data 53 | batch_idx (int): batch index, currently unused 54 | 55 | Returns: 56 | torch.Tensor: aggregated loss 57 | """ 58 | y_hat = self.forward(batch.x, batch.edge_index) 59 | loss = self.loss_f(y_hat.squeeze(), batch.y) 60 | 61 | self.log("train_loss", loss) 62 | return loss 63 | 64 | def vis_potential( 65 | self, 66 | data_x: np.array, 67 | data_y: np.array, 68 | size_x: int = 5, 69 | size_y: int = 4, 70 | vis_a: np.array = None, 71 | res: int = 50, 72 | vmin: float = None, 73 | vmax: float = None, 74 | ) -> None: 75 | """Visualize potential assuming fixed order of data in tensor""" 76 | return ( 77 | Visualization(data_x[:, 0], data_x[:, 1], figsize=(size_x, size_y), visible_area=vis_a, resolution=res) 78 | .plot_on_grid(data_y[:, 0], scatter=False, vmin=vmin, vmax=vmax) 79 | .to_numpy() 80 | ) 81 | 82 | def vis_vector_field( 83 | self, 84 | data_x: np.array, 85 | data_y: np.array, 86 | size_x: int = 5, 87 | size_y: int = 4, 88 | vis_a: np.array = None, 89 | res: int = 50, 90 | vmin: float = None, 91 | vmax: float = None, 92 | ) -> None: 93 | """Visualize vector field assuming fixed order of data in tensor""" 94 | return ( 95 | Visualization(data_x[:, 0], data_x[:, 1], figsize=(size_x, size_y), visible_area=vis_a, resolution=res) 96 | .grad_field(data_y[:, 1], data_y[:, 2], vmin=vmin, vmax=vmax) 97 | .quiver_on_grid(data_y[:, 1], data_y[:, 2], normalize=True, interpolate=False) 98 | .to_numpy() 99 | ) 100 | 101 | def vis_out_gt( 102 | self, 103 | data_x: np.array, 104 | data_y: np.array, 105 | data_y_hat: np.array, 106 | size_x: int = 5, 107 | size_y: int = 4, 108 | log: bool = True, 109 | res: int = 50, 110 | ) -> Union[Tuple, None]: 111 | """Visualize both ground truth and predictions""" 112 | vis_a = None # optionally infer the visible area here from the tensor 113 | 114 | pot_max = max(data_y[:, 0].max(), data_y_hat[:, 0].max()) 115 | pot_min = min(data_y[:, 0].min(), data_y_hat[:, 0].min()) 116 | 117 | ming, maxg = Visualization.get_min_max_grad_field(data_y[:, 1], data_y[:, 2]) 118 | ming_h, maxg_h = Visualization.get_min_max_grad_field(data_y_hat[:, 1], data_y_hat[:, 2]) 119 | v_max = max(maxg, maxg_h) 120 | v_min = min(ming, ming_h) 121 | 122 | potential_gt = self.vis_potential( 123 | data_x, data_y, size_x=size_x, size_y=size_y, vis_a=vis_a, res=res, vmin=pot_min, vmax=pot_max 124 | ) 125 | potential_predicted = self.vis_potential( 126 | data_x, data_y_hat, size_x=size_x, size_y=size_y, vis_a=vis_a, res=res, vmin=pot_min, vmax=pot_max 127 | ) 128 | 129 | vector_f_gt = self.vis_vector_field( 130 | data_x, data_y, size_x=size_x, size_y=size_y, vis_a=vis_a, res=res, vmin=v_min, vmax=v_max 131 | ) 132 | vector_f_predicted = self.vis_vector_field( 133 | data_x, data_y_hat, size_x=size_x, size_y=size_y, vis_a=vis_a, res=res, vmin=v_min, vmax=v_max 134 | ) 135 | 136 | if log: 137 | image_potential = np.concatenate([potential_predicted, potential_gt], axis=1) 138 | image_vector = np.concatenate([vector_f_predicted, vector_f_gt], axis=1) 139 | data = np.concatenate([image_potential, image_vector], axis=0) # image_all 140 | 141 | wandb.log({"Field plots (Pred | GT)": wandb.Image(data)}) 142 | else: 143 | return potential_predicted, potential_gt, vector_f_predicted, vector_f_gt 144 | 145 | def validation_step( 146 | self, 147 | batch: Batch, 148 | batch_idx: int, 149 | metric: str = "val_loss", 150 | ) -> None: 151 | """Validate model""" 152 | y_hat = self.forward(batch.x, batch.edge_index) 153 | loss = self.loss_f(y_hat.squeeze(), batch.y) 154 | self.log(metric, loss) 155 | 156 | self.log("mse_potential", self.loss_f(y_hat.squeeze()[:, 0], batch.y[:, 0])) 157 | self.log("mse_vector_field", self.loss_f(y_hat.squeeze()[:, 1:], batch.y[:, 1:])) 158 | 159 | self.mape(y_hat.squeeze(), batch.y) 160 | self.log("relative_absolute_error", self.relative_absolute_error(y_hat.squeeze(), batch.y)) 161 | self.log("mape", self.mape, on_step=True, on_epoch=True) 162 | 163 | # do not log all validation samples 164 | # log only every 10 epochs 165 | if (np.random.randint(200) > 198 or batch_idx == 0) and self.current_epoch % 10 == 0: 166 | data_x = batch.x.cpu().numpy() 167 | data_y = batch.y.cpu().numpy() 168 | data_y_hat = y_hat.detach().cpu().numpy() 169 | 170 | self.vis_out_gt(data_x, data_y, data_y_hat) 171 | # wandb.log({"Field plots": wandb.Image(img)}) 172 | 173 | def test_step(self, batch: Batch, batch_idx: int) -> None: 174 | """Test model 175 | 176 | Args: 177 | batch (Batch): batch containing graph structure and data 178 | batch_idx (int): batch index, currently unused 179 | """ 180 | self.validation_step(batch, batch_idx, metric="test_loss") 181 | # set index = 0 to plot all results -> takes long time 182 | 183 | def configure_optimizers(self) -> torch.optim.Optimizer: 184 | """Configure optimizer 185 | 186 | Returns: 187 | torch.optim.Optimizer: used torch optimizer 188 | """ 189 | return torch.optim.Adam(self.parameters()) 190 | -------------------------------------------------------------------------------- /gnn_bvp_solver/visualization/vis_quantities.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | from gnn_bvp_solver.fem_dataset.data_generators.elasticity_fixed_line import ElasticityFixedLineGenerator 3 | from gnn_bvp_solver.fem_dataset.data_generators.electrics_random_charge import ElectricsRandomChargeGenerator 4 | from gnn_bvp_solver.fem_dataset.data_generators.magnetics_random_current import MagneticsRandomCurrentGenerator 5 | from gnn_bvp_solver.fem_dataset.mesh_generators.square_mesh import UnitSquareGenerator 6 | from gnn_bvp_solver.fem_dataset.mesh_generators.cylinder_mesh import CylinderGenerator 7 | from gnn_bvp_solver.fem_dataset.mesh_generators.l_mesh import LMeshGenerator 8 | from gnn_bvp_solver.fem_dataset.mesh_generators.disk_mesh import UnitDiskGenerator 9 | from gnn_bvp_solver.fem_dataset.mesh_generators.u_mesh import UMeshGenerator 10 | from gnn_bvp_solver.visualization.plot_graph import Visualization 11 | from gnn_bvp_solver.fem_dataset.data_generators.extend_solution import map_extend 12 | 13 | import matplotlib.pyplot as plt 14 | import wandb 15 | from matplotlib.colors import ListedColormap 16 | 17 | 18 | def plot_quantities_multimesh( 19 | plot: Union[None, List] = None, randomize: bool = False, fem_g: int = 2, show: bool = True 20 | ) -> None: 21 | """Visualize quantities for multiple meshes""" 22 | if plot is None: 23 | plot = ["bdr", "g_field", "q", "bdr_dist"] 24 | 25 | if fem_g == 2: 26 | es = ElectricsRandomChargeGenerator 27 | em = MagneticsRandomCurrentGenerator 28 | 29 | fem_g = [es, em] 30 | elif fem_g == 1: 31 | fem_g = [ElectricsRandomChargeGenerator] 32 | 33 | g_small_square = UnitSquareGenerator(15, randomize) 34 | g_cylinder = CylinderGenerator(randomize) 35 | g_disk = UnitDiskGenerator(15, randomize) 36 | g_lmesh = LMeshGenerator(randomize) 37 | g_umesh = UMeshGenerator(randomize) 38 | 39 | if randomize: 40 | mesh_g = [g_small_square, g_cylinder, g_disk, g_lmesh] # g_umesh 41 | else: 42 | mesh_g = [g_small_square, g_cylinder, g_disk, g_lmesh, g_umesh] # g_umesh 43 | 44 | if "bdr" in plot: 45 | f1, ax1 = plt.subplots(len(fem_g), len(mesh_g), figsize=(len(mesh_g) * 5, len(fem_g) * 5), squeeze=False) 46 | f1.set_tight_layout(True) 47 | 48 | if "q" in plot: 49 | f, ax = plt.subplots(len(fem_g), len(mesh_g), figsize=(len(mesh_g) * 5, len(fem_g) * 5), squeeze=False) 50 | f.set_tight_layout(True) 51 | 52 | if "g_field" in plot: 53 | f2, ax2 = plt.subplots(len(fem_g), len(mesh_g), figsize=(len(mesh_g) * 5, len(fem_g) * 5), squeeze=False) 54 | f2.set_tight_layout(True) 55 | 56 | if "bdr_dist" in plot: 57 | f3, ax3 = plt.subplots(len(fem_g), len(mesh_g), figsize=(len(mesh_g) * 5, len(fem_g) * 5), squeeze=False) 58 | f4, ax4 = plt.subplots(len(fem_g), len(mesh_g), figsize=(len(mesh_g) * 5, len(fem_g) * 5), squeeze=False) 59 | f3.set_tight_layout(True) 60 | f4.set_tight_layout(True) 61 | 62 | for i, g in enumerate(mesh_g): 63 | mesh = g.solve_config(g()) 64 | print("nodes: ", mesh.mesh.coordinates().shape) 65 | fem_instances = [t(1, g, 3) for t in fem_g] 66 | 67 | for j, e in enumerate(fem_instances): 68 | for sol in e: 69 | sol = e.solve_config(sol, debug=True) 70 | 71 | if isinstance(e, ElectricsRandomChargeGenerator): 72 | q1 = "u" 73 | q2, q3 = "E_x", "E_y" 74 | elif isinstance(e, MagneticsRandomCurrentGenerator): 75 | q1 = "A" 76 | q2, q3 = "B_x", "B_y" 77 | elif isinstance(e, ElasticityFixedLineGenerator): 78 | q1 = "m" 79 | q2, q3 = "u_x", "u_y" 80 | 81 | if "bdr" in plot: 82 | vis = Visualization( 83 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 84 | ) 85 | 86 | if "mesh" in plot: 87 | vis.plot_mesh(sol.mesh.mesh) 88 | 89 | cmap = ListedColormap(["black", "#009E73"]) 90 | vis.plot_on_grid(sol.quantities["bdr_v0"], scatter_only=True, cmap=cmap, s=1000.0) 91 | 92 | ax1[j, i].imshow(vis.to_numpy()) 93 | ax1[j, i].axis("off") 94 | 95 | if "q" in plot: 96 | vis = Visualization( 97 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 98 | ) 99 | vis.plot_on_grid(sol.quantities[q1]) 100 | 101 | ax[j, i].imshow(vis.to_numpy()) 102 | ax[j, i].axis("off") 103 | if "g_field" in plot: 104 | vis = Visualization( 105 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 106 | ) 107 | vis.quiver_on_grid(sol.quantities[q2], sol.quantities[q3], interpolate=False) 108 | vis.grad_field(sol.quantities[q2], sol.quantities[q3]) 109 | 110 | ax2[j, i].imshow(vis.to_numpy()) 111 | ax2[j, i].axis("off") 112 | if "bdr_dist" in plot: 113 | sol = map_extend(sol) 114 | 115 | vis = Visualization( 116 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 117 | ) 118 | vis.plot_on_grid(sol.quantities["dist_border"]) 119 | vis.quiver_on_grid( 120 | sol.quantities["dist_border_x"], sol.quantities["dist_border_y"], interpolate=False 121 | ) 122 | 123 | vis2 = Visualization( 124 | sol.quantities["x"], sol.quantities["y"], axis_off=True, visible_area=sol.visible_area 125 | ) 126 | vis2.plot_on_grid(sol.quantities["dist_bc"]) 127 | vis2.quiver_on_grid(sol.quantities["dist_bc_x"], sol.quantities["dist_bc_y"], interpolate=False) 128 | 129 | ax3[j, i].imshow(vis.to_numpy()) 130 | ax3[j, i].axis("off") 131 | 132 | ax4[j, i].imshow(vis2.to_numpy()) 133 | ax4[j, i].axis("off") 134 | 135 | break 136 | 137 | if show: 138 | plt.show() 139 | 140 | 141 | def plot_result_compare(artifact_es: str, artifact_ms: str, plot_n: int = 1, fontsize: int = 25) -> None: 142 | """Compare results saved as wandb artifacts""" 143 | import matplotlib.pyplot as plt 144 | 145 | table1 = wandb.use_artifact(artifact_es).get("visualizations") 146 | table2 = wandb.use_artifact(artifact_ms).get("visualizations") 147 | 148 | _, ax = plt.subplots(2 * plot_n, 4, figsize=(12, 6 * plot_n), squeeze=False) 149 | 150 | for i in range(2 * plot_n): 151 | for j in range(4): 152 | ax[i][j].axis("off") 153 | 154 | ax[0][0].set_title("El. Potential (pred)", fontsize=fontsize) 155 | ax[0][1].set_title("El. Potential (gt)", fontsize=fontsize) 156 | ax[0][2].set_title("El. Field (pred)", fontsize=fontsize) 157 | ax[0][3].set_title("El. Field (gt)", fontsize=fontsize) 158 | 159 | ax[plot_n][0].set_title("Magn. Potential (pred)", fontsize=fontsize) 160 | ax[plot_n][1].set_title("Magn. Potential (gt)", fontsize=fontsize) 161 | ax[plot_n][2].set_title("Magn. Field (pred)", fontsize=fontsize) 162 | ax[plot_n][3].set_title("Magn. Field (gt)", fontsize=fontsize) 163 | 164 | for i in range(2): 165 | tbl = table1 if i == 0 else table2 166 | 167 | for next_plot in range(plot_n): 168 | for plot_image in range(4): 169 | ax[i * plot_n + next_plot][plot_image].imshow(tbl.data[next_plot][plot_image + 1].image) 170 | 171 | plt.tight_layout() 172 | -------------------------------------------------------------------------------- /gnn_bvp_solver/preprocessing/split_and_normalize.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Dict, Callable 3 | from squirrel.driver.msgpack import MessagepackDriver 4 | from squirrel.serialization import MessagepackSerializer 5 | from squirrel.store import SquirrelStore 6 | from squirrel.iterstream import IterableSource, Composable 7 | 8 | import numpy as np 9 | 10 | 11 | N_SAMPLES = 2500 12 | MAX_VALUE = 10.0 13 | 14 | SPLIT_25 = int(N_SAMPLES * 0.25) 15 | SPLIT_50 = int(N_SAMPLES * 0.5) 16 | SPLIT_80 = int(N_SAMPLES * 0.8) 17 | SPLIT_90 = int(N_SAMPLES * 0.9) 18 | 19 | N_SHARD = 100 20 | 21 | 22 | def update_range_dict(range_dict: Dict, name: str, value: np.array, op: Callable = np.maximum) -> None: 23 | """Track maximum and minimum values for normalization""" 24 | if name in range_dict: 25 | range_dict[name] = op(value, range_dict[name]) 26 | else: 27 | range_dict[name] = value 28 | 29 | 30 | def unify_range_dicts(range_dict1: Dict, range_dict2: Dict, op: Callable = np.maximum) -> Dict: 31 | """Unify maximum and minimum values""" 32 | result = {} 33 | 34 | for name in range_dict1: 35 | result[name] = op(range_dict1[name], range_dict2[name]) 36 | 37 | return result 38 | 39 | 40 | def map_update_ranges(sample: Dict, range_dict: Dict) -> Dict: 41 | """Iterate samples and update minimums and maximums""" 42 | max_x = np.amax(np.abs(sample["data_x"]), axis=0) 43 | max_y = np.amax(np.abs(sample["data_y"]), axis=0) 44 | 45 | update_range_dict(range_dict, "x_range", max_x) 46 | update_range_dict(range_dict, "y_range", max_y) 47 | 48 | return sample 49 | 50 | 51 | def get_range_dict(base_url: str, split: str) -> Dict: 52 | """Get maximums and minimums for normalization""" 53 | range_dict = {} 54 | 55 | it = MessagepackDriver(f"{base_url}/{split}").get_iter() 56 | it.map(partial(map_update_ranges, range_dict=range_dict)).tqdm().join() 57 | 58 | return range_dict 59 | 60 | 61 | def save_shard(it: Composable, store: SquirrelStore) -> None: 62 | """Save set of shards""" 63 | store.set(value=list(it)) 64 | 65 | 66 | def scale(sample: Dict, range_dict: Dict) -> Dict: 67 | """Normalize example using the extreme values""" 68 | range_x = np.clip(range_dict["x_range"], a_min=0.000001, a_max=None) 69 | range_y = np.clip(range_dict["y_range"], a_min=0.000001, a_max=None) 70 | 71 | return { 72 | "data_x": sample["data_x"] / range_x.reshape(1, -1), 73 | "data_y": sample["data_y"] / range_y.reshape(1, -1), 74 | "edge_index": sample["edge_index"], 75 | } 76 | 77 | 78 | def filter_max(sample: Dict) -> bool: 79 | """Filter outliers""" 80 | if sample["data_x"].max() > MAX_VALUE: 81 | return False 82 | if sample["data_y"].max() > MAX_VALUE: 83 | return False 84 | return True 85 | 86 | 87 | def save_stream( 88 | it: Composable, output_url: str, split: str, range_dict: Dict = None, filter_outliers: bool = True 89 | ) -> None: 90 | """Scale, filter outliers and save composable as shards""" 91 | if it is None: 92 | return 93 | 94 | store = SquirrelStore(f"{output_url}/{split}", serializer=MessagepackSerializer()) 95 | 96 | if range_dict is not None: 97 | it = it.map(partial(scale, range_dict=range_dict)) 98 | 99 | if filter_outliers: 100 | it = it.filter(filter_max) 101 | 102 | it.batched(N_SHARD, drop_last_if_not_full=False).map(partial(save_shard, store=store)).tqdm().join() 103 | 104 | 105 | def iterate_source_data(fem_generator: str) -> None: 106 | """Filter data for a single generator and iterate if necessary to create splits""" 107 | 108 | mesh_generators = [ 109 | "square", 110 | "disk", 111 | "cylinder", 112 | "l_mesh", 113 | "u_mesh", 114 | "square_extra", 115 | "disk_extra", 116 | "cylinder_extra", 117 | "l_mesh_extra", 118 | "u_mesh_extra", 119 | "square_rand", 120 | "disk_rand", 121 | "cylinder_rand", 122 | "l_mesh_rand", 123 | "u_mesh_rand", 124 | ] 125 | 126 | for mesh_g in mesh_generators: 127 | key = f"{fem_generator}_{mesh_g}" 128 | path = f"gs://squirrel-core-public-data/gnn_bvp_solver/{key}" 129 | iter = MessagepackDriver(path).get_iter() 130 | 131 | print("GENERATING:", fem_generator, mesh_g) 132 | 133 | if mesh_g.startswith("u_mesh"): 134 | if mesh_g == "u_mesh": 135 | # test set 2 136 | # TRAIN1, VAL1, TRAIN2, VAL2, TEST1, TEST2 137 | yield None, None, None, None, None, iter 138 | else: 139 | # all but U-mesh 140 | if mesh_g.endswith("extra"): 141 | all_data = iter.tqdm().collect() 142 | 143 | # test set 1 144 | # TRAIN1, VAL1, TRAIN2, VAL2, TEST1, TEST2 145 | yield None, None, None, None, IterableSource(all_data[:SPLIT_25]), None 146 | elif mesh_g.endswith("rand"): 147 | all_data = iter.tqdm().collect() 148 | 149 | # train/val set 2 150 | # TRAIN1, VAL1, TRAIN2, VAL2, TEST1, TEST2 151 | yield None, None, IterableSource(all_data[:SPLIT_80]), IterableSource(all_data[SPLIT_80:]), None, None 152 | else: 153 | all_data = iter.tqdm().collect() 154 | 155 | # train/val set 1 156 | # TRAIN1, VAL1, TRAIN2, VAL2, TEST1, TEST2 157 | yield IterableSource(all_data[:SPLIT_80]), IterableSource(all_data[SPLIT_80:]), None, None, None, None 158 | 159 | 160 | def scale_and_store(in_split: str, out_split: str, range_dict: Dict, base_url_in: str, base_url_out: str) -> None: 161 | """Normalize one stream and save it""" 162 | it = MessagepackDriver(f"{base_url_in}/{in_split}").get_iter() 163 | save_stream(it, base_url_out, out_split, range_dict) 164 | 165 | 166 | def main(fem_generator: str, out_url: str) -> None: 167 | """Generate split for a single generator""" 168 | for append_train1, append_val1, append_train2, append_val2, append_test1, append_test2 in iterate_source_data( 169 | fem_generator 170 | ): 171 | print("saving splits") 172 | 173 | print("train1") 174 | save_stream(append_train1, out_url, "raw_train1") 175 | 176 | print("val1") 177 | save_stream(append_val1, out_url, "raw_val1") 178 | 179 | print("train2") 180 | save_stream(append_train2, out_url, "raw_train2") 181 | 182 | print("val2") 183 | save_stream(append_val2, out_url, "raw_val2") 184 | 185 | print("test1") 186 | save_stream(append_test1, out_url, "raw_test1") 187 | 188 | print("test2") 189 | save_stream(append_test2, out_url, "raw_test2") 190 | 191 | print("moving on") 192 | 193 | 194 | def main_scale(in_url: str, out_url: str) -> None: 195 | """Apply normalization to generated data""" 196 | range_dict1 = get_range_dict(in_url, "raw_train1") 197 | range_dict2 = get_range_dict(in_url, "raw_train2") 198 | range_dict = unify_range_dicts(range_dict1, range_dict2) 199 | 200 | print("unnormalized ranges: ", range_dict) 201 | print("scale and store") 202 | 203 | print("train") 204 | scale_and_store("raw_train1", "norm_train_no_ma", range_dict, in_url, out_url) 205 | scale_and_store("raw_train2", "norm_train_ma", range_dict, in_url, out_url) 206 | 207 | print("val") 208 | scale_and_store("raw_val1", "norm_val_no_ma", range_dict, in_url, out_url) 209 | scale_and_store("raw_val2", "norm_val_ma", range_dict, in_url, out_url) 210 | 211 | print("test1") 212 | scale_and_store("raw_test1", "norm_test_sup", range_dict, in_url, out_url) 213 | 214 | print("test2") 215 | scale_and_store("raw_test2", "norm_test_shape", range_dict, in_url, out_url) 216 | 217 | 218 | def process(generator_key: str) -> None: 219 | """Process data from a single fem generator""" 220 | base_url_gs = f"gs://squirrel-core-public-data/gnn_bvp_solver/{generator_key}" 221 | base_url = f"data/{generator_key}" # store intermediate results locally 222 | 223 | main(generator_key, base_url) 224 | main_scale(base_url, base_url_gs) 225 | 226 | 227 | if __name__ == "__main__": 228 | for label_g in ["ElectricsRandomChargeGenerator", "MagneticsRandomCurrentGenerator", "ElasticityFixedLineGenerator"]: 229 | process(label_g) 230 | -------------------------------------------------------------------------------- /gnn_bvp_solver/fem_trainer/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pytorch_lightning.loggers import WandbLogger 3 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 4 | from datetime import datetime 5 | from gnn_bvp_solver.fem_dataset.lightning_datamodule import FEMDataModule 6 | from gnn_bvp_solver.fem_dataset.msg_dataset import MsgIterableDataset 7 | from gnn_bvp_solver.fem_trainer.graph_training_module import GNNModule 8 | from squirrel.driver.msgpack import MessagepackDriver, MessagepackSerializer 9 | from squirrel.store import SquirrelStore 10 | from pathlib import Path 11 | from queue import PriorityQueue 12 | 13 | import wandb 14 | import pytorch_lightning as pl 15 | 16 | 17 | class FEMTraining: 18 | def __init__( 19 | self, 20 | data_dir_train: str, 21 | data_dir_val: str, 22 | data_dir_test: str, 23 | project: str = None, 24 | batch_size_train: int = 1, 25 | tags: List = None, 26 | init_wandb: bool = True, 27 | dry_run: bool = False, 28 | profiling: bool = False, 29 | download_data: bool = True, 30 | seed: int = 111, 31 | ): 32 | """Init Fem training. Load data from default directory""" 33 | self.profiling = profiling 34 | self.seed = seed 35 | pl.seed_everything(self.seed) 36 | 37 | if tags is None: 38 | tags = [] 39 | 40 | self.dry_run = dry_run 41 | if self.dry_run: 42 | batch_size_train = 1 43 | tags += ["dry_run"] 44 | 45 | if download_data: 46 | data_dir_train = self.convert_to_local_path(data_dir_train) 47 | data_dir_val = self.convert_to_local_path(data_dir_val) 48 | data_dir_test = self.convert_to_local_path(data_dir_test) 49 | self.data_module = self.load_msgpack(data_dir_train, data_dir_val, data_dir_test, batch_size_train, dry_run) 50 | 51 | if init_wandb and not self.profiling: 52 | if tags is None: 53 | tags = [] 54 | 55 | if project is not None: 56 | wandb.init(tags=tags, project=project) 57 | else: 58 | wandb.init(tags=tags) 59 | 60 | wandb.config.seed = self.seed 61 | 62 | def convert_to_local_path(self, path: str) -> str: 63 | """Cache data locally if we train for multiple epochs.""" 64 | N_SHARD = 100 65 | local_path = path.replace("gs://", "") 66 | local_path = "data/" + local_path.replace("/", "-") 67 | 68 | if Path(local_path).exists(): 69 | print(f"directory {local_path} already exists") 70 | return local_path 71 | 72 | print("downloading data") 73 | store = SquirrelStore(local_path, serializer=MessagepackSerializer()) 74 | driver = MessagepackDriver(path) 75 | driver.get_iter().batched(N_SHARD, drop_last_if_not_full=False).map( 76 | lambda it: store.set(value=list(it)) 77 | ).tqdm().join() 78 | 79 | return local_path 80 | 81 | def load_msgpack( 82 | self, path_train: str, path_val: str, path_test: str, batch_size_train: int, dry_run: bool 83 | ) -> FEMDataModule: 84 | """Create pt lightning datamodule for messagepack dataset.""" 85 | return FEMDataModule( 86 | train_data=MsgIterableDataset(path_train, dry_run=dry_run, shuffle=True), 87 | val_data=MsgIterableDataset(path_val, dry_run=dry_run, shuffle=False), 88 | test_data=MsgIterableDataset(path_test, dry_run=dry_run, shuffle=False), 89 | batch_size_train=batch_size_train, 90 | ) 91 | 92 | def train( 93 | self, 94 | dim_in: int, 95 | dim_out: int, 96 | processor: str, 97 | augmentation: List, 98 | remove_pos: True, 99 | epochs: int = 1000, 100 | cuda: bool = True, 101 | ) -> None: 102 | """Train for fixed number of epochs. Weights and logs are saved to wandb.""" 103 | gpus = 1 if cuda else 0 104 | lightning_module = GNNModule(dim_in, dim_out, processor, augmentation=augmentation, remove_pos=remove_pos) 105 | 106 | if not self.profiling: 107 | wandb_logger = WandbLogger( 108 | log_model=True, name="gnn_bvp_logs", version=datetime.today().strftime("%Y-%m-%d") 109 | ) 110 | wandb_logger.watch(lightning_module.model, log="all") 111 | 112 | if self.dry_run: 113 | checkpoint_callback = ModelCheckpoint(every_n_epochs=0) 114 | trainer = pl.Trainer( 115 | max_epochs=epochs, 116 | gpus=gpus, 117 | logger=wandb_logger, 118 | callbacks=[checkpoint_callback], 119 | log_every_n_steps=1, 120 | overfit_batches=1, 121 | ) 122 | elif self.profiling: 123 | checkpoint_callback = ModelCheckpoint(every_n_epochs=0) 124 | trainer = pl.Trainer( 125 | max_epochs=epochs, gpus=gpus, callbacks=[checkpoint_callback], profiler="simple", limit_val_batches=0 126 | ) 127 | else: 128 | checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=3, save_last=True) 129 | trainer = pl.Trainer( 130 | max_epochs=epochs, gpus=gpus, logger=wandb_logger, callbacks=[checkpoint_callback], log_every_n_steps=50 131 | ) 132 | 133 | trainer.fit(lightning_module, datamodule=self.data_module) 134 | 135 | if not self.profiling: 136 | wandb.finish() 137 | 138 | def test(self, artifact_reference: str, cuda: bool = True, project: str = None) -> None: 139 | """Download model artifact from wandb and test, which requires still to copy the artifact name manually. 140 | 141 | Args: 142 | artifact_reference (str): Artifact name from wandb console. 143 | cuda (bool, optional): Use GPU for testing. Defaults to True. 144 | """ 145 | gpus = 1 if cuda else 0 146 | 147 | if project is not None: 148 | wandb.init(project=project) 149 | else: 150 | wandb.init() 151 | 152 | wandb_logger = WandbLogger(name="gnn_bvp_logs", version=datetime.today().strftime("%Y-%m-%d")) 153 | 154 | artifact = wandb.use_artifact(artifact_reference, type="model") 155 | artifact_dir = artifact.download() 156 | lightning_module = GNNModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt") 157 | 158 | trainer = pl.Trainer(gpus=gpus, logger=wandb_logger, log_every_n_steps=1) 159 | trainer.test(lightning_module, datamodule=self.data_module) 160 | 161 | wandb.finish() 162 | 163 | def vis_cases(self, artifact_reference: str, cuda: bool = True, failure: bool = False) -> None: 164 | """Visualize example cases (either random or highest loss of model)""" 165 | artifact = wandb.use_artifact(artifact_reference, type="model") 166 | artifact_dir = artifact.download() 167 | lightning_module = GNNModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt") 168 | 169 | if cuda: 170 | lightning_module = lightning_module.cuda() 171 | 172 | q = PriorityQueue() 173 | idx = 0 174 | n = 10 175 | 176 | print("testing ..") 177 | for batch in self.data_module.test_dataloader(): 178 | if cuda: 179 | batch = batch.cuda() 180 | 181 | y_hat = lightning_module(batch.x, batch.edge_index) 182 | 183 | # loss = lightning_module.loss_f(y_hat.squeeze(), batch.y) 184 | # print(y_hat.min(), y_hat.mean(), y_hat.max()) 185 | # print(batch.y.min(), batch.y.mean(), batch.y.max()) 186 | # print(loss) 187 | 188 | # compute loss for the VECTOR FIELDS 189 | loss = lightning_module.loss_f(y_hat.squeeze()[:, 1:], batch.y[:, 1:]) 190 | 191 | loss_it = loss.cpu().item() 192 | q.put((-loss_it, idx, (batch, y_hat.cpu().detach()))) 193 | idx += 1 194 | 195 | if idx % 5 == 0: 196 | print(f"\r{idx}", end="") 197 | 198 | if idx >= n and not failure: 199 | print(f"not in failure mode, break after {n} samples") 200 | break 201 | 202 | print(f"select {n} examples to visualize") 203 | 204 | columns = ["Loss", "Pred_pot", "GT_pot", "Pred_vec", "GT_vec"] 205 | table = wandb.Table(columns=columns) 206 | 207 | for _i in range(n): 208 | item = q.get() 209 | print("loss", -item[0]) 210 | 211 | batch, data_y_hat = item[2] 212 | data_x = batch.x.cpu() 213 | data_y = batch.y.cpu() 214 | 215 | potential_predicted, potential_gt, vector_f_predicted, vector_f_gt = lightning_module.vis_out_gt( 216 | data_x, data_y, data_y_hat, log=False, size_x=10, size_y=8, res=100 217 | ) 218 | 219 | table.add_data( 220 | -item[0], 221 | wandb.Image(potential_predicted), 222 | wandb.Image(potential_gt), 223 | wandb.Image(vector_f_predicted), 224 | wandb.Image(vector_f_gt), 225 | ) 226 | 227 | wandb.log({"visualizations": table}) 228 | -------------------------------------------------------------------------------- /gnn_bvp_solver/visualization/plot_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Callable, Tuple 3 | from scipy.interpolate import griddata 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | 9 | class Visualization: 10 | def __init__( 11 | self, 12 | mx: np.array, 13 | my: np.array, 14 | axis_off: bool = True, 15 | resolution: int = 50, 16 | scale: bool = True, 17 | figsize: Tuple[float, float] = (15, 12), 18 | visible_area: Callable[[float, float], bool] = None, 19 | ) -> None: 20 | """Init visualization. 21 | 22 | Args: 23 | mx (np.array): list of x coordinates of nodes in mesh 24 | my (np.array): list of y coordinates of nodes in mesh 25 | axis_off (bool, optional): Turn off axis. Defaults to True. 26 | resolution (int, optional): Resolution of plotting grid. Defaults to 50. 27 | scale (bool, optional): scale grid to data. Defaults to True. 28 | figsize (Tuple[float, float], optional): figsize of plot. 29 | visible_area (Callable[[float, float], bool], optional): area to display 30 | """ 31 | self.mx = mx 32 | self.my = my 33 | self.f, self.ax = plt.subplots(figsize=figsize) 34 | self.c_labelsize = 45 35 | 36 | if scale: 37 | f_x = self.mx.max() - self.mx.min() 38 | f_y = self.my.max() - self.my.min() 39 | resolution_x = (f_x * resolution * 2) / (f_x + f_y) 40 | resolution_y = (f_y * resolution * 2) / (f_x + f_y) 41 | else: 42 | resolution_x = resolution 43 | resolution_y = resolution 44 | 45 | self.x = np.linspace(self.mx.min(), self.mx.max(), int(np.round(resolution_x))) 46 | self.y = np.linspace(self.my.min(), self.my.max(), int(np.round(resolution_y))) 47 | 48 | self.xx, self.yy = np.meshgrid(self.x, self.y) 49 | self.extent = (self.x.min(), self.x.max(), self.y.min(), self.y.max()) 50 | 51 | self.vis = None 52 | if visible_area is not None: 53 | self.vis = np.zeros(self.xx.shape) 54 | for i in range(self.vis.shape[0]): 55 | for j in range(self.vis.shape[1]): 56 | if visible_area(self.xx[i, j], self.yy[i, j]): 57 | self.vis[i, j] = 1.0 58 | 59 | if axis_off: 60 | self.ax.axis("off") 61 | 62 | def plot_on_grid( 63 | self, 64 | data: np.array, 65 | cmap: str = "viridis", 66 | scatter: bool = True, 67 | scatter_only: bool = False, 68 | vmin: float = None, 69 | vmax: float = None, 70 | s: float = None, 71 | ) -> Visualization: 72 | """Plot one-dimensional data as colors on a grid. 73 | 74 | Args: 75 | data (np.array): 1d data to plot. 76 | cmap (str, optional): color map. Defaults to 'viridis'. 77 | scatter (bool, optional): plot node points. Defaults to True. 78 | scatter_only (bool, optional): plot only node points. Defaults to False. 79 | vmin (float, optional): min value for color scale. Defaults to None. 80 | vmax (float, optional): max value for color scale. Defaults to None. 81 | 82 | Returns: 83 | Visualization: return self 84 | """ 85 | if vmax is None: 86 | vmax = data.max() 87 | if vmin is None: 88 | vmin = data.min() 89 | 90 | # interpolation 91 | grid_data = griddata(np.stack([self.mx, self.my], axis=1), data, (self.xx, self.yy), method="linear") 92 | grid_data = np.nan_to_num(grid_data) 93 | 94 | if scatter_only: 95 | self.ax.scatter(self.mx, self.my, c=data, s=25.0 if s is None else s, cmap=cmap, marker="s") 96 | 97 | # Change colorbar 98 | # scat = self.ax.scatter(self.mx, self.my, c=data, s=25.0 if s is None else s, cmap=cmap, marker="s") 99 | # cbar = self.f.colorbar(scat, ax = self.ax, shrink=0.7, ticks=[vmin, vmax]) 100 | # cbar.ax.tick_params(labelsize=self.c_labelsize) 101 | 102 | return self 103 | 104 | if self.vis is not None: 105 | imsh = self.ax.imshow( 106 | grid_data, 107 | extent=self.extent, 108 | cmap=cmap, 109 | origin="lower", 110 | interpolation="bilinear", 111 | vmin=vmin, 112 | vmax=vmax, 113 | alpha=self.vis, 114 | ) 115 | cbar = self.f.colorbar(imsh, ax=self.ax, shrink=0.7) 116 | cbar.ax.tick_params(labelsize=self.c_labelsize) 117 | else: 118 | imsh = self.ax.imshow( 119 | grid_data, extent=self.extent, cmap=cmap, origin="lower", interpolation="bilinear", vmin=vmin, vmax=vmax 120 | ) 121 | cbar = self.f.colorbar(imsh, ax=self.ax, shrink=0.7) 122 | cbar.ax.tick_params(labelsize=self.c_labelsize) 123 | 124 | if scatter: 125 | self.ax.scatter(self.mx, self.my, s=0.5 if s is None else s, c="black") 126 | 127 | return self 128 | 129 | @staticmethod 130 | def get_min_max_grad_field(data_x: np.array, data_y: np.array) -> Tuple: 131 | """Get minimum and maximum gradient values""" 132 | norm = np.sqrt(data_x**2 + data_y**2) 133 | return norm.min(), norm.max() 134 | 135 | def grad_field( 136 | self, 137 | data_x: np.array, 138 | data_y: np.array, 139 | cmap: str = "viridis", 140 | scatter_only: bool = False, 141 | s: float = None, 142 | vmin: float = None, 143 | vmax: float = None, 144 | ) -> Visualization: 145 | """Plot a scalar field""" 146 | norm = np.sqrt(data_x**2 + data_y**2) 147 | 148 | if vmax is None: 149 | vmax = norm.max() 150 | if vmin is None: 151 | vmin = norm.min() 152 | 153 | grid_data = griddata(np.stack([self.mx, self.my], axis=1), norm, (self.xx, self.yy), method="linear") 154 | grid_data = np.nan_to_num(grid_data) 155 | 156 | if scatter_only: 157 | self.ax.scatter(self.mx, self.my, c=norm, s=25.0 if s is None else s, cmap=cmap) 158 | return self 159 | 160 | if self.vis is not None: 161 | imsh = self.ax.imshow( 162 | grid_data, 163 | extent=self.extent, 164 | cmap=cmap, 165 | origin="lower", 166 | interpolation="bilinear", 167 | alpha=self.vis, 168 | vmin=vmin, 169 | vmax=vmax, 170 | ) 171 | cbar = self.f.colorbar(imsh, ax=self.ax, shrink=0.7) 172 | cbar.ax.tick_params(labelsize=self.c_labelsize) 173 | else: 174 | imsh = self.ax.imshow( 175 | grid_data, extent=self.extent, cmap=cmap, origin="lower", interpolation="bilinear", vmin=vmin, vmax=vmax 176 | ) 177 | cbar = self.f.colorbar(imsh, ax=self.ax, shrink=0.7) 178 | cbar.ax.tick_params(labelsize=self.c_labelsize) 179 | 180 | return self 181 | 182 | def quiver_on_grid( 183 | self, data_x: np.array, data_y: np.array, normalize: bool = True, interpolate: bool = True 184 | ) -> Visualization: 185 | """Plot one-dimensional data as colors on a grid. 186 | 187 | Args: 188 | data_x (np.array): x component to plot. 189 | data_y (np.array): y component to plot. 190 | 191 | Returns: 192 | Visualization: return self 193 | """ 194 | if normalize: 195 | norm = np.sqrt(data_x**2 + data_y**2) 196 | data_x = data_x / norm 197 | data_y = data_y / norm 198 | 199 | width = 0.005 200 | hal = 500.0 201 | 202 | if not interpolate: 203 | self.ax.quiver( 204 | self.mx, 205 | self.my, 206 | data_x, 207 | data_y, 208 | headwidth=hal, 209 | headaxislength=hal * 1.5, 210 | headlength=hal * 1.5, 211 | width=width, 212 | color="white", 213 | ) 214 | return self 215 | 216 | # interpolation 217 | grid_x = griddata(np.stack([self.mx, self.my], axis=1), data_x, (self.xx, self.yy), method="linear") 218 | grid_y = griddata(np.stack([self.mx, self.my], axis=1), data_y, (self.xx, self.yy), method="linear") 219 | 220 | if self.vis is not None: 221 | grid_x = grid_x * self.vis 222 | grid_y = grid_y * self.vis 223 | 224 | self.ax.quiver( 225 | self.x, 226 | self.y, 227 | grid_x, 228 | grid_y, 229 | headwidth=hal, 230 | headaxislength=hal, 231 | headlength=hal, 232 | width=width, 233 | color="white", 234 | ) 235 | return self 236 | 237 | def plot_line(self, idx_1: int, idx_2: int) -> None: 238 | """Plot a line identified by coordinate indiced of two grid points""" 239 | self.ax.plot([self.mx[idx_1], self.mx[idx_2]], [self.my[idx_1], self.my[idx_2]], "k-", lw=0.2) 240 | 241 | def plot_mesh(self, mesh: Any) -> None: 242 | """Plot a fenics mesh""" 243 | for c in mesh.cells(): 244 | # note: we assume a triangle mesh 245 | self.plot_line(c[0], c[1]) 246 | self.plot_line(c[1], c[2]) 247 | self.plot_line(c[2], c[0]) 248 | 249 | def show_inline(self) -> None: 250 | """Display the figure in an inline plot""" 251 | self.ax.axis("scaled") 252 | self.f.show() 253 | 254 | def to_numpy(self) -> np.array: 255 | """Save plot to numpy array. 256 | 257 | Returns: 258 | np.array: snapshot of the plot as numpy array 259 | """ 260 | self.ax.axis("scaled") 261 | self.f.set_tight_layout(True) 262 | self.f.canvas.draw() 263 | 264 | result = np.frombuffer(self.f.canvas.tostring_rgb(), dtype=np.uint8) 265 | result = result.reshape(self.f.canvas.get_width_height()[::-1] + (3,)) 266 | plt.close(self.f) 267 | 268 | return result 269 | --------------------------------------------------------------------------------