├── .gitattributes ├── .gitignore ├── README.md ├── Training Deep Neural Networks for the Inverse Design of Nanophotonic Structures.pptx ├── cy_funcs.pyx ├── environment-gpu.yml ├── environment.yml ├── forward_model.ipynb ├── forward_model ├── Arch4_Epochs1000_Adam0001_Sigmoid.h5 ├── Arch4_Epochs2500_Adam0001_Sigmoid.h5 ├── Arch4_Epochs4000_Adam0001_Sigmoid.h5 ├── Arch4_Epochs400_Adam0001_Sigmoid.h5 └── Arch4_Epochs6000_Adam0001_Sigmoid.h5 ├── grating.py ├── inverse_model.ipynb ├── inverse_model ├── Epochs400_Adam0001_Sigmoid.h5 ├── InverseNet15EpochsTandem.h5 └── TandemNN_Epochs400_Adam0001_Sigmoid.h5 ├── produce_data.ipynb ├── produce_data.py ├── py_funcs.py ├── requirements.txt └── test_grating.py /.gitattributes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/.gitattributes -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | # Generated files 14 | .idea/**/contentModel.xml 15 | 16 | # Sensitive or high-churn files 17 | .idea/**/dataSources/ 18 | .idea/**/dataSources.ids 19 | .idea/**/dataSources.local.xml 20 | .idea/**/sqlDataSources.xml 21 | .idea/**/dynamic.xml 22 | .idea/**/uiDesigner.xml 23 | .idea/**/dbnavigator.xml 24 | 25 | # Gradle 26 | .idea/**/gradle.xml 27 | .idea/**/libraries 28 | 29 | # Gradle and Maven with auto-import 30 | # When using Gradle or Maven with auto-import, you should exclude module files, 31 | # since they will be recreated, and may cause churn. Uncomment if using 32 | # auto-import. 33 | # .idea/artifacts 34 | # .idea/compiler.xml 35 | # .idea/jarRepositories.xml 36 | # .idea/modules.xml 37 | # .idea/*.iml 38 | # .idea/modules 39 | # *.iml 40 | # *.ipr 41 | 42 | # CMake 43 | cmake-build-*/ 44 | 45 | # Mongo Explorer plugin 46 | .idea/**/mongoSettings.xml 47 | 48 | # File-based project format 49 | *.iws 50 | 51 | # IntelliJ 52 | out/ 53 | 54 | # mpeltonen/sbt-idea plugin 55 | .idea_modules/ 56 | 57 | # JIRA plugin 58 | atlassian-ide-plugin.xml 59 | 60 | # Cursive Clojure plugin 61 | .idea/replstate.xml 62 | 63 | # Crashlytics plugin (for Android Studio and IntelliJ) 64 | com_crashlytics_export_strings.xml 65 | crashlytics.properties 66 | crashlytics-build.properties 67 | fabric.properties 68 | 69 | # Editor-based Rest Client 70 | .idea/httpRequests 71 | 72 | # Android studio 3.1+ serialized cache file 73 | .idea/caches/build_file_checksums.ser 74 | 75 | ### Python template 76 | # Byte-compiled / optimized / DLL files 77 | __pycache__/ 78 | *.py[cod] 79 | *$py.class 80 | 81 | # C extensions 82 | *.so 83 | 84 | # Distribution / packaging 85 | .Python 86 | build/ 87 | develop-eggs/ 88 | dist/ 89 | downloads/ 90 | eggs/ 91 | .eggs/ 92 | lib/ 93 | lib64/ 94 | parts/ 95 | sdist/ 96 | var/ 97 | wheels/ 98 | share/python-wheels/ 99 | *.egg-info/ 100 | .installed.cfg 101 | *.egg 102 | MANIFEST 103 | 104 | # PyInstaller 105 | # Usually these files are written by a python script from a template 106 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 107 | *.manifest 108 | *.spec 109 | 110 | # Installer logs 111 | pip-log.txt 112 | pip-delete-this-directory.txt 113 | 114 | # Unit test / coverage reports 115 | htmlcov/ 116 | .tox/ 117 | .nox/ 118 | .coverage 119 | .coverage.* 120 | .cache 121 | nosetests.xml 122 | coverage.xml 123 | *.cover 124 | *.py,cover 125 | .hypothesis/ 126 | .pytest_cache/ 127 | cover/ 128 | 129 | # Translations 130 | *.mo 131 | *.pot 132 | 133 | # Django stuff: 134 | *.log 135 | local_settings.py 136 | db.sqlite3 137 | db.sqlite3-journal 138 | 139 | # Flask stuff: 140 | instance/ 141 | .webassets-cache 142 | 143 | # Scrapy stuff: 144 | .scrapy 145 | 146 | # Sphinx documentation 147 | docs/_build/ 148 | 149 | # PyBuilder 150 | .pybuilder/ 151 | target/ 152 | 153 | # Jupyter Notebook 154 | .ipynb_checkpoints 155 | 156 | # IPython 157 | profile_default/ 158 | ipython_config.py 159 | 160 | # pyenv 161 | # For a library or package, you might want to ignore these files since the code is 162 | # intended to run in multiple environments; otherwise, check them in: 163 | # .python-version 164 | 165 | # pipenv 166 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 167 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 168 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 169 | # install all needed dependencies. 170 | #Pipfile.lock 171 | 172 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 173 | __pypackages__/ 174 | 175 | # Celery stuff 176 | celerybeat-schedule 177 | celerybeat.pid 178 | 179 | # SageMath parsed files 180 | *.sage.py 181 | 182 | # Environments 183 | .env 184 | .venv 185 | env/ 186 | venv/ 187 | ENV/ 188 | env.bak/ 189 | venv.bak/ 190 | 191 | # Spyder project settings 192 | .spyderproject 193 | .spyproject 194 | 195 | # Rope project settings 196 | .ropeproject 197 | 198 | # mkdocs documentation 199 | /site 200 | 201 | # mypy 202 | .mypy_cache/ 203 | .dmypy.json 204 | dmypy.json 205 | 206 | # Pyre type checker 207 | .pyre/ 208 | 209 | # pytype static type analyzer 210 | .pytype/ 211 | 212 | # Cython debug symbols 213 | cython_debug/ 214 | 215 | # These are too large 216 | database.npz 217 | dataset_snapshots 218 | 219 | # ppt temp files 220 | ~$Training Deep Neural Networks for the Inverse Design of Nanophotonic Structures.pptx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Neural Network for Inverse Design in Nanophotonics 2 | ================================================== 3 | This project demonstates training a neural network for inverse design of nanophotonic gratings. 4 | The project is inspired by the [work](https://doi.org/10.1021/acsphotonics.7b01377) of Liu et al. in ACS Photonics journal. 5 | 6 | Generating Data 7 | ------------------ 8 | * To be able to run DNN training scripts, you should have dataset `dataset.npz` file in the project root folder. 9 | * If you want to use pre-generated dataset, you can download the dataset file from [here](https://drive.google.com/file/d/1D0fJ815a0pgtrE-_lObbhYUnVooeIZIe/view?usp=sharing). **Note**: The file size is ~1 GB. 10 | * In case you wish to generate the dataset from scratch, you can run either [produce_data.py](./produce_data.py) python script or [produce_data.ipynb](./produce_data.py) Jupyter notebook. 11 | 12 | Forward Model 13 | ------------- 14 | * See [forward_model.ipynb](./forward_model.ipynb) Jupyter notebook for loading, training, and saving forward model 15 | * Saved forward model states can be found in [forward_model](./forward_model/) folder. 16 | 17 | Inverse Model 18 | -------------- 19 | * See [inverse_model.ipynb](./inverse_model.ipynb) Jupyter notebook for loading, training, and saving inverse model 20 | * Saved inverse model states can be found in [inverse_model](./inverse_model/) folder. 21 | -------------------------------------------------------------------------------- /Training Deep Neural Networks for the Inverse Design of Nanophotonic Structures.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/Training Deep Neural Networks for the Inverse Design of Nanophotonic Structures.pptx -------------------------------------------------------------------------------- /cy_funcs.pyx: -------------------------------------------------------------------------------- 1 | # cython: infer_types=True 2 | import numpy as np 3 | cimport numpy as np 4 | cimport cython 5 | import scipy.linalg as spla 6 | from libc.math cimport sqrt, exp 7 | 8 | 9 | ex = np.array([1, 0]) 10 | exex = np.outer(ex, ex) 11 | ey = np.array([0, 1]) 12 | eyey = np.outer(ey, ey) 13 | I = np.eye(2) 14 | I4 = np.eye(4, dtype=np.complex128) 15 | 16 | 17 | @cython.boundscheck(False) 18 | @cython.wraparound(False) 19 | cpdef matrix_M(double eps, double b=0): 20 | M = np.empty((4, 4), dtype=np.double) 21 | cdef double[:, ::1] M_view = M 22 | M_view[:, :] = 0 23 | M_view[0, 2] = eps - b ** 2 24 | M_view[1, 3] = eps 25 | M_view[2, 0] = 1 26 | M_view[3, 1] = 1 - (b ** 2 / eps) 27 | return M 28 | 29 | 30 | @cython.boundscheck(False) 31 | @cython.wraparound(False) 32 | cpdef eigvalsM(double eps, double b=0.): 33 | cdef double n = sqrt(eps - b ** 2) 34 | return np.array([-n, -n, n, n], dtype=np.double) 35 | 36 | 37 | @cython.boundscheck(False) 38 | @cython.wraparound(False) 39 | cpdef eigvecsM(double eps, double b=0.): 40 | cdef double sqrtepsb = sqrt(eps - b ** 2) 41 | cdef double c1 = sqrtepsb 42 | cdef double c2 = eps * sqrtepsb / (b ** 2 - eps) 43 | cdef double c3 = sqrtepsb / (2 * eps) 44 | cdef double c4 = 0.5 / sqrtepsb 45 | return np.array([[[0, -c1, 0, c1], 46 | [c2, 0, -c2, 0], 47 | [0, 1, 0, 1], 48 | [1, 0, 1, 0]], 49 | [[0, -c3, 0, 0.5], 50 | [-c4, 0, 0.5, 0], 51 | [0, c3, 0, 0.5], 52 | [c4, 0, 0.5, 0]]], dtype=np.double) 53 | 54 | 55 | @cython.boundscheck(False) 56 | @cython.wraparound(False) 57 | cpdef propagator_layer(f, double eps, double d, double b=0): 58 | """Propagator, or evolution operator, or transfer matrix of 59 | a uniform layer with permittivity eps and thickness d 60 | at frequency f. 61 | """ 62 | scalar_input = np.isscalar(f) 63 | cdef double[::1] freqs = np.atleast_1d(f) 64 | cdef Py_ssize_t i, N = freqs.shape[0] 65 | cdef double complex ik0d 66 | cdef np.ndarray[dtype=np.complex_t, ndim=3] res 67 | res = np.empty((N, 4, 4), dtype=np.complex_) 68 | cdef np.ndarray[dtype=np.double_t, ndim=1] w = eigvalsM(eps, b) 69 | cdef np.ndarray[dtype=np.double_t, ndim=3] vr = eigvecsM(eps, b) 70 | for i in range(N): 71 | ik0d = 2j * np.pi * freqs[i] * d 72 | res[i, :, :] = np.linalg.multi_dot( 73 | [vr[0], np.diag(np.exp(ik0d * w)), vr[1]] 74 | ) 75 | if scalar_input: 76 | return res[0] 77 | return res 78 | 79 | 80 | @cython.boundscheck(False) 81 | @cython.wraparound(False) 82 | def propagator_grating(f, double eps1, double eps2, double[:] D, double b=0): 83 | scalar_input = np.isscalar(f) 84 | cdef size_t n, N = D.shape[0] 85 | if scalar_input: 86 | propagator = propagator_layer(f, eps1, D[0], b) 87 | for n in range(1, N): 88 | eps = eps1 if n % 2 == 0 else eps2 89 | propagator = np.dot(propagator_layer(f, eps, D[n], b), propagator) 90 | return propagator 91 | cdef np.ndarray[np.complex128_t, ndim=3] prop_layer 92 | propagator = propagator_layer(f, eps1, D[0], b) 93 | for n in range(1, N): 94 | eps = eps1 if n % 2 == 0 else eps2 95 | prop_layer = propagator_layer(f, eps, D[n], b) 96 | for i in range(propagator.shape[0]): 97 | propagator[i] = np.dot(prop_layer[i], propagator[i]) 98 | return propagator 99 | 100 | 101 | cpdef gamma(eps, b=0): 102 | return np.diag([1 / np.sqrt(eps - b ** 2), 103 | np.sqrt(eps - b ** 2) / eps]) 104 | 105 | 106 | cpdef operator_r(propagator, eps_left=1, eps_right=None, b=0): 107 | """Reflection operator of a multilayer characterized by propagator 108 | surrounded by media with eps_left at left and eps_right at right. 109 | """ 110 | if not eps_right: 111 | eps_right = eps_left 112 | gamma_left = gamma(eps_left, b) 113 | gamma_right = gamma(eps_right, b) 114 | factor1 = (np.bmat([gamma_right, -I]). 115 | dot(propagator). 116 | dot(np.bmat([I, -gamma_left]).T)) 117 | factor2 = (np.bmat([gamma_right, -I]). 118 | dot(propagator). 119 | dot(np.bmat([I, gamma_left]).T)) 120 | return spla.inv(factor1).dot(factor2) 121 | 122 | 123 | cpdef operator_t(propagator, eps_left=1, eps_right=None, b=0): 124 | """Transmission operator of a multilayer characterized by propagator 125 | surrounded by media with eps_left at left and eps_right at right. 126 | """ 127 | if not eps_right: 128 | eps_right = eps_left 129 | gamma_left = gamma(eps_left, b) 130 | gamma_right = gamma(eps_right, b) 131 | return 2 * spla.inv((np.bmat([gamma_left, I]). 132 | dot(spla.inv(propagator)). 133 | dot(np.bmat([I, gamma_right]).T))).dot(gamma_left) 134 | -------------------------------------------------------------------------------- /environment-gpu.yml: -------------------------------------------------------------------------------- 1 | name: tf-gpu 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _tflow_select=2.3.0=gpu 6 | - absl-py=0.11.0=pyhd3eb1b0_1 7 | - aiohttp=3.7.3=py38h2bbff1b_1 8 | - argon2-cffi=20.1.0=py38h2bbff1b_1 9 | - astunparse=1.6.3=py_0 10 | - async-timeout=3.0.1=py38haa95532_0 11 | - async_generator=1.10=pyhd3eb1b0_0 12 | - attrs=20.3.0=pyhd3eb1b0_0 13 | - backcall=0.2.0=pyhd3eb1b0_0 14 | - blas=1.0=mkl 15 | - bleach=3.3.0=pyhd3eb1b0_0 16 | - blinker=1.4=py38haa95532_0 17 | - brotlipy=0.7.0=py38h2bbff1b_1003 18 | - ca-certificates=2021.1.19=haa95532_0 19 | - cachetools=4.2.1=pyhd3eb1b0_0 20 | - certifi=2020.12.5=py38haa95532_0 21 | - cffi=1.14.4=py38hcd4344a_0 22 | - chardet=3.0.4=py38haa95532_1003 23 | - click=7.1.2=pyhd3eb1b0_0 24 | - colorama=0.4.4=pyhd3eb1b0_0 25 | - cryptography=2.9.2=py38h7a1dbc1_0 26 | - cycler=0.10.0=py38_0 27 | - cython=0.29.21=py38hd77b12b_0 28 | - decorator=4.4.2=pyhd3eb1b0_0 29 | - defusedxml=0.6.0=pyhd3eb1b0_0 30 | - entrypoints=0.3=py38_0 31 | - freetype=2.10.4=hd328e21_0 32 | - gast=0.4.0=py_0 33 | - google-auth=1.24.0=pyhd3eb1b0_0 34 | - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2 35 | - google-pasta=0.2.0=py_0 36 | - graphviz=2.38=hfd603c8_2 37 | - grpcio=1.35.0=py38hc60d5dd_0 38 | - h5py=2.10.0=py38h5e291fa_0 39 | - hdf5=1.10.4=h7ebc959_0 40 | - icc_rt=2019.0.0=h0cc432a_1 41 | - icu=58.2=ha925a31_3 42 | - idna=2.10=pyhd3eb1b0_0 43 | - importlib-metadata=2.0.0=py_1 44 | - importlib_metadata=2.0.0=1 45 | - intel-openmp=2020.2=254 46 | - ipykernel=5.3.4=py38h5ca1d4c_0 47 | - ipython=7.20.0=py38hd4e2768_1 48 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 49 | - ipywidgets=7.6.3=pyhd3eb1b0_1 50 | - jedi=0.17.0=py38_0 51 | - jinja2=2.11.3=pyhd3eb1b0_0 52 | - joblib=1.0.0=pyhd3eb1b0_0 53 | - jpeg=9b=hb83a4c4_2 54 | - jsonschema=3.2.0=py_2 55 | - jupyter=1.0.0=py38_7 56 | - jupyter_client=6.1.7=py_0 57 | - jupyter_console=6.2.0=py_0 58 | - jupyter_core=4.7.1=py38haa95532_0 59 | - jupyterlab_pygments=0.1.2=py_0 60 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 61 | - keras-applications=1.0.8=py_1 62 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 63 | - kiwisolver=1.3.1=py38hd77b12b_0 64 | - libpng=1.6.37=h2a8f88b_0 65 | - libprotobuf=3.14.0=h23ce68f_0 66 | - libsodium=1.0.18=h62dcd97_0 67 | - libtiff=4.1.0=h56a325e_1 68 | - lz4-c=1.9.3=h2bbff1b_0 69 | - m2w64-gcc-libgfortran=5.3.0=6 70 | - m2w64-gcc-libs=5.3.0=7 71 | - m2w64-gcc-libs-core=5.3.0=7 72 | - m2w64-gmp=6.1.0=2 73 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 74 | - markdown=3.3.3=py38haa95532_0 75 | - markupsafe=1.1.1=py38he774522_0 76 | - matplotlib=3.3.2=haa95532_0 77 | - matplotlib-base=3.3.2=py38hba9282a_0 78 | - mistune=0.8.4=py38he774522_1000 79 | - mkl=2020.2=256 80 | - mkl-service=2.3.0=py38h196d8e1_0 81 | - mkl_fft=1.2.0=py38h45dec08_0 82 | - mkl_random=1.1.1=py38h47e9c7a_0 83 | - msys2-conda-epoch=20160418=1 84 | - multidict=4.7.6=py38he774522_1 85 | - nbclient=0.5.1=py_0 86 | - nbconvert=6.0.7=py38_0 87 | - nbformat=5.1.2=pyhd3eb1b0_1 88 | - nest-asyncio=1.4.3=pyhd3eb1b0_0 89 | - notebook=6.2.0=py38haa95532_0 90 | - numpy=1.19.2=py38hadc3359_0 91 | - numpy-base=1.19.2=py38ha3acd2a_0 92 | - oauthlib=3.1.0=py_0 93 | - olefile=0.46=py_0 94 | - openssl=1.1.1i=h2bbff1b_0 95 | - opt_einsum=3.1.0=py_0 96 | - packaging=20.9=pyhd3eb1b0_0 97 | - pandoc=2.11=h9490d1a_0 98 | - pandocfilters=1.4.3=py38haa95532_1 99 | - parso=0.8.1=pyhd3eb1b0_0 100 | - pickleshare=0.7.5=pyhd3eb1b0_1003 101 | - pillow=8.1.0=py38h4fa10fc_0 102 | - pip=20.3.3=py38haa95532_0 103 | - prometheus_client=0.9.0=pyhd3eb1b0_0 104 | - prompt-toolkit=3.0.8=py_0 105 | - prompt_toolkit=3.0.8=0 106 | - protobuf=3.14.0=py38hd77b12b_1 107 | - pyasn1=0.4.8=py_0 108 | - pyasn1-modules=0.2.8=py_0 109 | - pycparser=2.20=py_2 110 | - pydot=1.4.1=py38_0 111 | - pygments=2.7.4=pyhd3eb1b0_0 112 | - pyjwt=2.0.1=py38haa95532_0 113 | - pyopenssl=20.0.1=pyhd3eb1b0_1 114 | - pyparsing=2.4.7=pyhd3eb1b0_0 115 | - pyqt=5.9.2=py38ha925a31_4 116 | - pyreadline=2.1=py38_1 117 | - pyrsistent=0.17.3=py38he774522_0 118 | - pysocks=1.7.1=py38haa95532_0 119 | - python=3.8.5=h5fd99cc_1 120 | - python-dateutil=2.8.1=pyhd3eb1b0_0 121 | - pywin32=227=py38he774522_1 122 | - pywinpty=0.5.7=py38_0 123 | - pyzmq=20.0.0=py38hd77b12b_1 124 | - qt=5.9.7=vc14h73c81de_0 125 | - qtconsole=5.0.2=pyhd3eb1b0_0 126 | - qtpy=1.9.0=py_0 127 | - requests=2.25.1=pyhd3eb1b0_0 128 | - requests-oauthlib=1.3.0=py_0 129 | - rsa=4.7=pyhd3eb1b0_1 130 | - scikit-learn=0.23.2=py38h47e9c7a_0 131 | - scipy=1.6.0=py38h14eb087_0 132 | - send2trash=1.5.0=pyhd3eb1b0_1 133 | - setuptools=52.0.0=py38haa95532_0 134 | - sip=4.19.13=py38ha925a31_0 135 | - six=1.15.0=py38haa95532_0 136 | - sqlite=3.33.0=h2a8f88b_0 137 | - tensorboard=2.3.0=pyh4dce500_0 138 | - tensorboard-plugin-wit=1.6.0=py_0 139 | - tensorflow=2.3.0=mkl_py38h8557ec7_0 140 | - tensorflow-base=2.3.0=eigen_py38h75a453f_0 141 | - tensorflow-estimator=2.3.0=pyheb71bc4_0 142 | - tensorflow-gpu=2.3.0=he13fc11_0 143 | - termcolor=1.1.0=py38_1 144 | - terminado=0.9.2=py38haa95532_0 145 | - testpath=0.4.4=pyhd3eb1b0_0 146 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 147 | - tk=8.6.10=he774522_0 148 | - tornado=6.1=py38h2bbff1b_0 149 | - traitlets=5.0.5=pyhd3eb1b0_0 150 | - typing-extensions=3.7.4.3=hd3eb1b0_0 151 | - typing_extensions=3.7.4.3=pyh06a4308_0 152 | - urllib3=1.26.3=pyhd3eb1b0_0 153 | - vc=14.2=h21ff451_1 154 | - vs2015_runtime=14.27.29016=h5e58377_2 155 | - wcwidth=0.2.5=py_0 156 | - webencodings=0.5.1=py38_1 157 | - werkzeug=1.0.1=pyhd3eb1b0_0 158 | - wheel=0.36.2=pyhd3eb1b0_0 159 | - widgetsnbextension=3.5.1=py38_0 160 | - win_inet_pton=1.1.0=py38haa95532_0 161 | - wincertstore=0.2=py38_0 162 | - winpty=0.4.3=4 163 | - wrapt=1.12.1=py38he774522_1 164 | - xz=5.2.5=h62dcd97_0 165 | - yarl=1.5.1=py38he774522_0 166 | - zeromq=4.3.3=ha925a31_3 167 | - zipp=3.4.0=pyhd3eb1b0_0 168 | - zlib=1.2.11=h62dcd97_4 169 | - zstd=1.4.5=h04227a9_0 170 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: tf 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _tflow_select=2.3.0=eigen 6 | - absl-py=0.11.0=pyhd3eb1b0_1 7 | - aiohttp=3.7.3=py38h2bbff1b_1 8 | - argon2-cffi=20.1.0=py38h2bbff1b_1 9 | - astunparse=1.6.3=py_0 10 | - async-timeout=3.0.1=py38haa95532_0 11 | - async_generator=1.10=pyhd3eb1b0_0 12 | - attrs=20.3.0=pyhd3eb1b0_0 13 | - backcall=0.2.0=pyhd3eb1b0_0 14 | - blas=1.0=mkl 15 | - bleach=3.3.0=pyhd3eb1b0_0 16 | - blinker=1.4=py38haa95532_0 17 | - brotlipy=0.7.0=py38h2bbff1b_1003 18 | - ca-certificates=2021.1.19=haa95532_0 19 | - cachetools=4.2.1=pyhd3eb1b0_0 20 | - certifi=2020.12.5=py38haa95532_0 21 | - cffi=1.14.4=py38hcd4344a_0 22 | - chardet=3.0.4=py38haa95532_1003 23 | - click=7.1.2=pyhd3eb1b0_0 24 | - colorama=0.4.4=pyhd3eb1b0_0 25 | - cryptography=2.9.2=py38h7a1dbc1_0 26 | - decorator=4.4.2=pyhd3eb1b0_0 27 | - defusedxml=0.6.0=pyhd3eb1b0_0 28 | - entrypoints=0.3=py38_0 29 | - gast=0.4.0=py_0 30 | - google-auth=1.24.0=pyhd3eb1b0_0 31 | - google-auth-oauthlib=0.4.2=pyhd3eb1b0_2 32 | - google-pasta=0.2.0=py_0 33 | - grpcio=1.35.0=py38hc60d5dd_0 34 | - h5py=2.10.0=py38h5e291fa_0 35 | - hdf5=1.10.4=h7ebc959_0 36 | - icc_rt=2019.0.0=h0cc432a_1 37 | - icu=58.2=ha925a31_3 38 | - idna=2.10=pyhd3eb1b0_0 39 | - importlib-metadata=2.0.0=py_1 40 | - importlib_metadata=2.0.0=1 41 | - intel-openmp=2020.2=254 42 | - ipykernel=5.3.4=py38h5ca1d4c_0 43 | - ipython=7.20.0=py38hd4e2768_1 44 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 45 | - ipywidgets=7.6.3=pyhd3eb1b0_1 46 | - jedi=0.17.0=py38_0 47 | - jinja2=2.11.3=pyhd3eb1b0_0 48 | - jpeg=9b=hb83a4c4_2 49 | - jsonschema=3.2.0=py_2 50 | - jupyter=1.0.0=py38_7 51 | - jupyter_client=6.1.7=py_0 52 | - jupyter_console=6.2.0=py_0 53 | - jupyter_core=4.7.1=py38haa95532_0 54 | - jupyterlab_pygments=0.1.2=py_0 55 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 56 | - keras-applications=1.0.8=py_1 57 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 58 | - libpng=1.6.37=h2a8f88b_0 59 | - libprotobuf=3.14.0=h23ce68f_0 60 | - libsodium=1.0.18=h62dcd97_0 61 | - m2w64-gcc-libgfortran=5.3.0=6 62 | - m2w64-gcc-libs=5.3.0=7 63 | - m2w64-gcc-libs-core=5.3.0=7 64 | - m2w64-gmp=6.1.0=2 65 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 66 | - markdown=3.3.3=py38haa95532_0 67 | - markupsafe=1.1.1=py38he774522_0 68 | - mistune=0.8.4=py38he774522_1000 69 | - mkl=2020.2=256 70 | - mkl-service=2.3.0=py38h196d8e1_0 71 | - mkl_fft=1.2.0=py38h45dec08_0 72 | - mkl_random=1.1.1=py38h47e9c7a_0 73 | - msys2-conda-epoch=20160418=1 74 | - multidict=4.7.6=py38he774522_1 75 | - nbclient=0.5.1=py_0 76 | - nbconvert=6.0.7=py38_0 77 | - nbformat=5.1.2=pyhd3eb1b0_1 78 | - nest-asyncio=1.4.3=pyhd3eb1b0_0 79 | - notebook=6.2.0=py38haa95532_0 80 | - numpy=1.19.2=py38hadc3359_0 81 | - numpy-base=1.19.2=py38ha3acd2a_0 82 | - oauthlib=3.1.0=py_0 83 | - openssl=1.1.1i=h2bbff1b_0 84 | - opt_einsum=3.1.0=py_0 85 | - packaging=20.9=pyhd3eb1b0_0 86 | - pandoc=2.11=h9490d1a_0 87 | - pandocfilters=1.4.3=py38haa95532_1 88 | - parso=0.8.1=pyhd3eb1b0_0 89 | - pickleshare=0.7.5=pyhd3eb1b0_1003 90 | - pip=20.3.3=py38haa95532_0 91 | - prometheus_client=0.9.0=pyhd3eb1b0_0 92 | - prompt-toolkit=3.0.8=py_0 93 | - prompt_toolkit=3.0.8=0 94 | - protobuf=3.14.0=py38hd77b12b_1 95 | - pyasn1=0.4.8=py_0 96 | - pyasn1-modules=0.2.8=py_0 97 | - pycparser=2.20=py_2 98 | - pygments=2.7.4=pyhd3eb1b0_0 99 | - pyjwt=2.0.1=py38haa95532_0 100 | - pyopenssl=20.0.1=pyhd3eb1b0_1 101 | - pyparsing=2.4.7=pyhd3eb1b0_0 102 | - pyqt=5.9.2=py38ha925a31_4 103 | - pyreadline=2.1=py38_1 104 | - pyrsistent=0.17.3=py38he774522_0 105 | - pysocks=1.7.1=py38haa95532_0 106 | - python=3.8.5=h5fd99cc_1 107 | - python-dateutil=2.8.1=pyhd3eb1b0_0 108 | - pywin32=227=py38he774522_1 109 | - pywinpty=0.5.7=py38_0 110 | - pyzmq=20.0.0=py38hd77b12b_1 111 | - qt=5.9.7=vc14h73c81de_0 112 | - qtconsole=5.0.2=pyhd3eb1b0_0 113 | - qtpy=1.9.0=py_0 114 | - requests=2.25.1=pyhd3eb1b0_0 115 | - requests-oauthlib=1.3.0=py_0 116 | - rsa=4.7=pyhd3eb1b0_1 117 | - scipy=1.6.0=py38h14eb087_0 118 | - send2trash=1.5.0=pyhd3eb1b0_1 119 | - setuptools=52.0.0=py38haa95532_0 120 | - sip=4.19.13=py38ha925a31_0 121 | - six=1.15.0=py38haa95532_0 122 | - sqlite=3.33.0=h2a8f88b_0 123 | - tensorboard=2.3.0=pyh4dce500_0 124 | - tensorboard-plugin-wit=1.6.0=py_0 125 | - tensorflow=2.3.0=mkl_py38h8c0d9a2_0 126 | - tensorflow-base=2.3.0=eigen_py38h75a453f_0 127 | - tensorflow-estimator=2.3.0=pyheb71bc4_0 128 | - termcolor=1.1.0=py38_1 129 | - terminado=0.9.2=py38haa95532_0 130 | - testpath=0.4.4=pyhd3eb1b0_0 131 | - tornado=6.1=py38h2bbff1b_0 132 | - traitlets=5.0.5=pyhd3eb1b0_0 133 | - typing-extensions=3.7.4.3=hd3eb1b0_0 134 | - typing_extensions=3.7.4.3=pyh06a4308_0 135 | - urllib3=1.26.3=pyhd3eb1b0_0 136 | - vc=14.2=h21ff451_1 137 | - vs2015_runtime=14.27.29016=h5e58377_2 138 | - wcwidth=0.2.5=py_0 139 | - webencodings=0.5.1=py38_1 140 | - werkzeug=1.0.1=pyhd3eb1b0_0 141 | - wheel=0.36.2=pyhd3eb1b0_0 142 | - widgetsnbextension=3.5.1=py38_0 143 | - win_inet_pton=1.1.0=py38haa95532_0 144 | - wincertstore=0.2=py38_0 145 | - winpty=0.4.3=4 146 | - wrapt=1.12.1=py38he774522_1 147 | - yarl=1.5.1=py38he774522_0 148 | - zeromq=4.3.3=ha925a31_3 149 | - zipp=3.4.0=pyhd3eb1b0_0 150 | - zlib=1.2.11=h62dcd97_4 151 | -------------------------------------------------------------------------------- /forward_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "incorporated-exhaust", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib inline\n", 11 | "import os\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import tensorflow as tf\n", 15 | "from tensorflow import keras\n", 16 | "from tensorflow.keras import layers\n", 17 | "from sklearn.model_selection import train_test_split\n", 18 | "from grating import *" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "id": "raised-interference", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "Num GPUs Available: 0\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "print(\"Num GPUs Available: \", len(tf.config.list_physical_devices('GPU')))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "id": "bulgarian-surrey", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "%config IPCompleter.use_jedi = False" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "incorrect-scotland", 52 | "metadata": {}, 53 | "source": [ 54 | "# Loading the Dataset" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "id": "traditional-functionality", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "Train set contains 588429 samples\n", 68 | "Validation set contains 65381 samples\n" 69 | ] 70 | } 71 | ], 72 | "source": [ 73 | "fname = 'dataset.npz'\n", 74 | "with np.load(fname) as data:\n", 75 | " designs = data['D']\n", 76 | " responses = data['R']\n", 77 | " \n", 78 | "n_grating_layers = designs.shape[-1]\n", 79 | "n_freqs = responses.shape[-1]\n", 80 | "Dtrain, Dtest, Rtrain, Rtest = train_test_split(designs, responses,\n", 81 | " test_size=0.1,\n", 82 | " random_state=42)\n", 83 | "print(\"Train set contains {} samples\".format(Dtrain.shape[0]))\n", 84 | "print(\"Validation set contains {} samples\".format(Dtest.shape[0]))" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "periodic-adelaide", 90 | "metadata": {}, 91 | "source": [ 92 | "# Initializing the Model" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 5, 98 | "id": "opening-appearance", 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "Model: \"ForwardNet\"\n", 106 | "_________________________________________________________________\n", 107 | "Layer (type) Output Shape Param # \n", 108 | "=================================================================\n", 109 | "F1 (Dense) (None, 500) 8000 \n", 110 | "_________________________________________________________________\n", 111 | "F2 (Dense) (None, 200) 100200 \n", 112 | "_________________________________________________________________\n", 113 | "F3 (Dense) (None, 200) 40200 \n", 114 | "_________________________________________________________________\n", 115 | "F4 (Dense) (None, 200) 40200 \n", 116 | "_________________________________________________________________\n", 117 | "R (Dense) (None, 200) 40200 \n", 118 | "=================================================================\n", 119 | "Total params: 228,800\n", 120 | "Trainable params: 228,800\n", 121 | "Non-trainable params: 0\n", 122 | "_________________________________________________________________\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "activation = keras.activations.sigmoid\n", 128 | "# Architecture 4\n", 129 | "model = keras.Sequential([layers.Input((n_grating_layers,), name='D'),\n", 130 | " layers.Dense(500, activation=activation, name='F1'),\n", 131 | " layers.Dense(200, activation=activation, name='F2'),\n", 132 | " layers.Dense(200, activation=activation, name='F3'),\n", 133 | " layers.Dense(200, activation=activation, name='F4'),\n", 134 | " layers.Dense(n_freqs, activation='sigmoid', name='R')],\n", 135 | " name='ForwardNet')\n", 136 | "model.summary()" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "coupled-mining", 142 | "metadata": {}, 143 | "source": [ 144 | "# Loading Weights" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "id": "vietnamese-cargo", 150 | "metadata": {}, 151 | "source": [ 152 | "Let us load previous model state to continue trai" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 6, 158 | "id": "hungarian-commons", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "model.load_weights(os.path.join('forward_model',\n", 163 | " 'Arch4_Epochs4000_Adam0001_Sigmoid.h5'))" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 7, 169 | "id": "assured-crack", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "def loss(y_true, y_pred):\n", 174 | " return n_freqs * keras.losses.mse(y_true, y_pred)\n", 175 | "\n", 176 | "model.compile(loss=loss, optimizer='adam')" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": 8, 182 | "id": "protecting-comparison", 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Epoch 4001/4001\n", 190 | "4598/4598 [==============================] - 17s 4ms/step - loss: 0.5009\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "initial_epoch = 4000\n", 196 | "\n", 197 | "info = model.fit(Dtrain, Rtrain,\n", 198 | " batch_size=128, epochs=initial_epoch + 1,\n", 199 | " validation_data=(Dtest, Rtest),\n", 200 | " validation_freq=5,\n", 201 | " initial_epoch=initial_epoch)\n", 202 | "initial_epoch = model.history.epoch[-1]" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 9, 208 | "id": "regular-medicine", 209 | "metadata": { 210 | "scrolled": true 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "epochs = initial_epoch + 1\n", 215 | "model.save(\n", 216 | " os.path.join('forward_model',\n", 217 | " 'Arch4_Epochs{}_Adam0001_Sigmoid.h5'.format(epochs)))" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 10, 223 | "id": "marine-terrorism", 224 | "metadata": { 225 | "scrolled": true 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "# # Uncomment after training:\n", 230 | "# # ======================== \n", 231 | "# val_loss = model.history.history['val_loss']\n", 232 | "# loss = model.history.history['loss']\n", 233 | "# fig, ax = plt.subplots()\n", 234 | "# ax.plot(info.epoch, loss, label='Loss')\n", 235 | "# ax.plot(info.epoch[::5], val_loss, label='Validation Loss')\n", 236 | "# ax.legend()\n", 237 | "# ax.set_xlabel('Epoch')" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": 11, 243 | "id": "flexible-winning", 244 | "metadata": {}, 245 | "outputs": [ 246 | { 247 | "data": { 248 | "image/png": "\n", 249 | "text/plain": [ 250 | "
" 251 | ] 252 | }, 253 | "metadata": { 254 | "needs_background": "light" 255 | }, 256 | "output_type": "display_data" 257 | } 258 | ], 259 | "source": [ 260 | "idx = np.random.randint(0, Dtest.shape[0], 1)\n", 261 | "dnn_responses = model(Dtest[idx]).numpy()\n", 262 | "responses = Rtest[idx]\n", 263 | "for o, r in zip(dnn_responses, responses):\n", 264 | " line, = plt.plot(o, '--')\n", 265 | " plt.plot(r, '-', color=line.get_color())" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 12, 271 | "id": "korean-display", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "load_dir = 'forward_model'\n", 276 | "def loss(y_true, y_pred):\n", 277 | " return n_freqs * keras.losses.mse(y_true, y_pred)\n", 278 | "\n", 279 | "fmodel400 = keras.models.load_model(\n", 280 | " os.path.join(load_dir, 'Arch4_Epochs400_Adam0001_Sigmoid.h5'),\n", 281 | " custom_objects={'loss': loss})\n", 282 | "fmodel1000 = keras.models.load_model(\n", 283 | " os.path.join(load_dir, 'Arch4_Epochs1000_Adam0001_Sigmoid.h5'),\n", 284 | " custom_objects={'loss': loss})\n", 285 | "fmodel2500 = keras.models.load_model(\n", 286 | " os.path.join(load_dir, 'Arch4_Epochs2500_Adam0001_Sigmoid.h5'),\n", 287 | " custom_objects={'loss': loss})\n", 288 | "fmodel4000 = keras.models.load_model(\n", 289 | " os.path.join(load_dir, 'Arch4_Epochs4000_Adam0001_Sigmoid.h5'),\n", 290 | " custom_objects={'loss': loss})" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 13, 296 | "id": "moral-equipment", 297 | "metadata": {}, 298 | "outputs": [ 299 | { 300 | "data": { 301 | "text/plain": [ 302 | "" 303 | ] 304 | }, 305 | "execution_count": 13, 306 | "metadata": {}, 307 | "output_type": "execute_result" 308 | }, 309 | { 310 | "data": { 311 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAASEAAAD4CAYAAACjW1BIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGzklEQVR4nO3bz4vtdR3H8ddbb5JakXBtkUpjEJUEYUhYQosMKoraFhTRuh8WQVh/QYuIWkQQVgRJLcxFRGSLWkv3alBmgfjzlpESlURwEz8tZgIJszvO+d7XbXo8VjPnfDnf92fOd55zvt85Z9ZaAWi5qD0A8P9NhIAqEQKqRAioEiGg6sRhNj558uTa29s75+3PPvXHrPXMYWc6lL9fckXOZ0svPfunJNv/R/Eo63r+Ged57jv6LEf5+bzQNR/1OTnqMbTVMbHFsb2rWWcuyiUvfcU5b3/69Okn11pXPtd9h4rQ3t5eTp06dc7bP3TXFw7z8C/IvXsf2Xwfz3b9w98+L/s5yrp2PeNhZjnKvl/omo+63qMeQ1sdE1sc27uc9dp33nrO287MI//pPqdjQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUCVCQJUIAVUiBFSJEFAlQkCVCAFVIgRUiRBQJUJAlQgBVSIEVIkQUDVrrXPfeOaJJI8c4vFPJnnysEP9D7LO48U6d+9Va60rn+uOQ0XosGbm1Frrhs12cIGwzuPFOs8vp2NAlQgBVVtH6OsbP/6FwjqPF+s8jza9JgTw3zgdA6pECKjaJEIz866Z+e3MPDAzt26xjwvBzFwzMz+bmftn5r6ZuaU901Zm5uKZuXdmftieZUsz8/KZuWNmfnPwvL6lPdOuzcynD47XX83Md2fmxc15dh6hmbk4yVeTvDvJdUk+ODPX7Xo/F4ink3xmrfX6JDcm+dgxXustSe5vD3EefCXJj9dar0vyxhyzNc/MVUk+meSGtdYbklyc5APNmbZ4JfTmJA+stR5ca51N8r0k799gP3VrrcfXWvccfP1U9g/Yq7pT7d7MXJ3kPUlua8+ypZl5WZK3JflGkqy1zq61/lwdahsnklw6MyeSXJbk981htojQVUkee9b3Z3IMfzH/3czsJbk+yd3lUbbw5SSfTfJMeY6tvTrJE0m+dXDqedvMXN4eapfWWr9L8sUkjyZ5PMlf1lo/ac60RYTmOW471u8DmJmXJPl+kk+ttf7anmeXZua9Sf641jrdnuU8OJHkTUm+tta6Psnfkhyra5ozc0X2z0yuTfLKJJfPzIeaM20RoTNJrnnW91en/HJvSzPzouwH6Pa11p3teTZwU5L3zczD2T+1fvvMfKc70mbOJDmz1vrXq9k7sh+l4+QdSR5aaz2x1vpHkjuTvLU50BYR+nmS18zMtTNzSfYvev1gg/3Uzcxk//rB/WutL7Xn2cJa63NrravXWnvZfy5/utaq/uXcylrrD0kem5nXHtx0c5JfF0fawqNJbpyZyw6O35tTvvh+YtcPuNZ6emY+nuSu7F95/+Za675d7+cCcVOSDyf55cz84uC2z6+1ftQbiSP6RJLbD/6APpjko+V5dmqtdffM3JHknuz/d/felD++4WMbQJV3TANVIgRUiRBQJUJAlQgBVSIEVIkQUPVPXXwoIjol4D4AAAAASUVORK5CYII=\n", 312 | "text/plain": [ 313 | "
" 314 | ] 315 | }, 316 | "metadata": { 317 | "needs_background": "light" 318 | }, 319 | "output_type": "display_data" 320 | }, 321 | { 322 | "data": { 323 | "image/png": "\n", 324 | "text/plain": [ 325 | "
" 326 | ] 327 | }, 328 | "metadata": { 329 | "needs_background": "light" 330 | }, 331 | "output_type": "display_data" 332 | } 333 | ], 334 | "source": [ 335 | "seed = 8\n", 336 | "np.random.seed(seed)\n", 337 | "\n", 338 | "epsilon_Si = 13.491\n", 339 | "epsilon_SiO2 = 2.085136\n", 340 | "n_freqs = 200\n", 341 | "freqs = np.linspace(0.15, 0.25, n_freqs)\n", 342 | "n_grating_layers = 15\n", 343 | "\n", 344 | "D = np.random.random_sample((1, n_grating_layers))\n", 345 | "gr = Grating(epsilon_Si, epsilon_SiO2, D[0])\n", 346 | "\n", 347 | "# plot grating design and save file\n", 348 | "gr.plot()\n", 349 | "# plt.savefig('grating_seed{}.png'.format(seed), dpi=200)\n", 350 | "\n", 351 | "# plot target and prediction responses and save file\n", 352 | "fig, ax = plt.subplots(figsize=(8,6))\n", 353 | "line, = ax.plot(freqs, gr.transmittivity(freqs), label='target')\n", 354 | "ax.plot(freqs, fmodel400(D)[0].numpy(), '--', label='prediction (400 epochs)') #color=line.get_color()\n", 355 | "ax.plot(freqs, fmodel1000(D)[0].numpy(), '--', label='prediction (1000 epochs)')\n", 356 | "# ax.plot(freqs, fmodel2500(D)[0].numpy(), '--', label='prediction (2500 epochs)')\n", 357 | "ax.plot(freqs, fmodel4000(D)[0].numpy(), '--', label='prediction (4000 epochs)')\n", 358 | "ax.set_xlabel('Frequency [c/a]')\n", 359 | "ax.set_ylabel('Transmission')\n", 360 | "ax.legend()\n", 361 | "# plt.savefig('forward_model_seed{}.png'.format(seed), dpi=200)" 362 | ] 363 | } 364 | ], 365 | "metadata": { 366 | "kernelspec": { 367 | "display_name": "Python 3", 368 | "language": "python", 369 | "name": "python3" 370 | }, 371 | "language_info": { 372 | "codemirror_mode": { 373 | "name": "ipython", 374 | "version": 3 375 | }, 376 | "file_extension": ".py", 377 | "mimetype": "text/x-python", 378 | "name": "python", 379 | "nbconvert_exporter": "python", 380 | "pygments_lexer": "ipython3", 381 | "version": "3.8.5" 382 | } 383 | }, 384 | "nbformat": 4, 385 | "nbformat_minor": 5 386 | } 387 | -------------------------------------------------------------------------------- /forward_model/Arch4_Epochs1000_Adam0001_Sigmoid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/forward_model/Arch4_Epochs1000_Adam0001_Sigmoid.h5 -------------------------------------------------------------------------------- /forward_model/Arch4_Epochs2500_Adam0001_Sigmoid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/forward_model/Arch4_Epochs2500_Adam0001_Sigmoid.h5 -------------------------------------------------------------------------------- /forward_model/Arch4_Epochs4000_Adam0001_Sigmoid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/forward_model/Arch4_Epochs4000_Adam0001_Sigmoid.h5 -------------------------------------------------------------------------------- /forward_model/Arch4_Epochs400_Adam0001_Sigmoid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/forward_model/Arch4_Epochs400_Adam0001_Sigmoid.h5 -------------------------------------------------------------------------------- /forward_model/Arch4_Epochs6000_Adam0001_Sigmoid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/forward_model/Arch4_Epochs6000_Adam0001_Sigmoid.h5 -------------------------------------------------------------------------------- /grating.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | import pyximport 6 | pyximport.install( 7 | setup_args={'include_dirs': np.get_include()}, 8 | language_level=3, 9 | reload_support=True 10 | ) 11 | from cy_funcs import * 12 | 13 | 14 | class Layer(object): 15 | def __init__(self, eps, width): 16 | self.eps = eps 17 | self.width = width 18 | 19 | 20 | class Grating(object): 21 | def __init__(self, eps1, eps2, widths): 22 | self.eps1 = eps1 23 | self.eps2 = eps2 24 | self.widths = widths 25 | self.layers = [] 26 | for i, d in enumerate(widths): 27 | eps = self.eps1 if i % 2 == 0 else self.eps2 28 | self.layers.append(Layer(eps, d)) 29 | 30 | def props_layers(self, f, b=0): 31 | for i, d in enumerate(self.widths): 32 | eps = self.eps1 if i % 2 == 0 else self.eps2 33 | yield propagator_layer(f, eps, d, b) 34 | 35 | def propagator(self, f, b=0): 36 | return propagator_grating(f, self.eps1, self.eps2, self.widths, b) 37 | 38 | def transmittivity(self, f, b=0., pol='x'): 39 | propagators = self.propagator(f, b) 40 | op_t = np.array([operator_t(p, 1., 1., b) for p in propagators]) 41 | if pol == 'x': 42 | return np.abs(op_t[:, 0, 0]) ** 2 43 | if pol == 'y': 44 | return np.abs(op_t[:, 1, 1]) ** 2 45 | if pol == 'xy': 46 | return np.abs(op_t[:, 0, 1]) ** 2 47 | if pol == 'yx': 48 | return np.abs(op_t[:, 1, 0]) ** 2 49 | 50 | def plot(self, ax=None, colors=('burlywood', 'lightblue')): 51 | if ax is None: 52 | fig, ax = plt.subplots() 53 | # height = sum(l.width for l in self.layers) 54 | xmin = 0 55 | for i, l in enumerate(self.layers): 56 | ax.axvspan(xmin, xmin + l.width, color=colors[i % 2]) 57 | xmin += l.width 58 | ax.set_ylim(0, len(self.layers) / 2) 59 | ax.set_yticks([]) 60 | ax.set_aspect('equal') 61 | return ax 62 | -------------------------------------------------------------------------------- /inverse_model/Epochs400_Adam0001_Sigmoid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/inverse_model/Epochs400_Adam0001_Sigmoid.h5 -------------------------------------------------------------------------------- /inverse_model/InverseNet15EpochsTandem.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/inverse_model/InverseNet15EpochsTandem.h5 -------------------------------------------------------------------------------- /inverse_model/TandemNN_Epochs400_Adam0001_Sigmoid.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FiodarM/InvDesignNet/c07ae0ae89fd2a3f527227984bf0b3a14e0c5d83/inverse_model/TandemNN_Epochs400_Adam0001_Sigmoid.h5 -------------------------------------------------------------------------------- /produce_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "vertical-portsmouth", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%matplotlib inline\n", 11 | "import numpy as np\n", 12 | "import scipy.linalg as spla\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "from tqdm.notebook import tqdm\n", 15 | "import os\n", 16 | "from grating import Grating\n", 17 | "import gc" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "united-prerequisite", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "epsilon_Si = 13.491\n", 28 | "epsilon_SiO2 = 2.085136\n", 29 | "n_freqs = 200\n", 30 | "freqs = np.linspace(0.15, 0.25, n_freqs)\n", 31 | "n_grating_layers = 15\n", 32 | "\n", 33 | "np.random.seed(42)\n", 34 | "D = np.random.random_sample(n_grating_layers)\n", 35 | "gr = Grating(epsilon_Si, epsilon_SiO2, D)\n", 36 | "freqs = np.linspace(0.15, 0.25, n_freqs)\n", 37 | "ts = gr.transmittivity(freqs)\n", 38 | "plt.plot(freqs, ts)\n", 39 | "plt.ylim(0, 1)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "shared-framework", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# Benchmarking\n", 50 | "%timeit gr.propagator(0.2)\n", 51 | "%timeit gr.propagator(freqs)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "fossil-witness", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "batch_size = 1000\n", 62 | "n_grating_layers = 15\n", 63 | "\n", 64 | "a = 1\n", 65 | "freqs = np.linspace(0.15 / a, 0.25 / a, n_freqs)\n", 66 | "\n", 67 | "\n", 68 | "from produce_data import save_samples\n", 69 | "fname = 'dataset.npz'\n", 70 | "samples = dict(D=[], R=[])\n", 71 | "try:\n", 72 | " i = 0\n", 73 | " with tqdm(total=batch_size, leave=False) as pbar:\n", 74 | " while True:\n", 75 | " D = np.random.random_sample(n_grating_layers)\n", 76 | " gr = Grating(epsilon_Si, epsilon_SiO2, D)\n", 77 | " R = gr.transmittivity(freqs)\n", 78 | " i += 1\n", 79 | " pbar.update(1)\n", 80 | " samples['D'].append(D)\n", 81 | " samples['R'].append(R)\n", 82 | " if i == batch_size:\n", 83 | " save_samples(fname, samples)\n", 84 | " samples = dict(D=[], R=[])\n", 85 | " gc.collect()\n", 86 | " pbar.reset()\n", 87 | " i = 0\n", 88 | "except KeyboardInterrupt:\n", 89 | " save_samples(fname, samples)" 90 | ] 91 | } 92 | ], 93 | "metadata": { 94 | "kernelspec": { 95 | "display_name": "Python 3", 96 | "language": "python", 97 | "name": "python3" 98 | }, 99 | "language_info": { 100 | "codemirror_mode": { 101 | "name": "ipython", 102 | "version": 3 103 | }, 104 | "file_extension": ".py", 105 | "mimetype": "text/x-python", 106 | "name": "python", 107 | "nbconvert_exporter": "python", 108 | "pygments_lexer": "ipython3", 109 | "version": "3.7.9" 110 | } 111 | }, 112 | "nbformat": 4, 113 | "nbformat_minor": 5 114 | } 115 | -------------------------------------------------------------------------------- /produce_data.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import numpy as np 4 | from grating import Grating 5 | import os 6 | from tqdm import tqdm 7 | 8 | 9 | batch_size = 1000 10 | n_grating_layers = 15 11 | n_freqs = 200 12 | a = 1. 13 | freqs = np.linspace(0.15 / a, 0.25 / a, n_freqs) 14 | epsilon_Si = 13.491 15 | epsilon_SiO2 = 2.085136 16 | 17 | 18 | def save_samples(fname, samples): 19 | n_samples = len(samples['R']) 20 | if os.path.exists(fname): 21 | with np.load(fname) as existing: 22 | for k in samples.keys(): 23 | samples[k] = np.vstack((existing[k], samples[k])) 24 | try: 25 | np.savez(fname, **samples) 26 | except KeyboardInterrupt: 27 | pass 28 | print("Saved {0} samples to {1}. The dataset contains {2} samples." 29 | .format(n_samples, fname, len(samples['R']))) 30 | 31 | 32 | fname = 'dataset.npz' 33 | 34 | 35 | if __name__ == '__main__': 36 | samples = dict(D=[], R=[]) 37 | try: 38 | i = 0 39 | with tqdm(total=batch_size, leave=False) as pbar: 40 | while True: 41 | D = np.random.random_sample(n_grating_layers) 42 | gr = Grating(epsilon_Si, epsilon_SiO2, D) 43 | R = gr.transmittivity(freqs) 44 | i += 1 45 | pbar.update(1) 46 | samples['D'].append(D) 47 | samples['R'].append(R) 48 | if i == batch_size: 49 | save_samples(fname, samples) 50 | samples = dict(D=[], R=[]) 51 | gc.collect() 52 | pbar.reset() 53 | i = 0 54 | except KeyboardInterrupt: 55 | print('Interrupting calculation...') 56 | answer = input('Do you want to save calculated data? y/[n]: ') 57 | if answer == 'y': 58 | save_samples(fname, samples) 59 | finally: 60 | print('Exiting script') 61 | exit(0) 62 | -------------------------------------------------------------------------------- /py_funcs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as spla 3 | 4 | 5 | ex = np.array([1, 0]) 6 | exex = np.outer(ex, ex) 7 | ey = np.array([0, 1]) 8 | eyey = np.outer(ey, ey) 9 | I = np.eye(2) 10 | I4 = np.eye(4, dtype=np.complex128) 11 | 12 | 13 | def eigvalsM(eps, b=0.): 14 | n = np.sqrt(eps - b ** 2) 15 | return np.array([-n, -n, n, n]) 16 | 17 | 18 | def eigvecsM(eps, b=0.): 19 | sqrtepsb = np.sqrt(eps - b ** 2) 20 | c1 = sqrtepsb 21 | c2 = eps * sqrtepsb / (b ** 2 - eps) 22 | c3 = sqrtepsb / (2 * eps) 23 | c4 = 0.5 / sqrtepsb 24 | return np.array([[[0, -c1, 0, c1], 25 | [c2, 0, -c2, 0], 26 | [0, 1, 0, 1], 27 | [1, 0, 1, 0]], 28 | [[0, -c3, 0, 0.5], 29 | [-c4, 0, 0.5, 0], 30 | [0, c3, 0, 0.5], 31 | [c4, 0, 0.5, 0]]]) 32 | 33 | 34 | def propagator_layer(f, eps, d, b=0): 35 | """Propagator, or evolution operator, or transfer matrix of 36 | a uniform layer with permittivity eps and thickness d 37 | at frequency f. 38 | """ 39 | scalar_input = np.isscalar(f) 40 | freqs = np.atleast_1d(f) 41 | N = freqs.shape[0] 42 | res = np.empty((N, 4, 4), dtype=np.complex_) 43 | w = eigvalsM(eps, b) 44 | vr = eigvecsM(eps, b) 45 | for i in range(N): 46 | ik0d = 2j * np.pi * freqs[i] * d 47 | res[i, :, :] = np.linalg.multi_dot( 48 | [vr[0], np.diag(np.exp(ik0d * w)), vr[1]] 49 | ) 50 | if scalar_input: 51 | return res[0] 52 | return res 53 | 54 | 55 | def propagator_grating(f, eps1, eps2, D, b=0): 56 | scalar_input = np.isscalar(f) 57 | N = D.shape[0] 58 | if scalar_input: 59 | propagator = propagator_layer(f, eps1, D[0], b) 60 | for n in range(1, N): 61 | eps = eps1 if n % 2 == 0 else eps2 62 | propagator = np.dot(propagator_layer(f, eps, D[n], b), propagator) 63 | return propagator 64 | propagator = propagator_layer(f, eps1, D[0], b) 65 | for n in range(1, N): 66 | eps = eps1 if n % 2 == 0 else eps2 67 | prop_layer = propagator_layer(f, eps, D[n], b) 68 | for i in range(propagator.shape[0]): 69 | propagator[i] = np.dot(prop_layer[i], propagator[i]) 70 | return propagator 71 | 72 | 73 | def gamma(eps, b=0): 74 | return np.diag([1 / np.sqrt(eps - b ** 2), 75 | np.sqrt(eps - b ** 2) / eps]) 76 | 77 | 78 | def operator_r(propagator, eps_left=1, eps_right=None, b=0): 79 | """Reflection operator of a multilayer characterized by propagator 80 | surrounded by media with eps_left at left and eps_right at right. 81 | """ 82 | if not eps_right: 83 | eps_right = eps_left 84 | gamma_left = gamma(eps_left, b) 85 | gamma_right = gamma(eps_right, b) 86 | factor1 = (np.bmat([gamma_right, -I]). 87 | dot(propagator). 88 | dot(np.bmat([I, -gamma_left]).T)) 89 | factor2 = (np.bmat([gamma_right, -I]). 90 | dot(propagator). 91 | dot(np.bmat([I, gamma_left]).T)) 92 | return spla.inv(factor1).dot(factor2) 93 | 94 | 95 | def operator_t(propagator, eps_left=1, eps_right=None, b=0): 96 | """Transmission operator of a multilayer characterized by propagator 97 | surrounded by media with eps_left at left and eps_right at right. 98 | """ 99 | if not eps_right: 100 | eps_right = eps_left 101 | gamma_left = gamma(eps_left, b) 102 | gamma_right = gamma(eps_right, b) 103 | return 2 * spla.inv((np.bmat([gamma_left, I]). 104 | dot(spla.inv(propagator)). 105 | dot(np.bmat([I, gamma_right]).T))).dot(gamma_left) 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.3 2 | numpy>=1.19.5 3 | tqdm 4 | matplotlib 5 | cython 6 | jupyter 7 | scikit-learn -------------------------------------------------------------------------------- /test_grating.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from grating import Grating 4 | 5 | 6 | if __name__ == '__main__': 7 | n_grating_layers = 15 8 | n_freqs = 200 9 | epsilon_Si = 13.491 10 | epsilon_SiO2 = 2.085136 11 | np.random.seed(42) 12 | D = np.random.random_sample(15) 13 | gr = Grating(epsilon_Si, epsilon_SiO2, D) 14 | freqs = np.linspace(0.15, 0.25, n_freqs) 15 | ts = gr.transmittivity(freqs) 16 | plt.plot(freqs, ts) 17 | plt.show() --------------------------------------------------------------------------------