├── .gitignore
├── Instructions.md
├── README.md
├── materials
├── NU_1.png
├── NU_2.png
├── file_structure.png
├── laser_path.png
├── single_track_T_DNS.gif
├── single_track_T_PEGN.gif
├── single_track_eta_DNS.gif
├── single_track_eta_PEGN.gif
├── single_track_zeta_DNS.gif
├── single_track_zeta_PEGN.gif
├── solidification_anisotropic.gif
└── solidification_isotropic.gif
└── src
├── allen_cahn.py
├── arguments.py
├── curved_grain.py
├── example.py
├── fit_ellipsoid.py
├── multi_layer.py
├── npj_review.py
├── property.py
├── single_layer.py
├── solidification.py
├── temperature.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | reference
3 | *.DS_Store
4 |
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
--------------------------------------------------------------------------------
/Instructions.md:
--------------------------------------------------------------------------------
1 | # Frequently Asked Questions (FAQ)
2 |
3 | We believe JAX is an exciting tool for scientific simulations. The phase-field simulation is just an example and the beginning of this area. Here are our reasons:
4 |
5 | 1. You program in Python, and the code runs on ___CPU/GPU___ as fast as compiled languages (C/Fortran). Efficient for both human developers and machines!
6 | 2. It is natural for research of the flavor of ___AI for science/engineering___, because JAX is intended for high performance machine learning.
7 | 3. The automatic differentiation feature of JAX opens the door for ___design and optimization___ because gradients are now available on the fly.
8 |
9 | The idea is __open code = better science__, and this is how we try to contribute to the scientific computing (with a focus on computational mechanics) community.
10 |
11 |
12 |
13 | ### Q: What are the dependencies of this code?
14 |
15 | You need to install [JAX](https://github.com/google/jax). We highly encourage you to create a `conda` environment and install all Python packages in the environment. Although JAX works with CPU, its true power is on GPU. You may expect a boost of performance for 10x to 100x on GPU than CPU.
16 |
17 | Another major tool we use is [Neper](https://neper.info/) for polycrystal structure generation. Neper is a CPU intensive software.
18 |
19 |
20 |
21 | ### Q: What are the source files and how are they intended for?
22 |
23 | For developers, the key file for computation is `allen_cahn.py` , where we solve the phase-field equations. Almost everything interesting happens in this file. The parameters are defined in `arguments.py`, though. Utility functions and post-processing functions are defined in `utils.py`.
24 |
25 | For users, there are several example files, such as `example.py` and `solidification.py`.
26 |
27 |
28 |
29 | ### Q: What are the units used in the simulation?
30 |
31 | We use SI units for all quantities except for length, which is in [mm].
32 |
33 |
34 |
35 | ### Q: How do I run an example?
36 |
37 | We use [module run](https://stackoverflow.com/questions/7610001/what-is-the-purpose-of-the-m-switch) to execute a certain script. For instance, to run `src/example.py`, under `/polycrystal` and specify in command line the following
38 |
39 | ```
40 | python -m src.example
41 | ```
42 |
43 |
44 |
45 | ### Q: How do I specify laser path?
46 |
47 | You specify laser path with txt files. For example, in this `data/txt/fd_example.txt` example file, we have
48 |
49 |
50 |
51 |
52 |
53 | The four columns are time, x position, y position, and laser switch. In this file, we turn the laser on at t=0 [s], x=0.2 [mm], y=0.1 [mm], turn the laser off at t=0.001 [s], x=0.8, [mm], y=0.1 [mm], and finally keep the laser off at t=0.002 [s], x=0.8, [mm] y=0.1 [mm]. In between those time stamps, laser is assumed to travel with a constant speed.
54 |
55 |
56 |
57 | ### Q: I got an error of "certain directory does not exist".
58 |
59 | Just create a directory manually. We didn't track the data folder on GitHub, so you would have to create the data folder by yourself on your local machine. For example, what appears on my local machine is
60 |
61 |
62 |
63 |
64 |
65 | As you can see, we sort the data files according to their types.
66 |
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Phase-field simulation of microstructure evolution in additive manufacturing
2 |
3 | This is the repository for our paper "[Physics-embedded graph network for accelerating phase-field simulation of microstructure evolution in additive manufacturing](https://doi.org/10.1038/s41524-022-00890-9)" on _npj Computational Materials - Nature_. We implemented both classic direct numerical simulation (DNS) based on the finite difference method and a new physics-embedded graph network (PEGN) approach. PEGN is a computationally light alternative for DNS. The code runs on both CPU and GPU. The DNS approach generally follows [Yang et al](https://www.nature.com/articles/s41524-021-00524-6). This code is developed under [AMPL](https://www.cao.mech.northwestern.edu/) at Northwestern University.
4 |
5 | ## Requirements
6 |
7 | We use [JAX](https://github.com/google/jax) for implementation of the computationally intensive part. The graph construction is based on [Jraph](https://github.com/deepmind/jraph). The polycrystal structure is generated with [Neper](https://neper.info/).
8 |
9 | ## Descriptions
10 |
11 | The typical workflow contains two major steps:
12 | 1. Generate a polycrystal structure and mesh it (with Neper)
13 | 2. Perform phase-field simulation.
14 |
15 | The file `src/example.py` is an instructive example. To run this file, under root directory and run
16 |
17 | ```
18 | python -m src.example
19 | ```
20 |
21 | Please see the comments in `src/example.py` for further details. We also have [an instruction file](https://github.com/tianjuxue/polycrystal/blob/main/Instructions.md) for FAQs.
22 |
23 | ## Case studies
24 |
25 | __Single-layer single-track powder bed fusion process__
26 |
27 | The left column shows results of DNS, while the right column shows results of PEGN. The first row shows temperature field, the second row is melt pool, and the the third row is grain evolution.
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | __Multi-layer multi-track powder bed fusion process__
46 |
47 | Below is the result of using PEGN for simulating a 20-layer process.
48 |
49 |
50 |
51 |
52 |
53 |
54 | __Directional solidification__
55 |
56 | Competitive grain growth is observed if grain anisotropy is considered. Left is when grain anisotropy is NOT included, and right is when it is included. Competitiveness: red > blue > green for anisotropic case. The result is based on DNS.
57 |
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/materials/NU_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/NU_1.png
--------------------------------------------------------------------------------
/materials/NU_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/NU_2.png
--------------------------------------------------------------------------------
/materials/file_structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/file_structure.png
--------------------------------------------------------------------------------
/materials/laser_path.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/laser_path.png
--------------------------------------------------------------------------------
/materials/single_track_T_DNS.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/single_track_T_DNS.gif
--------------------------------------------------------------------------------
/materials/single_track_T_PEGN.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/single_track_T_PEGN.gif
--------------------------------------------------------------------------------
/materials/single_track_eta_DNS.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/single_track_eta_DNS.gif
--------------------------------------------------------------------------------
/materials/single_track_eta_PEGN.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/single_track_eta_PEGN.gif
--------------------------------------------------------------------------------
/materials/single_track_zeta_DNS.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/single_track_zeta_DNS.gif
--------------------------------------------------------------------------------
/materials/single_track_zeta_PEGN.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/single_track_zeta_PEGN.gif
--------------------------------------------------------------------------------
/materials/solidification_anisotropic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/solidification_anisotropic.gif
--------------------------------------------------------------------------------
/materials/solidification_isotropic.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CMSL-HKUST/polycrystal/ddbe19bf1ed756ba60c4dc95927a9be0c4c11e97/materials/solidification_isotropic.gif
--------------------------------------------------------------------------------
/src/allen_cahn.py:
--------------------------------------------------------------------------------
1 | import jraph
2 | import jax
3 | import jax.numpy as np
4 | import numpy as onp
5 | import meshio
6 | import os
7 | import glob
8 | import time
9 | import pickle
10 | from functools import partial
11 | from scipy.spatial.transform import Rotation as R
12 | from collections import namedtuple
13 | from matplotlib import pyplot as plt
14 | from src.arguments import args
15 | from src.utils import unpack_state, get_unique_ori_colors, obj_to_vtu, walltime
16 |
17 |
18 | # TODO: unique_oris_rgb and unique_grain_directions should be a class property, not an instance property
19 | PolyCrystal = namedtuple('PolyCrystal', ['edges', 'ch_len', 'face_areas', 'centroids', 'volumes', 'unique_oris_rgb',
20 | 'unique_grain_directions', 'cell_ori_inds', 'boundary_face_areas', 'boundary_face_centroids', 'meta_info'])
21 |
22 |
23 | @partial(jax.jit, static_argnums=(2,))
24 | def rk4(state, t_crt, f, *ode_params):
25 | '''
26 | Fourth order Runge-Kutta method
27 | We probably don't need this one.
28 | '''
29 | y_prev, t_prev = state
30 | h = t_crt - t_prev
31 | k1 = h * f(y_prev, t_prev, *ode_params)
32 | k2 = h * f(y_prev + k1/2., t_prev + h/2., *ode_params)
33 | k3 = h * f(y_prev + k2/2., t_prev + h/2., *ode_params)
34 | k4 = h * f(y_prev + k3, t_prev + h, *ode_params)
35 | y_crt = y_prev + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)
36 | return (y_crt, t_crt), y_crt
37 |
38 |
39 | # gpus = jax.devices('gpu')
40 | # @partial(jax.jit, static_argnums=(2,), device=gpus[-1])
41 |
42 | @partial(jax.jit, static_argnums=(2,))
43 | def explicit_euler(state, t_crt, f, *ode_params):
44 | '''
45 | Explict Euler method
46 | '''
47 | y_prev, t_prev = state
48 | h = t_crt - t_prev
49 | y_crt = y_prev + h * f(y_prev, t_prev, *ode_params)
50 | return (y_crt, t_crt), y_crt
51 |
52 |
53 | @jax.jit
54 | def force_eta_zero_in_liquid(y):
55 | '''
56 | In liquid zone, set all eta to be zero.
57 | '''
58 | T, zeta, eta = unpack_state(y)
59 | eta = np.where(zeta < 0.5, 0., eta)
60 | return np.hstack((T, zeta, eta))
61 |
62 |
63 | # @walltime
64 | def odeint_no_output(polycrystal, mesh, mesh_bottom_layer, stepper, f, y0, melt, ts, xs, ys, ps, overwrite_T=None):
65 | '''
66 | Just for measureing wall time.
67 | '''
68 | before_jit = time.time()
69 | state = (y0, ts[0])
70 | for (i, t_crt) in enumerate(ts[1:]):
71 | state, y = stepper(state, t_crt, f, xs[i + 1], ys[i + 1], ps[i + 1])
72 | state = (force_eta_zero_in_liquid(y), t_crt)
73 | if (i + 1) % 1000 == 0:
74 | print(f"step {i + 1} of {len(ts[1:])}, unix timestamp = {time.time()}")
75 | if i == 0:
76 | start_time = time.time()
77 |
78 | end_time = time.time()
79 | platform = jax.lib.xla_bridge.get_backend().platform
80 | print(f"Time for jitting: {start_time - before_jit}, execution: {end_time - start_time} on platform {platform}")
81 | with open(f'data/txt/walltime_{platform}_{args.case}_{args.layer:03d}.txt', 'w') as f:
82 | f.write(f'{start_time - before_jit} {end_time - start_time} {len(polycrystal.centroids)} {len(polycrystal.edges)} {args.num_oris}\n')
83 |
84 |
85 | def odeint(polycrystal, mesh, mesh_bottom_layer, stepper, f, y0, melt, ts, xs, ys, ps, overwrite_T=None):
86 | '''
87 | ODE integrator.
88 | '''
89 | clean_sols()
90 | state = (y0, ts[0])
91 | if overwrite_T is not None:
92 | state = (overwrite_T(y0, polycrystal.centroids, ts[0]), ts[0])
93 | write_sols(polycrystal, mesh, y0, melt, 0)
94 | for (i, t_crt) in enumerate(ts[1:]):
95 | state, y = stepper(state, t_crt, f, xs[i + 1], ys[i + 1], ps[i + 1])
96 | state = (force_eta_zero_in_liquid(y), t_crt)
97 | if overwrite_T is not None:
98 | state = (overwrite_T(y, polycrystal.centroids, t_crt), t_crt)
99 | melt = np.logical_or(melt, y[:, 1] < 0.5)
100 | if (i + 1) % 20 == 0:
101 | print(f"step {i + 1} of {len(ts[1:])}, unix timestamp = {time.time()}")
102 | # print(y[:10, :5])
103 | inspect_sol(y, y0)
104 | if not np.all(np.isfinite(y)):
105 | raise ValueError(f"Found np.inf or np.nan in y - stop the program")
106 | write_sol_interval = args.write_sol_interval
107 | if (i + 1) % write_sol_interval == 0:
108 | write_sols(polycrystal, mesh, y, melt, (i + 1) // write_sol_interval)
109 |
110 | write_final_sols(polycrystal, mesh_bottom_layer, y, melt)
111 | write_info(polycrystal)
112 |
113 |
114 | def inspect_sol(y, y0):
115 | '''
116 | While running simulations, print out some useful information.
117 | '''
118 | T = y[:, 0]
119 | zeta = y[:, 1]
120 | change_zeta = np.where(zeta < 0.5, 1, 0)
121 | eta0 = np.argmax(y0[:, 2:], axis=1)
122 | eta = np.argmax(y[:, 2:], axis=1)
123 | change_eta = np.where(eta0 == eta, 0, 1)
124 | change_T = np.where(T >= args.T_melt, 1, 0)
125 | print(f"percet of zeta in liquid = {np.sum(change_zeta)/len(change_zeta)*100}%")
126 | print(f"percent of change of orientations = {np.sum(change_eta)/len(change_eta)*100}%")
127 | print(f"percet of T >= T_melt = {np.sum(change_T)/len(change_T)*100}%")
128 | print(f"max T = {np.max(T)}")
129 |
130 |
131 | def clean_sols():
132 | '''
133 | Clean the data folder.
134 | '''
135 | if args.case.startswith('gn_multi_layer'):
136 | vtk_folder = f"data/vtk/{args.case}/sols/layer_{args.layer:03d}"
137 | if not os.path.exists(vtk_folder):
138 | os.makedirs(vtk_folder)
139 | numpy_folder = f'data/numpy/{args.case}/sols/layer_{args.layer:03d}'
140 | if not os.path.exists(numpy_folder):
141 | os.makedirs(numpy_folder)
142 | group_folder = f"data/vtk/{args.case}/sols/group"
143 | if not os.path.exists(group_folder):
144 | os.makedirs(group_folder)
145 | else:
146 | vtk_folder = f"data/vtk/{args.case}/sols"
147 | numpy_folder = f"data/numpy/{args.case}/sols"
148 | if not os.path.exists(vtk_folder):
149 | os.makedirs(vtk_folder)
150 | if not os.path.exists(numpy_folder):
151 | os.makedirs(numpy_folder)
152 |
153 | files_vtk = glob.glob(vtk_folder + f"/*")
154 | files_numpy = glob.glob(numpy_folder + f"/*")
155 | files = files_vtk + files_numpy
156 | for f in files:
157 | os.remove(f)
158 |
159 |
160 | def write_info(polycrystal):
161 | '''
162 | Mostly for post-processing. E.g., compute grain volume, aspect ratios, etc.
163 | '''
164 | if not args.case.startswith('gn_multi_layer'):
165 | numpy_folder_info = f"data/numpy/{args.case}/info"
166 | if not os.path.exists(numpy_folder_info):
167 | os.makedirs(numpy_folder_info)
168 | onp.save(f"data/numpy/{args.case}/info/edges.npy", polycrystal.edges)
169 | onp.save(f"data/numpy/{args.case}/info/face_areas.npy", polycrystal.face_areas)
170 | onp.save(f"data/numpy/{args.case}/info/vols.npy", polycrystal.volumes)
171 | onp.save(f"data/numpy/{args.case}/info/centroids.npy", polycrystal.centroids)
172 |
173 |
174 | def write_sols_heper(polycrystal, mesh, y, melt):
175 | T = y[:, 0]
176 | zeta = y[:, 1]
177 | eta = y[:, 2:]
178 | eta_max = onp.max(eta, axis=1)
179 | cell_ori_inds = onp.argmax(eta, axis=1)
180 | ipf_x = onp.take(polycrystal.unique_oris_rgb[0], cell_ori_inds, axis=0)
181 | ipf_y = onp.take(polycrystal.unique_oris_rgb[1], cell_ori_inds, axis=0)
182 | ipf_z = onp.take(polycrystal.unique_oris_rgb[2], cell_ori_inds, axis=0)
183 |
184 | # ipf_x[eta_max < 0.1] = 0.
185 | # ipf_y[eta_max < 0.1] = 0.
186 | # ipf_z[eta_max < 0.1] = 0.
187 |
188 | # TODO: Is this better?
189 | ipf_x[zeta < 0.1] = 0.
190 | ipf_y[zeta < 0.1] = 0.
191 | ipf_z[zeta < 0.1] = 0.
192 |
193 | mesh.cell_data['T'] = [onp.array(T, dtype=onp.float32)]
194 | mesh.cell_data['zeta'] = [onp.array(zeta, dtype=onp.float32)]
195 | mesh.cell_data['ipf_x'] = [ipf_x]
196 | mesh.cell_data['ipf_y'] = [ipf_y]
197 | mesh.cell_data['ipf_z'] = [ipf_z]
198 | mesh.cell_data['melt'] = [onp.array(melt, dtype=onp.float32)]
199 | cell_ori_inds = onp.array(cell_ori_inds, dtype=onp.int32)
200 | # Remark: cell_ori_inds starts with index 0
201 | mesh.cell_data['ori_inds'] = [cell_ori_inds]
202 |
203 | return T, zeta, cell_ori_inds
204 |
205 |
206 | def write_sols(polycrystal, mesh, y, melt, step):
207 | '''
208 | Use Paraview to open .vtu files for visualization of:
209 | 1. Temeperature field (T)
210 | 2. Liquid/Solid phase (zeta)
211 | 3. Grain orientations (eta)
212 | '''
213 | print(f"Write sols to file...")
214 | T, zeta, cell_ori_inds = write_sols_heper(polycrystal, mesh, y, melt)
215 | if args.case.startswith('gn_multi_layer'):
216 | if args.layer == args.num_total_layers:
217 | mesh.write(f"data/vtk/{args.case}/sols/layer_{args.layer:03d}/u{step:03d}.vtu")
218 | else:
219 | onp.save(f"data/numpy/{args.case}/sols/T_{step:03d}.npy", T)
220 | onp.save(f"data/numpy/{args.case}/sols/zeta_{step:03d}.npy", zeta)
221 | onp.save(f"data/numpy/{args.case}/sols/cell_ori_inds_{step:03d}.npy", cell_ori_inds)
222 | onp.save(f"data/numpy/{args.case}/sols/melt_{step:03d}.npy", melt)
223 | mesh.write(f"data/vtk/{args.case}/sols/u{step:03d}.vtu")
224 |
225 |
226 | def write_final_sols(polycrystal, mesh_bottom_layer, y, melt):
227 | if args.case.startswith('gn_multi_layer'):
228 | # top layer solutions are saved to be the initial values of next layer
229 | y_top = onp.array(y[args.layer_num_dofs:, :])
230 | y_top[:, 0] = args.T_ambient
231 | y_top[:, 1] = 1.
232 | np.save(f'data/numpy/{args.case}/sols/layer_{args.layer:03d}/y_final_top.npy', y_top)
233 | np.save(f'data/numpy/{args.case}/sols/layer_{args.layer:03d}/melt_final_top.npy', melt[args.layer_num_dofs:])
234 |
235 | # bottom layer solutions are saved for analysis
236 | melt_bottom = melt[:args.layer_num_dofs]
237 | y_bottom = y[:args.layer_num_dofs, :]
238 | T_bottom, zeta_bottom, cell_ori_inds_bottom = write_sols_heper(polycrystal, mesh_bottom_layer, y_bottom, melt_bottom)
239 | np.save(f'data/numpy/{args.case}/sols/layer_{args.layer:03d}/melt_final_bottom.npy', melt_bottom)
240 | np.save(f"data/numpy/{args.case}/sols/layer_{args.layer:03d}/cell_ori_inds_bottom.npy", cell_ori_inds_bottom)
241 | mesh_bottom_layer.write(f"data/vtk/{args.case}/sols/group/sol_bottom_layer_{args.layer:03d}.vtu")
242 |
243 |
244 | def polycrystal_gn(domain_name='single_layer'):
245 | '''
246 | Prepare graph information for reduced-order modeling
247 | '''
248 | unique_oris_rgb, unique_grain_directions = get_unique_ori_colors()
249 | grain_oris_inds = onp.random.randint(args.num_oris, size=args.num_grains)
250 | cell_ori_inds = grain_oris_inds
251 | mesh = obj_to_vtu(domain_name)
252 |
253 | stface = onp.loadtxt(f'data/neper/{domain_name}/domain.stface')
254 | face_centroids = stface[:, :3]
255 | face_areas = stface[:, 3]
256 |
257 | edges = [[] for i in range(len(face_areas))]
258 | centroids = []
259 | volumes = []
260 |
261 | file = open(f'data/neper/{domain_name}/domain.stcell', 'r')
262 | lines = file.readlines()
263 |
264 | assert args.num_grains == len(lines)
265 |
266 | boundary_face_areas = onp.zeros((args.num_grains, 6))
267 | boundary_face_centroids = onp.zeros((args.num_grains, 6, args.dim))
268 |
269 | for i, line in enumerate(lines):
270 | l = line.split()
271 | centroids.append([float(l[0]), float(l[1]), float(l[2])])
272 | volumes.append(float(l[3]))
273 | l = l[4:]
274 | num_nb_faces = len(l)
275 | for j in range(num_nb_faces):
276 | edges[int(l[j]) - 1].append(i)
277 |
278 | centroids = onp.array(centroids)
279 | volumes = onp.array(volumes)
280 |
281 | new_face_areas = []
282 | new_edges = []
283 |
284 | def face_centroids_to_boundary_index(face_centroid):
285 | domain_measures = [args.domain_length, args.domain_width, args.domain_height]
286 | for i, domain_measure in enumerate(domain_measures):
287 | if onp.isclose(face_centroid[i], 0., atol=1e-08):
288 | return 2*i
289 | if onp.isclose(face_centroid[i], domain_measure, atol=1e-08):
290 | return 2*i + 1
291 | raise ValueError(f"Expect a boundary face, got centroid {face_centroid} that is not on any boundary.")
292 |
293 | for i, edge in enumerate(edges):
294 | if len(edge) == 1:
295 | grain_index = edge[0]
296 | boundary_index = face_centroids_to_boundary_index(face_centroids[i])
297 | face_area = face_areas[i]
298 | face_centroid = face_centroids[i]
299 | boundary_face_areas[grain_index, boundary_index] = face_area
300 | boundary_face_centroids[grain_index, boundary_index] = face_centroid
301 | elif len(edge) == 2:
302 | new_edges.append(edge)
303 | new_face_areas.append(face_areas[i])
304 | else:
305 | raise ValueError(f"Number of connected grains for any face must be 1 or 2, got {len(edge)}.")
306 |
307 | new_edges = onp.array(new_edges)
308 | new_face_areas = onp.array(new_face_areas)
309 |
310 | centroids_1 = onp.take(centroids, new_edges[:, 0], axis=0)
311 | centroids_2 = onp.take(centroids, new_edges[:, 1], axis=0)
312 | grain_distances = onp.sqrt(onp.sum((centroids_1 - centroids_2)**2, axis=1))
313 |
314 | ch_len = new_face_areas / grain_distances
315 |
316 | # domain_vol = args.domain_length*args.domain_width*args.domain_height
317 | # ch_len_avg = (domain_vol / args.num_grains)**(1./3.) * onp.ones(len(new_face_areas))
318 |
319 | meta_info = onp.array([0., 0., 0., args.domain_length, args.domain_width, args.domain_height])
320 | polycrystal = PolyCrystal(new_edges, ch_len, new_face_areas, centroids, volumes, unique_oris_rgb, unique_grain_directions,
321 | cell_ori_inds, boundary_face_areas, boundary_face_centroids, meta_info)
322 |
323 | return polycrystal, mesh
324 |
325 |
326 | def polycrystal_fd(domain_name='single_layer'):
327 | '''
328 | Prepare graph information for finite difference method
329 | '''
330 | filepath = f'data/neper/{domain_name}/domain.msh'
331 | mesh = meshio.read(filepath)
332 | points = mesh.points
333 | cells = mesh.cells_dict['hexahedron']
334 | cell_grain_inds = mesh.cell_data['gmsh:physical'][0] - 1
335 |
336 | numpy_folder = f"data/numpy/{args.case}/info"
337 | if not os.path.exists(numpy_folder):
338 | os.makedirs(numpy_folder)
339 |
340 | onp.save(f"data/numpy/{args.case}/info/cell_grain_inds.npy", cell_grain_inds)
341 | assert args.num_grains == onp.max(cell_grain_inds) + 1
342 |
343 | unique_oris_rgb, unique_grain_directions = get_unique_ori_colors()
344 | grain_oris_inds = onp.random.randint(args.num_oris, size=args.num_grains)
345 | cell_ori_inds = onp.take(grain_oris_inds, cell_grain_inds, axis=0)
346 |
347 | Nx = round(args.domain_length / points[1, 0])
348 | Ny = round(args.domain_width / points[Nx + 1, 1])
349 | Nz = round(args.domain_height / points[(Nx + 1)*(Ny + 1), 2])
350 | args.Nx = Nx
351 | args.Ny = Ny
352 | args.Nz = Nz
353 |
354 | print(f"Total num of grains = {args.num_grains}")
355 | print(f"Total num of orientations = {args.num_oris}")
356 | print(f"Total num of finite difference cells = {len(cells)}")
357 | assert Nx*Ny*Nz == len(cells)
358 |
359 | edges = []
360 | for i in range(Nx):
361 | if i % 100 == 0:
362 | print(f"i = {i}")
363 | for j in range(Ny):
364 | for k in range(Nz):
365 | crt_ind = i + j * Nx + k * Nx * Ny
366 | if i != Nx - 1:
367 | edges.append([crt_ind, (i + 1) + j * Nx + k * Nx * Ny])
368 | if j != Ny - 1:
369 | edges.append([crt_ind, i + (j + 1) * Nx + k * Nx * Ny])
370 | if k != Nz - 1:
371 | edges.append([crt_ind, i + j * Nx + (k + 1) * Nx * Ny])
372 |
373 | edges = onp.array(edges)
374 | cell_points = onp.take(points, cells, axis=0)
375 | centroids = onp.mean(cell_points, axis=1)
376 | domain_vol = args.domain_length*args.domain_width*args.domain_height
377 | volumes = domain_vol / (Nx*Ny*Nz) * onp.ones(len(cells))
378 | ch_len = (domain_vol / len(cells))**(1./3.) * onp.ones(len(edges))
379 | face_areas = (domain_vol / len(cells))**(2./3.) * onp.ones(len(edges))
380 |
381 | face_inds = [[0, 3, 4, 7], [1, 2, 5, 6], [0, 1, 4, 5], [2, 3, 6, 7], [0, 1, 2, 3], [4, 5, 6, 7]]
382 | boundary_face_centroids = onp.transpose(onp.stack([onp.mean(onp.take(cell_points, face_ind, axis=1), axis=1)
383 | for face_ind in face_inds]), axes=(1, 0, 2))
384 |
385 | boundary_face_areas = []
386 | domain_measures = [args.domain_length, args.domain_width, args.domain_height]
387 | face_cell_nums = [Ny*Nz, Nx*Nz, Nx*Ny]
388 | for i, domain_measure in enumerate(domain_measures):
389 | cell_area = domain_vol/domain_measure/face_cell_nums[i]
390 | boundary_face_area1 = onp.where(onp.isclose(boundary_face_centroids[:, 2*i, i], 0., atol=1e-08), cell_area, 0.)
391 | boundary_face_area2 = onp.where(onp.isclose(boundary_face_centroids[:, 2*i + 1, i], domain_measure, atol=1e-08), cell_area, 0.)
392 | boundary_face_areas += [boundary_face_area1, boundary_face_area2]
393 |
394 | boundary_face_areas = onp.transpose(onp.stack(boundary_face_areas))
395 |
396 | meta_info = onp.array([0., 0., 0., args.domain_length, args.domain_width, args.domain_height])
397 | polycrystal = PolyCrystal(edges, ch_len, face_areas, centroids, volumes, unique_oris_rgb, unique_grain_directions,
398 | cell_ori_inds, boundary_face_areas, boundary_face_centroids, meta_info)
399 |
400 | return polycrystal, mesh
401 |
402 |
403 | def build_graph(polycrystal, y0):
404 | '''
405 | Initialize graph using JAX library Jraph
406 | https://github.com/deepmind/jraph
407 | '''
408 | print(f"Build graph...")
409 | num_nodes = len(polycrystal.centroids)
410 | senders = polycrystal.edges[:, 0]
411 | receivers = polycrystal.edges[:, 1]
412 | n_node = np.array([num_nodes])
413 | n_edge = np.array([len(senders)])
414 | senders = np.array(senders)
415 | receivers = np.array(receivers)
416 |
417 | print(f"Total number nodes = {n_node[0]}, total number of edges = {n_edge[0]}")
418 |
419 | node_features = {'state':y0,
420 | 'centroids': polycrystal.centroids,
421 | 'volumes': polycrystal.volumes[:, None],
422 | 'boundary_face_areas': polycrystal.boundary_face_areas,
423 | 'boundary_face_centroids': polycrystal.boundary_face_centroids}
424 |
425 | edge_features = {'ch_len': polycrystal.ch_len[:, None],
426 | 'anisotropy': np.ones((n_edge[0], args.num_oris))}
427 |
428 | graph = jraph.GraphsTuple(nodes=node_features, edges=edge_features, senders=senders, receivers=receivers,
429 | n_node=n_node, n_edge=n_edge, globals={})
430 |
431 | return graph
432 |
433 |
434 | def update_graph():
435 | '''
436 | With the help of Jraph, we can compute both grad_energy and local_energy easily.
437 | Note that grad_energy should be understood as stored in edges, while local_energy stored in nodes.
438 | '''
439 | # TODO: Don't do sum here. Let Jraph do sum by defining global energy.
440 | def update_edge_fn(edges, senders, receivers, globals_):
441 | '''
442 | Compute grad_energy for T, zeta, eta
443 | '''
444 | del globals_
445 | sender_T, sender_zeta, sender_eta = unpack_state(senders['state'])
446 | receiver_T, receiver_zeta, receiver_eta = unpack_state(receivers['state'])
447 | ch_len = edges['ch_len']
448 | anisotropy = edges['anisotropy']
449 | assert anisotropy.shape == sender_eta.shape
450 | grad_energy_T = args.kappa_T * 0.5 * np.sum((sender_T - receiver_T)**2 * ch_len)
451 | grad_energy_zeta = args.kappa_p * 0.5 * np.sum((sender_zeta - receiver_zeta)**2 * ch_len)
452 | grad_energy_eta = args.kappa_g * 0.5 * np.sum((sender_eta - receiver_eta)**2 * ch_len * anisotropy)
453 | grad_energy = (grad_energy_zeta + grad_energy_eta) * args.ad_hoc + grad_energy_T
454 |
455 | return {'grad_energy': grad_energy}
456 |
457 | def update_node_fn(nodes, sent_edges, received_edges, globals_):
458 | '''
459 | Compute local_energy for zeta and eta
460 | '''
461 | del sent_edges, received_edges
462 |
463 | T, zeta, eta = unpack_state(nodes['state'])
464 | assert T.shape == zeta.shape
465 |
466 | # phi = 0.5 * (1 - np.tanh(1e2*(T/args.T_melt - 1)))
467 | phi = 0.5 * (1 - np.tanh(1e10*(T/args.T_melt - 1)))
468 |
469 | phase_energy = args.m_p * np.sum(((1 - zeta)**2 * phi + zeta**2 * (1 - phi)))
470 | gamma = 1
471 | vmap_outer = jax.vmap(np.outer, in_axes=(0, 0))
472 | grain_energy_1 = np.sum((eta**4/4. - eta**2/2.))
473 | graph_energy_2 = gamma * (np.sum(np.sum(vmap_outer(eta, eta)**2, axis=(1, 2))[:, None]) - np.sum(eta**4))
474 | graph_energy_3 = np.sum((1 - zeta.reshape(-1))**2 * np.sum(eta**2, axis=1).reshape(-1))
475 | grain_energy = args.m_g * (grain_energy_1 + graph_energy_2 + graph_energy_3)
476 |
477 | local_energy = phase_energy + grain_energy
478 | local_energy = local_energy / args.ad_hoc
479 |
480 | return {'local_energy': local_energy}
481 |
482 | def update_global_fn(nodes, edges, globals_):
483 | del globals_
484 | total_energy = edges['grad_energy'] + nodes['local_energy']
485 | return {'total_energy': total_energy}
486 |
487 | net_fn = jraph.GraphNetwork(update_edge_fn=update_edge_fn,
488 | update_node_fn=update_node_fn,
489 | update_global_fn=update_global_fn)
490 |
491 | return net_fn
492 |
493 |
494 | def phase_field(graph, polycrystal):
495 | net_fn = update_graph()
496 | volumes = graph.nodes['volumes']
497 | centroids = graph.nodes['centroids']
498 |
499 | def heat_source(y, t, *ode_params):
500 | '''
501 | Using a boundary heat source with following reference:
502 | Lian, Yanping, et al. "A cellular automaton finite volume method for microstructure evolution during
503 | additive manufacturing." Materials & Design 169 (2019): 107672.
504 | The heat source only acts on the top surface.
505 | Also, convection and radiation are considered, which act on all surfaces.
506 | '''
507 | power_x, power_y, power_on = ode_params
508 | T, zeta, eta = unpack_state(y)
509 | boundary_face_areas = graph.nodes['boundary_face_areas']
510 | boundary_face_centroids = graph.nodes['boundary_face_centroids']
511 |
512 | q_convection = np.sum(args.h_conv*(args.T_ambient - T)*boundary_face_areas, axis=1)
513 | q_radiation = np.sum(args.emissivity*args.SB_constant*(args.T_ambient**4 - T**4)*boundary_face_areas, axis=1)
514 |
515 | # 0: left surface, 1: right surface, 2: front surface, 3: back surface, 4: bottom surface, 5: top surface
516 | upper_face_centroids = boundary_face_centroids[:, 5, :]
517 | upper_face_areas = boundary_face_areas[:, 5]
518 |
519 | X = upper_face_centroids[:, 0] - power_x
520 | Y = upper_face_centroids[:, 1] - power_y
521 | q_laser = 2*args.power*args.power_fraction/(np.pi * args.r_beam**2) * np.exp(-2*(X**2 + Y**2)/args.r_beam**2) * upper_face_areas
522 | q_laser = q_laser * power_on
523 |
524 | q = q_convection + q_radiation + q_laser
525 |
526 | return q[:, None]
527 |
528 | def update_anisotropy():
529 | '''
530 | Determine anisotropy (see Yan paper Eq. (12))
531 | '''
532 | print("Start of compute_anisotropy...")
533 | # Anisotropy computation is a bit arguable here.
534 | # One way is to use grad_eta to determine grain boundary direction, which seems physically better.
535 | # But what seems to be used in Yan paper is just to use finite difference cell boundary direction as grain boundary direction.
536 | use_eta_flag = False
537 | if args.case.startswith('fd') and use_eta_flag:
538 | y = graph.nodes['state']
539 | eta = y[:, 2:]
540 | eta_xyz = np.reshape(eta, (args.Nz, args.Ny, args.Nx, args.num_oris))
541 | eta_neg_x = np.concatenate((eta_xyz[:, :, :1, :], eta_xyz[:, :, :-1, :]), axis=2)
542 | eta_pos_x = np.concatenate((eta_xyz[:, :, 1:, :], eta_xyz[:, :, -1:, :]), axis=2)
543 | eta_neg_y = np.concatenate((eta_xyz[:, :1, :, :], eta_xyz[:, :-1, :, :]), axis=1)
544 | eta_pos_y = np.concatenate((eta_xyz[:, 1:, :, :], eta_xyz[:, -1:, :, :]), axis=1)
545 | eta_neg_z = np.concatenate((eta_xyz[:1, :, :, :], eta_xyz[:-1, :, :, :]), axis=0)
546 | eta_pos_z = np.concatenate((eta_xyz[1:, :, :, :], eta_xyz[-1:, :, :, :]), axis=0)
547 | directions_xyz = np.stack((eta_pos_x - eta_neg_x, eta_pos_y - eta_neg_y, eta_pos_z - eta_neg_z), axis=-1)
548 | assert directions_xyz.shape == (args.Nz, args.Ny, args.Nx, args.num_oris, args.dim)
549 | directions = directions_xyz.reshape(-1, args.num_oris, args.dim)
550 | sender_directions = np.take(directions, graph.senders, axis=0)
551 | receiver_directions = np.take(directions, graph.receivers, axis=0)
552 | edge_directions = (sender_directions + receiver_directions) / 2.
553 | else:
554 | sender_centroids = np.take(centroids, graph.senders, axis=0)
555 | receiver_centroids = np.take(centroids, graph.receivers, axis=0)
556 | edge_directions = sender_centroids - receiver_centroids
557 | edge_directions = np.repeat(edge_directions[:, None, :], args.num_oris, axis=1) # (num_edges, num_oris, dim)
558 |
559 | unique_grain_directions = polycrystal.unique_grain_directions # (num_directions_per_cube, num_oris, dim)
560 |
561 | assert edge_directions.shape == (len(graph.senders), args.num_oris, args.dim)
562 | cosines = np.sum(unique_grain_directions[None, :, :, :] * edge_directions[:, None, :, :], axis=-1) \
563 | / (np.linalg.norm(edge_directions, axis=-1)[:, None, :])
564 | anlges = np.arccos(cosines)
565 | anlges = np.where(np.isfinite(anlges), anlges, 0.)
566 | anlges = np.where(anlges < np.pi/2., anlges, np.pi - anlges)
567 | anlges = np.min(anlges, axis=1)
568 |
569 | anisotropy_term = 1. + args.anisotropy * (np.cos(anlges)**4 + np.sin(anlges)**4) # (num_edges, num_oris)
570 |
571 | assert anisotropy_term.shape == (len(graph.senders), args.num_oris)
572 | graph.edges['anisotropy'] = anisotropy_term
573 | print("End of compute_anisotropy...")
574 |
575 | def compute_energy(y, t, *ode_params):
576 | '''
577 | When you call net_fn, you are asking Jraph to compute the total free energy (grad energy + local energy) for you.
578 | '''
579 | q = heat_source(y, t, *ode_params)
580 | graph.nodes['state'] = y
581 | new_graph = net_fn(graph)
582 | return new_graph.edges['grad_energy'], new_graph.nodes['local_energy'], q
583 |
584 | grad_energy_der_fn = jax.grad(lambda y, t, *ode_params: compute_energy(y, t, *ode_params)[0])
585 | local_energy_der_fn = jax.grad(lambda y, t, *ode_params: compute_energy(y, t, *ode_params)[1])
586 |
587 | def state_rhs(y, t, *ode_params):
588 | '''
589 | Define the right-hand-side function for the ODE system
590 | '''
591 | update_anisotropy()
592 | _, _, q = compute_energy(y, t, *ode_params)
593 | T, zeta, eta = unpack_state(y)
594 |
595 | # If T is too large, L would be too large - solution diverges; Also, T too large is not physical.
596 | T = np.where(T > 2000., 2000., T)
597 |
598 | der_grad = grad_energy_der_fn(y, t, *ode_params)
599 | der_local = local_energy_der_fn(y, t, *ode_params)
600 |
601 | # How to choose Lp? Seems to be an open question.
602 | Lp = args.L0 * np.exp(-args.Qg / (args.T_melt*args.gas_const))
603 |
604 | # The problem of the following definition of Lp is that
605 | # if a state is liquid, and assume T is very low, then Lp can be very small,
606 | # then the liquid state can't be transformed to solid state.
607 | # Lp = args.L0 * np.exp(-args.Qg / (T*args.gas_const))
608 |
609 | Lg = args.L0 * np.exp(-args.Qg / (T*args.gas_const))
610 |
611 | rhs_p = -Lp * (der_grad[:, 1:2]/volumes + der_local[:, 1:2])
612 | rhs_g = -Lg * (der_grad[:, 2:]/volumes + der_local[:, 2:])
613 | rhs_T = (-der_grad[:, 0:1] + q)/volumes/(args.rho * args.c_p)
614 |
615 | rhs = np.hstack((rhs_T, rhs_p, rhs_g))
616 |
617 | return rhs
618 |
619 | return state_rhs
620 |
--------------------------------------------------------------------------------
/src/arguments.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax
3 | import jax.numpy as np
4 | import argparse
5 | import sys
6 | import numpy as onp
7 | import matplotlib.pyplot as plt
8 | from jax.config import config
9 |
10 | # Set numpy printing format
11 | onp.random.seed(0)
12 | onp.set_printoptions(threshold=sys.maxsize, linewidth=1000, suppress=True)
13 | onp.set_printoptions(precision=10)
14 |
15 | # np.set_printoptions(threshold=sys.maxsize, linewidth=1000, suppress=True)
16 | # np.set_printoptions(precision=5)
17 |
18 | # Manage arguments
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--num_oris', type=int, default=20)
21 | parser.add_argument('--num_grains', type=int, default=40000)
22 | parser.add_argument('--dim', type=int, default=3)
23 | parser.add_argument('--domain_height', type=float, help='Unit: mm', default=0.1)
24 | parser.add_argument('--domain_width', type=float, help='Unit: mm', default=0.4)
25 | parser.add_argument('--domain_length', type=float, help='Unit: mm', default=1.)
26 | parser.add_argument('--dt', type=float, help='Unit: s', default=2e-7)
27 | parser.add_argument('--T_melt', type=float, help='Unit: K', default=1700.)
28 | parser.add_argument('--T_ambient', type=float, help='Unit: K', default=300.)
29 | parser.add_argument('--rho', type=float, help='Unit: kg/mm^3', default=8.08e-6)
30 | parser.add_argument('--c_p', type=float, help='Unit: J/(kg*K)', default=770.)
31 | # parser.add_argument('--laser_vel', type=float, help='Unit: mm/s', default=500.)
32 | parser.add_argument('--power', type=float, help='Unit: W', default=200.)
33 | parser.add_argument('--power_fraction', type=float, help='Unit: None', default=0.4)
34 | parser.add_argument('--r_beam', type=float, help='Unit: mm', default=0.1)
35 | parser.add_argument('--emissivity', type=float, help='Unit:', default=0.2)
36 | parser.add_argument('--SB_constant', type=float, help='Unit: W/(mm^2*K^4)', default=5.67e-14)
37 | parser.add_argument('--h_conv', type=float, help='Unit: W/(mm^2*K)', default=1e-4)
38 | parser.add_argument('--kappa_T', type=float, help='Unit: W/(mm*K)', default=1e-2)
39 | parser.add_argument('--gas_const', type=float, help='Unit: J/(Mol*K)', default=8.3)
40 | parser.add_argument('--Qg', type=float, help='Unit: J/Mol', default=1.4e5)
41 | parser.add_argument('--L0', type=float, help='Unit: mm^4/(J*s)', default=3.5e12)
42 | parser.add_argument('--kappa_p', type=float, help='Unit: J/mm', default=2.77e-9)
43 |
44 | # We don't know a_k value in Yan paper Eq. (16), so let's just make kappa_g = kappa_p.
45 | # parser.add_argument('--kappa_g', type=float, help='Unit: J/mm', default=3.7e-9)
46 | parser.add_argument('--kappa_g', type=float, help='Unit: J/mm', default=2.77e-9)
47 |
48 | parser.add_argument('--m_p', type=float, help='Unit: J/mm^3', default=1.2e-4)
49 | parser.add_argument('--m_g', type=float, help='Unit: J/mm^3', default=2.4e-4)
50 | parser.add_argument('--anisotropy', type=float, help='Unit: None', default=0.15)
51 | parser.add_argument('--ad_hoc', type=float, help='Unit: None', default=1.)
52 |
53 | parser.add_argument('--layer', type=int, help='layer number', default=1)
54 | parser.add_argument('--write_sol_interval', type=int, help='interval of writing solutions to file', default=500)
55 |
56 | args = parser.parse_args()
57 |
--------------------------------------------------------------------------------
/src/curved_grain.py:
--------------------------------------------------------------------------------
1 | '''
2 | For debugging purposes, replicate the curved shape grains as reported in Voorhees's paper
3 | See https://doi.org/10.1016/j.actamat.2021.116862
4 | '''
5 | import jax
6 | import jax.numpy as np
7 | import numpy as onp
8 | import os
9 | import meshio
10 | from src.utils import read_path, obj_to_vtu, unpack_state, walltime
11 | from src.arguments import args
12 | from src.allen_cahn import polycrystal_fd, build_graph, phase_field, odeint, explicit_euler
13 |
14 |
15 | def set_params():
16 | '''
17 | If a certain parameter is not set, a default value will be used (see src/arguments.py for details).
18 | '''
19 | args.case = 'fd_example'
20 | args.num_grains = 20000
21 | args.domain_length = 1.
22 | args.domain_width = 0.2
23 | args.domain_height = 0.1
24 | args.r_beam = 0.03
25 | args.power = 100
26 | args.write_sol_interval = 1000
27 |
28 | args.ad_hoc = 0.1
29 |
30 |
31 | def neper_domain():
32 | '''
33 | We use Neper to generate polycrystal structure.
34 | Neper has two major functions: generate a polycrystal structure, and mesh it.
35 | See https://neper.info/ for more information.
36 | '''
37 | set_params()
38 | os.system(f'neper -T -n {args.num_grains} -id 1 -regularization 0 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
39 | -o data/neper/{args.case}/domain -format tess,obj,ori')
40 | os.system(f'neper -T -loadtess data/neper/{args.case}/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
41 | os.system(f'neper -M -rcl 1 -elttype hex -faset faces data/neper/{args.case}/domain.tess')
42 |
43 |
44 | def write_vtu_files():
45 | '''
46 | This is just a helper function if you want to visualize the polycrystal or the mesh generated by Neper.
47 | You may use Paraview to open the output vtu files.
48 | '''
49 | set_params()
50 | filepath = f'data/neper/{args.case}/domain.msh'
51 | fd_mesh = meshio.read(filepath)
52 | fd_mesh.write(f'data/vtk/{args.case}/mesh/fd_mesh.vtu')
53 | poly_mesh = obj_to_vtu(args.case)
54 | poly_mesh.write(f'data/vtk/{args.case}/mesh/poly_mesh.vtu')
55 |
56 |
57 | def get_T(centroids, t):
58 | '''
59 | Analytic T from https://doi.org/10.1016/j.actamat.2021.116862
60 | '''
61 | T_ambiant = 300.
62 | alpha = 5.2
63 | Q = 25
64 | kappa = 2.7*1e-2
65 | x0 = 0.2*args.domain_length
66 |
67 | vel = 0.6/0.0024
68 |
69 | X = centroids[:, 0] - x0 - vel * t
70 | Y = centroids[:, 1] - 0.5*args.domain_width
71 | Z = centroids[:, 2] - args.domain_height
72 | R = np.sqrt(X**2 + Y**2 + Z**2)
73 | T = T_ambiant + Q / (2 * np.pi * kappa) / R * np.exp(-vel / (2*alpha) * (R + X))
74 |
75 | return T[:, None]
76 |
77 |
78 |
79 | @jax.jit
80 | def overwrite_T(y, centroids, t):
81 | '''
82 | We overwrite T if T is prescribed.
83 | '''
84 | T, zeta, eta = unpack_state(y)
85 | T = get_T(centroids, t)
86 | return np.hstack((T, zeta, eta))
87 |
88 |
89 |
90 | def initialization(poly_sim):
91 | '''
92 | Prescribe the initial conditions for T, zeta and eta.
93 | '''
94 | num_nodes = len(poly_sim.centroids)
95 | T = args.T_ambient*np.ones(num_nodes)
96 | zeta = np.ones(num_nodes)
97 | eta = np.zeros((num_nodes, args.num_oris))
98 | eta = eta.at[np.arange(num_nodes), poly_sim.cell_ori_inds].set(1)
99 | # shape of state: (num_nodes, 1 + 1 + args.num_oris)
100 | y0 = np.hstack((T[:, None], zeta[:, None], eta))
101 | melt = np.zeros(len(y0), dtype=bool)
102 | return y0, melt
103 |
104 |
105 | def run():
106 | '''
107 | The laser scanning path is defined using a txt file.
108 | Each line of the txt file stands for:
109 | time [s], x_position [mm], y_position [mm], action_of_turning_laser_on_or_off_at_this_time [N/A]
110 | '''
111 | set_params()
112 | ts, xs, ys, ps = read_path(f'data/txt/fd_example_1.txt')
113 | polycrystal, mesh = polycrystal_fd(args.case)
114 | y0, melt = initialization(polycrystal)
115 | graph = build_graph(polycrystal, y0)
116 | state_rhs = phase_field(graph, polycrystal)
117 | odeint(polycrystal, mesh, None, explicit_euler, state_rhs, y0, melt, ts, xs, ys, ps, overwrite_T)
118 |
119 |
120 | if __name__ == "__main__":
121 | # neper_domain()
122 | # write_vtu_files()
123 | run()
124 |
125 |
--------------------------------------------------------------------------------
/src/example.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as np
3 | import numpy as onp
4 | import os
5 | import meshio
6 | from src.utils import read_path, obj_to_vtu
7 | from src.arguments import args
8 | from src.allen_cahn import polycrystal_fd, build_graph, phase_field, odeint, explicit_euler
9 |
10 |
11 | def set_params():
12 | '''
13 | If a certain parameter is not set, a default value will be used (see src/arguments.py for details).
14 | '''
15 | args.case = 'fd_example'
16 | args.num_grains = 20000
17 | args.domain_length = 1.
18 | args.domain_width = 0.2
19 | args.domain_height = 0.1
20 |
21 | args.r_beam = 0.03
22 | args.power = 80
23 |
24 | # args.r_beam = 0.02
25 | # args.power = 50
26 |
27 | args.write_sol_interval = 1000
28 | # args.m_g = 1.2e-4
29 |
30 |
31 | def neper_domain():
32 | '''
33 | We use Neper to generate polycrystal structure.
34 | Neper has two major functions: generate a polycrystal structure, and mesh it.
35 | See https://neper.info/ for more information.
36 | '''
37 | set_params()
38 | os.system(f'neper -T -n {args.num_grains} -id 1 -regularization 0 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
39 | -o data/neper/{args.case}/domain -format tess,obj,ori')
40 | os.system(f'neper -T -loadtess data/neper/{args.case}/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
41 | os.system(f'neper -M -rcl 1 -elttype hex -faset faces data/neper/{args.case}/domain.tess')
42 |
43 |
44 | def write_vtu_files():
45 | '''
46 | This is just a helper function if you want to visualize the polycrystal or the mesh generated by Neper.
47 | You may use Paraview to open the output vtu files.
48 | '''
49 | set_params()
50 | filepath = f'data/neper/{args.case}/domain.msh'
51 | fd_mesh = meshio.read(filepath)
52 | fd_mesh.write(f'data/vtk/{args.case}/mesh/fd_mesh.vtu')
53 | poly_mesh = obj_to_vtu(args.case)
54 | poly_mesh.write(f'data/vtk/{args.case}/mesh/poly_mesh.vtu')
55 |
56 |
57 | def initialization(poly_sim):
58 | '''
59 | Prescribe the initial conditions for T, zeta and eta.
60 | '''
61 | num_nodes = len(poly_sim.centroids)
62 | T = args.T_ambient*np.ones(num_nodes)
63 | zeta = np.ones(num_nodes)
64 | eta = np.zeros((num_nodes, args.num_oris))
65 | eta = eta.at[np.arange(num_nodes), poly_sim.cell_ori_inds].set(1)
66 | # shape of state: (num_nodes, 1 + 1 + args.num_oris)
67 | y0 = np.hstack((T[:, None], zeta[:, None], eta))
68 | melt = np.zeros(len(y0), dtype=bool)
69 | return y0, melt
70 |
71 |
72 | def run():
73 | '''
74 | The laser scanning path is defined using a txt file.
75 | Each line of the txt file stands for:
76 | time [s], x_position [mm], y_position [mm], action_of_turning_laser_on_or_off_at_this_time [N/A]
77 | '''
78 | set_params()
79 | ts, xs, ys, ps = read_path(f'data/txt/{args.case}.txt')
80 | polycrystal, mesh = polycrystal_fd(args.case)
81 | y0, melt = initialization(polycrystal)
82 | graph = build_graph(polycrystal, y0)
83 | state_rhs = phase_field(graph, polycrystal)
84 | odeint(polycrystal, mesh, None, explicit_euler, state_rhs, y0, melt, ts, xs, ys, ps)
85 |
86 |
87 | if __name__ == "__main__":
88 | # neper_domain()
89 | # write_vtu_files()
90 | run()
91 |
--------------------------------------------------------------------------------
/src/fit_ellipsoid.py:
--------------------------------------------------------------------------------
1 | '''
2 | The code is a direct copy of
3 | https://github.com/minillinim/ellipsoid/blob/master/ellipsoid.py
4 | The purpose is to compute aspect ratio.
5 | The code is slow. We didn't use it in the manuscript.
6 | '''
7 |
8 |
9 | from __future__ import division
10 | from mpl_toolkits.mplot3d import Axes3D
11 | import matplotlib.pyplot as plt
12 | import sys
13 | import numpy as np
14 | from numpy import linalg
15 | from random import random
16 |
17 | class EllipsoidTool:
18 | """Some stuff for playing with ellipsoids"""
19 | def __init__(self): pass
20 |
21 | def getMinVolEllipse(self, P=None, tolerance=0.01):
22 | """ Find the minimum volume ellipsoid which holds all the points
23 |
24 | Based on work by Nima Moshtagh
25 | http://www.mathworks.com/matlabcentral/fileexchange/9542
26 | and also by looking at:
27 | http://cctbx.sourceforge.net/current/python/scitbx.math.minimum_covering_ellipsoid.html
28 | Which is based on the first reference anyway!
29 |
30 | Here, P is a numpy array of N dimensional points like this:
31 | P = [[x,y,z,...], <-- one point per line
32 | [x,y,z,...],
33 | [x,y,z,...]]
34 |
35 | Returns:
36 | (center, radii, rotation)
37 |
38 | """
39 | (N, d) = np.shape(P)
40 | d = float(d)
41 |
42 | # Q will be our working array
43 | Q = np.vstack([np.copy(P.T), np.ones(N)])
44 | QT = Q.T
45 |
46 | # initializations
47 | err = 1.0 + tolerance
48 | u = (1.0 / N) * np.ones(N)
49 |
50 | # Khachiyan Algorithm
51 | while err > tolerance:
52 | V = np.dot(Q, np.dot(np.diag(u), QT))
53 | M = np.diag(np.dot(QT , np.dot(linalg.inv(V), Q))) # M the diagonal vector of an NxN matrix
54 | j = np.argmax(M)
55 | maximum = M[j]
56 | step_size = (maximum - d - 1.0) / ((d + 1.0) * (maximum - 1.0))
57 | new_u = (1.0 - step_size) * u
58 | new_u[j] += step_size
59 | err = np.linalg.norm(new_u - u)
60 | u = new_u
61 |
62 | # center of the ellipse
63 | center = np.dot(P.T, u)
64 |
65 | # the A matrix for the ellipse
66 | A = linalg.inv(
67 | np.dot(P.T, np.dot(np.diag(u), P)) -
68 | np.array([[a * b for b in center] for a in center])
69 | ) / d
70 |
71 | # Get the values we'd like to return
72 | U, s, rotation = linalg.svd(A)
73 | radii = 1.0/np.sqrt(s)
74 |
75 | return (center, radii, rotation)
76 |
77 | def getEllipsoidVolume(self, radii):
78 | """Calculate the volume of the blob"""
79 | return 4./3.*np.pi*radii[0]*radii[1]*radii[2]
80 |
81 | def plotEllipsoid(self, center, radii, rotation, ax=None, plotAxes=False, cageColor='b', cageAlpha=0.2):
82 | """Plot an ellipsoid"""
83 | make_ax = ax == None
84 | if make_ax:
85 | fig = plt.figure()
86 | ax = fig.add_subplot(111, projection='3d')
87 |
88 | u = np.linspace(0.0, 2.0 * np.pi, 100)
89 | v = np.linspace(0.0, np.pi, 100)
90 |
91 | # cartesian coordinates that correspond to the spherical angles:
92 | x = radii[0] * np.outer(np.cos(u), np.sin(v))
93 | y = radii[1] * np.outer(np.sin(u), np.sin(v))
94 | z = radii[2] * np.outer(np.ones_like(u), np.cos(v))
95 | # rotate accordingly
96 | for i in range(len(x)):
97 | for j in range(len(x)):
98 | [x[i,j],y[i,j],z[i,j]] = np.dot([x[i,j],y[i,j],z[i,j]], rotation) + center
99 |
100 | if plotAxes:
101 | # make some purdy axes
102 | axes = np.array([[radii[0],0.0,0.0],
103 | [0.0,radii[1],0.0],
104 | [0.0,0.0,radii[2]]])
105 | # rotate accordingly
106 | for i in range(len(axes)):
107 | axes[i] = np.dot(axes[i], rotation)
108 |
109 |
110 | # plot axes
111 | for p in axes:
112 | X3 = np.linspace(-p[0], p[0], 100) + center[0]
113 | Y3 = np.linspace(-p[1], p[1], 100) + center[1]
114 | Z3 = np.linspace(-p[2], p[2], 100) + center[2]
115 | ax.plot(X3, Y3, Z3, color=cageColor)
116 |
117 | # plot ellipsoid
118 | ax.plot_wireframe(x, y, z, rstride=4, cstride=4, color=cageColor, alpha=cageAlpha)
119 |
120 | if make_ax:
121 | plt.show()
122 | plt.close(fig)
123 | del fig
124 |
125 | if __name__ == "__main__":
126 | # make 100 random points
127 | P = np.reshape([random()*100 for i in range(300)],(100,3))
128 |
129 | print(P.shape)
130 |
131 | # find the ellipsoid
132 | ET = EllipsoidTool()
133 | (center, radii, rotation) = ET.getMinVolEllipse(P, .01)
134 |
135 | print(radii)
136 |
137 | fig = plt.figure()
138 | ax = fig.add_subplot(111, projection='3d')
139 |
140 | # plot points
141 | ax.scatter(P[:,0], P[:,1], P[:,2], color='g', marker='*', s=100)
142 |
143 | # plot ellipsoid
144 | ET.plotEllipsoid(center, radii, rotation, ax=ax, plotAxes=True)
145 |
146 | plt.show()
147 | plt.close(fig)
148 | del fig
149 |
--------------------------------------------------------------------------------
/src/multi_layer.py:
--------------------------------------------------------------------------------
1 | import numpy as onp
2 | import jax
3 | import jax.numpy as np
4 | import os
5 | from src.utils import obj_to_vtu, read_path, walltime
6 | from src.arguments import args
7 | from src.allen_cahn import polycrystal_gn, PolyCrystal, build_graph, phase_field, odeint, explicit_euler
8 | import copy
9 | import meshio
10 |
11 |
12 | def set_params():
13 | args.num_grains = 100000
14 | args.domain_length = 2.
15 | args.domain_width = 2.
16 | args.domain_height = 0.025
17 | args.write_sol_interval = 10000
18 |
19 |
20 | def neper_domain():
21 | set_params()
22 | os.system(f'neper -T -n {args.num_grains} -id 1 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
23 | -o data/neper/multi_layer/domain -format tess,obj,ori')
24 | os.system(f'neper -T -loadtess data/neper/multi_layer/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
25 |
26 |
27 | def default_initialization(poly_sim):
28 | num_nodes = len(poly_sim.centroids)
29 | T = args.T_ambient*np.ones(num_nodes)
30 | zeta = np.ones(num_nodes)
31 | eta = np.zeros((num_nodes, args.num_oris))
32 | eta = eta.at[np.arange(num_nodes), poly_sim.cell_ori_inds].set(1)
33 | # shape of state: (num_nodes, 1 + 1 + args.num_oris)
34 | y0 = np.hstack((T[:, None], zeta[:, None], eta))
35 | melt = np.zeros(len(y0), dtype=bool)
36 | return y0, melt
37 |
38 |
39 | def layered_initialization(poly_top_layer):
40 | y_top, melt_top = default_initialization(poly_top_layer)
41 | # Current lower layer is previous upper layer
42 | y_down = np.load(f'data/numpy/{args.case}/sols/layer_{args.layer - 1:03d}/y_final_top.npy')
43 | melt_down = np.load(f'data/numpy/{args.case}/sols/layer_{args.layer - 1:03d}/melt_final_top.npy')
44 | return np.vstack((y_down, y_top)), np.hstack((melt_down, melt_top))
45 |
46 |
47 | def lift_poly(poly, delta_z):
48 | poly.boundary_face_centroids[:, :, 2] = poly.boundary_face_centroids[:, :, 2] + delta_z
49 | poly.centroids[:, 2] = poly.centroids[:, 2] + delta_z
50 | poly.meta_info[2] = poly.meta_info[2] + delta_z
51 |
52 |
53 | def flip_poly(poly, base_z):
54 | new_boundary_face_areas = onp.hstack((poly.boundary_face_areas[:, :4], poly.boundary_face_areas[:, 5:6], poly.boundary_face_areas[:, 4:5]))
55 | poly.boundary_face_areas[:] = new_boundary_face_areas
56 |
57 | new_boundary_face_centroids = onp.array(poly.boundary_face_centroids)
58 | new_boundary_face_centroids[:, :, 2] = 2*base_z - new_boundary_face_centroids[:, :, 2]
59 | new_boundary_face_centroids = onp.concatenate((new_boundary_face_centroids[:, :4, :], new_boundary_face_centroids[:, 5:6, :],
60 | new_boundary_face_centroids[:, 4:5, :]), axis=1)
61 | poly.boundary_face_centroids[:] = new_boundary_face_centroids
62 |
63 | poly.centroids[:, 2] = 2*base_z - poly.centroids[:, 2]
64 | poly.meta_info[2] = 2*base_z - poly.meta_info[2] - poly.meta_info[5]
65 |
66 |
67 | def lift_mesh(mesh, delta_z):
68 | mesh.points[:, 2] = mesh.points[:, 2] + delta_z
69 |
70 |
71 | def flip_mesh(mesh, base_z):
72 | mesh.points[:, 2] = 2*base_z - mesh.points[:, 2]
73 |
74 |
75 | def merge_mesh(mesh1, mesh2):
76 | '''
77 | Merge two meshes
78 | '''
79 | print("Merge two meshes...")
80 | points1 = mesh1.points
81 | points2 = mesh2.points
82 | cells1 = mesh1.cells_dict['polyhedron']
83 | cells2 = mesh2.cells_dict['polyhedron']
84 |
85 | num_points1 = len(points1)
86 |
87 | for cell in cells2:
88 | for face in cell:
89 | for i in range(len(face)):
90 | face[i] += num_points1
91 |
92 | points_merged = onp.vstack((points1, points2))
93 | cells_merged = [('polyhedron', onp.concatenate((cells1, cells2)))]
94 | mesh_merged = meshio.Mesh(points_merged, cells_merged)
95 | return mesh_merged
96 |
97 |
98 | def merge_poly(poly1, poly2):
99 | '''
100 | Merge two polycrystals: poly2 should exactly sits on top of poly1
101 | '''
102 | print("Merge two polycrystal domains...")
103 |
104 | poly1_top_z = poly1.meta_info[2] + poly1.meta_info[5]
105 | poly2_bottom_z = poly2.meta_info[2]
106 | num_nodes1 = len(poly1.volumes)
107 | num_nodes2 = len(poly2.volumes)
108 |
109 | assert onp.isclose(poly1_top_z, poly2_bottom_z, atol=1e-8)
110 |
111 | inds1 = onp.argwhere(poly1.boundary_face_areas[:, 5] > 0).reshape(-1)
112 | inds2 = onp.argwhere(poly2.boundary_face_areas[:, 4] > 0).reshape(-1)
113 |
114 | face_areas1 = onp.take(poly1.boundary_face_areas[:, 5], inds1)
115 | face_areas2 = onp.take(poly2.boundary_face_areas[:, 4], inds2)
116 |
117 | assert onp.isclose(onp.sum(onp.absolute(face_areas1 - face_areas2)), 0., atol=1e-8)
118 |
119 | grain_distances = 2 * (poly1_top_z - onp.take(poly1.centroids[:, 2], inds1))
120 | ch_len_interface = face_areas1 / grain_distances
121 | edges_interface = onp.stack((inds1, inds2 + num_nodes1)).T
122 | edges_merged = onp.vstack((poly1.edges, poly2.edges + num_nodes1, edges_interface))
123 | ch_len_merged = onp.hstack((poly1.ch_len, poly2.ch_len, ch_len_interface))
124 |
125 | boundary_face_areas1 = onp.hstack((poly1.boundary_face_areas[:, :5], onp.zeros((num_nodes1, 1))))
126 | boundary_face_areas2 = onp.hstack((poly2.boundary_face_areas[:, :4], onp.zeros((num_nodes2, 1)), poly2.boundary_face_areas[:, 5:6]))
127 | boundary_face_areas_merged = onp.vstack((boundary_face_areas1, boundary_face_areas2))
128 |
129 | boundary_face_centroids_merged = onp.concatenate((poly1.boundary_face_centroids, poly2.boundary_face_centroids), axis=0)
130 | volumes_merged = onp.hstack((poly1.volumes, poly2.volumes))
131 | centroids_merged = onp.vstack((poly1.centroids, poly2.centroids))
132 | cell_ori_inds_merged = onp.hstack((poly1.cell_ori_inds, poly2.cell_ori_inds))
133 |
134 | meta_info = onp.hstack((poly1.meta_info[:5], poly1.meta_info[5:] + poly2.meta_info[5:]))
135 |
136 | poly_merged = PolyCrystal(edges_merged, ch_len_merged, centroids_merged, volumes_merged, poly1.unique_oris_rgb, poly1.unique_grain_directions,
137 | cell_ori_inds_merged, boundary_face_areas_merged, boundary_face_centroids_merged, meta_info)
138 |
139 | return poly_merged
140 |
141 |
142 | def randomize_oris(poly, seed):
143 | onp.random.seed(seed)
144 | cell_ori_inds = onp.random.randint(args.num_oris, size=len(poly.volumes))
145 | poly.cell_ori_inds[:] = cell_ori_inds
146 |
147 |
148 | # @walltime
149 | def run_helper(path):
150 | print(f"Merge into poly layer")
151 | poly1, mesh1 = polycrystal_gn('multi_layer')
152 | N_random = 100
153 | randomize_oris(poly1, N_random*args.layer + 1)
154 | poly2 = copy.deepcopy(poly1)
155 | randomize_oris(poly2, N_random*args.layer + 2)
156 |
157 | mesh2 = copy.deepcopy(mesh1)
158 | flip_poly(poly2, poly1.meta_info[2] + poly1.meta_info[5])
159 | flip_mesh(mesh2, poly1.meta_info[2] + poly1.meta_info[5])
160 | mesh_layer1 = merge_mesh(mesh1, mesh2)
161 | poly_layer1 = merge_poly(poly1, poly2)
162 | args.layer_num_dofs = len(poly_layer1.volumes)
163 | args.layer_height = poly_layer1.meta_info[5]
164 |
165 | # mesh_layer1.write(f'data/vtk/part/domain.vtu')
166 |
167 | print(f"Merge into poly sim")
168 | poly_layer2 = copy.deepcopy(poly_layer1)
169 | randomize_oris(poly_layer2, N_random*args.layer + 3)
170 |
171 | mesh_layer2 = copy.deepcopy(mesh_layer1)
172 | lift_poly(poly_layer2, poly_layer1.meta_info[5])
173 | lift_mesh(mesh_layer2, poly_layer1.meta_info[5])
174 | mesh_sim = merge_mesh(mesh_layer1, mesh_layer2)
175 | poly_sim = merge_poly(poly_layer1, poly_layer2)
176 |
177 | poly_top_layer = poly_layer2
178 | bottom_mesh = mesh_layer1
179 |
180 | lift_val = (args.layer - 1) * args.layer_height
181 | lift_poly(poly_sim, lift_val)
182 | lift_poly(poly_top_layer, lift_val)
183 | lift_mesh(mesh_sim, lift_val)
184 | lift_mesh(bottom_mesh, lift_val)
185 |
186 | if args.layer == 1:
187 | y0, melt = default_initialization(poly_sim)
188 | else:
189 | y0, melt = layered_initialization(poly_top_layer)
190 |
191 | graph = build_graph(poly_sim, y0)
192 | state_rhs = phase_field(graph, poly_sim)
193 | # This is how you generate NU.txt
194 | # traveled_time = onp.cumsum(onp.array([0., 0.6, (0.6**2 + 0.3**2)**0.5, 0.6, 0.2, 0.6, 0.3, 0.6, 0.4]))/500.
195 | ts, xs, ys, ps = read_path(path)
196 | odeint(poly_sim, mesh_sim, bottom_mesh, explicit_euler, state_rhs, y0, melt, ts, xs, ys, ps)
197 |
198 |
199 | def write_info():
200 | args.case = 'gn_multi_layer_scan_2'
201 | set_params()
202 | print(f"Merge into poly layer")
203 | poly1, mesh1 = polycrystal_gn('multi_layer')
204 | poly2 = copy.deepcopy(poly1)
205 | flip_poly(poly2, poly1.meta_info[2] + poly1.meta_info[5])
206 | poly_layer1 = merge_poly(poly1, poly2)
207 | args.layer_height = poly_layer1.meta_info[5]
208 |
209 | crt_layers = copy.deepcopy(poly_layer1)
210 |
211 | args.num_total_layers = 10
212 | for i in range(1, args.num_total_layers):
213 | print(f"Merge layer {i + 1} into current {i} layers")
214 | poly_layer_new = copy.deepcopy(poly_layer1)
215 | lift_poly(poly_layer_new, args.layer_height * i)
216 | crt_layers = merge_poly(crt_layers, poly_layer_new)
217 |
218 | onp.save(f"data/numpy/{args.case}/info/edges.npy", crt_layers.edges)
219 | onp.save(f"data/numpy/{args.case}/info/vols.npy", crt_layers.volumes)
220 | onp.save(f"data/numpy/{args.case}/info/centroids.npy", crt_layers.centroids)
221 |
222 |
223 | def run_NU():
224 | args.case = 'gn_multi_layer_NU'
225 | set_params()
226 | args.num_total_layers = 20
227 | for i in range(num_total_layers):
228 | print(f"\nLayer {i + 1}...")
229 | args.layer = i + 1
230 | onp.random.seed(args.layer)
231 | run_helper(f'data/txt/{args.case}.txt')
232 |
233 |
234 | def run_scans_1():
235 | args.case = 'gn_multi_layer_scan_1'
236 | set_params()
237 | args.num_total_layers = 10
238 | for i in range(args.num_total_layers - 1, args.num_total_layers):
239 | print(f"\nLayer {i + 1}...")
240 | args.layer = i + 1
241 | onp.random.seed(args.layer)
242 | run_helper(f'data/txt/{args.case}.txt')
243 |
244 |
245 | def rotate(points, angle, center):
246 | rot_mat = onp.array([[onp.cos(angle), -onp.sin(angle)], [onp.sin(angle), onp.cos(angle)]])
247 | return onp.matmul(rot_mat, (points - center[None, :]).T).T + center[None, :]
248 |
249 |
250 | def run_scans_2():
251 | args.case = 'gn_multi_layer_scan_2'
252 | set_params()
253 | path1 = f'data/txt/{args.case}-1.txt'
254 | path2 = f'data/txt/{args.case}-2.txt'
255 | path_info = onp.loadtxt(path1)
256 | center = onp.array([args.domain_length/2., args.domain_width/2.])
257 | rotated_points = rotate(path_info[:, 1:3], onp.pi/2, center)
258 | onp.savetxt(path2, onp.hstack((path_info[:, :1], rotated_points, path_info[:, -1:])), fmt='%.5f')
259 |
260 | args.num_total_layers = 10
261 | for i in range(args.num_total_layers - 1, args.num_total_layers):
262 | print(f"\nLayer {i + 1}...")
263 | args.layer = i + 1
264 | onp.random.seed(args.layer)
265 | if i % 2 == 0:
266 | run_helper(path1)
267 | else:
268 | run_helper(path2)
269 |
270 |
271 | if __name__ == "__main__":
272 | # neper_domain()
273 | # write_info()
274 | # run_NU()
275 | run_scans_1()
276 | run_scans_2()
277 |
--------------------------------------------------------------------------------
/src/npj_review.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as np
3 | import numpy as onp
4 | import os
5 | import matplotlib.pyplot as plt
6 | from src.utils import read_path, obj_to_vtu
7 | from src.arguments import args
8 | from src.allen_cahn import polycrystal_gn, polycrystal_fd, build_graph, phase_field, odeint, odeint_no_output, explicit_euler
9 |
10 |
11 | def debug():
12 | neper_mesh = 'debug'
13 | morpho = 'gg'
14 | # args.num_grains = 1000
15 | os.system(f'neper -T -n 10000 -morpho centroidal -morphooptistop itermax=50 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
16 | -o data/neper/{neper_mesh}/domain -format tess,obj,ori')
17 | # os.system(f'neper -T -n 100 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
18 | # -o data/neper/{neper_mesh}/domain -format tess,obj,ori')
19 |
20 | # os.system(f'neper -T -loadtess data/neper/{neper_mesh}/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
21 | # os.system(f'neper -M -rcl 1 -elttype hex -faset faces data/neper/{neper_mesh}/domain.tess')
22 | write_vtu_files(neper_mesh)
23 |
24 |
25 | def write_vtu_files(neper_mesh):
26 | args.case = 'gn_' + neper_mesh
27 | poly_mesh = obj_to_vtu(neper_mesh)
28 | vtk_folder = f'data/vtk/{args.case}/mesh/'
29 | if not os.path.exists(vtk_folder):
30 | os.makedirs(vtk_folder)
31 | poly_mesh.write(f'data/vtk/{args.case}/mesh/poly_mesh.vtu')
32 |
33 |
34 | def neper_domain(neper_mesh, morpho):
35 | itermax = 50 if morpho == 'centroidal' else 1000000
36 | # TODO: Will itermax=1000000 cause a bug for voronoi?
37 | os.system(f'neper -T -n {args.num_grains} -morpho {morpho} -morphooptistop itermax={itermax} -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
38 | -o data/neper/{neper_mesh}/domain -format tess,obj,ori')
39 | os.system(f'neper -T -loadtess data/neper/{neper_mesh}/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
40 | os.system(f'neper -M -rcl 1 -elttype hex -faset faces data/neper/{neper_mesh}/domain.tess')
41 | write_vtu_files(neper_mesh)
42 |
43 |
44 | def default_initialization(poly_sim):
45 | num_nodes = len(poly_sim.centroids)
46 | T = args.T_ambient*np.ones(num_nodes)
47 | zeta = np.ones(num_nodes)
48 | eta = np.zeros((num_nodes, args.num_oris))
49 | eta = eta.at[np.arange(num_nodes), poly_sim.cell_ori_inds].set(1)
50 | # shape of state: (num_nodes, 1 + 1 + args.num_oris)
51 | y0 = np.hstack((T[:, None], zeta[:, None], eta))
52 | melt = np.zeros(len(y0), dtype=bool)
53 | return y0, melt
54 |
55 |
56 | def simulate(func, neper_mesh, just_measure_time=False):
57 | print(f"Running case {args.case}")
58 | polycrystal, mesh = func(neper_mesh)
59 | y0, melt = default_initialization(polycrystal)
60 | graph = build_graph(polycrystal, y0)
61 | state_rhs = phase_field(graph, polycrystal)
62 | ts, xs, ys, ps = read_path(f'data/txt/single_track.txt')
63 | if just_measure_time:
64 | odeint_no_output(polycrystal, mesh, None, explicit_euler, state_rhs, y0, melt, ts, xs, ys, ps)
65 | else:
66 | odeint(polycrystal, mesh, None, explicit_euler, state_rhs, y0, melt, ts, xs, ys, ps)
67 |
68 |
69 | def run_voronoi():
70 | neper_mesh = 'npj_review_voronoi'
71 | # neper_domain(neper_mesh, voronoi)
72 | args.num_oris = 20
73 | args.num_grains = 40000
74 |
75 | args.case = 'gn_npj_review_voronoi'
76 | simulate(polycrystal_gn, neper_mesh)
77 | # simulate(polycrystal_gn, neper_mesh, True)
78 |
79 | args.case = 'fd_npj_review_voronoi'
80 | simulate(polycrystal_fd, neper_mesh)
81 | # simulate(polycrystal_fd, neper_mesh, True)
82 |
83 |
84 | def run_voronoi_more_oris():
85 | neper_mesh = 'npj_review_voronoi'
86 | args.num_oris = 40
87 | args.num_grains = 40000
88 |
89 | args.case = 'gn_npj_review_voronoi_more_oris'
90 | # simulate(polycrystal_gn, neper_mesh)
91 | simulate(polycrystal_gn, neper_mesh, True)
92 |
93 | args.case = 'fd_npj_review_voronoi_more_oris'
94 | # simulate(polycrystal_fd, neper_mesh)
95 | simulate(polycrystal_fd, neper_mesh, True)
96 |
97 |
98 | def run_voronoi_less_oris():
99 | neper_mesh = 'npj_review_voronoi'
100 | args.num_oris = 10
101 | args.num_grains = 40000
102 |
103 | args.case = 'gn_npj_review_voronoi_less_oris'
104 | # simulate(polycrystal_gn, neper_mesh)
105 | simulate(polycrystal_gn, neper_mesh, True)
106 |
107 | args.case = 'fd_npj_review_voronoi_less_oris'
108 | # simulate(polycrystal_fd, neper_mesh)
109 | simulate(polycrystal_fd, neper_mesh, True)
110 |
111 |
112 | def run_voronoi_fine():
113 | neper_mesh = 'npj_review_voronoi_fine'
114 | args.num_oris = 20
115 | args.num_grains = 80000
116 | # neper_domain(neper_mesh, 'voronoi')
117 |
118 | args.case = 'gn_npj_review_voronoi_fine'
119 | # simulate(polycrystal_gn, neper_mesh)
120 | simulate(polycrystal_gn, neper_mesh, True)
121 |
122 | args.case = 'fd_npj_review_voronoi_fine'
123 | # simulate(polycrystal_fd, neper_mesh)
124 | simulate(polycrystal_fd, neper_mesh, True)
125 |
126 |
127 | def run_voronoi_coarse():
128 | neper_mesh = 'npj_review_voronoi_coarse'
129 | args.num_oris = 20
130 | args.num_grains = 20000
131 |
132 | # neper_domain(neper_mesh, 'voronoi')
133 |
134 | args.case = 'gn_npj_review_voronoi_coarse'
135 | # simulate(polycrystal_gn, neper_mesh)
136 | simulate(polycrystal_gn, neper_mesh, True)
137 |
138 | args.case = 'fd_npj_review_voronoi_coarse'
139 | # simulate(polycrystal_fd, neper_mesh)
140 | simulate(polycrystal_fd, neper_mesh, True)
141 |
142 |
143 | def run_centroidal():
144 | neper_mesh = 'npj_review_centroidal'
145 | args.num_oris = 20
146 | args.num_grains = 40000
147 |
148 | # neper_domain(neper_mesh, 'centroidal')
149 |
150 | args.case = 'gn_npj_review_centroidal'
151 | # simulate(polycrystal_gn, neper_mesh)
152 | simulate(polycrystal_gn, neper_mesh, True)
153 |
154 | args.case = 'fd_npj_review_centroidal'
155 | # simulate(polycrystal_fd, neper_mesh)
156 | simulate(polycrystal_fd, neper_mesh, True)
157 |
158 |
159 | def run_voronoi_laser_150():
160 | neper_mesh = 'npj_review_voronoi'
161 | args.power = 150.
162 |
163 | args.case = 'gn_npj_review_voronoi_laser_150'
164 | simulate(polycrystal_gn, neper_mesh)
165 |
166 | args.case = 'fd_npj_review_voronoi_laser_150'
167 | simulate(polycrystal_fd, neper_mesh)
168 |
169 |
170 | def run_voronoi_laser_250():
171 | neper_mesh = 'npj_review_voronoi'
172 | args.power = 250.
173 |
174 | args.case = 'gn_npj_review_voronoi_laser_250'
175 | simulate(polycrystal_gn, neper_mesh)
176 |
177 | args.case = 'fd_npj_review_voronoi_laser_250'
178 | simulate(polycrystal_fd, neper_mesh)
179 |
180 |
181 | def run_voronoi_laser_100():
182 | neper_mesh = 'npj_review_voronoi'
183 | args.power = 100.
184 |
185 | args.case = 'gn_npj_review_voronoi_laser_100'
186 | simulate(polycrystal_gn, neper_mesh)
187 |
188 | args.case = 'fd_npj_review_voronoi_laser_100'
189 | simulate(polycrystal_fd, neper_mesh)
190 |
191 |
192 | def npj_review_initial_size_distribution():
193 | args.case = 'none'
194 | poly_voronoi, _ = polycrystal_gn('npj_review_voronoi')
195 | voronoi_vols = poly_voronoi.volumes*1e9
196 | poly_centroidal, _ = polycrystal_gn('npj_review_centroidal')
197 | centroidal_vols = poly_centroidal.volumes*1e9
198 |
199 | colors = ['red', 'blue']
200 | labels = ['Voronoi', 'Centroidal']
201 | fig = plt.figure(figsize=(8, 6))
202 | plt.hist([voronoi_vols, centroidal_vols], bins=onp.linspace(0, 3*1e3, 13), color=colors, label=labels)
203 | plt.legend(fontsize=20, frameon=False)
204 | plt.xlabel(r'Grain volume [$\mu$m$^3$]', fontsize=20)
205 | plt.ylabel(r'Count', fontsize=20)
206 | plt.tick_params(labelsize=18)
207 | plt.grid(False)
208 | plt.savefig(f'data/pdf/initial_size_distribution.pdf', bbox_inches='tight')
209 |
210 |
211 | if __name__ == "__main__":
212 | # run_voronoi()
213 | # run_voronoi_more_oris()
214 | # run_voronoi_less_oris()
215 | # run_voronoi_fine()
216 | # run_voronoi_coarse()
217 | # run_centroidal()
218 | # run_voronoi_small_laser()
219 | # run_voronoi_big_laser()
220 | # run_voronoi_laser_100()
221 |
222 | npj_review_initial_size_distribution()
223 | # plt.show()
--------------------------------------------------------------------------------
/src/property.py:
--------------------------------------------------------------------------------
1 | '''
2 | We tried to use Neper sister software FEPX (https://fepx.info/) or DAMASK for crystal plasticity anlysis.
3 | '''
4 |
5 | def selected_cube_hex():
6 | '''
7 | This function produces input files for DAMASK (and for the OSU folks).
8 | Since we're using MOOSE now instead of DAMASK, this function should be deprecated.
9 | See https://github.com/tianjuxue/cp_gnn
10 | '''
11 | property_name = 'property_damask'
12 |
13 | offset_x = 0.5
14 | offset_y = 0.1
15 | offset_z = 0.05
16 |
17 | neper_create_cube = True
18 | if neper_create_cube:
19 | select_length = 0.2
20 | select_width = 0.2
21 | select_height = 0.05
22 |
23 | os.system(f'neper -T -n 1 -reg 0 -domain "cube({select_length},{select_width},{select_height})" -o data/neper/{property_name}/simple -format tess')
24 | os.system(f'neper -M -rcl 0.1 -elttype hex data/neper/{property_name}/simple.tess ')
25 |
26 |
27 | filepath_raw = f'data/neper/single_layer/domain.msh'
28 |
29 | mesh = meshio.read(filepath_raw)
30 | points = mesh.points
31 | cells = mesh.cells_dict['hexahedron']
32 |
33 | cell_points = onp.take(points, cells, axis=0)
34 | centroids = onp.mean(cell_points, axis=1)
35 |
36 | min_x, min_y, min_z = onp.min(points[:, 0]), onp.min(points[:, 1]), onp.min(points[:, 2])
37 | max_x, max_y, max_z = onp.max(points[:, 0]), onp.max(points[:, 1]), onp.max(points[:, 2])
38 | domain_length = max_x - min_x
39 | domain_width = max_y - min_y
40 | domain_height = max_z - min_z
41 |
42 | Nx = round(domain_length / (points[1, 0] - min_x))
43 | Ny = round(domain_width / (points[Nx + 1, 1]) - min_y)
44 | Nz = round(domain_height / (points[(Nx + 1)*(Ny + 1), 2]) - min_z)
45 | tick_x, tick_y, tick_z = domain_length / Nx, domain_width / Ny, domain_height / Nz
46 |
47 | assert Nx*Ny*Nz == len(cells)
48 |
49 | filepath_neper = f'data/neper/{property_name}/simple.msh'
50 | mesh = meshio.read(filepath_neper)
51 | points = mesh.points
52 |
53 | cells = mesh.cells_dict['hexahedron']
54 | cell_points = onp.take(points, cells, axis=0)
55 | order2_hex_centroids = onp.mean(cell_points, axis=1)
56 | indx = onp.round((order2_hex_centroids[:, 0] + offset_x - min_x - tick_x / 2.) / tick_x)
57 | indy = onp.round((order2_hex_centroids[:, 1] + offset_y - min_y - tick_y / 2.) / tick_y)
58 | indz = onp.round((order2_hex_centroids[:, 2] + offset_z - min_z - tick_z / 2.) / tick_z)
59 | total_ind = onp.array(indx + indy * Nx + indz * Nx * Ny, dtype=np.int32)
60 |
61 |
62 | def helper(case, step):
63 | print(f"Processing case {case} and step {step}")
64 | if case == 'fd_single_layer':
65 | cell_ori_inds = onp.load(f"data/numpy/{case}/sols/cell_ori_inds_{step:03d}.npy")
66 | else:
67 | grain_oris_inds = onp.load(f"data/numpy/{case}/sols/cell_ori_inds_{step:03d}.npy")
68 | cell_grain_inds = onp.load(f"data/numpy/fd_single_layer/info/cell_grain_inds.npy")
69 | cell_ori_inds = onp.take(grain_oris_inds, cell_grain_inds, axis=0)
70 |
71 | order2_cell_ori_inds = onp.take(cell_ori_inds, total_ind, axis=0)
72 |
73 | file_to_read = open(filepath_neper, 'r')
74 | lines = file_to_read.readlines()
75 | new_lines = []
76 | flag = True
77 | for i, line in enumerate(lines):
78 | l = line.split()
79 | # TODO: dirty
80 | if len(l) == 14:
81 | if flag:
82 | offset_cell_ind = int(l[0])
83 | flag = False
84 |
85 | ori_ind = order2_cell_ori_inds[int(l[0]) - offset_cell_ind]
86 | l[3] = str(ori_ind)
87 | l[4] = str(ori_ind)
88 | new_line = " ".join(l) + "\n"
89 | else:
90 | new_line = line
91 | new_lines.append(new_line)
92 |
93 | ori_quat = onp.load(f"data/numpy/quat.npy")
94 | file_to_write = f'data/neper/{property_name}/simulation_{case}_{step:03d}.msh'
95 | with open(file_to_write, 'w') as f:
96 | for i, line in enumerate(new_lines):
97 | l = line.split()
98 | if l[0] == "$ElsetOrientations":
99 | break
100 | # TODO: dirty
101 | elif len(l) == 4:
102 | l[1] = str(float(l[1]) + offset_x)
103 | l[2] = str(float(l[2]) + offset_y)
104 | l[3] = str(float(l[3]) + offset_z)
105 | f.write(" ".join(l) + "\n")
106 | else:
107 | f.write(line)
108 | f.write(f"$ElsetOrientations\n{len(ori_quat)} quaternion:active\n")
109 | for i, line in enumerate(ori_quat):
110 | f.write(f"{i + 1}")
111 | for q in line:
112 | f.write(f" {q}")
113 | f.write("\n")
114 | f.write(f"$EndElsetOrientations\n")
115 |
116 | mesh = meshio.read(file_to_write)
117 | mesh.write(f"data/neper/{property_name}/simulation_{case}_{step:03d}.vtu")
118 |
119 | helper('gn_single_layer', 30)
120 | helper('fd_single_layer', 30)
121 | helper('fd_single_layer', 0)
122 |
123 |
124 | if __name__ == "__main__":
125 | selected_cube_hex()
126 |
--------------------------------------------------------------------------------
/src/single_layer.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as np
3 | import numpy as onp
4 | import os
5 | from src.utils import read_path, obj_to_vtu
6 | from src.arguments import args
7 | from src.allen_cahn import polycrystal_gn, polycrystal_fd, build_graph, phase_field, odeint, explicit_euler
8 |
9 |
10 | def neper_domain():
11 | os.system(f'neper -T -n {args.num_grains} -id 1 -regularization 0 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
12 | -o data/neper/single_layer/domain -format tess,obj,ori')
13 | os.system(f'neper -T -loadtess data/neper/single_layer/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
14 | os.system(f'neper -M -rcl 1 -elttype hex -faset faces data/neper/single_layer/domain.tess')
15 |
16 |
17 | def default_initialization(poly_sim):
18 | num_nodes = len(poly_sim.centroids)
19 | T = args.T_ambient*np.ones(num_nodes)
20 | zeta = np.ones(num_nodes)
21 | eta = np.zeros((num_nodes, args.num_oris))
22 | eta = eta.at[np.arange(num_nodes), poly_sim.cell_ori_inds].set(1)
23 | # shape of state: (num_nodes, 1 + 1 + args.num_oris)
24 | y0 = np.hstack((T[:, None], zeta[:, None], eta))
25 | melt = np.zeros(len(y0), dtype=bool)
26 | return y0, melt
27 |
28 |
29 | def simulate(ts, xs, ys, ps, func):
30 | polycrystal, mesh = func()
31 | y0, melt = default_initialization(polycrystal)
32 | graph = build_graph(polycrystal, y0)
33 | state_rhs = phase_field(graph, polycrystal)
34 | odeint(polycrystal, mesh, None, explicit_euler, state_rhs, y0, melt, ts, xs, ys, ps)
35 |
36 |
37 | def run_gn():
38 | args.case = 'gn_single_layer'
39 | ts, xs, ys, ps = read_path(f'data/txt/single_track.txt')
40 | simulate(ts, xs, ys, ps, polycrystal_gn)
41 |
42 |
43 | def run_fd():
44 | args.case = 'fd_single_layer'
45 | ts, xs, ys, ps = read_path(f'data/txt/single_track.txt')
46 | simulate(ts, xs, ys, ps, polycrystal_fd)
47 |
48 |
49 | if __name__ == "__main__":
50 | # neper_domain()
51 | run_gn()
52 | run_fd()
53 |
--------------------------------------------------------------------------------
/src/solidification.py:
--------------------------------------------------------------------------------
1 | import os
2 | import jax
3 | import jax.numpy as np
4 | import numpy as onp
5 | from src.arguments import args
6 | from src.allen_cahn import polycrystal_fd, build_graph, phase_field, odeint, explicit_euler
7 | from src.utils import unpack_state, walltime, read_path
8 |
9 |
10 | def set_params():
11 | args.case = 'fd_solidification'
12 | args.num_grains = 10000
13 | args.domain_length = 1.
14 | args.domain_width = 0.01
15 | args.domain_height = 1.
16 | args.write_sol_interval = 1000
17 |
18 | # The following parameter controls the anisotropy level, see Yan paper Eq. (12)
19 | # If set to be zero, then isotropic grain growth is considered.
20 | # Default value is used if not explict set here.
21 | # args.anisotropy = 0.
22 |
23 |
24 | def neper_domain():
25 | set_params()
26 | os.system(f'neper -T -n {args.num_grains} -id 1 -regularization 0 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
27 | -o data/neper/solidification/domain -format tess,obj,ori')
28 | os.system(f'neper -T -loadtess data/neper/solidification/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
29 | os.system(f'neper -M -rcl 1 -elttype hex -faset faces data/neper/solidification/domain.tess')
30 |
31 |
32 | def solidification_initialization(poly_sim):
33 | num_nodes = len(poly_sim.centroids)
34 | T = args.T_ambient*np.ones(num_nodes)
35 | zeta = np.zeros(num_nodes)
36 | eta = np.zeros((num_nodes, args.num_oris))
37 | # shape of state: (num_nodes, 1 + 1 + args.num_oris)
38 | eta = eta.at[np.arange(num_nodes), poly_sim.cell_ori_inds].set(1)
39 | y0 = np.hstack((T[:, None], zeta[:, None], eta))
40 | melt = np.zeros(len(y0), dtype=bool)
41 | return y0, melt
42 |
43 |
44 | def get_T(centroids, t):
45 | '''
46 | Given spatial coordinates and t, we prescribe the value of T.
47 | '''
48 | z = centroids[:, 2]
49 | vel = 200.
50 | thermal_grad = 500.
51 | cooling_rate = thermal_grad * vel
52 | t_total = args.domain_height / vel
53 | T = args.T_melt + thermal_grad * z - cooling_rate * t
54 | return T[:, None]
55 |
56 |
57 | @jax.jit
58 | def overwrite_T(y, centroids, t):
59 | '''
60 | We overwrite T if T is prescribed.
61 | '''
62 | T, zeta, eta = unpack_state(y)
63 | T = get_T(centroids, t)
64 | return np.hstack((T, zeta, eta))
65 |
66 |
67 | def run():
68 | set_params()
69 | ts, xs, ys, ps = read_path(f'data/txt/solidification.txt')
70 | polycrystal, mesh = polycrystal_fd('solidification')
71 | y0, melt = solidification_initialization(polycrystal)
72 | graph = build_graph(polycrystal, y0)
73 | state_rhs = phase_field(graph, polycrystal)
74 | odeint(polycrystal, mesh, None, explicit_euler, state_rhs, y0, melt, ts, xs, ys, ps, overwrite_T)
75 |
76 |
77 | if __name__ == "__main__":
78 | set_params()
79 | # get_unique_ori_colors()
80 | # neper_domain()
81 | run()
82 |
--------------------------------------------------------------------------------
/src/temperature.py:
--------------------------------------------------------------------------------
1 | '''
2 | Just for debugging purposes, we wrote this quick FEniCS-based FEM solver to solve temperature field.
3 | '''
4 | import fenics as fe
5 | import numpy as np
6 | from src.arguments import args
7 |
8 |
9 | def simulation():
10 | domain_x = 1.
11 | domain_y = 0.2
12 | domain_z = 0.1
13 | total_t = 2400*1e-6
14 | ambient_T = args.T_ambient
15 |
16 | rho = args.rho
17 |
18 | Cp = args.c_h
19 |
20 | k = args.kappa_T
21 |
22 | h = args.h_conv
23 |
24 | eta = args.power_fraction
25 | r = args.r_beam
26 | P = args.power
27 |
28 | x0 = 0.2*args.domain_length
29 | y0 = 0.5*args.domain_width
30 | vel = 0.6*args.domain_length/total_t
31 |
32 | finer = False
33 | ele_size = 0.01
34 | dt = 2*1e-7
35 |
36 | EPS = 1e-8
37 | ts = np.arange(0., total_t + dt, dt)
38 |
39 | # Building mesh, see https://fenicsproject.org/olddocs/dolfin/1.5.0/python/programmers-reference/cpp/mesh/BoxMesh.html
40 | mesh = fe.BoxMesh(fe.Point(0., 0., 0.), fe.Point(domain_x, domain_y, domain_z),
41 | round(domain_x/ele_size), round(domain_y/ele_size), round(domain_z/ele_size))
42 |
43 | # Save mesh to local file, optional, just for inspection
44 | mesh_file = fe.File(f'data/vtk/fem/mesh.pvd')
45 | mesh_file << mesh
46 |
47 | # Define bottom surface
48 | class Bottom(fe.SubDomain):
49 | def inside(self, x, on_boundary):
50 | # The condition for a point x to be on bottom side is that x[2] < EPS
51 | return on_boundary and x[2] < EPS
52 |
53 | # Define top surface
54 | class Top(fe.SubDomain):
55 | def inside(self, x, on_boundary):
56 | return on_boundary and x[2] > domain_z - EPS
57 |
58 | # Define the other four surfaces
59 | class SurroundingSurfaces(fe.SubDomain):
60 | def inside(self, x, on_boundary):
61 | return on_boundary and (x[0] < EPS or x[0] > domain_x - EPS or x[1] < EPS or x[1] > domain_y - EPS)
62 |
63 | # The following few lines mark different boundaries with different numbers
64 | # For example, the top surface is marked with the integer number 2
65 | bottom = Bottom()
66 | top = Top()
67 | surrounding_surfaces = SurroundingSurfaces()
68 | boundaries = fe.MeshFunction("size_t", mesh, mesh.topology().dim() - 1)
69 | boundaries.set_all(0)
70 | bottom.mark(boundaries, 1)
71 | top.mark(boundaries, 2)
72 | surrounding_surfaces.mark(boundaries, 3)
73 | ds = fe.Measure('ds')(subdomain_data=boundaries)
74 |
75 | # Define FEM function space to be first order continuous Galerkin (the most commonly used)
76 | V = fe.FunctionSpace(mesh, 'CG', 2)
77 |
78 | # u_crt is the temperature field we want to solve
79 | # u_crt = fe.Function(V)
80 | u_crt = fe.interpolate(fe.Constant(ambient_T), V)
81 |
82 | # u_pre is the temperature from the previous step
83 | # We initialize u_pre to be a constant field = ambient_T (assign initial values)
84 | u_pre = fe.interpolate(fe.Constant(ambient_T), V)
85 |
86 | # v is the test function in FEM
87 | v = fe.TestFunction(V)
88 |
89 | # If theta = 0., we recover implicit Eulear; if theta = 1., we recover explicit Euler; theta = 0.5 seems to be a good choice.
90 | theta = 1.
91 | u_rhs = theta*u_pre + (1 - theta)*u_crt
92 |
93 | # Define Dirichlet boundary conditions for the bottom surface to be always at ambient temperature
94 | bcs = [fe.DirichletBC(V, fe.Constant(ambient_T), bottom)]
95 |
96 | # Define the laser heat source, note that t is a changeble parameter
97 | class LaserExpression(fe.UserExpression):
98 | def __init__(self, t):
99 | # Construction method of base class has to be called first
100 | super(LaserExpression, self).__init__()
101 | self.t = t
102 |
103 | def eval(self, values, x):
104 | t = self.t
105 | values[0] = 2*P*eta/(np.pi*r**2) * np.exp(-2*((x[0] - x0 - vel*t)**2 + (x[1] - y0)**2) / r**2)
106 |
107 | def value_shape(self):
108 | return ()
109 |
110 | q_laser = LaserExpression(None)
111 | q_convection = h * (u_rhs - ambient_T)
112 |
113 | # For the top surface, we will consider both convection and laser heating
114 | q_top = q_convection + q_laser
115 | # For the four side surfaces, we will only consider convection
116 | q_surr = q_convection
117 |
118 | # Deine the weak form residual
119 | # For the terms with fe.dx, they are volume integrals
120 | # Note that ds(2) means that it is a surface integral only computed on surface number 2 (the top surface), which we defined previously!
121 | residual = rho*Cp/dt*(u_crt - u_pre) * v * fe.dx + k * fe.dot(fe.grad(u_rhs), fe.grad(v)) * fe.dx \
122 | - q_top * v * ds(2) - q_surr * v * ds(3)
123 |
124 | # Open a pvd file to store results
125 | u_vtk_file = fe.File(f'data/vtk/fem/u.pvd')
126 |
127 | # Store solution at the 0th step
128 | u_vtk_file << u_pre
129 |
130 | # for i in range(len(ts) - 1):
131 |
132 | for i in range(101):
133 |
134 | print(f"step {i + 1}, time = {ts[i + 1]}")
135 |
136 | # Update the time parameter in laser
137 | q_laser.t = theta*ts[i] + (1 - theta)*ts[i + 1]
138 |
139 | # Solve the problem at this time step
140 | solver_parameters = {'newton_solver': {'maximum_iterations': 20, 'linear_solver': 'mumps'}}
141 | fe.solve(residual == 0, u_crt, bcs, solver_parameters=solver_parameters)
142 |
143 | # After solving, update u_pre so that it is equal to the newly solved u_crt
144 | u_pre.assign(u_crt)
145 |
146 | # Store solution at this step
147 | u_vtk_file << u_pre
148 |
149 | print(f"min T = {np.min(np.array(u_pre.vector()))}")
150 | print(f"max T = {np.max(np.array(u_pre.vector()))}")
151 |
152 |
153 | if __name__ == '__main__':
154 | simulation()
155 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | The file produces figures in the manuscript.
3 | It also has some post-processing functions.
4 | '''
5 | import jax.numpy as np
6 | import jax
7 | import numpy as onp
8 | import orix
9 | import meshio
10 | import pickle
11 | import time
12 | import os
13 | import glob
14 | import matplotlib.pyplot as plt
15 | from orix import plot, sampling
16 | from orix.crystal_map import Phase
17 | from orix.quaternion import Orientation, symmetry
18 | from orix.vector import Vector3d
19 | from src.arguments import args
20 | from sklearn.decomposition import PCA
21 | from scipy.spatial.transform import Rotation as R
22 | from src.fit_ellipsoid import EllipsoidTool
23 |
24 |
25 | # Latex style plot
26 | plt.rcParams.update({
27 | "text.latex.preamble": r"\usepackage{amsmath}",
28 | "text.usetex": True,
29 | "font.family": "sans-serif",
30 | "font.sans-serif": ["Helvetica"]})
31 |
32 |
33 | def unpack_state(state):
34 | T = state[..., 0:1]
35 | zeta = state[..., 1:2]
36 | eta = state[..., 2:]
37 | return T, zeta, eta
38 |
39 |
40 | def get_unique_ori_colors():
41 | onp.random.seed(1)
42 |
43 | if args.case == 'fd_solidification':
44 |
45 | # axes = onp.array([[1., 0., 0.],
46 | # [1., 1., 0.],
47 | # [1., 1., 1.],
48 | # [1., 1., 0.],
49 | # [1., 0., 0.],
50 | # [1., -1., 0.]])
51 | # angles = onp.array([0.,
52 | # onp.pi/8,
53 | # onp.pi/4,
54 | # onp.pi/4,
55 | # onp.pi/4,
56 | # onp.pi/2 - onp.arccos(onp.sqrt(2)/onp.sqrt(3))])
57 |
58 | axes = onp.array([[1., 0., 0.],
59 | [1., 0., 0.],
60 | [1., -1., 0.]])
61 | angles = onp.array([0.,
62 | onp.pi/4,
63 | onp.pi/2 - onp.arccos(onp.sqrt(2)/onp.sqrt(3))])
64 |
65 | args.num_oris = len(axes)
66 | ori2 = Orientation.from_axes_angles(axes, angles)
67 | else:
68 | ori2 = Orientation.random(args.num_oris)
69 |
70 | vx = Vector3d((1, 0, 0))
71 | vy = Vector3d((0, 1, 0))
72 | vz = Vector3d((0, 0, 1))
73 | ipfkey_x = plot.IPFColorKeyTSL(symmetry.Oh, vx)
74 | rgb_x = ipfkey_x.orientation2color(ori2)
75 | ipfkey_y = plot.IPFColorKeyTSL(symmetry.Oh, vy)
76 | rgb_y = ipfkey_y.orientation2color(ori2)
77 | ipfkey_z = plot.IPFColorKeyTSL(symmetry.Oh, vz)
78 | rgb_z = ipfkey_z.orientation2color(ori2)
79 | rgb = onp.stack((rgb_x, rgb_y, rgb_z))
80 |
81 | onp.save(f"data/numpy/quat_{args.num_oris:03d}.npy", ori2.data)
82 | dx = onp.array([1., 0., 0.])
83 | dy = onp.array([0., 1., 0.])
84 | dz = onp.array([0., 0., 1.])
85 | scipy_quat = onp.concatenate((ori2.data[:, 1:], ori2.data[:, :1]), axis=1)
86 | r = R.from_quat(scipy_quat)
87 | grain_directions = onp.stack((r.apply(dx), r.apply(dy), r.apply(dz)))
88 |
89 | save_ipf = False
90 | if save_ipf:
91 | # Plot IPF for those orientations
92 | new_params = {
93 | "figure.facecolor": "w",
94 | "figure.figsize": (6, 3),
95 | "lines.markersize": 10,
96 | "font.size": 20,
97 | "axes.grid": True,
98 | }
99 | plt.rcParams.update(new_params)
100 | ori2.symmetry = symmetry.Oh
101 | ori2.scatter("ipf", c=rgb_x, direction=ipfkey_x.direction)
102 | # plt.savefig(f'data/pdf/ipf_x.pdf', bbox_inches='tight')
103 | ori2.scatter("ipf", c=rgb_y, direction=ipfkey_y.direction)
104 | # plt.savefig(f'data/pdf/ipf_y.pdf', bbox_inches='tight')
105 | ori2.scatter("ipf", c=rgb_z, direction=ipfkey_z.direction)
106 | # plt.savefig(f'data/pdf/ipf_z.pdf', bbox_inches='tight')
107 |
108 | return rgb, grain_directions
109 |
110 |
111 | def ipf_logo():
112 | new_params = {
113 | "figure.facecolor": "w",
114 | "figure.figsize": (6, 3),
115 | "lines.markersize": 10,
116 | "font.size": 25,
117 | "axes.grid": True,
118 | }
119 | plt.rcParams.update(new_params)
120 | plot.IPFColorKeyTSL(symmetry.Oh).plot()
121 | plt.savefig(f'data/pdf/ipf_legend.pdf', bbox_inches='tight')
122 |
123 |
124 | def generate_demo_graph():
125 | '''
126 | Produce the grain graph in Fig. 1 in the manuscript
127 | '''
128 | args.num_grains = 10
129 | args.domain_length = 1.
130 | args.domain_width = 1.
131 | args.domain_height = 1.
132 | # os.system(f'neper -T -n {args.num_grains} -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
133 | # -o data/neper/graph/domain -format tess,obj')
134 |
135 | os.system(f'neper -T -n {args.num_grains} -periodic 1 -domain "cube({args.domain_length},{args.domain_width},{args.domain_height})" \
136 | -o data/neper/graph/domain -format tess,obj')
137 |
138 | os.system(f'neper -T -loadtess data/neper/graph/domain.tess -statcell x,y,z,vol,facelist -statface x,y,z,area')
139 | mesh = obj_to_vtu(domain_name='graph')
140 | num = len(mesh.cells_dict['polyhedron'])
141 | mesh.cell_data['color'] = [onp.hstack((onp.random.uniform(low=0., high=1., size=(num, 3)), onp.ones((num, 1))))]
142 | mesh.cell_data['id'] = [onp.arange(num)]
143 | mesh.write(f'data/vtk/graph/demo.vtu')
144 |
145 | # poly, _ = polycrystal_gn(domain_name='graph')
146 | # print(poly.edges)
147 |
148 |
149 | def make_video():
150 | # The command -pix_fmt yuv420p is to ensure preview of video on Mac OS is enabled
151 | # https://apple.stackexchange.com/questions/166553/why-wont-video-from-ffmpeg-show-in-quicktime-imovie-or-quick-preview
152 | # The command -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" is to solve the following "not-divisible-by-2" problem
153 | # https://stackoverflow.com/questions/20847674/ffmpeg-libx264-height-not-divisible-by-2
154 | # -y means always overwrite
155 | os.system('ffmpeg -y -framerate 10 -i data/png/tmp/u.%04d.png -pix_fmt yuv420p -vf "crop=trunc(iw/2)*2:trunc(ih/2)*2" data/mp4/test.mp4')
156 |
157 |
158 | def obj_to_vtu(domain_name):
159 | filepath=f'data/neper/{domain_name}/domain.obj'
160 | file = open(filepath, 'r')
161 | lines = file.readlines()
162 | points = []
163 | cells_inds = []
164 |
165 | for i, line in enumerate(lines):
166 | l = line.split()
167 | if l[0] == 'v':
168 | points.append([float(l[1]), float(l[2]), float(l[3])])
169 | if l[0] == 'g':
170 | cells_inds.append([])
171 | if l[0] == 'f':
172 | cells_inds[-1].append([int(pt_ind) - 1 for pt_ind in l[1:]])
173 |
174 | cells = [('polyhedron', cells_inds)]
175 | mesh = meshio.Mesh(points, cells)
176 | return mesh
177 |
178 |
179 | def walltime(func):
180 | def wrapper(*list_args, **keyword_args):
181 | start_time = time.time()
182 | return_values = func(*list_args, **keyword_args)
183 | end_time = time.time()
184 | time_elapsed = end_time - start_time
185 | platform = jax.lib.xla_bridge.get_backend().platform
186 | print(f"Time elapsed {time_elapsed} on platform {platform}")
187 | with open(f'data/txt/walltime_{platform}_{args.case}_{args.layer:03d}.txt', 'w') as f:
188 | f.write(f'{start_time}, {end_time}, {time_elapsed}\n')
189 | return return_values
190 | return wrapper
191 |
192 |
193 | def read_path(path):
194 | path_info = onp.loadtxt(path)
195 | traveled_time = path_info[:, 0]
196 | x_corners = path_info[:, 1]
197 | y_corners = path_info[:, 2]
198 | power_control = path_info[:-1, 3]
199 | ts, xs, ys, ps = [], [], [], []
200 | for i in range(len(traveled_time) - 1):
201 | ts_seg = onp.arange(traveled_time[i], traveled_time[i + 1], args.dt)
202 | xs_seg = onp.linspace(x_corners[i], x_corners[i + 1], len(ts_seg))
203 | ys_seg = onp.linspace(y_corners[i], y_corners[i + 1], len(ts_seg))
204 | ps_seg = onp.linspace(power_control[i], power_control[i], len(ts_seg))
205 | ts.append(ts_seg)
206 | xs.append(xs_seg)
207 | ys.append(ys_seg)
208 | ps.append(ps_seg)
209 |
210 | ts, xs, ys, ps = onp.hstack(ts), onp.hstack(xs), onp.hstack(ys), onp.hstack(ps)
211 | print(f"Total number of time steps = {len(ts)}")
212 | return ts, xs, ys, ps
213 |
214 |
215 | def fd_helper(num_fd_nodes):
216 | domain_vol = args.domain_length*args.domain_width*args.domain_height
217 | avg_cell_vol = domain_vol / num_fd_nodes
218 | avg_cell_len = avg_cell_vol**(1/3)
219 | avg_grain_vol = domain_vol / args.num_grains
220 | print(f"avg fd cell_vol = {avg_cell_vol}")
221 | print(f"avg grain vol = {avg_grain_vol}")
222 | return avg_cell_vol, avg_cell_len
223 |
224 |
225 | def get_edges_and_face_in_order(edges, face_areas, num_graph_nodes):
226 | edges_in_order = [[] for _ in range(num_graph_nodes)]
227 | face_areas_in_order = [[] for _ in range(num_graph_nodes)]
228 |
229 | assert len(edges) == len(face_areas)
230 |
231 | print(f"Re-ordering edges and face_areas...")
232 | for i, edge in enumerate(edges):
233 | node1 = edge[0]
234 | node2 = edge[1]
235 | edges_in_order[node1].append(node2)
236 | edges_in_order[node2].append(node1)
237 | face_areas_in_order[node1].append(face_areas[i])
238 | face_areas_in_order[node2].append(face_areas[i])
239 |
240 | return edges_in_order, face_areas_in_order
241 |
242 |
243 | def get_edges_in_order(edges, num_graph_nodes):
244 | edges_in_order = [[] for _ in range(num_graph_nodes)]
245 | print(f"Re-ordering edges...")
246 | for i, edge in enumerate(edges):
247 | node1 = edge[0]
248 | node2 = edge[1]
249 | edges_in_order[node1].append(node2)
250 | edges_in_order[node2].append(node1)
251 | return edges_in_order
252 |
253 |
254 | def BFS(edges_in_order, melt, cell_ori_inds, combined=True):
255 | num_graph_nodes = len(melt)
256 | print(f"BFS...")
257 | visited = onp.zeros(num_graph_nodes)
258 | grains = [[] for _ in range(args.num_oris)]
259 | for i in range(len(visited)):
260 | if visited[i] == 0 and melt[i]:
261 | oris_index = cell_ori_inds[i]
262 | grains[oris_index].append([])
263 | queue = [i]
264 | visited[i] = 1
265 | while queue:
266 | s = queue.pop(0)
267 | grains[oris_index][-1].append(s)
268 | connected_nodes = edges_in_order[s]
269 | for cn in connected_nodes:
270 | if visited[cn] == 0 and cell_ori_inds[cn] == oris_index and melt[cn]:
271 | queue.append(cn)
272 | visited[cn] = 1
273 |
274 | grains_combined = []
275 | for i in range(len(grains)):
276 | grains_oris = grains[i]
277 | for j in range(len(grains_oris)):
278 | grains_combined.append(grains_oris[j])
279 |
280 | if combined:
281 | return grains_combined
282 | else:
283 | return grains
284 |
285 |
286 | def get_aspect_ratio_inputs_single_track(grains_combined, volumes, centroids):
287 | grain_vols = []
288 | grain_centroids = []
289 | for i in range(len(grains_combined)):
290 | grain = grains_combined[i]
291 | grain_vol = onp.array([volumes[g] for g in grain])
292 | grain_centroid = onp.take(centroids, grain, axis=0)
293 | assert grain_centroid.shape == (len(grain_vol), 3)
294 | grain_vols.append(grain_vol)
295 | grain_centroids.append(grain_centroid)
296 |
297 | return grain_vols, grain_centroids
298 |
299 |
300 | def compute_aspect_ratios_and_vols(grain_vols, grain_centroids):
301 | pca = PCA(n_components=3)
302 | print(f"Call compute_aspect_ratios_and_vols")
303 | grain_sum_vols = []
304 | grain_sum_aspect_ratios = []
305 |
306 | for i in range(len(grain_vols)):
307 | grain_vol = grain_vols[i]
308 | sum_vol = onp.sum(grain_vol)
309 |
310 | if len(grain_vol) < 3:
311 | grain_sum_aspect_ratios.append(1.)
312 | else:
313 | directions = grain_centroids[i]
314 | weighted_directions = directions * grain_vol[:, None]
315 | # weighted_directions = weighted_directions - onp.mean(weighted_directions, axis=0)[None, :]
316 | pca.fit(weighted_directions)
317 | components = pca.components_
318 | ev = pca.explained_variance_
319 | lengths = onp.sqrt(ev)
320 | aspect_ratio = 2*lengths[0]/(lengths[1] + lengths[2])
321 | grain_sum_aspect_ratios.append(aspect_ratio)
322 |
323 | grain_sum_vols.append(sum_vol)
324 |
325 | return [grain_sum_vols, grain_sum_aspect_ratios]
326 |
327 |
328 | def compute_stats_multi_layer():
329 | args.case = 'gn_multi_layer_scan_1'
330 | args.num_total_layers = 10
331 |
332 | grain_oris_inds = []
333 | melt = []
334 | for i in range(args.num_total_layers):
335 | grain_ori_inds_bottom = onp.load(f"data/numpy/{args.case}/sols/layer_{i + 1:03d}/cell_ori_inds_bottom.npy")
336 | melt_final_bottom = onp.load(f'data/numpy/{args.case}/sols/layer_{i + 1:03d}/melt_final_bottom.npy')
337 | assert grain_ori_inds_bottom.shape == melt_final_bottom.shape
338 | grain_oris_inds.append(grain_ori_inds_bottom)
339 | melt.append(melt_final_bottom)
340 |
341 | melt = onp.hstack(melt)
342 | grain_oris_inds = onp.hstack(grain_oris_inds)
343 |
344 | edges = onp.load(f"data/numpy/{args.case}/info/edges.npy")
345 | volumes = onp.load(f"data/numpy/{args.case}/info/vols.npy")
346 | centroids = onp.load(f"data/numpy/{args.case}/info/centroids.npy")
347 |
348 | assert melt.shape == volumes.shape
349 |
350 | grains_combined = BFS(edges, melt, grain_oris_inds)
351 |
352 | grain_sum_vols = []
353 | for i in range(len(grains_combined)):
354 | grain = grains_combined[i]
355 | grain_vol = onp.sum(onp.array([volumes[g] for g in grain]))
356 | grain_sum_vols.append(grain_vol)
357 |
358 | grain_sum_vols = onp.array(grain_sum_vols)
359 |
360 | val = 0.
361 | inds = onp.argwhere(grain_sum_vols > val)[:, 0]
362 | grain_sum_vols = grain_sum_vols[inds]*1e9
363 | # grain_sum_aspect_ratios = grain_sum_aspect_ratios[inds]
364 |
365 | onp.save(f"data/numpy/{args.case}/post-processing/grain_sum_vols.npy", grain_sum_vols)
366 |
367 | return grain_sum_vols
368 |
369 |
370 | def produce_figures_multi_layer():
371 | grain_sum_vols_scan1 = onp.load(f"data/numpy/gn_multi_layer_scan_1/post-processing/grain_sum_vols.npy")
372 | grain_sum_vols_scan2 = onp.load(f"data/numpy/gn_multi_layer_scan_2/post-processing/grain_sum_vols.npy")
373 |
374 | colors = ['blue', 'red']
375 | labels = ['Scan 1', 'Scan 2']
376 |
377 | print(f"total vol of scan 1 = {onp.sum(grain_sum_vols_scan1)}, mean = {onp.mean(grain_sum_vols_scan1)}")
378 | print(f"total vol of scan 2 = {onp.sum(grain_sum_vols_scan2)}, mean = {onp.mean(grain_sum_vols_scan2)}")
379 | print(f"total number of grains for scan 1 {len(grain_sum_vols_scan1)}")
380 | print(f"total number of grains for scan 2 {len(grain_sum_vols_scan2)}")
381 |
382 | log_grain_sum_vols_scan1 = onp.log10(grain_sum_vols_scan1)
383 | log_grain_sum_vols_scan2 = onp.log10(grain_sum_vols_scan2)
384 |
385 | bins = onp.linspace(1e2, 1e7, 25)
386 | logbins = np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins))
387 |
388 | fig = plt.figure(figsize=(8, 6))
389 | plt.hist([grain_sum_vols_scan1, grain_sum_vols_scan2], bins=logbins, color=colors, label=labels)
390 |
391 | plt.xscale('log')
392 | plt.xlabel(r'Grain volume [$\mu$m$^3$]', fontsize=20)
393 | plt.ylabel(r'Count', fontsize=20)
394 | plt.tick_params(labelsize=18)
395 | plt.legend(fontsize=20, frameon=False)
396 | # plt.savefig(f'data/pdf/multi_layer_vol.pdf', bbox_inches='tight')
397 |
398 |
399 | def compute_stats_single_layer(neper_mesh):
400 | edges = onp.load(f"data/numpy/fd_{neper_mesh}/info/edges.npy")
401 | volumes = onp.load(f"data/numpy/fd_{neper_mesh}/info/vols.npy")
402 | centroids = onp.load(f"data/numpy/fd_{neper_mesh}/info/centroids.npy")
403 | cell_grain_inds = onp.load(f"data/numpy/fd_{neper_mesh}/info/cell_grain_inds.npy")
404 | num_fd_nodes = len(volumes)
405 | avg_cell_vol, avg_cell_len = fd_helper(num_fd_nodes)
406 |
407 | def compute_stats_helper():
408 | if case.startswith('fd'):
409 | cell_ori_inds = onp.load(f"data/numpy/{case}/sols/cell_ori_inds_{step:03d}.npy")
410 | melt = onp.load(f"data/numpy/{case}/sols/melt_{step:03d}.npy")
411 | T = onp.load(f"data/numpy/{case}/sols/T_{step:03d}.npy")
412 | zeta = onp.load(f"data/numpy/{case}/sols/zeta_{step:03d}.npy")
413 | else:
414 | grain_oris_inds = onp.load(f"data/numpy/{case}/sols/cell_ori_inds_{step:03d}.npy")
415 | grain_melt = onp.load(f"data/numpy/{case}/sols/melt_{step:03d}.npy")
416 | grain_T = onp.load(f"data/numpy/{case}/sols/T_{step:03d}.npy")
417 | zeta_T = onp.load(f"data/numpy/{case}/sols/zeta_{step:03d}.npy")
418 | cell_ori_inds = onp.take(grain_oris_inds, cell_grain_inds, axis=0)
419 | melt = onp.take(grain_melt, cell_grain_inds, axis=0)
420 | T = onp.take(grain_T, cell_grain_inds, axis=0)
421 | zeta = onp.take(zeta_T, cell_grain_inds, axis=0)
422 |
423 | # More reasonable: This is NOT what's currently in paper
424 | # melt = onp.logical_and(melt, zeta > 0.5)
425 |
426 | return T, zeta, melt, cell_ori_inds
427 |
428 | def process_T():
429 | sampling_depth = 5
430 | sampling_width = 5
431 | avg_length = 8
432 | sampling_section = sampling_depth*sampling_width*2
433 |
434 | bias = avg_cell_len/2. if neper_mesh == 'npj_review_voronoi_coarse' else 0.
435 | inds = onp.argwhere((centroids[:, 2] > args.domain_height - sampling_depth*avg_cell_len) &
436 | (centroids[:, 2] < args.domain_height) &
437 | (centroids[:, 1] > args.domain_width/2 + bias - sampling_width*avg_cell_len) &
438 | (centroids[:, 1] < args.domain_width/2 + bias + sampling_width*avg_cell_len))[:, 0]
439 |
440 | T_sampled = T[inds].reshape(sampling_section, -1)
441 | T_sampled_len = T_sampled.shape[1]
442 | T_sampled = T_sampled[:, :T_sampled_len//avg_length*avg_length].T
443 | T_sampled = T_sampled.reshape(-1, sampling_section*avg_length)
444 | T_sampled = onp.mean(T_sampled, axis=1)
445 |
446 | return T_sampled
447 |
448 | def process_zeta():
449 | inds_melt_pool = onp.argwhere(zeta < 0.5)[:, 0]
450 |
451 | if len(inds_melt_pool) == 0:
452 | return onp.zeros(4)
453 |
454 | centroids_melt_pool = onp.take(centroids, inds_melt_pool, axis=0)
455 | length_melt_pool = onp.max(centroids_melt_pool[:, 0]) - onp.min(centroids_melt_pool[:, 0])
456 | width_melt_pool = onp.max(centroids_melt_pool[:, 1]) - onp.min(centroids_melt_pool[:, 1])
457 | height_melt_pool = onp.max(centroids_melt_pool[:, 2]) - onp.min(centroids_melt_pool[:, 2])
458 | volume_melt_pool = avg_cell_vol*len(inds_melt_pool)
459 | characteristics = onp.array([length_melt_pool, width_melt_pool, height_melt_pool, volume_melt_pool])
460 |
461 | return characteristics
462 |
463 | def process_eta():
464 | grains_combined = BFS(edges_in_order, melt, cell_ori_inds)
465 | grain_vols, grain_centroids = get_aspect_ratio_inputs_single_track(grains_combined, volumes, centroids)
466 | eta_results = compute_aspect_ratios_and_vols(grain_vols, grain_centroids)
467 | return eta_results
468 |
469 | edges_in_order = get_edges_in_order(edges, len(centroids))
470 |
471 |
472 | # cases = ['gn', 'fd']
473 | # steps = [20]
474 | cases = [f'gn_{neper_mesh}', f'fd_{neper_mesh}']
475 |
476 | for case in cases:
477 | numpy_folder = f"data/numpy/{case}/post-processing"
478 | if not os.path.exists(numpy_folder):
479 | os.makedirs(numpy_folder)
480 |
481 | T_collect = []
482 | zeta_collect = []
483 | eta_collect = []
484 | for step in range(31):
485 | print(f"step = {step}, case = {case}")
486 | T, zeta, melt, cell_ori_inds = compute_stats_helper()
487 | T_results = process_T()
488 | zeta_results = process_zeta()
489 | eta_results = process_eta()
490 | T_collect.append(T_results)
491 | zeta_collect.append(zeta_results)
492 | eta_collect.append(eta_results)
493 |
494 | onp.save(f"data/numpy/{case}/post-processing/T_collect.npy", onp.array(T_collect))
495 | onp.save(f"data/numpy/{case}/post-processing/zeta_collect.npy", onp.array(zeta_collect))
496 | onp.save(f"data/numpy/{case}/post-processing/eta_collect.npy", onp.array(eta_collect, dtype=object))
497 |
498 |
499 | def produce_figures_single_layer(neper_mesh, additional_info=None):
500 | pdf_folder = f"data/pdf/{neper_mesh}"
501 | if not os.path.exists(pdf_folder):
502 | os.makedirs(pdf_folder)
503 |
504 | ts, xs, ys, ps = read_path(f'data/txt/single_track.txt')
505 | ts = ts[::args.write_sol_interval]*1e6
506 |
507 | volumes = onp.load(f"data/numpy/fd_{neper_mesh}/info/vols.npy")
508 | num_fd_nodes = len(volumes)
509 | avg_cell_vol, avg_cell_len = fd_helper(num_fd_nodes)
510 |
511 | def T_plot():
512 | T_results_fd = onp.load(f"data/numpy/fd_{neper_mesh}/post-processing/T_collect.npy")
513 | T_results_gn = onp.load(f"data/numpy/gn_{neper_mesh}/post-processing/T_collect.npy")
514 |
515 | step = 12
516 | T_select_fd = T_results_fd[step]
517 | T_select_gn = T_results_gn[step]
518 | x = onp.linspace(0., args.domain_length, len(T_select_fd))*1e3
519 |
520 | fig = plt.figure(figsize=(8, 6))
521 | plt.plot(x, T_select_fd, label='DNS', color='blue', marker='o', markersize=8, linestyle="-", linewidth=2)
522 | plt.plot(x, T_select_gn, label='PEGN', color='red', marker='o', markersize=8, linestyle="-", linewidth=2)
523 | plt.xlabel(r'x-axis [$\mu$m]', fontsize=20)
524 | plt.ylabel(r'Temperature [K]', fontsize=20)
525 | plt.tick_params(labelsize=18)
526 | plt.legend(fontsize=20, frameon=False)
527 | plt.savefig(f'data/pdf/{neper_mesh}/T_scanning_line.pdf', bbox_inches='tight')
528 |
529 | ind_T = T_results_fd.shape[1]//2
530 | fig = plt.figure(figsize=(8, 6))
531 | plt.plot(ts, T_results_fd[:, ind_T], label='DNS', color='blue', marker='o', markersize=8, linestyle="-", linewidth=2)
532 | plt.plot(ts, T_results_gn[:, ind_T], label='PEGN', color='red', marker='o', markersize=8, linestyle="-", linewidth=2)
533 | plt.xlabel(r'Time [$\mu$s]', fontsize=20)
534 | plt.ylabel(r'Temperature [K]', fontsize=20)
535 | plt.tick_params(labelsize=18)
536 | plt.legend(fontsize=20, frameon=False)
537 | plt.savefig(f'data/pdf/{neper_mesh}/T_center.pdf', bbox_inches='tight')
538 |
539 |
540 | def zeta_plot():
541 | zeta_results_fd = onp.load(f"data/numpy/fd_{neper_mesh}/post-processing/zeta_collect.npy")
542 | zeta_results_gn = onp.load(f"data/numpy/gn_{neper_mesh}/post-processing/zeta_collect.npy")
543 | labels = ['Melt pool length [mm]', 'Melt pool width [mm]', 'Melt pool height [mm]', 'Melt pool volume [mm$^3$]']
544 | names = ['melt_pool_length', 'melt_pool_width', 'melt_pool_height', 'melt_pool_volume']
545 | for i in range(4):
546 | fig = plt.figure(figsize=(8, 6))
547 | plt.plot(ts, zeta_results_fd[:, i], label='DNS', color='blue', marker='o', markersize=8, linestyle="-", linewidth=2)
548 | plt.plot(ts, zeta_results_gn[:, i], label='PEGN', color='red', marker='o', markersize=8, linestyle="-", linewidth=2)
549 | plt.xlabel(r'Time [$\mu$s]', fontsize=20)
550 | plt.ylabel(labels[i], fontsize=20)
551 | plt.tick_params(labelsize=18)
552 | plt.legend(fontsize=20, frameon=False)
553 | plt.savefig(f'data/pdf/{neper_mesh}/{names[i]}.pdf', bbox_inches='tight')
554 |
555 |
556 | def eta_plot(neper_mesh):
557 | eta_results_fd = onp.load(f"data/numpy/fd_{neper_mesh}/post-processing/eta_collect.npy", allow_pickle=True)
558 | eta_results_gn = onp.load(f"data/numpy/gn_{neper_mesh}/post-processing/eta_collect.npy", allow_pickle=True)
559 |
560 | # val = 1e-7 is used before we consider anisotropy
561 | # val = 1.6*1e-7 is used after we consider anisotropy
562 | if neper_mesh == 'npj_review_voronoi_fine':
563 | val = 0.8*1e-7
564 | elif neper_mesh == 'npj_review_voronoi_coarse':
565 | val = 3.2*1e-7
566 | else:
567 | val = 1.6*1e-7
568 |
569 | if additional_info == 'npj_review_centroidal_big_grain':
570 | neper_mesh = additional_info
571 | val = 1e-5
572 |
573 | def eta_helper(eta_results):
574 | vols_filtered = []
575 | aspect_ratios_filtered = []
576 | num_vols = []
577 | avg_vol = []
578 | for item in eta_results:
579 | grain_vols, aspect_ratios = item
580 | grain_vols = onp.array(grain_vols)
581 | inds = onp.argwhere(grain_vols > val)[:, 0]
582 | grain_vols = grain_vols[inds]*1e9
583 | aspect_ratios = onp.array(aspect_ratios)
584 | aspect_ratios = aspect_ratios[inds]
585 | num_vols.append(len(grain_vols))
586 | avg_vol.append(onp.mean(grain_vols))
587 | vols_filtered.append(grain_vols)
588 | aspect_ratios_filtered.append(aspect_ratios)
589 | return num_vols, avg_vol, vols_filtered, aspect_ratios_filtered
590 |
591 | num_vols_fd, avg_vol_fd, vols_filtered_fd, aspect_ratios_filtered_fd = eta_helper(eta_results_fd)
592 | num_vols_gn, avg_vol_gn, vols_filtered_gn, aspect_ratios_filtered_gn = eta_helper(eta_results_gn)
593 |
594 | fig = plt.figure(figsize=(8, 6))
595 | plt.plot(ts, num_vols_fd, label='DNS', color='blue', marker='o', markersize=8, linestyle="-", linewidth=2)
596 | plt.plot(ts, num_vols_gn, label='PEGN', color='red', marker='o', markersize=8, linestyle="-", linewidth=2)
597 | plt.xlabel(r'Time [$\mu$s]', fontsize=20)
598 | plt.ylabel(r'Number of grains', fontsize=20)
599 | plt.tick_params(labelsize=18)
600 | plt.legend(fontsize=20, frameon=False)
601 | plt.savefig(f'data/pdf/{neper_mesh}/num_grains.pdf', bbox_inches='tight')
602 |
603 | fig = plt.figure(figsize=(8, 6))
604 | plt.plot(ts, avg_vol_fd, label='DNS', color='blue', marker='o', markersize=8, linestyle="-", linewidth=2)
605 | plt.plot(ts, avg_vol_gn, label='PEGN', color='red', marker='o', markersize=8, linestyle="-", linewidth=2)
606 | plt.xlabel(r'Time [$\mu$s]', fontsize=20)
607 | plt.ylabel(r'Average grain volume [$\mu$m$^3$]', fontsize=20)
608 | plt.tick_params(labelsize=18)
609 | plt.legend(fontsize=20, frameon=False)
610 | plt.savefig(f'data/pdf/{neper_mesh}/grain_vol.pdf', bbox_inches='tight')
611 |
612 | step = 30
613 | assert len(aspect_ratios_filtered_fd) == 31
614 | assert len(vols_filtered_fd) == len(aspect_ratios_filtered_gn)
615 |
616 | fd_vols = vols_filtered_fd[step]
617 | gn_vols = vols_filtered_gn[step]
618 | fd_aspect_ratios = aspect_ratios_filtered_fd[step]
619 | gn_aspect_ratios = aspect_ratios_filtered_gn[step]
620 |
621 | print("\n")
622 | print(f"fd mean vol = {onp.mean(fd_vols)}")
623 | print(f"gn mean vol = {onp.mean(gn_vols)}")
624 |
625 | print("\n")
626 | print(f"fd median aspect_ratio = {onp.median(fd_aspect_ratios)}")
627 | print(f"gn median aspect_ratio = {onp.median(gn_aspect_ratios)}")
628 |
629 | colors = ['blue', 'red']
630 | labels = ['DNS', 'PEGN']
631 |
632 | fig = plt.figure(figsize=(8, 6))
633 | plt.hist([fd_vols, gn_vols], color=colors, bins=onp.linspace(0., 1e4, 6), label=labels)
634 | plt.legend(fontsize=20, frameon=False)
635 | plt.xlabel(r'Grain volume [$\mu$m$^3$]', fontsize=20)
636 | plt.ylabel(r'Count', fontsize=20)
637 | plt.tick_params(labelsize=18)
638 | plt.savefig(f'data/pdf/{neper_mesh}/vol_distribution.pdf', bbox_inches='tight')
639 |
640 | fig = plt.figure(figsize=(8, 6))
641 | plt.hist([fd_aspect_ratios, gn_aspect_ratios], color=colors, bins=onp.linspace(1, 4, 13), label=labels)
642 | plt.legend(fontsize=20, frameon=False)
643 | plt.xlabel(r'Aspect ratio', fontsize=20)
644 | plt.ylabel(r'Count', fontsize=20)
645 | plt.tick_params(labelsize=18)
646 | plt.savefig(f'data/pdf/{neper_mesh}/aspect_distribution.pdf', bbox_inches='tight')
647 |
648 |
649 | # print(num_vols_fd)
650 | # print(num_vols_gn)
651 |
652 | T_plot()
653 | # zeta_plot()
654 | # eta_plot(neper_mesh)
655 |
656 |
657 | def compute_vol_and_area(grain, volumes, centroids, face_areas_in_order, edges_in_order):
658 | vol = onp.sum(onp.take(volumes, grain))
659 | cen = onp.mean(onp.take(centroids, grain, axis=0), axis=0)
660 | hash_table = set(grain)
661 | # print(f"Total number of g = {len(grain)}")
662 | area = 0.
663 | for g in grain:
664 | count = 0
665 | for i, f_area in enumerate(face_areas_in_order[g]):
666 | if edges_in_order[g][i] not in hash_table:
667 | area += f_area
668 | else:
669 | count += 1
670 | # print(f"Found {count} neighbor")
671 | # print(f"Total number of neighbors found = {count}")
672 | return vol, area, cen
673 |
674 |
675 | def grain_nodes_and_edges(grain, edges_in_order):
676 | hash_table = set(grain)
677 | count = 0
678 | for g in grain:
679 | for e in edges_in_order[g]:
680 | if e in hash_table:
681 | count += 1
682 | return len(grain), count//2
683 |
684 |
685 | def npj_review_grain_growth():
686 | neper_mesh = 'npj_review_voronoi'
687 |
688 | cases = ['gn_npj_review_voronoi']
689 | # cases = [f'gn_{neper_mesh}', f'fd_{neper_mesh}']
690 |
691 | for case in cases:
692 | args.case = case
693 | args.num_oris = 20
694 | args.num_grains = 40000
695 |
696 | compute = True
697 | if compute:
698 | files_vtk = glob.glob(f"data/vtk/{args.case}/single_grain/*")
699 | for f in files_vtk:
700 | os.remove(f)
701 | unique_oris_rgb, unique_grain_directions = get_unique_ori_colors()
702 | edges = onp.load(f"data/numpy/{args.case}/info/edges.npy")
703 | volumes = onp.load(f"data/numpy/{args.case}/info/vols.npy")
704 | centroids = onp.load(f"data/numpy/{args.case}/info/centroids.npy")
705 | face_areas = onp.load(f"data/numpy/{args.case}/info/face_areas.npy")
706 |
707 | edges_in_order, face_areas_in_order = get_edges_and_face_in_order(edges, face_areas, len(centroids))
708 |
709 | grain_geo = []
710 | for step in range(15, 31, 5):
711 | print(f"step = {step}, case = {args.case}")
712 | oris_inds = onp.load(f"data/numpy/{args.case}/sols/cell_ori_inds_{step:03d}.npy")
713 | melt = onp.load(f"data/numpy/{args.case}/sols/melt_{step:03d}.npy")
714 | zeta = onp.load(f"data/numpy/{args.case}/sols/zeta_{step:03d}.npy")
715 | melt = onp.logical_and(melt, zeta > 0.5)
716 | ipf_z = onp.take(unique_oris_rgb[2], oris_inds, axis=0)
717 |
718 | grains = BFS(edges_in_order, melt, oris_inds, combined=False)
719 | grains_combined = BFS(edges_in_order, melt, oris_inds, combined=True)
720 |
721 | # Very ad-hoc
722 | if args.case.startswith('gn'):
723 | selected_grain_id = 17980
724 | else:
725 | selected_grain_id = 2876108
726 |
727 | # To answer reviewer 1 Q5
728 | if step == 30 and args.case.startswith('gn'):
729 | nums_nodes = []
730 | nums_edges = []
731 | for grain in grains_combined:
732 | num_nodes, num_edges = grain_nodes_and_edges(grain, edges_in_order)
733 | nums_nodes.append(num_nodes)
734 | nums_edges.append(num_edges)
735 | nums_nodes = onp.array(nums_nodes)
736 | nums_edges = onp.array(nums_edges)
737 | print(f"len(nums_nodes) = {len(nums_nodes)}")
738 | print(f"max nums_nodes = {onp.max(nums_nodes)}, min nums_nodes = {onp.min(nums_nodes)}")
739 | print(f"mean nums_nodes = {onp.mean(nums_nodes)}, std nums_nodes = {onp.std(nums_nodes)}")
740 | print(f"max nums_edges = {onp.max(nums_edges)}, min nums_edges = {onp.min(nums_edges)}")
741 | print(f"mean nums_edges = {onp.mean(nums_edges)}, std nums_edges = {onp.std(nums_edges)}")
742 | print(f"onp.argmax(nums_nodes) = {onp.argmax(nums_nodes)}")
743 | print(f"onp.argmax(nums_edges) = {onp.argmax(nums_edges)}")
744 |
745 |
746 | grains_same_ori = []
747 | idx = 11
748 | # 11: pink color
749 | for i, g in enumerate(grains[idx]):
750 | vol, area, cen = compute_vol_and_area(onp.array(g), volumes, centroids, face_areas_in_order, edges_in_order)
751 | grains_same_ori += g
752 | # print(f"vol = {vol}, area = {area}")
753 | # if cen[0] > args.domain_length/2. and cen[1] < args.domain_width:
754 | # print(f"cen = {cen}, i = {i}, g = {g}")
755 | if selected_grain_id in g:
756 | single_grain_idx = g
757 | grain_geo.append([vol, area])
758 |
759 |
760 | def plot_some_grains(grain_ids, name):
761 | if args.case.startswith('gn'):
762 | mesh = obj_to_vtu(neper_mesh)
763 | cells = [('polyhedron', onp.take(mesh.cells_dict['polyhedron'], grain_ids, axis=0))]
764 | else:
765 | mesh = meshio.read(f"data/vtk/{args.case}/sols/u000.vtu")
766 | cells = [('hexahedron', onp.take(mesh.cells_dict['hexahedron'], grain_ids, axis=0))]
767 |
768 | new_mesh = meshio.Mesh(mesh.points, cells)
769 | new_mesh.cell_data['ipf_z'] = [onp.take(ipf_z, grain_ids, axis=0)]
770 | new_mesh.write(f'data/vtk/{args.case}/single_grain/{name}_u{step:03d}.vtu')
771 |
772 | plot_some_grains(grains_same_ori, 'same_color')
773 | plot_some_grains(single_grain_idx, 'single_grain')
774 |
775 |
776 | onp.save(f"data/numpy/{args.case}/post-processing/grain_geo.npy", onp.array(grain_geo))
777 |
778 |
779 | fd_grain_geo = onp.load(f"data/numpy/fd_{neper_mesh}/post-processing/grain_geo.npy")
780 | gn_grain_geo = onp.load(f"data/numpy/gn_{neper_mesh}/post-processing/grain_geo.npy")
781 |
782 | ts, xs, ys, ps = read_path(f'data/txt/single_track.txt')
783 | ts = ts[::args.write_sol_interval]*1e6
784 | ts = ts[15:31:5]
785 |
786 | fig = plt.figure(figsize=(8, 6))
787 | plt.plot(ts, fd_grain_geo[:, 0]*1e9, label='DNS', color='blue', marker='o', markersize=8, linestyle="-", linewidth=2)
788 | plt.plot(ts, gn_grain_geo[:, 0]*1e9, label='PEGN', color='red', marker='o', markersize=8, linestyle="-", linewidth=2)
789 | plt.xlabel(r'Time [$\mu$s]', fontsize=20)
790 | plt.ylabel(r'Grain volume [$\mu$m$^3$]', fontsize=20)
791 | plt.tick_params(labelsize=18)
792 | plt.legend(fontsize=20, frameon=False)
793 | # plt.savefig(f'data/pdf/npj_review_grain_growth/grain_vol.pdf', bbox_inches='tight')
794 |
795 | fig = plt.figure(figsize=(8, 6))
796 | plt.plot(ts, fd_grain_geo[:, 1]*1e6, label='DNS', color='blue', marker='o', markersize=8, linestyle="-", linewidth=2)
797 | plt.plot(ts, gn_grain_geo[:, 1]*1e6, label='PEGN', color='red', marker='o', markersize=8, linestyle="-", linewidth=2)
798 | plt.xlabel(r'Time [$\mu$s]', fontsize=20)
799 | plt.ylabel(r'Surface area [$\mu$m$^2$]', fontsize=20)
800 | plt.tick_params(labelsize=18)
801 | plt.legend(fontsize=20, frameon=False)
802 | # plt.savefig(f'data/pdf/npj_review_grain_growth/grain_area.pdf', bbox_inches='tight')
803 |
804 |
805 | def single_layer():
806 | neper_mesh = 'single_layer'
807 | compute_stats_single_layer(neper_mesh)
808 | produce_figures_single_layer(neper_mesh)
809 |
810 |
811 | def npj_review_voronoi():
812 | args.num_oris = 20
813 | args.num_grains = 40000
814 | neper_mesh = 'npj_review_voronoi'
815 | # compute_stats_single_layer(neper_mesh)
816 | produce_figures_single_layer(neper_mesh)
817 |
818 |
819 | def npj_review_voronoi_more_oris():
820 | args.num_oris = 40
821 | args.num_grains = 40000
822 | neper_mesh = 'npj_review_voronoi_more_oris'
823 | # compute_stats_single_layer(neper_mesh)
824 | produce_figures_single_layer(neper_mesh)
825 |
826 |
827 | def npj_review_voronoi_less_oris():
828 | args.num_oris = 10
829 | args.num_grains = 40000
830 | neper_mesh = 'npj_review_voronoi_less_oris'
831 | # compute_stats_single_layer(neper_mesh)
832 | produce_figures_single_layer(neper_mesh)
833 |
834 |
835 | def npj_review_voronoi_fine():
836 | args.num_oris = 20
837 | args.num_grains = 80000
838 | neper_mesh = 'npj_review_voronoi_fine'
839 | # compute_stats_single_layer(neper_mesh)
840 | produce_figures_single_layer(neper_mesh)
841 |
842 |
843 | def npj_review_voronoi_coarse():
844 | args.num_oris = 20
845 | args.num_grains = 20000
846 | neper_mesh = 'npj_review_voronoi_coarse'
847 | # compute_stats_single_layer(neper_mesh)
848 | produce_figures_single_layer(neper_mesh)
849 |
850 |
851 | def npj_review_centroidal():
852 | args.num_oris = 20
853 | args.num_grains = 40000
854 | neper_mesh = 'npj_review_centroidal'
855 | # compute_stats_single_layer(neper_mesh)
856 | produce_figures_single_layer(neper_mesh)
857 |
858 |
859 | def npj_review_centroidal_big_grain():
860 | args.num_oris = 20
861 | args.num_grains = 40000
862 | neper_mesh = 'npj_review_centroidal'
863 | produce_figures_single_layer(neper_mesh, 'npj_review_centroidal_big_grain')
864 |
865 |
866 | def npj_review_laser_150():
867 | args.power = 150.
868 | neper_mesh = 'npj_review_voronoi_laser_150'
869 | # compute_stats_single_layer(neper_mesh)
870 | produce_figures_single_layer(neper_mesh)
871 |
872 |
873 | def npj_review_laser_250():
874 | args.power = 250.
875 | neper_mesh = 'npj_review_voronoi_laser_250'
876 | # compute_stats_single_layer(neper_mesh)
877 | produce_figures_single_layer(neper_mesh)
878 |
879 |
880 | def npj_review_laser_100():
881 | args.power = 100.
882 | neper_mesh = 'npj_review_voronoi_laser_100'
883 | # compute_stats_single_layer(neper_mesh)
884 | produce_figures_single_layer(neper_mesh)
885 |
886 |
887 | if __name__ == "__main__":
888 | # generate_demo_graph()
889 | # vtk_convert_from_server()
890 | # get_unique_ori_colors()
891 | # ipf_logo()
892 | # make_video()
893 | # compute_stats_multi_layer()
894 | # produce_figures_multi_layer()
895 |
896 | npj_review_voronoi()
897 | npj_review_voronoi_more_oris()
898 | npj_review_voronoi_less_oris()
899 | npj_review_voronoi_fine()
900 | npj_review_voronoi_coarse()
901 | npj_review_centroidal()
902 | # npj_review_centroidal_big_grain()
903 | npj_review_laser_100()
904 |
905 | # npj_review_grain_growth()
906 | # plt.show()
907 |
--------------------------------------------------------------------------------