├── torch_afem ├── __init__.py ├── utils.py ├── Poisson.py └── mesh2d.py ├── docs ├── _config.yml └── index.md ├── README.md ├── setup.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── .gitignore └── example_Poisson.ipynb /torch_afem/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ## Welcome to Torch Finite Element 2 | 3 | A pure finite element package in PyTorch. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchFEM: Another Finite Element Package...in PyTorch 2 | 3 | ## Introduction 4 | 5 | TorchFEM is a compact finite element library based on PyTorch taking advantage of the auto-differentiation of PyTorch to speed up the traditional finite element pipeline, instead of train a black-box model to "solve" PDE. The code base is modified based on the popular MATLAB package [iFEM by Dr. Long Chen](https://lyc102.github.io/ifem/). 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'torch-afem', 5 | packages=find_packages(include=['torch_afem', 'torch_afem.*']), 6 | version = '0.0.1', 7 | license='MIT', 8 | description = 'PyTorch Finite Element Method', 9 | long_description='PyTorch Finite Element Method', 10 | long_description_content_type="text/markdown", 11 | author = 'Shuhao Cao', 12 | author_email = 'scao.math@gmail.com', 13 | url = 'https://github.com/scaomath/torch-fem', 14 | keywords = ['pytorch', 'fem', 'pde'], 15 | install_requires=[ 16 | 'seaborn', 17 | 'torchinfo', 18 | 'numpy', 19 | 'torch>=1.9.0', 20 | 'plotly', 21 | 'scipy', 22 | 'psutil', 23 | 'matplotlib', 24 | 'tqdm', 25 | 'PyYAML', 26 | ], 27 | classifiers=[ 28 | 'Development Status :: 4 - Beta', 29 | 'Intended Audience :: Science/Research', 30 | 'Topic :: Scientific/Engineering :: Mathematics', 31 | 'License :: OSI Approved :: MIT License', 32 | 'Programming Language :: Python :: 3.8', 33 | ], 34 | ) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Shuhao Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .DS_Store 131 | -------------------------------------------------------------------------------- /example_Poisson.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from fem.data import *\n", 12 | "from fem.Poisson import Poisson\n", 13 | "from fem.grf import GRF2d\n", 14 | "import numpy as np\n", 15 | "device = torch.device(\"cpu\")\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "h = 1/64\n", 25 | "n = int(1/h)\n", 26 | "poisson = Poisson(h=h, quadrature_order=1, dtype=torch.float)\n", 27 | "pde = GRF2d(n*4, alpha=2, tau=10, device=device, double=False)\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 3, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "poisson._assemble(pde)\n", 37 | "poisson.solve()\n", 38 | "uh_dir = poisson.get_u().cpu().numpy()\n", 39 | "\n", 40 | "uI = pde.solution(pde.source.squeeze(1)).unsqueeze(1)\n", 41 | "uI = F.interpolate(uI, size=(n+1, n+1),\n", 42 | " mode='bilinear',\n", 43 | " align_corners=True)\n", 44 | "uI = uI.view(-1).numpy()" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 5, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "\n", 57 | "1-th iter\n", 58 | "energy: \t 0.0000e+00 \n", 59 | "L2 error: \t 3.0084e-03 \n", 60 | "Linf error: \t 0.0083640 \n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "optimizer = torch.optim.LBFGS(poisson.parameters(), lr=1)\n", 66 | "num_iter = 1\n", 67 | "energies = []\n", 68 | "# re-init\n", 69 | "poisson = Poisson(h=h, quadrature_order=1, dtype=torch.float)\n", 70 | "poisson._assemble(pde)\n", 71 | "b = poisson.b_int\n", 72 | "\n", 73 | "for k in range(num_iter):\n", 74 | " # flux = poisson.forward(poisson.u)\n", 75 | "\n", 76 | " def closure():\n", 77 | " optimizer.zero_grad()\n", 78 | " loss = poisson.energy(poisson.u, b)\n", 79 | " loss.backward(retain_graph=True)\n", 80 | " return loss\n", 81 | "\n", 82 | " optimizer.step(closure)\n", 83 | "\n", 84 | " with torch.no_grad():\n", 85 | " loss_val = poisson.energy(poisson.u, b)\n", 86 | " print(f\"\\n{k+1}-th iter\")\n", 87 | " print(f\"energy: \\t {loss_val.item():.4e} \")\n", 88 | " energies.append(loss_val)\n", 89 | " uh = poisson.get_u().cpu().numpy()\n", 90 | " errL2 = np.linalg.norm(uh_dir - uh) * h\n", 91 | " errLinf = np.abs(uh_dir - uh).max()\n", 92 | " print(f\"L2 error: \\t {errL2:.4e} \")\n", 93 | " print(f\"Linf error: \\t {errLinf:.7f} \")\n" 94 | ] 95 | } 96 | ], 97 | "metadata": { 98 | "kernelspec": { 99 | "display_name": "Python 3.10.8 64-bit", 100 | "language": "python", 101 | "name": "python3" 102 | }, 103 | "language_info": { 104 | "codemirror_mode": { 105 | "name": "ipython", 106 | "version": 3 107 | }, 108 | "file_extension": ".py", 109 | "mimetype": "text/x-python", 110 | "name": "python", 111 | "nbconvert_exporter": "python", 112 | "pygments_lexer": "ipython3", 113 | "version": "3.10.8" 114 | }, 115 | "orig_nbformat": 4, 116 | "vscode": { 117 | "interpreter": { 118 | "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" 119 | } 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 2 124 | } 125 | -------------------------------------------------------------------------------- /torch_afem/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | try: 4 | import plotly.figure_factory as ff 5 | import plotly.io as pio 6 | import plotly.graph_objects as go 7 | except ImportError as e: 8 | print('Please install Plotly for showing mesh and solutions.') 9 | import matplotlib.pyplot as plt 10 | import matplotlib.tri as tri 11 | 12 | def showmesh(node,elem, **kwargs): 13 | triangulation = tri.Triangulation(node[:,0], node[:,1], elem) 14 | markersize = 3000/len(node) 15 | if kwargs.items(): 16 | h = plt.triplot(triangulation, 'b-h', **kwargs) 17 | else: 18 | h = plt.triplot(triangulation, 'b-h', linewidth=0.5, alpha=0.5, markersize=markersize) 19 | return h 20 | 21 | def showsolution(node,elem,u,**kwargs): 22 | ''' 23 | show 2D solution either of a scalar function or a vector field 24 | ''' 25 | markersize = 3000/len(node) 26 | 27 | if u.ndim == 1: 28 | uplot = ff.create_trisurf(x=node[:,0], y=node[:,1], z=u, 29 | simplices=elem, 30 | colormap="Viridis", # similar to matlab's default colormap 31 | showbackground=False, 32 | aspectratio=dict(x=1, y=1, z=1), 33 | ) 34 | fig = go.Figure(data=uplot) 35 | 36 | elif u.ndim == 2 and u.shape[-1] == 2: 37 | assert u.shape[0] == elem.shape[0] 38 | u /= (np.abs(u)).max() 39 | center = node[elem].mean(axis=1) 40 | uplot = ff.create_quiver(x=center[:,0], y=center[:,1], 41 | u=u[:,0], v=u[:,1], 42 | scale=.05, 43 | arrow_scale=.5, 44 | name='gradient of u', 45 | line_width=1, 46 | ) 47 | 48 | fig = go.Figure(data=uplot) 49 | 50 | fig.update_layout(template='plotly_dark', 51 | margin=dict(l=5, r=5, t=5, b=5), 52 | **kwargs) 53 | fig.show() 54 | 55 | def unique(x, 56 | sorted=False, 57 | return_counts=False, 58 | dim=None): 59 | """ 60 | modified from 61 | https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810 62 | 63 | Args: 64 | input (Tensor): the input tensor 65 | sorted (bool): Whether to sort the unique elements in ascending order 66 | before returning as output. 67 | return_indices (bool): If True, also return the indices of ar (along the specified axis, if provided, or in the flattened array) that result in the unique array (added in this script). 68 | return_inverse (bool): Whether to also return the indices for where 69 | elements in the original input ended up in the returned unique list. 70 | return_counts (bool): Whether to also return the counts for each unique 71 | element. 72 | dim (int): the dimension to apply unique. If ``None``, the unique of the 73 | flattened input is returned. default: ``None`` 74 | 75 | Returns: 76 | (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing 77 | 78 | - **output** (*Tensor*): the output list of unique scalar elements. 79 | - **indices**: (optional) if :attr:`return_indices` is True, the indices of the first occurrences of the unique values in the original array. 80 | - **inverse_indices** (*Tensor*): (optional) if 81 | :attr:`return_inverse` is True, there will be an additional 82 | returned tensor (same shape as input) representing the indices 83 | for where elements in the original input map to in the output; 84 | otherwise, this function will only return a single tensor. 85 | - **counts** (*Tensor*): (optional) if 86 | :attr:`return_counts` is True, there will be an additional 87 | returned tensor (same shape as output or output.size(dim), 88 | if dim was specified) representing the number of occurrences 89 | for each unique value or tensor. 90 | """ 91 | if return_counts: 92 | out, inverse, counts = torch.unique(x, 93 | sorted=sorted, 94 | return_inverse=True, 95 | return_counts=True, 96 | dim=dim) 97 | else: 98 | out, inverse = torch.unique(x, 99 | sorted=sorted, 100 | return_inverse=False, 101 | dim=dim) 102 | perm = torch.arange(inverse.size(0), dtype=inverse.dtype, 103 | device=inverse.device) 104 | inverse, perm = inverse.flip([0]), perm.flip([0]) 105 | # scatter_(dim, index, src) 106 | # Writes all values from the tensor src into self at the indices specified in the index tensor. 107 | indices = inverse.new_empty(out.size(0)).scatter_(0, inverse, perm) 108 | if return_counts: 109 | return out, indices, inverse, counts 110 | 111 | else: 112 | return out, indices, inverse -------------------------------------------------------------------------------- /torch_afem/Poisson.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .mesh2d import * 5 | 6 | class Poisson(nn.Module): 7 | """ 8 | A lightweight Poisson equation solver 9 | 10 | - Linear Lagrange element on triangulations 11 | 12 | Reference: Long Chen's iFEM library 13 | https://github.com/lyc102/ifem/blob/master/equation/Poisson.m 14 | """ 15 | 16 | def __init__( 17 | self, 18 | domain=((0, 1), (0, 1)), 19 | h=1 / 4, 20 | quadrature_order=1, 21 | dtype: torch.dtype = torch.float64, 22 | ) -> None: 23 | super().__init__() 24 | # self.quadpts = quadpts 25 | self.domain = domain 26 | self.h = h 27 | self.mesh_size = int(1/h)+1 28 | self.quadrature = quadpts(order=quadrature_order, dtype=dtype) 29 | self.dtype = dtype 30 | self._initialize() 31 | 32 | def _initialize(self) -> None: 33 | # TODO update default TriMesh2D options (done) 34 | 35 | node, elem = rectangleMesh( 36 | x_range=self.domain[0], y_range=self.domain[1], h=self.h 37 | ) 38 | self.trimesh = TriMesh2D(node, elem, dtype=self.dtype) 39 | # TODO set u to be parameters (done) 40 | self.freeNode = freeNode = self.trimesh.freeNode 41 | self.nDof = nDof = int(freeNode.sum()) 42 | self.nNode = nNode = node.size(0) 43 | self.nElem = elem.size(0) 44 | 45 | self.register_parameter( 46 | "u", nn.Parameter(torch.zeros((nDof, 1), dtype=self.dtype)) 47 | ) 48 | nn.init.zeros_(self.u) 49 | self.register_buffer("uh", torch.zeros((nNode, 1), dtype=self.dtype)) 50 | 51 | def _assemble(self, pde): 52 | node = self.trimesh.node 53 | elem = self.trimesh.elem 54 | gradPhi = self.trimesh.gradLambda 55 | area = self.trimesh.area 56 | nElem = self.nElem 57 | nNode = self.nNode 58 | 59 | phi, weight = self.quadrature 60 | 61 | # quadrature points 62 | quadPts = torch.einsum("qp, npd->qnd", phi, node[elem]) 63 | 64 | # diffusion coefficient 65 | Kp = torch.stack([pde.diffusion_coeff(p) for p in quadPts], dim=0) 66 | K = torch.einsum("q, qn->n", weight, Kp) 67 | 68 | intgradPhiAgradPhi = torch.einsum( 69 | "n,n,ndi,ndj->nij", K, area, gradPhi, gradPhi) 70 | 71 | I = elem[:, :, None].expand_as(intgradPhiAgradPhi) 72 | J = elem[:, None, :].expand_as(intgradPhiAgradPhi) 73 | IJ = torch.stack([I, J]) 74 | A = torch.sparse_coo_tensor( 75 | IJ.view(2, -1), 76 | intgradPhiAgradPhi.contiguous().view(-1), 77 | size=(nNode, nNode), 78 | ) 79 | 80 | # right hand side 81 | b = torch.zeros((nNode, 1), dtype=self.dtype) 82 | 83 | if callable(pde.source): 84 | fK = torch.stack([pde.source(p) for p in quadPts], dim=0) 85 | bt = torch.einsum("q, qn, qp, n->np", weight, fK, phi, area) 86 | elif torch.is_tensor(pde.source): 87 | 88 | if pde.source.size(-1) == 1: # (bsz, nNode, 1) 89 | pass 90 | else: # (bsz, n, n) 91 | f = F.interpolate(pde.source, 92 | size=(self.mesh_size+1, self.mesh_size+1), 93 | mode='bilinear', 94 | align_corners=True) 95 | fK = f.view(-1)[elem].mean(-1) 96 | bt = torch.einsum("q, n, qp, n->np", weight, fK, phi, area) 97 | 98 | b.scatter_(0, index=elem.view(-1, 1), src=bt.view(-1, 1), reduce="add") 99 | 100 | isBdNode = self.trimesh.isBdNode 101 | freeNode = self.trimesh.freeNode 102 | 103 | self.uh.scatter_( 104 | 0, 105 | index=torch.where(isBdNode)[0].unsqueeze(-1), 106 | src=pde.g_D(node[isBdNode]).unsqueeze(-1), 107 | ) 108 | b -= torch.sparse.mm(A, self.uh) 109 | 110 | self.A = A 111 | self.b = b 112 | 113 | A = A.to_dense() 114 | maskFreeNode = torch.outer(freeNode, freeNode) 115 | nDof = self.nDof 116 | A_int = A[maskFreeNode].view(nDof, nDof).to_sparse() 117 | A_int = A_int.coalesce() 118 | b_int = b[freeNode] 119 | self.b_int = b_int 120 | self.A_int = A_int 121 | 122 | def forward(self, u): 123 | return torch.sparse.mm(self.A_int, u) 124 | 125 | def solve(self, f=None) -> None: 126 | """ 127 | direct solver, not working in sparse only 128 | """ 129 | freeNode = self.trimesh.freeNode 130 | 131 | if self.A.is_sparse: 132 | A = self.A.to_dense() 133 | else: 134 | A = self.A.copy() 135 | b = self.b if f is None else f 136 | 137 | self.u.detach_() 138 | self.u = nn.Parameter( 139 | torch.linalg.solve(A[freeNode, :][:, freeNode], b[freeNode]) 140 | ) 141 | 142 | def get_u(self): 143 | """ 144 | assemble u and u_g back into 1 145 | """ 146 | self.uh[self.trimesh.freeNode] = self.u.detach() 147 | return self.uh.squeeze() 148 | 149 | def energy(self, u, b): 150 | """ 151 | 0.5*u^T A u - f*u 152 | """ 153 | Au = self.forward(u) 154 | return 0.5 * (u.T).mm(Au) - (b.T).mm(u) 155 | -------------------------------------------------------------------------------- /torch_afem/mesh2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import * 4 | 5 | def rectangleMesh(x_range=(0,1), y_range=(0,1), h=0.25): 6 | """ 7 | Input: 8 | - x's range, (x_min, x_max) 9 | - y's range, (y_min, y_max) 10 | - h, mesh size, can be a tuple 11 | Return the element matrix (NT, 3) 12 | of the mesh a torch.meshgrid 13 | """ 14 | try: 15 | hx, hy = h[0], h[1] 16 | except: 17 | hx, hy = h, h 18 | 19 | # need to add h because arange is not inclusive 20 | xp = torch.arange(x_range[0], x_range[1]+hx, hx) 21 | yp = torch.arange(y_range[0], y_range[1]+hy, hy) 22 | nx, ny = len(xp), len(yp) 23 | 24 | x, y = torch.meshgrid(xp, yp) 25 | 26 | elem = [] 27 | for j in range(ny-1): 28 | for i in range(nx-1): 29 | a = i + j*nx 30 | b = (i+1) + j*nx 31 | d = i + (j+1)*nx 32 | c = (i+1) + (j+1)*nx 33 | elem += [[a, c, d], [b, c, a]] 34 | 35 | node = torch.stack([x.flatten(), y.flatten()], dim=-1) 36 | elem = torch.tensor(elem, dtype=torch.long) 37 | return node, elem 38 | 39 | def quadpts(order=2, dtype=torch.float64): 40 | ''' 41 | ported from iFEM's quadpts 42 | ''' 43 | 44 | if order == 1: # Order 1, nQuad 1 45 | baryCoords = [[1/3, 1/3, 1/3]] 46 | weight = [1] 47 | elif order == 2: # Order 2, nQuad 3 48 | baryCoords = [[2/3, 1/6, 1/6], 49 | [1/6, 2/3, 1/6], 50 | [1/6, 1/6, 2/3]] 51 | weight = [1/3, 1/3, 1/3] 52 | elif order == 3: # Order 3, nQuad 4 53 | baryCoords = [[1/3, 1/3, 1/3], 54 | [0.6, 0.2, 0.2], 55 | [0.2, 0.6, 0.2], 56 | [0.2, 0.2, 0.6]] 57 | weight = [-27/48, 25/48, 25/48, 25/48] 58 | elif order == 4: # Order 4, nQuad 6 59 | baryCoords = [[0.108103018168070, 0.445948490915965, 0.445948490915965], 60 | [0.445948490915965, 0.108103018168070, 0.445948490915965], 61 | [0.445948490915965, 0.445948490915965, 0.108103018168070], 62 | [0.816847572980459, 0.091576213509771, 0.091576213509771], 63 | [0.091576213509771, 0.816847572980459, 0.091576213509771], 64 | [0.091576213509771, 0.091576213509771, 0.816847572980459], ] 65 | weight = [0.223381589678011, 0.223381589678011, 0.223381589678011, 66 | 0.109951743655322, 0.109951743655322, 0.109951743655322] 67 | return torch.tensor(baryCoords, dtype=dtype), torch.tensor(weight, dtype=dtype) 68 | 69 | class TriMesh2D(nn.Module): 70 | ''' 71 | Set up auxiliary data structures for Dirichlet boundary condition 72 | 73 | 74 | Combined the following routine from Long Chen's iFEM 75 | - setboundary: get a boundary bool matrix according to elem 76 | - delmesh: delete mesh by eval() 77 | - auxstructure: edge-based auxiliary data structure 78 | - gradbasis: compute the gradient of local barycentric coords 79 | 80 | Input: 81 | - node: (N, 2) 82 | - elem: (NT, 3) 83 | 84 | Outputs: 85 | - edge: (NE, 2) global indexing of edges 86 | - elem2edge: (NT, 3) local to global indexing 87 | - edge2edge: (NE, 4) 88 | edge2elem[e,:2] are the global indexes of two elements sharing the e-th edge 89 | edge2elem[e,-2:] are the local indices of e to edge2elem[e,:2] 90 | - neighbor: (NT, 3) the local to global indices map of neighbor of elements 91 | neighbor[t,i] is the global index of the element opposite to the i-th vertex of the t-th element. 92 | 93 | Example: the following routine gets all ifem similar data structures 94 | node, elem = rectangleMesh(x_range=(0,1), y_range=(0,1), h=1/16) 95 | T = TriMesh2D(node,elem) 96 | T.delete_mesh('(x>0) & (y<0)') 97 | T.update_auxstructure() 98 | T.update_gradbasis() 99 | node, elem = T.node, T.elem 100 | Dphi = T.Dlambda 101 | area = T.area 102 | elem2edgeSign = T.elem2edgeSign 103 | edge2elem = T.edge2elem 104 | 105 | Notes: 106 | 1. Python assigns the first appeared entry's index in unique; Matlab assigns the last appeared entry's index in unique. 107 | 2. Matlab uses columns as natural indexing, reshape(NT, 3) in Matlab should be changed to 108 | reshape(3, -1).T in Python if initially the data is concatenated along axis=0 using torch.r_[]. 109 | 110 | TODO: 111 | - Add Neumann boundary. 112 | - Change torch.bincount to torch.scatter 113 | 114 | ''' 115 | 116 | def __init__(self, 117 | node=None, 118 | elem=None, 119 | bdFlag=None, 120 | dtype: torch.dtype = torch.float64, 121 | ) -> None: 122 | super().__init__() 123 | 124 | self.dtype = dtype 125 | self.node = node.to(dtype) 126 | self.elem = elem 127 | self.bdFlag = bdFlag 128 | self._init_auxstruct() 129 | self._init_grad() 130 | 131 | def _init_auxstruct(self): 132 | elem = self.elem 133 | numElem = self.elem.size(0) 134 | numNode = self.node.size(0) 135 | 136 | # every edge's sign 137 | allEdge = torch.cat( 138 | [elem[:, [1, 2]], elem[:, [2, 0]], elem[:, [0, 1]]], dim=0) 139 | elem2edgeSign = torch.ones(3*numElem, dtype=int) 140 | elem2edgeSign[allEdge[:, 0] > allEdge[:, 1]] = -1 141 | self.elem2edgeSign = elem2edgeSign.view(3, -1).T 142 | allEdge, _ = torch.sort(allEdge, axis=1) 143 | # TODO indices in sort obj is dummy, can be used 144 | 145 | # edge structures 146 | self.edge, E2e, e2E, counts = unique(allEdge, 147 | return_counts=True, 148 | dim=0) 149 | self.elem2edge = e2E.view(3, -1).T 150 | 151 | # neighbor structures 152 | E2e_reverse = torch.zeros_like(E2e) 153 | E2e_reverse[e2E] = torch.arange(3*numElem) 154 | 155 | k1 = torch.div(E2e, numElem, rounding_mode='floor') 156 | k2 = torch.div(E2e_reverse, numElem, rounding_mode='floor') 157 | t1 = E2e - numElem*k1 158 | t2 = E2e_reverse - numElem*k2 159 | ix = self.isIntEdge = (counts == 2) # interior edge indicator 160 | # edge to elem 161 | self.edge2elem = torch.stack([t1, t2, k1, k2], dim=-1) 162 | 163 | self.neighbor = torch.zeros((numElem, 3), dtype=int) 164 | ixElemLocalEdge1 = torch.stack([t1[ix], k1[ix]], dim=-1) 165 | ixElemLocalEdge2 = torch.stack([t2, k2], dim=-1) 166 | ixElemLocalEdge = torch.cat( 167 | [ixElemLocalEdge1, ixElemLocalEdge2], dim=0) 168 | ixElem = torch.cat([t2[ix], t1], dim=0) 169 | for i in range(3): 170 | ix = (ixElemLocalEdge[:, 1] == i) # i-th edge's neighbor 171 | # TODO: check if bincount is necessary here 172 | self.neighbor[:, i] = torch.bincount(ixElemLocalEdge[ix, 0], 173 | weights=ixElem[ix], 174 | minlength=numElem) 175 | 176 | isBdEdge = (counts == 1) # boundary edge indicator 177 | if self.bdFlag is None: 178 | self.bdFlag = isBdEdge[e2E].view(3, -1).T 179 | Dirichlet = self.edge[isBdEdge] 180 | self.isBdNode = torch.zeros(numNode, dtype=bool) 181 | self.isBdNode[Dirichlet.ravel()] = True 182 | self.freeNode = ~self.isBdNode 183 | 184 | def _init_grad(self): 185 | node, elem = self.node, self.elem 186 | 187 | ve1 = node[elem[:, 2]]-node[elem[:, 1]] 188 | ve2 = node[elem[:, 0]]-node[elem[:, 2]] 189 | ve3 = node[elem[:, 1]]-node[elem[:, 0]] 190 | area = torch.abs(0.5*(-ve3[:, 0]*ve2[:, 1] + ve3[:, 1]*ve2[:, 0])) 191 | gradLambda = torch.zeros((len(elem), 2, 3), dtype=self.dtype) 192 | # (# elem, 2-dim vector, 3 vertices) 193 | 194 | gradLambda[..., 2] = torch.stack( 195 | [-ve3[:, 1]/(2*area), ve3[:, 0]/(2*area)], dim=-1) 196 | gradLambda[..., 0] = torch.stack( 197 | [-ve1[:, 1]/(2*area), ve1[:, 0]/(2*area)], dim=-1) 198 | gradLambda[..., 1] = torch.stack( 199 | [-ve2[:, 1]/(2*area), ve2[:, 0]/(2*area)], dim=-1) 200 | # torch.stack with dim=-1 is equivalent to np.c_[] 201 | 202 | self.area = area 203 | self.gradLambda = gradLambda 204 | 205 | def get_elem2edge(self): 206 | return self.elem2edge 207 | 208 | def get_bdFlag(self): 209 | return self.bdFlag 210 | 211 | def get_edge(self): 212 | return self.edge 213 | 214 | def forward(self, x): 215 | return None 216 | 217 | def delete_mesh(self, expr=None): 218 | ''' 219 | Update the mesh by deleting the eval(expr) 220 | ''' 221 | assert expr is not None 222 | node, elem = self.node, self.elem 223 | center = node[elem].mean(axis=1) 224 | x, y = center[:, 0], center[:, 1] 225 | 226 | # delete element 227 | idx = eval(expr) 228 | mask = torch.ones(len(elem), dtype=bool) 229 | mask[idx] = False 230 | elem = elem[mask] 231 | 232 | # re-mapping the indices of vertices 233 | # to remove the unused ones 234 | isValidNode = torch.zeros(len(node), dtype=bool) 235 | indexMap = torch.zeros(len(node), dtype=int) 236 | 237 | isValidNode[elem.ravel()] = True 238 | self.node = node[isValidNode] 239 | 240 | indexMap[isValidNode] = torch.arange(len(self.node)) 241 | self.elem = indexMap[elem] --------------------------------------------------------------------------------