├── optogpt ├── core │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── sim.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── datasets.cpython-38.pyc │ │ │ └── datasets_ddp.cpython-38.pyc │ │ ├── sim.py │ │ └── datasets.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── transformer.cpython-38.pyc │ │ └── transformer.py │ ├── trains │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── train.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ │ └── train.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── common.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ │ ├── cie-cmf.txt │ │ ├── color_system.py │ │ └── common.py │ └── __pycache__ │ │ └── __init__.cpython-38.pyc ├── figures │ ├── optogpt.png │ ├── color2spec.png │ ├── embedding.png │ ├── ol_transformer.jpg │ └── self_improving.png ├── visualization │ └── fonts │ │ └── webfonts.tar.gz ├── nk │ ├── MgF2_HZ.csv │ ├── Ni_HZ.csv │ ├── Ag_HZ.csv │ ├── Cr.csv │ ├── Cr_HZ.csv │ ├── HfO2_HZ.csv │ ├── Au.csv │ ├── Al_HZ.csv │ ├── SiO2_HZ.csv │ └── Si_HZ.csv ├── environment.yml ├── run_ol_transformer.py ├── run_optogpt.py ├── README.md └── data_conversion.ipynb ├── self_improving ├── requirements.txt ├── README.md ├── combine_data.py ├── model_retrain.py ├── run_self_improving.py ├── generate_dev_data.py ├── prepare_aug_data.py └── data_perturb.py └── README.md /optogpt/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optogpt/core/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optogpt/core/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optogpt/core/trains/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optogpt/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /optogpt/figures/optogpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/figures/optogpt.png -------------------------------------------------------------------------------- /optogpt/figures/color2spec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/figures/color2spec.png -------------------------------------------------------------------------------- /optogpt/figures/embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/figures/embedding.png -------------------------------------------------------------------------------- /optogpt/figures/ol_transformer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/figures/ol_transformer.jpg -------------------------------------------------------------------------------- /optogpt/figures/self_improving.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/figures/self_improving.png -------------------------------------------------------------------------------- /optogpt/visualization/fonts/webfonts.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/visualization/fonts/webfonts.tar.gz -------------------------------------------------------------------------------- /optogpt/core/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/datasets/__pycache__/sim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/datasets/__pycache__/sim.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/trains/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/trains/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/utils/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/utils/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/trains/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/trains/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/datasets/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/datasets/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/models/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/models/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/core/datasets/__pycache__/datasets_ddp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taigaoma1997/optogpt/HEAD/optogpt/core/datasets/__pycache__/datasets_ddp.cpython-38.pyc -------------------------------------------------------------------------------- /optogpt/nk/MgF2_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.2,1.42088,0.0 3 | 0.25,1.40849,0.0 4 | 0.32299999999999995,1.3984,0.0 5 | 0.349,1.39607,0.0 6 | 0.41100000000000003,1.39,0.0 7 | 0.5,1.385,0.0 8 | 0.604,1.38,0.0 9 | 0.715,1.376,0.0 10 | 1.55,1.37,0.0 11 | 2.2,1.37,0.0 12 | 15,1.37,0.0 -------------------------------------------------------------------------------- /self_improving/requirements.txt: -------------------------------------------------------------------------------- 1 | # Core dependencies 2 | torch>=1.8.0 3 | numpy>=1.19.0 4 | pandas>=1.2.0 5 | scipy>=1.6.0 6 | matplotlib>=3.3.0 7 | seaborn>=0.11.0 8 | 9 | # Optimization 10 | pyswarms>=1.3.0 11 | 12 | # NLP (for tokenization) 13 | nltk>=3.6 14 | 15 | # Utilities 16 | pickle5 # For Python 3.7 compatibility 17 | tqdm>=4.60.0 18 | 19 | # Optional but recommended 20 | scikit-learn>=0.24.0 21 | jupyterlab>=3.0.0 # For running notebooks -------------------------------------------------------------------------------- /optogpt/nk/Ni_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.3,1.63,2.11 3 | 0.355,1.63,2.11 4 | 0.365,1.62,2.17 5 | 0.376,1.61,2.23 6 | 0.388,1.61,2.3 7 | 0.401,1.61,2.36 8 | 0.414,1.61,2.44 9 | 0.428,1.62,2.52 10 | 0.444,1.63,2.61 11 | 0.46,1.64,2.71 12 | 0.478,1.65,2.81 13 | 0.497,1.67,2.93 14 | 0.518,1.71,3.06 15 | 0.54,1.75,3.19 16 | 0.565,1.8,3.33 17 | 0.592,1.85,3.48 18 | 0.621,1.92,3.65 19 | 0.654,2.02,3.82 20 | 0.69,2.14,4.01 21 | 0.731,2.28,4.18 22 | 0.776,2.43,4.31 23 | 0.828,2.53,4.47 24 | 0.887,2.65,4.63 25 | 0.956,2.74,4.85 26 | 1.04,2.85,5.1 27 | 1.13,2.97,5.38 28 | 1.24,3.06,5.74 29 | 1.38,3.18,6.23 30 | 1.55,3.38,6.82 31 | 1.77,3.59,7.48 32 | 2.07,3.84,8.35 33 | 2.48,4.03,9.64 34 | 3.11,3.84,11.4 35 | 3.55,4.03,13.1 36 | 4.14,4.19,15.1 37 | 4.97,4.25,17.7 38 | 6.21,4.12,22.5 39 | 8.28,5.45,30.6 40 | 12.4,9.54,45.8 41 | -------------------------------------------------------------------------------- /optogpt/nk/Ag_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.1879,1.07,1.212 3 | 0.1916,1.1,1.232 4 | 0.1953,1.12,1.255 5 | 0.1993,1.14,1.277 6 | 0.2033,1.15,1.296 7 | 0.2073,1.18,1.312 8 | 0.2119,1.2,1.325 9 | 0.2164,1.22,1.336 10 | 0.2214,1.25,1.342 11 | 0.2262,1.26,1.344 12 | 0.2313,1.28,1.357 13 | 0.2371,1.28,1.367 14 | 0.2426,1.3,1.378 15 | 0.249,1.31,1.389 16 | 0.2551,1.33,1.393 17 | 0.2616,1.35,1.387 18 | 0.2689,1.38,1.372 19 | 0.2761,1.41,1.331 20 | 0.2844,1.41,1.264 21 | 0.2924,1.39,1.161 22 | 0.3009,1.34,0.964 23 | 0.3107,1.13,0.616 24 | 0.3204,0.81,0.392 25 | 0.3315,0.17,0.829 26 | 0.3425,0.14,1.142 27 | 0.3542,0.1,1.419 28 | 0.3679,0.07,1.657 29 | 0.3815,0.05,1.864 30 | 0.3974,0.05,2.07 31 | 0.4133,0.05,2.275 32 | 0.4305,0.04,2.462 33 | 0.4509,0.04,2.657 34 | 0.4714,0.05,2.869 35 | 0.4959,0.05,3.093 36 | 0.5209,0.05,3.324 37 | 0.5486,0.06,3.586 38 | 0.5821,0.05,3.858 39 | 0.6168,0.06,4.152 40 | 0.6595,0.05,4.483 41 | 0.7045,0.04,4.838 42 | 0.756,0.03,5.242 43 | 0.8211,0.04,5.727 44 | 0.892,0.04,6.312 45 | 0.984,0.04,6.992 46 | 1.088,0.04,7.795 47 | 1.216,0.09,8.828 48 | 1.393,0.13,10.1 49 | 1.61,0.15,11.85 50 | 1.937,0.24,14.08 51 | 15,0.24,14.08 -------------------------------------------------------------------------------- /optogpt/nk/Cr.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.2938,0.94,2.58 3 | 0.3002,0.98,2.67 4 | 0.3077,1.02,2.76 5 | 0.3163,1.06,2.85 6 | 0.3237,1.12,2.95 7 | 0.3333,1.18,3.04 8 | 0.3416,1.26,3.12 9 | 0.3512,1.33,3.18 10 | 0.3625,1.39,3.24 11 | 0.3723,1.43,3.31 12 | 0.385,1.44,3.4 13 | 0.3961,1.48,3.54 14 | 0.4092,1.54,3.71 15 | 0.4246,1.65,3.89 16 | 0.4381,1.8,4.06 17 | 0.4558,1.99,4.22 18 | 0.4714,2.22,4.36 19 | 0.4901,2.49,4.44 20 | 0.5123,2.75,4.46 21 | 0.5321,2.98,4.45 22 | 0.5585,3.18,4.41 23 | 0.5821,3.34,4.38 24 | 0.6108,3.48,4.36 25 | 0.7005,3.84,4.37 26 | 0.8157,4.23,4.34 27 | 0.8266,4.27,4.33 28 | 0.8492,4.31,4.32 29 | 0.861,4.33,4.32 30 | 0.8856,4.38,4.31 31 | 0.9116,4.42,4.3 32 | 0.9393,4.47,4.29 33 | 0.9686,4.49,4.28 34 | 0.9999,4.5,4.28 35 | 1.033,4.52,4.29 36 | 1.069,4.53,4.3 37 | 1.107,4.53,4.31 38 | 1.148,4.53,4.34 39 | 1.192,4.51,4.36 40 | 1.24,4.47,4.43 41 | 1.2919999999999998,4.47,4.5 42 | 1.348,4.45,4.56 43 | 1.378,4.43,4.6 44 | 1.442,4.35,4.66 45 | 1.5119999999999998,4.24,4.81 46 | 1.59,4.13,5.03 47 | 1.675,4.06,5.3 48 | 1.771,4.01,5.59 49 | 1.879,3.96,5.95 50 | 2.0,4.01,6.31 51 | 2.138,4.01,6.65 52 | 2.296,3.92,7.06 53 | 2.48,3.83,7.58 54 | 2.695,3.72,8.2 55 | 2.952,3.47,8.97 56 | -------------------------------------------------------------------------------- /optogpt/nk/Cr_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.2938,0.94,2.58 3 | 0.3002,0.98,2.67 4 | 0.3077,1.02,2.76 5 | 0.3163,1.06,2.85 6 | 0.3237,1.12,2.95 7 | 0.3333,1.18,3.04 8 | 0.3416,1.26,3.12 9 | 0.3512,1.33,3.18 10 | 0.3625,1.39,3.24 11 | 0.3723,1.43,3.31 12 | 0.385,1.44,3.4 13 | 0.3961,1.48,3.54 14 | 0.4092,1.54,3.71 15 | 0.4246,1.65,3.89 16 | 0.4381,1.8,4.06 17 | 0.4558,1.99,4.22 18 | 0.4714,2.22,4.36 19 | 0.4901,2.49,4.44 20 | 0.5123,2.75,4.46 21 | 0.5321,2.98,4.45 22 | 0.5585,3.18,4.41 23 | 0.5821,3.34,4.38 24 | 0.6108,3.48,4.36 25 | 0.7005,3.84,4.37 26 | 0.8157,4.23,4.34 27 | 0.8266,4.27,4.33 28 | 0.8492,4.31,4.32 29 | 0.861,4.33,4.32 30 | 0.8856,4.38,4.31 31 | 0.9116,4.42,4.3 32 | 0.9393,4.47,4.29 33 | 0.9686,4.49,4.28 34 | 0.9999,4.5,4.28 35 | 1.033,4.52,4.29 36 | 1.069,4.53,4.3 37 | 1.107,4.53,4.31 38 | 1.148,4.53,4.34 39 | 1.192,4.51,4.36 40 | 1.24,4.47,4.43 41 | 1.2919999999999998,4.47,4.5 42 | 1.348,4.45,4.56 43 | 1.378,4.43,4.6 44 | 1.442,4.35,4.66 45 | 1.5119999999999998,4.24,4.81 46 | 1.59,4.13,5.03 47 | 1.675,4.06,5.3 48 | 1.771,4.01,5.59 49 | 1.879,3.96,5.95 50 | 2.0,4.01,6.31 51 | 2.138,4.01,6.65 52 | 2.296,3.92,7.06 53 | 2.48,3.83,7.58 54 | 2.695,3.72,8.2 55 | 2.952,3.47,8.97 56 | -------------------------------------------------------------------------------- /optogpt/core/utils/cie-cmf.txt: -------------------------------------------------------------------------------- 1 | 400 0.0143 0.0004 0.0679 2 | 405 0.0232 0.0006 0.1102 3 | 410 0.0435 0.0012 0.2074 4 | 415 0.0776 0.0022 0.3713 5 | 420 0.1344 0.0040 0.6456 6 | 425 0.2148 0.0073 1.0391 7 | 430 0.2839 0.0116 1.3856 8 | 435 0.3285 0.0168 1.6230 9 | 440 0.3483 0.0230 1.7471 10 | 445 0.3481 0.0298 1.7826 11 | 450 0.3362 0.0380 1.7721 12 | 455 0.3187 0.0480 1.7441 13 | 460 0.2908 0.0600 1.6692 14 | 465 0.2511 0.0739 1.5281 15 | 470 0.1954 0.0910 1.2876 16 | 475 0.1421 0.1126 1.0419 17 | 480 0.0956 0.1390 0.8130 18 | 485 0.0580 0.1693 0.6162 19 | 490 0.0320 0.2080 0.4652 20 | 495 0.0147 0.2586 0.3533 21 | 500 0.0049 0.3230 0.2720 22 | 505 0.0024 0.4073 0.2123 23 | 510 0.0093 0.5030 0.1582 24 | 515 0.0291 0.6082 0.1117 25 | 520 0.0633 0.7100 0.0782 26 | 525 0.1096 0.7932 0.0573 27 | 530 0.1655 0.8620 0.0422 28 | 535 0.2257 0.9149 0.0298 29 | 540 0.2904 0.9540 0.0203 30 | 545 0.3597 0.9803 0.0134 31 | 550 0.4334 0.9950 0.0087 32 | 555 0.5121 1.0000 0.0057 33 | 560 0.5945 0.9950 0.0039 34 | 565 0.6784 0.9786 0.0027 35 | 570 0.7621 0.9520 0.0021 36 | 575 0.8425 0.9154 0.0018 37 | 580 0.9163 0.8700 0.0017 38 | 585 0.9786 0.8163 0.0014 39 | 590 1.0263 0.7570 0.0011 40 | 595 1.0567 0.6949 0.0010 41 | 600 1.0622 0.6310 0.0008 42 | 605 1.0456 0.5668 0.0006 43 | 610 1.0026 0.5030 0.0003 44 | 615 0.9384 0.4412 0.0002 45 | 620 0.8544 0.3810 0.0002 46 | 625 0.7514 0.3210 0.0001 47 | 630 0.6424 0.2650 0.0000 48 | 635 0.5419 0.2170 0.0000 49 | 640 0.4479 0.1750 0.0000 50 | 645 0.3608 0.1382 0.0000 51 | 650 0.2835 0.1070 0.0000 52 | 655 0.2187 0.0816 0.0000 53 | 660 0.1649 0.0610 0.0000 54 | 665 0.1212 0.0446 0.0000 55 | 670 0.0874 0.0320 0.0000 56 | 675 0.0636 0.0232 0.0000 57 | 680 0.0468 0.0170 0.0000 58 | 685 0.0329 0.0119 0.0000 59 | 690 0.0227 0.0082 0.0000 60 | 695 0.0158 0.0057 0.0000 61 | 700 0.0114 0.0041 0.0000 62 | 705 0.0081 0.0029 0.0000 63 | 710 0.0058 0.0021 0.0000 64 | 715 0.0041 0.0015 0.0000 65 | 720 0.0029 0.0010 0.0000 66 | 725 0.0020 0.0007 0.0000 67 | 730 0.0014 0.0005 0.0000 68 | 735 0.0010 0.0004 0.0000 69 | 740 0.0007 0.0002 0.0000 70 | 745 0.0005 0.0002 0.0000 71 | 750 0.0003 0.0001 0.0000 72 | 755 0.0002 0.0001 0.0000 73 | 760 0.0002 0.0001 0.0000 74 | 765 0.0001 0.0000 0.0000 75 | 770 0.0001 0.0000 0.0000 76 | 775 0.0001 0.0000 0.0000 77 | 780 0.0000 0.0000 0.0000 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to the OptoGPT Repo 2 | 3 | This is the complete repo for three research works: 4 | 5 | ## OptoGPT: A foundation model for multilayer thin film inverse design 6 | In this work, we introduce OptoGPT (Opto Generative Pretrained Transformer), a decoder-only transformer, to solve inverse design of multi-layer thin film structures. 7 | 8 | Check our paper at: https://www.oejournal.org/article/doi/10.29026/oea.2024.240062 9 | 10 | The code can be found in folder **/optogpt** 11 | 12 | ![Workflow](optogpt/figures/optogpt.png) 13 | ## OL-Transformer: A Fast and Universal Surrogate Simulator for Optical Multilayer Thin Film Structures 14 | In this work, we propose the Opto-Layer (OL) Transformer to act as a universal surrogate simulator for fast and efficient simulation of multilayer thin film structures. 15 | 16 | Check our paper at: https://arxiv.org/abs/2305.11984 17 | 18 | 19 | The code can be found in folder **/optogpt** 20 | 21 | ![Workflow](optogpt/figures/ol_transformer.jpg) 22 | ## Solving Out-of-Distribution Challenges in Optical Foundation Models using Self-Improving Data Augmentation 23 | In this work, we propose a self-improving data augmentation technique by leveraging neural networks' extrapolation ability. Using this method, we show significant improvement in various real-applicative design tasks with minimum fine-tuning, which can also be potentially generalized to inverse scientific foundation models. 24 | 25 | Check our paper at: https://openreview.net/forum?id=8jqhElTmNP 26 | 27 | The code can be found in folder **/self_improving** 28 | ![Workflow](optogpt/figures/self_improving.png) 29 | ## Citation 30 | To cite this work: 31 | ~~~ 32 | @article{ma2024optogpt, 33 | title={OptoGPT: a foundation model for inverse design in optical multilayer thin film structures}, 34 | author={Ma, Taigao and Wang, Haozhu and Guo, L Jay}, 35 | journal={Opto-Electronic Advances}, 36 | volume={7}, 37 | number={7}, 38 | year={2024}, 39 | publisher={Opto-Electronic Advance} 40 | } 41 | 42 | @article{ma2023ol, 43 | title={OL-Transformer: A Fast and Universal Surrogate Simulator for Optical Multilayer Thin Film Structures}, 44 | author={Ma, Taigao and Wang, Haozhu and Guo, L Jay}, 45 | journal={arXiv preprint arXiv:2305.11984}, 46 | year={2023} 47 | } 48 | 49 | @inproceedings{ma2024solving, 50 | title={Solving Out-of-Distribution Challenges in Optical Foundation Models using Self-Improving Data Augmentation}, 51 | author={Ma, Mingqian and Ma, Taigao and Guo, L Jay}, 52 | booktitle={Neurips 2024 Workshop Foundation Models for Science: Progress, Opportunities, and Challenges} 53 | } 54 | ~~~ -------------------------------------------------------------------------------- /optogpt/nk/HfO2_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.3,2.18,0 3 | 0.365,2.184180485,0 4 | 0.4114,2.157402548,0 5 | 0.4577,2.139457506,0 6 | 0.5041,2.126701404,0 7 | 0.5504,2.117291268,0 8 | 0.5968,2.110085253,0 9 | 0.6431,2.104439338,0 10 | 0.6895,2.099892459,0 11 | 0.7358,2.096172295,0 12 | 0.7822,2.09306009,0 13 | 0.8285,2.090425506,0 14 | 0.8749,2.088152214,0 15 | 0.9212,2.086172286,0 16 | 0.9676,2.084418244,0 17 | 1.014,2.082849329,0 18 | 1.06,2.08144194,0 19 | 1.107,2.080127888,0 20 | 1.153,2.078941097,0 21 | 1.199,2.077835009,0 22 | 1.246,2.07677352,0 23 | 1.292,2.075789991,0 24 | 1.338,2.074851582,0 25 | 1.385,2.073931075,0 26 | 1.431,2.073060828,0 27 | 1.477,2.072215241,0 28 | 1.524,2.07137176,0 29 | 1.57,2.070562118,0 30 | 1.616,2.069764668,0 31 | 1.663,2.06895933,0 32 | 1.709,2.068177711,0 33 | 1.756,2.067383444,0 34 | 1.802,2.066608304,0 35 | 1.848,2.065833611,0 36 | 1.895,2.065040944,0 37 | 1.941,2.064262651,0 38 | 1.987,2.063480681,0 39 | 2.034,2.062676799,0 40 | 2.08,2.061884237,0 41 | 2.126,2.061085089,0 42 | 2.173,2.060260963,0 43 | 2.219,2.059446218,0 44 | 2.265,2.058622771,0 45 | 2.312,2.057771841,0 46 | 2.358,2.056929107,0 47 | 2.404,2.056076088,0 48 | 2.451,2.055193451,0 49 | 2.497,2.054318344,0 50 | 2.543,2.053431737,0 51 | 2.59,2.052513621,0 52 | 2.636,2.051602742,0 53 | 2.683,2.050659188,0 54 | 2.729,2.049722828,0 55 | 2.775,2.048773467,0 56 | 2.822,2.047789786,0 57 | 2.868,2.046813413,0 58 | 2.914,2.045823349,0 59 | 2.961,2.044797408,0 60 | 3.007,2.04377905,0 61 | 3.053,2.042746415,0 62 | 3.1,2.041676402,0 63 | 3.146,2.040614372,0 64 | 3.192,2.039537555,0 65 | 3.239,2.038421893,0 66 | 3.285,2.037314709,0 67 | 3.331,2.036192283,0 68 | 3.378,2.035029566,0 69 | 3.424,2.033875896,0 70 | 3.47,2.032706568,0 71 | 3.517,2.031495511,0 72 | 3.563,2.030294129,0 73 | 3.61,2.029050057,0 74 | 3.656,2.027816108,0 75 | 3.702,2.026565865,0 76 | 3.749,2.025271489,0 77 | 3.795,2.023987939,0 78 | 3.841,2.022687736,0 79 | 3.888,2.021341947,0 80 | 3.934,2.020007722,0 81 | 3.98,2.018656493,0 82 | 4.027,2.017258209,0 83 | 4.073,2.015872255,0 84 | 4.119,2.014468952,0 85 | 4.166,2.013017106,0 86 | 4.212,2.011578379,0 87 | 4.258,2.01012196,0 88 | 4.305,2.008615484,0 89 | 4.351,2.007122942,0 90 | 4.397,2.005612358,0 91 | 4.444,2.004050181,0 92 | 4.49,2.002502768,0 93 | 4.537,2.000902717,0 94 | 4.583,1.999317996,0 95 | 4.629,1.997714642,0 96 | 4.676,1.996057062,0 97 | 4.722,1.994415671,0 98 | 4.768,1.992755281,0 99 | 4.815,1.991039047,0 100 | 4.861,1.989339873,0 101 | 4.907,1.987621326,0 102 | 4.954,1.985845279,0 103 | 5,1.984087178,0 104 | 8,1.984,0 105 | 15,1.984,0 -------------------------------------------------------------------------------- /optogpt/environment.yml: -------------------------------------------------------------------------------- 1 | name: optogpt 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=2_gnu 8 | - bzip2=1.0.8=h4bc722e_7 9 | - ca-certificates=2025.4.26=hbd8a1cb_0 10 | - ld_impl_linux-64=2.43=h712a8e2_4 11 | - libffi=3.4.6=h2dba641_1 12 | - libgcc=15.1.0=h767d61c_2 13 | - libgcc-ng=15.1.0=h69a702a_2 14 | - libgomp=15.1.0=h767d61c_2 15 | - liblzma=5.8.1=hb9d3cd8_1 16 | - liblzma-devel=5.8.1=hb9d3cd8_1 17 | - libnsl=2.0.1=hd590300_0 18 | - libsqlite=3.50.0=hee588c1_0 19 | - libuuid=2.38.1=h0b41bf4_0 20 | - libxcrypt=4.4.36=hd590300_1 21 | - libzlib=1.3.1=hb9d3cd8_2 22 | - ncurses=6.5=h2d0b736_3 23 | - openssl=3.5.0=h7b32b05_1 24 | - pip=24.3.1=pyh8b19718_0 25 | - python=3.8.19=hd12c33a_0_cpython 26 | - readline=8.2=h8c095d6_2 27 | - setuptools=75.3.0=pyhd8ed1ab_0 28 | - tk=8.6.13=noxft_hd72426e_102 29 | - wheel=0.45.1=pyhd8ed1ab_0 30 | - xz=5.8.1=hbcc6ac9_1 31 | - xz-gpl-tools=5.8.1=hbcc6ac9_1 32 | - xz-tools=5.8.1=hb9d3cd8_1 33 | - pip: 34 | - asttokens==3.0.0 35 | - backcall==0.2.0 36 | - certifi==2025.4.26 37 | - cmake==4.0.2 38 | - colour-science==0.4.1 39 | - comm==0.2.2 40 | - contourpy==1.1.1 41 | - cycler==0.12.1 42 | - debugpy==1.8.14 43 | - decorator==5.2.1 44 | - entrypoints==0.4 45 | - executing==2.2.0 46 | - filelock==3.16.1 47 | - fonttools==4.57.0 48 | - idna==3.10 49 | - imageio==2.35.1 50 | - importlib-resources==6.4.5 51 | - ipykernel==6.29.5 52 | - ipython==8.12.3 53 | - jedi==0.19.2 54 | - kiwisolver==1.4.7 55 | - lit==18.1.8 56 | - matplotlib==3.7.5 57 | - matplotlib-inline==0.1.7 58 | - mpmath==1.3.0 59 | - networkx==3.1 60 | - numpy==1.23.4 61 | - nvidia-cublas-cu11==11.10.3.66 62 | - nvidia-cuda-cupti-cu11==11.7.101 63 | - nvidia-cuda-nvrtc-cu11==11.7.99 64 | - nvidia-cuda-runtime-cu11==11.7.99 65 | - nvidia-cudnn-cu11==8.5.0.96 66 | - nvidia-cufft-cu11==10.9.0.58 67 | - nvidia-curand-cu11==10.2.10.91 68 | - nvidia-cusolver-cu11==11.4.0.1 69 | - nvidia-cusparse-cu11==11.7.4.91 70 | - nvidia-nccl-cu11==2.14.3 71 | - nvidia-nvtx-cu11==11.7.91 72 | - pandas==1.5.1 73 | - parso==0.8.4 74 | - pexpect==4.9.0 75 | - pickleshare==0.7.5 76 | - pillow==10.4.0 77 | - platformdirs==4.3.6 78 | - prompt-toolkit==3.0.51 79 | - psutil==7.0.0 80 | - pure-eval==0.2.3 81 | - pyparsing==3.1.4 82 | - pyyaml==6.0.2 83 | - scikit-learn==1.2.1 84 | - scipy==1.10.1 85 | - seaborn==0.13.2 86 | - six==1.17.0 87 | - stack-data==0.6.3 88 | - sympy==1.13.3 89 | - threadpoolctl==3.5.0 90 | - tmm==0.1.8 91 | - torch==2.0.1 92 | - tqdm==4.67.1 93 | - triton==2.0.0 94 | - typing-extensions==4.13.2 95 | - tzdata==2025.2 96 | - urllib3==2.2.3 97 | - wcwidth==0.2.13 98 | prefix: /home/taigaom/anaconda3/envs/optogpt 99 | -------------------------------------------------------------------------------- /self_improving/README.md: -------------------------------------------------------------------------------- 1 | # Self-Improving Data Augmentation for OptoGPT 2 | 3 | This module implements self-improving data augmentation to solve out-of-distribution challenges in optical design foundation models. 4 | 5 | ## Overview 6 | 7 | The self-improving approach automatically generates better training data by: 8 | 1. Exploring diverse solutions with multiple decoding strategies 9 | 2. Applying intelligent perturbations to promising structures 10 | 3. Filtering perturbed structures that improve performance 11 | 4. Retraining the model with the augmented dataset 12 | 13 | ## Quick Start 14 | 15 | ```bash 16 | python run_self_improving.py \ 17 | --model_path path/to/pretrained_model.pt \ 18 | --train_struct_path path/to/train_struct.pkl \ 19 | --train_spec_path path/to/train_spec.pkl \ 20 | --dev_struct_path path/to/dev_struct.pkl \ 21 | --dev_spec_path path/to/dev_spec.pkl \ 22 | --output_dir ./output \ 23 | --decoding_method TOP-KP_Decode_v2 \ 24 | --perturbation_method GA_PSO \ 25 | --epochs 10 26 | ``` 27 | 28 | ## Module Components 29 | 30 | ### Core Scripts 31 | 32 | - **`run_self_improving.py`**: Main pipeline that orchestrates the complete workflow 33 | 34 | - **`prepare_aug_data.py`**: Prepares augmentation data using different decoding strategies 35 | - Greedy Decode: Deterministic decoding 36 | - TOP-KP Decode: Top-k/Top-p sampling for diversity 37 | - Beam Search: Multi-path exploration 38 | 39 | - **`data_perturb.py`**: Implements perturbation strategies 40 | - Random: Simple thickness perturbation 41 | - PSO: Particle Swarm Optimization for thickness 42 | - GA_PSO: Combined Genetic Algorithm and PSO approach 43 | 44 | - **`generate_dev_data.py`**: Generates out-of-distribution test spectra 45 | - Gaussian and Double Gaussian spectra 46 | - Smooth pulse functions 47 | - DBR (Distributed Bragg Reflector) structures 48 | 49 | - **`combine_data.py`**: Combines original and augmented data 50 | 51 | - **`model_retrain.py`**: Fine-tunes model on augmented dataset 52 | 53 | ## Material Categories 54 | 55 | Materials are categorized by refractive index for intelligent perturbation: 56 | - **High Index**: TiO2, ZnS, ZnSe, Ta2O5, HfO2 57 | - **Medium Index**: SiO2, Al2O3, MgF2, Si3N4 58 | - **Low Index**: MgO, ITO 59 | - **Metals**: Al, Ag 60 | - **Semiconductors**: Ge, Si 61 | 62 | ## Key Parameters 63 | 64 | - **Decoding Method**: `Greedy Decode`, `TOP-KP Decode`, `TOP-KP Decode_v2`, `Beam Search` 65 | - **Perturbation Method**: `random`, `PSO`, `GA_PSO` 66 | - **Error Type**: `MAE`, `MSE` 67 | 68 | ## Requirements 69 | 70 | - PyTorch >= 1.8.0 71 | - NumPy 72 | - Pandas 73 | - pyswarms (for PSO optimization) 74 | - scipy 75 | - multiprocessing support 76 | 77 | ## Notes 78 | 79 | - The module uses multiprocessing for efficient parallel computation 80 | - GPU is recommended for model inference and training 81 | - Ensure sufficient disk space for saving augmented datasets 82 | - The `nk` folder with material refractive index data must be accessible 83 | 84 | ## Citation 85 | 86 | If you use this code, please cite: 87 | ``` 88 | @inproceedings{ma2024solving, 89 | title={Solving Out-of-Distribution Challenges in Optical Foundation Models using Self-Improving Data Augmentation}, 90 | author={Ma, Mingqian and Ma, Taigao and Guo, L Jay}, 91 | booktitle={Neurips 2024 Workshop Foundation Models for Science: Progress, Opportunities, and Challenges} 92 | } 93 | ``` -------------------------------------------------------------------------------- /optogpt/core/utils/color_system.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def xyz_from_xy(x, y): 4 | """Return the vector (x, y, 1-x-y).""" 5 | return np.array((x, y, 1-x-y)) 6 | 7 | class ColourSystem: 8 | """A class representing a colour system. 9 | 10 | A colour system defined by the CIE x, y and z=1-x-y coordinates of 11 | its three primary illuminants and its "white point". 12 | 13 | TODO: Implement gamma correction 14 | 15 | """ 16 | 17 | # The CIE colour matching function for 400 - 780 nm in 5 nm intervals 18 | cmf = np.loadtxt('cie-cmf.txt', usecols=(1,2,3)) 19 | 20 | def __init__(self, red, green, blue, white): 21 | """Initialise the ColourSystem object. 22 | 23 | Pass vectors (ie NumPy arrays of shape (3,)) for each of the 24 | red, green, blue chromaticities and the white illuminant 25 | defining the colour system. 26 | 27 | """ 28 | 29 | # Chromaticities 30 | self.red, self.green, self.blue = red, green, blue 31 | self.white = white 32 | # The chromaticity matrix (rgb -> xyz) and its inverse 33 | self.M = np.vstack((self.red, self.green, self.blue)).T 34 | self.MI = np.linalg.inv(self.M) 35 | # White scaling array 36 | self.wscale = self.MI.dot(self.white) 37 | # xyz -> rgb transformation matrix 38 | self.T = self.MI / self.wscale[:, np.newaxis] 39 | 40 | def xyz_to_rgb(self, xyz, out_fmt=None): 41 | """Transform from xyz to rgb representation of colour. 42 | 43 | The output rgb components are normalized on their maximum 44 | value. If xyz is out the rgb gamut, it is desaturated until it 45 | comes into gamut. 46 | 47 | By default, fractional rgb components are returned; if 48 | out_fmt='html', the HTML hex string '#rrggbb' is returned. 49 | 50 | """ 51 | 52 | rgb = self.T.dot(xyz) 53 | # if np.any(rgb < 0): 54 | # # We're not in the RGB gamut: approximate by desaturating 55 | # w = - np.min(rgb) 56 | # rgb += w 57 | if not np.all(rgb==0): 58 | # Normalize the rgb vector 59 | rgb /= np.max(rgb) 60 | 61 | if out_fmt == 'html': 62 | return self.rgb_to_hex(rgb) 63 | return rgb 64 | 65 | def rgb_to_hex(self, rgb): 66 | """Convert from fractional rgb values to HTML-style hex string.""" 67 | 68 | hex_rgb = (255 * rgb).astype(int) 69 | return '#{:02x}{:02x}{:02x}'.format(*hex_rgb) 70 | 71 | def spec_to_xyz(self, spec): 72 | """Convert a spectrum to an xyz point. 73 | 74 | The spectrum must be on the same grid of points as the colour-matching 75 | function, self.cmf: 380-780 nm in 5 nm steps. 76 | 77 | """ 78 | 79 | XYZ = np.sum(spec[:, np.newaxis] * self.cmf, axis=0) 80 | den = np.sum(XYZ) 81 | if den == 0.: 82 | return XYZ 83 | return XYZ / den 84 | 85 | def spec_to_rgb(self, spec, out_fmt=None): 86 | """Convert a spectrum to an rgb value.""" 87 | 88 | xyz = self.spec_to_xyz(spec) 89 | return self.xyz_to_rgb(xyz, out_fmt) 90 | 91 | illuminant_D65 = xyz_from_xy(0.3127, 0.3291) 92 | cs_srgb = ColourSystem(red=xyz_from_xy(0.64, 0.33), 93 | green=xyz_from_xy(0.30, 0.60), 94 | blue=xyz_from_xy(0.15, 0.06), 95 | white=illuminant_D65) -------------------------------------------------------------------------------- /optogpt/core/utils/common.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from collections import defaultdict 3 | import pickle 4 | import sys 5 | import torch 6 | import json 7 | from typing import Any 8 | from tabulate import tabulate 9 | import os 10 | import time 11 | import psutil 12 | 13 | 14 | def get_mem_info(pid: int) -> dict[str, int]: 15 | res = defaultdict(int) 16 | for mmap in psutil.Process(pid).memory_maps(): 17 | res['rss'] += mmap.rss 18 | res['pss'] += mmap.pss 19 | res['uss'] += mmap.private_clean + mmap.private_dirty 20 | res['shared'] += mmap.shared_clean + mmap.shared_dirty 21 | if mmap.path.startswith('/'): 22 | res['shared_file'] += mmap.shared_clean + mmap.shared_dirty 23 | return res 24 | 25 | 26 | class MemoryMonitor(): 27 | def __init__(self, pids: list[int] = None): 28 | if pids is None: 29 | pids = [os.getpid()] 30 | self.pids = pids 31 | 32 | def add_pid(self, pid: int): 33 | assert pid not in self.pids 34 | self.pids.append(pid) 35 | 36 | def _refresh(self): 37 | self.data = {pid: get_mem_info(pid) for pid in self.pids} 38 | return self.data 39 | 40 | def table(self) -> str: 41 | self._refresh() 42 | table = [] 43 | keys = list(list(self.data.values())[0].keys()) 44 | now = str(int(time.perf_counter() % 1e5)) 45 | for pid, data in self.data.items(): 46 | table.append((now, str(pid)) + tuple(self.format(data[k]) for k in keys)) 47 | return tabulate(table, headers=["time", "PID"] + keys) 48 | 49 | def str(self): 50 | self._refresh() 51 | keys = list(list(self.data.values())[0].keys()) 52 | res = [] 53 | for pid in self.pids: 54 | s = f"PID={pid}" 55 | for k in keys: 56 | v = self.format(self.data[pid][k]) 57 | s += f", {k}={v}" 58 | res.append(s) 59 | return "\n".join(res) 60 | 61 | @staticmethod 62 | def format(size: int) -> str: 63 | for unit in ('', 'K', 'M', 'G'): 64 | if size < 1024: 65 | break 66 | size /= 1024.0 67 | return "%.1f%s" % (size, unit) 68 | 69 | 70 | def create_coco() -> list[Any]: 71 | # Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json 72 | with open("instances_train2017.json") as f: 73 | obj = json.load(f) 74 | return obj["annotations"] 75 | 76 | 77 | def read_sample(x): 78 | # A function that is supposed to read object x, incrementing its refcount. 79 | # This mimics what a real dataloader would do. 80 | if sys.version_info >= (3, 10, 6): 81 | # Before this version, pickle does not increment refcount. This is a bug that's 82 | # fixed in https://github.com/python/cpython/pull/92931. 83 | return pickle.dumps(x) 84 | else: 85 | import msgpack 86 | return msgpack.dumps(x) 87 | 88 | 89 | class DatasetFromList(torch.utils.data.Dataset): 90 | def __init__(self, lst): 91 | self.lst = lst 92 | def __len__(self): 93 | return len(self.lst) 94 | def __getitem__(self, idx: int): 95 | return self.lst[idx] 96 | 97 | 98 | if __name__ == "__main__": 99 | from serialize import NumpySerializedList 100 | monitor = MemoryMonitor() 101 | print("Initial", monitor.str()) 102 | lst = create_coco() 103 | print("JSON", monitor.str()) 104 | lst = NumpySerializedList(lst) 105 | print("Serialized", monitor.str()) 106 | del lst; import gc; gc.collect() 107 | print("End", monitor.str()) 108 | -------------------------------------------------------------------------------- /optogpt/core/datasets/sim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import pi 3 | import colour 4 | import pandas as pd 5 | import colour 6 | import pickle as pkl 7 | from tmm import coh_tmm, inc_tmm 8 | from scipy.interpolate import interp1d 9 | from colour import SDS_ILLUMINANTS, SpectralDistribution 10 | from colour.colorimetry import MSDS_CMFS 11 | from colour.plotting import plot_single_colour_swatch, ColourSwatch, plot_chromaticity_diagram_CIE1931 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | import matplotlib as mpl 15 | import os 16 | import itertools 17 | from multiprocessing import Pool 18 | import pyswarms as ps 19 | from colour.difference import delta_E, delta_E_CIE2000 20 | 21 | 22 | 23 | DATABASE = './nk' 24 | illuminant = SDS_ILLUMINANTS['D65'] 25 | cmfs = MSDS_CMFS['CIE 1931 2 Degree Standard Observer'] 26 | 27 | mats = ['Al', 'Al2O3', 'AlN', 'Ge', 'HfO2', 'ITO', 'MgF2', 'MgO', 'Si', 'Si3N4', 'SiO2', 'Ta2O5', 'TiN', 'TiO2', 'ZnO', 'ZnS', 'ZnSe', 'Glass_Substrate'] 28 | thicks = [str(i) for i in range(5, 255, 5)] 29 | 30 | lamda_low = 0.4 31 | lamda_high = 1.1 32 | wavelengths = np.arange(lamda_low, lamda_high+1e-3, 0.01) 33 | 34 | 35 | 36 | def load_materials(all_mats = mats, wavelengths = wavelengths, DATABASE = './nk'): 37 | ''' 38 | Load material nk and return corresponding interpolators. 39 | 40 | Return: 41 | nk_dict: dict, key -- material name, value: n, k in the 42 | self.wavelength range 43 | ''' 44 | nk_dict = {} 45 | 46 | for mat in all_mats: 47 | nk = pd.read_csv(os.path.join(DATABASE, mat + '.csv')) 48 | nk.dropna(inplace=True) 49 | 50 | wl = nk['wl'].to_numpy() 51 | index_n = nk['n'].to_numpy() 52 | index_k = nk['k'].to_numpy() 53 | 54 | n_fn = interp1d( 55 | wl, index_n, bounds_error=False, fill_value='extrapolate', kind=3) 56 | k_fn = interp1d( 57 | wl, index_k, bounds_error=False, fill_value='extrapolate', kind=1) 58 | 59 | nk_dict[mat] = n_fn(wavelengths) + 1j*k_fn(wavelengths) 60 | 61 | return nk_dict 62 | 63 | def spectrum(materials, thickness, pol = 's', theta=0, wavelengths = wavelengths, nk_dict = {}, substrate = 'Glass_Substrate', substrate_thick = 500000): 64 | ''' 65 | Input: 66 | metal materials: list 67 | thickness: list 68 | theta: degree, the incidence angle 69 | 70 | Return: 71 | All_results: dictionary contains R, T, A, RGB, LAB 72 | ''' 73 | #aa = time.time() 74 | degree = pi/180 75 | theta = theta *degree 76 | wavess = (1e3 * wavelengths).astype('int') 77 | 78 | 79 | thickness = [np.inf] + thickness + [substrate_thick, np.inf] 80 | 81 | R, T, A = [], [], [] 82 | inc_list = ['i'] + ['c']*len(materials) + ['i', 'i'] 83 | for i, lambda_vac in enumerate(wavess): 84 | 85 | n_list = [1] + [nk_dict[mat][i] for mat in materials] + [nk_dict[substrate][i], 1] 86 | 87 | res = inc_tmm(pol, n_list, thickness, inc_list, theta, lambda_vac) 88 | 89 | R.append(res['R']) 90 | T.append(res['T']) 91 | 92 | # thickness = [np.inf] + thickness + [np.inf] 93 | 94 | # R, T, A = [], [], [] 95 | # for i, lambda_vac in enumerate(wavess): 96 | 97 | # n_list = [1] + [nk_dict[mat][i] for mat in materials] + [nk_dict[substrate][i]] 98 | 99 | # res = coh_tmm(pol, n_list, thickness, theta, lambda_vac) 100 | 101 | # R.append(res['R']) 102 | # T.append(res['T']) 103 | 104 | return R + T -------------------------------------------------------------------------------- /self_improving/combine_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../optogpt') 3 | import os 4 | import numpy as np 5 | import pickle as pkl 6 | 7 | def combine_data(original_spec_train, original_struct_train, original_spec_test, original_struct_test, 8 | new_data, ratio=0.1, type_T="Greedy Decode", output_dir='./dataset'): 9 | """ 10 | Separate the new data into train and test data, and combine the original data and new data 11 | 12 | Args: 13 | original_spec_train: Path to original training spectrum data 14 | original_struct_train: Path to original training structure data 15 | original_spec_test: Path to original test spectrum data 16 | original_struct_test: Path to original test structure data 17 | new_data: DataFrame with augmented data 18 | ratio: Ratio for train/test split 19 | type_T: Type of decoding method used 20 | output_dir: Directory to save output files 21 | 22 | Returns: 23 | Tuple of paths to new train/test spec/struct files 24 | """ 25 | # Create output directory if it doesn't exist 26 | os.makedirs(output_dir, exist_ok=True) 27 | 28 | with open(original_struct_train, 'rb') as fp: 29 | train_struc = pkl.load(fp) 30 | 31 | with open(original_spec_train, 'rb') as fp: 32 | train_spec = pkl.load(fp) 33 | 34 | with open(original_struct_test, 'rb') as fp: 35 | test_struc = pkl.load(fp) 36 | 37 | with open(original_spec_test, 'rb') as fp: 38 | test_spec = pkl.load(fp) 39 | 40 | print("Read original data") 41 | 42 | # Separate the new data into train and test data, size 100:1 43 | new_data = new_data.sample(frac=1).reset_index(drop=True) 44 | new_data_train = new_data[:int(len(new_data)*0.99)] 45 | new_data_test = new_data[int(len(new_data)*0.99):] 46 | 47 | # Combine the original data and new data 48 | add_struct_train = new_data_train["perturb_struct"].tolist() 49 | add_spec_train = new_data_train["perturb_spec"].tolist() 50 | 51 | add_struct_test = new_data_test["perturb_struct"].tolist() 52 | add_spec_test = new_data_test["perturb_spec"].tolist() 53 | 54 | new_train_struct = train_struc + add_struct_train 55 | new_train_spec = np.array(train_spec.tolist() + add_spec_train) 56 | 57 | new_test_struct = test_struc + add_struct_test 58 | new_test_spec = np.array(test_spec.tolist() + add_spec_test) 59 | 60 | print("Combined data") 61 | 62 | # Convert decoding type to simplified string for filenames 63 | if type_T == "Greedy Decode": 64 | method_suffix = "greedy" 65 | elif type_T == "TOP-KP Decode": 66 | method_suffix = "topkp" 67 | elif type_T == "TOP-KP Decode_v2": 68 | method_suffix = "topkp_v2" 69 | elif type_T == "Beam Search": 70 | method_suffix = "beam" 71 | else: 72 | method_suffix = "custom" 73 | 74 | # Create consistent naming pattern 75 | train_spec_filename = f"train_spectrum_augmented_{method_suffix}.pkl" 76 | train_struct_filename = f"train_structure_augmented_{method_suffix}.pkl" 77 | test_spec_filename = f"test_spectrum_augmented_{method_suffix}.pkl" 78 | test_struct_filename = f"test_structure_augmented_{method_suffix}.pkl" 79 | 80 | # Full paths 81 | train_spec_path = os.path.join(output_dir, train_spec_filename) 82 | train_struct_path = os.path.join(output_dir, train_struct_filename) 83 | test_spec_path = os.path.join(output_dir, test_spec_filename) 84 | test_struct_path = os.path.join(output_dir, test_struct_filename) 85 | 86 | # Save the combined data 87 | with open(train_spec_path, 'wb') as fp: 88 | pkl.dump(new_train_spec, fp) 89 | 90 | with open(train_struct_path, 'wb') as fp: 91 | pkl.dump(new_train_struct, fp) 92 | 93 | with open(test_spec_path, 'wb') as fp: 94 | pkl.dump(new_test_spec, fp) 95 | 96 | with open(test_struct_path, 'wb') as fp: 97 | pkl.dump(new_test_struct, fp) 98 | 99 | print(f"Saved combined data with {method_suffix} method:") 100 | print(f" Train files: {train_spec_filename}, {train_struct_filename}") 101 | print(f" Test files: {test_spec_filename}, {test_struct_filename}") 102 | 103 | return train_spec_path, train_struct_path, test_spec_path, test_struct_path -------------------------------------------------------------------------------- /optogpt/run_ol_transformer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import Counter 7 | from torch.autograd import Variable 8 | import seaborn as sns 9 | import matplotlib.pyplot as plt 10 | import pickle as pkl 11 | 12 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") #CUDA_VISIBLE_DEVICES=3 13 | from core.datasets.datasets import * 14 | from core.models.transformer import * 15 | from core.trains.train import * 16 | 17 | print(DEVICE) 18 | 19 | if __name__ == '__main__': 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--seeds', default=42, type=int, help='random seeds') 23 | parser.add_argument('--epochs', default=1000, type=int, help='Num of training epoches') 24 | parser.add_argument('--ratios', default=100, type=int, help='Ratio of training dataset') 25 | parser.add_argument('--batch_size', default=1000, type=int, help='Batch size') 26 | parser.add_argument('--dropout', default=0.1, type=float, help='dropout rate') 27 | parser.add_argument('--max_lr', default=1.0, type=float, help='maximum learning rate') 28 | parser.add_argument('--warm_steps', default=100000, type=int, help='learning rate warmup steps') 29 | 30 | parser.add_argument('--struc_dim', default=104, type=int, help='Num of struc tokens') 31 | parser.add_argument('--spec_dim', default=142, type=int, help='Spec dimension') 32 | 33 | parser.add_argument('--layers', default=12, type=int, help='Encoder layers') 34 | parser.add_argument('--head_num', default=8, type=int, help='Attention head numbers') 35 | parser.add_argument('--d_model', default=1024, type=int, help='Total attention dim = head_num * head_dim') 36 | parser.add_argument('--d_ff', default=512, type=int, help='Feed forward layer dim') 37 | parser.add_argument('--max_len', default=22, type=int, help='Transformer horizons') 38 | 39 | parser.add_argument('--save_folder', default='test', type=str, help='First order folder') 40 | parser.add_argument('--save_name', default='model_forward', type=str, help='First order folder') 41 | parser.add_argument('--spec_type', default='R_T', type=str, help='If predict R/T/R+T') 42 | parser.add_argument('--TRAIN_FILE', default='TRAIN_FILE', type=str, help='TRAIN_FILE') 43 | parser.add_argument('--TRAIN_SPEC_FILE', default='TRAIN_SPEC_FILE', type=str, help='TRAIN_SPEC_FILE') 44 | parser.add_argument('--DEV_FILE', default='DEV_FILE', type=str, help='DEV_FILE') 45 | parser.add_argument('--DEV_SPEC_FILE', default='DEV_SPEC_FILE', type=str, help='DEV_SPEC_FILE') 46 | parser.add_argument('--struc_index_dict', default={2:'BOS'}, type=dict, help='struc_index_dict') 47 | parser.add_argument('--struc_word_dict', default={'BOS':2}, type=dict, help='struc_word_dict') 48 | 49 | 50 | args = parser.parse_args() 51 | 52 | torch.manual_seed(args.seeds) 53 | np.random.seed(args.seeds) 54 | 55 | temp = [args.ratios, args.batch_size, args.max_lr, args.warm_steps, args.layers, args.head_num, args.d_model, args.d_ff] 56 | args.save_name += '_' + args.spec_type 57 | args.save_name += '_R_B_LR_WU_L_H_D_F_'+str(temp) 58 | 59 | TRAIN_FILE = './dataset/Structure_train.pkl' 60 | TRAIN_SPEC_FILE = './dataset/Spectrum_train.pkl' 61 | DEV_FILE = './dataset/Structure_dev.pkl' 62 | DEV_SPEC_FILE = './dataset/Spectrum_dev.pkl' 63 | 64 | args.TRAIN_FILE, args.TRAIN_SPEC_FILE, args.DEV_FILE, args.DEV_SPEC_FILE = TRAIN_FILE, TRAIN_SPEC_FILE, DEV_FILE, DEV_SPEC_FILE 65 | 66 | data = PrepareData(TRAIN_FILE, TRAIN_SPEC_FILE, args.ratios, DEV_FILE, DEV_SPEC_FILE, args.batch_size, args.spec_type) 67 | 68 | src_vocab = len(data.struc_word_dict) 69 | tgt_vocab = len(data.dev_spec[0]) 70 | args.struc_dim = src_vocab 71 | args.spec_dim = tgt_vocab 72 | args.struc_index_dict = data.struc_index_dict 73 | args.struc_word_dict = data.struc_word_dict 74 | 75 | print(f"struc_vocab {src_vocab}") 76 | print(f"spec_vocab {tgt_vocab}") 77 | 78 | model = make_model( 79 | args.struc_dim, 80 | args.spec_dim, 81 | args.layers, 82 | args.d_model, 83 | args.d_ff, 84 | args.head_num, 85 | args.dropout 86 | ).to(DEVICE) 87 | 88 | print('Model Transformer, Number of parameters {}'.format(count_params(model))) 89 | 90 | # Step 3: Training model 91 | print(">>>>>>> start train") 92 | train_start = time.time() 93 | criterion = torch.nn.MSELoss() 94 | 95 | optimizer = NoamOpt(args.d_model, args.max_lr, args.warm_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9,0.98), eps=1e-9)) 96 | 97 | train(data, model, criterion, optimizer, args, DEVICE) 98 | print(f"<<<<<<< finished train, cost {time.time()-train_start:.4f} seconds") 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /optogpt/run_optogpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from collections import Counter 7 | from torch.autograd import Variable 8 | import seaborn as sns 9 | import matplotlib.pyplot as plt 10 | import pickle as pkl 11 | 12 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") #CUDA_VISIBLE_DEVICES=3 13 | from core.datasets.datasets import * 14 | from core.models.transformer import * 15 | from core.trains.train import * 16 | 17 | 18 | if __name__ == '__main__': 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--seeds', default=42, type=int, help='random seeds') 22 | parser.add_argument('--epochs', default=1000, type=int, help='Num of training epoches') 23 | parser.add_argument('--ratios', default=100, type=int, help='Ratio of training dataset') 24 | parser.add_argument('--batch_size', default=1000, type=int, help='Batch size') 25 | parser.add_argument('--dropout', default=0.1, type=float, help='dropout rate') 26 | parser.add_argument('--max_lr', default=1.0, type=float, help='maximum learning rate') 27 | parser.add_argument('--warm_steps', default=100000, type=int, help='learning rate warmup steps') 28 | 29 | parser.add_argument('--smoothing', default=0.1, type=float, help='Smoothing for KL divergence') 30 | 31 | parser.add_argument('--struc_dim', default=104, type=int, help='Num of struc tokens') 32 | parser.add_argument('--spec_dim', default=142, type=int, help='Spec dimension') 33 | 34 | parser.add_argument('--layers', default=1, type=int, help='Encoder layers') 35 | parser.add_argument('--head_num', default=8, type=int, help='Attention head numbers') 36 | parser.add_argument('--d_model', default=1024, type=int, help='Total attention dim = head_num * head_dim') 37 | parser.add_argument('--d_ff', default=512, type=int, help='Feed forward layer dim') 38 | parser.add_argument('--max_len', default=22, type=int, help='Transformer horizons') 39 | 40 | parser.add_argument('--save_folder', default='test', type=str, help='First order folder') 41 | parser.add_argument('--save_name', default='model_inverse', type=str, help='First order folder') 42 | parser.add_argument('--spec_type', default='R_T', type=str, help='If predict R/T/R+T') 43 | parser.add_argument('--TRAIN_FILE', default='TRAIN_FILE', type=str, help='TRAIN_FILE') 44 | parser.add_argument('--TRAIN_SPEC_FILE', default='TRAIN_SPEC_FILE', type=str, help='TRAIN_SPEC_FILE') 45 | parser.add_argument('--DEV_FILE', default='DEV_FILE', type=str, help='DEV_FILE') 46 | parser.add_argument('--DEV_SPEC_FILE', default='DEV_SPEC_FILE', type=str, help='DEV_SPEC_FILE') 47 | parser.add_argument('--struc_index_dict', default={2:'BOS'}, type=dict, help='struc_index_dict') 48 | parser.add_argument('--struc_word_dict', default={'BOS':2}, type=dict, help='struc_word_dict') 49 | 50 | args = parser.parse_args() 51 | 52 | torch.manual_seed(args.seeds) 53 | np.random.seed(args.seeds) 54 | 55 | temp = [args.ratios, args.smoothing, args.batch_size, args.max_lr, args.warm_steps, args.layers, args.head_num, args.d_model, args.d_ff] 56 | args.save_name += '_' + args.spec_type 57 | args.save_name += '_S_R_B_LR_WU_L_H_D_F_'+str(temp) 58 | 59 | TRAIN_FILE = './dataset/Structure_train.pkl' 60 | TRAIN_SPEC_FILE = './dataset/Spectrum_train.pkl' 61 | DEV_FILE = './dataset/Structure_dev.pkl' 62 | DEV_SPEC_FILE = './dataset/Spectrum_dev.pkl' 63 | 64 | args.TRAIN_FILE, args.TRAIN_SPEC_FILE, args.DEV_FILE, args.DEV_SPEC_FILE = TRAIN_FILE, TRAIN_SPEC_FILE, DEV_FILE, DEV_SPEC_FILE 65 | 66 | data = PrepareData(TRAIN_FILE, TRAIN_SPEC_FILE, args.ratios, DEV_FILE, DEV_SPEC_FILE, args.batch_size, args.spec_type, 'Inverse') 67 | 68 | tgt_vocab = len(data.struc_word_dict) 69 | src_vocab = len(data.dev_spec[0]) 70 | args.struc_dim = tgt_vocab 71 | args.spec_dim = src_vocab 72 | args.struc_index_dict = data.struc_index_dict 73 | args.struc_word_dict = data.struc_word_dict 74 | 75 | print(f"struc_vocab {src_vocab}") 76 | print(f"spec_vocab {tgt_vocab}") 77 | 78 | model = make_model_I( 79 | args.spec_dim, 80 | args.struc_dim, 81 | args.layers, 82 | args.d_model, 83 | args.d_ff, 84 | args.head_num, 85 | args.dropout 86 | ).to(DEVICE) 87 | 88 | print('Model Transformer, Number of parameters {}'.format(count_params(model))) 89 | 90 | # Step 3: Training model 91 | print(">>>>>>> start train") 92 | train_start = time.time() 93 | criterion = LabelSmoothing(tgt_vocab, padding_idx = 0, smoothing= args.smoothing) 94 | 95 | optimizer = NoamOpt(args.d_model, args.max_lr, args.warm_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9,0.98), eps=1e-9)) 96 | 97 | train_I(data, model, criterion, optimizer, args, DEVICE) 98 | print(f"<<<<<<< finished train, cost {time.time()-train_start:.4f} seconds") -------------------------------------------------------------------------------- /optogpt/nk/Au.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.3,1.61783946,1.92591542 3 | 0.31,1.65423659,1.93079422 4 | 0.32,1.68172004,1.92480464 5 | 0.33,1.69212673,1.91224392 6 | 0.34,1.68363206,1.8968735 7 | 0.35,1.65736939,1.89665208 8 | 0.36,1.63126096,1.92356822 9 | 0.37,1.61569793,1.95201259 10 | 0.38,1.60767985,1.97413747 11 | 0.39,1.60518794,1.99180722 12 | 0.4,1.59915947,1.99984804 13 | 0.41,1.58698287,2.00174335 14 | 0.42,1.56953206,2.00233117 15 | 0.43,1.54514587,1.99683745 16 | 0.44,1.51356155,1.9862035 17 | 0.45,1.47044934,1.96776892 18 | 0.46,1.41125099,1.94435033 19 | 0.47,1.32913574,1.91877424 20 | 0.48,1.21296262,1.9001494 21 | 0.49,1.06656637,1.90987639 22 | 0.5,0.906704742,1.96228634 23 | 0.51,0.761291669,2.06346006 24 | 0.52,0.639418671,2.19935212 25 | 0.53,0.554638906,2.33304055 26 | 0.54,0.487291101,2.46412886 27 | 0.55,0.434403561,2.59213465 28 | 0.56,0.391020095,2.71722435 29 | 0.57,0.355549579,2.83576007 30 | 0.58,0.325333725,2.95416074 31 | 0.59,0.298993077,3.0684205 32 | 0.6,0.277677889,3.18419893 33 | 0.61,0.258734449,3.29545101 34 | 0.62,0.24134197,3.40146254 35 | 0.63,0.226579103,3.50765886 36 | 0.64,0.215667805,3.6145369 37 | 0.65,0.204666045,3.7188903 38 | 0.66,0.19636131,3.82142006 39 | 0.67,0.189107572,3.92236834 40 | 0.68,0.184643121,4.02355566 41 | 0.69,0.178887146,4.12044518 42 | 0.7,0.17674719,4.21775367 43 | 0.71,0.175831732,4.31214755 44 | 0.72,0.175549525,4.40493828 45 | 0.73,0.176358218,4.49879233 46 | 0.74,0.1777416,4.58950938 47 | 0.75,0.179043511,4.68265772 48 | 0.76,0.180860178,4.7718728 49 | 0.77,0.181949436,4.85748775 50 | 0.78,0.185046347,4.94759841 51 | 0.79,0.185339889,5.03345022 52 | 0.8,0.187064345,5.11763898 53 | 0.81,0.189520204,5.20375303 54 | 0.82,0.195110208,5.29776872 55 | 0.83,0.193744024,5.36890621 56 | 0.84,0.197930018,5.45611436 57 | 0.85,0.199790075,5.53437366 58 | 0.86,0.205244053,5.62676185 59 | 0.87,0.211010171,5.71407507 60 | 0.88,0.21395881,5.79490053 61 | 0.89,0.212656461,5.86492403 62 | 0.9,0.216331153,5.94469408 63 | 0.91,0.21935199,6.02396629 64 | 0.92,0.22414654,6.1067325 65 | 0.93,0.229797483,6.19229508 66 | 0.94,0.232607781,6.26902542 67 | 0.95,0.23704489,6.34934485 68 | 0.96,0.240509086,6.4268699 69 | 0.97,0.246718218,6.51125502 70 | 0.98,0.250438281,6.58872086 71 | 0.99,0.256576978,6.6715479 72 | 1,0.25655369,6.7375508 73 | 1.01,0.263584649,6.82234694 74 | 1.02,0.2646218,6.88818953 75 | 1.03,0.278585592,6.99872749 76 | 1.04,0.283963428,7.07749835 77 | 1.05,0.288391969,7.15927066 78 | 1.06,0.290576756,7.22900595 79 | 1.07,0.298268038,7.3147441 80 | 1.08,0.305740907,7.3978619 81 | 1.09,0.304288195,7.46163467 82 | 1.1,0.310787119,7.54735399 83 | 1.11,0.320585844,7.62961479 84 | 1.12,0.323817758,7.69623412 85 | 1.13,0.331362368,7.7764131 86 | 1.14,0.335643765,7.8573744 87 | 1.15,0.33826523,7.93288949 88 | 1.16,0.343892181,8.01004551 89 | 1.17,0.353740945,8.09674549 90 | 1.18,0.354434494,8.16187979 91 | 1.19,0.36030252,8.2383744 92 | 1.2,0.367462172,8.31567714 93 | 1.21,0.368502468,8.39484412 94 | 1.22,0.379622097,8.46861541 95 | 1.23,0.385832929,8.54408328 96 | 1.24,0.391271341,8.62159475 97 | 1.25,0.397540274,8.69773375 98 | 1.26,0.406180459,8.78292039 99 | 1.27,0.404675711,8.8452797 100 | 1.28,0.410182453,8.91774497 101 | 1.29,0.41846701,8.99823878 102 | 1.3,0.427119383,9.07652745 103 | 1.31,0.43347882,9.15066735 104 | 1.32,0.438702623,9.22615887 105 | 1.33,0.444799144,9.29992841 106 | 1.34,0.453582764,9.37777243 107 | 1.35,0.461846521,9.45103501 108 | 1.36,0.469203894,9.52740543 109 | 1.37,0.474751914,9.59985306 110 | 1.38,0.481276075,9.67287427 111 | 1.39,0.48886687,9.74728119 112 | 1.4,0.492746002,9.8268533 113 | 1.41,0.504031585,9.90114802 114 | 1.42,0.510780248,9.97408648 115 | 1.43,0.510288532,10.0403904 116 | 1.44,0.519775692,10.116375 117 | 1.45,0.527035273,10.1898781 118 | 1.46,0.532867313,10.2687343 119 | 1.47,0.544649492,10.3476553 120 | 1.48,0.547589218,10.4094155 121 | 1.49,0.555481749,10.4799791 122 | 1.5,0.558814908,10.5633678 123 | 1.51,0.57128982,10.6300872 124 | 1.52,0.578929627,10.7110355 125 | 1.53,0.584512947,10.7752519 126 | 1.54,0.593974095,10.8621738 127 | 1.55,0.595768292,10.9226244 128 | 1.56,0.604615236,11.0007124 129 | 1.57,0.608956178,11.0750847 130 | 1.58,0.618540012,11.1508136 131 | 1.59,0.62976668,11.2207716 132 | 1.6,0.643011418,11.2975389 133 | 1.61,0.646509141,11.3684578 134 | 1.62,0.654064212,11.4508738 135 | 1.63,0.66001648,11.526675 136 | 1.64,0.669719247,11.6016693 137 | 1.65,0.676406022,11.6721402 138 | 1.66,0.690263603,11.7392467 139 | 1.67,0.698898482,11.8193879 140 | 1.68,0.702935018,11.8848169 141 | 1.69,0.705036798,11.9494043 142 | 1.7,0.727437326,12.0213473 143 | 1.71,0.725509695,12.0962452 144 | 1.72,0.734958354,12.177221 145 | 1.73,0.742519492,12.2437827 146 | 1.74,0.748727421,12.309308 147 | 1.75,0.753407718,12.378775 148 | 1.76,0.760841931,12.4554859 149 | 1.77,0.769842834,12.5281773 150 | 1.78,0.777584532,12.6011156 151 | 1.79,0.786724744,12.6721867 152 | 1.8,0.795400415,12.7429019 153 | 1.81,0.809131832,12.8201888 154 | 1.82,0.814552818,12.8898472 155 | 1.83,0.82060085,12.9535926 156 | 1.84,0.828590025,13.0232969 157 | 1.85,0.838228471,13.0928779 158 | 1.86,0.844166212,13.1557882 159 | 1.87,0.852370186,13.2310206 160 | 1.88,0.864016715,13.3065992 161 | 1.89,0.870248673,13.3677447 162 | 1.9,0.888674006,13.4665863 163 | 1.91,0.900940666,13.5411184 164 | 1.92,0.911992415,13.6112341 165 | 1.93,0.918093169,13.685776 166 | 1.94,0.926388921,13.7622216 167 | 1.95,0.93614448,13.834178 168 | 1.96,0.942111589,13.9061232 169 | 1.97,0.950296795,13.9809941 170 | 1.98,0.960693627,14.0482626 171 | 1.99,0.972762005,14.1252331 172 | 2,0.982293583,14.1969217 -------------------------------------------------------------------------------- /optogpt/core/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from collections import Counter 4 | import pickle as pkl 5 | from torch.autograd import Variable 6 | 7 | UNK = 0 # unknow word-id 8 | PAD = 1 # padding word-id 9 | 10 | def seq_padding(X, padding=0): 11 | """ 12 | add padding to a batch data 13 | """ 14 | L = [len(x) for x in X] 15 | ML = max(L) 16 | return np.array([ 17 | np.concatenate([x, [padding] * (ML - len(x))]) if len(x) < ML else x for x in X 18 | ]) 19 | 20 | class Batch: 21 | "Object for holding a batch of data with mask during training." 22 | def __init__(self, src, trg, if_inverse = 'Forward', pad=0): 23 | 24 | # convert words id to long format. 25 | src = torch.from_numpy(src).long() 26 | trg = torch.tensor(trg).float() 27 | if if_inverse == 'Forward': 28 | self.src = src # source is structure 29 | self.trg = trg # target is spectrum 30 | # # get the padding postion binary mask 31 | # # change the matrix shape to 1×seq.length 32 | self.src_mask = (src != pad).unsqueeze(-2) 33 | self.ntokens = (self.src != pad).data.sum() 34 | elif if_inverse == 'Inverse': 35 | self.trg = src[:, :-1] # target is structure 36 | # decoder target from trg 37 | self.trg_y = src[:, 1:] 38 | self.src = trg.unsqueeze(-2) # source is spectrum 39 | # # get the padding postion binary mask 40 | # # change the matrix shape to 1×seq.length 41 | self.src_mask = None 42 | self.trg_mask = self.make_std_mask(self.trg, pad) 43 | self.ntokens = (self.trg != pad).data.sum() 44 | else: 45 | raise NotImplementedError 46 | # Mask 47 | @staticmethod 48 | def make_std_mask(tgt, pad): 49 | "Create a mask to hide padding and future words." 50 | tgt_mask = (tgt != pad).unsqueeze(-2) 51 | tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) 52 | return tgt_mask # subsequent_mask is defined in 'decoder' section. 53 | 54 | def subsequent_mask(size): 55 | "Mask out subsequent positions." 56 | attn_shape = (1, size, size) 57 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 58 | return torch.from_numpy(subsequent_mask) == 0 59 | 60 | 61 | class PrepareData: 62 | def __init__(self, train_file, train_spec_file, train_ratio, dev_file, dev_spec_file, BATCH_SIZE=128, spec_type = 'R_T', if_inverse = 'Forward'): 63 | 64 | # 01. Read the data and tokenize 65 | self.train_struc, self.train_spec = self.load_data(train_file, train_spec_file, train_ratio) 66 | self.dev_struc, self.dev_spec = self.load_data(dev_file, dev_spec_file) 67 | 68 | dims = self.train_spec.shape[1]//2 69 | if spec_type == 'R': 70 | self.train_spec = self.train_spec[:, :dims] 71 | self.dev_spec = self.dev_spec[:, :dims] 72 | elif spec_type == 'T': 73 | self.train_spec = self.train_spec[:, dims:] 74 | self.dev_spec = self.dev_spec[:, dims:] 75 | 76 | # 02. build dictionary: structure 77 | self.struc_word_dict, self.struc_total_words, self.struc_index_dict = self.build_dict(self.train_struc) 78 | 79 | # 03. word to id by dictionary 80 | self.train_struc = self.wordToID(self.train_struc, self.struc_word_dict) 81 | self.dev_struc = self.wordToID(self.dev_struc, self.struc_word_dict) 82 | 83 | # 04. batch + padding + mask 84 | self.train_data = self.splitBatch(self.train_struc, self.train_spec, BATCH_SIZE, if_inverse) 85 | self.dev_data = self.splitBatch(self.dev_struc, self.dev_spec, BATCH_SIZE, if_inverse) 86 | 87 | def load_data(self, path, spec_path, ratio = 100): 88 | 89 | """ 90 | Read structure and spec Data 91 | tokenize the structure and add start/end marks(Begin of Sentence; End of Sentence) 92 | """ 93 | 94 | struc = [] 95 | all_struc = [] 96 | 97 | with open (path, 'rb') as fp: 98 | all_struc = pkl.load(fp) 99 | 100 | for ele in all_struc: 101 | struc.append(["BOS"] + ele + ["EOS"]) 102 | 103 | with open (spec_path, 'rb') as fp: 104 | spec = pkl.load(fp) 105 | 106 | if ratio <=0 or ratio > 100: 107 | raise NameError('Wrong training dataset ratio. Make sure it is (0, 100]. ') 108 | 109 | lengs = len(struc) 110 | struc = struc[:int(ratio*lengs/100)] 111 | spec = spec[:int(ratio*lengs/100)] 112 | 113 | return struc, spec 114 | 115 | def build_dict(self, sentences, max_words = 1000): 116 | """ 117 | sentences: list of word list 118 | build dictonary as {key(word): value(id)} 119 | """ 120 | word_count = Counter() 121 | for sentence in sentences: 122 | for s in sentence: 123 | word_count[s] += 1 124 | 125 | ls = word_count.most_common(max_words) 126 | total_words = len(ls) + 2 127 | word_dict = {w[0]: index + 2 for index, w in enumerate(ls)} 128 | word_dict['UNK'] = UNK 129 | word_dict['PAD'] = PAD 130 | index_dict = {v: k for k, v in word_dict.items()} 131 | return word_dict, total_words, index_dict 132 | 133 | def wordToID(self, en, en_dict, sort=False): 134 | """ 135 | convert input/output word lists to id lists. 136 | Use input word list length to sort, reduce padding. 137 | """ 138 | 139 | out_en_ids = [[en_dict.get(w, 0) for w in sent] for sent in en] 140 | # out_cn_ids = [[cn_dict.get(w, 0) for w in sent] for sent in cn] 141 | 142 | def len_argsort(seq): 143 | """ 144 | get sorted index w.r.t length. 145 | """ 146 | return sorted(range(len(seq)), key=lambda x: len(seq[x])) 147 | 148 | if sort: # update index 149 | sorted_index = len_argsort(out_en_ids) 150 | out_en_ids = [out_en_ids[id] for id in sorted_index] 151 | return out_en_ids 152 | 153 | def splitBatch(self, struc, spec, batch_size, if_inverse = 'Forward', shuffle=False): 154 | """ 155 | get data into batches 156 | """ 157 | idx_list = np.arange(0, len(struc), batch_size) 158 | if shuffle: 159 | np.random.shuffle(idx_list) 160 | 161 | batch_indexs = [] 162 | for idx in idx_list: 163 | batch_indexs.append(np.arange(idx, min(idx + batch_size, len(struc)))) 164 | 165 | # print(batch_indexs, len(struc)) 166 | batches = [] 167 | for batch_index in batch_indexs: 168 | batch_struc = [struc[index] for index in batch_index] 169 | batch_spec = [spec[index] for index in batch_index] 170 | 171 | batch_struc = seq_padding(batch_struc) # pad the structure sequence 172 | batches.append(Batch(batch_struc, np.array(batch_spec), if_inverse)) 173 | 174 | return batches 175 | -------------------------------------------------------------------------------- /self_improving/model_retrain.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../optogpt') 3 | import os 4 | import time 5 | import torch 6 | import numpy as np 7 | from core.datasets.datasets import PrepareDataAug 8 | from core.models.transformer import make_model_I 9 | from core.trains.train import LabelSmoothing, run_epoch_I, SimpleLossCompute 10 | 11 | # Default device can be overridden from outside 12 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | def save_checkpoint(model, optimizer, epoch, loss_all, path, configs): 15 | """Save model checkpoint.""" 16 | os.makedirs(os.path.dirname(path), exist_ok=True) 17 | torch.save({ 18 | 'epoch': epoch, 19 | 'model_state_dict': model.state_dict(), 20 | 'optimizer_state_dict': optimizer.state_dict(), 21 | 'loss': loss_all, 22 | 'configs': configs 23 | }, path) 24 | 25 | def train_I_aug(data, model, criterion, optimizer, configs, device, epochs, early_stopping_patience, type_T, output_dir="./saved_models"): 26 | """ 27 | Train and Save the model with early stopping. 28 | 29 | Args: 30 | data: Dataset object containing train and dev data 31 | model: Model to train 32 | criterion: Loss criterion 33 | optimizer: Optimizer 34 | configs: Configuration parameters 35 | device: Device to train on 36 | epochs: Number of epochs to train 37 | early_stopping_patience: Early stopping patience 38 | type_T: Type of decoding method used 39 | output_dir: Directory to save model checkpoints 40 | """ 41 | best_dev_loss = 1e5 42 | loss_all = {'train_loss': [], 'dev_loss': []} 43 | save_name = configs.save_name if hasattr(configs, 'save_name') else 'model' 44 | EPOCHS = epochs 45 | epochs_without_improvement = 0 46 | 47 | # Set up model save directory based on decoding type 48 | if type_T == "Greedy Decode": 49 | save_dir = os.path.join(output_dir, "greedy_decode") 50 | elif type_T == "TOP-KP Decode": 51 | save_dir = os.path.join(output_dir, "topkp_decode") 52 | elif type_T == "TOP-KP Decode_v2": 53 | save_dir = os.path.join(output_dir, "topkp_decode_v2") 54 | elif type_T == "Beam Search": 55 | save_dir = os.path.join(output_dir, "beam_search") 56 | else: 57 | save_dir = os.path.join(output_dir, "custom") 58 | 59 | # Create save directory 60 | os.makedirs(save_dir, exist_ok=True) 61 | 62 | best_model_path = os.path.join(save_dir, f"{save_name}_best_augmented.pt") 63 | recent_model_path = os.path.join(save_dir, f"{save_name}_recent_augmented.pt") 64 | 65 | print(f"Models will be saved to: {save_dir}") 66 | 67 | for epoch in range(EPOCHS): 68 | start_time = time.time() 69 | 70 | # Train model 71 | model.train() 72 | train_loss = run_epoch_I(data.train_data, model, SimpleLossCompute(model.generator, criterion, optimizer), epoch, device) 73 | model.eval() 74 | 75 | # validate model on dev dataset 76 | print('>>>>> Evaluate') 77 | dev_loss = run_epoch_I(data.dev_data, model, SimpleLossCompute(model.generator, criterion, None), epoch, device) 78 | print(f'<<<<< Evaluate loss: {dev_loss:.2f} (Epoch {epoch+1}/{EPOCHS}, {time.time()-start_time:.2f}s)') 79 | loss_all['train_loss'].append(train_loss.detach()) 80 | loss_all['dev_loss'].append(dev_loss.detach()) 81 | 82 | # Check for early stopping 83 | if dev_loss < best_dev_loss: 84 | best_dev_loss = dev_loss 85 | epochs_without_improvement = 0 86 | # Save the model if it's the best so far 87 | save_checkpoint(model, optimizer, epoch, loss_all, best_model_path, configs) 88 | print(f"New best model saved (loss: {dev_loss:.4f})") 89 | else: 90 | epochs_without_improvement += 1 91 | if epochs_without_improvement >= early_stopping_patience: 92 | print(f'Early stopping triggered after {epoch + 1} epochs') 93 | break 94 | 95 | # Save the recent model 96 | save_checkpoint(model, optimizer, epoch, loss_all, recent_model_path, configs) 97 | 98 | class FinetuneOpt: 99 | "Optim wrapper that implements rate." 100 | def __init__(self, model_size, factor, warmup, optimizer): 101 | self.optimizer = optimizer 102 | self._step = 0 103 | self.warmup = warmup 104 | self.factor = factor 105 | self.model_size = model_size 106 | self._rate = 0 107 | 108 | def step(self): 109 | "Update parameters and rate" 110 | self._step += 1 111 | rate = self.rate() 112 | for p in self.optimizer.param_groups: 113 | p['lr'] = rate 114 | self._rate = rate 115 | self.optimizer.step() 116 | 117 | def rate(self, step = None): 118 | "Implement `lrate` above" 119 | if step is None: 120 | step = self._step 121 | return self.factor 122 | 123 | def retrain_model(model_path, new_train_struct, new_train_spec, new_test_struct, new_test_spec, 124 | epochs, early_stopping_patience, struct_word_dict, struct_index_dict, type_T, device=None, output_dir="./saved_models"): 125 | """ 126 | Retrain the model with augmented data. 127 | 128 | Args: 129 | model_path: Path to the pretrained model 130 | new_train_struct: Path to augmented training structure data 131 | new_train_spec: Path to augmented training spectrum data 132 | new_test_struct: Path to augmented test structure data 133 | new_test_spec: Path to augmented test spectrum data 134 | epochs: Number of training epochs 135 | early_stopping_patience: Early stopping patience 136 | struct_word_dict: Structure word dictionary 137 | struct_index_dict: Structure index dictionary 138 | type_T: Type of decoding method used 139 | device: Device to use (if None, uses global DEVICE) 140 | output_dir: Directory to save model outputs 141 | """ 142 | if device is None: 143 | device = DEVICE 144 | 145 | a = torch.load(model_path, map_location=device) 146 | args = a['configs'] 147 | torch.manual_seed(args.seeds) 148 | np.random.seed(args.seeds) 149 | model = make_model_I( 150 | args.spec_dim, 151 | args.struc_dim, 152 | args.layers, 153 | args.d_model, 154 | args.d_ff, 155 | args.head_num, 156 | args.dropout 157 | ).to(device) 158 | 159 | model.load_state_dict(a['model_state_dict']) 160 | 161 | TRAIN_FILE = new_train_struct 162 | TRAIN_SPEC_FILE = new_train_spec 163 | DEV_FILE = new_test_struct 164 | DEV_SPEC_FILE = new_test_spec 165 | 166 | args.TRAIN_FILE, args.TRAIN_SPEC_FILE, args.DEV_FILE, args.DEV_SPEC_FILE = TRAIN_FILE, TRAIN_SPEC_FILE, DEV_FILE, DEV_SPEC_FILE 167 | 168 | data = PrepareDataAug(TRAIN_FILE, TRAIN_SPEC_FILE, args.ratios, DEV_FILE, DEV_SPEC_FILE, args.batch_size, args.spec_type, 'Inverse', struct_word_dict, struct_index_dict) 169 | tgt_vocab = len(data.struc_word_dict) 170 | print(tgt_vocab) 171 | src_vocab = len(data.dev_spec[0]) 172 | args.struc_dim = tgt_vocab 173 | args.spec_dim = src_vocab 174 | args.struc_index_dict = data.struc_index_dict 175 | args.struc_word_dict = data.struc_word_dict 176 | 177 | print(">>>>>>> start train") 178 | criterion = LabelSmoothing(tgt_vocab, padding_idx = 0, smoothing= args.smoothing) 179 | 180 | optimizer = FinetuneOpt(args.d_model, 5e-5, 1, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 181 | 182 | train_I_aug(data, model, criterion, optimizer, args, device, epochs, early_stopping_patience, type_T, output_dir) 183 | print("<<<<<<< finished train") 184 | 185 | 186 | -------------------------------------------------------------------------------- /optogpt/README.md: -------------------------------------------------------------------------------- 1 | # README for OptoGPT & OL-Transforemr 2 | 3 | ## Project Overview 4 | 5 | Optical multilayer thin film structures have been widely used in numerous photonic applications. However, existing inverse design methods have many drawbacks because they either fail to quickly adapt to different design targets, or are difficult to suit for different types of structures, e.g., designing for different materials at each layer. These methods also cannot accommodate versatile design situations under different angles and polarizations. In addition, how to benefit practical fabrications and manufacturing has not been extensively considered yet. In this work, we introduce OptoGPT (Opto Generative Pretrained Transformer), a decoder-only transformer, to solve all these drawbacks and issues simultaneously. 6 | 7 | Deep learning-based methods have recently been established as fast and accurate surrogate simulators for optical multilayer thin film structures. However, existing methods only work for limited types of structures with different material arrangements, preventing their applications towards diverse and universal structures. Here, we propose the Opto-Layer (OL) Transformer to act as a universal surrogate simulator for enormous types of structures. Combined with the technique of structure serialization, our model can predict accurate reflection and transmission spectra for up to $10^{25}$ different multilayer structures, while still achieving a six-fold degradation in simulation time compared to physical solvers. Further investigation reveals that the general learning ability comes from the fact that our model first learns the physical embeddings and then uses the self-attention mechanism to capture the hidden relationship of light-matter interaction between each layer. 8 | 9 | 10 | Check our OptoGPT paper here: https://www.oejournal.org/article/doi/10.29026/oea.2024.240062 11 | ![Workflow](figures/optogpt.png) 12 | 13 | Check our OL-Transformer paper here: https://arxiv.org/abs/2305.11984 14 | ![Workflow](figures/ol_transformer.jpg) 15 | 16 | ## Installation 17 | 18 | To set up the environment for this project, follow these steps: 19 | 20 | ### 1. Clone this repository 21 | ~~~ 22 | git clone https://github.com/taigaoma1997/optogpt.git 23 | 24 | cd optogpt/optogpt 25 | ~~~ 26 | 27 | ### 2. Create and activate a conda environment: 28 | ~~~ 29 | conda env create -f environment.yml 30 | conda activate optogpt 31 | ~~~ 32 | ## Environment Details 33 | The project uses the following key dependencies (full list in environment.yml) 34 | ~~~ 35 | python==3.8.19 36 | datasets==2.17.0 37 | numpy==1.23.4 38 | torch==2.0.1 39 | tmm==0.1.8 40 | scikit-learn==1.2.1 41 | pandas==1.5.1 42 | matplotlib==3.7.5 43 | seaborn==0.13.2 44 | colour-science==0.4.1 45 | ~~~ 46 | ## Dataset 47 | 48 | Please download our dataset and place it into this folder: https://huggingface.co/datasets/mataigao/optogpt_data 49 | 50 | Then run data_conversion.ipynb to convert the .csv file to .pkl used for our loading function. 51 | 52 | ## Model 53 | 54 | Please download our trained model here: https://huggingface.co/mataigao/optogpt 55 | 56 | ## Usage 57 | To run the main training script: 58 | ~~~ 59 | CUDA_VISIBLE_DEVICES=0 python run_optogpt.py # for OptoGPT 60 | CUDA_VISIBLE_DEVICES=0 python run_ol_transformer.py # for OL-Transformer 61 | ~~~ 62 | To do evaluations, please use the following files: 63 | ~~~ 64 | analysis_ol_transformer.ipynb # analyze your trained OL-Transformer and reproduce paper results 65 | 66 | analysis_optogpt.ipynb # analyze your trained OptoGPT and reproduce paper results 67 | ~~~ 68 | ![Workflow](figures/embedding.png) 69 | ## Color to spectrum conversion 70 | 71 | We noticed that the color-to-spectrum algorithm could be helpful for the structural color research community. We also put the code here: 72 | ~~~ 73 | import numpy as np 74 | from scipy.optimize import minimize 75 | import matplotlib.pyplot as plt 76 | from colour.difference import delta_E, delta_E_CIE2000 77 | from colour import SDS_ILLUMINANTS, SpectralDistribution 78 | from colour.colorimetry import MSDS_CMFS 79 | illuminant = SDS_ILLUMINANTS['D65'] 80 | cmfs = MSDS_CMFS['CIE 1931 2 Degree Standard Observer'] 81 | 82 | TARGET = [0, 0.6, 0.6] 83 | TARGET_LAB = [50, 80, 0] 84 | 85 | wavelengthss = np.arange(400, 1101, 10) 86 | 87 | def get_rgb_from_spec(spec): 88 | wavelengthss = np.arange(400, 801, 10) 89 | data = dict(zip(wavelengthss, spec[:41])) 90 | sd = colour.SpectralDistribution(data, name='Sample') 91 | rgb = colour.convert(sd, 'Spectral Distribution', 'sRGB') 92 | return rgb 93 | 94 | def get_color(spec): 95 | # return the xyY, RGB, LAB from spectrum 96 | wavelengthss = np.arange(400, 801, 10) 97 | data = dict(zip((wavelengthss).astype('int'), spec[:41])) 98 | sd = SpectralDistribution(data) 99 | XYZ = colour.sd_to_XYZ(sd, cmfs, illuminant) 100 | xyY = colour.XYZ_to_xyY(XYZ) 101 | Lab = colour.XYZ_to_Lab(XYZ / 100) 102 | RGB = colour.XYZ_to_sRGB(XYZ / 100) 103 | return Lab, RGB, xyY 104 | 105 | def fitness(spec, MODE = 'DeltaE'): 106 | lab, rgb, xyy = get_color(spec) 107 | if MODE == 'DeltaE': 108 | mse = delta_E_CIE2000(lab, TARGET_LAB) 109 | smoothness = np.square(np.gradient(np.gradient(spec))).mean() 110 | return mse + 50*smoothness # you can change this alpha factor accordingly 111 | elif MODE == 'RGB': 112 | mse = np.mean(np.square(TARGET - rgb)) 113 | smoothness = np.square(np.gradient(np.gradient(spec))).mean() 114 | return mse + 10*smoothness # you can change this alpha factor accordingly 115 | else: 116 | raise NotImplementedError 117 | 118 | # color to spectrum conversion using optimization 119 | np.random.seed(100) 120 | x0 = np.random.rand((1100-400) // 10 + 1) 121 | bounds = [[0, 1] for _ in range(len(x0))] 122 | res = minimize(fitness, x0, method='SLSQP', options={'disp': True, 'maxiter':100, 'tol':1e-9}, bounds=bounds) 123 | 124 | # calculate the performance of converted spectrum 125 | spec_rgb = list(res.x) 126 | lab2, rgb2, xyy2 = get_color(res.x) 127 | mse = np.mean(np.abs(TARGET - rgb2)) 128 | deltae = delta_E_CIE2000(lab2, TARGET_LAB) 129 | smoothness = np.abs(np.gradient(np.gradient(res.x))).mean() 130 | print('Designed RGB:', rgb2, '\nDesigned LAB:', lab2, '\nTarget LAB:', TARGET_LAB, '\nelta E:', deltae, '\nColor MSE:', mse, '\nSpectrum Smoothness:', smoothness) 131 | 132 | # plot the retrived spectrum 133 | fig = plt.figure(dpi=120, figsize=[4, 3]) 134 | fig.patch.set_facecolor('white') 135 | plt.plot(wavelengthss, res.x, color = np.clip(rgb2, 0, 1)) 136 | plt.title(rgb2) 137 | plt.ylim(-0.020, 1.03) 138 | plt.show() 139 | ~~~ 140 | ![Workflow](figures/color2spec.png) 141 | ## Project Structure 142 | ~~~ 143 | optogpt/optogpt/ 144 | |--- core/ # core functions for OptoGPT and OL-Transformer 145 | |--- dataset/ 146 | |--- datasets.py # dataset preprocess 147 | |--- sim.py # multilayer system simulation functions 148 | |--- models # model architecture for OptoGPT and OL-Transformer 149 | |--- trains # Training algorithms 150 | |--- utils 151 | |--- cie-cmf.txt # color matching function files 152 | |--- color_system.py # helpful functions for color conversion 153 | |--- common.py # maybe useful 154 | |--- datasets/ # please convert from .csv file to .pkl file 155 | |--- Spectrum_train.pkl 156 | |--- Spectrum_dev.pkl 157 | |--- Structure_train.pkl 158 | |--- Structure_dev.pkl 159 | |--- figures/ # repo figures 160 | |--- nk 161 | |--- multiple .csv files # each file is a refractive index data we measured experimentally 162 | |--- optogpt_data # please download it from huggingface 163 | |--- analysis_ol_transformer.ipynb # analyze your trained OL-Transformer and reproduce paper results 164 | |--- analysis_optogpt.ipynb # analyze your trained OptoGPT and reproduce paper results 165 | |--- data_conversion.ipynb # data conversion from 166 | |--- environment.yml 167 | |--- read_me.md 168 | |--- run_ol_transformer.py 169 | |--- run_optogpt.py 170 | ~~~ 171 | 172 | To cite this work: 173 | ~~~ 174 | @article{ma2024optogpt, 175 | title={OptoGPT: a foundation model for inverse design in optical multilayer thin film structures}, 176 | author={Ma, Taigao and Wang, Haozhu and Guo, L Jay}, 177 | journal={Opto-Electronic Advances}, 178 | volume={7}, 179 | number={7}, 180 | year={2024}, 181 | publisher={Opto-Electronic Advance} 182 | } 183 | 184 | @article{ma2023ol, 185 | title={OL-Transformer: A Fast and Universal Surrogate Simulator for Optical Multilayer Thin Film Structures}, 186 | author={Ma, Taigao and Wang, Haozhu and Guo, L Jay}, 187 | journal={arXiv preprint arXiv:2305.11984}, 188 | year={2023} 189 | } 190 | ~~~ 191 | -------------------------------------------------------------------------------- /optogpt/core/trains/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import copy 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from nltk import word_tokenize 10 | from collections import Counter 11 | from torch.autograd import Variable 12 | import seaborn as sns 13 | import matplotlib.pyplot as plt 14 | import pickle as pkl 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | import torch.distributed as dist 17 | 18 | 19 | 20 | class LabelSmoothing(nn.Module): 21 | "Implement label smoothing." 22 | def __init__(self, size, padding_idx, smoothing=0.0): 23 | super(LabelSmoothing, self).__init__() 24 | self.criterion = nn.KLDivLoss(reduction='sum') # 2020 update 25 | self.padding_idx = padding_idx 26 | self.confidence = 1.0 - smoothing 27 | self.smoothing = smoothing 28 | self.size = size 29 | self.true_dist = None 30 | 31 | def forward(self, x, target): 32 | assert x.size(1) == self.size 33 | true_dist = x.data.clone() 34 | true_dist.fill_(self.smoothing / (self.size - 2)) 35 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 36 | true_dist[:, self.padding_idx] = 0 37 | mask = torch.nonzero(target.data == self.padding_idx) 38 | if mask.dim() > 0: 39 | true_dist.index_fill_(0, mask.squeeze(), 0.0) 40 | self.true_dist = true_dist 41 | return self.criterion(x, Variable(true_dist, requires_grad=False)) 42 | 43 | class SimpleLossCompute: 44 | def __init__(self, generator, criterion, opt=None): 45 | self.generator = generator 46 | self.criterion = criterion 47 | self.opt = opt 48 | 49 | def __call__(self, x, y, norm): 50 | x = self.generator(x) 51 | loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 52 | y.contiguous().view(-1)) / norm 53 | loss.backward() 54 | if self.opt is not None: 55 | self.opt.step() 56 | self.opt.optimizer.zero_grad() 57 | return loss.data.item() * norm.float() 58 | 59 | # We used factor=2, warmup-step = 4000 60 | def get_std_opt(model): 61 | return NoamOpt(model.src_embed[0].d_model, 2, 4000, 62 | torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) 63 | 64 | 65 | class NoamOpt: 66 | "Optim wrapper that implements rate." 67 | def __init__(self, model_size, factor, warmup, optimizer): 68 | self.optimizer = optimizer 69 | self._step = 0 70 | self.warmup = warmup 71 | self.factor = factor 72 | self.model_size = model_size 73 | self._rate = 0 74 | 75 | def step(self): 76 | "Update parameters and rate" 77 | self._step += 1 78 | rate = self.rate() 79 | for p in self.optimizer.param_groups: 80 | p['lr'] = rate 81 | self._rate = rate 82 | self.optimizer.step() 83 | 84 | def rate(self, step = None): 85 | "Implement `lrate` above" 86 | if step is None: 87 | step = self._step 88 | return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) 89 | 90 | 91 | def run_epoch(data, model, criterion, optimizer, epoch, DEVICE): 92 | start = time.time() 93 | total_tokens = 0. 94 | total_loss = 0. 95 | tokens = 0. 96 | for i , batch in enumerate(data): 97 | out = model(batch.src.to(DEVICE), batch.src_mask.to(DEVICE)) 98 | # print(out.size(), batch.trg.size()) 99 | loss = criterion(out, batch.trg.to(DEVICE)) 100 | 101 | if optimizer is not None: 102 | loss.backward() 103 | optimizer.step() 104 | optimizer.optimizer.zero_grad() 105 | 106 | total_loss += loss 107 | total_tokens += batch.ntokens 108 | tokens += batch.ntokens 109 | if i % 50 == 1: 110 | elapsed = time.time() - start 111 | print("Epoch {:d} Batch: {:d} Loss: {:.4f} Tokens per Sec: {:.2f}s".format(epoch, i - 1, loss, (tokens.float() / elapsed))) 112 | start = time.time() 113 | tokens = 0 114 | del out, loss 115 | print(total_loss, i) 116 | return total_loss/i 117 | 118 | def count_params(model): 119 | 120 | return sum([np.prod(layer.size()) for layer in model.parameters() if layer.requires_grad]) 121 | 122 | def save_checkpoint(model, optimizer, epoch, loss_all, path, configs): 123 | # save the saved file 124 | torch.save({ 125 | 'epoch': epoch, 126 | 'model_state_dict': model.state_dict(), 127 | 'optimizer_state_dict': optimizer, 128 | 'loss_all':loss_all, 129 | 'configs':configs, 130 | # 'seed':seed, 131 | }, path) 132 | 133 | 134 | def train(data, model, criterion, optimizer, configs, DEVICE): 135 | """ 136 | Train and Save the model. 137 | """ 138 | # init loss as a large value 139 | best_dev_loss = 1e5 140 | loss_all = {'train_loss':[], 'dev_loss':[]} 141 | 142 | save_folder = configs.save_folder 143 | save_name = configs.save_name 144 | EPOCHS = configs.epochs 145 | 146 | for epoch in range(EPOCHS): 147 | # Train model 148 | model.train() 149 | train_loss = run_epoch(data.train_data, model, criterion, optimizer, epoch, DEVICE) 150 | 151 | # validate model on dev dataset 152 | 153 | model.eval() 154 | print('>>>>> Evaluate') 155 | with torch.no_grad(): 156 | dev_loss = run_epoch(data.dev_data, model, criterion, None, epoch, DEVICE) 157 | print('<<<<< Evaluate loss: {:.8f}'.format(dev_loss)) 158 | loss_all['train_loss'].append(train_loss.detach()) 159 | loss_all['dev_loss'].append(dev_loss.detach()) 160 | 161 | # save the model with best-dev-loss 162 | if dev_loss < best_dev_loss: 163 | best_dev_loss = dev_loss 164 | save_checkpoint(model, optimizer, epoch, loss_all, 'saved_models/ol_transformer/'+save_folder+'/'+save_name+'_best.pt', configs) 165 | print('Saved') 166 | if epoch%2 == 1: 167 | best_dev_loss = dev_loss 168 | save_checkpoint(model, optimizer, epoch, loss_all, 'saved_models/ol_transformer/'+save_folder+'/'+save_name+'_recent.pt', configs) 169 | 170 | print(f">>>>> current best loss: ", best_dev_loss) 171 | 172 | 173 | def run_epoch_I(data, model, loss_compute, epoch, DEVICE): 174 | start = time.time() 175 | total_tokens = 0. 176 | total_loss = 0. 177 | tokens = 0. 178 | for i , batch in enumerate(data): 179 | out = model(batch.src.to(DEVICE), batch.trg.to(DEVICE), batch.src_mask, batch.trg_mask.to(DEVICE)) 180 | loss = loss_compute(out, batch.trg_y.to(DEVICE), batch.ntokens.to(DEVICE)) 181 | total_loss += loss 182 | total_tokens += batch.ntokens 183 | tokens += batch.ntokens 184 | if i % 50 == 1: 185 | elapsed = time.time() - start 186 | print("Epoch {:d} Batch: {:d} Loss: {:.4f} Tokens per Sec: {:.2f}s".format(epoch, i - 1, loss / batch.ntokens, (tokens.float() / elapsed ))) 187 | start = time.time() 188 | tokens = 0 189 | del out, loss 190 | 191 | return total_loss / total_tokens 192 | 193 | 194 | def train_I(data, model, criterion, optimizer, configs, DEVICE): 195 | """ 196 | Train and Save the model. 197 | """ 198 | # init loss as a large value 199 | best_dev_loss = 1e5 200 | loss_all = {'train_loss':[], 'dev_loss':[]} 201 | 202 | save_folder = configs.save_folder 203 | save_name = configs.save_name 204 | EPOCHS = configs.epochs 205 | 206 | for epoch in range(EPOCHS): 207 | # Train model 208 | model.train() 209 | train_loss = run_epoch_I(data.train_data, model, SimpleLossCompute(model.generator, criterion, optimizer), epoch, DEVICE) 210 | model.eval() 211 | 212 | # validate model on dev dataset 213 | print('>>>>> Evaluate') 214 | dev_loss = run_epoch_I(data.dev_data, model, SimpleLossCompute(model.generator, criterion, None), epoch, DEVICE) 215 | print('<<<<< Evaluate loss: {:.2f}'.format(dev_loss)) 216 | loss_all['train_loss'].append(train_loss.detach()) 217 | loss_all['dev_loss'].append(dev_loss.detach()) 218 | 219 | # save the model with best-dev-loss 220 | 221 | if dev_loss < best_dev_loss: 222 | best_dev_loss = dev_loss 223 | save_checkpoint(model, optimizer, epoch, loss_all, 'saved_models/optogpt/'+save_folder+'/'+save_name+'_best.pt', configs) 224 | 225 | save_checkpoint(model, optimizer, epoch, loss_all, 'saved_models/optogpt/'+save_folder+'/'+save_name+'_recent.pt', configs) 226 | 227 | print(f">>>>> current best loss: {best_dev_loss}") 228 | -------------------------------------------------------------------------------- /optogpt/nk/Al_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 1.2399E-04,9.9999E-01,8.2410E-08 3 | 1.3051E-04,9.9999E-01,1.2720E-07 4 | 1.3776E-04,9.9999E-01,1.2488E-07 5 | 1.5498E-04,9.9999E-01,1.9173E-07 6 | 1.7712E-04,9.9999E-01,3.1463E-07 7 | 2.0664E-04,9.9998E-01,5.5102E-07 8 | 2.4797E-04,9.9998E-01,1.0726E-06 9 | 3.0996E-04,9.9997E-01,2.4843E-06 10 | 3.5424E-04,9.9995E-01,4.0484E-06 11 | 4.1328E-04,9.9994E-01,7.1408E-06 12 | 4.9594E-04,9.9991E-01,1.3642E-05 13 | 6.1993E-04,9.9987E-01,3.2079E-05 14 | 6.5255E-04,9.9986E-01,3.8497E-05 15 | 6.8881E-04,9.9985E-01,4.7138E-05 16 | 7.2932E-04,9.9984E-01,5.7974E-05 17 | 7.4021E-04,9.9984E-01,6.1633E-05 18 | 7.7734E-04,9.9986E-01,7.4862E-05 19 | 7.8472E-04,9.9987E-01,7.8480E-05 20 | 7.9274E-04,9.9990E-01,8.5418E-05 21 | 7.9376E-04,9.9992E-01,7.1847E-05 22 | 7.9478E-04,9.9993E-01,4.9998E-05 23 | 7.9529E-04,9.9992E-01,4.1131E-05 24 | 7.9580E-04,9.9992E-01,3.2911E-05 25 | 7.9785E-04,9.9991E-01,2.3687E-05 26 | 7.9887E-04,9.9991E-01,1.1026E-05 27 | 7.9990E-04,9.9989E-01,6.0484E-06 28 | 8.0510E-04,9.9987E-01,6.1901E-06 29 | 8.1569E-04,9.9985E-01,6.4818E-06 30 | 8.2657E-04,9.9983E-01,6.8049E-06 31 | 8.5507E-04,9.9981E-01,7.8422E-06 32 | 8.8561E-04,9.9979E-01,8.9400E-06 33 | 9.5373E-04,9.9974E-01,1.1636E-05 34 | 1.0332E-03,9.9968E-01,1.5628E-05 35 | 1.1271E-03,9.9961E-01,2.1586E-05 36 | 1.2399E-03,9.9953E-01,3.1115E-05 37 | 1.3776E-03,9.9940E-01,4.6771E-05 38 | 1.5498E-03,9.9924E-01,7.2992E-05 39 | 1.7712E-03,9.9898E-01,1.1811E-04 40 | 2.0664E-03,9.9860E-01,2.2001E-04 41 | 2.4797E-03,9.9797E-01,4.3503E-04 42 | 3.0996E-03,9.9694E-01,9.6887E-04 43 | 4.1328E-03,9.9480E-01,2.3492E-03 44 | 4.9594E-03,9.9313E-01,4.1863E-03 45 | 6.1993E-03,9.9111E-01,7.5099E-03 46 | 6.5255E-03,9.9054E-01,8.4716E-03 47 | 6.8881E-03,9.9007E-01,9.6517E-03 48 | 7.2932E-03,9.8909E-01,1.0987E-02 49 | 7.7491E-03,9.8912E-01,1.3728E-02 50 | 8.2657E-03,9.8966E-01,1.4773E-02 51 | 8.5507E-03,9.8934E-01,1.5437E-02 52 | 8.8561E-03,9.8883E-01,1.6304E-02 53 | 9.1841E-03,9.8793E-01,1.7765E-02 54 | 9.5373E-03,9.8761E-01,2.0606E-02 55 | 9.9188E-03,9.8941E-01,2.3421E-02 56 | 1.0332E-02,9.9139E-01,2.4063E-02 57 | 1.0781E-02,9.9233E-01,2.4928E-02 58 | 1.1271E-02,9.9415E-01,2.5452E-02 59 | 1.1808E-02,9.9285E-01,2.4415E-02 60 | 1.2399E-02,9.9123E-01,2.9920E-02 61 | 1.2652E-02,9.9265E-01,3.3061E-02 62 | 1.2915E-02,9.9652E-01,3.5883E-02 63 | 1.3190E-02,1.0012E+00,3.5311E-02 64 | 1.3477E-02,1.0041E+00,3.3392E-02 65 | 1.3776E-02,1.0058E+00,3.0918E-02 66 | 1.4089E-02,1.0060E+00,2.8956E-02 67 | 1.4417E-02,1.0065E+00,2.8232E-02 68 | 1.4760E-02,1.0074E+00,2.6826E-02 69 | 1.5120E-02,1.0077E+00,2.5460E-02 70 | 1.5498E-02,1.0075E+00,2.4476E-02 71 | 1.5694E-02,1.0075E+00,2.4501E-02 72 | 1.5896E-02,1.0078E+00,2.4757E-02 73 | 1.6102E-02,1.0095E+00,2.5180E-02 74 | 1.6314E-02,1.0106E+00,2.3853E-02 75 | 1.6422E-02,1.0110E+00,2.3955E-02 76 | 1.6531E-02,1.0118E+00,2.4020E-02 77 | 1.6642E-02,1.0132E+00,2.4343E-02 78 | 1.6755E-02,1.0151E+00,2.4184E-02 79 | 1.6777E-02,1.0156E+00,2.4228E-02 80 | 1.6800E-02,1.0161E+00,2.4282E-02 81 | 1.6823E-02,1.0167E+00,2.4375E-02 82 | 1.6846E-02,1.0173E+00,2.4499E-02 83 | 1.6869E-02,1.0181E+00,2.4831E-02 84 | 1.6892E-02,1.0194E+00,2.5432E-02 85 | 1.6915E-02,1.0219E+00,2.6018E-02 86 | 1.6938E-02,1.0259E+00,2.4227E-02 87 | 1.6961E-02,1.0262E+00,1.9564E-02 88 | 1.6984E-02,1.0246E+00,1.9145E-02 89 | 1.7008E-02,1.0255E+00,2.0012E-02 90 | 1.7031E-02,1.0305E+00,2.0072E-02 91 | 1.7054E-02,1.0349E+00,1.2476E-02 92 | 1.7078E-02,1.0305E+00,4.1164E-03 93 | 1.7101E-02,1.0249E+00,3.6218E-03 94 | 1.7125E-02,1.0226E+00,3.4141E-03 95 | 1.7149E-02,1.0206E+00,3.4400E-03 96 | 1.7172E-02,1.0192E+00,3.4767E-03 97 | 1.7196E-02,1.0179E+00,3.4877E-03 98 | 1.7208E-02,1.0174E+00,3.5108E-03 99 | 1.7220E-02,1.0169E+00,3.5249E-03 100 | 1.7463E-02,1.0108E+00,3.4957E-03 101 | 1.7712E-02,1.0070E+00,3.5425E-03 102 | 1.8233E-02,1.0019E+00,3.6730E-03 103 | 1.8786E-02,9.9791E-01,3.8926E-03 104 | 1.9373E-02,9.9457E-01,4.1092E-03 105 | 1.9998E-02,9.9143E-01,4.2397E-03 106 | 2.0664E-02,9.8827E-01,4.3696E-03 107 | 2.2543E-02,9.7998E-01,5.0004E-03 108 | 2.4797E-02,9.7048E-01,5.7469E-03 109 | 2.7552E-02,9.5834E-01,6.6191E-03 110 | 3.0996E-02,9.4189E-01,7.8466E-03 111 | 3.5424E-02,9.1802E-01,9.3121E-03 112 | 4.1328E-02,8.8013E-01,1.1651E-02 113 | 4.9594E-02,8.1512E-01,1.5894E-02 114 | 6.1993E-02,6.7912E-01,2.2340E-02 115 | 6.5255E-02,6.3242E-01,2.4770E-02 116 | 6.8881E-02,5.7251E-01,2.7681E-02 117 | 7.2932E-02,4.9131E-01,3.2409E-02 118 | 7.7491E-02,3.7197E-01,4.4202E-02 119 | 7.8472E-02,3.4031E-01,4.8320E-02 120 | 7.9478E-02,3.0373E-01,5.3349E-02 121 | 7.9990E-02,2.8271E-01,5.6697E-02 122 | 8.0510E-02,2.5936E-01,6.1407E-02 123 | 8.1036E-02,2.3344E-01,6.8348E-02 124 | 8.1569E-02,2.0569E-01,7.9959E-02 125 | 8.2109E-02,1.7943E-01,9.4223E-02 126 | 8.2657E-02,1.5065E-01,1.1041E-01 127 | 8.3774E-02,9.4517E-02,1.6589E-01 128 | 8.4921E-02,6.7041E-02,2.3420E-01 129 | 8.6101E-02,5.4863E-02,2.9293E-01 130 | 8.8561E-02,4.4168E-02,3.9115E-01 131 | 9.5373E-02,3.6437E-02,5.9086E-01 132 | 1.0332E-01,3.5753E-02,7.7163E-01 133 | 1.1271E-01,3.8468E-02,9.5677E-01 134 | 1.2399E-01,4.6304E-02,1.1555E+00 135 | 1.3776E-01,5.7167E-02,1.3775E+00 136 | 1.5498E-01,7.2505E-02,1.6366E+00 137 | 1.7712E-01,9.4236E-02,1.9519E+00 138 | 2.0664E-01,1.2677E-01,2.3563E+00 139 | 2.4797E-01,1.8137E-01,2.9029E+00 140 | 3.0996E-01,2.8003E-01,3.7081E+00 141 | 3.2628E-01,3.1474E-01,3.9165E+00 142 | 3.6466E-01,3.9877E-01,4.3957E+00 143 | 4.1328E-01,5.2135E-01,5.0008E+00 144 | 4.4280E-01,6.0790E-01,5.3676E+00 145 | 4.7687E-01,7.2780E-01,5.7781E+00 146 | 5.1660E-01,8.7340E-01,6.2418E+00 147 | 5.6357E-01,1.0728E+00,6.7839E+00 148 | 6.1993E-01,1.3660E+00,7.4052E+00 149 | 6.5225E-01,1.5724E+00,7.7354E+00 150 | 6.8881E-01,1.8301E+00,8.0601E+00 151 | 7.2932E-01,2.1606E+00,8.3565E+00 152 | 7.7491E-01,2.6154E+00,8.4914E+00 153 | 7.9478E-01,2.7675E+00,8.3866E+00 154 | 8.1569E-01,2.7668E+00,8.2573E+00 155 | 8.3774E-01,2.6945E+00,8.1878E+00 156 | 8.8561E-01,2.2802E+00,8.1134E+00 157 | 9.1166E-01,1.9739E+00,8.3058E+00 158 | 9.3928E-01,1.6784E+00,8.5970E+00 159 | 9.6863E-01,1.4867E+00,9.0655E+00 160 | 9.9988E-01,1.4359E+00,9.4939E+00 161 | 1.0332E+00,1.3998E+00,9.8914E+00 162 | 1.1271E+00,1.3281E+00,1.0969E+01 163 | 1.2399E+00,1.3157E+00,1.2245E+01 164 | 1.3776E+00,1.3899E+00,1.3784E+01 165 | 1.5498E+00,1.5782E+00,1.5656E+01 166 | 1.7712E+00,1.9205E+00,1.7991E+01 167 | 2.0664E+00,2.4738E+00,2.0982E+01 168 | 2.4797E+00,3.3372E+00,2.5004E+01 169 | 2.7552E+00,3.9380E+00,2.7580E+01 170 | 3.0996E+00,4.7097E+00,3.0737E+01 171 | 3.2628E+00,5.0735E+00,3.2183E+01 172 | 3.4440E+00,5.4903E+00,3.3814E+01 173 | 3.6466E+00,5.9564E+00,3.5608E+01 174 | 3.8745E+00,6.4808E+00,3.7595E+01 175 | 4.1328E+00,7.0796E+00,3.9826E+01 176 | 4.4280E+00,7.7757E+00,4.2367E+01 177 | 4.7687E+00,8.5881E+00,4.5257E+01 178 | 5.1660E+00,9.5580E+00,4.8593E+01 179 | 5.6357E+00,1.0742E+01,5.2518E+01 180 | 6.1993E+00,1.2195E+01,5.7156E+01 181 | 6.8881E+00,1.4088E+01,6.2841E+01 182 | 7.7491E+00,1.6755E+01,6.9857E+01 183 | 8.8561E+00,2.0837E+01,7.8274E+01 184 | 1.0332E+01,2.6216E+01,8.8197E+01 185 | 1.2399E+01,3.3519E+01,1.0128E+02 186 | 1.3776E+01,3.8461E+01,1.0896E+02 187 | 1.5498E+01,4.3775E+01,1.1839E+02 188 | 1.7712E+01,5.0951E+01,1.2949E+02 189 | 1.9075E+01,5.4413E+01,1.3609E+02 190 | 2.0664E+01,5.8580E+01,1.4423E+02 191 | 2.2543E+01,6.3554E+01,1.5345E+02 192 | 2.4797E+01,6.8535E+01,1.6481E+02 193 | 2.7552E+01,7.5748E+01,1.8178E+02 194 | 3.0996E+01,9.1955E+01,1.9999E+02 195 | 3.3333E+01,1.0210E+02,2.0810E+02 196 | 4.0000E+01,1.2514E+02,2.3019E+02 197 | 4.4444E+01,1.4005E+02,2.4343E+02 198 | 5.0000E+01,1.5730E+02,2.5826E+02 199 | 5.7144E+01,1.7793E+02,2.7534E+02 200 | 6.6666E+01,2.0263E+02,2.9542E+02 201 | 8.0001E+01,2.3356E+02,3.2108E+02 202 | 9.9996E+01,2.7438E+02,3.5435E+02 203 | 1.2500E+02,3.1881E+02,3.9171E+02 204 | 1.3776E+02,3.3962E+02,4.0892E+02 205 | 1.5385E+02,3.6404E+02,4.2962E+02 206 | 1.7712E+02,3.9793E+02,4.5850E+02 207 | 2.0000E+02,4.2396E+02,4.8370E+02 208 | ,, 209 | ,, 210 | ,, 211 | ,, 212 | ,, 213 | ,, 214 | ,, 215 | ,, 216 | ,, 217 | ,, 218 | ,, 219 | ,, 220 | ,, 221 | ,, 222 | ,, 223 | ,, 224 | ,, 225 | ,, 226 | ,, 227 | ,, 228 | ,, 229 | ,, 230 | ,, 231 | ,, 232 | ,, 233 | ,, 234 | ,, 235 | ,, 236 | ,, 237 | ,, 238 | ,, 239 | ,, 240 | ,, 241 | ,, 242 | ,, 243 | ,, 244 | ,, 245 | ,, 246 | ,, 247 | ,, 248 | ,, 249 | ,, 250 | ,, 251 | ,, 252 | ,, 253 | ,, 254 | ,, 255 | ,, 256 | ,, 257 | ,, 258 | ,, 259 | ,, 260 | ,, 261 | ,, 262 | ,, 263 | ,, 264 | ,, 265 | ,, 266 | ,, 267 | ,, 268 | ,, 269 | ,, 270 | ,, 271 | ,, 272 | ,, 273 | ,, 274 | ,, 275 | ,, 276 | ,, 277 | ,, 278 | ,, 279 | ,, 280 | ,, 281 | ,, 282 | ,, 283 | ,, 284 | ,, 285 | ,, 286 | ,, 287 | ,, 288 | ,, 289 | ,, 290 | ,, 291 | ,, 292 | ,, 293 | ,, 294 | ,, 295 | ,, 296 | ,, 297 | ,, 298 | ,, 299 | ,, 300 | ,, 301 | ,, 302 | ,, 303 | ,, 304 | ,, 305 | ,, 306 | ,, 307 | ,, 308 | ,, 309 | ,, 310 | ,, 311 | ,, 312 | ,, 313 | ,, 314 | ,, 315 | ,, 316 | ,, 317 | ,, 318 | ,, 319 | ,, 320 | ,, 321 | ,, 322 | ,, 323 | ,, 324 | ,, 325 | ,, 326 | ,, 327 | ,, 328 | ,, 329 | ,, 330 | ,, 331 | ,, 332 | ,, 333 | ,, 334 | ,, 335 | ,, 336 | ,, 337 | ,, 338 | ,, 339 | ,, 340 | ,, 341 | ,, 342 | ,, 343 | ,, 344 | ,, 345 | ,, 346 | ,, 347 | ,, 348 | ,, 349 | ,, 350 | ,, 351 | ,, 352 | ,, 353 | ,, 354 | ,, 355 | ,, 356 | ,, 357 | ,, 358 | ,, 359 | ,, 360 | ,, 361 | ,, 362 | ,, 363 | ,, 364 | ,, 365 | ,, 366 | ,, 367 | ,, 368 | ,, 369 | ,, 370 | ,, 371 | ,, 372 | ,, 373 | ,, 374 | ,, 375 | ,, 376 | ,, 377 | ,, 378 | ,, 379 | ,, 380 | ,, 381 | ,, 382 | ,, 383 | ,, 384 | ,, 385 | ,, 386 | ,, 387 | ,, 388 | ,, 389 | ,, 390 | ,, 391 | ,, 392 | ,, 393 | ,, 394 | ,, 395 | ,, 396 | ,, 397 | ,, 398 | ,, 399 | ,, 400 | ,, 401 | ,, 402 | ,, 403 | ,, 404 | ,, 405 | ,, 406 | ,, 407 | ,, 408 | ,, 409 | ,, 410 | ,, 411 | ,, 412 | ,, 413 | ,, 414 | ,, 415 | ,, -------------------------------------------------------------------------------- /self_improving/run_self_improving.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Clean example script for running the self-improving data augmentation pipeline. 4 | This script demonstrates the full workflow without hardcoded paths. 5 | """ 6 | 7 | import argparse 8 | import os 9 | import torch 10 | import pandas as pd 11 | from pathlib import Path 12 | 13 | from prepare_aug_data import Prepare_Augment_Data 14 | from data_perturb import get_to_be_perturbed_data, perturb_data, simulate_perturbed_struct, get_perturbed_better_data 15 | from combine_data import combine_data 16 | from model_retrain import retrain_model 17 | from generate_dev_data import generate_dev_data 18 | from core.datasets.datasets import PrepareData 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Self-Improving Data Augmentation for OptoGPT') 23 | 24 | # Required paths 25 | parser.add_argument('--model_path', type=str, required=True, 26 | help='Path to pretrained OptoGPT model') 27 | parser.add_argument('--train_struct_path', type=str, required=True, 28 | help='Path to training structure data (pkl file)') 29 | parser.add_argument('--train_spec_path', type=str, required=True, 30 | help='Path to training spectrum data (pkl file)') 31 | parser.add_argument('--dev_struct_path', type=str, required=True, 32 | help='Path to development structure data (pkl file)') 33 | parser.add_argument('--dev_spec_path', type=str, required=True, 34 | help='Path to development spectrum data (pkl file)') 35 | 36 | # Output directory 37 | parser.add_argument('--output_dir', type=str, default='./output', 38 | help='Directory to save outputs') 39 | 40 | # Method configurations 41 | parser.add_argument('--decoding_method', type=str, default='TOP-KP Decode_v2', 42 | choices=['Greedy Decode', 'TOP-KP Decode', 'TOP-KP Decode_v2', 'Beam Search'], 43 | help='Decoding method for data generation') 44 | parser.add_argument('--perturbation_method', type=str, default='GA_PSO', 45 | choices=['random', 'PSO', 'GA_PSO'], 46 | help='Perturbation method') 47 | parser.add_argument('--error_type', type=str, default='MSE', 48 | choices=['MAE', 'MSE'], 49 | help='Error metric type') 50 | 51 | # Hyperparameters 52 | parser.add_argument('--give_up_threshold', type=float, default=3.0, 53 | help='Maximum error threshold for structures') 54 | parser.add_argument('--kp_num', type=int, default=50, 55 | help='Number of decoding attempts per spectrum') 56 | parser.add_argument('--keep_num', type=int, default=20, 57 | help='Number of best structures to keep') 58 | parser.add_argument('--target_aug_size', type=int, default=200000, 59 | help='Target size for augmented dataset') 60 | 61 | # Training parameters 62 | parser.add_argument('--epochs', type=int, default=10, 63 | help='Number of retraining epochs') 64 | parser.add_argument('--early_stopping_patience', type=int, default=10, 65 | help='Early stopping patience') 66 | 67 | # GPU settings 68 | parser.add_argument('--gpu', type=int, default=0, 69 | help='GPU device ID (-1 for CPU)') 70 | parser.add_argument('--num_workers', type=int, default=4, 71 | help='Number of CPU workers for multiprocessing') 72 | 73 | return parser.parse_args() 74 | 75 | 76 | def setup_directories(output_dir): 77 | """Create necessary output directories.""" 78 | os.makedirs(output_dir, exist_ok=True) 79 | os.makedirs(os.path.join(output_dir, 'augmented_data'), exist_ok=True) 80 | os.makedirs(os.path.join(output_dir, 'models'), exist_ok=True) 81 | os.makedirs(os.path.join(output_dir, 'logs'), exist_ok=True) 82 | 83 | 84 | def main(): 85 | args = parse_args() 86 | 87 | # Set device 88 | if args.gpu >= 0 and torch.cuda.is_available(): 89 | device = torch.device(f'cuda:{args.gpu}') 90 | else: 91 | device = torch.device('cpu') 92 | 93 | # Override hardcoded devices in modules 94 | import prepare_aug_data 95 | import data_perturb 96 | prepare_aug_data.DEVICE = device 97 | data_perturb.DEVICE = device 98 | 99 | # Set multiprocessing workers 100 | import multiprocessing 101 | if args.num_workers > 0: 102 | multiprocessing.set_start_method('spawn', force=True) 103 | 104 | # Setup directories 105 | setup_directories(args.output_dir) 106 | 107 | print(f"Running self-improving data augmentation...") 108 | print(f"Device: {device}") 109 | print(f"Output directory: {args.output_dir}") 110 | 111 | # Step 1: Generate development data 112 | print("\n1. Generating development data...") 113 | dev_spec = generate_dev_data() 114 | print(f"Generated {len(dev_spec)} development spectra") 115 | 116 | # Step 2: Prepare augmentation data 117 | print("\n2. Preparing augmentation data...") 118 | augment_data = Prepare_Augment_Data( 119 | model_path=args.model_path, 120 | decoding_method=args.decoding_method, 121 | error_type=args.error_type, 122 | top_k=10, 123 | top_p=0.9, 124 | kp_num=args.kp_num, 125 | keep_num=args.keep_num 126 | ) 127 | 128 | # Step 3: Get data to be perturbed 129 | print("\n3. Selecting data for perturbation...") 130 | perturbed_df = get_to_be_perturbed_data(augment_data, args.give_up_threshold) 131 | print(f"Selected {len(perturbed_df)} structures for perturbation") 132 | 133 | # Step 4: Perturb data 134 | print(f"\n4. Perturbing data using {args.perturbation_method}...") 135 | perturbed_df = perturb_data(perturbed_df, method=args.perturbation_method) 136 | 137 | # Save intermediate results 138 | perturbed_df.to_pickle(os.path.join(args.output_dir, 'augmented_data', 'perturbed_data.pkl')) 139 | 140 | # Step 5: Simulate perturbed structures 141 | print("\n5. Simulating perturbed structures...") 142 | perturbed_df = simulate_perturbed_struct(perturbed_df, error_type=args.error_type) 143 | 144 | # Step 6: Filter better data 145 | print("\n6. Filtering improved structures...") 146 | added_data = get_perturbed_better_data(perturbed_df) 147 | initial_size = len(added_data) 148 | print(f"Found {initial_size} improved structures") 149 | 150 | # Remove duplicates 151 | added_data = added_data.drop_duplicates(subset=['new_error']) 152 | 153 | # Duplicate to reach target size 154 | while len(added_data) < args.target_aug_size: 155 | added_data = pd.concat([added_data, added_data], ignore_index=True) 156 | added_data = added_data[:args.target_aug_size] 157 | 158 | # Save augmented data 159 | added_data.to_pickle(os.path.join(args.output_dir, 'augmented_data', 'added_data.pkl')) 160 | print(f"Augmented dataset size: {len(added_data)}") 161 | 162 | # Step 7: Combine data 163 | print("\n7. Combining original and augmented data...") 164 | new_train_spec_path, new_train_struct_path, new_test_spec_path, new_test_struct_path = combine_data( 165 | args.train_spec_path, 166 | args.train_struct_path, 167 | args.dev_spec_path, 168 | args.dev_struct_path, 169 | added_data, 170 | ratio=0.1, 171 | type_T=args.decoding_method 172 | ) 173 | 174 | # Move combined data to output directory 175 | import shutil 176 | for old_path, new_name in [ 177 | (new_train_spec_path, 'train_spec_augmented.pkl'), 178 | (new_train_struct_path, 'train_struct_augmented.pkl'), 179 | (new_test_spec_path, 'test_spec_augmented.pkl'), 180 | (new_test_struct_path, 'test_struct_augmented.pkl') 181 | ]: 182 | new_path = os.path.join(args.output_dir, 'augmented_data', new_name) 183 | shutil.move(old_path, new_path) 184 | if 'train_spec' in new_name: 185 | new_train_spec_path = new_path 186 | elif 'train_struct' in new_name: 187 | new_train_struct_path = new_path 188 | elif 'test_spec' in new_name: 189 | new_test_spec_path = new_path 190 | elif 'test_struct' in new_name: 191 | new_test_struct_path = new_path 192 | 193 | # Step 8: Load original model configuration 194 | print("\n8. Loading model configuration...") 195 | model_checkpoint = torch.load(args.model_path, map_location=device) 196 | original_args = model_checkpoint['configs'] 197 | 198 | # Get vocabulary from original data 199 | data = PrepareData( 200 | args.train_struct_path, 201 | args.train_spec_path, 202 | original_args.ratios, 203 | args.dev_struct_path, 204 | args.dev_spec_path, 205 | original_args.batch_size, 206 | original_args.spec_type, 207 | 'Inverse' 208 | ) 209 | struct_word_dict = data.struc_word_dict 210 | struct_index_dict = data.struc_index_dict 211 | 212 | # Step 9: Retrain model 213 | print(f"\n9. Retraining model for {args.epochs} epochs...") 214 | 215 | # Override device selection in retrain function 216 | import model_retrain 217 | model_retrain.DEVICE = device 218 | 219 | # Models directory 220 | models_dir = os.path.join(args.output_dir, 'models') 221 | 222 | retrain_model( 223 | args.model_path, 224 | new_train_struct_path, 225 | new_train_spec_path, 226 | new_test_spec_path, 227 | new_test_struct_path, 228 | args.epochs, 229 | args.early_stopping_patience, 230 | struct_word_dict, 231 | struct_index_dict, 232 | args.decoding_method, 233 | device=device, 234 | output_dir=models_dir 235 | ) 236 | 237 | print(f"\n✓ Self-improving data augmentation completed!") 238 | print(f"Results saved to: {args.output_dir}") 239 | print(f"- Augmented data: {os.path.join(args.output_dir, 'augmented_data')}") 240 | print(f"- Trained models: {models_dir}") 241 | 242 | 243 | if __name__ == '__main__': 244 | main() -------------------------------------------------------------------------------- /self_improving/generate_dev_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to generate development data for self-improving data augmentation. 3 | This creates diverse spectrum patterns to test the model's ability to handle 4 | out-of-distribution data. 5 | """ 6 | import sys 7 | sys.path.append('../optogpt') 8 | import os 9 | import numpy as np 10 | import torch 11 | from core.datasets.sim import load_materials, inc_tmm 12 | import math 13 | 14 | # Materials for multilayer thin film structures 15 | MATERIALS = ['Al', 'Ag', 'Al2O3', 'AlN', 'Ge', 'HfO2', 'ITO', 'MgF2', 'MgO', 16 | 'Si', 'Si3N4', 'SiO2', 'Ta2O5', 'TiN', 'TiO2', 'ZnO', 'ZnS', 17 | 'ZnSe', 'Glass_Substrate'] 18 | 19 | # Thickness range for materials (in nm) 20 | THICKNESSES = [str(i) for i in range(10, 505, 10)] 21 | 22 | # Wavelength range for simulations (in micrometers) 23 | LAMBDA_LOW = 0.4 24 | LAMBDA_HIGH = 1.1 25 | WAVELENGTHS = np.arange(LAMBDA_LOW, LAMBDA_HIGH+1e-3, 0.01) 26 | 27 | # Load material refractive index data 28 | NK_DICT = load_materials(all_mats=MATERIALS, wavelengths=WAVELENGTHS, DATABASE='../optogpt/nk') 29 | 30 | pi = math.pi 31 | 32 | def spectrum(materials, thickness, pol='s', theta=0, wavelengths=WAVELENGTHS, 33 | nk_dict=NK_DICT, substrate='Glass_Substrate', substrate_thick=500000): 34 | """ 35 | Calculate reflection and transmission spectra for multilayer thin film structure. 36 | 37 | Args: 38 | materials: List of material names 39 | thickness: List of layer thicknesses in nm 40 | pol: Polarization ('s' or 'p') 41 | theta: Incidence angle in degrees 42 | wavelengths: Array of wavelengths in micrometers 43 | nk_dict: Dictionary of refractive indices 44 | substrate: Substrate material name 45 | substrate_thick: Substrate thickness in nm 46 | 47 | Returns: 48 | List of reflection and transmission values at each wavelength 49 | """ 50 | degree = pi/180 51 | theta = theta * degree 52 | wavess = (1e3 * wavelengths).astype('int') 53 | 54 | thickness = [np.inf] + thickness + [substrate_thick, np.inf] 55 | 56 | R, T = [], [] 57 | inc_list = ['i'] + ['c']*len(materials) + ['i', 'i'] 58 | for i, lambda_vac in enumerate(wavess): 59 | n_list = [1] + [nk_dict[mat][i] for mat in materials] + [nk_dict[substrate][i], 1] 60 | res = inc_tmm(pol, n_list, thickness, inc_list, theta, lambda_vac) 61 | R.append(res['R']) 62 | T.append(res['T']) 63 | 64 | return R + T 65 | 66 | def sigmoid(x, x0, k): 67 | return 1 / (1 + np.exp(-k * (x - x0))) 68 | 69 | def smooth_pulse_function(start, end=0, steepness=50, reverse=False, wavelength=WAVELENGTHS): 70 | """ 71 | Create a smooth pulse function using sigmoid. 72 | 73 | Args: 74 | start: Starting wavelength 75 | end: Ending wavelength (not used in current implementation) 76 | steepness: Steepness of the sigmoid transition 77 | reverse: Whether to invert the function 78 | wavelength: Array of wavelengths 79 | 80 | Returns: 81 | Array of spectral values 82 | """ 83 | # Sigmoid rise 84 | rising_edge = sigmoid(wavelength, start, steepness) 85 | 86 | if reverse: 87 | return 1 - rising_edge # return the reverse of the smooth pulse function 88 | return rising_edge 89 | 90 | def gaussian_spec(center, std, peak, b=0, if_reverse=False, wavelength=WAVELENGTHS): 91 | """ 92 | Generate a Gaussian-shaped spectrum. 93 | 94 | Args: 95 | center: Center wavelength of the Gaussian 96 | std: Standard deviation (width) of the Gaussian 97 | peak: Peak amplitude 98 | b: Baseline value 99 | if_reverse: Whether to invert the spectrum 100 | wavelength: Array of wavelengths 101 | 102 | Returns: 103 | Array of spectral values 104 | """ 105 | spec = [] 106 | for i in range(len(wavelength)): 107 | temp = np.round(peak * np.exp(-0.5*((center-wavelength[i])/std)**2), 3) 108 | temp = max(temp, b) 109 | if if_reverse: 110 | spec.append(1 - temp) 111 | else: 112 | spec.append(temp) 113 | return spec 114 | 115 | def double_gaussian_spec(center1, std1, peak1, center2, std2, peak2, 116 | b=0, if_reverse=False, wavelength=WAVELENGTHS): 117 | """ 118 | Generate a spectrum with two Gaussian peaks. 119 | 120 | Args: 121 | center1, center2: Center wavelengths of the two Gaussians 122 | std1, std2: Standard deviations of the two Gaussians 123 | peak1, peak2: Peak amplitudes of the two Gaussians 124 | b: Baseline value 125 | if_reverse: Whether to invert the spectrum 126 | wavelength: Array of wavelengths 127 | 128 | Returns: 129 | Array of spectral values 130 | """ 131 | # Initialize the spectrum array 132 | spec = np.zeros(len(wavelength)) 133 | 134 | # Calculate the first Gaussian curve 135 | for i in range(len(wavelength)): 136 | temp1 = peak1 * np.exp(-0.5 * ((center1 - wavelength[i]) / std1) ** 2) 137 | temp1 = max(temp1, b) # Ensure the value is not below the baseline 138 | spec[i] += temp1 139 | 140 | # Add the second Gaussian curve 141 | for i in range(len(wavelength)): 142 | temp2 = peak2 * np.exp(-0.5 * ((center2 - wavelength[i]) / std2) ** 2) 143 | temp2 = max(temp2, b) # Ensure the value is not below the baseline 144 | spec[i] += temp2 145 | 146 | # If reverse is true, invert the spectrum 147 | if if_reverse: 148 | spec = 1 - spec 149 | 150 | # Round the spectrum values 151 | spec = np.round(spec, 3) 152 | return spec 153 | 154 | def generate_dbr_spectra(all_spec): 155 | """ 156 | Generate spectra for Distributed Bragg Reflector (DBR) structures. 157 | 158 | Args: 159 | all_spec: List to append the generated spectra to 160 | 161 | Returns: 162 | Updated list with DBR spectra added 163 | """ 164 | # Center wavelength ranges (nm) 165 | centers = [550, 850, 1050] # Different center wavelengths 166 | 167 | # Material combinations for DBR structures 168 | dbr_combinations = [ 169 | # Material pairs and thickness ratios for 550nm center 170 | ('ZnO', 'Al2O3', 60, 80), 171 | ('ZnO', 'MgF2', 60, 100), 172 | ('ZnO', 'SiO2', 60, 90), 173 | ('TiO2', 'Al2O3', 50, 80), 174 | ('TiO2', 'MgF2', 50, 100), 175 | ('TiO2', 'SiO2', 50, 90), 176 | ('ZnS', 'Al2O3', 50, 80), 177 | ('ZnS', 'MgF2', 50, 100), 178 | ('ZnS', 'SiO2', 50, 90), 179 | ('ZnSe', 'Al2O3', 50, 80), 180 | ('ZnSe', 'MgF2', 50, 100), 181 | ('ZnSe', 'SiO2', 50, 90) 182 | ] 183 | 184 | # Scale factors for different center wavelengths (550nm -> 1x, 850nm -> ~1.5x, 1050nm -> ~1.9x) 185 | scale_factors = [1.0, 1.5, 1.9] 186 | 187 | for center_idx, scale in enumerate(scale_factors): 188 | for mat1, mat2, thick1, thick2 in dbr_combinations: 189 | # Scale thicknesses according to center wavelength 190 | scaled_thick1 = int(thick1 * scale) 191 | scaled_thick2 = int(thick2 * scale) 192 | 193 | unit_struct = [mat1, mat2] 194 | unit_thick = [scaled_thick1, scaled_thick2] 195 | 196 | # Create DBR structures with 5-10 repeated unit cells 197 | for rep in range(5, 11): 198 | dbr_struct = unit_struct * rep 199 | dbr_thick = unit_thick * rep 200 | all_spec.append(np.array(spectrum(dbr_struct, dbr_thick))) 201 | 202 | return all_spec 203 | 204 | def generate_dev_data(): 205 | """ 206 | Generate diverse spectrum patterns for testing. 207 | 208 | Returns: 209 | List of spectra for testing 210 | """ 211 | all_spec = [] 212 | 213 | # Generate Gaussian spectra 214 | centers = [0.5, 0.6, 0.7, 0.8, 0.9] 215 | stds = [0.02, 0.03, 0.04, 0.05] 216 | peaks = [1, 0.95, 0.9] 217 | 218 | for c in centers: 219 | for s in stds: 220 | for p in peaks: 221 | # Four variations: normal, inverted, normal with zero T, and zero R with normal T 222 | all_spec.append(np.concatenate((gaussian_spec(c, s, p), gaussian_spec(c, s, p, 0, True)))) 223 | all_spec.append(np.concatenate((gaussian_spec(c, s, p, 0, True), gaussian_spec(c, s, p)))) 224 | all_spec.append(np.concatenate((gaussian_spec(c, s, p), np.zeros(71)))) 225 | all_spec.append(np.concatenate((np.zeros(71), gaussian_spec(c, s, p)))) 226 | 227 | # Generate Double Gaussian spectra 228 | centers1 = [0.4, 0.5] 229 | centers2 = [0.8, 0.9] 230 | stds1 = [0.04, 0.06] 231 | stds2 = [0.04, 0.06] 232 | peaks1 = [1, 0.9] 233 | peaks2 = [1, 0.9] 234 | 235 | for c1 in centers1: 236 | for c2 in centers2: 237 | for s1 in stds1: 238 | for s2 in stds2: 239 | for p1 in peaks1: 240 | for p2 in peaks2: 241 | # Four variations: normal, inverted, normal with zero T, and zero R with normal T 242 | all_spec.append(np.concatenate(( 243 | double_gaussian_spec(c1, s1, p1, c2, s2, p2), 244 | double_gaussian_spec(c1, s1, p1, c2, s2, p2, 0, True) 245 | ))) 246 | all_spec.append(np.concatenate(( 247 | double_gaussian_spec(c1, s1, p1, c2, s2, p2, 0, True), 248 | double_gaussian_spec(c1, s1, p1, c2, s2, p2) 249 | ))) 250 | all_spec.append(np.concatenate(( 251 | double_gaussian_spec(c1, s1, p1, c2, s2, p2), 252 | np.zeros(71) 253 | ))) 254 | all_spec.append(np.concatenate(( 255 | np.zeros(71), 256 | double_gaussian_spec(c1, s1, p1, c2, s2, p2) 257 | ))) 258 | 259 | # Generate smooth pulse functions 260 | starts = [0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95] 261 | smoothness_values = [35, 40, 45, 50, 55, 60] 262 | 263 | for s in starts: 264 | for sm in smoothness_values: 265 | # Four variations: normal with zero T, zero R with normal T, normal with inverted T, and inverted R with normal T 266 | all_spec.append(np.concatenate(( 267 | smooth_pulse_function(s, 0, sm), 268 | np.zeros(71) 269 | ))) 270 | all_spec.append(np.concatenate(( 271 | np.zeros(71), 272 | smooth_pulse_function(s, 0, sm) 273 | ))) 274 | all_spec.append(np.concatenate(( 275 | smooth_pulse_function(s, 0, sm), 276 | smooth_pulse_function(s, 0, sm, reverse=True) 277 | ))) 278 | all_spec.append(np.concatenate(( 279 | smooth_pulse_function(s, 0, sm, reverse=True), 280 | smooth_pulse_function(s, 0, sm) 281 | ))) 282 | 283 | # Add DBR (Distributed Bragg Reflector) spectra 284 | all_spec = generate_dbr_spectra(all_spec) 285 | 286 | return all_spec 287 | -------------------------------------------------------------------------------- /optogpt/nk/SiO2_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.3,1.491874,0 3 | 0.35131683300000005,1.491874,0 4 | 0.35290734900000004,1.4916639999999999,0 5 | 0.354497894,1.491456,0 6 | 0.35608850099999995,1.491248,0 7 | 0.35767910799999997,1.491042,0 8 | 0.359269714,1.4908379999999999,0 9 | 0.360860352,1.490635,0 10 | 0.362450989,1.490433,0 11 | 0.364041687,1.490233,0 12 | 0.365632385,1.490034,0 13 | 0.367223083,1.489837,0 14 | 0.368813782,1.489641,0 15 | 0.370404541,1.4894459999999998,0 16 | 0.3719953,1.4892530000000002,0 17 | 0.37358606000000005,1.489061,0 18 | 0.375176849,1.4888709999999998,0 19 | 0.37676763900000004,1.488682,0 20 | 0.378358459,1.488494,0 21 | 0.37994931000000004,1.488308,0 22 | 0.381540161,1.4881229999999999,0 23 | 0.383131042,1.48794,0 24 | 0.38472189300000004,1.487758,0 25 | 0.386312805,1.4875770000000001,0 26 | 0.387903687,1.487398,0 27 | 0.38949459799999997,1.48722,0 28 | 0.39108551,1.4870430000000001,0 29 | 0.392676453,1.486868,0 30 | 0.394267426,1.486695,0 31 | 0.395858398,1.486522,0 32 | 0.397449371,1.486351,0 33 | 0.399040344,1.486181,0 34 | 0.400631348,1.486013,0 35 | 0.402222351,1.485846,0 36 | 0.40381338499999997,1.48568,0 37 | 0.405404388,1.485516,0 38 | 0.406995422,1.4853530000000001,0 39 | 0.40858648700000005,1.485191,0 40 | 0.410177551,1.4850299999999999,0 41 | 0.41176861600000003,1.4848709999999998,0 42 | 0.41335967999999995,1.484713,0 43 | 0.41495074500000007,1.484556,0 44 | 0.41654184,1.4844,0 45 | 0.41813290399999997,1.484246,0 46 | 0.41972402999999997,1.484093,0 47 | 0.421315125,1.483941,0 48 | 0.42290625,1.48379,0 49 | 0.42449737499999995,1.483641,0 50 | 0.42608853099999994,1.483493,0 51 | 0.427679626,1.4833459999999998,0 52 | 0.429270813,1.4832,0 53 | 0.430861908,1.483055,0 54 | 0.432453064,1.482911,0 55 | 0.43404424999999996,1.482769,0 56 | 0.435635345,1.482628,0 57 | 0.437226501,1.482487,0 58 | 0.438817688,1.482348,0 59 | 0.440408844,1.4822110000000002,0 60 | 0.442000061,1.482074,0 61 | 0.443591217,1.481938,0 62 | 0.445182373,1.481803,0 63 | 0.446773529,1.48167,0 64 | 0.44836468499999993,1.481537,0 65 | 0.44995590199999996,1.481406,0 66 | 0.451547058,1.481276,0 67 | 0.45313824500000005,1.4811459999999999,0 68 | 0.45472943099999996,1.481018,0 69 | 0.45632061799999996,1.480891,0 70 | 0.45791177400000005,1.4807649999999999,0 71 | 0.459502991,1.480639,0 72 | 0.46109414699999995,1.480515,0 73 | 0.46268530300000005,1.480392,0 74 | 0.464276459,1.48027,0 75 | 0.46586767599999995,1.480148,0 76 | 0.46745883200000005,1.4800280000000001,0 77 | 0.469049988,1.479909,0 78 | 0.47064117400000005,1.47979,0 79 | 0.47223233,1.479673,0 80 | 0.47382348600000007,1.4795559999999999,0 81 | 0.475414642,1.4794399999999999,0 82 | 0.477005768,1.479326,0 83 | 0.478596924,1.479212,0 84 | 0.48018807999999996,1.479099,0 85 | 0.481779205,1.478987,0 86 | 0.4833703,1.478876,0 87 | 0.484961487,1.4787649999999999,0 88 | 0.486552582,1.478656,0 89 | 0.488143707,1.478548,0 90 | 0.48973480199999997,1.47844,0 91 | 0.491325867,1.478333,0 92 | 0.492916992,1.478227,0 93 | 0.494508057,1.478122,0 94 | 0.496099182,1.478017,0 95 | 0.497690247,1.477914,0 96 | 0.499281311,1.477811,0 97 | 0.500872314,1.477709,0 98 | 0.502463379,1.477608,0 99 | 0.5040543820000001,1.477508,0 100 | 0.505645416,1.477408,0 101 | 0.50723642,1.477309,0 102 | 0.508827454,1.477211,0 103 | 0.510418457,1.477114,0 104 | 0.512009399,1.477017,0 105 | 0.513600403,1.4769219999999998,0 106 | 0.5151913450000001,1.476826,0 107 | 0.516782288,1.476732,0 108 | 0.51837323,1.476638,0 109 | 0.519964172,1.476546,0 110 | 0.521555054,1.476453,0 111 | 0.5231459350000001,1.4763620000000002,0 112 | 0.5247368769999999,1.4762709999999999,0 113 | 0.5263277590000001,1.476181,0 114 | 0.527918579,1.476091,0 115 | 0.529509399,1.4760030000000002,0 116 | 0.53110022,1.4759149999999999,0 117 | 0.53269104,1.475827,0 118 | 0.5342818599999999,1.47574,0 119 | 0.535872681,1.475654,0 120 | 0.5374633790000001,1.4755690000000001,0 121 | 0.539054138,1.475484,0 122 | 0.5406448970000001,1.4754,0 123 | 0.542235596,1.475316,0 124 | 0.543826294,1.475233,0 125 | 0.5454169919999999,1.475151,0 126 | 0.54700769,1.475069,0 127 | 0.5485982669999999,1.474988,0 128 | 0.550188904,1.4749079999999999,0 129 | 0.551779541,1.474828,0 130 | 0.5533701169999999,1.474748,0 131 | 0.554960693,1.4746700000000001,0 132 | 0.55655127,1.474592,0 133 | 0.5581417849999999,1.4745139999999999,0 134 | 0.5597323,1.474437,0 135 | 0.561322815,1.474361,0 136 | 0.562913269,1.474285,0 137 | 0.5645037230000001,1.474209,0 138 | 0.566094116,1.474135,0 139 | 0.56768457,1.47406,0 140 | 0.5692749629999999,1.4739870000000002,0 141 | 0.570865295,1.473913,0 142 | 0.572455688,1.473841,0 143 | 0.574045959,1.473769,0 144 | 0.575636292,1.473697,0 145 | 0.577226501,1.473626,0 146 | 0.578816772,1.473555,0 147 | 0.5804069820000001,1.473485,0 148 | 0.581997192,1.473416,0 149 | 0.5835873410000001,1.473347,0 150 | 0.585177551,1.473278,0 151 | 0.586767639,1.4732100000000001,0 152 | 0.588357788,1.473143,0 153 | 0.589947815,1.473076,0 154 | 0.5915378419999999,1.473009,0 155 | 0.593127869,1.4729430000000001,0 156 | 0.594717896,1.472877,0 157 | 0.596307861,1.472812,0 158 | 0.597897766,1.472747,0 159 | 0.599487671,1.472683,0 160 | 0.601077576,1.4726190000000001,0 161 | 0.602667419,1.4725549999999998,0 162 | 0.604257202,1.472492,0 163 | 0.605847046,1.47243,0 164 | 0.607436829,1.4723680000000001,0 165 | 0.60902655,1.472306,0 166 | 0.610616272,1.472245,0 167 | 0.6122059329999999,1.472184,0 168 | 0.613795532,1.472124,0 169 | 0.6153851929999999,1.472064,0 170 | 0.6169747309999999,1.472004,0 171 | 0.618564331,1.471945,0 172 | 0.62015387,1.471886,0 173 | 0.621743347,1.4718280000000001,0 174 | 0.623332825,1.47177,0 175 | 0.6249222409999999,1.4717120000000001,0 176 | 0.6265115969999999,1.471655,0 177 | 0.628100952,1.471598,0 178 | 0.6296903079999999,1.471542,0 179 | 0.631279602,1.471486,0 180 | 0.632868835,1.47143,0 181 | 0.634458069,1.471375,0 182 | 0.636047241,1.4713200000000002,0 183 | 0.637636353,1.471266,0 184 | 0.639225464,1.471211,0 185 | 0.6408145749999999,1.471158,0 186 | 0.642403625,1.4711040000000002,0 187 | 0.643992615,1.471051,0 188 | 0.645581604,1.470998,0 189 | 0.647170532,1.4709459999999999,0 190 | 0.648759399,1.4708940000000001,0 191 | 0.6503482669999999,1.470842,0 192 | 0.651937134,1.470791,0 193 | 0.653525879,1.47074,0 194 | 0.6551146240000001,1.470689,0 195 | 0.656703308,1.470639,0 196 | 0.6582919919999999,1.470589,0 197 | 0.6598806150000001,1.470539,0 198 | 0.661469177,1.4704899999999999,0 199 | 0.663057739,1.4704409999999999,0 200 | 0.66464624,1.470392,0 201 | 0.666234741,1.470344,0 202 | 0.66782312,1.4702959999999998,0 203 | 0.6694114990000001,1.470248,0 204 | 0.670999878,1.470201,0 205 | 0.6725881349999999,1.4701540000000002,0 206 | 0.674176392,1.470107,0 207 | 0.675764587,1.4700600000000001,0 208 | 0.677352783,1.470014,0 209 | 0.678940918,1.469968,0 210 | 0.680528992,1.469922,0 211 | 0.682117065,1.469877,0 212 | 0.683705017,1.469832,0 213 | 0.685292969,1.4697870000000002,0 214 | 0.68688092,1.469743,0 215 | 0.68846875,1.469698,0 216 | 0.6900565190000001,1.469654,0 217 | 0.6916443480000001,1.469611,0 218 | 0.6932320559999999,1.469567,0 219 | 0.694819702,1.469524,0 220 | 0.6964073490000001,1.469481,0 221 | 0.6979949339999999,1.469439,0 222 | 0.699582458,1.4693969999999998,0 223 | 0.701169983,1.469355,0 224 | 0.702757446,1.4693129999999999,0 225 | 0.704344849,1.469271,0 226 | 0.705932129,1.46923,0 227 | 0.707519409,1.4691889999999999,0 228 | 0.709106689,1.469148,0 229 | 0.710693909,1.469108,0 230 | 0.712281067,1.469068,0 231 | 0.713868103,1.469028,0 232 | 0.7154552,1.468988,0 233 | 0.717042175,1.468948,0 234 | 0.718629089,1.468909,0 235 | 0.720216003,1.4688700000000001,0 236 | 0.7218028559999999,1.468831,0 237 | 0.723389648,1.468793,0 238 | 0.724976379,1.4687540000000001,0 239 | 0.72656311,1.468716,0 240 | 0.728149719,1.4686780000000002,0 241 | 0.729736267,1.468641,0 242 | 0.731322815,1.468603,0 243 | 0.7329093019999999,1.468566,0 244 | 0.734495728,1.4685290000000002,0 245 | 0.736082092,1.468493,0 246 | 0.737668335,1.468456,0 247 | 0.739254639,1.46842,0 248 | 0.74084082,1.468384,0 249 | 0.7424269410000001,1.468348,0 250 | 0.744013062,1.468312,0 251 | 0.74559906,1.468277,0 252 | 0.747185059,1.468242,0 253 | 0.748770996,1.468207,0 254 | 0.750356812,1.468172,0 255 | 0.751942688,1.468137,0 256 | 0.753528442,1.4681030000000002,0 257 | 0.755114136,1.468069,0 258 | 0.756699707,1.468035,0 259 | 0.758285278,1.4680010000000001,0 260 | 0.7598707889999999,1.467968,0 261 | 0.761456238,1.467934,0 262 | 0.7630416259999999,1.4679010000000001,0 263 | 0.7646270139999999,1.467868,0 264 | 0.76621228,1.467835,0 265 | 0.7677974849999999,1.467803,0 266 | 0.769382629,1.46777,0 267 | 0.770967712,1.467738,0 268 | 0.772552734,1.467706,0 269 | 0.774137695,1.467674,0 270 | 0.7757225950000001,1.4676420000000001,0 271 | 0.7773074950000001,1.467611,0 272 | 0.7788922119999999,1.4675790000000002,0 273 | 0.780476929,1.4675479999999999,0 274 | 0.7820616459999999,1.467517,0 275 | 0.783646179,1.467487,0 276 | 0.785230713,1.4674559999999999,0 277 | 0.786815186,1.467425,0 278 | 0.788399597,1.467395,0 279 | 0.789983948,1.467365,0 280 | 0.7915682370000001,1.467335,0 281 | 0.793152405,1.4673049999999999,0 282 | 0.7947365110000001,1.467276,0 283 | 0.796320618,1.4672459999999998,0 284 | 0.7979046630000001,1.467217,0 285 | 0.7994885860000001,1.467188,0 286 | 0.8010723879999999,1.467159,0 287 | 0.80265625,1.4671299999999998,0 288 | 0.80423999,1.4671020000000001,0 289 | 0.8058236080000001,1.4670729999999998,0 290 | 0.8074072269999999,1.4670450000000002,0 291 | 0.8089907839999999,1.467017,0 292 | 0.810574219,1.4669889999999999,0 293 | 0.812157593,1.466961,0 294 | 0.813740967,1.466933,0 295 | 0.8153242190000001,1.466906,0 296 | 0.8169073490000001,1.466879,0 297 | 0.81849054,1.4668510000000001,0 298 | 0.820073547,1.466824,0 299 | 0.821656555,1.466797,0 300 | 0.823239441,1.4667709999999998,0 301 | 0.824822266,1.466744,0 302 | 0.826404968,1.466718,0 303 | 0.8279876709999999,1.466691,0 304 | 0.829570313,1.4666649999999999,0 305 | 0.831152832,1.466639,0 306 | 0.8327353519999999,1.466613,0 307 | 0.834317688,1.466587,0 308 | 0.8359000240000001,1.4665620000000001,0 309 | 0.8374823,1.466536,0 310 | 0.839064453,1.4665110000000001,0 311 | 0.8406466060000001,1.466486,0 312 | 0.842228638,1.466461,0 313 | 0.843810547,1.466436,0 314 | 0.845392395,1.4664110000000001,0 315 | 0.846974243,1.466386,0 316 | 0.8485558470000001,1.4663620000000002,0 317 | 0.8501375119999999,1.466337,0 318 | 0.8517191159999999,1.466313,0 319 | 0.853300598,1.466289,0 320 | 0.8548820189999999,1.466265,0 321 | 0.8564633789999999,1.466241,0 322 | 0.858044617,1.466217,0 323 | 0.859625793,1.466193,0 324 | 0.861206848,1.46617,0 325 | 0.862787903,1.466146,0 326 | 0.864368835,1.466123,0 327 | 0.865949707,1.4661,0 328 | 0.8675304570000001,1.466077,0 329 | 0.869111145,1.4660540000000002,0 330 | 0.870691772,1.4660309999999999,0 331 | 0.8722723390000001,1.4660090000000001,0 332 | 0.8738527220000001,1.465986,0 333 | 0.875433228,1.465964,0 334 | 0.8770134890000001,1.465941,0 335 | 0.878593628,1.465919,0 336 | 0.880173828,1.465897,0 337 | 0.881753784,1.465875,0 338 | 0.883333801,1.465853,0 339 | 0.884913696,1.4658309999999999,0 340 | 0.886493469,1.46581,0 341 | 0.8880732419999999,1.4657879999999999,0 342 | 0.889652832,1.465767,0 343 | 0.8912323,1.465746,0 344 | 0.8928118289999999,1.465724,0 345 | 0.894391235,1.465703,0 346 | 0.89597052,1.465682,0 347 | 0.8975496829999999,1.465661,0 348 | 0.899128784,1.465641,0 349 | 0.9007078249999999,1.4656200000000001,0 350 | 0.9022868039999999,1.4655989999999999,0 351 | 0.903865601,1.4655790000000002,0 352 | 0.9054444580000001,1.465559,0 353 | 0.9070230710000001,1.465538,0 354 | 0.9086016849999999,1.465518,0 355 | 0.910180176,1.465498,0 356 | 0.9117586060000001,1.465478,0 357 | 0.913336914,1.465458,0 358 | 0.9149151609999999,1.465438,0 359 | 0.9164933470000001,1.465419,0 360 | 0.9180714109999999,1.465399,0 361 | 0.9196492919999999,1.46538,0 362 | 0.921227234,1.46536,0 363 | 0.922804993,1.465341,0 364 | 0.9243826900000001,1.465322,0 365 | 0.925960327,1.465303,0 366 | 0.927537781,1.465284,0 367 | 0.929115234,1.465265,0 368 | 0.930692566,1.4652459999999998,0 369 | 0.9322698359999999,1.465227,0 370 | 0.933846985,1.465209,0 371 | 0.9354240109999999,1.46519,0 372 | 0.937000977,1.465172,0 373 | 0.93857782,1.465154,0 374 | 0.940154602,1.465135,0 375 | 0.9417312619999999,1.465117,0 376 | 0.943307922,1.465099,0 377 | 0.944884399,1.4650809999999999,0 378 | 0.946460754,1.465063,0 379 | 0.9480369869999999,1.4650450000000002,0 380 | 0.949613281,1.465028,0 381 | 0.951189331,1.4650100000000001,0 382 | 0.95276532,1.464992,0 383 | 0.954341309,1.464975,0 384 | 0.9559171140000001,1.464958,0 385 | 0.957492798,1.46494,0 386 | 0.959068359,1.464923,0 387 | 0.9606439210000001,1.464906,0 388 | 0.96221936,1.4648889999999999,0 389 | 0.963794617,1.464872,0 390 | 0.9653699339999999,1.464855,0 391 | 0.9669450070000001,1.4648379999999999,0 392 | 0.9685200199999999,1.464821,0 393 | 0.9700949099999999,1.464805,0 394 | 0.9716698,1.464788,0 395 | 0.973244507,1.464772,0 396 | 0.9748191529999999,1.464755,0 397 | 0.9763936769999999,1.464739,0 398 | 0.9779680789999998,1.464723,0 399 | 0.979542358,1.4647059999999998,0 400 | 0.981116577,1.46469,0 401 | 0.982690735,1.464674,0 402 | 0.9842647710000001,1.464658,0 403 | 0.985838684,1.464642,0 404 | 0.987412415,1.4646270000000001,0 405 | 0.9889861449999999,1.464611,0 406 | 0.9905597529999999,1.464595,0 407 | 0.992133179,1.464579,0 408 | 0.993706665,1.464564,0 409 | 0.995279907,1.464548,0 410 | 0.996853027,1.464533,0 411 | 0.998426147,1.464518,0 412 | 0.9999991460000001,1.464503,0 413 | 8,1.4645,0 414 | 15,1.4645,0 415 | -------------------------------------------------------------------------------- /optogpt/core/models/transformer.py: -------------------------------------------------------------------------------- 1 | # Build models 2 | import os 3 | import math 4 | import copy 5 | import time 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from nltk import word_tokenize 11 | from collections import Counter 12 | from torch.autograd import Variable 13 | import seaborn as sns 14 | import matplotlib.pyplot as plt 15 | import pickle as pkl 16 | 17 | 18 | def clones(module, N): 19 | """ 20 | "Produce N identical layers." 21 | Use deepcopy the weight are indenpendent. 22 | """ 23 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 24 | 25 | 26 | class Embeddings(nn.Module): 27 | def __init__(self, d_model, vocab): 28 | super(Embeddings, self).__init__() 29 | self.lut = nn.Embedding(vocab, d_model) 30 | self.d_model = d_model 31 | 32 | def forward(self, x): 33 | # return x's embedding vector(times math.sqrt(d_model)) 34 | return self.lut(x) * math.sqrt(self.d_model) 35 | 36 | 37 | class PositionalEncoding(nn.Module): 38 | def __init__(self, d_model, dropout, max_len=5000): 39 | super(PositionalEncoding, self).__init__() 40 | self.dropout = nn.Dropout(p=dropout) 41 | 42 | pe = torch.zeros(max_len, d_model) 43 | position = torch.arange(0., max_len).unsqueeze(1) 44 | div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) 45 | pe_pos = torch.mul(position, div_term) 46 | pe[:, 0::2] = torch.sin(pe_pos) 47 | pe[:, 1::2] = torch.cos(pe_pos) 48 | pe = pe.unsqueeze(0) 49 | self.register_buffer('pe', pe) # pe 50 | 51 | def forward(self, x): 52 | # build pe w.r.t to the max_length 53 | x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False) 54 | return self.dropout(x) 55 | 56 | def attention(query, key, value, mask=None, dropout=None): 57 | "Compute 'Scaled Dot Product Attention'" 58 | d_k = query.size(-1) 59 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 60 | if mask is not None: 61 | scores = scores.masked_fill(mask == 0, -1e9) 62 | p_attn = F.softmax(scores, dim=-1) 63 | if dropout is not None: 64 | p_attn = dropout(p_attn) 65 | return torch.matmul(p_attn, value), p_attn 66 | 67 | 68 | class MultiHeadedAttention(nn.Module): 69 | def __init__(self, h, d_model, dropout=0.1): 70 | "Take in model size and number of heads." 71 | super(MultiHeadedAttention, self).__init__() 72 | # h : number of head 73 | assert d_model % h == 0 # check the h number 74 | self.d_k = d_model // h 75 | self.h = h 76 | # 4 linear layers: WQ WK WV and final linear mapping WO 77 | self.linears = clones(nn.Linear(d_model, d_model), 4) 78 | self.attn = None 79 | self.dropout = nn.Dropout(p=dropout) 80 | 81 | def forward(self, query, key, value, mask=None): 82 | # apply the multi-head using quick method 83 | 84 | if mask is not None: 85 | # Same mask applied to all h heads. 86 | mask = mask.unsqueeze(1) 87 | nbatches = query.size(0) # get batch size 88 | 89 | # 1) Do all the linear projections in batch from d_model => h x d_k 90 | # parttion into h sections,switch 2,3 axis for computation. 91 | query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 92 | for l, x in zip(self.linears, (query, key, value))] 93 | 94 | # 2) Apply attention on all the projected vectors in batch. 95 | x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) 96 | 97 | # 3) "Concat" using a view and apply a final linear. 98 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) 99 | # contiguous: when use transpose, PyTorch does not create a new tensor and just changes the meta data 100 | # use contiguous to make a copy of the tensor with transpose data 101 | 102 | return self.linears[-1](x) # final linear layer 103 | 104 | class LayerNorm(nn.Module): 105 | def __init__(self, features, eps=1e-6): 106 | super(LayerNorm, self).__init__() 107 | self.a_2 = nn.Parameter(torch.ones(features)) 108 | self.b_2 = nn.Parameter(torch.zeros(features)) 109 | self.eps = eps 110 | 111 | def forward(self, x): 112 | mean = x.mean(-1, keepdim=True) # rows 113 | std = x.std(-1, keepdim=True) 114 | x_zscore = (x - mean)/ torch.sqrt(std ** 2 + self.eps) 115 | return self.a_2*x_zscore+self.b_2 116 | 117 | class SublayerConnection(nn.Module): 118 | """ 119 | A residual connection followed by a layer norm. 120 | Note for code simplicity the norm is first as opposed to last. 121 | SublayerConnection: connect Multi-Head Attention and Feed Forward Layers 122 | """ 123 | def __init__(self, size, dropout): 124 | super(SublayerConnection, self).__init__() 125 | self.norm = LayerNorm(size) 126 | self.dropout = nn.Dropout(dropout) 127 | 128 | def forward(self, x, sublayer): 129 | "Apply residual connection to any sublayer with the same size." 130 | return x + self.dropout(sublayer(self.norm(x))) 131 | 132 | class FullyConnectedLayers(nn.Module): 133 | def __init__(self, input_dim, out_dim): 134 | super(FullyConnectedLayers, self).__init__() 135 | self.fc1 = nn.Linear(input_dim, input_dim) 136 | self.fc2 = nn.Linear(input_dim, out_dim) 137 | self.norm = LayerNorm(input_dim) 138 | 139 | def forward(self, x): 140 | return self.fc2(self.norm(self.fc1(x))) 141 | 142 | 143 | class PositionwiseFeedForward(nn.Module): 144 | def __init__(self, d_model, d_ff, dropout=0.1): 145 | super(PositionwiseFeedForward, self).__init__() 146 | self.w_1 = nn.Linear(d_model, d_ff) 147 | self.w_2 = nn.Linear(d_ff, d_model) 148 | self.dropout = nn.Dropout(dropout) 149 | 150 | def forward(self, x): 151 | h1 = self.w_1(x) 152 | h2 = self.dropout(h1) 153 | return self.w_2(h2) 154 | 155 | class Encoder(nn.Module): 156 | "Core encoder is a stack of N layers (blocks)" 157 | def __init__(self, layer, N): 158 | super(Encoder, self).__init__() 159 | self.layers = clones(layer, N) 160 | self.norm = LayerNorm(layer.size) 161 | 162 | def forward(self, x, mask): 163 | """ 164 | Pass the input (and mask) through each layer in turn. 165 | """ 166 | for layer in self.layers: 167 | x = layer(x, mask) 168 | return self.norm(x) 169 | 170 | class EncoderLayer(nn.Module): 171 | def __init__(self, size, self_attn, feed_forward, dropout): 172 | super(EncoderLayer, self).__init__() 173 | self.self_attn = self_attn 174 | self.feed_forward = feed_forward 175 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 176 | self.size = size # d_model 177 | 178 | def forward(self, x, mask): 179 | # X-embedding to Multi-head-Attention 180 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) # why use lambda? 181 | # X-embedding to feed-forwad nn 182 | return self.sublayer[1](x, self.feed_forward) 183 | 184 | class Transformer(nn.Module): 185 | def __init__(self, encoder, fc, src_embed): 186 | super(Transformer, self).__init__() 187 | self.encoder = encoder 188 | self.fc = fc 189 | self.src_embed = src_embed 190 | 191 | def encode(self, src, src_mask): 192 | return self.encoder(self.src_embed(src), src_mask) 193 | 194 | def forward(self, src, src_mask): 195 | "Take in and process masked src and target sequences." 196 | # encoder output will be the decoder's memory for decoding 197 | en = self.encode(src, src_mask) 198 | 199 | en = en[:, 0,:] 200 | 201 | return self.fc(en) 202 | 203 | def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h = 8, dropout=0.1): 204 | 205 | # d_model: dimension of Query, Key, Value 206 | # d_ff: neurons for FeedForward layer 207 | # h: num of head attention 208 | # N: number of transformer stacks. 209 | 210 | c = copy.deepcopy 211 | # Attention 212 | attn = MultiHeadedAttention(h, d_model) 213 | # FeedForward 214 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 215 | # Positional Encoding 216 | position = PositionalEncoding(d_model, dropout) 217 | # Fully connected layers 218 | fc = FullyConnectedLayers(d_model, tgt_vocab) 219 | # Transformer 220 | model = Transformer( 221 | Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), 222 | fc, 223 | nn.Sequential(Embeddings(d_model, src_vocab), c(position))) 224 | 225 | # This was important from their code. 226 | # Initialize parameters with Glorot / fan_avg. 227 | # Paper title: Understanding the difficulty of training deep feedforward neural networks Xavier 228 | 229 | for p in model.parameters(): 230 | if p.dim() > 1: 231 | nn.init.xavier_uniform_(p) 232 | return model 233 | 234 | 235 | class Decoder(nn.Module): 236 | def __init__(self, layer, N): 237 | "Generic N layer decoder with masking." 238 | super(Decoder, self).__init__() 239 | self.layers = clones(layer, N) 240 | self.norm = LayerNorm(layer.size) 241 | 242 | def forward(self, x, memory, src_mask, tgt_mask): 243 | """ 244 | Repeat decoder N times 245 | Decoderlayer get a input attention mask (src) 246 | and a output attention mask (tgt) + subsequent mask 247 | """ 248 | for layer in self.layers: 249 | x = layer(x, memory, src_mask, tgt_mask) 250 | return self.norm(x) 251 | 252 | 253 | class DecoderLayer(nn.Module): 254 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 255 | super(DecoderLayer, self).__init__() 256 | self.size = size 257 | self.self_attn = self_attn 258 | self.src_attn = src_attn 259 | self.feed_forward = feed_forward 260 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 261 | 262 | def forward(self, x, memory, src_mask, tgt_mask): 263 | m = memory # encoder output embedding 264 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 265 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 266 | # Context-Attention:q=decoder hidden,k,v from encoder hidden 267 | return self.sublayer[2](x, self.feed_forward) 268 | 269 | 270 | def subsequent_mask(size): 271 | "Mask out subsequent positions." 272 | attn_shape = (1, size, size) 273 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 274 | return torch.from_numpy(subsequent_mask) == 0 275 | 276 | class Transformer_I(nn.Module): 277 | def __init__(self, fc, decoder, tgt_embed, generator): 278 | super(Transformer_I, self).__init__() 279 | self.fc = fc 280 | self.decoder = decoder 281 | # self.src_embed = src_embed 282 | self.tgt_embed = tgt_embed 283 | self.generator = generator 284 | 285 | # def encode(self, src, src_mask): 286 | # return self.encoder(self.src_embed(src), src_mask) 287 | 288 | def decode(self, memory, src_mask, tgt, tgt_mask): 289 | return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) 290 | 291 | def forward(self, src, tgt, src_mask, tgt_mask): 292 | "Take in and process masked src and target sequences." 293 | # encoder output will be the decoder's memory for decoding 294 | return self.decode(self.fc(src), src_mask, tgt, tgt_mask) 295 | 296 | class Generator(nn.Module): 297 | def __init__(self, d_model, vocab): 298 | super(Generator, self).__init__() 299 | # decode: d_model to vocab mapping 300 | self.proj = nn.Linear(d_model, vocab) 301 | 302 | def forward(self, x): 303 | return F.log_softmax(self.proj(x), dim=-1) ##??? 304 | 305 | def make_model_I(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h = 8, dropout=0.1): 306 | 307 | # src_vocab: dim of spectrum 308 | # tgt_vocab: list of structures 309 | 310 | # d_model: dimension of Query, Key, Value 311 | # d_ff: neurons for FeedForward layer 312 | # h: num of head attention 313 | # N: number of transformer stacks. 314 | 315 | c = copy.deepcopy 316 | # Attention 317 | attn = MultiHeadedAttention(h, d_model) 318 | # FeedForward 319 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 320 | # Positional Encoding 321 | position = PositionalEncoding(d_model, dropout) 322 | # Fully connected layers 323 | fc = FullyConnectedLayers(src_vocab, d_model) 324 | # Transformer 325 | model = Transformer_I( 326 | fc, 327 | Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N), 328 | nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), 329 | Generator(d_model, tgt_vocab)) 330 | 331 | # This was important from their code. 332 | # Initialize parameters with Glorot / fan_avg. 333 | # Paper title: Understanding the difficulty of training deep feedforward neural networks Xavier 334 | 335 | for p in model.parameters(): 336 | if p.dim() > 1: 337 | nn.init.xavier_uniform_(p) 338 | return model 339 | -------------------------------------------------------------------------------- /optogpt/nk/Si_HZ.csv: -------------------------------------------------------------------------------- 1 | wl,n,k 2 | 0.38,4.0998779999999995,1.0945562 3 | 0.381,4.09293,1.0877290000000002 4 | 0.382,4.086036,1.0809068000000002 5 | 0.38299999999999995,4.079196,1.0740902 6 | 0.384,4.0724089999999995,1.0672801 7 | 0.385,4.065675,1.0604770000000001 8 | 0.386,4.058993,1.0536818000000001 9 | 0.387,4.052363,1.0468950000000001 10 | 0.38799999999999996,4.045783999999999,1.0401174000000002 11 | 0.389,4.039256,1.0333496 12 | 0.39,4.0327779999999995,1.0265921999999998 13 | 0.391,4.02635,1.0198459 14 | 0.392,4.019971,1.0131113 15 | 0.39299999999999996,4.0136400000000005,1.006389 16 | 0.39399999999999996,4.007358,0.9996796 17 | 0.395,4.0011230000000015,0.9929836 18 | 0.396,3.994936,0.9863017 19 | 0.397,3.9887949999999996,0.9796344000000001 20 | 0.39799999999999996,3.9827,0.9729821 21 | 0.39899999999999997,3.9766510000000004,0.9663455 22 | 0.4,3.970648,0.9597251 23 | 0.401,3.964689,0.9531214 24 | 0.402,3.958775,0.9465348 25 | 0.40299999999999997,3.952904,0.9399658999999999 26 | 0.40399999999999997,3.947078,0.9334150999999999 27 | 0.405,3.941294,0.9268829 28 | 0.406,3.935553,0.9203697 29 | 0.40700000000000003,3.929854,0.9138759 30 | 0.408,3.924197,0.9074021000000001 31 | 0.409,3.918582,0.9009485999999999 32 | 0.41,3.913008,0.8945158000000001 33 | 0.41100000000000003,3.9074739999999997,0.8881041 34 | 0.41200000000000003,3.9019800000000004,0.8817139 35 | 0.413,3.896527,0.8753456000000001 36 | 0.414,3.891113,0.8689995 37 | 0.415,3.885737,0.862676 38 | 0.41600000000000004,3.880401,0.8563754000000001 39 | 0.41700000000000004,3.8751029999999997,0.8500981 40 | 0.418,3.869843,0.8438445 41 | 0.419,3.864621,0.8376146999999999 42 | 0.42,3.859435,0.8314091 43 | 0.42100000000000004,3.8542870000000002,0.8252280999999999 44 | 0.42200000000000004,3.849175,0.8190719 45 | 0.423,3.8441,0.8129408 46 | 0.424,3.83906,0.8068350000000001 47 | 0.425,3.834056,0.8007549 48 | 0.426,3.8290870000000004,0.7947006 49 | 0.4270000000000001,3.824153,0.7886725 50 | 0.428,3.819254,0.7826706999999999 51 | 0.429,3.8143879999999997,0.7766956 52 | 0.43,3.809557,0.7707472 53 | 0.431,3.804759,0.7648259 54 | 0.4320000000000001,3.7999949999999996,0.7589318 55 | 0.433,3.795263,0.7530650999999999 56 | 0.434,3.790565,0.7472261 57 | 0.435,3.7858980000000004,0.7414149 58 | 0.436,3.7812639999999997,0.7356315999999999 59 | 0.4370000000000001,3.7766610000000003,0.7298765 60 | 0.43799999999999994,3.77209,0.7241496999999999 61 | 0.439,3.76755,0.7184514 62 | 0.44,3.7630410000000003,0.7127817 63 | 0.441,3.7585620000000004,0.7071407 64 | 0.442,3.754114,0.7015286999999999 65 | 0.44299999999999995,3.7496970000000003,0.6959456 66 | 0.444,3.745308,0.6903916999999999 67 | 0.445,3.74095,0.684867 68 | 0.446,3.736621,0.6793716999999999 69 | 0.447,3.7323199999999996,0.6739059000000001 70 | 0.44799999999999995,3.728049,0.6684696 71 | 0.449,3.7238059999999997,0.663063 72 | 0.45,3.7195910000000003,0.6576860999999999 73 | 0.451,3.715404,0.652339 74 | 0.452,3.711245,0.6470218000000001 75 | 0.45299999999999996,3.707114,0.6417345999999999 76 | 0.45399999999999996,3.70301,0.6364774000000001 77 | 0.455,3.6989330000000002,0.6312503 78 | 0.456,3.6948819999999998,0.6260532 79 | 0.457,3.690858,0.6208864000000001 80 | 0.45799999999999996,3.686861,0.6157498 81 | 0.45899999999999996,3.68289,0.6106434000000001 82 | 0.46,3.6789440000000004,0.6055673 83 | 0.461,3.6750239999999996,0.6005216 84 | 0.462,3.6711300000000002,0.5955060999999999 85 | 0.46299999999999997,3.667261,0.5905210999999999 86 | 0.46399999999999997,3.6634160000000002,0.5855664 87 | 0.465,3.659597,0.580642 88 | 0.466,3.6558019999999996,0.5757481 89 | 0.467,3.6520309999999996,0.5708845 90 | 0.46799999999999997,3.648285,0.5660513 91 | 0.469,3.6445629999999998,0.5612484999999999 92 | 0.47,3.640864,0.5564760999999999 93 | 0.47100000000000003,3.637189,0.5517339 94 | 0.47200000000000003,3.6335370000000013,0.5470222 95 | 0.473,3.6299080000000004,0.5423406999999999 96 | 0.474,3.6263019999999995,0.5376895 97 | 0.475,3.622719,0.5330686 98 | 0.47600000000000003,3.619159,0.5284778 99 | 0.47700000000000004,3.615621,0.5239173 100 | 0.478,3.612105,0.5193869000000001 101 | 0.479,3.6086110000000002,0.5148865 102 | 0.48,3.605139,0.5104162999999999 103 | 0.48100000000000004,3.601688,0.505976 104 | 0.48200000000000004,3.5982589999999997,0.5015657 105 | 0.483,3.5948519999999995,0.4971852 106 | 0.484,3.5914650000000004,0.49283459999999996 107 | 0.485,3.5880989999999997,0.4885138 108 | 0.486,3.584754,0.4842227 109 | 0.48700000000000004,3.58143,0.47996120000000003 110 | 0.488,3.578126,0.47572929999999997 111 | 0.489,3.574842,0.4715269 112 | 0.49,3.5715790000000003,0.46735390000000004 113 | 0.491,3.568335,0.4632103 114 | 0.4920000000000001,3.565111,0.4590959 115 | 0.493,3.561907,0.4550107 116 | 0.494,3.558722,0.4509545999999999 117 | 0.495,3.5555559999999997,0.4469276 118 | 0.496,3.55241,0.4429294 119 | 0.4970000000000001,3.549282,0.43896009999999996 120 | 0.498,3.5461730000000005,0.4350196 121 | 0.499,3.543083,0.43110770000000004 122 | 0.5,3.5400120000000004,0.42722429999999995 123 | 0.501,3.5369589999999995,0.42336949999999995 124 | 0.502,3.533924,0.4195429 125 | 0.503,3.530907,0.4157447 126 | 0.504,3.527908,0.41197449999999997 127 | 0.505,3.5249269999999995,0.4082325 128 | 0.506,3.521963,0.4045183 129 | 0.507,3.519018,0.400832 130 | 0.508,3.516089,0.3971734000000001 131 | 0.509,3.513178,0.3935425 132 | 0.51,3.5102839999999995,0.389939 133 | 0.511,3.5074059999999996,0.3863629 134 | 0.512,3.504546,0.3828140999999999 135 | 0.513,3.501702,0.37929250000000003 136 | 0.514,3.498875,0.3757978 137 | 0.515,3.496065,0.3723301 138 | 0.516,3.49327,0.36888920000000003 139 | 0.517,3.4904919999999997,0.365475 140 | 0.518,3.4877300000000004,0.3620873 141 | 0.519,3.484984,0.3587261 142 | 0.52,3.482254,0.35539109999999996 143 | 0.521,3.479539,0.3520824 144 | 0.522,3.47684,0.34879970000000005 145 | 0.523,3.474157,0.3455429 146 | 0.524,3.471489,0.3423119 147 | 0.525,3.468836,0.3391066 148 | 0.526,3.4661980000000003,0.3359268 149 | 0.527,3.463575,0.33277249999999997 150 | 0.528,3.4609669999999997,0.32964340000000003 151 | 0.529,3.4583739999999996,0.3265395 152 | 0.53,3.455796,0.3234606 153 | 0.531,3.453232,0.3204066 154 | 0.532,3.450682,0.3173773 155 | 0.5329999999999999,3.448147,0.3143727 156 | 0.534,3.4456260000000003,0.3113925 157 | 0.535,3.4431199999999995,0.3084367 158 | 0.536,3.4406269999999997,0.3055051 159 | 0.537,3.4381480000000004,0.30259759999999997 160 | 0.5379999999999999,3.435683,0.299714 161 | 0.539,3.433232,0.2968542 162 | 0.54,3.430794,0.29401809999999995 163 | 0.541,3.4283699999999997,0.29120559999999995 164 | 0.542,3.4259589999999998,0.2884164 165 | 0.5429999999999999,3.423561,0.28565050000000003 166 | 0.544,3.4211769999999997,0.28290770000000004 167 | 0.545,3.4188059999999996,0.28018790000000005 168 | 0.546,3.4164480000000004,0.277491 169 | 0.547,3.414103,0.27481669999999997 170 | 0.5479999999999999,3.4117699999999997,0.2721651 171 | 0.5489999999999999,3.40945,0.2695358 172 | 0.55,3.407143,0.2669289 173 | 0.551,3.404849,0.26434409999999997 174 | 0.552,3.402567,0.2617813 175 | 0.5529999999999999,3.400297,0.2592404 176 | 0.5539999999999999,3.39804,0.2567212 177 | 0.555,3.3957949999999997,0.2542236 178 | 0.556,3.393562,0.2517475 179 | 0.557,3.39134,0.24929279999999998 180 | 0.5579999999999999,3.389131,0.2468592 181 | 0.5589999999999999,3.3869339999999997,0.24444659999999999 182 | 0.56,3.3847480000000005,0.24205500000000002 183 | 0.561,3.382575,0.23968409999999998 184 | 0.562,3.380412,0.2373339 185 | 0.563,3.3782620000000003,0.23500410000000002 186 | 0.564,3.376122,0.2326947 187 | 0.565,3.373994,0.23040560000000002 188 | 0.5660000000000001,3.3718779999999997,0.22813649999999996 189 | 0.5670000000000001,3.369772,0.2258874 190 | 0.568,3.3676779999999997,0.223658 191 | 0.569,3.365594,0.2214484 192 | 0.57,3.363522,0.21925830000000002 193 | 0.5710000000000001,3.3614599999999997,0.21708760000000002 194 | 0.5720000000000001,3.35941,0.2149361 195 | 0.573,3.3573699999999995,0.2128038 196 | 0.574,3.35534,0.21069050000000009 197 | 0.575,3.3533209999999998,0.20859610000000006 198 | 0.5760000000000001,3.3513129999999998,0.20652039999999997 199 | 0.5770000000000001,3.3493150000000003,0.2044632 200 | 0.578,3.347328,0.2024246 201 | 0.579,3.3453510000000004,0.2004042 202 | 0.58,3.3433839999999995,0.1984021 203 | 0.581,3.3414269999999995,0.196418 204 | 0.5820000000000001,3.3394800000000004,0.1944519 205 | 0.583,3.337543,0.19250350000000002 206 | 0.584,3.3356169999999996,0.19057290000000002 207 | 0.585,3.3337,0.1886597 208 | 0.586,3.3317930000000002,0.186764 209 | 0.5870000000000001,3.3298949999999996,0.1848856 210 | 0.588,3.328007,0.1830243 211 | 0.589,3.326129,0.18118 212 | 0.59,3.324261,0.1793526 213 | 0.591,3.322402,0.177542 214 | 0.5920000000000001,3.3205519999999997,0.17574800000000002 215 | 0.593,3.318712,0.1739705 216 | 0.594,3.316881,0.17220950000000002 217 | 0.595,3.3150589999999998,0.1704646 218 | 0.596,3.313246,0.168736 219 | 0.597,3.311443,0.16702329999999999 220 | 0.598,3.309648,0.16532650000000002 221 | 0.599,3.3078629999999998,0.1636455 222 | 0.6,3.306086,0.16198010000000002 223 | 0.601,3.304318,0.16033029999999998 224 | 0.602,3.3025589999999996,0.1586958 225 | 0.603,3.3008089999999997,0.1570766 226 | 0.604,3.299068,0.15547260000000002 227 | 0.605,3.297335,0.1538836 228 | 0.606,3.29561,0.15230950000000001 229 | 0.607,3.2938949999999996,0.15075020000000006 230 | 0.608,3.292187,0.1492056 231 | 0.609,3.290488,0.14767550000000002 232 | 0.61,3.2887980000000003,0.14615989999999998 233 | 0.611,3.287116,0.14465870000000006 234 | 0.612,3.2854410000000005,0.1431716 235 | 0.613,3.283776,0.1416986 236 | 0.614,3.282118,0.14023970000000002 237 | 0.615,3.280468,0.1387945 238 | 0.616,3.2788269999999997,0.13736320000000002 239 | 0.617,3.277193,0.1359454 240 | 0.618,3.2755669999999997,0.13454129999999995 241 | 0.619,3.273949,0.1331505 242 | 0.62,3.272339,0.131773 243 | 0.621,3.270737,0.1304087 244 | 0.622,3.269143,0.1290575 245 | 0.623,3.2675560000000003,0.1277193 246 | 0.624,3.2659759999999998,0.126394 247 | 0.625,3.264405,0.12508139999999998 248 | 0.626,3.262841,0.1237815 249 | 0.627,3.261284,0.12249410000000001 250 | 0.628,3.259735,0.1212192 251 | 0.629,3.258193,0.1199566 252 | 0.63,3.2566580000000003,0.11870619999999996 253 | 0.631,3.255131,0.117468 254 | 0.632,3.2536110000000003,0.1162418 255 | 0.633,3.252098,0.1150275 256 | 0.634,3.250593,0.113825 257 | 0.635,3.249094,0.1126343 258 | 0.636,3.247603,0.11145510000000004 259 | 0.637,3.246118,0.1102875 260 | 0.638,3.244641,0.10913130000000001 261 | 0.639,3.24317,0.10798640000000001 262 | 0.64,3.241707,0.1068528 263 | 0.6409999999999999,3.24025,0.10573030000000001 264 | 0.642,3.2388,0.10461880000000001 265 | 0.643,3.237357,0.1035182 266 | 0.644,3.23592,0.1024285 267 | 0.645,3.23449,0.1013495 268 | 0.6459999999999999,3.2330669999999997,0.1002811 269 | 0.647,3.23165,0.0992233 270 | 0.648,3.2302400000000002,0.098176 271 | 0.649,3.228837,0.097139 272 | 0.65,3.22744,0.09611230000000001 273 | 0.6509999999999999,3.226049,0.09509580000000001 274 | 0.652,3.2246650000000003,0.09408939999999999 275 | 0.653,3.223287,0.093093 276 | 0.654,3.221915,0.09210650000000001 277 | 0.655,3.22055,0.0911299 278 | 0.6559999999999999,3.219191,0.09016289999999999 279 | 0.657,3.217838,0.08920560000000001 280 | 0.6579999999999999,3.216491,0.0882579 281 | 0.659,3.2151509999999996,0.0873197 282 | 0.66,3.213816,0.08639089999999999 283 | 0.6609999999999999,3.212488,0.08547130000000001 284 | 0.662,3.2111650000000003,0.084561 285 | 0.6629999999999999,3.209849,0.0836598 286 | 0.664,3.208539,0.0827677 287 | 0.665,3.207234,0.0818846 288 | 0.6659999999999999,3.2059349999999998,0.08101030000000001 289 | 0.667,3.2046419999999998,0.08014489999999999 290 | 0.6679999999999999,3.2033549999999997,0.07928819999999999 291 | 0.669,3.202074,0.07844010000000001 292 | 0.67,3.200798,0.0776007 293 | 0.6709999999999999,3.199529,0.0767697 294 | 0.672,3.198264,0.0759471 295 | 0.6729999999999999,3.197006,0.0751328 296 | 0.674,3.195753,0.07432689999999999 297 | 0.675,3.194505,0.0735291 298 | 0.6759999999999999,3.193263,0.0727394 299 | 0.677,3.192027,0.0719577 300 | 0.6779999999999999,3.1907959999999997,0.071184 301 | 0.679,3.189571,0.0704182 302 | 0.68,3.188351,0.06966019999999999 303 | 0.6809999999999999,3.187136,0.0689099 304 | 0.682,3.185926,0.0681673 305 | 0.6829999999999999,3.1847220000000003,0.0674323 306 | 0.684,3.1835240000000002,0.0667049 307 | 0.685,3.18233,0.06598480000000001 308 | 0.6859999999999999,3.181142,0.0652722 309 | 0.687,3.179959,0.06456680000000001 310 | 0.688,3.1787810000000003,0.0638687 311 | 0.6890000000000001,3.177608,0.0631778 312 | 0.69,3.17644,0.062494000000000015 313 | 0.691,3.175277,0.0618172 314 | 0.6920000000000001,3.1741200000000003,0.06114740000000001 315 | 0.693,3.172967,0.060484500000000004 316 | 0.6940000000000001,3.1718189999999997,0.0598284 317 | 0.695,3.170677,0.059179099999999984 318 | 0.696,3.169539,0.0585366 319 | 0.6970000000000001,3.168406,0.057900599999999997 320 | 0.698,3.167278,0.05727130000000001 321 | 0.6990000000000001,3.166155,0.056648500000000004 322 | 0.7,3.165036,0.0560321 323 | 0.701,3.163923,0.0554221 324 | 0.7020000000000001,3.162814,0.05481849999999999 325 | 0.703,3.16171,0.0542211 326 | 0.7040000000000001,3.1606099999999997,0.05363 327 | 0.705,3.1595150000000003,0.053045 328 | 0.706,3.158425,0.052466 329 | 0.7070000000000001,3.15734,0.0518932 330 | 0.708,3.156259,0.051326300000000005 331 | 0.7090000000000001,3.155183,0.0507653 332 | 0.71,3.154111,0.050210199999999997 333 | 0.711,3.153044,0.0496608 334 | 0.7120000000000001,3.151981,0.0491173 335 | 0.713,3.150923,0.0485794 336 | 0.7140000000000001,3.149869,0.0480471 337 | 0.715,3.148819,0.0475205 338 | 0.716,3.147774,0.04699930000000002 339 | 0.7170000000000001,3.146734,0.046483699999999996 340 | 0.718,3.145697,0.045973400000000005 341 | 0.7190000000000001,3.1446650000000003,0.0454685 342 | 0.72,3.143638,0.045 343 | 0.721,3.142614,0.044474599999999996 344 | 0.722,3.141595,0.043985500000000004 345 | 0.723,3.14058,0.043501599999999994 346 | 0.7240000000000001,3.139569,0.0430227 347 | 0.725,3.138563,0.0425489 348 | 0.726,3.1375599999999997,0.0420801 349 | 0.727,3.136562,0.0416163 350 | 0.728,3.135568,0.041157400000000004 351 | 0.7290000000000001,3.134578,0.040703300000000005 352 | 0.73,3.1335919999999997,0.0403 353 | 0.731,3.1326099999999997,0.0398 354 | 0.732,3.1316319999999997,0.0393698 355 | 0.733,3.1306580000000004,0.0389 356 | 0.7340000000000001,3.1296880000000002,0.0385 357 | 0.735,3.1287220000000002,0.0381 358 | 0.736,3.12776,0.0377 359 | 0.737,3.126802,0.0372 360 | 0.738,3.1258470000000003,0.0368275 361 | 0.7390000000000001,3.1248970000000003,0.036419400000000005 362 | 0.74,3.123951,0.0360157 363 | 0.741,3.123008,0.0356 364 | 0.742,3.1220689999999998,0.0352211 365 | 0.743,3.121134,0.0348 366 | 0.7440000000000001,3.120202,0.034443400000000006 367 | 0.745,3.119275,0.0340608 368 | 0.746,3.1183509999999997,0.0337 369 | 0.747,3.1174310000000003,0.0333 370 | 0.748,3.116514,0.0329373 371 | 0.7490000000000001,3.115601,0.0325708 372 | 0.75,3.114692,0.0322082 373 | 0.7509999999999999,3.113787,0.0318 374 | 0.752,3.112885,0.0314947 375 | 0.753,3.111986,0.0311437 376 | 0.754,3.1110919999999997,0.030796499999999997 377 | 0.755,3.110201,0.030452999999999997 378 | 0.7559999999999999,3.109313,0.0301 379 | 0.757,3.108429,0.029776999999999998 380 | 0.758,3.107548,0.0294445 381 | 0.759,3.106671,0.0291 382 | 0.76,3.105797,0.0288 383 | 0.7609999999999999,3.104927,0.0285 384 | 0.762,3.10406,0.0281 385 | 0.763,3.103197,0.0278349 386 | 0.764,3.102336,0.0275 387 | 0.765,3.10148,0.0272 388 | 0.7659999999999999,3.1006259999999997,0.0269103 389 | 0.767,3.099776,0.0266087 390 | 0.768,3.0989299999999997,0.0263 391 | 0.769,3.098086,0.0260154 392 | 0.77,3.0972459999999997,0.0257235 393 | 0.7709999999999999,3.096409,0.0254 394 | 0.772,3.095576,0.0251493 395 | 0.773,3.094745,0.0248668 396 | 0.774,3.093918,0.0246 397 | 0.775,3.093094,0.0243111 398 | 0.7759999999999999,3.092274,0.024 399 | 0.777,3.091456,0.0237674 400 | 0.778,3.090642,0.0235 401 | 0.779,3.08983,0.0232 402 | 0.78,3.089022,0.023 403 | 5.0,3.0,0.0 404 | -------------------------------------------------------------------------------- /self_improving/prepare_aug_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../optogpt') 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import pickle as pkl 9 | from core.datasets.sim import load_materials, spectrum 10 | from core.models.transformer import make_model_I, subsequent_mask 11 | from core.trains.train import Variable 12 | from generate_dev_data import generate_dev_data 13 | import multiprocessing as mp 14 | 15 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | DATABASE = '../optogpt/nk' 18 | illuminant = SDS_ILLUMINANTS['D65'] 19 | cmfs = MSDS_CMFS['CIE 1931 2 Degree Standard Observer'] 20 | 21 | mats = ['Al', 'Al2O3', 'AlN', 'Ge', 'HfO2', 'ITO', 'MgF2', 'MgO', 'Si', 'Si3N4', 'SiO2', 'Ta2O5', 'TiN', 'TiO2', 'ZnO', 'ZnS', 'ZnSe', 'Glass_Substrate'] 22 | thicks = [str(i) for i in range(5, 255, 5)] 23 | 24 | lamda_low = 0.4 25 | lamda_high = 1.1 26 | wavelengths = np.arange(lamda_low, lamda_high+1e-3, 0.01) 27 | 28 | class Prepare_Augment_Data: 29 | 30 | def __init__(self, model_path, decoding_method = "Greedy Decode",error_type = "MAE",top_k = 10, top_p = 0.9, kp_num = 50, keep_num = 20): 31 | print("Initializing...") 32 | # load the model 33 | self.model,self.args = self.load_model(model_path) 34 | 35 | self.error_type = error_type 36 | 37 | # load data 38 | self.load_data_file(model_path) 39 | 40 | # get the id_word_dict 41 | self.id_word_dict = self.get_id_word_dict() 42 | 43 | print("Calculating original spec...") 44 | self.original_spec = self.get_original_spec() 45 | 46 | self.top_k = top_k 47 | self.top_p = top_p 48 | self.kp_num = kp_num 49 | 50 | print("Calculating designed structure...") 51 | if decoding_method == "Greedy Decode": 52 | self.all_mat, self.all_thick, self.designed_struct = self.get_designed_struct(self.greedy_decode) 53 | elif decoding_method == "TOP-KP Decode": 54 | # repeat the decoding process for kp_num times 55 | self.all_mat, self.all_thick, self.designed_struct = [],[],[] 56 | for i in range(self.kp_num): 57 | print("Decoding process: {}/{}".format(i+1,self.kp_num)) 58 | mat, thick, struct = self.get_designed_struct(self.top_kp_decode) 59 | self.all_mat.append(mat) 60 | self.all_thick.append(thick) 61 | self.designed_struct.append(struct) 62 | # flatten the list 63 | self.all_mat = [inner for outer in self.all_mat for inner in outer] 64 | self.all_thick = [inner for outer in self.all_thick for inner in outer] 65 | self.designed_struct = [inner for outer in self.designed_struct for inner in outer] 66 | #copy original_spec kp_num times 67 | all_original_spec = [] 68 | for i in range(self.kp_num): 69 | all_original_spec.append(self.original_spec) 70 | self.original_spec = [inner for outer in all_original_spec for inner in outer] 71 | elif decoding_method == "TOP-KP Decode_v2": 72 | # repeat the decoding process for kp_num times 73 | self.all_mat, self.all_thick, self.designed_struct = [],[],[] 74 | all_original_spec = [] 75 | cnt = 0 76 | for spec in self.original_spec: 77 | tmp_mat, tmp_thick, tmp_struct = [], [], [] 78 | for i in range(self.kp_num): 79 | with torch.no_grad(): 80 | struct = self.top_kp_decode(spec, 10, 'BOS') 81 | mat, thick = self.return_mat_thick(struct) 82 | tmp_mat.append(mat) 83 | tmp_thick.append(thick) 84 | tmp_struct.append(struct) 85 | # calculate the error for each decoding 86 | tmp_error = [] 87 | for i in range(self.kp_num): 88 | int_list = [int(item) for item in tmp_thick[i]] 89 | tmp_error.append(self.calc_single_error(spec, spectrum(tmp_mat[i],int_list))) 90 | # keep only the keep_num best decoding 91 | best_idx = np.argsort(tmp_error)[:keep_num] 92 | self.all_mat.append([tmp_mat[i] for i in best_idx]) 93 | self.all_thick.append([tmp_thick[i] for i in best_idx]) 94 | self.designed_struct.append([tmp_struct[i] for i in best_idx]) 95 | all_original_spec.append([spec]*keep_num) 96 | cnt += 1 97 | if cnt % 100 == 0: 98 | print("Decoding process: {}/{}".format(cnt,len(self.original_spec))) 99 | # flatten the list 100 | self.all_mat = [inner for outer in self.all_mat for inner in outer] 101 | self.all_thick = [inner for outer in self.all_thick for inner in outer] 102 | self.designed_struct = [inner for outer in self.designed_struct for inner in outer] 103 | self.original_spec = all_original_spec 104 | self.original_spec = [inner for outer in all_original_spec for inner in outer] 105 | elif decoding_method == 'Beam Search': 106 | self.all_mat, self.all_thick, self.designed_struct = [],[],[] 107 | cnt = 0 108 | for spec in self.original_spec: 109 | with torch.no_grad(): 110 | if cnt % 100 == 0: 111 | print("Decoding process: {}/{}".format(cnt,len(self.original_spec))) 112 | cnt += 1 113 | struct = self.beam_search_decode(spec, 10, 'BOS',keep_num) 114 | for i in range(keep_num): 115 | mat, thick = self.return_mat_thick(struct[i]) 116 | self.all_mat.append(mat) 117 | self.all_thick.append(thick) 118 | self.designed_struct.append(struct[i]) 119 | # duplicate self.original_spec keep_num times like [spec1,spec1,spec1,spec2,spec2,spec2,...] 120 | all_original_spec = [] 121 | for spec in self.original_spec: 122 | for i in range(keep_num): 123 | all_original_spec.append(spec) 124 | self.original_spec = all_original_spec 125 | 126 | 127 | print("Calculating designed spec...") 128 | self.designed_spec = self.simulate_spec() 129 | 130 | print("Calculating Error...") 131 | self.error = self.calc_error() 132 | 133 | 134 | # helper function 135 | def return_mat_thick(self, struc_list): 136 | materials = [] 137 | thickness = [] 138 | for struc_ in struc_list: 139 | materials.append(struc_.split('_')[0]) 140 | thickness.append(struc_.split('_')[1]) 141 | 142 | return materials, thickness 143 | 144 | def get_id_word_dict(self): 145 | # translate the encoding back to structure 146 | word_id_dict = self.struc_word_dict 147 | # make a reverse dictionary 148 | id_word_dict = {} 149 | for key, value in word_id_dict.items(): 150 | id_word_dict[value] = key 151 | return id_word_dict 152 | 153 | #translate an array of word ids to words 154 | def translate(self, word_ids, id_word_dict): 155 | word_ids = word_ids.to('cpu').numpy() 156 | words = [] 157 | for i in word_ids: 158 | words.append(id_word_dict[i]) 159 | return words 160 | 161 | def load_model(self, model_path): 162 | #load the model 163 | a = torch.load(model_path) 164 | args = a['configs'] 165 | torch.manual_seed(args.seeds) 166 | np.random.seed(args.seeds) 167 | model = make_model_I( 168 | args.spec_dim, 169 | args.struc_dim, 170 | args.layers, 171 | args.d_model, 172 | args.d_ff, 173 | args.head_num, 174 | args.dropout 175 | ).to(DEVICE) 176 | 177 | model.load_state_dict(a['model_state_dict']) 178 | return model, args 179 | 180 | def load_data_file(self, model_path): 181 | # load the training and spec data 182 | a = torch.load(model_path) 183 | self.train_spec = generate_dev_data() 184 | self.struc_word_dict, self.struc_index_dict = a['configs'].struc_word_dict, a['configs'].struc_index_dict 185 | 186 | return 187 | 188 | def get_original_spec(self): 189 | return self.train_spec 190 | 191 | def get_designed_struct(self, decoding_method): 192 | all_mat = [] 193 | all_thick = [] 194 | designed_struc = [] 195 | 196 | with torch.no_grad(): 197 | for src in self.original_spec: 198 | crt_res = decoding_method(list(src), 10, 'BOS') 199 | material, thickness = self.return_mat_thick(crt_res) 200 | all_mat.append(material) 201 | all_thick.append(thickness) 202 | designed_struc.append(crt_res) 203 | return all_mat, all_thick, designed_struc 204 | 205 | def simulate_spec(self): 206 | NUM_CORES = min(mp.cpu_count(), 16) # Reasonable default 207 | DATABASE = './nk' 208 | nk_dict = load_materials(all_mats=mats, wavelengths=wavelengths, DATABASE=DATABASE) 209 | 210 | from multiprocessing import Pool 211 | args_for_starmap = [(mat, thick, 's', 0, wavelengths, nk_dict, 'Glass_Substrate', 500000) 212 | for mat, thick in zip(self.all_mat, self.all_thick)] 213 | 214 | # Create a pool and use starmap 215 | with Pool(NUM_CORES) as pool: 216 | spec_res = pool.starmap(spectrum, args_for_starmap) 217 | pool.close() 218 | 219 | return np.array(spec_res) 220 | 221 | def calc_error(self): 222 | error = [] 223 | for i in range(len(self.original_spec)): 224 | if self.error_type == "MAE": 225 | error.append(np.mean(np.abs(self.original_spec[i] - self.designed_spec[i]))) 226 | elif self.error_type == "MSE": 227 | error.append(np.mean(np.square(self.original_spec[i] - self.designed_spec[i]))) 228 | return error 229 | 230 | def calc_single_error(self, spec, designed_spec): 231 | if self.error_type == "MAE": 232 | return np.mean(np.abs(spec - designed_spec)) 233 | elif self.error_type == "MSE": 234 | return np.mean(np.square(spec - designed_spec)) 235 | 236 | def greedy_decode(self, spec_target, max_len, start_symbol, start_mat = None): 237 | """ 238 | use greedy decode to generate text 239 | """ 240 | # init 1×1 tensor as prediction,fill in ('BOS')id, type: (LongTensor) 241 | start_symbol = self.struc_word_dict[start_symbol] 242 | ys = torch.ones(1, 1).fill_(start_symbol).type(torch.LongTensor).to(DEVICE) 243 | 244 | if start_mat: 245 | start_mat = self.struc_word_dict[start_mat] 246 | ys = torch.tensor([[start_symbol, start_mat]]).type(torch.LongTensor).to(DEVICE) 247 | struc_design = [start_mat] 248 | else: 249 | struc_design = [] 250 | 251 | # process src 252 | src = torch.tensor([spec_target]).unsqueeze(0).float().to(DEVICE) 253 | src_mask = None 254 | 255 | struc_design = [] 256 | probs = [] 257 | for i in range(max_len-1): 258 | # decode one by one 259 | trg_mask = Variable(subsequent_mask(ys.size(1)).type_as(src.data)) 260 | 261 | out = self.model(src.to(DEVICE), Variable(ys), src_mask, trg_mask.to(DEVICE)) 262 | 263 | # out to log_softmax 264 | prob = self.model.generator(out[:, -1]) 265 | probs.append(prob[0, :].to('cpu').tolist()) 266 | 267 | # get the max-prob id 268 | _, next_word = torch.max(prob, dim = 1) 269 | next_word = next_word.data[0] 270 | # concatnate with early predictions 271 | ys = torch.cat([ys,torch.ones(1, 1).type(torch.LongTensor).fill_(next_word).to(DEVICE)], dim=1) 272 | sym = self.struc_index_dict[next_word.to('cpu').item()] 273 | if sym != 'EOS': 274 | struc_design.append(sym) 275 | else: 276 | break 277 | 278 | return struc_design 279 | 280 | def top_kp_decode(self, spec_target, max_len, start_symbol, start_mat=None): 281 | """ 282 | Use top-k and top-p (top-kp) decode to generate text 283 | """ 284 | # process src 285 | src = torch.tensor([spec_target]).unsqueeze(0).float().to(DEVICE) 286 | src_mask = None 287 | 288 | probs = [] 289 | start_symbol = self.struc_word_dict[start_symbol] 290 | ys = torch.ones(1, 1).fill_(start_symbol).type(torch.LongTensor).to(DEVICE) 291 | 292 | if start_mat: 293 | start_mat = self.struc_word_dict[start_mat] 294 | ys = torch.tensor([[start_symbol, start_mat]]).type(torch.LongTensor).to(DEVICE) 295 | struc_design = [start_mat] 296 | else: 297 | struc_design = [] 298 | 299 | src = torch.tensor([spec_target]).unsqueeze(0).float().to(DEVICE) 300 | src_mask = None 301 | 302 | for i in range(max_len - 1): 303 | trg_mask = Variable(subsequent_mask(ys.size(1)).type_as(src.data)) 304 | 305 | out = self.model(src.to(DEVICE), Variable(ys), src_mask, trg_mask.to(DEVICE)) 306 | 307 | # out to log_softmax 308 | prob = self.model.generator(out[:, -1]).exp().to('cpu') 309 | 310 | prob_sort = torch.argsort(prob, descending=True) 311 | 312 | prob_item_select = [] 313 | prob_total = 0 314 | i = 0 315 | while prob_total < self.top_p and len(prob_item_select) < min(int(self.top_k), prob.size(1)): 316 | mat_design = self.struc_index_dict[prob_sort[0, i].item()].split('_')[0] 317 | prob_item_select.append(prob_sort[0, i].item()) 318 | prob_total += prob[0, prob_sort[0, i]].item() 319 | i += 1 320 | prob_select = [prob[0, i].item() for i in prob_item_select] 321 | probs.append(prob_item_select + prob_select) 322 | 323 | temp_sum = sum(prob_select) 324 | prob_select = [i/temp_sum for i in prob_select] 325 | 326 | next_word = np.random.choice(prob_item_select, p=prob_select) 327 | 328 | # concatnate with early predictions 329 | ys = torch.cat([ys,torch.ones(1, 1).type(torch.LongTensor).fill_(next_word).to(DEVICE)], dim=1) 330 | sym = self.struc_index_dict[next_word] 331 | 332 | if sym != 'EOS': 333 | struc_design.append(sym) 334 | else: 335 | break 336 | 337 | return struc_design 338 | 339 | def beam_search_decode(self, spec_target, max_len, start_symbol, beam_width=5, start_mat=None): 340 | """ 341 | Use beam search decode to generate text 342 | """ 343 | src = torch.tensor([spec_target]).unsqueeze(0).float().to(DEVICE) 344 | src_mask = None 345 | 346 | start_symbol = self.struc_word_dict[start_symbol] 347 | initial_beam = { 348 | 'ys': torch.ones(1, 1).fill_(start_symbol).type(torch.LongTensor).to(DEVICE), 349 | 'score': 0, 350 | 'struc_design': [] 351 | } 352 | 353 | if start_mat: 354 | start_mat = self.struc_word_dict[start_mat] 355 | initial_beam['ys'] = torch.tensor([[start_symbol, start_mat]]).type(torch.LongTensor).to(DEVICE) 356 | initial_beam['struc_design'] = [start_mat] 357 | 358 | beams = [initial_beam] 359 | 360 | for _ in range(max_len - 1): 361 | all_candidates = [] 362 | for beam in beams: 363 | trg_mask = Variable(subsequent_mask(beam['ys'].size(1)).type_as(src.data)) 364 | out = self.model(src, Variable(beam['ys']), src_mask, trg_mask) 365 | prob = self.model.generator(out[:, -1]).softmax(dim=-1) 366 | 367 | # Get top beam_width candidates for this beam 368 | top_probs, top_idxs = prob.topk(beam_width) 369 | 370 | for i in range(beam_width): 371 | next_word = top_idxs[0][i].item() 372 | score = beam['score'] - torch.log(top_probs[0][i]) # Use log prob for numerical stability 373 | ys = torch.cat([beam['ys'], torch.tensor([[next_word]], device=DEVICE)], dim=1) 374 | sym = self.struc_index_dict[next_word] 375 | struc_design = beam['struc_design'] + [sym] if sym != 'EOS' else beam['struc_design'] 376 | 377 | candidate = { 378 | 'ys': ys, 379 | 'score': score, 380 | 'struc_design': struc_design 381 | } 382 | all_candidates.append(candidate) 383 | 384 | # Sort all candidates by score and select top beam_width 385 | beams = sorted(all_candidates, key=lambda x: x['score'])[:beam_width] 386 | 387 | # Check if all beams ended with EOS 388 | if all(self.struc_index_dict[beam['ys'][0, -1].item()] == 'EOS' for beam in beams): 389 | break 390 | 391 | return [beam['struc_design'] for beam in beams] 392 | -------------------------------------------------------------------------------- /optogpt/data_conversion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "id": "2bb3f0fe", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pandas as pd \n", 11 | "import numpy as np\n", 12 | "import os\n", 13 | "import pickle as pkl" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "4268e4f7", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "mode = 'train' # or 'test\n", 24 | "\n", 25 | "data_path = 'optogpt_data/' + mode \n", 26 | "\n", 27 | "data_names = os.listdir(data_path)\n", 28 | "\n", 29 | "df_list = []\n", 30 | "\n", 31 | "for file in data_names:\n", 32 | " df_ = pd.read_csv(os.path.join(data_path, file))\n", 33 | " df_list.append(df_)\n", 34 | " \n", 35 | "df = pd.concat(df_list, ignore_index=True)\n" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 5, 41 | "id": "56297afb", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/html": [ 47 | "
\n", 48 | "\n", 61 | "\n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | "
Unnamed: 0structurenum_layersR_400nmR_410nmR_420nmR_430nmR_440nmR_450nmR_460nm...T_1010nmT_1020nmT_1030nmT_1040nmT_1050nmT_1060nmT_1070nmT_1080nmT_1090nmT_1100nm
00['Ta2O5_300', 'AlN_320', 'Ta2O5_340', 'HfO2_40...170.1431540.2704090.0682580.2031400.2077520.1339080.256749...7.279138e-206.708738e-204.977434e-203.349879e-202.357321e-201.845059e-201.634083e-201.645585e-201.893093e-202.501610e-20
11['MgO_270', 'TiO2_40', 'ZnO_190', 'HfO2_380', ...170.1367120.1791320.0946690.1003620.2212890.3008300.291660...2.466956e-012.217513e-012.162272e-012.241115e-012.363959e-012.409754e-012.310539e-012.122547e-011.949739e-011.856023e-01
22['Ge_290', 'Si3N4_10', 'Al_240', 'Si3N4_70', '...170.4682740.4663900.4649400.4631120.4610490.4588920.456624...5.828044e-239.855431e-231.372599e-221.465879e-221.327228e-221.149973e-229.994108e-238.733339e-237.615815e-236.604243e-23
33['Ta2O5_240', 'ITO_380', 'SiO2_280', 'ZnO_480'...190.1346390.2240110.1607320.2154280.1713780.0641310.113954...1.791174e-012.166785e-012.527115e-012.603744e-012.430273e-012.243570e-012.154357e-012.151012e-012.178291e-012.189599e-01
44['MgF2_180', 'MgO_260', 'Al_80', 'TiN_450', 'T...190.6512190.6592210.7054390.7593900.8024100.8316430.849628...3.709293e-183.902026e-183.491132e-182.958163e-182.665547e-182.660998e-182.863903e-183.030304e-182.821330e-182.301221e-18
..................................................................
1999995999995['SiO2_250', 'TiO2_230', 'SiO2_280', 'Al2O3_46...90.2482470.3081160.1343440.1070440.1782230.2897470.185345...7.265530e-017.134139e-016.981519e-016.820254e-016.665148e-016.529206e-016.421852e-016.347691e-016.306720e-016.294758e-01
1999996999996['ZnO_390', 'Ta2O5_230', 'ZnSe_280', 'SiO2_40'...200.2125420.1668140.0594360.1295340.2398940.2118350.077812...2.165693e-193.526536e-195.911565e-199.458620e-191.289738e-181.388152e-181.214964e-189.433717e-196.990779e-195.163660e-19
1999997999997['HfO2_470', 'MgF2_380', 'SiO2_180', 'MgO_190'...180.0623300.0579410.1757880.3404790.2081170.3087020.236718...1.499599e-011.784428e-011.989469e-012.081872e-012.118915e-012.082258e-011.912477e-011.665002e-011.461210e-011.359751e-01
1999998999998['ZnS_30', 'Ta2O5_180', 'Si_240', 'TiO2_130', ...160.0786940.1022080.1516830.2087800.2601420.2950360.311392...3.359812e-213.639918e-214.014346e-214.422215e-214.775591e-214.983909e-214.991032e-214.789933e-214.425239e-213.967554e-21
1999999999999['Si3N4_230', 'Al2O3_300', 'AlN_380', 'MgO_190...180.2235630.3607720.2460680.0192240.0700390.2224330.200406...1.674699e-011.698070e-011.845447e-012.087810e-012.312235e-012.395830e-012.362823e-012.329998e-012.346766e-012.354661e-01
\n", 355 | "

2000000 rows × 145 columns

\n", 356 | "
" 357 | ], 358 | "text/plain": [ 359 | " Unnamed: 0 structure \\\n", 360 | "0 0 ['Ta2O5_300', 'AlN_320', 'Ta2O5_340', 'HfO2_40... \n", 361 | "1 1 ['MgO_270', 'TiO2_40', 'ZnO_190', 'HfO2_380', ... \n", 362 | "2 2 ['Ge_290', 'Si3N4_10', 'Al_240', 'Si3N4_70', '... \n", 363 | "3 3 ['Ta2O5_240', 'ITO_380', 'SiO2_280', 'ZnO_480'... \n", 364 | "4 4 ['MgF2_180', 'MgO_260', 'Al_80', 'TiN_450', 'T... \n", 365 | "... ... ... \n", 366 | "1999995 999995 ['SiO2_250', 'TiO2_230', 'SiO2_280', 'Al2O3_46... \n", 367 | "1999996 999996 ['ZnO_390', 'Ta2O5_230', 'ZnSe_280', 'SiO2_40'... \n", 368 | "1999997 999997 ['HfO2_470', 'MgF2_380', 'SiO2_180', 'MgO_190'... \n", 369 | "1999998 999998 ['ZnS_30', 'Ta2O5_180', 'Si_240', 'TiO2_130', ... \n", 370 | "1999999 999999 ['Si3N4_230', 'Al2O3_300', 'AlN_380', 'MgO_190... \n", 371 | "\n", 372 | " num_layers R_400nm R_410nm R_420nm R_430nm R_440nm \\\n", 373 | "0 17 0.143154 0.270409 0.068258 0.203140 0.207752 \n", 374 | "1 17 0.136712 0.179132 0.094669 0.100362 0.221289 \n", 375 | "2 17 0.468274 0.466390 0.464940 0.463112 0.461049 \n", 376 | "3 19 0.134639 0.224011 0.160732 0.215428 0.171378 \n", 377 | "4 19 0.651219 0.659221 0.705439 0.759390 0.802410 \n", 378 | "... ... ... ... ... ... ... \n", 379 | "1999995 9 0.248247 0.308116 0.134344 0.107044 0.178223 \n", 380 | "1999996 20 0.212542 0.166814 0.059436 0.129534 0.239894 \n", 381 | "1999997 18 0.062330 0.057941 0.175788 0.340479 0.208117 \n", 382 | "1999998 16 0.078694 0.102208 0.151683 0.208780 0.260142 \n", 383 | "1999999 18 0.223563 0.360772 0.246068 0.019224 0.070039 \n", 384 | "\n", 385 | " R_450nm R_460nm ... T_1010nm T_1020nm T_1030nm \\\n", 386 | "0 0.133908 0.256749 ... 7.279138e-20 6.708738e-20 4.977434e-20 \n", 387 | "1 0.300830 0.291660 ... 2.466956e-01 2.217513e-01 2.162272e-01 \n", 388 | "2 0.458892 0.456624 ... 5.828044e-23 9.855431e-23 1.372599e-22 \n", 389 | "3 0.064131 0.113954 ... 1.791174e-01 2.166785e-01 2.527115e-01 \n", 390 | "4 0.831643 0.849628 ... 3.709293e-18 3.902026e-18 3.491132e-18 \n", 391 | "... ... ... ... ... ... ... \n", 392 | "1999995 0.289747 0.185345 ... 7.265530e-01 7.134139e-01 6.981519e-01 \n", 393 | "1999996 0.211835 0.077812 ... 2.165693e-19 3.526536e-19 5.911565e-19 \n", 394 | "1999997 0.308702 0.236718 ... 1.499599e-01 1.784428e-01 1.989469e-01 \n", 395 | "1999998 0.295036 0.311392 ... 3.359812e-21 3.639918e-21 4.014346e-21 \n", 396 | "1999999 0.222433 0.200406 ... 1.674699e-01 1.698070e-01 1.845447e-01 \n", 397 | "\n", 398 | " T_1040nm T_1050nm T_1060nm T_1070nm T_1080nm \\\n", 399 | "0 3.349879e-20 2.357321e-20 1.845059e-20 1.634083e-20 1.645585e-20 \n", 400 | "1 2.241115e-01 2.363959e-01 2.409754e-01 2.310539e-01 2.122547e-01 \n", 401 | "2 1.465879e-22 1.327228e-22 1.149973e-22 9.994108e-23 8.733339e-23 \n", 402 | "3 2.603744e-01 2.430273e-01 2.243570e-01 2.154357e-01 2.151012e-01 \n", 403 | "4 2.958163e-18 2.665547e-18 2.660998e-18 2.863903e-18 3.030304e-18 \n", 404 | "... ... ... ... ... ... \n", 405 | "1999995 6.820254e-01 6.665148e-01 6.529206e-01 6.421852e-01 6.347691e-01 \n", 406 | "1999996 9.458620e-19 1.289738e-18 1.388152e-18 1.214964e-18 9.433717e-19 \n", 407 | "1999997 2.081872e-01 2.118915e-01 2.082258e-01 1.912477e-01 1.665002e-01 \n", 408 | "1999998 4.422215e-21 4.775591e-21 4.983909e-21 4.991032e-21 4.789933e-21 \n", 409 | "1999999 2.087810e-01 2.312235e-01 2.395830e-01 2.362823e-01 2.329998e-01 \n", 410 | "\n", 411 | " T_1090nm T_1100nm \n", 412 | "0 1.893093e-20 2.501610e-20 \n", 413 | "1 1.949739e-01 1.856023e-01 \n", 414 | "2 7.615815e-23 6.604243e-23 \n", 415 | "3 2.178291e-01 2.189599e-01 \n", 416 | "4 2.821330e-18 2.301221e-18 \n", 417 | "... ... ... \n", 418 | "1999995 6.306720e-01 6.294758e-01 \n", 419 | "1999996 6.990779e-19 5.163660e-19 \n", 420 | "1999997 1.461210e-01 1.359751e-01 \n", 421 | "1999998 4.425239e-21 3.967554e-21 \n", 422 | "1999999 2.346766e-01 2.354661e-01 \n", 423 | "\n", 424 | "[2000000 rows x 145 columns]" 425 | ] 426 | }, 427 | "execution_count": 5, 428 | "metadata": {}, 429 | "output_type": "execute_result" 430 | } 431 | ], 432 | "source": [ 433 | "df" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 20, 439 | "id": "822cee0d", 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "import ast\n", 444 | "\n", 445 | "spec_index = ['R_'+str(i)+'nm' for i in range(400, 1110, 10)] + ['T_'+str(i)+'nm' for i in range(400, 1110, 10)]\n", 446 | "\n", 447 | "spec = df[spec_index].values\n", 448 | "\n", 449 | "structure = list(df['structure'].apply(ast.literal_eval))" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "id": "7cdafb04", 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [ 459 | "\n", 460 | "with open('dataset/Structure_'+mode+'.pkl', 'wb') as files:\n", 461 | " pkl.dump(structure, files)\n", 462 | " \n", 463 | "with open('dataset/Spectrum_'+mode+'.pkl', 'wb') as files:\n", 464 | " pkl.dump(spec, files)" 465 | ] 466 | } 467 | ], 468 | "metadata": { 469 | "kernelspec": { 470 | "display_name": "transformer", 471 | "language": "python", 472 | "name": "python3" 473 | }, 474 | "language_info": { 475 | "codemirror_mode": { 476 | "name": "ipython", 477 | "version": 3 478 | }, 479 | "file_extension": ".py", 480 | "mimetype": "text/x-python", 481 | "name": "python", 482 | "nbconvert_exporter": "python", 483 | "pygments_lexer": "ipython3", 484 | "version": "3.8.19" 485 | } 486 | }, 487 | "nbformat": 4, 488 | "nbformat_minor": 5 489 | } 490 | -------------------------------------------------------------------------------- /self_improving/data_perturb.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../optogpt') 3 | import os 4 | import math 5 | import numpy as np 6 | import torch 7 | import random 8 | import pandas as pd 9 | from core.datasets.sim import load_materials, spectrum 10 | from multiprocessing import Pool 11 | import multiprocessing 12 | import multiprocessing as mp 13 | import pyswarms as ps 14 | 15 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | DATABASE = '../optogpt/nk' 17 | illuminant = SDS_ILLUMINANTS['D65'] 18 | cmfs = MSDS_CMFS['CIE 1931 2 Degree Standard Observer'] 19 | 20 | mats = ['Al', 'Ag', 'Al2O3', 'AlN', 'Ge', 'HfO2', 'ITO', 'MgF2', 'MgO', 'Si', 'Si3N4', 'SiO2', 'Ta2O5', 'TiN', 'TiO2', 'ZnO', 'ZnS', 'ZnSe', 'Glass_Substrate'] 21 | thicks = [str(i) for i in range(10, 505, 10)] 22 | high_index_mats = ['TiO2','ZnS','ZnSe','Ta2O5','HfO2'] 23 | medium_index_mats = ['SiO2','Al2O3','MgF2','Si3N4'] 24 | low_index_mats = ['MgO','ITO'] 25 | metal_mats = ['Al','Ag'] 26 | semi_mats = ['Ge','Si'] 27 | 28 | lamda_low = 0.4 29 | lamda_high = 1.1 30 | wavelengths = np.arange(lamda_low, lamda_high+1e-3, 0.01) 31 | 32 | nk_dict = load_materials(all_mats = mats, wavelengths = wavelengths, DATABASE = DATABASE) 33 | 34 | 35 | def return_mat_thick(struc_list): 36 | materials = [] 37 | thickness = [] 38 | for struc_ in struc_list: 39 | materials.append(struc_.split('_')[0]) 40 | thickness.append(struc_.split('_')[1]) 41 | 42 | return materials, thickness 43 | 44 | POPULATION_SIZE = 100 45 | 46 | # Valid genes: materials 47 | 48 | # Target string to be generated 49 | def mutate_genes(individual): 50 | ''' 51 | gene mutation with specified probabilities 52 | ''' 53 | mat_to_mutate = np.random.randint(0, len(individual.chromosome)) 54 | material = individual.chromosome[mat_to_mutate] 55 | 56 | # Determine mutation type based on specified probabilities 57 | mutation_type = np.random.choice(['same_class', 'unchanged', 'other_class'], p=[0.8, 0.1, 0.1]) 58 | 59 | if mutation_type == 'same_class': 60 | # mutate within the same category 61 | if material in high_index_mats: 62 | new_material = random.choice([m for m in high_index_mats if m != material]) # Avoid picking the same material 63 | elif material in medium_index_mats: 64 | new_material = random.choice([m for m in medium_index_mats if m != material]) 65 | elif material in low_index_mats: 66 | new_material = random.choice([m for m in low_index_mats if m != material]) 67 | elif material in metal_mats: 68 | new_material = random.choice([m for m in metal_mats if m != material]) 69 | else: 70 | new_material = random.choice([m for m in semi_mats if m != material]) 71 | elif mutation_type == 'unchanged': 72 | # Keep the material unchanged 73 | new_material = material 74 | elif mutation_type == 'other_class': 75 | # mutate to a different category 76 | all_mats = [] 77 | if material in high_index_mats: 78 | all_mats = medium_index_mats + low_index_mats + metal_mats + semi_mats 79 | elif material in medium_index_mats: 80 | all_mats = high_index_mats + low_index_mats + metal_mats + semi_mats 81 | elif material in low_index_mats: 82 | all_mats = high_index_mats + medium_index_mats + metal_mats + semi_mats 83 | elif material in metal_mats: 84 | all_mats = high_index_mats + medium_index_mats + low_index_mats + semi_mats 85 | else: 86 | all_mats = high_index_mats + medium_index_mats + low_index_mats + metal_mats 87 | new_material = random.choice([m for m in all_mats if m != material]) # Ensure different material is chosen 88 | 89 | individual.chromosome[mat_to_mutate] = new_material 90 | individual.cal_fitness() 91 | return individual 92 | 93 | def mate(par1, par2): 94 | ''' 95 | Perform mating and produce new offspring 96 | ''' 97 | 98 | # chromosome for offspring 99 | child_chromosome = [] 100 | child_thickness = [] 101 | cut_off_point = np.random.randint(0, min(len(par1.chromosome), len(par2.chromosome))) 102 | dice = np.random.random() 103 | if dice < 0.5: 104 | child_chromosome = par1.chromosome[:cut_off_point] + par2.chromosome[cut_off_point:] 105 | child_thickness = par1.thickness[:cut_off_point] + par2.thickness[cut_off_point:] 106 | else: 107 | child_chromosome = par2.chromosome[:cut_off_point] + par1.chromosome[cut_off_point:] 108 | child_thickness = par2.thickness[:cut_off_point] + par1.thickness[cut_off_point:] 109 | 110 | # create new Individual(offspring) using 111 | child_individual = [str(child_chrom) + '_' + str(child_thick) for child_chrom, child_thick in zip(child_chromosome, child_thickness)] 112 | # generated chromosome for offspring 113 | return Individual(child_individual, par1.target) 114 | 115 | 116 | class Individual(): 117 | ''' 118 | Class representing individual in population 119 | ''' 120 | def __init__(self, chromosome, target): 121 | self.chromosome, self.thickness = return_mat_thick(chromosome) 122 | self.target = target 123 | self.fitness = self.cal_fitness() 124 | 125 | def cal_fitness(self): 126 | ''' 127 | Calculate fitness score, it is the number of 128 | characters in string which differ from target 129 | string. 130 | ''' 131 | designed_spec = spectrum(self.chromosome, self.thickness, wavelengths=wavelengths, nk_dict=nk_dict, substrate='Glass_Substrate', substrate_thick=500000) 132 | fitness = np.mean(np.square(np.array(designed_spec) - np.array(self.target))) 133 | 134 | return fitness 135 | 136 | def GA_perturb(mat_thicks,target,size): 137 | ''' 138 | Create initial population 139 | ''' 140 | population = [] 141 | for i in range(size): 142 | crt = mat_thicks[i] 143 | population.append(Individual(crt, target)) 144 | generation = 1 145 | # stopping criteria: mean fitness score of the population doesn't improve for 10 generations 146 | early_stopping_patience = 20 147 | early_stopping = 0 148 | crt_best_fitness = np.mean([ind.fitness for ind in population]) 149 | crt_best_population = mat_thicks 150 | crt_best_err = [] 151 | print("Original mean fitness: ", np.mean([ind.fitness for ind in population])) 152 | while True: 153 | ''' 154 | Sort the population in increasing order of fitness score 155 | ''' 156 | population = sorted(population, key = lambda x:x.fitness) 157 | # if the mean fitness score doesn't improve for 10 generations, stop 158 | 159 | #calculate the mean fitness score 160 | mean_fitness = np.mean([population[i].fitness for i in range(20)]) 161 | if mean_fitness < crt_best_fitness: 162 | crt_best_fitness = mean_fitness 163 | early_stopping = 0 164 | crt_best_population = [] 165 | for i in range(20): 166 | crt_struct = [] 167 | for j in range(len(population[i].chromosome)): 168 | crt_struct.append(str(population[i].chromosome[j]) + '_' + str(population[i].thickness[j])) 169 | crt_best_population.append(crt_struct) 170 | crt_best_err = [] 171 | for i in range(20): 172 | crt_best_err.append(population[i].fitness) 173 | else: 174 | early_stopping += 1 175 | if early_stopping == early_stopping_patience: 176 | break 177 | 178 | ''' 179 | Otherwise generate new offsprings for new generation 180 | ''' 181 | new_generation = [] 182 | ''' 183 | Perform Elitism, that mean 10% of fittest population 184 | goes to the next generation 185 | ''' 186 | s = int((10*POPULATION_SIZE)/100) 187 | new_generation.extend(population[:s]) 188 | ''' 189 | From 50% of fittest population, Individuals 190 | will mate to produce offspring 191 | ''' 192 | s = int((90*POPULATION_SIZE)/100) 193 | for _ in range(s): 194 | parent1 = random.choice(population[:10]) 195 | parent2 = random.choice(population[:10]) 196 | child = mate(parent1, parent2) 197 | mutated_child = mutate_genes(child) 198 | new_generation.append(mutated_child) 199 | 200 | 201 | population = new_generation 202 | print("Generation: ", generation, "Fitness: ", np.mean([ind.fitness for ind in population]) ) 203 | # print("Best struc: ", population[0].fitness) 204 | generation += 1 205 | return crt_best_population 206 | 207 | 208 | def PSO_perturb(mat_thicks, target, type='MAE'): 209 | 210 | materials, thickness = return_mat_thick(mat_thicks) 211 | 212 | M_t = len(thickness) 213 | x = [float(ii) for ii in thickness] 214 | 215 | def objective_func(x): 216 | # x is a matrix of shape (n_particles, dimensions) 217 | # Each row represents a particle (potential solution) 218 | n_particles = x.shape[0] 219 | j = [0] * n_particles 220 | 221 | for i in range(n_particles): 222 | R_T = spectrum(materials, list(x[i]), wavelengths=wavelengths, nk_dict=nk_dict, substrate='Glass_Substrate', substrate_thick=500000) 223 | if type == 'MAE': 224 | j[i] = np.mean(np.abs(np.array(R_T) - np.array(target))) 225 | elif type == 'MSE': 226 | j[i] = np.mean(np.square(np.array(R_T) - np.array(target))) 227 | else: 228 | raise NotImplementedError 229 | return np.array(j) 230 | # PSO hyperparameters 231 | options = {'c1': 0.5, 'c2': 0.3, 'w': 0.9} 232 | 233 | # Bounds for each dimension in the search space 234 | bounds = (np.array([10] * M_t), np.array([500] * M_t)) 235 | 236 | init_pos = np.full((5, M_t), x) 237 | 238 | # Initialize the optimizer 239 | optimizer = ps.single.GlobalBestPSO(n_particles=5, dimensions=M_t, options=options, bounds=bounds, init_pos = init_pos) 240 | 241 | # Create the final structure 242 | 243 | all_structures = [] 244 | for _ in range(5): # Number of iterations 245 | # Perform one step of optimization 246 | optimizer.optimize(objective_func, iters=1) 247 | 248 | # Record the current structures 249 | current_positions = optimizer.swarm.position 250 | for particle_pos in current_positions: 251 | temp_struc = [materials[i] + '_' + str(int(pos)) for i, pos in enumerate(particle_pos)] 252 | all_structures.append(temp_struc) 253 | 254 | # Return the list of all structures 255 | return all_structures 256 | 257 | 258 | def apply_func_to_chunk(chunk_df): 259 | return chunk_df.apply(lambda x: PSO_perturb(x['designed_struct'], x['original_spec'], type='MSE'), axis=1) 260 | 261 | def apply_func_to_chunk2(chunk_df): 262 | return chunk_df.apply(lambda x: PSO_perturb(x['perturb_struct1'], x['original_spec'], type='MSE'), axis=1) 263 | 264 | 265 | def get_to_be_perturbed_data(original_data, give_up_threshold=3): 266 | origin_df = pd.DataFrame() 267 | origin_df['original_spec'] = original_data.original_spec 268 | origin_df['designed_spec'] = original_data.designed_spec.tolist() 269 | origin_df['designed_struct'] = original_data.designed_struct 270 | origin_df['error'] = original_data.error 271 | 272 | # get the data with mae smaller than threshold 273 | origin_df = origin_df[origin_df['error'] < give_up_threshold].reset_index(drop=True) 274 | 275 | return origin_df 276 | 277 | def random_perturb(crt_struct): 278 | if len(crt_struct) == 0: 279 | return crt_struct 280 | idx = np.random.randint(0, len(crt_struct)) 281 | while (crt_struct[idx] in ['EOS','BOS','UNK']): 282 | idx = np.random.randint(0, len(crt_struct)) 283 | # separate the structure into material and thickness 284 | material, thickness = crt_struct[idx].split('_') 285 | # disturb the thickness randomly by 10/20/30/40/50 nm, normal distribution 286 | # Define the numbers and their corresponding probabilities 287 | numbers = [10, 20, 30, 40, 50] 288 | probabilities = [0.30, 0.25, 0.20, 0.15, 0.10] 289 | 290 | # Sample a number based on the defined probabilities 291 | sampled_number = random.choices(numbers, probabilities, k=1)[0] 292 | 293 | # Randomly choose whether the number is positive or negative 294 | sampled_number *= random.choice([-1, 1]) 295 | thickness = int(thickness) + sampled_number 296 | if (thickness > 500): 297 | thickness = 500 298 | if (thickness < 10): 299 | thickness = 10 300 | # combine the material and thickness 301 | crt_struct[idx] = str(material)+'_'+str(thickness) 302 | # perturb the structure 303 | return crt_struct 304 | 305 | def apply_func_to_chunk_GA(chunk_df): 306 | return chunk_df.apply(lambda x: GA_perturb(x['designed_struct'], x['original_spec'], 20), axis=1) 307 | 308 | def perturb_data(df, method="random", output_dir=None): 309 | 310 | if method == "random": 311 | # random perturb 312 | perturb_df = df.copy() 313 | perturb_df['perturb_struct'] = perturb_df['designed_struct'].apply(lambda x:random_perturb(x)) 314 | elif method == "PSO": 315 | # PSO perturb 316 | perturb_df = df.copy() 317 | 318 | def parallelize_dataframe(df, func): 319 | num_cores = mp.cpu_count() # Number of CPU cores 320 | df_split = np.array_split(df, min(100, num_cores)) # Split DataFrame into chunks 321 | pool = mp.Pool(num_cores) 322 | results = pool.map(func, df_split) # Process each chunk in parallel 323 | pool.close() 324 | pool.join() 325 | # Concatenate results into a single Series 326 | concatenated_series = pd.Series([item for sublist in results for item in sublist]) 327 | return concatenated_series 328 | perturb_df['perturb_struct'] = parallelize_dataframe(perturb_df, apply_func_to_chunk) 329 | perturb_df = expand_perturbed_data(perturb_df) 330 | perturb_df['perturb_struct'] = perturb_df.apply(lambda x:round_to_nearest_length(x),axis=1) 331 | 332 | elif method == "GA_PSO": 333 | # first use GA to perturb the structure, then use PSO to optimize the perturbed structure 334 | 335 | def parallelize_dataframe(df, func): 336 | num_cores = mp.cpu_count() # Number of CPU cores 337 | df_split = np.array_split(df, min(100, num_cores)) # Split DataFrame into chunks 338 | pool = mp.Pool(num_cores) 339 | results = pool.map(func, df_split) # Process each chunk in parallel 340 | pool.close() 341 | pool.join() 342 | # Concatenate results into a single Series 343 | concatenated_series = pd.Series([item for sublist in results for item in sublist]) 344 | return concatenated_series 345 | new_df = df.copy() 346 | original_error = new_df['error'][19::20].reset_index(drop=True) 347 | designed_spec = new_df['designed_spec'][19::20].reset_index(drop=True) 348 | original_spec = new_df['original_spec'][19::20].reset_index(drop=True) 349 | 350 | tmp = [] 351 | for i in range(0, len(new_df), 20): 352 | tmp.append(new_df['designed_struct'][i:i+20].to_list()) 353 | chunks = pd.Series(tmp) 354 | designed_structs = pd.DataFrame({'designed_struct':chunks}) 355 | designed_structs['original_spec'] = original_spec 356 | designed_structs['designed_spec'] = designed_spec 357 | designed_structs['error'] = original_error 358 | designed_structs.reset_index() 359 | print(designed_structs.iloc[0]) 360 | 361 | 362 | designed_structs['perturb_struct1'] = parallelize_dataframe(designed_structs, apply_func_to_chunk_GA) 363 | designed_structs = designed_structs.explode('perturb_struct1').reset_index(drop=True) 364 | 365 | # Save intermediate results if output_dir is provided 366 | if output_dir: 367 | os.makedirs(output_dir, exist_ok=True) 368 | designed_structs.to_pickle(os.path.join(output_dir, "designed_structs.pkl")) 369 | 370 | # apply PSO to the perturbed structure 371 | perturb_df = designed_structs.copy() 372 | perturb_df['perturb_struct'] = parallelize_dataframe(perturb_df, apply_func_to_chunk2) 373 | perturb_df = expand_perturbed_data(perturb_df) 374 | perturb_df['perturb_struct'] = perturb_df.apply(lambda x:round_to_nearest_length(x),axis=1) 375 | 376 | # Save final results if output_dir is provided 377 | if output_dir: 378 | perturb_df.to_pickle(os.path.join(output_dir, "perturb_df.pkl")) 379 | 380 | else: 381 | raise NotImplementedError 382 | return perturb_df 383 | 384 | def simulate_perturbed_struct(perturb_df, error_type="MAE"): 385 | def return_mat_thick(struc_list): 386 | materials = [] 387 | thickness = [] 388 | for struc_ in struc_list: 389 | if (struc_ != 'EOS' and struc_ != 'BOS' and struc_ != 'UNK' and struc_ != 'PAD'): 390 | materials.append(struc_.split('_')[0]) 391 | thickness.append(struc_.split('_')[1]) 392 | return materials, thickness 393 | 394 | def simulate_spec(all_mat, all_thick): 395 | NUM_CORES = min(mp.cpu_count(), 16) # Reasonable default 396 | DATABASE = './nk' 397 | mats = ['Al', 'Ag', 'Al2O3', 'AlN', 'Ge', 'HfO2', 'ITO', 'MgF2', 'MgO', 'Si', 'Si3N4', 'SiO2', 'Ta2O5', 'TiN', 'TiO2', 'ZnO', 'ZnS', 'ZnSe', 'Glass_Substrate'] 398 | nk_dict = load_materials(all_mats = mats, wavelengths = wavelengths, DATABASE = DATABASE ) 399 | args_for_starmap = [(mat, thick, 's', 0, wavelengths, nk_dict,'Glass_Substrate', 500000) 400 | for mat, thick in zip(all_mat, all_thick)] 401 | 402 | # Create a pool and use starmap 403 | with Pool(NUM_CORES) as pool: 404 | spec_res = pool.starmap(spectrum, args_for_starmap) 405 | pool.close() 406 | 407 | return spec_res 408 | 409 | # separate the structure into material and thickness for perturbed 410 | all_mat = [] 411 | all_thick = [] 412 | for i in range(len(perturb_df)): 413 | material, thickness = return_mat_thick(perturb_df.iloc[i,:]['perturb_struct']) 414 | all_mat.append(material) 415 | all_thick.append(thickness) 416 | 417 | # simulate the perturbed spectrum 418 | perturb_df['perturb_spec'] = simulate_spec(all_mat, all_thick) 419 | 420 | # calculate the new error 421 | if error_type == "MAE": 422 | perturb_df['new_error'] = perturb_df.apply(lambda x:np.mean(np.abs(np.array(x['original_spec']) - np.array(x['perturb_spec']))),axis=1) 423 | else: 424 | perturb_df['new_error'] = perturb_df.apply(lambda x:np.mean(np.square(np.array(x['original_spec']) - np.array(x['perturb_spec']))),axis=1) 425 | return perturb_df 426 | 427 | def get_perturbed_better_data(perturb_df): 428 | # get the data with new mae smaller than original mae 429 | perturb_df = perturb_df[perturb_df['new_error'] < perturb_df['error']] 430 | return perturb_df 431 | 432 | def expand_perturbed_data(perturb_df): 433 | # Create a new DataFrame 434 | new_df = pd.DataFrame() 435 | 436 | # Copy the relevant columns 437 | for col in ['original_spec', 'designed_spec', 'designed_struct', 'perturb_struct', 'error']: 438 | new_df[col] = perturb_df[col] 439 | 440 | # Explode the perturb_struct column 441 | new_df = new_df.explode('perturb_struct') 442 | 443 | # Reset the index 444 | new_df.reset_index(drop=True, inplace=True) 445 | 446 | return new_df 447 | 448 | def round_to_nearest_length(perturb_df_slice): 449 | perturb_struct = perturb_df_slice['perturb_struct'] 450 | for i in range(len(perturb_struct)): 451 | material, thickness = perturb_struct[i].split('_') 452 | thickness = int(thickness) 453 | # round the thickness to the nearest 10 454 | thickness = int(round(thickness / 10.0)) * 10 455 | # if the thickness is smaller than 10, set it to 10 456 | if (thickness < 10): 457 | thickness = 10 458 | # if the thickness is larger than 500, set it to 500 459 | if (thickness > 500): 460 | thickness = 500 461 | # combine the material and thickness 462 | perturb_struct[i] = material + '_' + str(thickness) 463 | return perturb_struct 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | --------------------------------------------------------------------------------