├── .gitignore ├── LICENSE ├── README.md ├── applications ├── README.md ├── arc_length │ ├── beam2d_arc_length_disp_driven.py │ ├── beam2d_arc_length_force_driven.py │ ├── beam2d_dynamic_relaxation_disp_driven.py │ └── demo_analytical.py ├── battery │ ├── input │ │ ├── data │ │ │ ├── results.mat │ │ │ ├── sol_j_t2.mat │ │ │ ├── sol_macro.mat │ │ │ ├── sol_macro_t1.mat │ │ │ ├── sol_micro.mat │ │ │ └── sol_micro_t1.mat │ │ └── mesh │ │ │ ├── macro_mesh.mat │ │ │ └── micro_mesh.mat │ ├── main.py │ ├── matlab_fns.py │ ├── micro.py │ ├── output │ │ ├── data │ │ │ ├── sol_c_t10.svg │ │ │ ├── sol_j_an_t10.svg │ │ │ ├── sol_j_ca_t10.svg │ │ │ ├── sol_p_t10.svg │ │ │ ├── sol_s_an_t10.svg │ │ │ └── sol_s_ca_t10.svg │ │ └── mesh │ │ │ ├── macro_mesh.svg │ │ │ └── micro_mesh.svg │ ├── para.py │ ├── prep.py │ └── utils.py ├── boundary_index │ └── example.py ├── convergence_rate │ └── example.py ├── crystal_plasticity │ ├── calibration.py │ ├── models.py │ ├── polycrystal_neper.py │ └── simple.py ├── dendrite │ ├── README.md │ ├── explicit_fd │ │ ├── example.py │ │ └── input │ │ │ └── json │ │ │ └── params.json │ ├── explicit_fem │ │ ├── example.py │ │ └── input │ │ │ └── json │ │ │ └── params.json │ ├── implicit_fem │ │ ├── example.py │ │ └── input │ │ │ └── json │ │ │ └── params.json │ └── materials │ │ ├── explicit_fd.png │ │ ├── explicit_fem.png │ │ ├── implicit_fem.png │ │ └── paper.png ├── dynamic_relaxation │ ├── beam3d.py │ ├── cellular_solid.py │ └── input │ │ └── abaqus │ │ └── cellular_solid.inp ├── forming │ ├── model.py │ ├── single_cell.py │ └── thin_plate.py ├── nodal_stress │ └── example.py ├── outdated │ ├── aesthetic │ │ ├── README.md │ │ ├── arguments.py │ │ ├── bridge.py │ │ ├── building.py │ │ ├── debug.py │ │ ├── image_utils.py │ │ ├── input │ │ │ ├── contents │ │ │ │ ├── dancing.jpg │ │ │ │ ├── structure.jpg │ │ │ │ └── tree.png │ │ │ └── styles │ │ │ │ ├── calligraphy.png │ │ │ │ ├── circles.jpg │ │ │ │ ├── circles.png │ │ │ │ ├── circles_ppt_made.png │ │ │ │ ├── moha.png │ │ │ │ ├── picasso.jpg │ │ │ │ ├── square.jpg │ │ │ │ ├── square.png │ │ │ │ ├── strip.png │ │ │ │ ├── tree.jpg │ │ │ │ ├── tree.png │ │ │ │ ├── voronoi.jpg │ │ │ │ ├── voronoi.png │ │ │ │ └── wave.jpg │ │ ├── main_not_working.py │ │ ├── models.py │ │ ├── modules.py │ │ ├── style_loss.py │ │ ├── to.py │ │ ├── tree.py │ │ └── tree_utils.py │ ├── design │ │ ├── curve.py │ │ └── input │ │ │ └── abaqus │ │ │ └── beam.inp │ ├── fem_examples │ │ ├── README.md │ │ ├── fenicsx_examples.py │ │ ├── jax_examples.py │ │ └── utils.py │ ├── full_field_infer │ │ ├── README.md │ │ ├── poisson.py │ │ └── utils.py │ ├── multi_scale │ │ ├── README.md │ │ ├── arguments.py │ │ ├── deploy.py │ │ ├── fem_model.py │ │ ├── run.sh │ │ ├── rve.py │ │ ├── trainer.py │ │ └── utils.py │ ├── thermal │ │ ├── bare_plate.py │ │ ├── models.py │ │ ├── thin_wall.py │ │ └── utils.py │ └── top_opt │ │ ├── L_shape.py │ │ ├── README.md │ │ ├── box.py │ │ ├── eigen.py │ │ ├── fem_model.py │ │ ├── freecad.py │ │ ├── mma.py │ │ ├── multi_material.py │ │ └── utils.py ├── periodic_bc │ ├── README.md │ ├── example.py │ └── fenics.py ├── plasticity_gradient │ └── example.py ├── quiet_element │ ├── example.py │ └── input │ │ ├── abaqus │ │ └── thinwall.inp │ │ └── toolpath │ │ └── thinwall_toolpath.crs ├── robin_bc │ ├── example.py │ ├── fenics.py │ └── input │ │ └── numpy │ │ ├── cells_u.npy │ │ └── points_u.npy ├── serendipity │ ├── example.py │ ├── input │ │ └── abaqus │ │ │ └── cube.inp │ └── remarks_using_pprof.md ├── stokes │ ├── example.py │ ├── example_1var.py │ ├── example_2vars.py │ ├── fenics.py │ └── input │ │ ├── numpy │ │ ├── cells_p.npy │ │ ├── cells_u.npy │ │ ├── points_p.npy │ │ └── points_u.npy │ │ └── xml │ │ ├── dolfin_fine.xml.gz │ │ └── dolfin_fine_subdomains.xml.gz └── surrogate_model │ ├── example.py │ ├── input │ ├── dataset │ │ ├── e_strain.npy │ │ └── pk2_data.npy │ └── model.pth │ ├── materials │ ├── RVE.png │ ├── compression.png │ ├── strech.png │ └── surrogatemodel.png │ ├── readme.md │ └── train.py ├── demos ├── README.md ├── hyperelasticity │ ├── README.md │ ├── example.py │ └── materials │ │ └── sol.png ├── inverse │ ├── README.md │ └── example.py ├── linear_elasticity │ ├── README.md │ ├── example.py │ └── materials │ │ └── sol.png ├── phase_field_fracture │ ├── README.md │ ├── animation.py │ ├── eigen.py │ ├── example.py │ └── materials │ │ ├── disp_force.png │ │ ├── fracture.gif │ │ └── time_disp.png ├── plasticity │ ├── README.md │ ├── example.py │ └── materials │ │ ├── sol.gif │ │ └── stress_strain.png ├── poisson │ ├── README.md │ ├── example.py │ └── materials │ │ └── sol.png ├── thermal_mechanical │ ├── README.md │ ├── animation.py │ ├── example.py │ └── materials │ │ ├── T.gif │ │ ├── line.gif │ │ ├── phase.gif │ │ └── value.gif ├── thermal_mechanical_full │ ├── README.md │ ├── example.py │ ├── fenics.py │ ├── input │ │ └── numpy │ │ │ ├── cells.npy │ │ │ └── points.npy │ └── material │ │ ├── theta.gif │ │ └── uy.gif ├── topology_optimization │ ├── README.md │ ├── animation.py │ ├── example.py │ └── materials │ │ ├── obj_val.png │ │ └── to.gif └── wave │ ├── README.md │ ├── example.py │ ├── fenics.py │ ├── input │ └── numpy │ │ ├── cells.npy │ │ └── points.npy │ └── material │ └── pressure.gif ├── environment.yml ├── images ├── ded.gif ├── poisson.png ├── polycrystal_grain.gif ├── polycrystal_stress.gif ├── stokes_p.png ├── stokes_u.png ├── to.gif └── von_mises.png ├── jax_fem ├── README.md ├── __init__.py ├── basis.py ├── experimental │ ├── adjoint_save_to_local.py │ ├── autodiff_utils.py │ ├── custom_jvp.py │ ├── jit_global.py │ ├── lm_solver.py │ ├── memory.py │ ├── petsc_solver.py │ ├── safe_grad.py │ ├── sparse.py │ └── when_to_jit.py ├── fe.py ├── generate_mesh.py ├── logger_setup.py ├── mma.py ├── problem.py ├── solver.py └── utils.py ├── pyproject.toml └── tests ├── __init__.py └── benchmarks ├── __init__.py ├── __main__.py ├── fenicsx_gold.py ├── hyperelasticity ├── __init__.py ├── __main__.py ├── fenicsx │ ├── sol.pvd │ ├── sol.vtu │ ├── sol000000.pvtu │ ├── sol_p0_000000.vtu │ └── traction.npy └── test_hyper_elasticity.py ├── linear_elasticity_cube ├── __init__.py ├── __main__.py ├── fenicsx │ ├── sol.pvd │ ├── sol.vtu │ ├── sol000000.pvtu │ └── sol_p0_000000.vtu └── test_linear_elasticity_cube.py ├── linear_elasticity_cylinder ├── __init__.py ├── __main__.py ├── fenicsx │ ├── sol.pvd │ ├── sol.vtu │ ├── sol000000.pvtu │ ├── sol_p0_000000.vtu │ └── surface_area.npy └── test_linear_elasticity_cylinder.py ├── linear_poisson ├── __init__.py ├── __main__.py ├── fenicsx │ ├── sol.pvd │ ├── sol.vtu │ ├── sol000000.pvtu │ └── sol_p0_000000.vtu └── test_linear_poisson.py └── plasticity ├── __init__.py ├── __main__.py ├── fenicsx ├── avg_stresses.npy ├── disps.npy ├── sol.pvd ├── sol.vtu ├── sol000000.pvtu └── sol_p0_000000.vtu └── test_plasticity.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/launch.json 2 | *.pyc 3 | 4 | dev_code/ 5 | *reference/ 6 | *.DS_Store 7 | *data/ 8 | *output/ 9 | *defcon/ 10 | *debug/ 11 | tests/benchmarks/*/jax_fem/* 12 | applications/aesthetic/input/models 13 | 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | -------------------------------------------------------------------------------- /applications/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | This folder contains more advanced JAX-FEM examples (mostly without documentation). For more elementary and tutorial purpose examples, please visit [tutorial](https://github.com/tianjuxue/jax-fem/tree/main/demos) for more information. 4 | 5 | TODO (if there is certain interest): 6 | 7 | 1. Mass lumping explicit dynamics 8 | 1. RVE mechanics (solve for fluctuation field) 9 | -------------------------------------------------------------------------------- /applications/arc_length/beam2d_arc_length_disp_driven.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import numpy as onp 4 | import os 5 | import glob 6 | 7 | from jax_fem.problem import Problem 8 | from jax_fem.solver import solver, arc_length_solver_disp_driven 9 | from jax_fem.utils import save_sol 10 | from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh 11 | 12 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 13 | vtk_dir = os.path.join(output_dir, 'vtk') 14 | onp.random.seed(0) 15 | 16 | 17 | class HyperElasticity(Problem): 18 | def get_tensor_map(self): 19 | def psi(F): 20 | E = 1e3 21 | nu = 0.3 22 | mu = E / (2. * (1. + nu)) 23 | kappa = E / (3. * (1. - 2. * nu)) 24 | J = np.linalg.det(F) 25 | Jinv = J**(-2. / 2.) 26 | I1 = np.trace(F.T @ F) 27 | energy = (mu / 2.) * (Jinv * I1 - 2.) + (kappa / 2.) * (J - 1.)**2. 28 | return energy 29 | 30 | P_fn = jax.grad(psi) 31 | 32 | def first_PK_stress(u_grad): 33 | I = np.eye(self.dim) 34 | F = u_grad + I 35 | P = P_fn(F) 36 | return P 37 | 38 | return first_PK_stress 39 | 40 | def get_surface_maps(self): 41 | def surface_map(u, x): 42 | # Some small noise to guide the arc-length solver 43 | return np.array([0., 1e-5]) 44 | return [surface_map] 45 | 46 | 47 | def example(): 48 | ele_type = 'QUAD4' 49 | cell_type = get_meshio_cell_type(ele_type) 50 | Lx, Ly = 50., 2. 51 | 52 | meshio_mesh = rectangle_mesh(Nx=50, Ny=2, domain_x=Lx, domain_y=Ly) 53 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 54 | 55 | files = glob.glob(os.path.join(vtk_dir, f'*')) 56 | for f in files: 57 | os.remove(f) 58 | 59 | def left(point): 60 | return np.isclose(point[0], 0., atol=1e-5) 61 | 62 | def right(point): 63 | return np.isclose(point[0], Lx, atol=1e-5) 64 | 65 | def middle(point): 66 | return np.isclose(point[1], 0., atol=1e-5) & (point[0] > Lx/2. - 2.) & (point[0] < Lx/2. + 2.) 67 | 68 | def zero_dirichlet_val(point): 69 | return 0. 70 | 71 | def compressed_dirichlet_val(point): 72 | return -0.05*Lx 73 | 74 | def small_dirichlet_val(point): 75 | return 0.0*Ly 76 | 77 | dirichlet_bc_info = [[left]*2 + [right]*2, [0, 1, 0, 1], [zero_dirichlet_val]*2 + [compressed_dirichlet_val, small_dirichlet_val]] 78 | 79 | location_fns = [middle] 80 | 81 | problem = HyperElasticity(mesh, vec=2, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns) 82 | 83 | solver_flag = 'arc-length' # 'arc-length' or 'newton' 84 | 85 | if solver_flag == 'arc-length': 86 | # Arc-length solver converges to buckling configuration 87 | u_vec = np.zeros(problem.num_total_dofs_all_vars) 88 | lamda = 0. 89 | Delta_u_vec_dir = np.zeros(problem.num_total_dofs_all_vars) 90 | Delta_lamda_dir = 0. 91 | 92 | for i in range(600): 93 | print(f"\n\nStep {i}, lamda = {lamda}") 94 | u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir = arc_length_solver_disp_driven(problem, u_vec, 95 | lamda, Delta_u_vec_dir, Delta_lamda_dir, Delta_l=0.1, psi=1.) 96 | sol_list = problem.unflatten_fn_sol_list(u_vec) 97 | if i % 10 == 0: 98 | vtk_path = os.path.join(vtk_dir, f'u{i:05d}.vtu') 99 | sol = sol_list[0] 100 | save_sol(problem.fes[0], np.hstack((sol, np.zeros((len(sol), 1)))), vtk_path) 101 | if lamda > 1.: 102 | break 103 | else: 104 | # Newton's solver does not converge to buckling configuration 105 | sol_list = solver(problem, solver_options={'umfpack_solver': {}}) 106 | sol = sol_list[0] 107 | vtk_path = os.path.join(vtk_dir, f'u.vtu') 108 | save_sol(problem.fes[0], np.hstack((sol, np.zeros((len(sol), 1)))), vtk_path) 109 | 110 | 111 | if __name__ == "__main__": 112 | example() 113 | -------------------------------------------------------------------------------- /applications/arc_length/beam2d_arc_length_force_driven.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import os 4 | import glob 5 | 6 | from jax_fem.problem import Problem 7 | from jax_fem.solver import solver, arc_length_solver_force_driven, get_q_vec 8 | from jax_fem.utils import save_sol 9 | from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh 10 | 11 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 12 | vtk_dir = os.path.join(output_dir, 'vtk') 13 | 14 | 15 | class HyperElasticityMain(Problem): 16 | def get_tensor_map(self): 17 | def psi(F): 18 | E = 1e3 19 | nu = 0.3 20 | mu = E / (2. * (1. + nu)) 21 | kappa = E / (3. * (1. - 2. * nu)) 22 | J = np.linalg.det(F) 23 | Jinv = J**(-2. / 2.) 24 | I1 = np.trace(F.T @ F) 25 | energy = (mu / 2.) * (Jinv * I1 - 2.) + (kappa / 2.) * (J - 1.)**2. 26 | return energy 27 | 28 | P_fn = jax.grad(psi) 29 | 30 | def first_PK_stress(u_grad): 31 | I = np.eye(self.dim) 32 | F = u_grad + I 33 | P = P_fn(F) 34 | return P 35 | 36 | return first_PK_stress 37 | 38 | def get_surface_maps(self): 39 | def surface_map(u, x): 40 | # Some small noise to guide the arc-length solver 41 | return np.array([0., 1e-5]) 42 | return [surface_map] 43 | 44 | 45 | class HyperElasticityAux(Problem): 46 | def get_tensor_map(self): 47 | def first_PK_stress(u_grad): 48 | return np.zeros((self.dim, self.dim)) 49 | return first_PK_stress 50 | 51 | def get_surface_maps(self): 52 | def surface_map(u, x): 53 | return np.array([10., 0]) 54 | return [surface_map] 55 | 56 | 57 | def example(): 58 | ele_type = 'QUAD4' 59 | cell_type = get_meshio_cell_type(ele_type) 60 | Lx, Ly = 50., 2. 61 | 62 | meshio_mesh = rectangle_mesh(Nx=50, Ny=2, domain_x=Lx, domain_y=Ly) 63 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 64 | 65 | files = glob.glob(os.path.join(vtk_dir, f'*')) 66 | for f in files: 67 | os.remove(f) 68 | 69 | def left(point): 70 | return np.isclose(point[0], 0., atol=1e-5) 71 | 72 | def right(point): 73 | return np.isclose(point[0], Lx, atol=1e-5) 74 | 75 | def zero_dirichlet_val(point): 76 | return 0. 77 | 78 | def middle(point): 79 | return np.isclose(point[1], 0., atol=1e-5) & (point[0] > Lx/2. - 2.) & (point[0] < Lx/2. + 2.) 80 | 81 | location_fns_middle = [middle] 82 | 83 | dirichlet_bc_info = [[left]*2, [0, 1], [zero_dirichlet_val]*2] 84 | 85 | location_fns_right = [right] 86 | 87 | problem_main = HyperElasticityMain(mesh, vec=2, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns_middle) 88 | problem_aux = HyperElasticityAux(mesh, vec=2, dim=2, ele_type=ele_type, location_fns=location_fns_right) 89 | 90 | q_vec = get_q_vec(problem_aux) 91 | u_vec = np.zeros(problem_main.num_total_dofs_all_vars) 92 | lamda = 0. 93 | Delta_u_vec_dir = np.zeros(problem_main.num_total_dofs_all_vars) 94 | Delta_lamda_dir = 0. 95 | 96 | for i in range(500): 97 | print(f"\n\nStep {i}, lamda = {lamda}") 98 | if i < 200: 99 | Delta_l = 0.1 100 | psi = 0.5 101 | else: 102 | Delta_l = 1. 103 | psi = 0.5 104 | 105 | u_vec, lamda, Delta_u_vec_dir, Delta_lamda_dir = arc_length_solver_force_driven(problem_main, u_vec, 106 | lamda, Delta_u_vec_dir, Delta_lamda_dir, q_vec, Delta_l, psi) 107 | sol_list = problem_main.unflatten_fn_sol_list(u_vec) 108 | if i % 10 == 0: 109 | vtk_path = os.path.join(vtk_dir, f'u{i:05d}.vtu') 110 | sol = sol_list[0] 111 | save_sol(problem_main.fes[0], np.hstack((sol, np.zeros((len(sol), 1)))), vtk_path) 112 | 113 | if lamda > 1.: 114 | break 115 | 116 | 117 | if __name__ == "__main__": 118 | example() 119 | -------------------------------------------------------------------------------- /applications/arc_length/beam2d_dynamic_relaxation_disp_driven.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import numpy as onp 4 | import os 5 | import glob 6 | 7 | from jax_fem.problem import Problem 8 | from jax_fem.solver import solver, dynamic_relax_solve 9 | from jax_fem.utils import save_sol 10 | from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh 11 | 12 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 13 | vtk_dir = os.path.join(output_dir, 'vtk') 14 | onp.random.seed(0) 15 | 16 | 17 | class HyperElasticity(Problem): 18 | def get_tensor_map(self): 19 | def psi(F): 20 | E = 1e3 21 | nu = 0.3 22 | mu = E / (2. * (1. + nu)) 23 | kappa = E / (3. * (1. - 2. * nu)) 24 | J = np.linalg.det(F) 25 | Jinv = J**(-2. / 2.) 26 | I1 = np.trace(F.T @ F) 27 | energy = (mu / 2.) * (Jinv * I1 - 2.) + (kappa / 2.) * (J - 1.)**2. 28 | return energy 29 | 30 | P_fn = jax.grad(psi) 31 | 32 | def first_PK_stress(u_grad): 33 | I = np.eye(self.dim) 34 | F = u_grad + I 35 | P = P_fn(F) 36 | return P 37 | 38 | return first_PK_stress 39 | 40 | 41 | def get_surface_maps(self): 42 | def surface_map(u, x): 43 | # Some small noise to guide the dynamic relaxation solver 44 | return np.array([0., 1e-5]) 45 | return [surface_map] 46 | 47 | 48 | class HyperElasticityAux(Problem): 49 | def get_tensor_map(self): 50 | def psi(F): 51 | E = 1e3 52 | nu = 0.3 53 | mu = E / (2. * (1. + nu)) 54 | kappa = E / (3. * (1. - 2. * nu)) 55 | J = np.linalg.det(F) 56 | Jinv = J**(-2. / 2.) 57 | I1 = np.trace(F.T @ F) 58 | energy = (mu / 2.) * (Jinv * I1 - 2.) + (kappa / 2.) * (J - 1.)**2. 59 | return energy 60 | 61 | P_fn = jax.grad(psi) 62 | 63 | def first_PK_stress(u_grad): 64 | I = np.eye(self.dim) 65 | F = u_grad + I 66 | P = P_fn(F) 67 | return P 68 | 69 | return first_PK_stress 70 | 71 | 72 | def example(): 73 | ele_type = 'QUAD4' 74 | cell_type = get_meshio_cell_type(ele_type) 75 | Lx, Ly = 50., 2. 76 | 77 | meshio_mesh = rectangle_mesh(Nx=50, Ny=2, domain_x=Lx, domain_y=Ly) 78 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 79 | 80 | files = glob.glob(os.path.join(vtk_dir, f'*')) 81 | for f in files: 82 | os.remove(f) 83 | 84 | def left(point): 85 | return np.isclose(point[0], 0., atol=1e-5) 86 | 87 | def right(point): 88 | return np.isclose(point[0], Lx, atol=1e-5) 89 | 90 | def zero_dirichlet_val(point): 91 | return 0. 92 | 93 | def compressed_dirichlet_val(point): 94 | return -0.05*Lx 95 | 96 | def small_dirichlet_val(point): 97 | return 0. 98 | 99 | def middle(point): 100 | return np.isclose(point[1], 0., atol=1e-5) & (point[0] >= Lx/2. - 1.) & (point[0] <= Lx/2. + 1.) 101 | 102 | location_fns = [middle] 103 | 104 | dirichlet_bc_info = [[left]*2 + [right]*2, [0, 1, 0, 1], [zero_dirichlet_val]*2 + [compressed_dirichlet_val, small_dirichlet_val]] 105 | 106 | problem = HyperElasticity(mesh, vec=2, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns) 107 | sol = dynamic_relax_solve(problem, tol=1e-8) 108 | vtk_path = os.path.join(output_dir, f'vtk/u.vtu') 109 | save_sol(problem.fes[0], np.hstack((sol, np.zeros((len(sol), 1)))), vtk_path) 110 | 111 | # The aux problem is to verify that our "noise" is actually very small. 112 | # The aux problem does not have any noise, but still converging to a similar configuration. 113 | problem_aux = HyperElasticityAux(mesh, vec=2, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info) 114 | sol_aux = dynamic_relax_solve(problem_aux, tol=1e-8, initial_guess=sol.reshape(-1)) 115 | vtk_path_aux = os.path.join(output_dir, f'vtk/u_aux.vtu') 116 | save_sol(problem_aux.fes[0], np.hstack((sol_aux, np.zeros((len(sol), 1)))), vtk_path_aux) 117 | 118 | 119 | if __name__ == "__main__": 120 | example() 121 | -------------------------------------------------------------------------------- /applications/arc_length/demo_analytical.py: -------------------------------------------------------------------------------- 1 | """ 2 | The code follows the MATLAB example: 3 | https://www.mathworks.com/matlabcentral/fileexchange/48643-demonstrations-of-newton-raphson-method-and-arc-length-method 4 | """ 5 | import jax 6 | import jax.numpy as np 7 | import numpy as onp 8 | import os 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def example(): 13 | def internal_force(x): 14 | return -x**2 + x 15 | 16 | def slope(x): 17 | return 1 - 2.*x 18 | 19 | def arc_length_solver(prev_u_vec, prev_lamda, q_vec): 20 | def newton_update_helper(x): 21 | A_fn = slope(x) 22 | res_vec = internal_force(x) 23 | return res_vec, A_fn 24 | 25 | psi = 1. 26 | u_vec = prev_u_vec 27 | lamda = prev_lamda 28 | q_vec_mapped = q_vec 29 | 30 | tol = 1e-6 31 | res_val = 1. 32 | step = 0 33 | while (res_val > tol) or (step <= 1): 34 | res_vec, A_fn = newton_update_helper(u_vec) 35 | res_val = np.linalg.norm(res_vec + lamda*q_vec_mapped) 36 | print(f"\nInternal loop step = {step}, res_val = {res_val}") 37 | step += 1 38 | 39 | delta_u_bar = -(res_vec + lamda*q_vec_mapped)/A_fn 40 | delta_u_t = -q_vec_mapped/A_fn 41 | 42 | psi = 1. 43 | Delta_u = u_vec - prev_u_vec 44 | Delta_lamda = lamda - prev_lamda 45 | a1 = delta_u_t**2. + psi**2.*q_vec_mapped**2. 46 | a2 = 2.* (Delta_u + delta_u_bar)*delta_u_t + 2.*psi**2.*Delta_lamda*q_vec_mapped**2. 47 | a3 = (Delta_u + delta_u_bar)**2. + psi**2.*Delta_lamda**2.*q_vec_mapped**2. - Delta_l**2. 48 | print(f"a1 = {a1}, a2 = {a2}, a3 = {a3}") 49 | 50 | delta_lamda1 = (-a2 + np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1) 51 | delta_lamda2 = (-a2 - np.sqrt(a2**2. - 4.*a1*a3))/(2.*a1) 52 | print(f"delta_lamda1 = {delta_lamda1}, delta_lamda2 = {delta_lamda2}") 53 | 54 | delta_lamda = np.maximum(delta_lamda1, delta_lamda2) 55 | delta_u = delta_u_bar + delta_lamda * delta_u_t 56 | 57 | lamda = lamda + delta_lamda 58 | u_vec = u_vec + delta_u 59 | print(f"lamda = {lamda}, u_vec = {u_vec}") 60 | 61 | return u_vec, lamda 62 | 63 | q_vec = -1. 64 | u_vec = 0. 65 | lamda = 0. 66 | u_vecs = [u_vec] 67 | lamdas = [lamda] 68 | # Delta_l = 0.65316 69 | Delta_l = 0.2 70 | 71 | num_arc_length_steps = 2 72 | for i in range(num_arc_length_steps): 73 | print(f"\n\n############################################################################") 74 | print(f"Arc length step = {i}, u_vec = {u_vec}, lambda = {lamda}") 75 | u_vec, lamda = arc_length_solver(u_vec, lamda, q_vec) 76 | u_vecs.append(u_vec) 77 | lamdas.append(lamda) 78 | 79 | u_vecs = onp.array(u_vecs) 80 | lamdas = onp.array(lamdas) 81 | 82 | fig = plt.figure(figsize=(10, 8)) 83 | ax = plt.gca() 84 | ax.cla() 85 | 86 | us = np.linspace(0, 1, 100) 87 | plt.plot(us, internal_force(us), color='black') 88 | 89 | for i in range(num_arc_length_steps + 1): 90 | circle1 = plt.Circle((u_vecs[i], lamdas[i]), Delta_l, color='red', fill=False) 91 | plt.gca().add_patch(circle1) 92 | 93 | circle2 = plt.Circle((u_vecs[i], lamdas[i]), 0.01, color='blue', fill=True) 94 | plt.gca().add_patch(circle2) 95 | 96 | ax.set_xlim((-0.4, 1.2)) 97 | ax.set_ylim((-0.3, 0.5)) 98 | ax.set_aspect('equal') 99 | 100 | 101 | plt.show() 102 | 103 | if __name__ == "__main__": 104 | example() 105 | -------------------------------------------------------------------------------- /applications/battery/input/data/results.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/data/results.mat -------------------------------------------------------------------------------- /applications/battery/input/data/sol_j_t2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/data/sol_j_t2.mat -------------------------------------------------------------------------------- /applications/battery/input/data/sol_macro.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/data/sol_macro.mat -------------------------------------------------------------------------------- /applications/battery/input/data/sol_macro_t1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/data/sol_macro_t1.mat -------------------------------------------------------------------------------- /applications/battery/input/data/sol_micro.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/data/sol_micro.mat -------------------------------------------------------------------------------- /applications/battery/input/data/sol_micro_t1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/data/sol_micro_t1.mat -------------------------------------------------------------------------------- /applications/battery/input/mesh/macro_mesh.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/mesh/macro_mesh.mat -------------------------------------------------------------------------------- /applications/battery/input/mesh/micro_mesh.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/battery/input/mesh/micro_mesh.mat -------------------------------------------------------------------------------- /applications/boundary_index/example.py: -------------------------------------------------------------------------------- 1 | """ 2 | This example shows how to specify boundary locations with point and/or index 3 | 4 | It addresses the Github issue https://github.com/deepmodeling/jax-fem/issues/20 5 | """ 6 | import jax 7 | import jax.numpy as np 8 | import os 9 | 10 | from jax_fem.problem import Problem 11 | from jax_fem.solver import solver 12 | from jax_fem.utils import save_sol 13 | from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh 14 | 15 | 16 | class Poisson(Problem): 17 | def get_tensor_map(self): 18 | return lambda x: x 19 | 20 | def get_surface_maps(self): 21 | def surface_map(u, x): 22 | return -np.array([np.sin(5.*x[0])]) 23 | 24 | return [surface_map, surface_map] 25 | 26 | ele_type = 'QUAD4' 27 | cell_type = get_meshio_cell_type(ele_type) 28 | Lx, Ly = 1., 1. 29 | meshio_mesh = rectangle_mesh(Nx=2, Ny=2, domain_x=Lx, domain_y=Ly) 30 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 31 | 32 | # We consider a problem having global DOF numbering and global cell numbering like the following: 33 | 34 | # 2-----5-----8 35 | # | | | 36 | # | (1) | (3) | 37 | # | | | 38 | # 1-----4-----7 39 | # | | | 40 | # | (0) | (2) | 41 | # | | | 42 | # 0-----3-----6 43 | 44 | # The local face indexing is as the following: 45 | 46 | # 3---[3]---2 47 | # | | 48 | # [1] [2] 49 | # | | 50 | # 0---[0]---1 51 | 52 | # You may define the right boundary point collection and the top boundary point collection as 53 | ind_set_right = np.array([6, 7, 8]) 54 | ind_set_top = np.array([2, 5, 8]) 55 | 56 | # Define boundary locations. 57 | def left(point): 58 | """ 59 | If one argument is passed, it is treated as "point". 60 | """ 61 | return np.isclose(point[0], 0., atol=1e-5) 62 | 63 | def right(point, ind): 64 | """ 65 | If two arguments are passed, the first will be "point" and the second will be "point index". 66 | """ 67 | return np.isin(ind, ind_set_right) 68 | 69 | def bottom(point): 70 | return np.isclose(point[1], 0., atol=1e-5) 71 | 72 | def top(point, ind): 73 | return np.isin(ind, ind_set_top) 74 | 75 | def dirichlet_val_left(point): 76 | return 0. 77 | 78 | def dirichlet_val_right(point): 79 | return 0. 80 | 81 | location_fns1 = [left, right] 82 | value_fns = [dirichlet_val_left, dirichlet_val_right] 83 | vecs = [0, 0] 84 | dirichlet_bc_info = [location_fns1, vecs, value_fns] 85 | 86 | location_fns2 = [bottom, top] 87 | 88 | problem = Poisson(mesh=mesh, vec=1, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns2) 89 | 90 | print(f"\n\nlocation_fns1 is processed to generate Dirichlet node indices: \n{problem.fes[0].node_inds_list}") 91 | print(f"\nwhere node_inds_list[l][j] returns the jth selected node index in Dirichlet set l") 92 | 93 | print(f"\n\nlocation_fns2 is processed to generate boundary indices list: \n{problem.boundary_inds_list}") 94 | print(f"\nwhere boundary_inds_list[k][i, 0] returns the global cell index of the ith selected face of boundary subset k") 95 | print(f" boundary_inds_list[k][i, 1] returns the local face index of the ith selected face of boundary subset k") 96 | -------------------------------------------------------------------------------- /applications/crystal_plasticity/simple.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import numpy as onp 4 | import os 5 | import glob 6 | import matplotlib.pyplot as plt 7 | 8 | from jax_fem.solver import solver, ad_wrapper 9 | from jax_fem.generate_mesh import Mesh, box_mesh_gmsh, get_meshio_cell_type 10 | from jax_fem.utils import save_sol 11 | 12 | from applications.crystal_plasticity.models import CrystalPlasticity 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 15 | 16 | case_name = 'calibration' 17 | 18 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 19 | numpy_dir = os.path.join(data_dir, f'numpy/{case_name}') 20 | vtk_dir = os.path.join(data_dir, f'vtk/{case_name}') 21 | csv_dir = os.path.join(data_dir, f'csv/{case_name}') 22 | 23 | 24 | def problem(): 25 | class CrystalPlasticityModified(CrystalPlasticity): 26 | def set_params(self, all_params): 27 | disp, params = all_params 28 | self.internal_vars = params 29 | self.fes[0].dirichlet_bc_info[-1][-1] = get_dirichlet_top(disp) 30 | self.fes[0].update_Dirichlet_boundary_conditions(self.fes[0].dirichlet_bc_info) 31 | 32 | def targe_val(self, sol): 33 | tgt_volume = 1.01 34 | def det_fn(u_grad): 35 | F = u_grad + np.eye(self.dim) 36 | return np.linalg.det(F) 37 | 38 | u_grads = self.fes[0].sol_to_grad(sol) 39 | vmap_det_fn = jax.jit(jax.vmap(jax.vmap(det_fn))) 40 | crt_volume = np.sum(vmap_det_fn(u_grads) * self.fes[0].JxW) 41 | 42 | square_error = (crt_volume - tgt_volume)**2 43 | 44 | return square_error 45 | 46 | ele_type = 'HEX8' 47 | Nx, Ny, Nz = 1, 1, 1 48 | Lx, Ly, Lz = 1., 1., 1. 49 | 50 | cell_type = get_meshio_cell_type(ele_type) 51 | meshio_mesh = box_mesh_gmsh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type) 52 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 53 | 54 | files = glob.glob(os.path.join(vtk_dir, f'*')) 55 | for f in files: 56 | os.remove(f) 57 | 58 | disps = np.linspace(0., 0.005, 11) 59 | ts = np.linspace(0., 0.5, 11) 60 | 61 | def corner(point): 62 | flag_x = np.isclose(point[0], 0., atol=1e-5) 63 | flag_y = np.isclose(point[1], 0., atol=1e-5) 64 | flag_z = np.isclose(point[2], 0., atol=1e-5) 65 | return np.logical_and(np.logical_and(flag_x, flag_y), flag_z) 66 | 67 | def bottom(point): 68 | return np.isclose(point[2], 0., atol=1e-5) 69 | 70 | def top(point): 71 | return np.isclose(point[2], Lz, atol=1e-5) 72 | 73 | def zero_dirichlet_val(point): 74 | return 0. 75 | 76 | def get_dirichlet_top(disp): 77 | def val_fn(point): 78 | return disp 79 | return val_fn 80 | 81 | dirichlet_bc_info = [[corner, corner, bottom, top], 82 | [0, 1, 2, 2], 83 | [zero_dirichlet_val, zero_dirichlet_val, zero_dirichlet_val, get_dirichlet_top(disps[0])]] 84 | 85 | quat = onp.array([[1, 0., 0., 0.]]) 86 | cell_ori_inds = onp.zeros(len(mesh.cells), dtype=onp.int32) 87 | problem = CrystalPlasticityModified(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, 88 | additional_info=(quat, cell_ori_inds)) 89 | fwd_pred = ad_wrapper(problem) 90 | 91 | def simulation(scale_d): 92 | params = problem.internal_vars 93 | results_to_save = [] 94 | for i in range(2): 95 | print(f"\nStep {i + 1} in {len(ts) - 1}, disp = {disps[i + 1]}") 96 | problem.dt = ts[i + 1] - ts[i] 97 | sol_list = fwd_pred([scale_d*disps[i + 1], params]) 98 | params = problem.update_int_vars_gp(sol_list[0], params) 99 | obj_val = problem.targe_val(sol_list[0]) 100 | print(f"obj_val = {obj_val}") 101 | return obj_val 102 | 103 | simulation(1.) 104 | 105 | grads = jax.grad(simulation)(1.) 106 | print(f"grads = {grads}") 107 | 108 | 109 | if __name__ == "__main__": 110 | problem() 111 | -------------------------------------------------------------------------------- /applications/dendrite/README.md: -------------------------------------------------------------------------------- 1 | # Phase field method for dendrite growth 2 | 3 | ## Background 4 | 5 | This example implements both **implicit finite element method** and **explicit finite element method** to solve the dendrite growth problem in the phase field framework. We compare our results with an explicit finite difference code by [1] and the result from the original paper [2]. 6 | 7 | ## Results 8 | 9 |
10 |
11 |
12 |
14 |
15 |
16 |
19 | Implicit finite element (top left); Explicit finite element (top right) 20 |
21 |22 | Explicit finite difference (bottom left); Kobayashi paper (bottom right) 23 |
24 | 25 | ## References 26 | 27 | [1] https://drzgan.github.io/Python_CFD/Konayashi_1993-main/jax_version/kobayashi_aniso_jax_ZGAN-2.html 28 | 29 | [2] Kobayashi, Ryo. "Modeling and numerical simulations of dendritic crystal growth." Physica D: Nonlinear Phenomena 63, no. 3-4 (1993): 410-423. 30 | 31 | -------------------------------------------------------------------------------- /applications/dendrite/explicit_fd/input/json/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "J": 6, 3 | "K": 1.6, 4 | "T_eq": 1.0, 5 | "a": 0.01, 6 | "alpha": 0.9, 7 | "delta": 0.04, 8 | "dt": 0.0001, 9 | "eps_bar": 0.01, 10 | "gamma": 10.0, 11 | "hx": 0.03, 12 | "hy": 0.03, 13 | "nx": 300, 14 | "ny": 300, 15 | "t_OFF": 0.36, 16 | "tau": 0.0003 17 | } -------------------------------------------------------------------------------- /applications/dendrite/explicit_fem/input/json/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "J": 6, 3 | "K": 1.6, 4 | "T_eq": 1.0, 5 | "a": 0.01, 6 | "alpha": 0.9, 7 | "delta": 0.04, 8 | "dt": 0.0002, 9 | "eps_bar": 0.01, 10 | "gamma": 10.0, 11 | "hx": 0.03, 12 | "hy": 0.03, 13 | "nx": 300, 14 | "ny": 300, 15 | "t_OFF": 0.36, 16 | "tau": 0.0003 17 | } -------------------------------------------------------------------------------- /applications/dendrite/implicit_fem/input/json/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "J": 6, 3 | "K": 1.6, 4 | "T_eq": 1.0, 5 | "a": 0.01, 6 | "alpha": 0.9, 7 | "delta": 0.04, 8 | "dt": 0.001, 9 | "eps_bar": 0.01, 10 | "gamma": 10.0, 11 | "hx": 0.03, 12 | "hy": 0.03, 13 | "nx": 300, 14 | "ny": 300, 15 | "t_OFF": 0.36, 16 | "tau": 0.0003 17 | } -------------------------------------------------------------------------------- /applications/dendrite/materials/explicit_fd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/dendrite/materials/explicit_fd.png -------------------------------------------------------------------------------- /applications/dendrite/materials/explicit_fem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/dendrite/materials/explicit_fem.png -------------------------------------------------------------------------------- /applications/dendrite/materials/implicit_fem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/dendrite/materials/implicit_fem.png -------------------------------------------------------------------------------- /applications/dendrite/materials/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/dendrite/materials/paper.png -------------------------------------------------------------------------------- /applications/dynamic_relaxation/beam3d.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import os 4 | import glob 5 | 6 | from jax_fem.problem import Problem 7 | from jax_fem.solver import solver, dynamic_relax_solve 8 | from jax_fem.utils import save_sol 9 | from jax_fem.generate_mesh import box_mesh_gmsh, get_meshio_cell_type, Mesh 10 | 11 | 12 | class HyperElasticity(Problem): 13 | def get_tensor_map(self): 14 | def psi(F): 15 | E = 10. 16 | nu = 0.3 17 | mu = E/(2.*(1. + nu)) 18 | kappa = E/(3.*(1. - 2.*nu)) 19 | J = np.linalg.det(F) 20 | Jinv = J**(-2./3.) 21 | I1 = np.trace(F.T @ F) 22 | energy = (mu/2.)*(Jinv*I1 - 3.) + (kappa/2.) * (J - 1.)**2. 23 | return energy 24 | P_fn = jax.grad(psi) 25 | 26 | def first_PK_stress(u_grad): 27 | I = np.eye(self.dim) 28 | F = u_grad + I 29 | P = P_fn(F) 30 | return P 31 | return first_PK_stress 32 | 33 | 34 | data_dir = os.path.join(os.path.dirname(__file__), 'output') 35 | files = glob.glob(os.path.join(data_dir, f'vtk/*')) 36 | for f in files: 37 | os.remove(f) 38 | 39 | ele_type = 'HEX8' 40 | cell_type = get_meshio_cell_type(ele_type) 41 | Lx, Ly, Lz = 20., 1., 1. 42 | meshio_mesh = box_mesh_gmsh(Nx=100, Ny=5, Nz=5, Lx=Lx, Ly=Ly, Lz=Lz, data_dir=data_dir, ele_type=ele_type) 43 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 44 | 45 | def left(point): 46 | return np.isclose(point[0], 0., atol=1e-5) 47 | 48 | def right(point): 49 | return np.isclose(point[0], Lx, atol=1e-5) 50 | 51 | def left_dirichlet_val_x1(point): 52 | return 0. 53 | 54 | def left_dirichlet_val_x2(point): 55 | return 0. 56 | 57 | def left_dirichlet_val_x3(point): 58 | return 0. 59 | 60 | def right_dirichlet_val_x1(point): 61 | return -0.2*Lx 62 | 63 | def right_dirichlet_val_x2(point): 64 | return 0.005*Ly 65 | 66 | def right_dirichlet_val_x3(point): 67 | return 0.*Lz 68 | 69 | 70 | dirichlet_bc_info = [[left]*3 + [right]*3, 71 | [0, 1, 2]*2, 72 | [left_dirichlet_val_x1, left_dirichlet_val_x2, left_dirichlet_val_x3, 73 | right_dirichlet_val_x1, right_dirichlet_val_x2, right_dirichlet_val_x3]] 74 | 75 | 76 | problem = HyperElasticity(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info) 77 | sol = dynamic_relax_solve(problem) 78 | 79 | vtk_path = os.path.join(data_dir, f'vtk/u.vtu') 80 | save_sol(problem.fes[0], sol, vtk_path) 81 | -------------------------------------------------------------------------------- /applications/dynamic_relaxation/cellular_solid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import os 4 | import glob 5 | import meshio 6 | 7 | from jax_fem.problem import Problem 8 | from jax_fem.solver import solver, dynamic_relax_solve 9 | from jax_fem.utils import save_sol 10 | from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh 11 | 12 | 13 | class Elasticity(Problem): 14 | def get_tensor_map(self): 15 | def psi(F_2d): 16 | F = np.array([[F_2d[0, 0], F_2d[0, 1], 0.], 17 | [F_2d[1, 0], F_2d[1, 1], 0.], 18 | [0., 0., 1.]]) 19 | E = 70.e3 20 | nu = 0.3 21 | mu = E/(2.*(1. + nu)) 22 | kappa = E/(3.*(1. - 2.*nu)) 23 | J = np.linalg.det(F) 24 | I1 = np.trace(F.T @ F) 25 | Jinv = J**(-2./3.) 26 | energy = (mu/2.)*(Jinv*I1 - 3.) + (kappa/2.) * (J - 1.)**2. 27 | return energy 28 | 29 | P_fn = jax.grad(psi) 30 | 31 | def first_PK_stress(u_grad): 32 | I = np.eye(self.dim) 33 | F = u_grad + I 34 | P = P_fn(F) 35 | return P 36 | return first_PK_stress 37 | 38 | 39 | def simulation(): 40 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 41 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 42 | 43 | files = glob.glob(os.path.join(output_dir, f'vtk/*')) 44 | for f in files: 45 | os.remove(f) 46 | 47 | ele_type = 'TRI6' 48 | cell_type = get_meshio_cell_type(ele_type) 49 | 50 | meshio_mesh = meshio.read(os.path.join(input_dir, f"abaqus/cellular_solid.inp")) 51 | meshio_mesh.points[:, 0] -= np.min(meshio_mesh.points[:, 0]) 52 | meshio_mesh.points[:, 1] -= np.min(meshio_mesh.points[:, 1]) 53 | meshio_mesh.write(os.path.join(output_dir, 'vtk/mesh.vtu')) 54 | 55 | Lx, Ly = np.max(meshio_mesh.points[:, 0]), np.max(meshio_mesh.points[:, 1]) 56 | print(f"Lx={Lx}, Ly={Ly}") 57 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 58 | 59 | def top(point): 60 | return np.isclose(point[1], Ly, atol=1e-5) 61 | 62 | def bottom(point): 63 | return np.isclose(point[1], 0., atol=1e-5) 64 | 65 | def dirichlet_val_bottom(point): 66 | return 0. 67 | 68 | def get_dirichlet_top(disp): 69 | def val_fn(point): 70 | return disp 71 | return val_fn 72 | 73 | disps = -0.15*Ly*np.linspace(1., 1., 1) 74 | 75 | location_fns = [bottom, bottom, top, top] 76 | value_fns = [dirichlet_val_bottom, dirichlet_val_bottom, dirichlet_val_bottom, get_dirichlet_top(disps[0])] 77 | vecs = [0, 1, 0, 1] 78 | 79 | dirichlet_bc_info = [location_fns, vecs, value_fns] 80 | 81 | problem = Elasticity(mesh, ele_type=ele_type, vec=2, dim=2, dirichlet_bc_info=dirichlet_bc_info) 82 | 83 | for i, disp in enumerate(disps): 84 | print(f"\nStep {i + 1} in {len(disps)}, disp = {disp}") 85 | dirichlet_bc_info[-1][-1] = get_dirichlet_top(disp) 86 | problem.fes[0].update_Dirichlet_boundary_conditions(dirichlet_bc_info) 87 | sol = dynamic_relax_solve(problem, tol=1e-6) 88 | vtk_path = os.path.join(output_dir, f'vtk/u_{i + 1:03d}.vtu') 89 | save_sol(problem.fes[0], np.hstack((sol, np.zeros((len(sol), 1)))), vtk_path) 90 | 91 | 92 | if __name__=="__main__": 93 | simulation() 94 | -------------------------------------------------------------------------------- /applications/forming/model.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | 4 | from jax_fem.problem import Problem 5 | 6 | 7 | class Plasticity(Problem): 8 | def custom_init(self): 9 | self.fe = self.fes[0] 10 | self.F_old = np.repeat(np.repeat(np.eye(self.dim)[None, None, :, :], len(self.fe.cells), axis=0), self.fe.num_quads, axis=1) 11 | self.Be_old = np.array(self.F_old) 12 | self.alpha_old = np.zeros((len(self.fe.cells), self.fe.num_quads)) 13 | self.internal_vars = [self.F_old, self.Be_old, self.alpha_old] 14 | 15 | def get_tensor_map(self): 16 | tensor_map, _, _ = self.get_maps() 17 | return tensor_map 18 | 19 | def get_maps(self): 20 | K = 164.e3 21 | G = 80.e3 22 | H1 = 18. 23 | sig0 = 400. 24 | 25 | def get_partial_tensor_map(F_old, be_bar_old, alpha_old): 26 | def first_PK_stress(u_grad): 27 | _, _, tau = return_map(u_grad) 28 | F = u_grad + np.eye(self.dim) 29 | P = tau @ np.linalg.inv(F).T 30 | return P 31 | 32 | def update_int_vars(u_grad): 33 | be_bar, alpha, _ = return_map(u_grad) 34 | F = u_grad + np.eye(self.dim) 35 | return F, be_bar, alpha 36 | 37 | def compute_cauchy_stress(u_grad): 38 | F = u_grad + np.eye(self.dim) 39 | J = np.linalg.det(F) 40 | P = first_PK_stress(u_grad) 41 | sigma = 1./J*P @ F.T 42 | return sigma 43 | 44 | def get_tau(F, be_bar): 45 | J = np.linalg.det(F) 46 | tau = 0.5*K*(J**2 - 1)*np.eye(self.dim) + G*deviatoric(be_bar) 47 | return tau 48 | 49 | def deviatoric(A): 50 | return A - 1./self.dim*np.trace(A)*np.eye(self.dim) 51 | 52 | def return_map(u_grad): 53 | F = u_grad + np.eye(self.dim) 54 | F_inv = np.linalg.inv(F) 55 | F_old_inv = np.linalg.inv(F_old) 56 | f = F @ F_old_inv 57 | f_bar = np.linalg.det(f)**(-1./3.)*f 58 | # be_bar_trial = f @ be_bar_old @ f.T # Seems that there is a bug here, discovered by Jiachen; should be f_bar @ be_bar_old @ f_bar.T 59 | be_bar_trial = f_bar @ be_bar_old @ f_bar.T 60 | s_trial = G*deviatoric(be_bar_trial) 61 | yield_f_trial = np.linalg.norm(s_trial) - np.sqrt(2./3.)*(sig0 + H1*alpha_old) 62 | 63 | def elastic_loading(): 64 | be_bar = be_bar_trial 65 | alpha = alpha_old 66 | tau = get_tau(F, be_bar) 67 | return be_bar, alpha, tau 68 | 69 | def plastic_loading(): 70 | Ie_bar = 1./3.*np.trace(be_bar_trial) 71 | G_bar = Ie_bar*G 72 | Delta_gamma = (yield_f_trial/(2.*G_bar))/(1. + H1/(3.*G_bar)) 73 | direction = s_trial/np.linalg.norm(s_trial) 74 | s = s_trial - 2.*G_bar*Delta_gamma * direction 75 | alpha = alpha_old + np.sqrt(2./3.)*Delta_gamma 76 | be_bar = s/G + Ie_bar*np.eye(self.dim) 77 | tau = get_tau(F, be_bar) 78 | return be_bar, alpha, tau 79 | 80 | return jax.lax.cond(yield_f_trial < 0., elastic_loading, plastic_loading) 81 | 82 | return first_PK_stress, update_int_vars, compute_cauchy_stress 83 | 84 | def tensor_map(u_grad, F_old, Be_old, alpha_old): 85 | first_PK_stress, _, _ = get_partial_tensor_map(F_old, Be_old, alpha_old) 86 | return first_PK_stress(u_grad) 87 | 88 | def update_int_vars_map(u_grad, F_old, Be_old, alpha_old): 89 | _, update_int_vars, _ = get_partial_tensor_map(F_old, Be_old, alpha_old) 90 | return update_int_vars(u_grad) 91 | 92 | def compute_cauchy_stress_map(u_grad, F_old, Be_old, alpha_old): 93 | _, _, compute_cauchy_stress = get_partial_tensor_map(F_old, Be_old, alpha_old) 94 | return compute_cauchy_stress(u_grad) 95 | 96 | return tensor_map, update_int_vars_map, compute_cauchy_stress_map 97 | 98 | def update_int_vars_gp(self, sol, int_vars): 99 | _, update_int_vars_map, _ = self.get_maps() 100 | vmap_update_int_vars_map = jax.jit(jax.vmap(jax.vmap(update_int_vars_map))) 101 | # (num_cells, 1, num_nodes, vec, 1) * (num_cells, num_quads, num_nodes, 1, self.dim) -> (num_cells, num_quads, num_nodes, vec, self.dim) 102 | u_grads = np.take(sol, self.fe.cells, axis=0)[:, None, :, :, None] * self.fe.shape_grads[:, :, :, None, :] 103 | u_grads = np.sum(u_grads, axis=2) # (num_cells, num_quads, vec, self.dim) 104 | updated_int_vars = vmap_update_int_vars_map(u_grads, *int_vars) 105 | return updated_int_vars 106 | 107 | def compute_stress(self, sol, int_vars): 108 | _, _, compute_cauchy_stress = self.get_maps() 109 | vmap_compute_cauchy_stress = jax.jit(jax.vmap(jax.vmap(compute_cauchy_stress))) 110 | # (num_cells, 1, num_nodes, vec, 1) * (num_cells, num_quads, num_nodes, 1, self.dim) -> (num_cells, num_quads, num_nodes, vec, self.dim) 111 | u_grads = np.take(sol, self.fe.cells, axis=0)[:, None, :, :, None] * self.fe.shape_grads[:, :, :, None, :] 112 | u_grads = np.sum(u_grads, axis=2) # (num_cells, num_quads, vec, self.dim) 113 | sigma = vmap_compute_cauchy_stress(u_grads, *int_vars) 114 | return sigma 115 | -------------------------------------------------------------------------------- /applications/forming/single_cell.py: -------------------------------------------------------------------------------- 1 | """Reference 2 | Simo, Juan C., and Thomas JR Hughes. Computational inelasticity. Vol. 7. Springer Science & Business Media, 2006. 3 | Chapter 9: Phenomenological Plasticity Models 4 | """ 5 | import jax 6 | import jax.numpy as np 7 | import jax.flatten_util 8 | import os 9 | import glob 10 | import matplotlib.pyplot as plt 11 | 12 | from jax_fem.problem import Problem 13 | from jax_fem.solver import solver 14 | from jax_fem.utils import save_sol 15 | from jax_fem.generate_mesh import box_mesh_gmsh, get_meshio_cell_type, Mesh 16 | 17 | from applications.forming.model import Plasticity 18 | 19 | def simulation(): 20 | 21 | class SingleCell(Plasticity): 22 | def set_params(self, params): 23 | int_vars, scale = params 24 | self.internal_vars = int_vars 25 | self.fe.dirichlet_bc_info[-1][-1] = get_dirichlet_top(scale) 26 | self.fe.update_Dirichlet_boundary_conditions(self.fe.dirichlet_bc_info) 27 | 28 | ele_type = 'HEX8' 29 | cell_type = get_meshio_cell_type(ele_type) 30 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 31 | vtk_dir = os.path.join(data_dir, 'vtk') 32 | 33 | files = glob.glob(os.path.join(vtk_dir, f'*')) 34 | for f in files: 35 | os.remove(f) 36 | 37 | Lx, Ly, Lz = 1., 1., 1. 38 | meshio_mesh = box_mesh_gmsh(Nx=1, Ny=1, Nz=1, Lx=Lx, Ly=Ly, Lz=Lz, data_dir=data_dir, ele_type=ele_type) 39 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 40 | 41 | def top(point): 42 | return np.isclose(point[2], Lz, atol=1e-5) 43 | 44 | def bottom(point): 45 | return np.isclose(point[2], 0., atol=1e-5) 46 | 47 | def get_dirichlet_top(scale): 48 | def val_fn(point): 49 | z_disp = scale*Lz 50 | return z_disp 51 | return val_fn 52 | 53 | def dirichlet_val_bottom(point): 54 | return 0. 55 | 56 | scales = 0.01*np.hstack((np.linspace(0., 1., 11), np.linspace(1, 0., 11))) 57 | 58 | location_fns = [bottom, top] 59 | vecs = [2, 2] 60 | value_fns = [dirichlet_val_bottom, get_dirichlet_top(0.)] 61 | dirichlet_bc_info = [location_fns, vecs, value_fns] 62 | 63 | problem = SingleCell(mesh, ele_type=ele_type, vec=3, dim=3, dirichlet_bc_info=dirichlet_bc_info) 64 | 65 | sol_list = [np.zeros(((problem.fe.num_total_nodes, problem.fe.vec)))] 66 | 67 | int_vars = problem.internal_vars 68 | 69 | for i, scale in enumerate(scales): 70 | print(f"\nStep {i} in {len(scales)}, scale = {scale}") 71 | problem.set_params([int_vars, scale]) 72 | sol_list = solver(problem, solver_options={'initial_guess': sol_list}) 73 | int_vars_copy = int_vars 74 | int_vars = problem.update_int_vars_gp(sol_list[0], int_vars) 75 | sigmas = problem.compute_stress(sol_list[0], int_vars_copy).mean(axis=1) 76 | print(f"max alpha = \n{np.max(int_vars[-1])}") 77 | print(sigmas[0]) 78 | vtk_path = os.path.join(vtk_dir, f'u_{i:03d}.vtu') 79 | save_sol(problem.fe, sol_list[0], vtk_path, cell_infos=[('s_norm', np.linalg.norm(sigmas, axis=(1, 2)))]) 80 | 81 | 82 | if __name__=="__main__": 83 | simulation() 84 | -------------------------------------------------------------------------------- /applications/forming/thin_plate.py: -------------------------------------------------------------------------------- 1 | """Reference 2 | Simo, Juan C., and Thomas JR Hughes. Computational inelasticity. Vol. 7. Springer Science & Business Media, 2006. 3 | Chapter 9: Phenomenological Plasticity Models 4 | 5 | Line search method is required! 6 | """ 7 | import jax 8 | import jax.numpy as np 9 | import jax.flatten_util 10 | import os 11 | import glob 12 | import matplotlib.pyplot as plt 13 | 14 | from jax_fem.problem import Problem 15 | from jax_fem.solver import solver 16 | from jax_fem.utils import save_sol 17 | from jax_fem.generate_mesh import box_mesh_gmsh, get_meshio_cell_type, Mesh 18 | 19 | from applications.forming.model import Plasticity 20 | 21 | 22 | def simulation(): 23 | 24 | class ThinPlate(Plasticity): 25 | def set_params(self, params): 26 | int_vars, scale = params 27 | self.internal_vars = int_vars 28 | self.fe.dirichlet_bc_info[-1][-1] = get_dirichlet_top(scale) 29 | self.fe.update_Dirichlet_boundary_conditions(self.fe.dirichlet_bc_info) 30 | 31 | ele_type = 'HEX8' 32 | cell_type = get_meshio_cell_type(ele_type) 33 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 34 | vtk_dir = os.path.join(data_dir, 'vtk') 35 | 36 | files = glob.glob(os.path.join(vtk_dir, f'*')) 37 | for f in files: 38 | os.remove(f) 39 | 40 | Lx, Ly, Lz = 10., 10., 0.25 41 | meshio_mesh = box_mesh_gmsh(Nx=40, Ny=40, Nz=1, Lx=Lx, Ly=Ly, Lz=Lz, data_dir=data_dir, ele_type=ele_type) 42 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 43 | 44 | def walls(point): 45 | left = np.isclose(point[0], 0., atol=1e-5) 46 | right = np.isclose(point[0], Lx, atol=1e-5) 47 | front = np.isclose(point[1], 0., atol=1e-5) 48 | back = np.isclose(point[1], Ly, atol=1e-5) 49 | return left | right | front | back 50 | 51 | def top(point): 52 | return np.isclose(point[2], Lz, atol=1e-5) 53 | 54 | def dirichlet_val(point): 55 | return 0. 56 | 57 | def get_dirichlet_top(scale): 58 | def val_fn(point): 59 | x, y = point[0], point[1] 60 | sdf = np.min(np.array([np.abs(x), np.abs(Lx - x), np.abs(y), np.abs(Ly - y)])) 61 | scaled_sdf = sdf/(0.5*np.minimum(Lx, Ly)) 62 | alpha = 3. 63 | EPS = 1e-10 64 | z_disp = -scale*Lx*(1./(1. + (1./(scaled_sdf + EPS) - 1.))) 65 | return z_disp 66 | return val_fn 67 | 68 | scales = 0.2*np.hstack((np.linspace(0., 1., 5), np.linspace(1, 0., 5))) 69 | 70 | location_fns = [walls]*3 + [top] 71 | value_fns = [dirichlet_val]*3 + [get_dirichlet_top(0.)] 72 | vecs = [0, 1, 2, 2] 73 | dirichlet_bc_info = [location_fns, vecs, value_fns] 74 | 75 | problem = ThinPlate(mesh, ele_type=ele_type, vec=3, dim=3, dirichlet_bc_info=dirichlet_bc_info) 76 | sol_list = [np.zeros(((problem.fe.num_total_nodes, problem.fe.vec)))] 77 | 78 | int_vars = problem.internal_vars 79 | 80 | for i, scale in enumerate(scales): 81 | print(f"\nStep {i} in {len(scales)}, scale = {scale}") 82 | problem.set_params([int_vars, scale]) 83 | 84 | # The line search method is necessary to get a converged solution. 85 | sol_list = solver(problem, solver_options={'initial_guess': sol_list, 'line_search_flag': True}) 86 | 87 | int_vars_copy = int_vars 88 | int_vars = problem.update_int_vars_gp(sol_list[0], int_vars) 89 | sigmas = problem.compute_stress(sol_list[0], int_vars_copy).mean(axis=1) 90 | print(f"max alpha = \n{np.max(int_vars[-1])}") 91 | print(sigmas[0]) 92 | vtk_path = os.path.join(vtk_dir, f'u_{i:03d}.vtu') 93 | save_sol(problem.fe, sol_list[0], vtk_path, cell_infos=[('s_norm', np.linalg.norm(sigmas, axis=(1, 2)))]) 94 | 95 | 96 | if __name__=="__main__": 97 | simulation() 98 | -------------------------------------------------------------------------------- /applications/outdated/aesthetic/README.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | Download model from [here](https://github.com/nicholasjng/jax-styletransfer) -------------------------------------------------------------------------------- /applications/outdated/aesthetic/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.parse_args() 6 | 7 | parser.add_argument("--pooling", type=str, default="max", help="Pooling method to use.") 8 | parser.add_argument("--num_steps", type=int, default=500, help="Number of training steps.") 9 | parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate of the Adam optimizer.") 10 | parser.add_argument("--image_size", type=int, default=512, help="Target size of the images in pixels.") 11 | parser.add_argument("--save_image_every", type=int, default=1, help="Saves the image every n steps to monitor progress.") 12 | 13 | parser.add_argument("--content_weight", type=float, default=1., help="Content loss weight.") 14 | parser.add_argument("--style_weight", type=float, default=1., help="Style loss weight.") 15 | 16 | output_path = os.path.join(os.path.dirname(__file__), 'output') 17 | parser.add_argument("--output_path", type=str, default=output_path, help="Output directory to save the styled images to.") 18 | # parser.add_argument("--content_layers", default=['conv_4'], 19 | # help="Names of network layers for which to capture content loss.") 20 | parser.add_argument("--content_layers", default={"conv_14": 0.1}, 21 | help="Names of network layers for which to capture content loss.") 22 | 23 | parser.add_argument("--style_layers", default={"conv_3": 1e1, "conv_5": 1e1}, 24 | help="Names of network layers for which to capture style loss.") 25 | 26 | 27 | # parser.add_argument("--style_layers", default={"conv_1": 1, 28 | # "conv_2": 1, 29 | # "conv_3": 1, 30 | # "conv_4": 1, 31 | # "conv_5": 1, 32 | # "conv_6": 1, 33 | # "conv_7": 1, 34 | # "conv_8": 1, 35 | # "conv_9": 1, 36 | # "conv_10": 1, 37 | # "conv_11": 1, 38 | # "conv_12": 1, 39 | # "conv_13": 1, 40 | # "conv_14": 1, 41 | # "conv_15": 1, 42 | # "conv_16": 1}, 43 | # help="Names of network layers for which to capture style loss.") 44 | 45 | parser.add_argument("--Lx", type=float, default=1., help="Length of domain.") 46 | parser.add_argument("--Ly", type=float, default=1., help="Width of domain.") 47 | parser.add_argument("--Nx", type=int, default=200, help="Number of elements along x-direction") 48 | parser.add_argument("--Ny", type=int, default=200, help="Number of elements along y-direction.") 49 | 50 | args = parser.parse_args() 51 | 52 | # TODO 53 | class bcolors: 54 | HEADER = '\033[95m' 55 | OKBLUE = '\033[94m' 56 | OKCYAN = '\033[96m' 57 | OKGREEN = '\033[92m' 58 | WARNING = '\033[93m' 59 | FAIL = '\033[91m' 60 | ENDC = '\033[0m' 61 | BOLD = '\033[1m' 62 | UNDERLINE = '\033[4m' 63 | -------------------------------------------------------------------------------- /applications/outdated/aesthetic/debug.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax.config import config 3 | import jax.numpy as np 4 | 5 | config.update("jax_enable_x64", True) 6 | 7 | a = np.array(1.) 8 | print(a.dtype) 9 | 10 | config.update("jax_enable_x64", False) 11 | 12 | 13 | b = np.array(1.) 14 | print(b.dtype) 15 | 16 | 17 | c = a + b 18 | print(c.dtype) 19 | 20 | 21 | 22 | 23 | config.update("jax_enable_x64", True) 24 | 25 | 26 | print(a.dtype) 27 | print(b.dtype) 28 | c = a + b 29 | print(c.dtype) 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /applications/outdated/aesthetic/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | from PIL import Image, ImageOps 8 | from jax import tree_util 9 | 10 | 11 | # TODO: Make target size a tuple for controlling aspect ratio 12 | def load_image(fp: str, img_type: str, target_size: int = 512, dtype=None, reverse = False): 13 | if not os.path.exists(fp): 14 | raise ValueError(f"File {fp} does not exist.") 15 | 16 | print(f'Loading {img_type} image...') 17 | 18 | image = Image.open(fp) 19 | 20 | image = ImageOps.grayscale(image).convert("RGB") 21 | 22 | image.save(fp[:-3] + 'jpg') 23 | 24 | image = image.resize((target_size, target_size)) 25 | 26 | image = image.rotate(-90) 27 | 28 | image = jnp.array(image, dtype=dtype) 29 | image = image / 255. 30 | 31 | image = np.where(image < 0.5, 0., 1.) 32 | 33 | image = jnp.clip(image, 0., 1.) 34 | image = jnp.expand_dims(jnp.moveaxis(image, -1, 0), 0) 35 | 36 | if reverse: 37 | image = 1. - image 38 | 39 | print(f"{img_type.capitalize()} image loaded successfully. " 40 | f"Shape: {image.shape}") 41 | 42 | return image 43 | 44 | 45 | def save_image(params: hk.Params, out_fp: str, reverse = False): 46 | im_data = tree_util.tree_leaves(params)[0] 47 | # clip values to avoid overflow problems in uint8 conversion 48 | im_data = jnp.clip(im_data, 0., 1.) 49 | 50 | if reverse: 51 | im_data = 1. - im_data 52 | 53 | # undo transformation block, squeeze off the batch dimension 54 | image: np.ndarray = np.squeeze(np.asarray(im_data)) 55 | image = image * 255 56 | image = image.astype(np.uint8) 57 | image = np.moveaxis(image, 0, -1) 58 | 59 | # TODO: This needs to change for a tiled image 60 | im = Image.fromarray(image, mode="RGB") 61 | 62 | os.makedirs(os.path.dirname(out_fp), exist_ok=True) 63 | im.save(out_fp) 64 | 65 | 66 | def checkpoint(params: hk.Params, out_dir: str, filename: str, reverse = False): 67 | """Saves the image at a checkpoint given by step.""" 68 | out_fp = os.path.join(out_dir, filename) 69 | 70 | save_image(params, out_fp, reverse) 71 | -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/contents/dancing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/contents/dancing.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/contents/structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/contents/structure.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/contents/tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/contents/tree.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/calligraphy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/calligraphy.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/circles.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/circles.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/circles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/circles.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/circles_ppt_made.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/circles_ppt_made.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/moha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/moha.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/picasso.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/square.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/square.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/square.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/strip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/strip.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/tree.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/tree.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/tree.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/voronoi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/voronoi.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/voronoi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/voronoi.png -------------------------------------------------------------------------------- /applications/outdated/aesthetic/input/styles/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/outdated/aesthetic/input/styles/wave.jpg -------------------------------------------------------------------------------- /applications/outdated/aesthetic/modules.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import haiku as hk 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | # ImageNet statistics 8 | imagenet_mean = jnp.array([0.485, 0.456, 0.406]) 9 | imagenet_std = jnp.array([0.229, 0.224, 0.225]) 10 | 11 | 12 | def gram_matrix(x: jnp.ndarray): 13 | """Computes Gram Matrix of an input array x.""" 14 | # N-C-H-W format 15 | # TODO: Refactor this to compute the Gram matrix of a feature map 16 | # and then apply jax.vmap on the batch dimension 17 | n, c, h, w = x.shape 18 | 19 | assert n == 1, "mini-batch has to be singular right now" 20 | 21 | features = jnp.reshape(x, (n * c, h * w)) 22 | 23 | return jnp.dot(features, features.T) / (n * c * h * w) 24 | 25 | 26 | class StyleLoss(hk.Module): 27 | """Identity layer capturing the style loss between input and target.""" 28 | 29 | def __init__(self, target, name: Optional[str] = None, weight = 0.): 30 | super(StyleLoss, self).__init__(name=name) 31 | self.target_g = jax.lax.stop_gradient(gram_matrix(target)) 32 | self.weight = weight 33 | 34 | def __call__(self, x: jnp.ndarray): 35 | g = gram_matrix(x) 36 | 37 | style_loss = self.weight*jnp.mean(jnp.square(g - self.target_g)) 38 | hk.set_state("style_loss", style_loss) 39 | 40 | return x 41 | 42 | 43 | class ContentLoss(hk.Module): 44 | """Identity layer capturing the content loss between input and target.""" 45 | 46 | def __init__(self, target, name: Optional[str] = None, weight = 0.): 47 | super(ContentLoss, self).__init__(name=name) 48 | self.target = jax.lax.stop_gradient(target) 49 | self.weight = weight 50 | 51 | def __call__(self, x: jnp.ndarray): 52 | content_loss = self.weight*jnp.mean(jnp.square(x - self.target)) 53 | hk.set_state("content_loss", content_loss) 54 | 55 | return x 56 | 57 | 58 | class Normalization(hk.Module): 59 | # create a module normalizing the input image 60 | # so we can easily put it into a hk.Sequential 61 | def __init__(self, 62 | image: jnp.ndarray, 63 | mean: jnp.ndarray, 64 | std: jnp.ndarray, 65 | name: Optional[str] = None): 66 | super(Normalization, self).__init__(name=name) 67 | 68 | # save image to make it a trainable parameter 69 | self.image = image 70 | 71 | # expand mean and std to make them [C x 1 x 1] so that they can 72 | # directly work with image Tensor of shape [N x C x H x W]. 73 | self.mean = jnp.expand_dims(mean, (1, 2)) 74 | self.std = jnp.expand_dims(std, (1, 2)) 75 | 76 | def __call__(self, x: jnp.ndarray, is_training: bool = False): 77 | # throw away the input and (re-)use the tracked parameter instead 78 | # this assures that the image is actually styled 79 | img = hk.get_parameter("image", 80 | shape=self.image.shape, 81 | dtype=self.image.dtype, 82 | init=hk.initializers.Constant(self.image)) 83 | 84 | if is_training: 85 | out = img 86 | else: 87 | out = x 88 | 89 | return (out - self.mean) / self.std 90 | -------------------------------------------------------------------------------- /applications/outdated/aesthetic/tree_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping 2 | 3 | import haiku as hk 4 | import jax.numpy as jnp 5 | from jax import tree_util 6 | 7 | __all__ = ["reduce_loss_tree", 8 | "weighted_loss", 9 | "split_loss_tree", 10 | "calculate_losses"] 11 | 12 | 13 | def reduce_loss_tree(loss_tree: Mapping) -> jnp.array: 14 | """Reduces a loss tree to a scalar (i.e. jnp.array w/ size 1).""" 15 | return tree_util.tree_reduce(lambda x, y: x + y, loss_tree) 16 | 17 | 18 | def weighted_loss(loss_tree: Mapping, weights: Mapping) -> Any: 19 | """Updates a loss tree by applying weights at the leaves.""" 20 | return hk.data_structures.map( 21 | # m: module_name, n: param name, v: param value 22 | lambda m, n, v: weights[n] * v, 23 | loss_tree) 24 | 25 | 26 | def split_loss_tree(loss_tree: Mapping): 27 | """Splits a loss tree into content and style loss trees.""" 28 | return hk.data_structures.partition( 29 | lambda m, n, v: n == "content_loss", 30 | loss_tree) 31 | 32 | 33 | def calculate_losses(loss_tree: Mapping, weights: Mapping): 34 | """Returns a tuple of current content loss and style loss.""" 35 | # obtain content and style trees 36 | content_tree, style_tree = split_loss_tree(loss_tree) 37 | 38 | # reduce and return losses 39 | return weights['content_loss']*reduce_loss_tree(content_tree), weights['style_loss']*reduce_loss_tree(style_tree) 40 | -------------------------------------------------------------------------------- /applications/outdated/fem_examples/README.md: -------------------------------------------------------------------------------- 1 | # Information 2 | 3 | The folder contains some examples presented in the JAX-FEM paper. -------------------------------------------------------------------------------- /applications/outdated/fem_examples/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as onp 3 | import os 4 | 5 | # Latex style plot 6 | plt.rcParams.update({ 7 | "text.latex.preamble": r"\usepackage{amsmath}", 8 | "text.usetex": True, 9 | "font.family": "sans-serif", 10 | "font.sans-serif": ["Helvetica"]}) 11 | 12 | 13 | def plot_plastic_stress_strain(): 14 | problem_names = ["linear_elasticity", "hyperelasticity", "plasticity"] 15 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 16 | y_lables = [r'Force on top surface [N]', r'Force on top surface [N]', r'Volume averaged stress (z-z) [MPa]'] 17 | ratios = [1e-3, 1e-3, 1.] 18 | 19 | for i in range(len(problem_names)): 20 | disps_path = os.path.join(data_dir, 'numpy', problem_names[i], 'fenicsx/disps.npy') 21 | fenicsx_forces_path = os.path.join(data_dir, 'numpy', problem_names[i], 'fenicsx/forces.npy') 22 | jax_fem_forces_path = os.path.join(data_dir, 'numpy', problem_names[i], 'jax_fem/forces.npy') 23 | fenicsx_forces = onp.load(fenicsx_forces_path) 24 | jax_fem_forces = onp.load(jax_fem_forces_path) 25 | disps = onp.load(disps_path) 26 | fig = plt.figure(figsize=(8, 6)) 27 | plt.plot(disps, fenicsx_forces*ratios[i], label='FEniCSx', color='blue', linestyle="-", linewidth=2) 28 | plt.plot(disps, jax_fem_forces*ratios[i], label='JAX-FEM', color='red', marker='o', markersize=8, linestyle='None') 29 | plt.xlabel(r'Displacement of top surface [mm]', fontsize=20) 30 | plt.ylabel(y_lables[i], fontsize=20) 31 | plt.tick_params(labelsize=18) 32 | plt.legend(fontsize=20, frameon=False) 33 | plt.savefig(os.path.join(data_dir, f'pdf/{problem_names[i]}_stress_strain.pdf'), bbox_inches='tight') 34 | 35 | 36 | def plot_performance(): 37 | data_dir = f"applications/fem/fem_examples/data/" 38 | abaqus_cpu_time = onp.loadtxt(os.path.join(data_dir, f"txt/abaqus_fem_time_cpu.txt")) 39 | abaqus_time_np_12 = onp.loadtxt(os.path.join(data_dir, f"txt/abaqus_fem_time_mpi_np_12.txt")) 40 | abaqus_time_np_24 = onp.loadtxt(os.path.join(data_dir, f"txt/abaqus_fem_time_mpi_np_24.txt")) 41 | fenicsx_time_np_1 = onp.loadtxt(os.path.join(data_dir, f"txt/fenicsx_fem_time_mpi_np_1.txt")) 42 | fenicsx_time_np_2 = onp.loadtxt(os.path.join(data_dir, f"txt/fenicsx_fem_time_mpi_np_2.txt")) 43 | fenicsx_time_np_4 = onp.loadtxt(os.path.join(data_dir, f"txt/fenicsx_fem_time_mpi_np_4.txt")) 44 | jax_time_cpu = onp.loadtxt(os.path.join(data_dir, f"txt/jax_fem_cpu_time.txt")) 45 | jax_time_gpu = onp.loadtxt(os.path.join(data_dir, f"txt/jax_fem_gpu_time.txt")) 46 | cpu_dofs = onp.loadtxt(os.path.join(data_dir, f"txt/jax_fem_cpu_dof.txt")) 47 | gpu_dofs = onp.loadtxt(os.path.join(data_dir, f"txt/jax_fem_gpu_dof.txt")) 48 | 49 | plt.figure(figsize=(12, 9)) 50 | plt.plot(gpu_dofs[1:], abaqus_cpu_time[1:], linestyle='-', marker='o', markersize=12, linewidth=2, color='blue', label='Abaqus CPU') 51 | plt.plot(gpu_dofs[1:], abaqus_time_np_12[1:], linestyle='-', marker='s', markersize=12, linewidth=2, color='blue', label='Abaqus CPU MPI 12') 52 | plt.plot(gpu_dofs[1:], abaqus_time_np_24[1:], linestyle='-', marker='^', markersize=12, linewidth=2, color='blue', label='Abaqus CPU MPI 24') 53 | plt.plot(cpu_dofs[1:], fenicsx_time_np_1[1:], linestyle='-', marker='o', markersize=12, linewidth=2, color='green', label='FEniCSx CPU') 54 | plt.plot(cpu_dofs[1:], fenicsx_time_np_2[1:], linestyle='-', marker='s', markersize=12, linewidth=2, color='green', label='FEniCSx CPU MPI 2') 55 | plt.plot(cpu_dofs[1:], fenicsx_time_np_4[1:], linestyle='-', marker='^', markersize=12, linewidth=2, color='green', label='FEniCSx CPU MPI 4') 56 | plt.plot(cpu_dofs[1:], jax_time_cpu[1:], linestyle='-', marker='s', markersize=12, linewidth=2, color='red', label='JAX-FEM CPU') 57 | plt.plot(gpu_dofs[1:], jax_time_gpu[1:], linestyle='-', marker='o', markersize=12, linewidth=2, color='red', label='JAX-FEM GPU') 58 | 59 | plt.xscale('log') 60 | plt.yscale('log') 61 | plt.xlabel("Number of DOFs", fontsize=20) 62 | plt.ylabel("Wall time [s]", fontsize=20) 63 | plt.tick_params(labelsize=20) 64 | ax = plt.gca() 65 | # ax.get_xaxis().set_tick_params(which='minor', size=0) 66 | # plt.xticks(plt_tmp, tick_labels) 67 | plt.tick_params(labelsize=20) 68 | plt.legend(fontsize=20, frameon=False) 69 | 70 | plt.savefig(os.path.join(data_dir, f'pdf/performance.pdf'), bbox_inches='tight') 71 | 72 | if __name__ == '__main__': 73 | # plot_plastic_stress_strain() 74 | plot_performance() 75 | plt.show() 76 | -------------------------------------------------------------------------------- /applications/outdated/full_field_infer/README.md: -------------------------------------------------------------------------------- 1 | # Information 2 | 3 | The folder contains the parameter identification examples presented in the JAX-FEM paper. -------------------------------------------------------------------------------- /applications/outdated/full_field_infer/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as onp 3 | import os 4 | 5 | # Latex style plot 6 | plt.rcParams.update({ 7 | "text.latex.preamble": r"\usepackage{amsmath}", 8 | "text.usetex": True, 9 | "font.family": "sans-serif", 10 | "font.sans-serif": ["Helvetica"]}) 11 | 12 | 13 | def plot_results(): 14 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 15 | obj_val, rel_error_sol, rel_error_force = onp.load(os.path.join(data_dir, f"numpy/outputs.npy")) 16 | truncate = 21 17 | obj_val, rel_error_sol, rel_error_force = obj_val[:truncate], rel_error_sol[:truncate], rel_error_force[:truncate] 18 | steps = onp.arange(len(obj_val)) 19 | print(rel_error_sol[-1]) 20 | 21 | plt.figure(figsize=(8, 6)) 22 | plt.plot(steps, obj_val, linestyle='-', marker='o', markersize=10, linewidth=2, color='black') 23 | plt.xlabel("Optimization step", fontsize=20) 24 | plt.ylabel("Objective value", fontsize=20) 25 | plt.tick_params(labelsize=20) 26 | plt.tick_params(labelsize=20) 27 | plt.savefig(os.path.join(data_dir, f'pdf/loss.pdf'), bbox_inches='tight') 28 | 29 | plt.figure(figsize=(8, 6)) 30 | plt.plot(steps, rel_error_sol, linestyle='-', marker='o', markersize=10, linewidth=2, color='black') 31 | plt.xlabel("Optimization step", fontsize=20) 32 | plt.ylabel("Inference error", fontsize=20) 33 | plt.tick_params(labelsize=20) 34 | plt.tick_params(labelsize=20) 35 | plt.savefig(os.path.join(data_dir, f'pdf/error.pdf'), bbox_inches='tight') 36 | 37 | 38 | hs, res_zero, res_first = onp.load(os.path.join(data_dir, f"numpy/res.npy")) 39 | 40 | ref_zero = [1/5.*res_zero[-1]/hs[-1] * h for h in hs] 41 | ref_first = [1/5.*res_first[-1]/hs[-1]**2 * h**2 for h in hs] 42 | 43 | plt.figure(figsize=(10, 8)) 44 | plt.plot(hs, res_zero, linestyle='-', marker='o', markersize=10, linewidth=2, color='blue', label=r"$r_{\textrm{zeroth}}$") 45 | plt.plot(hs, ref_zero, linestyle='--', linewidth=2, color='blue', label='First order reference') 46 | plt.plot(hs, res_first, linestyle='-', marker='o', markersize=10, linewidth=2, color='red', label=r"$r_{\textrm{first}}$") 47 | plt.plot(hs, ref_first, linestyle='--', linewidth=2, color='red', label='Second order reference') 48 | plt.xscale('log') 49 | plt.yscale('log') 50 | plt.xlabel(r"Step size $h$", fontsize=20) 51 | plt.ylabel("Residual", fontsize=20) 52 | plt.tick_params(labelsize=20) 53 | plt.tick_params(labelsize=20) 54 | plt.legend(fontsize=20, frameon=False) 55 | 56 | plt.savefig(os.path.join(data_dir, f'pdf/res.pdf'), bbox_inches='tight') 57 | 58 | 59 | 60 | if __name__=="__main__": 61 | plot_results() 62 | plt.show() 63 | -------------------------------------------------------------------------------- /applications/outdated/multi_scale/README.md: -------------------------------------------------------------------------------- 1 | # Information 2 | 3 | The folder contains ML-related examples presented in the JAX-FEM paper. -------------------------------------------------------------------------------- /applications/outdated/multi_scale/arguments.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax 3 | import jax.numpy as np 4 | import torch 5 | import argparse 6 | import os 7 | import sys 8 | import numpy as onp 9 | import matplotlib.pyplot as plt 10 | from jax.config import config 11 | 12 | torch.manual_seed(0) 13 | 14 | # Set numpy printing format 15 | onp.random.seed(0) 16 | onp.set_printoptions(threshold=sys.maxsize, linewidth=1000, suppress=True) 17 | onp.set_printoptions(precision=10) 18 | 19 | # Manage arguments 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--L', type=float, default=1.) 22 | parser.add_argument('--num_hex', type=int, default=10) 23 | parser.add_argument('--device', type=int, default=0) 24 | parser.add_argument('--E_in', type=float, default=1e3) 25 | parser.add_argument('--E_out', type=float, default=1e2) 26 | parser.add_argument('--nu_in', type=float, default=0.3) 27 | parser.add_argument('--nu_out', type=float, default=0.4) 28 | parser.add_argument('--ratio', type=float, default=0.3) 29 | 30 | 31 | parser.add_argument('--activation', choices=['tanh', 'selu', 'relu', 'sigmoid', 'softplus'], default='tanh') 32 | parser.add_argument('--width_hidden', type=int, default=64) 33 | parser.add_argument('--n_hidden', type=int, default=8) 34 | parser.add_argument('--lr', type=float, default=1e-4) 35 | parser.add_argument('--batch_size', type=int, default=32) 36 | parser.add_argument('--input_size', type=int, default=6) 37 | 38 | args = parser.parse_args() 39 | 40 | 41 | # Latex style plot 42 | plt.rcParams.update({ 43 | "text.latex.preamble": r"\usepackage{amsmath}", 44 | "text.usetex": True, 45 | "font.family": "sans-serif", 46 | "font.sans-serif": ["Helvetica"]}) -------------------------------------------------------------------------------- /applications/outdated/multi_scale/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python -m applications.fem.multi_scale.rve --device 1 & 3 | python -m applications.fem.multi_scale.rve --device 2 -------------------------------------------------------------------------------- /applications/outdated/multi_scale/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax 3 | import jax.numpy as np 4 | 5 | 6 | def flat_to_tensor(X_flat): 7 | return np.array([[X_flat[0], X_flat[3], X_flat[4]], 8 | [X_flat[3], X_flat[1], X_flat[5]], 9 | [X_flat[4], X_flat[5], X_flat[2]]]) 10 | 11 | 12 | def tensor_to_flat(X_tensor): 13 | return np.array([X_tensor[0, 0], X_tensor[1, 1], X_tensor[2, 2], X_tensor[0, 1], X_tensor[0, 2], X_tensor[1, 2]]) 14 | -------------------------------------------------------------------------------- /applications/outdated/thermal/bare_plate.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax 3 | import jax.numpy as np 4 | import os 5 | import glob 6 | import meshio 7 | 8 | from jax_fem.generate_mesh import box_mesh_gmsh, Mesh 9 | from jax_fem.solver import solver 10 | from jax_fem.utils import save_sol 11 | 12 | from applications.fem.thermal.models import Thermal, initialize_hash_map, update_hash_map, get_active_mesh 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 15 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 16 | 17 | 18 | def bare_plate_single_track(): 19 | t_total = 5. 20 | vel = 0.01 21 | dt = 1e-2 22 | T0 = 300. 23 | Cp = 500. 24 | L = 290e3 25 | rho = 8440. 26 | h = 50. 27 | rb = 1e-3 28 | eta = 0.4 29 | P = 500. 30 | vec = 1 31 | dim = 3 32 | ele_type = 'HEX8' 33 | 34 | ts = np.arange(0., 10e5, dt) 35 | # ts = np.arange(0., 10*dt, dt) 36 | 37 | vtk_dir = os.path.join(data_dir, 'vtk') 38 | 39 | problem_name = f'bare_plate' 40 | Nx, Ny, Nz = 150, 30, 10 41 | Lx, Ly, Lz = 30e-3, 6e-3, 2e-3 42 | meshio_mesh = box_mesh_gmsh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir) 43 | full_mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['hexahedron']) 44 | 45 | def top(point): 46 | return point[2] > 0. 47 | 48 | def walls(point): 49 | return True 50 | 51 | def neumann_top(point, old_T): 52 | # q is the heat flux into the domain 53 | d2 = (point[0] - laser_center[0])**2 + (point[1] - laser_center[1])**2 54 | q_laser = 2*eta*P/(np.pi*rb**2) * np.exp(-2*d2/rb**2) 55 | q = q_laser 56 | return np.array([q]) 57 | 58 | def neumann_walls(point, old_T): 59 | # q is the heat flux into the domain 60 | q_conv = h*(T0 - old_T[0]) 61 | q = q_conv 62 | return np.array([q]) 63 | 64 | neumann_bc_info = [None, [neumann_top, neumann_walls]] 65 | 66 | active_cell_truth_tab = onp.ones(len(full_mesh.cells), dtype=bool) 67 | active_mesh, points_map_active, cells_map_full = get_active_mesh(full_mesh, active_cell_truth_tab) 68 | external_faces, cells_face, hash_map, inner_faces, all_faces = initialize_hash_map(full_mesh, 69 | active_cell_truth_tab, cells_map_full, ele_type) 70 | sol = T0*np.ones((len(active_mesh.points), vec)) 71 | 72 | problem = Thermal(active_mesh, vec=vec, dim=dim, neumann_bc_info=neumann_bc_info, 73 | additional_info=(sol, rho, Cp, dt, external_faces)) 74 | 75 | files = glob.glob(os.path.join(vtk_dir, f'{problem_name}/*')) 76 | for f in files: 77 | os.remove(f) 78 | 79 | vtk_path = os.path.join(vtk_dir, f"{problem_name}/u_{0:05d}.vtu") 80 | save_sol(problem, sol, vtk_path) 81 | 82 | for i in range(len(ts[1:])): 83 | print(f"\nStep {i + 1}, total step = {len(ts)}, laser_x = {Lx*0.2 + vel*ts[i + 1]}") 84 | laser_center = np.array([Lx*0.2 + vel*ts[i + 1], Ly/2., Lz]) 85 | sol = solver(problem) 86 | problem.update_int_vars(sol) 87 | if (i + 1) % 10 == 0: 88 | vtk_path = os.path.join(vtk_dir, f"{problem_name}/u_{i + 1:05d}.vtu") 89 | save_sol(problem, sol, vtk_path) 90 | 91 | if Lx*0.2 + vel*ts[i + 1] > Lx*0.4: 92 | break 93 | 94 | 95 | if __name__ == "__main__": 96 | bare_plate_single_track() 97 | -------------------------------------------------------------------------------- /applications/outdated/thermal/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jax_fem.utils import make_video 3 | 4 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 5 | 6 | 7 | if __name__=="__main__": 8 | make_video(data_dir) 9 | -------------------------------------------------------------------------------- /applications/outdated/top_opt/README.md: -------------------------------------------------------------------------------- 1 | # Information 2 | 3 | The folder contains topology examples presented in the JAX-FEM paper. -------------------------------------------------------------------------------- /applications/outdated/top_opt/box.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax 3 | import jax.numpy as np 4 | import os 5 | import glob 6 | import os 7 | import meshio 8 | import time 9 | 10 | from jax_fem.generate_mesh import Mesh, box_mesh_gmsh 11 | from jax_fem.solver import ad_wrapper 12 | from jax_fem.utils import save_sol 13 | from jax_fem.common import walltime 14 | 15 | from applications.fem.top_opt.fem_model import Elasticity 16 | from applications.fem.top_opt.mma import optimize 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 19 | 20 | 21 | def topology_optimization(): 22 | problem_name = 'box' 23 | data_path = os.path.join(os.path.dirname(__file__), 'data') 24 | 25 | files = glob.glob(os.path.join(data_path, f'vtk/{problem_name}/*')) 26 | for f in files: 27 | os.remove(f) 28 | 29 | Lx, Ly, Lz = 2., 0.5, 1. 30 | Nx, Ny, Nz = 80, 20, 40 31 | 32 | meshio_mesh = box_mesh_gmsh(Nx, Ny, Nz, Lx, Ly, Lz, data_path) 33 | jax_mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['hexahedron']) 34 | 35 | def fixed_location(point): 36 | return np.isclose(point[0], 0., atol=1e-5) 37 | 38 | def load_location(point): 39 | return np.logical_and(np.isclose(point[0], Lx, atol=1e-5), np.isclose(point[2], 0., atol=0.1*Lz+1e-5)) 40 | 41 | def dirichlet_val(point): 42 | return 0. 43 | 44 | def neumann_val(point): 45 | return np.array([0., 0., -1e6]) 46 | 47 | dirichlet_bc_info = [[fixed_location]*3, [0, 1, 2], [dirichlet_val]*3] 48 | neumann_bc_info = [[load_location], [neumann_val]] 49 | problem = Elasticity(jax_mesh, vec=3, dim=3, dirichlet_bc_info=dirichlet_bc_info, neumann_bc_info=neumann_bc_info, additional_info=(problem_name,)) 50 | fwd_pred = ad_wrapper(problem, linear=True, use_petsc=False) 51 | 52 | def J_fn(dofs, params): 53 | """J(u, p) 54 | """ 55 | sol = dofs.reshape((problem.num_total_nodes, problem.vec)) 56 | compliance = problem.compute_compliance(neumann_val, sol) 57 | return compliance 58 | 59 | def J_total(params): 60 | """J(u(p), p) 61 | """ 62 | sol = fwd_pred(params) 63 | dofs = sol.reshape(-1) 64 | obj_val = J_fn(dofs, params) 65 | return obj_val 66 | 67 | outputs = [] 68 | def output_sol(params, obj_val): 69 | print(f"\nOutput solution - need to solve the forward problem again...") 70 | sol = fwd_pred(params) 71 | vtu_path = os.path.join(data_path, f'vtk/{problem_name}/sol_{output_sol.counter:03d}.vtu') 72 | save_sol(problem, sol, vtu_path, cell_infos=[('theta', problem.full_params[:, 0])]) 73 | print(f"compliance = {obj_val}") 74 | print(f"max theta = {np.max(params)}, min theta = {np.min(params)}, mean theta = {np.mean(params)}") 75 | outputs.append(obj_val) 76 | output_sol.counter += 1 77 | 78 | output_sol.counter = 0 79 | 80 | vf = 0.2 81 | 82 | def objectiveHandle(rho): 83 | J, dJ = jax.value_and_grad(J_total)(rho) 84 | output_sol(rho, J) 85 | return J, dJ 86 | 87 | def computeConstraints(rho, epoch): 88 | def computeGlobalVolumeConstraint(rho): 89 | g = np.mean(rho)/vf - 1. 90 | return g 91 | c, gradc = jax.value_and_grad(computeGlobalVolumeConstraint)(rho) 92 | c, gradc = c.reshape((1,)), gradc[None, ...] 93 | return c, gradc 94 | 95 | optimizationParams = {'maxIters':41, 'minIters':41, 'relTol':0.05} 96 | rho_ini = vf*np.ones((len(problem.flex_inds), 1)) 97 | _, mma_walltime = walltime(os.path.join(data_path, 'txt'))(optimize)(problem, rho_ini, optimizationParams, 98 | objectiveHandle, computeConstraints, numConstraints=1, movelimit=0.1) 99 | mma_walltime = onp.array(mma_walltime) 100 | print(mma_walltime) 101 | print(onp.sum(mma_walltime)) 102 | onp.save(os.path.join(data_path, f"numpy/{problem_name}_mma_walltime.npy"), mma_walltime) 103 | onp.save(os.path.join(data_path, f"numpy/{problem_name}_outputs.npy"), onp.array(outputs)) 104 | print(f"Compliance = {J_total(np.ones((len(problem.flex_inds), 1)))} for full material") 105 | 106 | 107 | if __name__=="__main__": 108 | topology_optimization() 109 | -------------------------------------------------------------------------------- /applications/outdated/top_opt/multi_material.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import jax 3 | import jax.numpy as np 4 | import os 5 | import glob 6 | import os 7 | import meshio 8 | import time 9 | 10 | from jax_fem.generate_mesh import Mesh, box_mesh_gmsh 11 | from jax_fem.solver import ad_wrapper 12 | from jax_fem.utils import save_sol 13 | 14 | from applications.fem.top_opt.fem_model import Elasticity 15 | from applications.fem.top_opt.mma import optimize 16 | 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 18 | 19 | def topology_optimization(): 20 | problem_name = 'multi_material' 21 | root_path = os.path.join(os.path.dirname(__file__), 'data') 22 | 23 | files = glob.glob(os.path.join(root_path, f'vtk/{problem_name}/*')) 24 | for f in files: 25 | os.remove(f) 26 | 27 | meshio_mesh = box_mesh_gmsh(50, 30, 1, 50., 30., 1., root_path) 28 | jax_mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['hexahedron']) 29 | 30 | def fixed_location(point): 31 | return np.isclose(point[0], 0., atol=1e-5) 32 | 33 | def load_location(point): 34 | return np.logical_and(np.isclose(point[0], 50., atol=1e-5), np.isclose(point[1], 15., atol=1.5)) 35 | 36 | def dirichlet_val(point): 37 | return 0. 38 | 39 | def neumann_val(point): 40 | return np.array([0., -1., 0.]) 41 | 42 | dirichlet_bc_info = [[fixed_location]*3, [0, 1, 2], [dirichlet_val]*3] 43 | neumann_bc_info = [[load_location], [neumann_val]] 44 | problem = Elasticity(jax_mesh, vec=3, dim=3, dirichlet_bc_info=dirichlet_bc_info, neumann_bc_info=neumann_bc_info, additional_info=(problem_name,)) 45 | fwd_pred = ad_wrapper(problem, linear=False) 46 | 47 | def J_fn(dofs, params): 48 | """J(u, p) 49 | """ 50 | sol = dofs.reshape((problem.num_total_nodes, problem.vec)) 51 | compliance = problem.compute_compliance(neumann_val, sol) 52 | return compliance 53 | 54 | def J_total(params): 55 | """J(u(p), p) 56 | """ 57 | sol = fwd_pred(params) 58 | dofs = sol.reshape(-1) 59 | obj_val = J_fn(dofs, params) 60 | return obj_val 61 | 62 | outputs = [] 63 | def output_sol(params, obj_val): 64 | print(f"\nOutput solution - need to solve the forward problem again...") 65 | sol = fwd_pred(params) 66 | vtu_path = os.path.join(root_path, f'vtk/{problem_name}/sol_{output_sol.counter:03d}.vtu') 67 | save_sol(problem, sol, vtu_path, cell_infos=[('theta1', problem.full_params[:, 0]), ('theta2', problem.full_params[:, 1])]) 68 | print(f"compliance = {obj_val}") 69 | print(f"max theta = {np.max(params)}, min theta = {np.min(params)}, mean theta = {np.mean(params)}") 70 | outputs.append(obj_val) 71 | output_sol.counter += 1 72 | 73 | output_sol.counter = 0 74 | 75 | vf = 0.3 76 | num_flex = len(problem.flex_inds) 77 | 78 | def objectiveHandle(rho): 79 | J, dJ = jax.value_and_grad(J_total)(rho) 80 | if objectiveHandle.counter % 10 == 0: 81 | output_sol(rho, J) 82 | objectiveHandle.counter += 1 83 | return J, dJ 84 | 85 | objectiveHandle.counter = 0 86 | 87 | def computeConstraints(rho, epoch): 88 | def computeGlobalVolumeConstraint(rho): 89 | rho1 = rho[:, 0] 90 | rho2 = rho[:, 1] 91 | 92 | # g = np.sum(rho1*(rho2*1. + (1-rho2)*0.4))/num_flex/vf - 1. 93 | 94 | g = np.sum(rho1*(rho2*1 + (1-rho2)*0.4))/num_flex/vf - 1. 95 | 96 | return g 97 | 98 | c, gradc = jax.value_and_grad(computeGlobalVolumeConstraint)(rho) 99 | c, gradc = c.reshape((1,)), gradc[None, ...] 100 | return c, gradc 101 | 102 | optimizationParams = {'maxIters':51, 'minIters':51, 'relTol':0.05} 103 | rho_ini = np.hstack((vf*np.ones((num_flex, 1)), 0.5*np.ones((num_flex, 1)))) 104 | optimize(problem, rho_ini, optimizationParams, objectiveHandle, computeConstraints, numConstraints=1) 105 | onp.save(os.path.join(root_path, f"numpy/{problem_name}_outputs.npy"), onp.array(outputs)) 106 | # print(f"Compliance = {J_total(np.ones((num_flex, 1)))} for full material") 107 | 108 | 109 | if __name__=="__main__": 110 | topology_optimization() 111 | -------------------------------------------------------------------------------- /applications/periodic_bc/README.md: -------------------------------------------------------------------------------- 1 | # Instructions 2 | 3 | Linear multipoint constraint for periodci B.C. 4 | 5 | (Updated in Dec, 2024) -------------------------------------------------------------------------------- /applications/periodic_bc/fenics.py: -------------------------------------------------------------------------------- 1 | # https://olddocs.fenicsproject.org/dolfin/2016.2.0/python/demo/documented/periodic/python/documentation.html 2 | import numpy as onp 3 | from dolfin import * 4 | 5 | # Source term 6 | class Source(UserExpression): 7 | def eval(self, values, x): 8 | dx = x[0] - 0.5 9 | dy = x[1] - 0.5 10 | values[0] = x[0]*sin(5.0*DOLFIN_PI*x[1]) \ 11 | + 1.0*exp(-(dx*dx + dy*dy)/0.02) 12 | 13 | # Sub domain for Dirichlet boundary condition 14 | class DirichletBoundary(SubDomain): 15 | def inside(self, x, on_boundary): 16 | return bool((x[1] < DOLFIN_EPS or x[1] > (1.0 - DOLFIN_EPS)) \ 17 | and on_boundary) 18 | 19 | # Sub domain for Periodic boundary condition 20 | class PeriodicBoundary(SubDomain): 21 | 22 | # Left boundary is "target domain" G 23 | def inside(self, x, on_boundary): 24 | return bool(x[0] < DOLFIN_EPS and x[0] > -DOLFIN_EPS and on_boundary) 25 | 26 | # Map right boundary (H) to left boundary (G) 27 | def map(self, x, y): 28 | y[0] = x[0] - 1.0 29 | y[1] = x[1] 30 | 31 | # Create mesh and finite element 32 | mesh = UnitSquareMesh(32, 32) 33 | V = FunctionSpace(mesh, "CG", 1, constrained_domain=PeriodicBoundary()) 34 | 35 | # Create Dirichlet boundary condition 36 | u0 = Constant(0.0) 37 | dbc = DirichletBoundary() 38 | bc0 = DirichletBC(V, u0, dbc) 39 | 40 | # Collect boundary conditions 41 | bcs = [bc0] 42 | 43 | # Define variational problem 44 | u = TrialFunction(V) 45 | v = TestFunction(V) 46 | f = Source() 47 | a = dot(grad(u), grad(v))*dx 48 | L = f*v*dx 49 | 50 | # Compute solution 51 | u = Function(V) 52 | solve(a == L, u, bcs) 53 | 54 | print(f"solution max = {onp.max(u.vector().get_local())}") -------------------------------------------------------------------------------- /applications/quiet_element/input/toolpath/thinwall_toolpath.crs: -------------------------------------------------------------------------------- 1 | 0.00000000 0.00000000 0.00000000 0.00000000 0 2 | 0.50000000 -36.75000000 0.00000000 0.70000000 0 3 | 11.00000000 36.75000000 0.00000000 0.70000000 1 4 | 12.50000000 -36.75000000 0.00000000 1.40000000 0 5 | 23.00000000 36.75000000 0.00000000 1.40000000 1 6 | 24.50000000 -36.75000000 0.00000000 2.10000000 0 7 | 35.00000000 36.75000000 0.00000000 2.10000000 1 8 | 36.50000000 -36.75000000 0.00000000 2.80000000 0 9 | 47.00000000 36.75000000 0.00000000 2.80000000 1 10 | 48.50000000 -36.75000000 0.00000000 3.50000000 0 11 | 59.00000000 36.75000000 0.00000000 3.50000000 1 12 | 60.50000000 -36.75000000 0.00000000 4.20000000 0 13 | 71.00000000 36.75000000 0.00000000 4.20000000 1 14 | 72.50000000 -36.75000000 0.00000000 4.90000000 0 15 | 83.00000000 36.75000000 0.00000000 4.90000000 1 16 | 84.50000000 -36.75000000 0.00000000 5.60000000 0 17 | 95.00000000 36.75000000 0.00000000 5.60000000 1 18 | 96.50000000 -36.75000000 0.00000000 6.30000000 0 19 | 107.00000000 36.75000000 0.00000000 6.30000000 1 20 | 108.50000000 -36.75000000 0.00000000 7.00000000 0 21 | 119.00000000 36.75000000 0.00000000 7.00000000 1 22 | 120.00000000 -36.75000000 0.00000000 7.00000000 0 -------------------------------------------------------------------------------- /applications/robin_bc/fenics.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import os 3 | from dolfin import * 4 | 5 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 6 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 7 | 8 | class LeftorRight(SubDomain): 9 | def inside(self,x,on_boundary): 10 | return on_boundary and (near(x[0], 0) or near(x[0], 1)) 11 | 12 | # Define Dirichlet boundary (x = 0 or x = 1) 13 | def boundary(x): 14 | return x[1] < DOLFIN_EPS or x[1] > 1.0 - DOLFIN_EPS 15 | 16 | # Create mesh and finite element 17 | mesh = UnitSquareMesh(32, 32) 18 | V = FunctionSpace(mesh, "CG", 1) 19 | 20 | boundaries = MeshFunction("size_t", mesh, mesh.topology().dim() - 1, 0) 21 | boundaries.set_all(0) 22 | 23 | # Now mark your Neumann boundary 24 | rightbound = LeftorRight() 25 | rightbound.mark(boundaries, 1) 26 | ds=Measure('ds')[boundaries] 27 | 28 | 29 | # Collect boundary conditions 30 | bc = DirichletBC(V, Constant(1.0), boundary) 31 | bcs = [bc] 32 | 33 | f = Expression("x[0]*sin(5.0*pi*x[1]) + 1.0*exp(-((x[0] - 0.5)*(x[0] - 0.5) + (x[1] - 0.5)*(x[1] - 0.5))/0.02)", degree=3) 34 | 35 | # Define variational problem 36 | u = Function(V) 37 | v = TestFunction(V) 38 | 39 | # F = dot(grad(u), grad(v))*dx + 5*u**2 *v*ds(1) 40 | F = dot(grad(u), grad(v))*dx -f*v*dx + 5*u**2 *v*ds(1) 41 | solve(F == 0, u, bcs) 42 | print(f"solution max = {onp.max(u.vector().get_local())}, min = {onp.min(u.vector().get_local())}") 43 | 44 | u.rename("u", "u") 45 | 46 | vtk_file = os.path.join(output_dir, f"vtk/u_fenics.pvd") 47 | File(vtk_file) << u 48 | 49 | # Save points and cells for the use of JAX-FEM 50 | # Build function space 51 | V = FiniteElement("Lagrange", mesh.ufl_cell(), 1) 52 | V_fs = FunctionSpace(mesh, V) 53 | 54 | points_u = V_fs.tabulate_dof_coordinates() 55 | print(f"points_u.shape = {points_u.shape}") 56 | 57 | cells_u = [] 58 | dofmap = V_fs.dofmap() 59 | for cell in cells(mesh): 60 | dof_index = dofmap.cell_dofs(cell.index()) 61 | # print(cell.index(), dof_index) 62 | cells_u.append(dof_index) 63 | cells_u = onp.stack(cells_u) 64 | print(f"cells_u.shape = {cells_u.shape}") 65 | 66 | onp.save(os.path.join(input_dir, f'numpy/points_u.npy'), points_u) 67 | onp.save(os.path.join(input_dir, f'numpy/cells_u.npy'), cells_u) 68 | 69 | -------------------------------------------------------------------------------- /applications/robin_bc/input/numpy/cells_u.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/robin_bc/input/numpy/cells_u.npy -------------------------------------------------------------------------------- /applications/robin_bc/input/numpy/points_u.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/robin_bc/input/numpy/points_u.npy -------------------------------------------------------------------------------- /applications/serendipity/example.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import numpy as onp 4 | import meshio 5 | import os 6 | import glob 7 | 8 | from jax_fem.problem import Problem 9 | from jax_fem.solver import solver 10 | from jax_fem.generate_mesh import Mesh, box_mesh_gmsh, get_meshio_cell_type 11 | from jax_fem.utils import save_sol 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 14 | 15 | 16 | class LinearElasticity(Problem): 17 | def get_tensor_map(self): 18 | def stress(u_grad): 19 | E = 70e3 20 | nu = 0.3 21 | mu = E/(2.*(1. + nu)) 22 | lmbda = E*nu/((1+nu)*(1-2*nu)) 23 | epsilon = 0.5*(u_grad + u_grad.T) 24 | sigma = lmbda*np.trace(epsilon)*np.eye(self.dim) + 2*mu*epsilon 25 | return sigma 26 | return stress 27 | 28 | 29 | def problem(): 30 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 31 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 32 | 33 | ele_type = 'HEX20' 34 | cell_type = get_meshio_cell_type(ele_type) 35 | 36 | mesh_file = os.path.join(input_dir, f"abaqus/cube.inp") 37 | meshio_mesh = meshio.read(mesh_file) 38 | 39 | meshio_mesh.points[:, 0] = meshio_mesh.points[:, 0] - onp.min(meshio_mesh.points[:, 0]) 40 | meshio_mesh.points[:, 1] = meshio_mesh.points[:, 1] - onp.min(meshio_mesh.points[:, 1]) 41 | meshio_mesh.points[:, 2] = meshio_mesh.points[:, 2] - onp.min(meshio_mesh.points[:, 2]) 42 | 43 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 44 | 45 | Lx = onp.max(mesh.points[:, 0]) 46 | Ly = onp.max(mesh.points[:, 1]) 47 | Lz = onp.max(mesh.points[:, 2]) 48 | 49 | print(f"Lx = {Lx}, Ly = {Ly}, Lz = {Lz}") 50 | 51 | def left(point): 52 | return np.isclose(point[0], 0., atol=1e-5) 53 | 54 | def right(point): 55 | return np.isclose(point[0], Lx, atol=1e-5) 56 | 57 | def zero_dirichlet_val(point): 58 | return 0. 59 | 60 | def dirichlet_val(point): 61 | return 0.1 * Lx 62 | 63 | dirichlet_bc_info = [[left, left, left, right, right, right], 64 | [0, 1, 2, 0, 1, 2], 65 | [zero_dirichlet_val, zero_dirichlet_val, zero_dirichlet_val, 66 | dirichlet_val, zero_dirichlet_val, zero_dirichlet_val]] 67 | 68 | problem = LinearElasticity(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info) 69 | sol_list = solver(problem) 70 | vtk_path = os.path.join(output_dir, f'vtk/u.vtu') 71 | save_sol(problem.fes[0], sol_list[0], vtk_path) 72 | 73 | prof_dir = os.path.join(output_dir, f'prof') 74 | os.makedirs(prof_dir, exist_ok=True) 75 | 76 | files = glob.glob(os.path.join(prof_dir, f'*')) 77 | for f in files: 78 | os.remove(f) 79 | 80 | jax.profiler.save_device_memory_profile(os.path.join(prof_dir, f'memory.prof')) 81 | 82 | 83 | if __name__ == "__main__": 84 | problem() 85 | -------------------------------------------------------------------------------- /applications/serendipity/remarks_using_pprof.md: -------------------------------------------------------------------------------- 1 | # command line instructions: 2 | 3 | go tool pprof memory.prof 4 | 5 | # Then input 6 | 7 | help 8 | 9 | # Then input 10 | 11 | svg 12 | 13 | # Open with browser 14 | -------------------------------------------------------------------------------- /applications/stokes/example_1var.py: -------------------------------------------------------------------------------- 1 | """Solve a 3D linear elasticity problem with one variable (vec=3) 2 | This is to verify that the implementation of multi-variable problem is correct. 3 | Compare the results with jax-fem/applications/stokes/example_2vars.py 4 | """ 5 | import jax 6 | import jax.numpy as np 7 | import jax.flatten_util 8 | import os 9 | 10 | from jax_fem.solver import solver 11 | from jax_fem.generate_mesh import Mesh, box_mesh_gmsh 12 | from jax_fem.utils import save_sol 13 | from jax_fem.problem import Problem 14 | 15 | 16 | class LinearElasticity(Problem): 17 | 18 | def get_universal_kernel(self): 19 | 20 | def stress(u_grad): 21 | E = 70e3 22 | nu = 0.3 23 | mu = E/(2.*(1. + nu)) 24 | lmbda = E*nu/((1+nu)*(1-2*nu)) 25 | epsilon = 0.5*(u_grad + u_grad.T) 26 | sigma = lmbda*np.trace(epsilon)*np.eye(self.dim) + 2*mu*epsilon 27 | return sigma 28 | 29 | def universal_kernel(cell_sol_flat, x, cell_shape_grads, cell_JxW, cell_v_grads_JxW, *cell_internal_vars): 30 | cell_sol_list = self.unflatten_fn_dof(cell_sol_flat) 31 | cell_shape_grads = cell_shape_grads[:, :self.fes[0].num_nodes, :] 32 | cell_sol = cell_sol_list[0] 33 | cell_JxW = cell_JxW[0] 34 | cell_v_grads_JxW = cell_v_grads_JxW[:, :self.fes[0].num_nodes, :, :] 35 | 36 | vec = self.fes[0].vec 37 | # (1, num_nodes, vec, 1) * (num_quads, num_nodes, 1, dim) -> (num_quads, num_nodes, vec, dim) 38 | u_grads = cell_sol[None, :, :, None] * cell_shape_grads[:, :, None, :] 39 | u_grads = np.sum(u_grads, axis=1) # (num_quads, vec, dim) 40 | u_grads_reshape = u_grads.reshape(-1, vec, self.dim) # (num_quads, vec, dim) 41 | # (num_quads, vec, dim) 42 | u_physics = jax.vmap(stress)(u_grads_reshape, *cell_internal_vars).reshape(u_grads.shape) 43 | # (num_quads, num_nodes, vec, dim) -> (num_nodes, vec) -> (num_nodes, vec) 44 | val = np.sum(u_physics[:, None, :, :] * cell_v_grads_JxW, axis=(0, -1)) 45 | 46 | return jax.flatten_util.ravel_pytree(val)[0] 47 | 48 | return universal_kernel 49 | 50 | 51 | def problem(): 52 | """Can be used to test the memory limit of JAX-FEM 53 | """ 54 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 55 | ele_type = 'HEX8' 56 | meshio_mesh = box_mesh_gmsh(2, 1, 1, 1., 1., 1., data_dir, ele_type=ele_type) 57 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['hexahedron']) 58 | 59 | def left(point): 60 | return np.isclose(point[0], 0., atol=1e-5) 61 | 62 | def right(point): 63 | return np.isclose(point[0], 1., atol=1e-5) 64 | 65 | def zero_dirichlet_val(point): 66 | return 0. 67 | 68 | def dirichlet_val(point): 69 | return 0.1 70 | 71 | dirichlet_bc_info = [[left, left, left, right, right, right], 72 | [0, 1, 2, 0, 1, 2], 73 | [zero_dirichlet_val, zero_dirichlet_val, zero_dirichlet_val, 74 | dirichlet_val, zero_dirichlet_val, zero_dirichlet_val]] 75 | 76 | problem = LinearElasticity(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info) 77 | 78 | sol_list = solver(problem) 79 | vtk_path = os.path.join(data_dir, f'vtk/u.vtu') 80 | save_sol(problem.fes[0], sol_list[0], vtk_path) 81 | 82 | 83 | if __name__ == "__main__": 84 | problem() -------------------------------------------------------------------------------- /applications/stokes/example_2vars.py: -------------------------------------------------------------------------------- 1 | """Solve a 3D linear elasticity problem with two variables (vec=[1, 2]) 2 | This is to verify that the implementation of multi-variable problem is correct. 3 | Compare the results with jax-fem/applications/stokes/example_1var.py 4 | """ 5 | import jax 6 | import jax.numpy as np 7 | import jax.flatten_util 8 | import os 9 | 10 | from jax_fem.solver import solver 11 | from jax_fem.generate_mesh import Mesh, box_mesh_gmsh 12 | from jax_fem.utils import save_sol 13 | from jax_fem.problem import Problem 14 | 15 | 16 | class LinearElasticity(Problem): 17 | 18 | def get_universal_kernel(self): 19 | 20 | def stress(u_grad): 21 | E = 70e3 22 | nu = 0.3 23 | mu = E/(2.*(1. + nu)) 24 | lmbda = E*nu/((1+nu)*(1-2*nu)) 25 | epsilon = 0.5*(u_grad + u_grad.T) 26 | sigma = lmbda*np.trace(epsilon)*np.eye(self.dim) + 2*mu*epsilon 27 | return sigma 28 | 29 | def universal_kernel(cell_sol_flat, x, cell_shape_grads, cell_JxW, cell_v_grads_JxW, *cell_internal_vars): 30 | cell_sol_list = self.unflatten_fn_dof(cell_sol_flat) 31 | cell_shape_grads = [cell_shape_grads[:, self.num_nodes_cumsum[i]: self.num_nodes_cumsum[i+1], :] 32 | for i in range(self.num_vars)][0] 33 | 34 | cell_sol = np.concatenate(cell_sol_list, axis=1) 35 | cell_v_grads_JxW = [cell_v_grads_JxW[:, self.num_nodes_cumsum[i]: self.num_nodes_cumsum[i+1], :, :] 36 | for i in range(self.num_vars)][0] 37 | 38 | vec = self.fes[0].vec + self.fes[1].vec 39 | 40 | # (1, num_nodes, vec, 1) * (num_quads, num_nodes, 1, dim) -> (num_quads, num_nodes, vec, dim) 41 | u_grads = cell_sol[None, :, :, None] * cell_shape_grads[:, :, None, :] 42 | u_grads = np.sum(u_grads, axis=1) # (num_quads, vec, dim) 43 | u_grads_reshape = u_grads.reshape(-1, vec, self.dim) # (num_quads, vec, dim) 44 | # (num_quads, vec, dim) 45 | u_physics = jax.vmap(stress)(u_grads_reshape, *cell_internal_vars).reshape(u_grads.shape) 46 | # (num_quads, num_nodes, vec, dim) -> (num_nodes, vec) -> (num_nodes, vec) 47 | val = np.sum(u_physics[:, None, :, :] * cell_v_grads_JxW, axis=(0, -1)) 48 | 49 | val = [val[:, :1], val[:, 1:]] 50 | 51 | return jax.flatten_util.ravel_pytree(val)[0] 52 | 53 | return universal_kernel 54 | 55 | 56 | def problem(): 57 | """Can be used to test the memory limit of JAX-FEM 58 | """ 59 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 60 | ele_type = 'HEX8' 61 | meshio_mesh = box_mesh_gmsh(2, 1, 1, 1., 1., 1., data_dir, ele_type=ele_type) 62 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['hexahedron']) 63 | 64 | def left(point): 65 | return np.isclose(point[0], 0., atol=1e-5) 66 | 67 | def right(point): 68 | return np.isclose(point[0], 1., atol=1e-5) 69 | 70 | def zero_dirichlet_val(point): 71 | return 0. 72 | 73 | def dirichlet_val(point): 74 | return 0.1 75 | 76 | dirichlet_bc_info1 = [[left, right], 77 | [0, 0], 78 | [zero_dirichlet_val, dirichlet_val]] 79 | 80 | 81 | dirichlet_bc_info2 = [[left, left, right, right], 82 | [0, 1, 0, 1], 83 | [zero_dirichlet_val, zero_dirichlet_val, 84 | zero_dirichlet_val, zero_dirichlet_val]] 85 | 86 | problem = LinearElasticity([mesh]*2, vec=[1, 2], dim=3, ele_type=[ele_type]*2, gauss_order=[None, None], 87 | dirichlet_bc_info=[dirichlet_bc_info1, dirichlet_bc_info2]) 88 | sol_list = solver(problem) 89 | vtk_path = os.path.join(data_dir, f'vtk/u.vtu') 90 | 91 | 92 | if __name__ == "__main__": 93 | problem() 94 | -------------------------------------------------------------------------------- /applications/stokes/fenics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from dolfin import * 3 | import os 4 | import numpy as onp 5 | import matplotlib.pyplot as plt 6 | 7 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 8 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 9 | 10 | # Load mesh and subdomains 11 | mesh = Mesh(os.path.join(input_dir, "xml/dolfin_fine.xml.gz")) 12 | sub_domains = MeshFunction("size_t", mesh, os.path.join(input_dir, "xml/dolfin_fine_subdomains.xml.gz")) 13 | 14 | 15 | # Build function space 16 | V = VectorElement("Lagrange", mesh.ufl_cell(), 2) 17 | Q = FiniteElement("Lagrange", mesh.ufl_cell(), 1) 18 | W = FunctionSpace(mesh, V * Q) 19 | 20 | 21 | # No-slip boundary condition for velocity 22 | # x1 = 0, x1 = 1 and around the dolphin 23 | noslip = Constant((0, 0)) 24 | bc0 = DirichletBC(W.sub(0), noslip, sub_domains, 0) 25 | 26 | # Inflow boundary condition for velocity 27 | # x0 = 1 28 | inflow = Expression(("-sin(x[1]*pi)", "0.0"), degree=2) 29 | bc1 = DirichletBC(W.sub(0), inflow, sub_domains, 1) 30 | 31 | # Boundary condition for pressure at outflow 32 | # x0 = 0 33 | zero = Constant(0) 34 | bc2 = DirichletBC(W.sub(1), zero, sub_domains, 2) 35 | 36 | # Collect boundary conditions 37 | bcs = [bc0, bc1, bc2] 38 | 39 | # Define variational problem 40 | (u, p) = TrialFunctions(W) 41 | (v, q) = TestFunctions(W) 42 | f = Constant((0, 0)) 43 | a = (inner(grad(u), grad(v)) + div(v)*p + q*div(u))*dx 44 | L = inner(f, v)*dx 45 | 46 | # Compute solution 47 | w = Function(W) 48 | solve(a == L, w, bcs) 49 | 50 | # Split the mixed solution using deepcopy 51 | # (needed for further computation on coefficient vector) 52 | (u, p) = w.split(True) 53 | 54 | print("Norm of velocity coefficient vector: %.15g" % u.vector().norm("l2")) 55 | print("Norm of pressure coefficient vector: %.15g" % p.vector().norm("l2")) 56 | 57 | print(f"Max u = {onp.max(u.vector()[:])}, Min u = {onp.min(u.vector()[:])}") 58 | print(f"Max p = {onp.max(p.vector()[:])}, Min p = {onp.min(p.vector()[:])}") 59 | 60 | # # Split the mixed solution using a shallow copy 61 | (u, p) = w.split() 62 | 63 | # Save solution in VTK format 64 | ufile_pvd = File(os.path.join(output_dir, "vtk/fenics_velocity.pvd")) 65 | u.rename("u", "u") 66 | ufile_pvd << u 67 | pfile_pvd = File(os.path.join(output_dir, "vtk/fenics_pressure.pvd")) 68 | p.rename("p", "p") 69 | pfile_pvd << p 70 | 71 | # Save points and cells for the use of JAX-FEM 72 | 73 | # Build function space 74 | V = FiniteElement("Lagrange", mesh.ufl_cell(), 2) 75 | Q = FiniteElement("Lagrange", mesh.ufl_cell(), 1) 76 | V_fs = FunctionSpace(mesh, V) 77 | Q_fs = FunctionSpace(mesh, Q) 78 | 79 | points_v = V_fs.tabulate_dof_coordinates() 80 | points_p = Q_fs.tabulate_dof_coordinates() 81 | print(f"points_v.shape = {points_v.shape}") 82 | print(f"points_p.shape = {points_p.shape}") 83 | 84 | cells_v = [] 85 | dofmap_v = V_fs.dofmap() 86 | for cell in cells(mesh): 87 | dof_index = dofmap_v.cell_dofs(cell.index()) 88 | # print(cell.index(), dof_index) 89 | cells_v.append(dof_index) 90 | cells_v = onp.stack(cells_v) 91 | print(f"cells_v.shape = {cells_v.shape}") 92 | 93 | cells_p = [] 94 | dofmap_p = Q_fs.dofmap() 95 | for cell in cells(mesh): 96 | dof_index = dofmap_p.cell_dofs(cell.index()) 97 | # print(cell.index(), dof_index) 98 | cells_p.append(dof_index) 99 | cells_p = onp.stack(cells_p) 100 | print(f"cells_p.shape = {cells_p.shape}") 101 | 102 | re_order = [0, 1, 2, 5, 3, 4] 103 | cells_v = cells_v[:, re_order] 104 | 105 | onp.save(os.path.join(input_dir, f'numpy/points_u.npy'), points_v) 106 | onp.save(os.path.join(input_dir, f'numpy/cells_u.npy'), cells_v) 107 | onp.save(os.path.join(input_dir, f'numpy/points_p.npy'), points_p) 108 | onp.save(os.path.join(input_dir, f'numpy/cells_p.npy'), cells_p) 109 | 110 | # The dof order now should follow JAX-FEM (same as Abaqus) 111 | # https://classes.engineering.wustl.edu/2009/spring/mase5513/abaqus/docs/v6.6/books/stm/default.htm?startat=ch03s02ath64.html 112 | # But the dof can be clockwise, which needs further modification 113 | selected_p = points_v[cells_v[6]] 114 | 115 | plt.plot(selected_p[0, 0], selected_p[0, 1], marker='o', color='red') 116 | plt.plot(selected_p[1, 0], selected_p[1, 1], marker='o', color='blue') 117 | plt.plot(selected_p[2, 0], selected_p[2, 1], marker='o', color='orange') 118 | plt.plot(selected_p[3, 0], selected_p[3, 1], marker='s', color='red') 119 | plt.plot(selected_p[4, 0], selected_p[4, 1], marker='s', color='blue') 120 | plt.plot(selected_p[5, 0], selected_p[5, 1], marker='s', color='orange') 121 | 122 | plt.show() 123 | -------------------------------------------------------------------------------- /applications/stokes/input/numpy/cells_p.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/stokes/input/numpy/cells_p.npy -------------------------------------------------------------------------------- /applications/stokes/input/numpy/cells_u.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/stokes/input/numpy/cells_u.npy -------------------------------------------------------------------------------- /applications/stokes/input/numpy/points_p.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/stokes/input/numpy/points_p.npy -------------------------------------------------------------------------------- /applications/stokes/input/numpy/points_u.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/stokes/input/numpy/points_u.npy -------------------------------------------------------------------------------- /applications/stokes/input/xml/dolfin_fine.xml.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/stokes/input/xml/dolfin_fine.xml.gz -------------------------------------------------------------------------------- /applications/stokes/input/xml/dolfin_fine_subdomains.xml.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/stokes/input/xml/dolfin_fine_subdomains.xml.gz -------------------------------------------------------------------------------- /applications/surrogate_model/example.py: -------------------------------------------------------------------------------- 1 | # Import standard modules. 2 | import numpy as onp 3 | import os 4 | import pickle 5 | import torch 6 | 7 | # Import JAX modules. 8 | import jax 9 | import jax.numpy as np 10 | from flax import linen as nn 11 | from jax import grad, jit, vmap 12 | import flax 13 | 14 | # Import JAX-FEM specific modules. 15 | from jax_fem.problem import Problem 16 | from jax_fem.solver import solver 17 | from jax_fem.utils import save_sol 18 | from jax_fem.generate_mesh import rectangle_mesh, get_meshio_cell_type, Mesh 19 | 20 | # Import local modules 21 | from applications.surrogate_model.train import SurrogateModel, Network 22 | 23 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 24 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 25 | 26 | 27 | class NnbasedMetamaterial(Problem): 28 | # Use neural network surrogate model as the constitutive model for the metamaterial. 29 | def get_tensor_map(self): 30 | def P_fn(E): 31 | K2 = np.concatenate((E[0,0].reshape(-1), E[0,1].reshape(-1), E[1,1].reshape(-1)), 0) 32 | P = self.surrogate_model.compute_input_gradient(self.surrogate_model.params, K2) - self.p_initial 33 | PK2 = np.concatenate((P[0].reshape(-1), P[1].reshape(-1), P[1].reshape(-1), P[2].reshape(-1)), 0).reshape(2,2) 34 | return PK2 35 | 36 | def first_PK_stress(u_grad): 37 | I = np.eye(self.dim) 38 | F = u_grad + I 39 | C = F.T @ F 40 | E = 0.5*(C - I) 41 | Pk2 = P_fn(E) 42 | Pk1 = F @ Pk2 43 | return Pk1 44 | return first_PK_stress 45 | 46 | def problem(): 47 | # Specify mesh-related information 48 | ele_type = 'QUAD4' 49 | cell_type = get_meshio_cell_type(ele_type) 50 | 51 | 52 | vtk_dir = os.path.join(output_dir, 'vtk') 53 | dataset_dir = os.path.join(input_dir, 'dataset') 54 | os.makedirs(vtk_dir, exist_ok=True) 55 | Lx, Ly = 30, 30 56 | meshio_mesh = rectangle_mesh(Nx=30, 57 | Ny=30, 58 | domain_x=Lx, 59 | domain_y=Ly) 60 | 61 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 62 | 63 | # Define boundary locations. 64 | def left(point): 65 | return np.isclose(point[0], 0., atol=1e-5) 66 | 67 | def right(point): 68 | return np.isclose(point[0], Lx, atol=1e-5) 69 | 70 | def bottom(point): 71 | return np.isclose(point[1], 0., atol=1e-5) 72 | 73 | def top(point): 74 | return np.isclose(point[1], Ly, atol=1e-5) 75 | 76 | def dirichlet_val_bottom(point): 77 | return 0. 78 | 79 | # Create an instance of the displacement problem. 80 | def get_dirichlet_top(disp): 81 | def val_fn(x): 82 | return disp 83 | return val_fn 84 | 85 | location_fns = [bottom, bottom, top, top] 86 | vecs = [0, 1, 0, 1] 87 | dirichlet_bc_info = [location_fns, vecs, [dirichlet_val_bottom]*3 + [get_dirichlet_top(-0.05*30)]] 88 | 89 | # Create an instance of the problem. 90 | surrogate_problem = NnbasedMetamaterial(mesh, 91 | vec=2, 92 | dim=2, 93 | ele_type=ele_type, 94 | dirichlet_bc_info=dirichlet_bc_info) 95 | 96 | # Create surrogate model 97 | surrogate_problem.surrogate_model = SurrogateModel(Network) 98 | model_file_path = os.path.join(input_dir, 'model.pth') 99 | if not os.path.exists(model_file_path): 100 | raise ValueError(f"Please run the train.py file to train the model first.") 101 | 102 | pkl_file = pickle.load(open(model_file_path, "rb")) 103 | surrogate_problem.surrogate_model.params = flax.serialization.from_state_dict(target=surrogate_problem.surrogate_model.model, state=pkl_file) 104 | 105 | # Stress of strain-free state 106 | e_initial = np.concatenate((np.array([0.]).reshape(-1), np.array([0.]).reshape(-1), np.array([0.]).reshape(-1)), 0) 107 | surrogate_problem.p_initial = surrogate_problem.surrogate_model.compute_input_gradient(surrogate_problem.surrogate_model.params, e_initial) 108 | 109 | # Solve problem 110 | sol = solver(surrogate_problem, solver_options={'petsc_solver':{}}) 111 | surrogate_problem.fes[0].update_Dirichlet_boundary_conditions(dirichlet_bc_info) 112 | vtk_path = os.path.join(vtk_dir, 'displacement.vtk') 113 | save_sol(surrogate_problem.fes[0], sol[0], vtk_path) 114 | 115 | 116 | if __name__ == "__main__": 117 | problem() 118 | -------------------------------------------------------------------------------- /applications/surrogate_model/input/dataset/e_strain.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/surrogate_model/input/dataset/e_strain.npy -------------------------------------------------------------------------------- /applications/surrogate_model/input/dataset/pk2_data.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/surrogate_model/input/dataset/pk2_data.npy -------------------------------------------------------------------------------- /applications/surrogate_model/input/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/surrogate_model/input/model.pth -------------------------------------------------------------------------------- /applications/surrogate_model/materials/RVE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/surrogate_model/materials/RVE.png -------------------------------------------------------------------------------- /applications/surrogate_model/materials/compression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/surrogate_model/materials/compression.png -------------------------------------------------------------------------------- /applications/surrogate_model/materials/strech.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/surrogate_model/materials/strech.png -------------------------------------------------------------------------------- /applications/surrogate_model/materials/surrogatemodel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/applications/surrogate_model/materials/surrogatemodel.png -------------------------------------------------------------------------------- /applications/surrogate_model/readme.md: -------------------------------------------------------------------------------- 1 | # Neural Network Surrogate Model 2 | 3 | ## Work flow 4 | 5 |
6 |
7 |
20 |
21 |
Geometry for RVE
46 |
47 |
48 |
50 | Deformation under 20% stretch (left) and 5% compression (right) 51 |
52 | 53 | 54 | ## Author 55 | 56 | This example was constributed by Min Shen from Prof. Sheng Mao's research group at Peking University. 57 | 58 | 59 | ## Reference 60 | 61 | [1] Xue, Tianju, Alex Beatson, Maurizio Chiaramonte, Geoffrey Roeder, Jordan T. Ash, Yigit Menguc, Sigrid Adriaenssens, Ryan P. Adams, and Sheng Mao. "A data-driven computational scheme for the nonlinear mechanical properties of cellular mechanical metamaterials under large deformation." Soft matter 16, no. 32 (2020): 7524-7534. 62 | 63 | -------------------------------------------------------------------------------- /applications/surrogate_model/train.py: -------------------------------------------------------------------------------- 1 | # Import some useful modules. 2 | import jax 3 | import jax.numpy as np 4 | import numpy as onp 5 | import os 6 | 7 | # Import JAX modules. 8 | from flax import linen as nn 9 | import flax 10 | from flax.core import freeze, unfreeze 11 | from jax import random 12 | from jax import grad, jit 13 | import pickle 14 | import optax 15 | from functools import partial 16 | 17 | # Import modules for dataset 18 | import torch 19 | from torch.utils.data import Dataset, DataLoader 20 | from torch.utils.data.sampler import SubsetRandomSampler 21 | torch.manual_seed(1) 22 | 23 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 24 | 25 | # Built the NN model 26 | class Network(nn.Module): 27 | """ 28 | Input: Three components of the Euler-Lagrange strain: E11, E12, E22 29 | Ouput: Strain energy 30 | """ 31 | @nn.compact 32 | def __call__(self, x): 33 | x = nn.Dense(128, kernel_init=nn.initializers.xavier_uniform(), use_bias=True)(x) 34 | x = nn.elu(x) 35 | x = nn.Dense(128, kernel_init=nn.initializers.xavier_uniform(), use_bias=True)(x) 36 | x = nn.elu(x) 37 | x = nn.Dense(128, kernel_init=nn.initializers.xavier_uniform(), use_bias=True)(x) 38 | x = nn.elu(x) 39 | x = nn.Dense(1, kernel_init=nn.initializers.xavier_uniform(), use_bias=True)(x) 40 | return x 41 | 42 | # Built the dataset 43 | class FEMDataset(Dataset): 44 | """ 45 | Dataset data including: 46 | 1. Euler-Lagrange strain components: E11, E12, E22 47 | 2. Second PK stress components: PK2_11, PK2_12, PK2_22 48 | """ 49 | def __init__(self, lens, e_data, pk2_data): 50 | self.lens = lens 51 | self.e_data = e_data 52 | self.pk2_data = pk2_data 53 | 54 | def __len__(self): 55 | return self.lens 56 | 57 | def __getitem__(self, idx): 58 | if torch.is_tensor(idx): 59 | idx = idx.tolist() 60 | e_strain = self.e_data [idx] 61 | pk2 = self.pk2_data[idx] 62 | pk2_variable = torch.cat((pk2[0, 0].unsqueeze(0), pk2[1, 0].unsqueeze(0), pk2[1, 1].unsqueeze(0)), 0) 63 | return {"e_strain": e_strain, 64 | "pk2": pk2_variable} 65 | 66 | 67 | class SurrogateModel(): 68 | def __init__(self, Network): 69 | self.model = Network() 70 | 71 | def compute_input_gradient(self, params, x): 72 | def forward(param, x): 73 | return self.model.apply(param, x).sum() 74 | grad_fn = jax.grad(forward, argnums=1) 75 | return grad_fn(params, x) 76 | 77 | def train(self, x_data, label_data, init_x): 78 | """ 79 | Training process. 80 | x_data: Input data. 81 | label_data: Output results. 82 | """ 83 | def nllloss(params, x, y_label): 84 | y_pred = self.compute_input_gradient(params, x) 85 | loss = np.sqrt((y_label - y_pred)**2 / (y_label**2 + 1e-3)) 86 | return np.mean(loss) 87 | 88 | @jax.jit 89 | def train_step(params, x, y, opt_state): 90 | loss, grads = jax.value_and_grad(nllloss)(params, x, y) 91 | updates, opt_state = optimizer.update(grads, opt_state) 92 | params = optax.apply_updates(params, updates) 93 | return params, loss, opt_state 94 | 95 | num_epochs = 100 96 | learning_rate = 1e-4 97 | 98 | optimizer = optax.adam(learning_rate) 99 | params = self.model.init(random.PRNGKey(0), init_x) 100 | opt_state = optimizer.init(params) 101 | dataloader = self.get_dataset(x_data, label_data) 102 | for epoch in range(0, num_epochs): 103 | for counts, data in enumerate(dataloader, 0): 104 | x = np.array(data["e_strain"]) 105 | label = np.array(data["pk2"]) 106 | b_size = x.shape[0] 107 | params, loss, opt_state = train_step(params, x, label, opt_state) 108 | if epoch % 1 == 0: 109 | print(f'Epoch {epoch}, Loss: {loss:.4f}') 110 | state_dict = flax.serialization.to_state_dict(params) 111 | 112 | model_file_path = os.path.join(input_dir, 'model.pth') 113 | pickle.dump(state_dict, open(model_file_path, "wb")) 114 | 115 | def get_dataset(self, x_data, label_data): 116 | """ 117 | Get dataset by pytorch package. 118 | x_data: strain data array in torch format 119 | label_data: stress data array in torch format 120 | """ 121 | dataset = FEMDataset(lens=len(x_data), e_data=x_data, pk2_data=label_data) 122 | indices = list(range((len(dataset)))) 123 | train_sampler = SubsetRandomSampler(indices[1:len(x_data):1]) 124 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, 125 | num_workers=0, sampler=train_sampler, pin_memory=True) 126 | return dataloader 127 | 128 | 129 | if __name__== '__main__': 130 | pk2_data = torch.from_numpy(onp.load(os.path.join(input_dir, "dataset/pk2_data.npy"))) 131 | estrain_data = torch.from_numpy(onp.load(os.path.join(input_dir, "dataset/e_strain.npy"))) 132 | init_x = estrain_data[0] 133 | surrogate_model = SurrogateModel(Network) 134 | surrogate_model.train(estrain_data, pk2_data, init_x) 135 | -------------------------------------------------------------------------------- /demos/README.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | _JAX-FEM_ is a light-weight Finite Element Method library in pure Python, accelerated with [_JAX_](https://github.com/google/jax). This folder contains tutorial examples with explanatory comments. The software is still at an experimental stage. 4 | 5 | In the following paragraphs, I want to share some of my motivations on developing the software: How is _JAX-FEM_ different from other FEM codes? What's new? Who cares? What should users expect and NOT expect from this software? 6 | 7 | ## Life Is Short, Use Python 8 | 9 | My first exposure to open-source FEM library was [_Dealii_](https://www.dealii.org/), a powerful C++ software library that allows users to build FEM codes to solve a broad variety of PDEs. While I enjoyed very much the flexibility of Dealii, a significant amount of my time was indeed spent on writing lengthy C++ code that easily became challenging for debugging and maintaining. 10 | 11 | My second choice was [_FEniCS_](https://fenicsproject.org/) (now _FEniCSx_), an amazing FEM library with high-level Python interfaces. The beauty of _FEniCS_ is that users write near-math code in Python, and immediately solve their (possibly nonlinear) problems, with highly competitive performance due to the C++ backend. Yet, the use of automatic (symbolic) differentiation by _FEniCS_ comes with a price: it becomes cumbersome for complicated constitutive relationships. When solving problems of solid mechanics, typically, a mapping from strain to stress needs to be specified. If this mapping can be explicitly expressed with an analytical form, _FEniCS_ works just fine. However, this is not always the case. There are two examples in my field. One is crystal plasticity, where strain is often times related to stress through an implicit function. The other example is the phase field fracture problem, where eigenvalue decomposition for the strain is necessary. After weeks of unsuccessful trials with _FEniCS_, I started the idea of implementing an FEM code myself that handles complicated constitutive relationships, and that became the start of _JAX-FEM_. 12 | 13 | Staying in the Python ecosystem, _JAX_ becomes a natural choice, due to its [outstanding performance for scientific computing workloads](https://github.com/dionhaefner/pyhpc-benchmarks/tree/master). 14 | 15 | ## The Magic of Automatic Differentiation 16 | 17 | The design of _JAX-FEM_ fundamentally exploits automatic differentiation. The rule of thumb is that whenever there is a derivative to take, let the machine do it. Some typical examples include 18 | 19 | 1. In a hyperelasticity problem, given strain energy density function $\psi(\boldsymbol F)$, compute first PK stress $\boldsymbol{P}=\frac{\partial \psi}{\partial \boldsymbol{F}}$. 20 | 21 | 2. In a plasticity problem, given stress $\boldsymbol{\sigma} (\boldsymbol{\varepsilon}, \boldsymbol{\alpha})$ as a function of strain and some internal variables , compute fourth-order elasto-plastic tangent moduli tensor $\mathbb{C}=\frac{\partial \boldsymbol{\sigma}}{\partial \boldsymbol{\varepsilon}}$. 22 | 3. In a topology optimization problem, the computation of sensitivity can be fully automatic. 23 | 24 | As developers, we are actively using _JAX-FEM_ to solve inverse problems (or PDE-constrained optimizaiton problems) involving complicated constitutive relationships, with thanks to AD that makes this effort easy. 25 | 26 | ## Native in Machine Learning 27 | 28 | Since _JAX_ itself is a framework for machine learning, _JAX-FEM_ trivially has access to the ecosystem of _JAX_. If you have a material model represented by a neural network, and you want to deploy that model into the computation of FEM, _JAX-FEM_ will be a perfect tool. No need to hard code the neural network coefficients into a Fortran file and run _Abaqus_! 29 | 30 | ## Heads Up! 31 | 32 | 1. **Kernels**. _JAX-FEM_ uses kernels to handle different terms in the FEM weak form, a concept similar as in [_MOOSE_](https://mooseframework.inl.gov/syntax/Kernels/). Currently, we can handle the "Laplace kernel" $\int_{\Omega} f(\nabla u)\cdot \nabla v$ and the "mass kernel" $\int_{\Omega}h(u)v$ in the weak form. This covers solving typical second-order elliptic equations like those occurring in quasi-static solid mechanics, or time-dependent parabolic problems like a heat equation. We also provide a "universal kernel" that lets users define their own weak form. This is a new feature introduced on Dec 11, 2023. 33 | 34 | 2. **Performance**. In most cases, the majority of computational time is spent on solving the linear system from the Newton's method. If CPU is available, the linear system will be solved by [_PETSc_](https://petsc.org/release/); if GPU is available, solving the linear system with _JAX_ built-in sparse linear solvers will usually be faster and scalable to larger problems. Exploiting multiple CPUs and/or even multiple GPUs is our future work. Please see our _JAX-FEM_ journal paper for performance report. 35 | 36 | 3. **Memory**. The largest problem that is solved without causing memory insufficiency issue on a 48G memory RTX8000 Nvidia GPU contains around 9 million DOFs. 37 | 38 | 4. **Nonlinearity**. _JAX-FEM_ handles material nonlinearity well, but currently does not handle other types of nonlinearities such as contact. Secondary development is needed. 39 | 40 | 5. **Boundary conditions**. As of now, we cannot handle periodic boundary conditions. We need some help on this. 41 | -------------------------------------------------------------------------------- /demos/hyperelasticity/README.md: -------------------------------------------------------------------------------- 1 | # Hyperelasticity 2 | 3 | ## Formulation 4 | 5 | The governing equation for hyperelasticity of a body $\Omega$ can be written as 6 | 7 | $$ 8 | \begin{align*} 9 | -\nabla \cdot \boldsymbol{P} = \boldsymbol{b} & \quad \textrm{in} \nobreakspace \nobreakspace \Omega, \\ 10 | \boldsymbol{u} = \boldsymbol{u}_D & \quad\textrm{on} \nobreakspace \nobreakspace \Gamma_D, \\ 11 | \boldsymbol{P} \cdot \boldsymbol{n} = \boldsymbol{t} & \quad \textrm{on} \nobreakspace \nobreakspace \Gamma_N. 12 | \end{align*} 13 | $$ 14 | 15 | The weak form gives 16 | 17 | $$ 18 | \begin{align*} 19 | \int_{\Omega} \boldsymbol{P} : \nabla \boldsymbol{v} \nobreakspace \nobreakspace \textrm{d}x = \int_{\Omega} \boldsymbol{b} \cdot \boldsymbol{v} \nobreakspace \textrm{d}x + \int_{\Gamma_N} \boldsymbol{t} \cdot \boldsymbol{v} \nobreakspace\nobreakspace \textrm{d}s. 20 | \end{align*} 21 | $$ 22 | 23 | Here, $\boldsymbol{P}$ is the first Piola-Kirchhoff stress and is given by 24 | 25 | $$ 26 | \begin{align*} 27 | \boldsymbol{P} &= \frac{\partial W}{\partial \boldsymbol{F}}, \\ 28 | \boldsymbol{F} &= \nabla \boldsymbol{u} + \boldsymbol{I}, \\ 29 | W (\boldsymbol{F}) &= \frac{G}{2}(J^{-2/3} I_1 - 3) + \frac{\kappa}{2}(J - 1)^2, 30 | \end{align*} 31 | $$ 32 | 33 | where $\boldsymbol{F}$ is the deformation gradient and $W$ is the strain energy density function. This constitutive relationship comes from a neo-Hookean solid model [2]. 34 | 35 | 36 | We have the following definitions: 37 | * $\Omega=(0,1)\times(0,1)\times(0,1)$ (a unit cube) 38 | * $\Gamma_{D_1}=0\times(0,1)\times(0,1)$ (first part of Dirichlet boundary) 39 | * $\boldsymbol{u}_{D_1}= [0,(0.5+(x_2−0.5)\textrm{cos}(\pi/3)−(x_3−0.5)\textrm{sin}(\pi/3)−x_2)/2, (0.5+(x_2−0.5)\textrm{sin}(\pi/3)+(x_3−0.5)\textrm{cos}(\pi/3)−x_3)/2]$ 40 | * $\Gamma_{D_2}=1\times(0,1)\times(0,1)$ (second part of Dirichlet boundary) 41 | * $\boldsymbol{u}_{D_2}=[0,0,0]$ 42 | * $b=[0, 0, 0]$ 43 | * $t=[0, 0, 0]$ 44 | 45 | ## Execution 46 | Run 47 | ```bash 48 | python -m demos.hyperelasticity.example 49 | ``` 50 | from the `jax-fem/` directory. 51 | 52 | 53 | ## Results 54 | 55 | Visualized with *ParaWiew* "Warp By Vector" function: 56 | 57 |
58 |
59 |
61 | Solution 62 |
63 | 64 | ## References 65 | 66 | [1] https://fenicsproject.org/olddocs/dolfin/1.5.0/python/demo/documented/hyperelasticity/python/documentation.html 67 | 68 | [2] https://en.wikipedia.org/wiki/Neo-Hookean_solid -------------------------------------------------------------------------------- /demos/hyperelasticity/example.py: -------------------------------------------------------------------------------- 1 | # Import some useful modules. 2 | import jax 3 | import jax.numpy as np 4 | import os 5 | 6 | 7 | # Import JAX-FEM specific modules. 8 | from jax_fem.problem import Problem 9 | from jax_fem.solver import solver 10 | from jax_fem.utils import save_sol 11 | from jax_fem.generate_mesh import box_mesh_gmsh, get_meshio_cell_type, Mesh 12 | 13 | 14 | # Define constitutive relationship. 15 | class HyperElasticity(Problem): 16 | # The function 'get_tensor_map' overrides base class method. Generally, JAX-FEM 17 | # solves -div(f(u_grad)) = b. Here, we define f(u_grad) = P. Notice how we first 18 | # define 'psi' (representing W), and then use automatic differentiation (jax.grad) 19 | # to obtain the 'P_fn' function. 20 | def get_tensor_map(self): 21 | 22 | def psi(F): 23 | E = 10. 24 | nu = 0.3 25 | mu = E / (2. * (1. + nu)) 26 | kappa = E / (3. * (1. - 2. * nu)) 27 | J = np.linalg.det(F) 28 | Jinv = J**(-2. / 3.) 29 | I1 = np.trace(F.T @ F) 30 | energy = (mu / 2.) * (Jinv * I1 - 3.) + (kappa / 2.) * (J - 1.)**2. 31 | return energy 32 | 33 | P_fn = jax.grad(psi) 34 | 35 | def first_PK_stress(u_grad): 36 | I = np.eye(self.dim) 37 | F = u_grad + I 38 | P = P_fn(F) 39 | return P 40 | 41 | return first_PK_stress 42 | 43 | 44 | # Specify mesh-related information (first-order hexahedron element). 45 | ele_type = 'HEX8' 46 | cell_type = get_meshio_cell_type(ele_type) 47 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 48 | Lx, Ly, Lz = 1., 1., 1. 49 | meshio_mesh = box_mesh_gmsh(Nx=20, 50 | Ny=20, 51 | Nz=20, 52 | Lx=Lx, 53 | Ly=Ly, 54 | Lz=Lz, 55 | data_dir=data_dir, 56 | ele_type=ele_type) 57 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 58 | 59 | 60 | # Define boundary locations. 61 | def left(point): 62 | return np.isclose(point[0], 0., atol=1e-5) 63 | 64 | 65 | def right(point): 66 | return np.isclose(point[0], Lx, atol=1e-5) 67 | 68 | 69 | # Define Dirichlet boundary values. 70 | def zero_dirichlet_val(point): 71 | return 0. 72 | 73 | 74 | def dirichlet_val_x2(point): 75 | return (0.5 + (point[1] - 0.5) * np.cos(np.pi / 3.) - 76 | (point[2] - 0.5) * np.sin(np.pi / 3.) - point[1]) / 2. 77 | 78 | 79 | def dirichlet_val_x3(point): 80 | return (0.5 + (point[1] - 0.5) * np.sin(np.pi / 3.) + 81 | (point[2] - 0.5) * np.cos(np.pi / 3.) - point[2]) / 2. 82 | 83 | 84 | dirichlet_bc_info = [[left] * 3 + [right] * 3, [0, 1, 2] * 2, 85 | [zero_dirichlet_val, dirichlet_val_x2, dirichlet_val_x3] + 86 | [zero_dirichlet_val] * 3] 87 | 88 | 89 | # Create an instance of the problem. 90 | problem = HyperElasticity(mesh, 91 | vec=3, 92 | dim=3, 93 | ele_type=ele_type, 94 | dirichlet_bc_info=dirichlet_bc_info) 95 | # Solve the defined problem. 96 | sol_list = solver(problem, solver_options={'petsc_solver': {}}) 97 | 98 | # Store the solution to local file. 99 | vtk_path = os.path.join(data_dir, f'vtk/u.vtu') 100 | save_sol(problem.fes[0], sol_list[0], vtk_path) 101 | -------------------------------------------------------------------------------- /demos/hyperelasticity/materials/sol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/hyperelasticity/materials/sol.png -------------------------------------------------------------------------------- /demos/inverse/README.md: -------------------------------------------------------------------------------- 1 | # Automatic differentiation 2 | 3 | ## Formulation 4 | 5 | In this tutorial, we demostrate the process to calculate the derivative by automatic differentiation and validate the results by the finite difference method. The same hyperelastic body as in our [hyperelasticity example](https://github.com/tianjuxue/jax-fem/tree/main/demos/hyperelasticity) is considered here, i.e., a unit cube with a neo-Hookean solid model. In addition, we have the following definitions: 6 | * $\Omega=(0,1)\times(0,1)\times(0,1)$ (a unit cube) 7 | * $b=[0, 0, 0]$ 8 | * $\Gamma_{D}=(0,1)\times(0,1)\times0$ 9 | * $\boldsymbol{u}_{D}=[0,0,\beta]$ 10 | * $\Gamma_{N_1}=(0,1)\times(0,1)\times1$ 11 | * $\boldsymbol{t}_{N_1}=[0, 0, -1000]$ 12 | * $\Gamma_{N_2}=\partial\Omega\backslash(\Gamma_{D}\cup\Gamma_{N_1})$ 13 | * $\boldsymbol{t}_{N_2}=[0, 0, 0]$ 14 | 15 | The objective function is defined as: 16 | 17 | $$J= \sum_{i=1}^{N_d}(\boldsymbol{u}[i])^2$$ 18 | where $N_d$ is the total number of degrees of freedom. $\boldsymbol{u}[i]$ is the $i\text{th}$ component of the dispalcement vector $\boldsymbol{u}$, which is obtained by solving the following discretized governing PDE: 19 | 20 | $$ 21 | \boldsymbol{C}(\boldsymbol{u},\boldsymbol{\alpha}_1,\boldsymbol{\alpha}_2,...\boldsymbol{\alpha}_N)=\boldsymbol{0} 22 | $$ 23 | 24 | where $\boldsymbol{\alpha}_1,\boldsymbol{\alpha}_2,...\boldsymbol{\alpha}_N$ are the parameter vectors. Here, we set up three parameters, $\boldsymbol{\alpha}_1 = \boldsymbol{E}$ the elasticity modulus, $\boldsymbol{\alpha}_2 =\boldsymbol{\rho}$ the material density, and $\boldsymbol{\alpha}_3 =\boldsymbol{\beta}$ the scale factor of the Dirichlet boundary conditions. 25 | 26 | We can see that $\boldsymbol{u}(\boldsymbol{\alpha}_1,\boldsymbol{\alpha}_2,...\boldsymbol{\alpha}_N)$ is the implicit function of the parameter vectors. In JAX-FEM, users can easily compute the derivative of the objective function with respect to these parameters through automatic differentiation. We first wrap the forward problem with the function `jax_fem.solver.ad_wrapper`, which defines the implicit differentiation through `@jax.custom_vjp`. Next, we can use the `jax.grad` to calculate the derivative. 27 | 28 | 29 | We then use the forward differnce scheme to validate the results. The derivative of the objective with respect to the $k\text{th}$ component of the parameter vector $\boldsymbol{\alpha}_i$ is defined as: 30 | $$\frac{\partial J}{\partial \boldsymbol{\alpha}_i[k]} = \frac{J(\boldsymbol{\alpha}_i+h\boldsymbol{\alpha}_i[k])-J(\boldsymbol{\alpha}_i)}{h\boldsymbol{\alpha}_i[k]}$$ 31 | 32 | where $h$ is a small perturbation. 33 | 34 | 35 | 36 | ## Execution 37 | Run 38 | ```bash 39 | python -m demos.inverse.example 40 | ``` 41 | from the `jax-fem/` directory. 42 | 43 | 44 | ## Results 45 | 46 | ```bash 47 | Derivative comparison between automatic differentiation (AD) and finite difference (FD) 48 | dE = 4.0641751938577116e-07, dE_fd = 0.0, WRONG results! Please avoid gradients w.r.t self.E 49 | drho[0, 0] = 0.002266954599447443, drho_fd_00 = 0.0022666187078357325 50 | dscale_d = 431.59223609853564, dscale_d_fd = 431.80823609844765 51 | ``` 52 | -------------------------------------------------------------------------------- /demos/inverse/example.py: -------------------------------------------------------------------------------- 1 | # Import some useful modules. 2 | import numpy as onp 3 | import jax 4 | import jax.numpy as np 5 | import os 6 | import glob 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | # Import JAX-FEM specific modules. 11 | from jax_fem.problem import Problem 12 | from jax_fem.solver import solver, ad_wrapper 13 | from jax_fem.utils import save_sol 14 | from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, box_mesh_gmsh 15 | 16 | 17 | # Define constitutive relationship. 18 | class HyperElasticity(Problem): 19 | def custom_init(self): 20 | self.fe = self.fes[0] 21 | 22 | def get_tensor_map(self): 23 | def psi(F, rho): 24 | E = self.E * rho 25 | nu = 0.3 26 | mu = E/(2.*(1. + nu)) 27 | kappa = E/(3.*(1. - 2.*nu)) 28 | J = np.linalg.det(F) 29 | Jinv = J**(-2./3.) 30 | I1 = np.trace(F.T @ F) 31 | energy = (mu/2.)*(Jinv*I1 - 3.) + (kappa/2.) * (J - 1.)**2. 32 | return energy 33 | P_fn = jax.grad(psi) 34 | 35 | def first_PK_stress(u_grad, rho): 36 | I = np.eye(self.dim) 37 | F = u_grad + I 38 | P = P_fn(F, rho) 39 | return P 40 | return first_PK_stress 41 | 42 | def get_surface_maps(self): 43 | def surface_map(u, x): 44 | return np.array([0., 0., 1e3]) 45 | 46 | return [surface_map] 47 | 48 | def set_params(self, params): 49 | E, rho, scale_d = params 50 | self.E = E 51 | self.internal_vars = [rho] 52 | self.fe.dirichlet_bc_info[-1][-1] = get_dirichlet_bottom(scale_d) 53 | self.fe.update_Dirichlet_boundary_conditions(self.fe.dirichlet_bc_info) 54 | 55 | 56 | # Specify mesh-related information (first-order hexahedron element). 57 | ele_type = 'HEX8' 58 | cell_type = get_meshio_cell_type(ele_type) 59 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 60 | Lx, Ly, Lz = 1., 1., 1. 61 | meshio_mesh = box_mesh_gmsh(Nx=5, Ny=5, Nz=5, Lx=Lx, Ly=Ly, Lz=Lz, data_dir=data_dir, ele_type=ele_type) 62 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 63 | 64 | 65 | # Define Dirichlet boundary values. 66 | def get_dirichlet_bottom(scale): 67 | def dirichlet_bottom(point): 68 | z_disp = scale*Lz 69 | return z_disp 70 | return dirichlet_bottom 71 | 72 | def zero_dirichlet_val(point): 73 | return 0. 74 | 75 | 76 | # Define boundary locations. 77 | def bottom(point): 78 | return np.isclose(point[2], 0., atol=1e-5) 79 | 80 | def top(point): 81 | return np.isclose(point[2], Lz, atol=1e-5) 82 | 83 | dirichlet_bc_info = [[bottom]*3, [0, 1, 2], [zero_dirichlet_val]*2 + [get_dirichlet_bottom(1.)]] 84 | location_fns = [top] 85 | 86 | 87 | # Create an instance of the problem. 88 | problem = HyperElasticity(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns) 89 | 90 | 91 | # Define parameters. 92 | rho = 0.5*np.ones((problem.fe.num_cells, problem.fe.num_quads)) 93 | E = 1.e6 94 | scale_d = 1. 95 | params = [E, rho, scale_d] 96 | 97 | 98 | # Implicit differentiation wrapper 99 | fwd_pred = ad_wrapper(problem) 100 | sol_list = fwd_pred(params) 101 | 102 | vtk_path = os.path.join(data_dir, f'vtk/u.vtu') 103 | save_sol(problem.fe, sol_list[0], vtk_path) 104 | 105 | def test_fn(sol_list): 106 | return np.sum(sol_list[0]**2) 107 | 108 | def composed_fn(params): 109 | return test_fn(fwd_pred(params)) 110 | 111 | val = test_fn(sol_list) 112 | 113 | h = 1e-3 # small perturbation 114 | 115 | 116 | # Forward difference 117 | E_plus = (1 + h)*E 118 | params_E = [E_plus, rho, scale_d] 119 | dE_fd = (composed_fn(params_E) - val)/(h*E) 120 | 121 | rho_plus = rho.at[0, 0].set((1 + h)*rho[0, 0]) 122 | params_rho = [E, rho_plus, scale_d] 123 | drho_fd_00 = (composed_fn(params_rho) - val)/(h*rho[0, 0]) 124 | 125 | scale_d_plus = (1 + h)*scale_d 126 | params_scale_d = [E, rho, scale_d_plus] 127 | dscale_d_fd = (composed_fn(params_scale_d) - val)/(h*scale_d) 128 | 129 | # Derivative obtained by automatic differentiation 130 | dE, drho, dscale_d = jax.grad(composed_fn)(params) 131 | 132 | # Comparison 133 | print(f"\nDerivative comparison between automatic differentiation (AD) and finite difference (FD)") 134 | print(f"\ndrho[0, 0] = {drho[0, 0]}, drho_fd_00 = {drho_fd_00}") 135 | print(f"\ndscale_d = {dscale_d}, dscale_d_fd = {dscale_d_fd}") 136 | 137 | print(f"\ndE = {dE}, dE_fd = {dE_fd}, WRONG results! Please avoid gradients w.r.t self.E") 138 | print(f"This is due to the use of global variable self.E, inside a jax jitted function.") 139 | 140 | # TODO: show the following will cause an error or not? 141 | # dE_E, _, _ = jax.grad(composed_fn)(params_E) 142 | -------------------------------------------------------------------------------- /demos/linear_elasticity/README.md: -------------------------------------------------------------------------------- 1 | # Linear Elasticity 2 | 3 | ## Formulation 4 | 5 | The governing equation for linear elasticity of a body $\Omega$ can be written as 6 | 7 | $$ 8 | \begin{align*} 9 | -\nabla \cdot \boldsymbol{\sigma} = \boldsymbol{b} & \quad \textrm{in} \nobreakspace \nobreakspace \Omega, \\ 10 | \boldsymbol{u} = \boldsymbol{u}_D & \quad\textrm{on} \nobreakspace \nobreakspace \Gamma_D, \\ 11 | \boldsymbol{\sigma} \cdot \boldsymbol{n} = \boldsymbol{t} & \quad \textrm{on} \nobreakspace \nobreakspace \Gamma_N. 12 | \end{align*} 13 | $$ 14 | 15 | The weak form gives 16 | 17 | $$ 18 | \begin{align*} 19 | \int_{\Omega} \boldsymbol{\sigma} : \nabla \boldsymbol{v} \nobreakspace \nobreakspace \textrm{d}x = \int_{\Omega} \boldsymbol{b} \cdot \boldsymbol{v} \nobreakspace \textrm{d}x + \int_{\Gamma_N} \boldsymbol{t} \cdot \boldsymbol{v} \nobreakspace\nobreakspace \textrm{d}s. 20 | \end{align*} 21 | $$ 22 | 23 | In this example, we consider a vertical bending load applied to the right side of the beam ($\boldsymbol{t}=[0, 0, -100]$) while fixing the left side ($\boldsymbol{u}_D=[0,0,0]$), and ignore body force ($\boldsymbol{b}=[0,0,0]$). 24 | 25 | The constitutive relationship is given by 26 | 27 | $$ 28 | \begin{align*} 29 | \boldsymbol{\sigma} &= \lambda \nobreakspace \textrm{tr}(\boldsymbol{\varepsilon}) \boldsymbol{I} + 2\mu \nobreakspace \boldsymbol{\varepsilon}, \\ 30 | \boldsymbol{\varepsilon} &= \frac{1}{2}\left[\nabla\boldsymbol{u} + (\nabla\boldsymbol{u})^{\top}\right]. 31 | \end{align*} 32 | $$ 33 | 34 | ## Execution 35 | Run 36 | ```bash 37 | python -m demos.linear_elasticity.example 38 | ``` 39 | from the `jax-fem/` directory. 40 | 41 | 42 | ## Results 43 | 44 | Visualized with *ParaWiew*: 45 | 46 |
47 |
48 |
50 | Solution 51 |
-------------------------------------------------------------------------------- /demos/linear_elasticity/example.py: -------------------------------------------------------------------------------- 1 | # Import some useful modules. 2 | import jax.numpy as np 3 | import numpy as onp 4 | import os 5 | import pypardiso 6 | import scipy 7 | 8 | # Import JAX-FEM specific modules. 9 | from jax_fem.problem import Problem 10 | from jax_fem.solver import solver 11 | from jax_fem.utils import save_sol 12 | from jax_fem.generate_mesh import box_mesh_gmsh, get_meshio_cell_type, Mesh 13 | from jax_fem import logger 14 | 15 | import logging 16 | logger.setLevel(logging.DEBUG) 17 | 18 | 19 | def pardiso_solver(A, b, x0, solver_options): 20 | """ 21 | Solves Ax=b with x0 being the initial guess. 22 | 23 | A: PETSc sparse matrix 24 | b: JAX array 25 | x0: JAX array (forward problem) or None (adjoint problem) 26 | solver_options: anything the user defines, at least satisfying solver_options['custom_solver'] = pardiso_solver 27 | """ 28 | logger.debug(f"Pardiso Solver - Solving linear system") 29 | 30 | # If you need to convert PETSc to scipy 31 | indptr, indices, data = A.getValuesCSR() 32 | A_sp_scipy = scipy.sparse.csr_array((data, indices, indptr), shape=A.getSize()) 33 | x = pypardiso.spsolve(A_sp_scipy, onp.array(b)) 34 | return x 35 | 36 | 37 | # Material properties. 38 | E = 70e3 39 | nu = 0.3 40 | mu = E/(2.*(1.+nu)) 41 | lmbda = E*nu/((1+nu)*(1-2*nu)) 42 | 43 | 44 | # Weak forms. 45 | class LinearElasticity(Problem): 46 | # The function 'get_tensor_map' overrides base class method. Generally, JAX-FEM 47 | # solves -div(f(u_grad)) = b. Here, we have f(u_grad) = sigma. 48 | def get_tensor_map(self): 49 | def stress(u_grad): 50 | epsilon = 0.5 * (u_grad + u_grad.T) 51 | sigma = lmbda * np.trace(epsilon) * np.eye(self.dim) + 2*mu*epsilon 52 | return sigma 53 | return stress 54 | 55 | def get_surface_maps(self): 56 | def surface_map(u, x): 57 | return np.array([0., 0., 100.]) 58 | return [surface_map] 59 | 60 | 61 | # Specify mesh-related information (second-order tetrahedron element). 62 | ele_type = 'TET10' 63 | cell_type = get_meshio_cell_type(ele_type) 64 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 65 | Lx, Ly, Lz = 10., 2., 2. 66 | Nx, Ny, Nz = 25, 5, 5 67 | meshio_mesh = box_mesh_gmsh(Nx=Nx, 68 | Ny=Ny, 69 | Nz=Nz, 70 | Lx=Lx, 71 | Ly=Ly, 72 | Lz=Lz, 73 | data_dir=data_dir, 74 | ele_type=ele_type) 75 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 76 | 77 | 78 | # Define boundary locations. 79 | def left(point): 80 | return np.isclose(point[0], 0., atol=1e-5) 81 | 82 | def right(point): 83 | return np.isclose(point[0], Lx, atol=1e-5) 84 | 85 | 86 | # Define Dirichlet boundary values. 87 | # This means on the 'left' side, we apply the function 'zero_dirichlet_val' 88 | # to all components of the displacement variable u. 89 | def zero_dirichlet_val(point): 90 | return 0. 91 | 92 | dirichlet_bc_info = [[left] * 3, [0, 1, 2], [zero_dirichlet_val] * 3] 93 | 94 | 95 | # Define Neumann boundary locations. 96 | # This means on the 'right' side, we will perform the surface integral to get 97 | # the tractions with the function 'get_surface_maps' defined in the class 'LinearElasticity'. 98 | location_fns = [right] 99 | 100 | 101 | # Create an instance of the problem. 102 | problem = LinearElasticity(mesh, 103 | vec=3, 104 | dim=3, 105 | ele_type=ele_type, 106 | dirichlet_bc_info=dirichlet_bc_info, 107 | location_fns=location_fns) 108 | 109 | # Solve the defined problem. 110 | sol_list = solver(problem, solver_options={'custom_solver': pardiso_solver}) 111 | # sol_list = solver(problem, solver_options={'umfpack_solver': {}}) 112 | 113 | # Postprocess for stress evaluations 114 | # (num_cells, num_quads, vec, dim) 115 | u_grad = problem.fes[0].sol_to_grad(sol_list[0]) 116 | epsilon = 0.5 * (u_grad + u_grad.transpose(0,1,3,2)) 117 | # (num_cells, bnum_quads, 1, 1) * (num_cells, num_quads, vec, dim) 118 | # -> (num_cells, num_quads, vec, dim) 119 | sigma = lmbda * np.trace(epsilon, axis1=2, axis2=3)[:,:,None,None] * np.eye(problem.dim) + 2*mu*epsilon 120 | # (num_cells, num_quads) 121 | cells_JxW = problem.JxW[:,0,:] 122 | # (num_cells, num_quads, vec, dim) * (num_cells, num_quads, 1, 1) -> 123 | # (num_cells, vec, dim) / (num_cells, 1, 1) 124 | # --> (num_cells, vec, dim) 125 | sigma_average = np.sum(sigma * cells_JxW[:,:,None,None], axis=1) / np.sum(cells_JxW, axis=1)[:,None,None] 126 | 127 | # Von Mises stress 128 | # (num_cells, dim, dim) 129 | s_dev = (sigma_average - 1/problem.dim * np.trace(sigma_average, axis1=1, axis2=2)[:,None,None] 130 | * np.eye(problem.dim)[None,:,:]) 131 | # (num_cells,) 132 | vm_stress = np.sqrt(3./2.*np.sum(s_dev*s_dev, axis=(1,2))) 133 | 134 | # Store the solution to local file. 135 | vtk_path = os.path.join(data_dir, 'vtk/u.vtu') 136 | save_sol(problem.fes[0], sol_list[0], vtk_path, cell_infos=[('vm_stress', vm_stress)]) 137 | -------------------------------------------------------------------------------- /demos/linear_elasticity/materials/sol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/linear_elasticity/materials/sol.png -------------------------------------------------------------------------------- /demos/phase_field_fracture/animation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jax_fem.utils import make_video 3 | 4 | data_path = os.path.join(os.path.dirname(__file__), 'data') 5 | make_video(data_path) -------------------------------------------------------------------------------- /demos/phase_field_fracture/eigen.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import os 4 | 5 | jax.config.update("jax_enable_x64", True) 6 | 7 | np.set_printoptions(precision=10, suppress=True) 8 | 9 | 10 | def get_eigen_f_jax(fn): 11 | fn_vmap = jax.vmap(fn) 12 | def eigen_f_jax(x): 13 | evals, evecs = np.linalg.eigh(x) 14 | evecs = evecs.T 15 | M = np.einsum('bi,bj->bij', evecs, evecs) 16 | # [batch, dim, dim] * [batch, 1, 1] -> [dim, dim] 17 | result = np.sum(M * fn_vmap(evals)[:, None, None], axis=0) 18 | return result 19 | return eigen_f_jax 20 | 21 | 22 | def get_eigen_f_custom(fn): 23 | grad_fn = jax.grad(fn) 24 | fn_vmap = jax.vmap(fn) 25 | grad_fn_vmap = jax.vmap(grad_fn) 26 | 27 | @jax.custom_jvp 28 | def eigen_f(x): 29 | evals, evecs = np.linalg.eigh(x) 30 | evecs = evecs.T 31 | M = np.einsum('bi,bj->bij', evecs, evecs) 32 | # [batch, dim, dim] * [batch, 1, 1] -> [dim, dim] 33 | result = np.sum(M * fn_vmap(evals)[:, None, None], axis=0) 34 | return result 35 | 36 | @eigen_f.defjvp 37 | def f_jvp(primals, tangents): 38 | """Impelemtation of Miehe's paper (https://doi.org/10.1002/cnm.404) Eq. (19) 39 | """ 40 | x, = primals 41 | v, = tangents 42 | 43 | evals, evecs = np.linalg.eigh(x) 44 | fvals = fn_vmap(evals) 45 | grads = grad_fn_vmap(evals) 46 | evecs = evecs.T 47 | 48 | M = np.einsum('bi,bj->bij', evecs, evecs) 49 | 50 | result = np.sum(M * fvals[:, None, None], axis=0) 51 | 52 | MM = np.einsum('bij,bkl->bijkl', M, M) 53 | # [batch, dim, dim, dim, dim] * [batch, 1, 1, 1, 1] -> [dim, dim, dim, dim] 54 | term1 = np.sum(MM * grads[:, None, None, None, None], axis=0) 55 | 56 | G = np.einsum('aik,bjl->abijkl', M, M) + np.einsum('ail,bjk->abijkl', M, M) 57 | 58 | diff_evals = evals[:, None] - evals[None, :] 59 | diff_fvals = fvals[:, None] - fvals[None, :] 60 | diff_grads = grads[:, None] 61 | 62 | theta = np.where(diff_evals == 0., diff_grads, diff_fvals/diff_evals) 63 | 64 | tmp = G * theta[:, :, None, None, None, None] 65 | tmp1 = np.sum(tmp, axis=(0, 1)) 66 | tmp2 = np.einsum('aa...->...', tmp) 67 | term2 = 0.5*(tmp1 - tmp2) 68 | 69 | P = term1 + term2 70 | jvp_result = np.einsum('ijkl,kl', P, v) 71 | 72 | return result, jvp_result 73 | 74 | return eigen_f 75 | 76 | 77 | def test_eigen_f(): 78 | """When repeated eigenvalues occur, JAX fail to find out the correct derivative (returning NaN) 79 | See the discussion: 80 | https://github.com/google/jax/issues/669 81 | Also see the basic derivation for the case of distinct eigenvalues: 82 | https://mathoverflow.net/questions/229425/derivative-of-eigenvectors-of-a-matrix-with-respect-to-its-components 83 | 84 | Here, we followed Miehe's approach to provide a solution to repeated eigenvalue case: 85 | https://doi.org/10.1002/cnm.404 86 | 87 | You will see how get_eigen_f_jax works for the variables a and b, but fails for c. 88 | Yet, our implementation of get_eigen_f_custom works for all the variables a, b and c. 89 | """ 90 | a = np.array([[1., -2., 3.], 91 | [-2., 5., 7.], 92 | [3., 7., 10.]]) 93 | 94 | key = jax.random.PRNGKey(0) 95 | b = jax.random.uniform(key, shape=(5, 5), minval=-0.1, maxval=0.1) 96 | b = 0.5*(b + b.T) 97 | 98 | c = np.zeros((3, 3)) 99 | 100 | fn = lambda x: x 101 | 102 | input_vars = [a, b, c] 103 | 104 | eigen_f_jax = get_eigen_f_jax(fn) 105 | eigen_f_custom = get_eigen_f_custom(fn) 106 | 107 | for x in input_vars: 108 | jax_result = jax.jacfwd(eigen_f_jax)(x) 109 | custom_results = jax.jacfwd(eigen_f_custom)(x) 110 | print(f"\nJAX:\n{jax_result}") 111 | print(f"\nCustom:\n{custom_results}") 112 | print(f"\nDiff:\n{jax_result - custom_results}") 113 | 114 | 115 | def f1(x): 116 | unsafe_plus = lambda x: np.maximum(x, 0.) 117 | unsafe_minus = lambda x: np.minimum(x, 0.) 118 | tr_x_plus = unsafe_plus(np.trace(x)) 119 | tr_x_minus = unsafe_minus(np.trace(x)) 120 | return 0.5*tr_x_plus**2 + 0.5*tr_x_minus**2 121 | 122 | def f2(x): 123 | safe_plus = lambda x: 0.5*(x + np.abs(x)) 124 | safe_minus = lambda x: 0.5*(x - np.abs(x)) 125 | tr_x_plus = safe_plus(np.trace(x)) 126 | tr_x_minus = safe_minus(np.trace(x)) 127 | return 0.5*tr_x_plus**2 + 0.5*tr_x_minus**2 128 | 129 | def f_gold(x): 130 | tr_x = np.trace(x) 131 | return 0.5*tr_x**2 132 | 133 | def test_bracket_operator(): 134 | # Different behaviors observed when derivative is taken at x=0. 135 | # The "abs" way of implemetation is preferred 136 | print(f"{jax.grad(lambda x: np.maximum(x, 0.))(0.)}") 137 | print(f"{jax.grad(lambda x: 0.5*(x + np.abs(x)))(0.)}") 138 | 139 | # Further tests 140 | a = np.zeros((3, 3)) 141 | # f1 gives wrong answer 142 | print(f"\nUnsafe:\n{jax.jacfwd(jax.grad(f1))(a)}") 143 | # f2 gieves correct answer 144 | print(f"\nSafe:\n{jax.jacfwd(jax.grad(f2))(a)}") 145 | # f_gold is the ground truth 146 | print(f"\nGround truth:\n{jax.jacfwd(jax.grad(f_gold))(a)}") 147 | 148 | 149 | if __name__ == "__main__": 150 | test_eigen_f() 151 | test_bracket_operator() 152 | -------------------------------------------------------------------------------- /demos/phase_field_fracture/materials/disp_force.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/phase_field_fracture/materials/disp_force.png -------------------------------------------------------------------------------- /demos/phase_field_fracture/materials/fracture.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/phase_field_fracture/materials/fracture.gif -------------------------------------------------------------------------------- /demos/phase_field_fracture/materials/time_disp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/phase_field_fracture/materials/time_disp.png -------------------------------------------------------------------------------- /demos/plasticity/README.md: -------------------------------------------------------------------------------- 1 | # Plasticity 2 | 3 | ## Formulation 4 | 5 | For perfect J2-plasticity model [1], we assume that the total strain $\boldsymbol{\varepsilon}^{n-1}$ and stress $\boldsymbol{\sigma}^{n-1}$ from the previous loading step are known, and the problem states that find the displacement field $\boldsymbol{u}^n$ at the current loading step such that 6 | 7 | $$ 8 | \begin{align*} 9 | -\nabla \cdot \big(\boldsymbol{\sigma}^n (\nabla \boldsymbol{u}^n, \boldsymbol{\varepsilon}^{n-1}, \boldsymbol{\sigma}^{n-1}) \big) = \boldsymbol{b} & \quad \textrm{in} \nobreakspace \nobreakspace \Omega, \nonumber \\ 10 | \boldsymbol{u}^n = \boldsymbol{u}_D & \quad\textrm{on} \nobreakspace \nobreakspace \Gamma_D, \nonumber \\ 11 | \boldsymbol{\sigma}^n \cdot \boldsymbol{n} = \boldsymbol{t} & \quad \textrm{on} \nobreakspace \nobreakspace \Gamma_N. 12 | \end{align*} 13 | $$ 14 | 15 | The stress $\boldsymbol{\sigma}^n$ is defined with the following relationships: 16 | 17 | ```math 18 | \begin{align*} 19 | \boldsymbol{\sigma}_\textrm{trial} &= \boldsymbol{\sigma}^{n-1} + \Delta \boldsymbol{\sigma}, \nonumber\\ 20 | \Delta \boldsymbol{\sigma} &= \lambda \nobreakspace \textrm{tr}(\Delta \boldsymbol{\varepsilon}) \boldsymbol{I} + 2\mu \nobreakspace \Delta \boldsymbol{\varepsilon}, \nonumber \\ 21 | \Delta \boldsymbol{\varepsilon} &= \boldsymbol{\varepsilon}^n - \boldsymbol{\varepsilon}^{n-1} = \frac{1}{2}\left[\nabla\boldsymbol{u}^n + (\nabla\boldsymbol{u}^n)^{\top}\right] - \boldsymbol{\varepsilon}^{n-1}, \nonumber\\ 22 | \boldsymbol{s} &= \boldsymbol{\sigma}_\textrm{trial} - \frac{1}{3}\textrm{tr}(\boldsymbol{\sigma}_\textrm{trial})\boldsymbol{I},\nonumber\\ 23 | s &= \sqrt{\frac{3}{2}\boldsymbol{s}:\boldsymbol{s}}, \nonumber\\ 24 | f_{\textrm{yield}} &= s - \sigma_{\textrm{yield}}, \nonumber\\ 25 | \boldsymbol{\sigma}^n &= \boldsymbol{\sigma}_\textrm{trial} - \frac{\boldsymbol{s}}{s} \langle f_{\textrm{yield}} \rangle_{+}, \nonumber 26 | \end{align*} 27 | ``` 28 | 29 | where $`\boldsymbol{\sigma}_\textrm{trial}`$ is the elastic trial stress, $`\boldsymbol{s}`$ is the devitoric part of $`\boldsymbol{\sigma}_\textrm{trial}`$, $`f_{\textrm{yield}}`$ is the yield function, $`\sigma_{\textrm{yield}}`$ is the yield strength, $`{\langle x \rangle_{+}}:=\frac{1}{2}(x+|x|)`$ is the ramp function, and $`\boldsymbol{\sigma}^n`$ is the stress at the currently loading step. 30 | 31 | 32 | The weak form gives 33 | 34 | $$ 35 | \begin{align*} 36 | \int_{\Omega} \boldsymbol{\sigma}^n : \nabla \boldsymbol{v} \nobreakspace \nobreakspace \textrm{d}x = \int_{\Omega} \boldsymbol{b} \cdot \boldsymbol{v} \nobreakspace \textrm{d}x + \int_{\Gamma_N} \boldsymbol{t} \cdot \boldsymbol{v} \nobreakspace\nobreakspace \textrm{d}s. 37 | \end{align*} 38 | $$ 39 | 40 | In this example, we consider a displacement-controlled uniaxial tensile loading condition. We assume free traction ($\boldsymbol{t}=[0, 0, 0]$) and ignore body force ($\boldsymbol{b}=[0,0,0]$). We assume quasi-static loadings from 0 to 0.1 mm and then unload from 0.1 mm to 0. 41 | 42 | 43 | > :ghost: A remarkable feature of *JAX-FEM* is that automatic differentiation is used to enhance the development efficiency. In this example, deriving the fourth-order elastoplastic tangent moduli tensor $\mathbb{C}=\frac{\partial \boldsymbol{\sigma}^n}{\partial \boldsymbol{\varepsilon}^n}$ is usually required by traditional FEM implementation, but is **NOT** needed in our program due to automatic differentiation. 44 | 45 | 46 | ## Execution 47 | Run 48 | ```bash 49 | python -m demos.plasticity.example 50 | ``` 51 | from the `jax-fem/` directory. 52 | 53 | 54 | ## Results 55 | 56 | Results can be visualized with *ParaWiew*. 57 | 58 |
59 |
60 |
62 | Deformation (x50) 63 |
64 | 65 | Plot of the $`z-z`$ component of volume-averaged stress versus displacement of the top surface: 66 | 67 | 68 |
69 |
70 |
72 | Stress-strain curve 73 |
74 | 75 | ## References 76 | 77 | [1] Simo, Juan C., and Thomas JR Hughes. *Computational inelasticity*. Vol. 7. Springer Science & Business Media, 2006. -------------------------------------------------------------------------------- /demos/plasticity/materials/sol.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/plasticity/materials/sol.gif -------------------------------------------------------------------------------- /demos/plasticity/materials/stress_strain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/plasticity/materials/stress_strain.png -------------------------------------------------------------------------------- /demos/poisson/README.md: -------------------------------------------------------------------------------- 1 | # Poisson's Equation 2 | 3 | ## Formulation 4 | 5 | The Poisson's equation is the canonical elliptic partial differential equation. Consider a domain $\Omega \subset \mathbb{R}^\textrm{d}$ with boundary $\partial \Omega = \Gamma_D \cup \Gamma_N$, the strong form gives 6 | 7 | $$ 8 | \begin{align*} 9 | -\nabla^2 u = b & \quad \textrm{in} \nobreakspace \nobreakspace \Omega, \\ 10 | u = 0 & \quad\textrm{on} \nobreakspace \nobreakspace \Gamma_D, \\ 11 | \nabla u \cdot \boldsymbol{n} = t & \quad \textrm{on} \nobreakspace \nobreakspace \Gamma_N. 12 | \end{align*} 13 | $$ 14 | 15 | The weak form gives 16 | 17 | $$ 18 | \begin{align*} 19 | \int_{\Omega} \nabla u \cdot \nabla v \nobreakspace \nobreakspace \textrm{d}x = \int_{\Omega} b \nobreakspace v \nobreakspace \textrm{d}x + \int_{\Gamma_N} t\nobreakspace v \nobreakspace\nobreakspace \textrm{d}s. 20 | \end{align*} 21 | $$ 22 | 23 | We have the following definitions: 24 | * $\Omega=(0,1)\times(0,1)$ (a unit square) 25 | * $\Gamma_D=\{(0, x_2)\cup (1, x_2)\subset\partial\Omega\}$ (Dirichlet boundary) 26 | * $\Gamma_N=\{(x_1, 0)\cup (x_1, 1)\subset\partial\Omega\}$ (Neumann boundary) 27 | * $b=10\nobreakspace\textrm{exp}\big(-((x_1-0.5)^2+(x_2-0.5)^2)/0.02 \big)$ 28 | * $t=\textrm{sin}(5x_1)$ 29 | 30 | ## Execution 31 | Run 32 | ```bash 33 | python -m demos.poisson.example 34 | ``` 35 | from the `jax-fem/` directory. 36 | 37 | 38 | ## Results 39 | 40 |
41 |
42 |
44 | Solution 45 |
46 | 47 | 48 | ## References 49 | 50 | [1] https://fenicsproject.org/olddocs/dolfin/1.3.0/python/demo/documented/poisson/python/documentation.html 51 | 52 | [2] Xue, Tianju, et al. "JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science." *Computer Physics Communications* (2023): 108802. -------------------------------------------------------------------------------- /demos/poisson/example.py: -------------------------------------------------------------------------------- 1 | # Import some generally useful packages. 2 | import jax 3 | import jax.numpy as np 4 | import os 5 | 6 | 7 | # Import JAX-FEM specific modules. 8 | from jax_fem.problem import Problem 9 | from jax_fem.solver import solver 10 | from jax_fem.utils import save_sol 11 | from jax_fem.generate_mesh import get_meshio_cell_type, Mesh, rectangle_mesh 12 | 13 | 14 | # Define constitutive relationship. 15 | class Poisson(Problem): 16 | # The function 'get_tensor_map' overrides base class method. Generally, JAX-FEM 17 | # solves -div.f(u_grad) = b. Here, we define f to be the indentity function. 18 | # We will see how f is deined as more complicated to solve non-linear problems 19 | # in later examples. 20 | def get_tensor_map(self): 21 | return lambda x: x 22 | 23 | # Define the source term b 24 | def get_mass_map(self): 25 | def mass_map(u, x): 26 | val = -np.array([10*np.exp(-(np.power(x[0] - 0.5, 2) + np.power(x[1] - 0.5, 2)) / 0.02)]) 27 | return val 28 | return mass_map 29 | 30 | def get_surface_maps(self): 31 | def surface_map(u, x): 32 | return -np.array([np.sin(5.*x[0])]) 33 | 34 | return [surface_map, surface_map] 35 | 36 | 37 | # Specify mesh-related information. 38 | # We make use of the external package 'meshio' and create a mesh named 'meshio_mesh', 39 | # then converting it into a JAX-FEM compatible one. 40 | ele_type = 'QUAD4' 41 | cell_type = get_meshio_cell_type(ele_type) 42 | Lx, Ly = 1., 1. 43 | meshio_mesh = rectangle_mesh(Nx=32, Ny=32, domain_x=Lx, domain_y=Ly) 44 | mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) 45 | 46 | 47 | # Define boundary locations. 48 | def left(point): 49 | return np.isclose(point[0], 0., atol=1e-5) 50 | 51 | def right(point): 52 | return np.isclose(point[0], Lx, atol=1e-5) 53 | 54 | def bottom(point): 55 | return np.isclose(point[1], 0., atol=1e-5) 56 | 57 | def top(point): 58 | return np.isclose(point[1], Ly, atol=1e-5) 59 | 60 | 61 | # Define Dirichlet boundary values. 62 | # This means on the 'left' side, we apply the function 'dirichlet_val_left' 63 | # to the 0 component of the solution variable; on the 'right' side, we apply 64 | # 'dirichlet_val_right' to the 0 component. 65 | def dirichlet_val_left(point): 66 | return 0. 67 | 68 | def dirichlet_val_right(point): 69 | return 0. 70 | 71 | location_fns_dirichlet = [left, right] 72 | value_fns = [dirichlet_val_left, dirichlet_val_right] 73 | vecs = [0, 0] 74 | dirichlet_bc_info = [location_fns_dirichlet, vecs, value_fns] 75 | 76 | 77 | # Define Neumann boundary locations. 78 | # This means on the 'bottom' and 'top' side, we will perform the surface integral 79 | # with the function 'get_surface_maps' defined in the class 'Poisson'. 80 | location_fns = [bottom, top] 81 | 82 | 83 | # Create an instance of the Class 'Poisson'. 84 | # Here, vec is the number of components for the solution. 85 | problem = Poisson(mesh=mesh, vec=1, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns) 86 | 87 | 88 | # Solve the problem. 89 | # solver_options can be changed for other linear solver options 90 | sol = solver(problem) 91 | # sol = solver(problem, solver_options={'umfpack_solver': {}}) 92 | # sol = solver(problem, solver_options={'petsc_solver': {'ksp_type': 'bcgsl', 'pc_type': 'ilu'}}) 93 | 94 | # Save the solution to a local folder that can be visualized with ParaWiew. 95 | data_dir = os.path.join(os.path.dirname(__file__), 'data') 96 | vtk_path = os.path.join(data_dir, f'vtk/u.vtu') 97 | save_sol(problem.fes[0], sol[0], vtk_path) 98 | -------------------------------------------------------------------------------- /demos/poisson/materials/sol.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/poisson/materials/sol.png -------------------------------------------------------------------------------- /demos/thermal_mechanical/animation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jax_fem.utils import make_video 3 | 4 | data_path = os.path.join(os.path.dirname(__file__), 'data') 5 | make_video(data_path) -------------------------------------------------------------------------------- /demos/thermal_mechanical/materials/T.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical/materials/T.gif -------------------------------------------------------------------------------- /demos/thermal_mechanical/materials/line.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical/materials/line.gif -------------------------------------------------------------------------------- /demos/thermal_mechanical/materials/phase.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical/materials/phase.gif -------------------------------------------------------------------------------- /demos/thermal_mechanical/materials/value.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical/materials/value.gif -------------------------------------------------------------------------------- /demos/thermal_mechanical_full/fenics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dolfin import * 4 | from mshr import * 5 | import numpy as onp 6 | import matplotlib.pyplot as plt 7 | 8 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 9 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 10 | 11 | # Define domain and mesh 12 | L = 1. 13 | R = 0.1 14 | N = 50 # mesh density 15 | 16 | domain = Rectangle(Point(0., 0.), Point(L, L)) - Circle(Point(0., 0.), R, 100) 17 | mesh = generate_mesh(domain, N) 18 | 19 | # Define parameters 20 | T0 = Constant(293.) # ambient temperature 21 | DThole = Constant(10.) # temperature change at hole boundary 22 | E = 70e3 23 | nu = 0.3 24 | lmbda = Constant(E*nu/((1+nu)*(1-2*nu))) 25 | mu = Constant(E/2/(1+nu)) 26 | rho = Constant(2700.) # density 27 | alpha = 2.31e-5 # thermal expansion coefficient 28 | kappa = Constant(alpha*(2*mu + 3*lmbda)) 29 | cV = Constant(910e-6)*rho # specific heat per unit volume at constant strain 30 | k = Constant(237e-6) # thermal conductivity 31 | 32 | # Build function space 33 | Vue = VectorElement('CG', mesh.ufl_cell(), 1) # displacement finite element 34 | Vte = FiniteElement('CG', mesh.ufl_cell(), 1) # temperature finite element 35 | V = FunctionSpace(mesh, MixedElement([Vue, Vte])) 36 | 37 | # Boundary condition 38 | def inner_boundary(x, on_boundary): 39 | return near(x[0]**2+x[1]**2, R**2, 1e-3) and on_boundary 40 | def bottom(x, on_boundary): 41 | return near(x[1], 0) and on_boundary 42 | def left(x, on_boundary): 43 | return near(x[0], 0) and on_boundary 44 | 45 | bc1 = DirichletBC(V.sub(0).sub(1), Constant(0.), bottom) 46 | bc2 = DirichletBC(V.sub(0).sub(0), Constant(0.), left) 47 | bc3 = DirichletBC(V.sub(1), DThole, inner_boundary) 48 | bcs = [bc1, bc2, bc3] 49 | 50 | # Define variational problem 51 | U_ = TestFunction(V) 52 | (u_, Theta_) = split(U_) 53 | dU = TrialFunction(V) 54 | (du, dTheta) = split(dU) 55 | Uold = Function(V) 56 | (uold, Thetaold) = split(Uold) 57 | 58 | def eps(v): 59 | return sym(grad(v)) 60 | 61 | def sigma(v, Theta): 62 | return (lmbda*tr(eps(v)) - kappa*Theta)*Identity(2) + 2*mu*eps(v) 63 | 64 | dt = Constant(0.) 65 | mech_form = inner(sigma(du, dTheta), eps(u_))*dx 66 | therm_form = (cV*(dTheta-Thetaold)/dt*Theta_ + 67 | kappa*T0*tr(eps(du-uold))/dt*Theta_ + 68 | dot(k*grad(dTheta), grad(Theta_)))*dx 69 | form = mech_form + therm_form 70 | 71 | # Compute solution 72 | Nincr = 200 73 | t = onp.logspace(1, 4, Nincr+1) 74 | U = Function(V) 75 | for (i, dti) in enumerate(onp.diff(t)): 76 | print("Increment " + str(i+1)) 77 | dt.assign(dti) 78 | solve(lhs(form) == rhs(form), U, bcs) 79 | Uold.assign(U) 80 | 81 | (u, theta) = U.split(True) 82 | 83 | # Save solution in VTK format 84 | # ufile_pvd = File(os.path.join(output_dir, "vtk/fenics_u.pvd")) 85 | # u.rename("u", "u") 86 | # ufile_pvd << u 87 | # tfile_pvd = File(os.path.join(output_dir, "vtk/fenics_theta.pvd")) 88 | # theta.rename("theta", "theta") 89 | # tfile_pvd << theta 90 | 91 | print(f"Max u = {onp.max(u.vector()[:])}, Min u = {onp.min(u.vector()[:])}") 92 | print(f"Max theta = {onp.max(theta.vector()[:])}, Min p = {onp.min(theta.vector()[:])}") 93 | 94 | # Save points and cells for the use of JAX-FEM 95 | # Build function space 96 | V = FiniteElement("Lagrange", mesh.ufl_cell(), 1) 97 | V_fs = FunctionSpace(mesh, V) 98 | 99 | points = V_fs.tabulate_dof_coordinates() 100 | print(f"points.shape = {points.shape}") 101 | 102 | cells_v = [] 103 | dofmap_v = V_fs.dofmap() 104 | for cell in cells(mesh): 105 | dof_index = dofmap_v.cell_dofs(cell.index()) 106 | # print(cell.index(), dof_index) 107 | cells_v.append(dof_index) 108 | cells = onp.stack(cells_v) 109 | print(f"cells.shape = {cells.shape}") 110 | 111 | numpy_dir = os.path.join(input_dir, f'numpy/') 112 | if not os.path.exists(numpy_dir): os.makedirs(numpy_dir) 113 | onp.save(os.path.join(input_dir, f'numpy/points.npy'), points) 114 | onp.save(os.path.join(input_dir, f'numpy/cells.npy'), cells) -------------------------------------------------------------------------------- /demos/thermal_mechanical_full/input/numpy/cells.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical_full/input/numpy/cells.npy -------------------------------------------------------------------------------- /demos/thermal_mechanical_full/input/numpy/points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical_full/input/numpy/points.npy -------------------------------------------------------------------------------- /demos/thermal_mechanical_full/material/theta.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical_full/material/theta.gif -------------------------------------------------------------------------------- /demos/thermal_mechanical_full/material/uy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/thermal_mechanical_full/material/uy.gif -------------------------------------------------------------------------------- /demos/topology_optimization/animation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jax_fem.utils import make_video 3 | 4 | data_path = os.path.join(os.path.dirname(__file__), 'data') 5 | make_video(data_path) -------------------------------------------------------------------------------- /demos/topology_optimization/materials/obj_val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/topology_optimization/materials/obj_val.png -------------------------------------------------------------------------------- /demos/topology_optimization/materials/to.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/topology_optimization/materials/to.gif -------------------------------------------------------------------------------- /demos/wave/README.md: -------------------------------------------------------------------------------- 1 | # Wave equation 2 | 3 | ## Formulation 4 | 5 | ### Governing Equations 6 | 7 | We consider the scalar wave equation in a domain $\Omega\subset\mathbb{R}^d$ with boundary $\partial\Omega =\Gamma_D\cup\Gamma_N$, the strong form gives: 8 | 9 | $$\begin{align*}\frac{1}{c^2}\frac{\partial^2u}{\partial t^2}&=\nabla^2u+q& &\textrm{in} \nobreakspace \nobreakspace \Omega \times(0, t_f],\\ 10 | u &= u_0 & &\textrm{at} \nobreakspace \nobreakspace t=0, \\ 11 | u&=u_D & &\textrm{on} \nobreakspace \nobreakspace \Gamma_{D} \times (0,t_f], \\ 12 | \nabla u \cdot \boldsymbol{n} &= t && \textrm{on} \nobreakspace \nobreakspace \Gamma_N \times (0,t_f].\end{align*}$$ 13 | 14 | where $u$ is the unknown pressure field, $c$ the speed of wave, and $q$ the source term. 15 | 16 | We have the following definitions: 17 | 18 | * $\Omega=(0,1)\times(0,1)$ (a unit square) 19 | * $\Gamma_D=\partial\Omega$ 20 | * $q = 0$ 21 | * $t = 0$ 22 | 23 | ### Discretization in Time 24 | 25 | We first approximate the second-order time derivative with the backward difference scheme: 26 | 27 | $$\displaystyle\frac{\partial^2 u}{\partial t^2}\approx\displaystyle\frac{u^{n}-2u^{n-1}+u^{n-2}}{\Delta t^2}$$ 28 | 29 | The governing equation at time step $n$ for the pressure field can be stated as: 30 | 31 | $$\begin{align*}\frac{1}{c^2}\displaystyle\frac{u^{n}-2u^{n-1}+u^{n-2}}{\Delta t^2}&=\nabla^2u^n& &\textrm{in} \nobreakspace \nobreakspace \Omega,\\ 32 | u&=u_D & &\textrm{on} \nobreakspace \nobreakspace \partial\Omega. \\ 33 | \end{align*}$$ 34 | 35 | ### Weak form 36 | The weak form for $u^n$ is the following: 37 | 38 | $$\int_{\Omega}(c^2\Delta t^2\nabla u^n\cdot\nabla \delta u+u^n\delta u)dx-\int_{\Omega}(2u^{n-1}-u^{n-2})\delta udx=0$$ 39 | 40 | 41 | ## Execution 42 | Run 43 | ```bash 44 | python -m demos.wave.example 45 | ``` 46 | from the `jax-fem/` directory. 47 | 48 | 49 | ## Results 50 | 51 | Results can be visualized with *ParaWiew*. 52 |
53 |
54 |
56 | Wave: pressure 57 |
58 | 59 | -------------------------------------------------------------------------------- /demos/wave/fenics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dolfin import * 4 | from mshr import * 5 | import matplotlib.pyplot as plt 6 | import numpy as onp 7 | 8 | input_dir = os.path.join(os.path.dirname(__file__), 'input') 9 | output_dir = os.path.join(os.path.dirname(__file__), 'output') 10 | 11 | # Define domain and mesh 12 | Lx , Ly = 1., 1. 13 | Nx, Ny = 100, 100 14 | mesh = RectangleMesh(Point(0, 0), Point(1,1), Nx, Ny, 'left') 15 | plot(mesh) 16 | # Define parameters 17 | dt = 1/250000 # temporal sampling interval 18 | c = 5000 # speed of sound 19 | steps = 200 20 | 21 | # Build function space 22 | V = FunctionSpace(mesh, "Lagrange", 1) 23 | 24 | # Define boundary conditions 25 | bcs = DirichletBC(V, Constant(1.), "on_boundary") # Pure Dirichlet boundary conditions 26 | 27 | # Define variational problem 28 | u0 = interpolate(Constant(0.0), V) # u_old_2dt 29 | u1 = interpolate(Constant(0.0), V) # u_old_dt 30 | 31 | u = TrialFunction(V) 32 | v = TestFunction(V) 33 | 34 | a = inner(u, v) * dx + Constant(dt**2 * c**2) * inner(grad(u), grad(v)) * dx 35 | L = (2*u1 - u0) * v * dx 36 | 37 | # Compute solution 38 | u = Function(V) 39 | for n in range(steps): 40 | solve(a==L, u, bcs) 41 | u0.assign(u1) 42 | u1.assign(u) 43 | 44 | print(f"Max u = {onp.max(u.vector()[:])}, Min u = {onp.min(u.vector()[:])}") 45 | 46 | # Save points and cells for the use of JAX-FEM 47 | # Build function space 48 | V = FiniteElement("Lagrange", mesh.ufl_cell(), 1) 49 | V_fs = FunctionSpace(mesh, V) 50 | 51 | points = V_fs.tabulate_dof_coordinates() 52 | print(f"points.shape = {points.shape}") 53 | 54 | cells_v = [] 55 | dofmap_v = V_fs.dofmap() 56 | for cell in cells(mesh): 57 | dof_index = dofmap_v.cell_dofs(cell.index()) 58 | # print(cell.index(), dof_index) 59 | cells_v.append(dof_index) 60 | cells = onp.stack(cells_v) 61 | print(f"cells.shape = {cells.shape}") 62 | 63 | numpy_dir = os.path.join(input_dir, f'numpy/') 64 | if not os.path.exists(numpy_dir): os.makedirs(numpy_dir) 65 | onp.save(os.path.join(input_dir, f'numpy/points.npy'), points) 66 | onp.save(os.path.join(input_dir, f'numpy/cells.npy'), cells) -------------------------------------------------------------------------------- /demos/wave/input/numpy/cells.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/wave/input/numpy/cells.npy -------------------------------------------------------------------------------- /demos/wave/input/numpy/points.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/wave/input/numpy/points.npy -------------------------------------------------------------------------------- /demos/wave/material/pressure.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/wave/material/pressure.gif -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: jax-fem-env 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python==3.9.18 7 | - numpy==1.24.4 8 | - scipy==1.11.3 9 | - matplotlib==3.8.0 10 | - meshio==5.3.4 11 | - petsc4py==3.20.0 12 | - pip==23.2.1 13 | - fenics==2019.1.0 14 | - pip: 15 | - setuptools==68.2 16 | - wheel==0.41 17 | - gmsh==4.11.1 18 | - fenics-basix==0.6.0 19 | - pyfiglet==1.0 -------------------------------------------------------------------------------- /images/ded.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/ded.gif -------------------------------------------------------------------------------- /images/poisson.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/poisson.png -------------------------------------------------------------------------------- /images/polycrystal_grain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/polycrystal_grain.gif -------------------------------------------------------------------------------- /images/polycrystal_stress.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/polycrystal_stress.gif -------------------------------------------------------------------------------- /images/stokes_p.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/stokes_p.png -------------------------------------------------------------------------------- /images/stokes_u.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/stokes_u.png -------------------------------------------------------------------------------- /images/to.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/to.gif -------------------------------------------------------------------------------- /images/von_mises.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepmodeling/jax-fem/1bdbf060bb32951d04ed9848c238c9a470fee1b4/images/von_mises.png -------------------------------------------------------------------------------- /jax_fem/README.md: -------------------------------------------------------------------------------- 1 | # JAX-FEM 2 | 3 | A differentiable 2D/3D finite element solver for automatic inverse design and mechanistic data science. This is the source code directory. Please visit [tutorial](https://github.com/tianjuxue/jax-fem/tree/main/demos) for more information. 4 | 5 | -------------------------------------------------------------------------------- /jax_fem/__init__.py: -------------------------------------------------------------------------------- 1 | from pyfiglet import Figlet 2 | 3 | f = Figlet(font='starwars') 4 | print(f.renderText('JAX - FEM')) 5 | 6 | from .logger_setup import setup_logger 7 | # LOGGING 8 | logger = setup_logger(__name__) 9 | 10 | # TODO: Be automatic 11 | # __version__ = "0.0.9" -------------------------------------------------------------------------------- /jax_fem/experimental/adjoint_save_to_local.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import os 4 | 5 | jax.config.update("jax_enable_x64", True) 6 | 7 | crt_file_path = os.path.dirname(__file__) 8 | data_dir = os.path.join(crt_file_path, 'data') 9 | numpy_dir = os.path.join(data_dir, 'numpy') 10 | file_path = os.path.join(numpy_dir, 'tmp.npy') 11 | 12 | 13 | def raw_f(x, y): 14 | return np.sin(x) * y 15 | 16 | @jax.custom_vjp 17 | def f(x, y): 18 | return np.sin(x) * y 19 | 20 | def f_fwd(x, y): 21 | np.save(file_path, np.cos(x)) 22 | return f(x, y), (np.sin(x), y) 23 | 24 | def f_bwd(res, g): 25 | cos_x = np.load(file_path) 26 | sin_x, y = res 27 | return (cos_x * g * y, sin_x * g) 28 | 29 | f.defvjp(f_fwd, f_bwd) 30 | 31 | print(f(1., 2.)) 32 | print(raw_f(1., 2.)) 33 | 34 | print(jax.grad(f)(1., 2.)) 35 | print(jax.grad(raw_f)(1., 2.)) 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /jax_fem/experimental/custom_jvp.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import os 4 | 5 | 6 | def implicit_residual(x, y): 7 | A = np.diag(np.array([1., 2., 3])) 8 | # Assume that Ay - x = 0 9 | return A @ y - x 10 | 11 | @jax.custom_jvp 12 | def newton_solver(x): 13 | y_0 = np.zeros(3) 14 | step_0 = 0 15 | res_vec_0 = implicit_residual(x, y_0) 16 | tol = 1e-8 17 | 18 | def cond_fun(state): 19 | step, res_vec, y = state 20 | return np.linalg.norm(res_vec) > tol 21 | 22 | def body_fun(state): 23 | step, res_vec, y = state 24 | f_partial = lambda y: implicit_residual(x, y) 25 | jac = jax.jacfwd(f_partial)(y) # Works for small system 26 | y_inc = np.linalg.solve(jac, -res_vec) # Works for small system 27 | res_vec = f_partial(y + y_inc) 28 | step_update = step + 1 29 | return step_update, res_vec, y + y_inc 30 | 31 | step_f, res_vec_f, y_f = jax.lax.while_loop(cond_fun, body_fun, (step_0, res_vec_0, y_0)) 32 | return y_f 33 | 34 | @newton_solver.defjvp 35 | def f_jvp(primals, tangents): 36 | x, = primals 37 | v, = tangents 38 | y = newton_solver(x) 39 | jac_x = jax.jacfwd(implicit_residual, argnums=0)(x, y) # Works for small system 40 | jac_y = jax.jacfwd(implicit_residual, argnums=1)(x, y) # Works for small system 41 | jvp_result = np.linalg.solve(jac_y, -(jac_x @ v[:, None]).reshape(-1)) # Works for small system 42 | return y, jvp_result 43 | 44 | 45 | x = np.ones(3) 46 | y = newton_solver(x) 47 | print(f"\ny = {y}") 48 | 49 | jac_y_over_x_fwd = jax.jacfwd(newton_solver)(x) 50 | jac_y_over_x_rev = jax.jacrev(newton_solver)(x) 51 | 52 | print(f"\njac_y_over_x_fwd = \n{jac_y_over_x_fwd}") 53 | print(f"\njac_y_over_x_rev = \n{jac_y_over_x_rev}") -------------------------------------------------------------------------------- /jax_fem/experimental/jit_global.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | import jax 3 | 4 | class A: 5 | def __init__(self): 6 | self.f = self.get_fn() 7 | 8 | def get_fn(self): 9 | @jax.jit 10 | def f(): 11 | return self.E 12 | return f 13 | 14 | def set_params(self, E): 15 | self.E = E 16 | 17 | a = A() 18 | 19 | def test_fn_a(E): 20 | a.set_params(E) 21 | return a.f() 22 | 23 | E_a = 10. 24 | 25 | print(test_fn_a(E_a)) 26 | print(jax.grad(test_fn_a)(E_a)) 27 | 28 | 29 | class B: 30 | def f(self): 31 | return self.E 32 | 33 | def set_params(self, E): 34 | self.E = E 35 | 36 | b = B() 37 | 38 | @jax.jit 39 | def test_fn_b(E): 40 | b.set_params(E) 41 | return b.f() 42 | 43 | E_b = 10. 44 | 45 | print(test_fn_b(E_b)) 46 | print(jax.grad(test_fn_b)(E_b)) 47 | 48 | E_b = 20. 49 | 50 | print(test_fn_b(E_b)) 51 | print(jax.grad(test_fn_b)(E_b)) -------------------------------------------------------------------------------- /jax_fem/experimental/memory.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as np 3 | import jax.profiler 4 | import numpy as onp 5 | 6 | def func1(x): 7 | return np.tile(x, 10) * 0.5 8 | 9 | def func2(x): 10 | y = func1(x) 11 | return y, np.tile(x, 10) + 1 12 | 13 | x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000)) 14 | y, z = func2(x) 15 | 16 | k = np.ones((10000, 10000)) 17 | 18 | del k 19 | 20 | z.block_until_ready() 21 | 22 | jax.profiler.save_device_memory_profile(f"modules/fem/experiments/data/memory.prof") 23 | -------------------------------------------------------------------------------- /jax_fem/experimental/petsc_solver.py: -------------------------------------------------------------------------------- 1 | import numpy as onp 2 | import petsc4py 3 | petsc4py.init() 4 | from petsc4py import PETSc 5 | 6 | n = 10 # Size of vector 7 | x = PETSc.Vec().createSeq(n) # Faster way to create a sequential vector. 8 | 9 | x.setValues(range(n), range(n)) 10 | 11 | print(x.getArray()) 12 | print(x.getValues(3)) 13 | print(x.getValues([1, 2])) 14 | 15 | x.setValues(range(n), range(n)) 16 | x.shift(1) 17 | print(x.getArray()) 18 | x.shift(-1) 19 | print(x.getArray()) 20 | 21 | x.setValues(range(n), range(n)) 22 | 23 | print(x.sum()) 24 | print(x.min()) 25 | print(x.max()) 26 | 27 | print(x.dot(x)) # dot product with self 28 | 29 | print ('2-norm =', x.norm()) 30 | print ('Infinity-norm =', x.norm(PETSc.NormType.NORM_INFINITY)) 31 | 32 | m, n = 4, 4 # size of the matrix 33 | A = PETSc.Mat().createAIJ([m, n]) # AIJ represents sparse matrix 34 | A.setUp() 35 | A.assemble() 36 | 37 | print(A.getValues(range(m), range(n))) 38 | print(A.getValues(range(2), range(1))) 39 | 40 | A.setValue(1, 1, -1) 41 | A.setValue(0, 0, -2) 42 | A.setValue(2, 2, -5) 43 | A.setValue(3, 3, 6) 44 | # A.setValues([0, 1], [2, 3], [1, 1, 1, 1]) 45 | 46 | # A.setValuesIJV([0, 1], [2, 3], [1, 1]) 47 | 48 | A.assemble() 49 | 50 | A.zeroRows([0, 1]) 51 | 52 | print(A.getValues(range(m), range(n))) 53 | 54 | 55 | exit() 56 | 57 | print(A.getSize()) 58 | B = A.copy() 59 | B.transpose() 60 | print(A.getSize(), B.getSize()) 61 | print(B.getValues(range(4), range(4))) 62 | 63 | C = A.matMult(B) 64 | print(C.getValues(range(m), range(n))) 65 | 66 | x = PETSc.Vec().createSeq(4) # making the x vector 67 | x.set(1) # assigning value 1 to all the elements 68 | y = PETSc.Vec().createSeq(4) # Put answer here. 69 | A.mult(e, y) # A*e = y 70 | print(y.getArray()) 71 | 72 | 73 | print("Matrix A: ") 74 | print(A.getValues(range(m), range(n))) # printing the matrix A defined above 75 | 76 | b = PETSc.Vec().createSeq(4) # creating a vector 77 | b.setValues(range(4), [10, 5, 3, 6]) # assigning values to the vector 78 | 79 | print('\\n Vector b: ') 80 | print(b.getArray()) # printing the vector 81 | 82 | x = PETSc.Vec().createSeq(4) # create the solution vector x 83 | 84 | ksp = PETSc.KSP().create() # creating a KSP object named ksp 85 | ksp.setOperators(A) 86 | 87 | # Allow for solver choice to be set from command line with -ksp_type