├── .gitignore ├── LICENSE ├── README.md ├── docs ├── Makefile ├── apidoc │ ├── pypret.fourier.rst │ ├── pypret.frequencies.rst │ ├── pypret.graphics.rst │ ├── pypret.io.rst │ ├── pypret.lib.rst │ ├── pypret.material.rst │ ├── pypret.mesh_data.rst │ ├── pypret.pnps.rst │ ├── pypret.pulse.rst │ ├── pypret.pulse_error.rst │ └── pypret.retrieval.rst ├── conf.py ├── getting_started.rst ├── index.rst ├── installation.rst ├── make.bat ├── references.rst └── requirements.txt ├── pypret ├── __init__.py ├── autocorrelation.py ├── fourier.py ├── frequencies.py ├── graphics.py ├── io │ ├── __init__.py │ ├── handlers.py │ ├── io.py │ ├── options.py │ └── tests │ │ └── test_io.py ├── lib.py ├── material.py ├── mesh_data.py ├── pnps.py ├── pulse.py ├── pulse_error.py ├── random_pulse.py ├── retrieval │ ├── __init__.py │ ├── nlo_retriever.py │ ├── retriever.py │ └── step_retriever.py └── tests │ ├── data │ ├── dscan-sd-copra-retrieved.hdf5 │ ├── dscan-sd-trace.hdf5 │ ├── dscan-shg-copra-retrieved.hdf5 │ ├── dscan-shg-gp-dscan-retrieved.hdf5 │ ├── dscan-shg-trace.hdf5 │ ├── dscan-thg-copra-retrieved.hdf5 │ ├── dscan-thg-gp-dscan-retrieved.hdf5 │ ├── dscan-thg-trace.hdf5 │ ├── frog-pg-copra-retrieved.hdf5 │ ├── frog-pg-trace.hdf5 │ ├── frog-shg-copra-retrieved.hdf5 │ ├── frog-shg-gpa-retrieved.hdf5 │ ├── frog-shg-pcgpa-retrieved.hdf5 │ ├── frog-shg-pie-retrieved.hdf5 │ ├── frog-shg-trace.hdf5 │ ├── ifrog-sd-copra-retrieved.hdf5 │ ├── ifrog-sd-trace.hdf5 │ ├── ifrog-shg-copra-retrieved.hdf5 │ ├── ifrog-shg-trace.hdf5 │ ├── ifrog-thg-copra-retrieved.hdf5 │ ├── ifrog-thg-trace.hdf5 │ ├── initial.hdf5 │ ├── miips-sd-copra-retrieved.hdf5 │ ├── miips-sd-trace.hdf5 │ ├── miips-shg-copra-retrieved.hdf5 │ ├── miips-shg-trace.hdf5 │ ├── miips-thg-copra-retrieved.hdf5 │ ├── miips-thg-trace.hdf5 │ ├── pulse.hdf5 │ ├── tdp-shg-copra-retrieved.hdf5 │ └── tdp-shg-trace.hdf5 │ ├── test_fourier.py │ ├── test_mesh_data.py │ ├── test_regression.py │ └── test_sellmeier.py └── scripts ├── benchmarking.py ├── create_pulse_bank.py ├── nlo_retrievers.py ├── path_helper.py ├── pulse_bank.hdf5 ├── result.png ├── simple_example.py ├── test_retrieval_algorithms.py ├── test_retrieval_algorithms_plot.py.py └── test_single_retrieval.py /.gitignore: -------------------------------------------------------------------------------- 1 | # repository specific stuff 2 | scripts/results/ 3 | 4 | Thumbs.db 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # pytest 40 | .pytest_cache 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | .static_storage/ 64 | .media/ 65 | local_settings.py 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Nils C. Geib 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python for Pulse Retrieval 2 | 3 | This project aims to provide numerical algorithms for ultrashort laser pulse measurement methods such as frequency-resolved optical gating (FROG), dispersion scan (d-scan), or time-domain ptychography (TDP) and more. Specifically, it provides a reference implementation of the algorithms presented in our paper ["Common pulse retrieval algorithm: a fast and universal method to retrieve ultrashort pulses"](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-6-4-495). 4 | 5 | ![Example output](scripts/result.png?raw=true "Result") 6 | 7 | ## dispersion scan 8 | We decided to remove the implementation of dispersion scan from the publicly available code. If you want to use pypret for d-scan please contact us. 9 | 10 | ## Notes 11 | 12 | This code is a complete re-implentation of the (rather messy) code used in our research. It was created with the expressive purpose to be well-documented and educational. The notation in the code tries to match the notation in the paper and references it. I would strongly recommend reading the publication before diving into the code. 13 | 14 | As a down-side the code is not optimized and on many occasions I deliberately decided to go with the less efficient but more expressive and straightforward solution. This pains me somewhat and I do not recommend to use the code as an example for high-performance, numerical Python. It creates unecessarily many temporal copies and re-calculates many values that could be stored. 15 | 16 | ## Documentation 17 | 18 | The full documentation can be found at [pypret.readthedocs.io](https://pypret.readthedocs.io). The ``scripts`` folder contains examples of how the package is used. 19 | 20 | ### Usage example 21 | The code is contained in the plain Python package ``pypret`` (PYthon for Pulse Retrieval). Point your PYTHONPATH to it and you can use it. The package contains many classes that are in general useful for ultrashort pulse simulations. The most iconic usage, however, is to simulate pulse measurement schemes and perform pulse retrieval: 22 | 23 | ```python 24 | import numpy as np 25 | import pypret 26 | # create simulation grid 27 | ft = pypret.FourierTransform(256, dt=5.0e-15) 28 | # instantiate a pulse object, central wavelength 800 nm 29 | pulse = pypret.Pulse(ft, 800e-9) 30 | # create a random pulse with time-bandwidth product of 2. 31 | pypret.random_pulse(pulse, 2.0) 32 | # plot the pulse 33 | pypret.PulsePlot(pulse) 34 | 35 | # simulate a frog measurement 36 | delay = np.linspace(-500e-15, 500e-15, 128) # delay in s 37 | pnps = pypret.PNPS(pulse, "frog", "shg") 38 | # calculate the measurement trace 39 | pnps.calculate(pulse.spectrum, delay) 40 | original_spectrum = pulse.spectrum 41 | # and plot it 42 | pypret.MeshDataPlot(pnps.trace) 43 | 44 | # and do the retrieval 45 | ret = pypret.Retriever(pnps, "copra", verbose=True, maxiter=300) 46 | # start with a Gaussian spectrum with random phase as initial guess 47 | pypret.random_gaussian(pulse, 50e-15, phase_max=0.0) 48 | # now retrieve from the synthetic trace simulated above 49 | ret.retrieve(pnps.trace, pulse.spectrum) 50 | # and print the retrieval results 51 | ret.result(original_spectrum) 52 | ``` 53 | The text output should look similar to this: 54 | ``` 55 | Retrieval report 56 | trace error R = 1.26951777571186456e-11 57 | min. trace error R0 = 0.00000000000000000e+00 58 | R - R0 = 1.26951777571186456e-11 59 | 60 | pulse error ε = 1.52280811825410699e-07 61 | ``` 62 | This shows that the retrieval converged to the input trace within almost the numerical accuracy of the underlying calculations. This is, of course, only possible if no noise was added to the input data. Occasionally, the algorithm will not converge - in which case you would have to run it again. 63 | 64 | More elaborate examples of how to use this package can be found in the scripts directory. 65 | 66 | ## Author 67 | 68 | This project was developed by Nils C. Geib at the [Institute of Applied Physics](https://www.iap.uni-jena.de/Micro_+structure+Technology/Research+Group%3Cbr%3EPhotonics+in+2D_Materials/Ultrashort+Laser+Pulse+Metrology.html) of the [University of Jena](https://www.uni-jena.de), Germany. 69 | 70 | For any questions or comments, you can contact me via email: nils.geib@uni-jena.de 71 | 72 | ## License 73 | 74 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 75 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/apidoc/pypret.fourier.rst: -------------------------------------------------------------------------------- 1 | pypret.fourier module 2 | ===================== 3 | 4 | .. automodule:: pypret.fourier 5 | :no-members: 6 | .. autoclass:: pypret.fourier.FourierTransform -------------------------------------------------------------------------------- /docs/apidoc/pypret.frequencies.rst: -------------------------------------------------------------------------------- 1 | pypret.frequencies module 2 | ========================= 3 | 4 | .. automodule:: pypret.frequencies 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.graphics.rst: -------------------------------------------------------------------------------- 1 | pypret.graphics module 2 | ====================== 3 | 4 | .. automodule:: pypret.graphics 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.io.rst: -------------------------------------------------------------------------------- 1 | pypret.io package 2 | ================= 3 | 4 | .. automodule:: pypret.io 5 | 6 | Public interface 7 | ---------------- 8 | .. autoclass:: pypret.io.options.HDF5Options 9 | .. autofunction:: pypret.io.save 10 | .. autofunction:: pypret.io.load 11 | .. autoclass:: pypret.io.IO 12 | :members: 13 | 14 | 15 | Custom handlers 16 | --------------- 17 | 18 | .. automodule:: pypret.io.handlers 19 | :no-members: 20 | .. autoclass:: pypret.io.handlers.Handler 21 | :members: 22 | .. autoclass:: pypret.io.handlers.TypeHandler 23 | .. autoclass:: pypret.io.handlers.InstanceHandler 24 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.lib.rst: -------------------------------------------------------------------------------- 1 | pypret.lib module 2 | ================= 3 | 4 | .. automodule:: pypret.lib 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.material.rst: -------------------------------------------------------------------------------- 1 | pypret.material module 2 | ====================== 3 | 4 | .. automodule:: pypret.material 5 | :no-members: 6 | 7 | Available materials 8 | ------------------- 9 | .. autodata:: pypret.material.BK7 10 | .. autodata:: pypret.material.FS 11 | 12 | Base classes 13 | ------------ 14 | .. autoclass:: pypret.material.BaseMaterial 15 | .. autoclass:: pypret.material.SellmeierF1 16 | .. autoclass:: pypret.material.SellmeierF2 17 | 18 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.mesh_data.rst: -------------------------------------------------------------------------------- 1 | pypret.mesh\_data module 2 | ======================== 3 | 4 | .. automodule:: pypret.mesh_data 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.pnps.rst: -------------------------------------------------------------------------------- 1 | pypret.pnps module 2 | ================== 3 | 4 | .. automodule:: pypret.pnps 5 | :no-members: 6 | 7 | Public interface 8 | ---------------- 9 | .. autofunction:: pypret.pnps.PNPS 10 | .. autoclass:: pypret.pnps.FROG 11 | .. autoclass:: pypret.pnps.IFROG 12 | .. autoclass:: pypret.pnps.TDP 13 | .. autoclass:: pypret.pnps.DSCAN 14 | .. autoclass:: pypret.pnps.MIIPS 15 | 16 | API 17 | --- 18 | .. autoclass:: pypret.pnps.BasePNPS 19 | .. autoclass:: pypret.pnps.CollinearPNPS 20 | .. autoclass:: pypret.pnps.NoncollinearPNPS 21 | 22 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.pulse.rst: -------------------------------------------------------------------------------- 1 | pypret.pulse module 2 | =================== 3 | 4 | .. automodule:: pypret.pulse 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. autofunction:: pypret.random_pulse.random_pulse 10 | 11 | .. autofunction:: pypret.random_pulse.random_gaussian 12 | -------------------------------------------------------------------------------- /docs/apidoc/pypret.pulse_error.rst: -------------------------------------------------------------------------------- 1 | pypret.pulse_error 2 | ----------------------- 3 | 4 | .. automodule:: pypret.pulse_error 5 | :members: -------------------------------------------------------------------------------- /docs/apidoc/pypret.retrieval.rst: -------------------------------------------------------------------------------- 1 | pypret.retrieval package 2 | ======================== 3 | 4 | .. automodule:: pypret.pnps 5 | :no-members: 6 | 7 | Retrieval algorithms 8 | -------------------- 9 | .. autofunction:: pypret.retrieval.retriever.Retriever 10 | .. autoclass:: pypret.retrieval.step_retriever.COPRARetriever 11 | .. autoclass:: pypret.retrieval.step_retriever.PCGPARetriever 12 | .. autoclass:: pypret.retrieval.step_retriever.GPARetriever 13 | .. autoclass:: pypret.retrieval.step_retriever.PIERetriever 14 | .. autoclass:: pypret.retrieval.step_retriever.GPDSCANRetriever 15 | .. autoclass:: pypret.retrieval.nlo_retriever.LMRetriever 16 | .. autoclass:: pypret.retrieval.nlo_retriever.NMRetriever 17 | .. autoclass:: pypret.retrieval.nlo_retriever.DERetriever 18 | .. autoclass:: pypret.retrieval.nlo_retriever.BFGSRetriever 19 | 20 | API 21 | --- 22 | .. autoclass:: pypret.retrieval.retriever.BaseRetriever 23 | .. autoclass:: pypret.retrieval.step_retriever.StepRetriever 24 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import os 16 | # import sys 17 | # sys.path.insert(0, os.path.abspath('.')) 18 | import sys 19 | from pathlib import Path 20 | pypret_folder = Path(__file__).resolve().parents[1] 21 | sys.path.insert(0, str(pypret_folder)) 22 | import pypret 23 | 24 | # -- Project information ----------------------------------------------------- 25 | 26 | project = 'pypret' 27 | copyright = '2019, Nils C. Geib' 28 | author = 'Nils C. Geib' 29 | 30 | # The short X.Y version 31 | version = pypret.__version__[:pypret.__version__.find(".")+2] 32 | # The full version, including alpha/beta/rc tags 33 | release = pypret.__version__ 34 | 35 | 36 | # -- General configuration --------------------------------------------------- 37 | 38 | # If your documentation needs a minimal Sphinx version, state it here. 39 | # 40 | # needs_sphinx = '1.0' 41 | 42 | # Add any Sphinx extension module names here, as strings. They can be 43 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 44 | # ones. 45 | extensions = [ 46 | 'sphinx.ext.autodoc', 47 | 'sphinx.ext.autosummary', 48 | 'sphinx.ext.mathjax', 49 | 'sphinx.ext.viewcode', 50 | 'sphinx.ext.napoleon', 51 | ] 52 | 53 | # Add any paths that contain templates here, relative to this directory. 54 | templates_path = ['_templates'] 55 | 56 | # The suffix(es) of source filenames. 57 | # You can specify multiple suffix as a list of string: 58 | # 59 | # source_suffix = ['.rst', '.md'] 60 | source_suffix = '.rst' 61 | 62 | # The master toctree document. 63 | master_doc = 'index' 64 | 65 | # The language for content autogenerated by Sphinx. Refer to documentation 66 | # for a list of supported languages. 67 | # 68 | # This is also used if you do content translation via gettext catalogs. 69 | # Usually you set "language" from the command line for these cases. 70 | language = None 71 | 72 | # List of patterns, relative to source directory, that match files and 73 | # directories to ignore when looking for source files. 74 | # This pattern also affects html_static_path and html_extra_path. 75 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 76 | 77 | # The name of the Pygments (syntax highlighting) style to use. 78 | pygments_style = None 79 | 80 | # autosummary options 81 | autosummary_generate = True 82 | 83 | # autodoc options 84 | autodoc_default_options = { 85 | 'members': None, 86 | 'member-order': 'bysource', 87 | 'special-members': '__init__', 88 | 'undoc-members': True, 89 | 'exclude-members': '__weakref__' 90 | } 91 | 92 | # -- Options for HTML output ------------------------------------------------- 93 | 94 | # The theme to use for HTML and HTML Help pages. See the documentation for 95 | # a list of builtin themes. 96 | # 97 | html_theme = 'default' 98 | 99 | # Theme options are theme-specific and customize the look and feel of a theme 100 | # further. For a list of options available for each theme, see the 101 | # documentation. 102 | # 103 | # html_theme_options = {} 104 | 105 | # Add any paths that contain custom static files (such as style sheets) here, 106 | # relative to this directory. They are copied after the builtin static files, 107 | # so a file named "default.css" will overwrite the builtin "default.css". 108 | html_static_path = ['_static'] 109 | 110 | # Custom sidebar templates, must be a dictionary that maps document names 111 | # to template names. 112 | # 113 | # The default sidebars (for documents that don't match any pattern) are 114 | # defined by theme itself. Builtin themes are using these templates by 115 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 116 | # 'searchbox.html']``. 117 | # 118 | # html_sidebars = {} 119 | 120 | 121 | # -- Options for HTMLHelp output --------------------------------------------- 122 | 123 | # Output file base name for HTML help builder. 124 | htmlhelp_basename = 'pypretdoc' 125 | 126 | 127 | # -- Options for LaTeX output ------------------------------------------------ 128 | 129 | latex_elements = { 130 | # The paper size ('letterpaper' or 'a4paper'). 131 | # 132 | # 'papersize': 'letterpaper', 133 | 134 | # The font size ('10pt', '11pt' or '12pt'). 135 | # 136 | # 'pointsize': '10pt', 137 | 138 | # Additional stuff for the LaTeX preamble. 139 | # 140 | # 'preamble': '', 141 | 142 | # Latex figure (float) alignment 143 | # 144 | # 'figure_align': 'htbp', 145 | } 146 | 147 | # Grouping the document tree into LaTeX files. List of tuples 148 | # (source start file, target name, title, 149 | # author, documentclass [howto, manual, or own class]). 150 | latex_documents = [ 151 | (master_doc, 'pypret.tex', 'pypret Documentation', 152 | 'Nils C. Geib', 'manual'), 153 | ] 154 | 155 | 156 | # -- Options for manual page output ------------------------------------------ 157 | 158 | # One entry per manual page. List of tuples 159 | # (source start file, name, description, authors, manual section). 160 | man_pages = [ 161 | (master_doc, 'pypret', 'pypret Documentation', 162 | [author], 1) 163 | ] 164 | 165 | 166 | # -- Options for Texinfo output ---------------------------------------------- 167 | 168 | # Grouping the document tree into Texinfo files. List of tuples 169 | # (source start file, target name, title, author, 170 | # dir menu entry, description, category) 171 | texinfo_documents = [ 172 | (master_doc, 'pypret', 'pypret Documentation', 173 | author, 'pypret', 'One line description of project.', 174 | 'Miscellaneous'), 175 | ] 176 | 177 | 178 | # -- Options for Epub output ------------------------------------------------- 179 | 180 | # Bibliographic Dublin Core info. 181 | epub_title = project 182 | 183 | # The unique identifier of the text. This can be a ISBN number 184 | # or the project homepage. 185 | # 186 | # epub_identifier = '' 187 | 188 | # A unique identification for the text. 189 | # 190 | # epub_uid = '' 191 | 192 | # A list of files that should not be packed into the epub file. 193 | epub_exclude_files = ['search.html'] 194 | 195 | 196 | # -- Extension configuration ------------------------------------------------- 197 | -------------------------------------------------------------------------------- /docs/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting started 2 | =============== 3 | 4 | pypret is a package to simulate and retrieve from measurements such as 5 | frequency-resolved optical gating (FROG), dispersion scan (d-scan), 6 | interferometric FROG (iFROG), time-domain ptychography (TDP) and even 7 | multiphoton intrapulse interference phase scan (MIIPS). These are all 8 | measurements used for ultrashort (sub-ps) laser pulse measurement. More 9 | generally the package can handle all kinds of parametrized nonlinear 10 | process spectra (PNPS) measurements. 11 | 12 | A good place to start reading on the algorithms and the used notation is 13 | our paper [Geib2019]_ and its supplement. pypret can be thought to accompany 14 | this publication and can be used to reproduce most of the results shown there. 15 | 16 | Basic Use 17 | --------- 18 | pypret can be used to simulate PNPS measurements. This is useful for designing 19 | experiments and necessary for retrieval, of course. 20 | 21 | In a first step you have to set up the simulation grid in time and frequency:: 22 | 23 | ft = pypret.FourierTransform(256, dt=2.5e-15) 24 | 25 | which generates a 256 elements grid with a temporal spacing of 2.5 fs centered 26 | around t=0. The frequency grid is chosen to match the reciprocity relation 27 | ``dt * dw = 2 * pi / N``. Alternatively you can specify the frequency spacing. 28 | See the documentation at :doc:`apidoc/pypret.fourier`. 29 | Next you can instantiate a :class:`pypret.Pulse` object:: 30 | 31 | pulse = pypret.Pulse(ft, 800e-9) 32 | 33 | where we used a central wavelength of 800 nm. This class can already be used 34 | for small but useful calculations:: 35 | 36 | # generate pulse with Gaussian spectrum and field standard deviation 37 | # of 20 nm 38 | pulse.spectrum = pypret.lib.gaussian(pulse.wl, x0=800e-9, sigma=20e-9) 39 | # print the accurate FWHM of the temporal intensity envelope 40 | print(pulse.fwhm(dt=pulse.dt/100)) 41 | # propagate it through 1cm of BK7 (remove first ord) 42 | phase = np.exp(1.0j * pypret.material.BK7.k(pulse.wl) * 0.01) 43 | pulse.spectrum = pulse.spectrum * phase 44 | # print the temporal FWHM again 45 | print(pulse.fwhm(dt=pulse.dt/100)) 46 | # finally plot the pulse 47 | pypret.graphics.PulsePlot(pulse) 48 | 49 | You can now instantiate a PNPS class with that pulse object:: 50 | 51 | insertion = np.linspace(-0.025, 0.025, 128) # insertion in m 52 | pnps = pypret.PNPS(pulse, "dscan", "shg", material=pypret.material.BK7) 53 | # calculate the measurement trace 54 | pnps.calculate(pulse.spectrum, delay) 55 | original_spectrum = pulse.spectrum 56 | # and plot it 57 | pypret.MeshDataPlot(pnps.trace) 58 | 59 | The PNPS constructor supports a lot of different PNPS measurements (see docs 60 | at :doc:`apidoc/pypret.pnps`). Furthermore, it is easy to implement your own. 61 | 62 | Finally, you can use pypret for pulse retrieval by instantiating a Retriever 63 | object:: 64 | 65 | # do the retrieval 66 | ret = pypret.Retriever(pnps, "copra", verbose=True, maxiter=300) 67 | # start with a Gaussian spectrum with random phase as initial guess 68 | pypret.random_gaussian(pulse, 50e-15, phase_max=0.0) 69 | # now retrieve from the synthetic trace simulated above 70 | ret.retrieve(pnps.trace, pulse.spectrum) 71 | # and print the retrieval results 72 | ret.result(original_spectrum) 73 | 74 | A lot of different retrieval algorithms besides the default, COPRA, are 75 | implemented (see docs at :doc:`apidoc/pypret.retrieval`). While COPRA should 76 | work for all PNPS measurements, you may try one of the others for verification. 77 | 78 | Storage 79 | ------- 80 | The :doc:`apidoc/pypret.io` subpackage supports saving almost arbitrary Python 81 | structures and all pypret classes to HDF5 files. You can either use the 82 | :func:`pypret.save` function or the `save` method on classes:: 83 | 84 | pnps.calculate(pulse.spectrum, insertion) 85 | pnps.trace.save("trace.hdf5") 86 | # or 87 | pypret.save(pnps.trace, "trace.hdf5") 88 | # load it with 89 | trace = pypret.load("trace.hdf5") 90 | 91 | This should make storing intermediate or final results almost effortless. 92 | 93 | Experimental data 94 | ----------------- 95 | As this question is surely going to come: you can use pypret to retrieve pulses 96 | from experimental data, however, it currently has no pre-processing functions 97 | to make that convenient. The data fed to the retrieval functions has to be 98 | properly dark-subtracted and interpolated. Furthermore, some features that are 99 | very useful for retrieval from experimental data (e.g., handling non-calibrated 100 | traces) are not yet implemented. This is on the top of the ToDo-list, though. 101 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | pypret 2 | ====== 3 | 4 | :Release: |release| 5 | :Date: |today| 6 | 7 | This is the documentation of `Python for pulse retrieval`. It is a Python 8 | package that aims to provide algorithms and tools to retrieve ultrashort 9 | laser pulses from parametrized nonlinear process spectra, such as 10 | frequency-resolved optical gating (FROG), dispersion scan (d-scan), 11 | time-domain ptychography (TDP) or multiphoton intrapulse interference phase 12 | scan (MIIPS). 13 | 14 | The package is currently in an early alpha state. It provides the 15 | algorithms but still requires thorough understanding of what they do to apply 16 | them correctly on measured data. 17 | 18 | Background 19 | ---------- 20 | 21 | The package was developed at the `Institute of Applied Physics`_ at the 22 | `Friedrich Schiller University Jena`_. Main author is Nils C. Geib. You 23 | can reach me at nils.geib@uni-jena.de if you have questions or comments on 24 | the code. 25 | 26 | The current capabilities of the package reflect mostly what we 27 | presented in our publication on a common pulse retrieval algorithm [Geib2019]_. 28 | If you want to reference this package you may cite that paper. 29 | 30 | The code in its current state mainly serves to give a reference implementation 31 | of the algorithms discussed within and allow the reproduction of our results. 32 | It is planned, however, to expand the package to make it a more full-fledged 33 | solution for pulse retrieval. 34 | 35 | .. _`Institute of Applied Physics`: https://www.iap.uni-jena.de/Micro_+structure+Technology/Research+Group%3Cbr%3EPhotonics+in+2D_Materials/Ultrashort+Laser+Pulse+Metrology.html 36 | .. _`Friedrich Schiller University Jena`: https://www.uni-jena.de 37 | 38 | 39 | User documentation 40 | ------------------ 41 | 42 | .. toctree:: 43 | :maxdepth: 1 44 | 45 | installation 46 | getting_started 47 | references 48 | 49 | API documentation 50 | ----------------- 51 | 52 | .. toctree:: 53 | :maxdepth: 1 54 | 55 | apidoc/pypret.fourier 56 | apidoc/pypret.pulse 57 | apidoc/pypret.pnps 58 | apidoc/pypret.retrieval 59 | apidoc/pypret.pulse_error 60 | apidoc/pypret.io 61 | apidoc/pypret.lib 62 | apidoc/pypret.frequencies 63 | apidoc/pypret.material 64 | apidoc/pypret.mesh_data 65 | apidoc/pypret.graphics 66 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Installation with ``pip`` or ``conda`` is currently neither supported nor 5 | necessary. Just clone the code repository from git:: 6 | 7 | git clone https://github.com/ncgeib/pypret.git 8 | 9 | and the directory ``pypret`` within contains all the required code of the 10 | package. Either add its location to your PYTHONPATH or copy it in your 11 | working directory. 12 | 13 | As the package matures I may add an installer. 14 | 15 | Requirements 16 | ------------ 17 | 18 | It requires Python >=3.6 and recent versions of NumPy and SciPy. Furthermore, 19 | it requires ``h5py`` for storage and loading. 20 | Optional dependencies are 21 | 22 | - pyfftw (for faster FFTs) 23 | - numba (for optimization of some low-level routines) 24 | - python-magic (to recognize zipped HDF5 files) 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 28 | goto end 29 | 30 | :help 31 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 32 | 33 | :end 34 | popd 35 | -------------------------------------------------------------------------------- /docs/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ---------- 3 | 4 | .. [Geib2019] Nils C. Geib, Matthias Zilk, Thomas Pertsch, and Falk 5 | Eilenberger, "Common pulse retrieval algorithm: a fast and universal 6 | method to retrieve ultrashort pulses," Optica 6, 495-505 (2019) 7 | .. [Lozovoy2004] V. V. Lozovoy, I. Pastirk and M. Dantus, "Multiphoton intrapulse 8 | interference. IV. Ultrashort laser pulse spectral phase 9 | characterization and compensation," Opt. Lett. 29, 775-777 10 | (OSA, 2004). 11 | .. [Xu2006] B. Xu, J. M. Gunn, J. M. D. Cruz, V. V. Lozovoy and M. Dantus, 12 | "Quantitative investigation of the multiphoton intrapulse 13 | interference phase scan method for simultaneous phase measurement 14 | and compensation of femtosecond laser pulses," J. Opt. Soc. Am. B 15 | 23, 750-759 (OSA, 2006). 16 | .. [Miranda2012a] M. Miranda, T. Fordell, C. Arnold, A. L'Huillier and H. Crespo, 17 | "Simultaneous compression and characterization of ultrashort laser 18 | pulses using chirped mirrors and glass wedges," Opt. Express 20, 19 | 688-697 (OSA, 2012). 20 | .. [Miranda2012b] M. Miranda, C. L. Arnold, T. Fordell, F. Silva, B. Alonso, R. 21 | Weigand, A. L'Huillier and H. Crespo, "Characterization of 22 | broadband few-cycle laser pulses with the d-scan technique," Opt. 23 | Express 20, 18732-18743 (OSA, 2012). 24 | .. [Kane1993] D. J. Kane and R. Trebino, "Characterization of arbitrary 25 | femtosecond pulses using frequency-resolved optical gating," IEEE 26 | J. Quant. Electron. 29, 571-579 (IEEE, 1993). 27 | .. [Kane1999] D. J. Kane, "Recent progress toward real-time measurement of 28 | ultrashort laser pulses," IEEE J. Quant. Electron. 35, 421-431 29 | (IEEE, 1999). 30 | .. [Trebino2000] R. Trebino, "Frequency-Resolved Optical Gating: The Measurement of 31 | Ultrashort Laser Pulses," , (Springer US, 2000). 32 | .. [Witting2016] T. Witting, D. Greening, D. Walke, P. Matia-Hernando, T. Barillot, 33 | J. P. Marangos and J. W. G. Tisch, "Time-domain ptychography of 34 | over-octave-spanning laser pulses in the single-cycle regime," Opt. 35 | Lett. 41, 4218-4221 (OSA, 2016). 36 | .. [Dorrer2002] C. Dorrer and I. A. Walmsley, "Accuracy criterion for ultrashort 37 | pulse characterization techniques: application to spectral phase 38 | interferometry for direct electric field reconstruction," J. Opt. 39 | Soc. Am. B 19, 1019-1029 (OSA, 2002). 40 | .. [Briggs1995] W. L. Briggs and v. E. Henson, "The DFT: an owners' manual for the 41 | discrete Fourier transform," (SIAM, 1995). 42 | .. [Hansen2014] E. W. Hansen, "Fourier transforms: principles and applications," (John 43 | Wiley & Sons, 2014). 44 | .. [Trefethen2014] L. N. Trefethen and J. A. C. Weideman, "The exponentially convergent 45 | trapezoidal rule," SIAM Review 56, 385-458 (2014). 46 | .. [DispersionFormulas] http://refractiveindex.info/database/doc/Dispersion%20formulas.pdf 47 | .. [Sidorenko2016] P. Sidorenko, O. Lahav, Z. Avnat and O. Cohen, "Ptychographic 48 | reconstruction algorithm for frequency-resolved optical gating: 49 | super-resolution and supreme robustness," Optica 3, 1320-1330 (OSA, 50 | 2016). 51 | .. [Sidorenko2017] P. Sidorenko, O. Lahav, Z. Avnat and O. Cohen, "Ptychographic 52 | reconstruction algorithm for frequency resolved optical gating: 53 | super-resolution and extreme robustness: erratum," Optica 4, 54 | 1388-1389 (OSA, 2017). 55 | .. [Miranda2017] M. Miranda, J. Penedones, C. Guo, A. Harth, M. Louisy, L. Neoričić, 56 | A. L'Huillier and C. L. Arnold, "Fast iterative retrieval algorithm 57 | for ultrashort pulse characterization using dispersion scans," J. 58 | Opt. Soc. Am. B 34, 190-197 (OSA, 2017). 59 | .. [DeLong1994] K. W. DeLong, B. Kohler, K. Wilson, D. N. Fittinghoff and R. 60 | Trebino, "Pulse retrieval in frequency-resolved optical gating 61 | based on the method of generalized projections," Opt. Lett. 19, 62 | 2152-2154 (Optical Society of America, 1994) 63 | .. [Escoto2018] E. Escoto, A. Tajalli, T. Nagy and G. Steinmeyer, "Advanced phase 64 | retrieval for dispersion scan: a comparative study," J. Opt. Soc. 65 | Am. B 35, 8-19 (OSA, 2018). 66 | .. [Diels2006] J.-C. Diels and W. Rudolph, "Ultrashort laser pulse phenomena," 67 | 2nd ed. (Academic press, 2006) -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements file for readthedocs virtualenv 2 | sphinx>=1.8 3 | numpy 4 | scipy 5 | h5py -------------------------------------------------------------------------------- /pypret/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Disclaimer 3 | ---------- 4 | 5 | THIS CODE IS FOR EDUCATIONAL PURPOSES ONLY! The code in this package was not 6 | optimized for accuracy or performance. Rather it aims to provide a simple 7 | implementation of the basic algorithms. 8 | 9 | Author: Nils C. Geib, nils.geib@uni-jena.de 10 | """ 11 | __version__ = "0.1alpha" 12 | from .autocorrelation import autocorrelation 13 | from .fourier import FourierTransform 14 | from .pulse import Pulse 15 | from .random_pulse import random_pulse, random_gaussian 16 | from .pulse_error import pulse_error 17 | from .pnps import PNPS 18 | from .mesh_data import MeshData 19 | from .graphics import MeshDataPlot, PulsePlot 20 | from .retrieval import Retriever 21 | from . import lib 22 | from . import material 23 | from . import io 24 | from .io import load, save 25 | -------------------------------------------------------------------------------- /pypret/autocorrelation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from . import lib 3 | 4 | 5 | def autocorrelation(pulse, tau=None, collinear=False): 6 | ''' Calculates the intensity or second-order autocorrelation G2(tau). 7 | 8 | Parameters 9 | ---------- 10 | pulse : Pulse instance 11 | The pulse of which the autocorrelation is calculated 12 | tau : 1d-array, optional 13 | The delays at which the autocorrelation is evaluated. If `None` 14 | the temporal grid `pulse.t` of the pulse object is used. 15 | collinear : bool, optional 16 | Calculates the collinear autocorrelation, with background and higher 17 | frequency terms. Otherwise only the non-collinear intensity 18 | autocorrelation without background is calculated. Default is `False`. 19 | 20 | Returns 21 | ------- 22 | tau : 1d-array 23 | The delay axis at which the autocorrelation was evaluated. 24 | ac : 1d-array 25 | The autocorrelation signal. 26 | 27 | Notes 28 | ----- 29 | Calculates the following expression:: 30 | 31 | G2(tau) = int | [E(t-tau) + E(t)]^2 |^2 dt 32 | 33 | where E(t) is the real-valued electric field. 34 | This expression is expanded in terms of the complex-valued pulse envelope 35 | and then evaluated with help of the convolution theorem and the Fourier 36 | transform. Terms containing oscillations in t, e.g., exp(2j w0 t), are 37 | neglected. 38 | Specifically, it implements Eq. (9.7) on page 460 from [Diels2006]_. 39 | ''' 40 | ft, ift = pulse.ft.forward, pulse.ft.backward 41 | ift_at = pulse.ft.backward_at 42 | 43 | # if the delays are not provided, use the pulse grid 44 | fft = False 45 | if tau is None: 46 | fft = True # in this case we can use the FFT directly 47 | tau = pulse.t 48 | 49 | if collinear: 50 | f = pulse.field 51 | i = lib.abs2(f) # intensity 52 | if fft: 53 | # uses the n log(n) FFT 54 | a0 = ift(lib.abs2(ft(pulse.intensity))) 55 | a1 = ift(np.real(ft(f).conj() * ft(f * i))) 56 | a2 = ift(lib.abs2(ft(f * f))) 57 | else: 58 | # uses the n^2 DFT 59 | a0 = ift_at(lib.abs2(ft(pulse.intensity)), tau) 60 | a1 = ift_at(np.real(ft(f).conj() * ft(f * i)), tau) 61 | a2 = ift_at(lib.abs2(ft(f * f)), tau) 62 | 63 | ac = ( 4.0 * np.real(a0) # A0 64 | + pulse.dt * (i * i).sum() / np.pi # background 65 | + 8.0 * np.real(a1 * np.exp(1.0j * pulse.w0 * tau)) # A1 66 | + 2.0 * np.real(a2 * np.exp(2.0j * pulse.w0 * tau)) ) # A2 67 | # scale for classic 8:1 ratio 68 | ac *= 8.0 / np.max(ac) 69 | else: 70 | i = pulse.intensity 71 | ac_omega = 4.0 * lib.abs2(ft(i)) 72 | if fft: 73 | ac = ift(ac_omega).real 74 | else: 75 | ac = ift_at(ac_omega, tau).real 76 | # normalize 77 | ac /= np.max(ac) 78 | 79 | return tau, ac 80 | -------------------------------------------------------------------------------- /pypret/fourier.py: -------------------------------------------------------------------------------- 1 | """ This module implements the Fourier transforms on linear grids. 2 | 3 | The following code approximates the continuous Fourier transform (FT) on 4 | equidistantly spaced grids. While this is usually associated with 5 | 'just doing a fast Fourier transform (FFT)', surprisingly, much can be done 6 | wrong. 7 | 8 | The reason is that the correct expressions depend on the grid location. In 9 | fact, the FT can be calculated with one DFT but in general it requires a prior 10 | and posterior multiplication with phase factors. 11 | 12 | The FT convention we are going to use here is the following:: 13 | 14 | Ẽ(w) = 1/2pi ∫ E(t) exp(+i w t) dt 15 | E(t) = ∫ Ẽ(w) exp(-i t w) dw 16 | 17 | where w is the angular frequency. We can approximate these integrals by their 18 | Riemann sums on the following equidistantly spaced grids:: 19 | 20 | t_k = t_0 + k Δt, k=0, ..., N-1 21 | w_n = w_0 + n Δw, n=0, ..., N-1 22 | 23 | and define E_k = E(t_k) and Ẽ_n = Ẽ(w_n) to obtain:: 24 | 25 | Ẽ_n = Δt/2pi ∑_k E_k exp(+i w_n t_k) 26 | E_k = Δw ∑_n Ẽ_n exp(-i t_k w_n). 27 | 28 | To evaluate the sum using the FFT we can expand the exponential to obtain:: 29 | 30 | Ẽ_n = Δt/2pi exp(+i n t_0 Δw) ∑_k [E_k exp(+i t_k w_0) ] exp(+i n k Δt Δw) 31 | E_k = Δw exp(-i t_k w_0) ∑_n [Ẽ_n exp(-i n t_0 Δw)] exp(-i k n Δt Δw) 32 | 33 | Additionally, we have to require the so-called reciprocity relation for 34 | the grid spacings:: 35 | 36 | ! 37 | Δt Δw = 2pi / N = ζ (reciprocity relation) 38 | 39 | This is what enables us to use the DFT/FFT! Now we look at the definition of 40 | the FFT in NumPy:: 41 | 42 | fft[x_m] -> X_k = ∑_m exp(-2pi i m k / N) 43 | ifft[X_k] -> x_m = 1/N ∑_k exp(+2pi i k m / N) 44 | 45 | which gives the final expressions:: 46 | 47 | Ẽ_n = Δt N/2pi r_n ifft[E_k s_k ] 48 | E_k = Δw s_k^* fft[Ẽ_n r_n^*] 49 | 50 | with r_n = exp(+i n t_0 Δw) 51 | s_k = exp(+i t_k w_0) 52 | 53 | where ^* means complex conjugation. We see that the array to be transformed 54 | has to be multiplied with an appropriate phase factor before and after 55 | performing the DFT. And those phase factors mainly depend on the starting 56 | points of the grids: w_0 and t_0. Note also that due to our sign convention 57 | for the FT we have to use ifft for the forward transform and vice versa. 58 | 59 | Trivially, we can see that for ``w_0 = t_0 = 0`` the phase factors vanish and 60 | the FT is approximated well by just the DFT. However, in optics these 61 | grids are unusual. 62 | For ``w_0 = l Δw`` and ``t_0 = m Δt``, where l, m are integers (i.e., w_0 and 63 | t_0 are multiples of the grid spacing), the phase factors can be 64 | incorperated into the DFT. Then the phase factors can be replaced by circular 65 | shifts of the input and output arrays. 66 | 67 | This is exactly what the functions (i)fftshift are doing for one specific 68 | choice of l and m, namely for:: 69 | 70 | t_0 = -floor(N/2) Δt 71 | w_0 = -floor(N/2) Δw. 72 | 73 | In this specific case only we can approximate the FT by:: 74 | 75 | Ẽ_n = Δt N/2pi fftshift(ifft(ifftshift(E_k))) 76 | E_k = Δw fftshift( fft(ifftshift(Ẽ_n))) (no mistake!) 77 | 78 | We see that the ifftshift _always_ has to appear on the inside. Failure to do 79 | so will still be correct for even N (here fftshift is the same as ifftshift) 80 | but will produce wrong results for odd N. 81 | 82 | Additionally you have to watch out not to violate the assumptions for the 83 | grid positions. Using a symmetrical grid, e.g.,:: 84 | 85 | x = linspace(-1, 1, 128) 86 | 87 | will also produce wrong results, as the elements of x are not multiples of the 88 | grid spacing (but shifted by half a grid point). 89 | 90 | The main drawback of this approach is that circular shifts are usually far more 91 | time- and memory-consuming than an elementwise multiplication, especially for 92 | higher dimensions. In fact I see no advantage in using the shift approach at 93 | all. But for some reason it got stuck in the minds of people and you find the 94 | notion of having to re-order the output of the DFT everywhere. 95 | 96 | Long story short: here we are going to stick with multiplying the correct 97 | phase factors. The code tries to follow the notation used above. 98 | 99 | Good, more comprehensive expositions of the issues above can be found in 100 | [Briggs1995]_ and [Hansen2014]_. For the reason why the first-order 101 | approximation to the Riemann integral suffices, see [Trefethen2014]_. 102 | """ 103 | import numpy as np 104 | # scipy.fftpack is still faster than numpy.fft (should change in numpy 1.17) 105 | import scipy.fftpack as fft 106 | from . import io 107 | from .lib import twopi, sqrt2pi 108 | _fft_backend = 'scipy' 109 | try: 110 | import pyfftw 111 | _fft_backend = 'pyfftw' 112 | except ImportError: 113 | pass 114 | 115 | 116 | class FourierTransformBase(io.IO): 117 | """ This class implements the Fourier transform on linear grids. 118 | 119 | This simple implementation is mainly for educational use. 120 | 121 | Attributes 122 | ---------- 123 | N : int 124 | Size of the grid 125 | dt : float 126 | Temporal spacing 127 | dw : float 128 | Frequency spacing (angular frequency) 129 | t0 : float 130 | The first element of the temporal grid 131 | w0 : float 132 | The first element of the frequency grid 133 | t : 1d-array 134 | The temporal grid 135 | w : 1d-array 136 | The frequency grid (angular frequency) 137 | """ 138 | _io_store = ['N', 'dt', 'dw', 't0', 'w0'] 139 | 140 | def __init__(self, N, dt=None, dw=None, t0=None, w0=None): 141 | """ Creates conjugate grids and calculates the Fourier transform. 142 | 143 | Parameters 144 | ---------- 145 | N : int 146 | Array size 147 | dt : float, optional 148 | The temporal grid spacing. If ``None`` will be calculated by the 149 | reciprocity relation ``dt = 2 * pi / (N * dw)``. Exactly one of 150 | ``dt`` or ``dw`` has be provided. 151 | dw : float, optional 152 | The spectral grid spacing. If ``None`` will be calculated by the 153 | reciprocity relation ``dw = 2 * pi / (N * dt)``. Exactly one of 154 | ``dt`` or ``dw`` has be provided. 155 | t0 : float, optional 156 | The first element of the temporal grid. If ``None`` will be 157 | ``t0 = -floor(N/2) * dt``. 158 | w0 : float, optional 159 | The first element of the spectral grid. If ``None`` will be 160 | ``w0 = -floor(N/2) * dw``. 161 | """ 162 | if dw is None and dt is not None: 163 | dw = np.pi / (0.5 * N * dt) 164 | elif dt is None and dw is not None: 165 | dt = np.pi / (0.5 * N * dw) 166 | else: 167 | raise ValueError("Exactly one of the grid spacings has to be " 168 | "provided!") 169 | 170 | if t0 is None: 171 | t0 = -np.floor(0.5 * N) * dt 172 | if w0 is None: 173 | w0 = -np.floor(0.5 * N) * dw 174 | self.N = N 175 | self.dt = dt 176 | self.dw = dw 177 | self.t0 = t0 178 | self.w0 = w0 179 | self._post_init() 180 | 181 | def _post_init(self): 182 | """ Hook to initialize an object from storage. 183 | """ 184 | # calculate the grids 185 | n = k = np.arange(self.N) 186 | self.t = self.t0 + k * self.dt 187 | self.w = self.w0 + n * self.dw 188 | # pre-calculate the phase factors 189 | # TODO: possibly inaccurate for large t0, w0 190 | self._fr = self.dt * self.N / twopi * np.exp(1.0j * n * self.t0 * 191 | self.dw) 192 | self._fs = np.exp(1.0j * self.t * self.w0) 193 | # complex conjugate of the above 194 | self._br = np.exp(-1.0j * n * self.t0 * self.dw) 195 | self._bs = self.dw * np.exp(-1.0j * self.t * self.w0) 196 | 197 | def forward_at(self, x, w): 198 | """ Calculates the forward Fourier transform of `x` at the 199 | frequencies `w`. 200 | 201 | This function calculates the Riemann sum directly and has quadratic 202 | runtime. However, it can evaluate the integral at arbitrary 203 | frequencies, even if they are non-equidistantly spaced. Effectively, 204 | it performs a trigonometric interpolation. 205 | """ 206 | Dnk = self.dt / twopi * np.exp(1.0j * w[:, None] * self.t[None, :]) 207 | return Dnk @ x 208 | 209 | def backward_at(self, x, t): 210 | """ Calculates the backward Fourier transform of `x` at the 211 | times `t`. 212 | 213 | This function calculates the Riemann sum directly and has quadratic 214 | runtime. However, it can evaluate the integral at arbitrary 215 | times, even if they are non-equidistantly spaced. Effectively, 216 | it performs a trigonometric interpolation. 217 | """ 218 | Dkn = self.dw * np.exp(-1.0j * t[:, None] * self.w[None, :]) 219 | return Dkn @ x 220 | 221 | 222 | # ============================================================================= 223 | # Fourier backend selection 224 | # ============================================================================= 225 | if _fft_backend == "scipy": 226 | class FourierTransform(FourierTransformBase): 227 | 228 | def forward(self, x, out=None): 229 | """ Calculates the (forward) Fourier transform of ``x``. 230 | 231 | For n-dimensional arrays it operates on the last axis, which has 232 | to match the size of `x`. 233 | 234 | Parameters 235 | ---------- 236 | x : ndarray 237 | The array of which the Fourier transform will be calculated. 238 | out : ndarray or None, optional 239 | A location into which the result is stored. If not provided or 240 | None, a freshly-allocated array is returned. 241 | """ 242 | if out is None: 243 | out = np.empty(x.shape, dtype=np.complex128) 244 | out[:] = self._fr * fft.ifft(self._fs * x) 245 | return out 246 | 247 | def backward(self, x, out=None): 248 | """ Calculates the backward (inverse) Fourier transform of ``x``. 249 | 250 | For n-dimensional arrays it operates on the last axis, which has 251 | to match the size of `x`. 252 | 253 | Parameters 254 | ---------- 255 | x : ndarray 256 | The array of which the Fourier transform will be calculated. 257 | out : ndarray or None, optional 258 | A location into which the result is stored. If not provided or 259 | None, a freshly-allocated array is returned. 260 | """ 261 | if out is None: 262 | out = np.empty(x.shape, dtype=np.complex128) 263 | out[:] = self._bs * fft.fft(self._br * x) 264 | return out 265 | 266 | elif _fft_backend == "pyfftw": 267 | class FourierTransform(FourierTransformBase): 268 | 269 | def _post_init(self): 270 | super()._post_init() 271 | # do not need the additional N factor 272 | n = np.arange(self.N) 273 | self._fr = self.dt / twopi * np.exp(1.0j * n * self.t0 * self.dw) 274 | # create the aligned arrays 275 | a = self._field = pyfftw.empty_aligned(self.N, dtype="complex128") 276 | b = self._spectrum = pyfftw.empty_aligned(self.N, 277 | dtype="complex128") 278 | # instantiate the FFTW objects 279 | self._fft = pyfftw.FFTW(b, a, direction="FFTW_FORWARD") 280 | self._ifft = pyfftw.FFTW(a, b, direction="FFTW_BACKWARD") 281 | 282 | def forward(self, x, out=None): 283 | """ Calculates the (forward) Fourier transform of ``x``. 284 | 285 | For n-dimensional arrays it operates on the last axis, which has 286 | to match the size of `x`. 287 | 288 | Parameters 289 | ---------- 290 | x : ndarray 291 | The array of which the Fourier transform will be calculated. 292 | out : ndarray or None, optional 293 | A location into which the result is stored. If not provided or 294 | None, a freshly-allocated array is returned. 295 | """ 296 | if out is None: 297 | out = np.empty(x.shape, dtype=np.complex128) 298 | f, s = self._field, self._spectrum 299 | if x.ndim == 1: 300 | # fast code path for single dimension 301 | f[:] = x 302 | f *= self._fs 303 | self._ifft.execute() 304 | s *= self._fr 305 | out[:] = s 306 | else: 307 | # implicitly work along last axis and return copy 308 | for idx in np.ndindex(x.shape[:-1]): 309 | f[:] = x[idx] 310 | f *= self._fs 311 | self._ifft.execute() 312 | s *= self._fr 313 | out[idx] = s 314 | return out 315 | 316 | def backward(self, x, out=None): 317 | """ Calculates the backward (inverse) Fourier transform of ``x``. 318 | 319 | For n-dimensional arrays it operates on the last axis, which has 320 | to match the size of `x`. 321 | 322 | Parameters 323 | ---------- 324 | x : ndarray 325 | The array of which the Fourier transform will be calculated. 326 | out : ndarray or None, optional 327 | A location into which the result is stored. If not provided or 328 | None, a freshly-allocated array is returned. 329 | """ 330 | if out is None: 331 | out = np.empty(x.shape, dtype=np.complex128) 332 | f, s = self._field, self._spectrum 333 | if x.ndim == 1: 334 | # fast code path for single dimension 335 | s[:] = x 336 | s *= self._br 337 | self._fft.execute() 338 | f *= self._bs 339 | out[:] = f 340 | else: 341 | # implicitly work along last axis and return copy 342 | for idx in np.ndindex(x.shape[:-1]): 343 | s[:] = x[idx] 344 | s *= self._br 345 | self._fft.execute() 346 | f *= self._bs 347 | out[idx] = f 348 | return out 349 | 350 | 351 | class Gaussian: 352 | """ This class can be used for testing the Fourier transform. 353 | """ 354 | 355 | def __init__(self, dt, t0=0.0, phase=0.0): 356 | """ Instantiates a shifted Gaussian function. 357 | 358 | The Gaussian is calculated by:: 359 | 360 | f(t) = exp(-0.5 (t - t0)^2 / dt^2) * exp(1.0j * phase) 361 | 362 | Its Fourier transform is:: 363 | 364 | F(w) = dt/sqrt(2pi) exp(-0.5 * (w + phase)^2 * dt^2 + 365 | 1j * t0 * w) 366 | 367 | Parameters 368 | ---------- 369 | dt : float 370 | The standard deviation of the temporal amplitude distribution. 371 | t0 : float 372 | The center of the temporal amplitude distribution. 373 | phase : float 374 | The linear phase coefficient of the temporal distribution. 375 | """ 376 | self.dt = dt 377 | self.t0 = t0 378 | self.phase = phase 379 | 380 | def temporal(self, t): 381 | """ Returns the temporal distribution. 382 | """ 383 | arg = (t - self.t0) / self.dt 384 | return np.exp(-0.5 * arg**2) * np.exp(1.0j * self.phase * t) 385 | 386 | def spectral(self, w): 387 | """ Returns the spectral distribution. 388 | """ 389 | w = w + self.phase 390 | arg = w * self.dt 391 | return (self.dt * np.exp(-0.5 * arg**2) * np.exp(1.0j * self.t0 * w) / 392 | sqrt2pi) 393 | -------------------------------------------------------------------------------- /pypret/frequencies.py: -------------------------------------------------------------------------------- 1 | """ This module handles conversion between frequency units. 2 | 3 | The supported units and their shorthands are: 4 | 5 | - wl : wavelength in meter 6 | - om: angular frequency in rad/s 7 | - f: frequency in 1/s 8 | - k: angular wavenumber in rad/m 9 | 10 | The conversion functions have the form `shorthand2shorthand` which is not 11 | pythonic but very short. A more pythonic conversion can be achieved by using 12 | the `convert` function 13 | 14 | >>> convert(x, 'wl', 'om') 15 | 16 | The shorthands will be used throughout the package to identify frequency units. 17 | 18 | The functions in this module should be used wherever a frequency convention 19 | is necessary to avoid mistakes and make the code more expressive. 20 | """ 21 | from copy import copy 22 | from .lib import sol, twopi 23 | 24 | 25 | frequency_labels = { 26 | 'wl': 'wavelength', 27 | 'om': 'angular frequency', 28 | 'f': 'frequency', 29 | 'k': 'angular wavenumber' 30 | } 31 | 32 | frequency_units = { 33 | 'wl': 'm', 34 | 'om': 'Hz rad', 35 | 'f': 'Hz', 36 | 'k': 'rad/m' 37 | } 38 | 39 | 40 | def om2wl(om): 41 | return twopi/om*sol 42 | 43 | 44 | def k2wl(k): 45 | return twopi/k 46 | 47 | 48 | def f2wl(f): 49 | return sol/f 50 | 51 | 52 | def wl2f(wl): 53 | return sol/wl 54 | 55 | 56 | def om2f(om): 57 | return om/twopi 58 | 59 | 60 | def k2f(k): 61 | return k*sol/twopi 62 | 63 | 64 | def wl2om(wl): 65 | return twopi*sol/wl 66 | 67 | 68 | def f2om(f): 69 | return twopi*f 70 | 71 | 72 | def k2om(k): 73 | return k*sol 74 | 75 | 76 | def wl2k(wl): 77 | return twopi/wl 78 | 79 | 80 | def om2k(om): 81 | return om/sol 82 | 83 | 84 | def f2k(f): 85 | return twopi*f/sol 86 | 87 | 88 | # this dictionary can be used for programmatic conversions 89 | conversions = { 90 | 'wl': { 91 | 'wl': lambda x: copy(x), 92 | 'om': wl2om, 93 | 'f': wl2f, 94 | 'k': wl2k 95 | }, 96 | 'om': { 97 | 'wl': om2wl, 98 | 'om': lambda x: copy(x), 99 | 'f': om2f, 100 | 'k': om2k 101 | }, 102 | 'f': { 103 | 'wl': f2wl, 104 | 'om': f2om, 105 | 'f': lambda x: copy(x), 106 | 'k': f2k 107 | }, 108 | 'k': { 109 | 'wl': k2wl, 110 | 'om': k2om, 111 | 'f': k2f, 112 | 'k': lambda x: copy(x) 113 | } 114 | } 115 | 116 | 117 | def convert(x, unit1, unit2): 118 | """ Convert between two frequency units. 119 | 120 | Parameters 121 | ---------- 122 | x : float or array_like 123 | Numerical value or array that should be converted. 124 | unit1, unit2 : str 125 | Shorthands for the original unit (`unit1`) and the destination unit 126 | (`unit2`). 127 | 128 | Returns 129 | ------- 130 | float or array_like 131 | The converted numerical value or array. It will always be a copy, even 132 | if `unit1 == unit2`. 133 | 134 | Notes 135 | ----- 136 | Unit shorthands can be any of 137 | `wl` : wavelength in meter 138 | `om` : angular frequency in rad/s 139 | `f` : frequency in 1/s 140 | `k` : angular wavenumber in rad/m 141 | """ 142 | return conversions[unit1][unit2](x) 143 | -------------------------------------------------------------------------------- /pypret/graphics.py: -------------------------------------------------------------------------------- 1 | """ This module implements several helper routines for plotting. 2 | """ 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.ticker import EngFormatter 6 | from . import lib 7 | from .frequencies import convert 8 | 9 | 10 | def plot_meshdata(ax, md, cmap="nipy_spectral"): 11 | x, y = lib.edges(md.axes[1]), lib.edges(md.axes[0]) 12 | im = ax.pcolormesh(x, y, md.data, cmap=cmap) 13 | ax.set_xlabel(md.labels[1]) 14 | ax.set_ylabel(md.labels[0]) 15 | 16 | fx = EngFormatter(unit=md.units[1]) 17 | ax.xaxis.set_major_formatter(fx) 18 | fy = EngFormatter(unit=md.units[0]) 19 | ax.yaxis.set_major_formatter(fy) 20 | return im 21 | 22 | 23 | class MeshDataPlot: 24 | 25 | def __init__(self, mesh_data, plot=True, **kwargs): 26 | self.md = mesh_data 27 | if plot: 28 | self.plot(**kwargs) 29 | 30 | def plot(self, show=True): 31 | md = self.md 32 | 33 | fig, ax = plt.subplots() 34 | im = plot_meshdata(ax, md, "nipy_spectral") 35 | fig.colorbar(im, ax=ax) 36 | 37 | self.fig, self.ax = fig, ax 38 | self.im = im 39 | if show: 40 | fig.tight_layout() 41 | plt.show() 42 | 43 | def show(self): 44 | plt.show() 45 | 46 | 47 | def plot_complex(x, y, ax, ax2, yaxis='intensity', limit=False, 48 | phase_blanking=False, phase_blanking_threshold=1e-3, 49 | amplitude_line="r-", phase_line="b-"): 50 | if yaxis == "intensity": 51 | amp = lib.abs2(y) 52 | elif yaxis == "amplitude": 53 | amp = np.abs(y) 54 | else: 55 | raise ValueError("yaxis mode '%s' is unknown!" % yaxis) 56 | phase = lib.phase(y) 57 | # center phase by weighted mean 58 | phase -= lib.mean(phase, amp * amp) 59 | if phase_blanking: 60 | x2, phase2 = lib.mask_phase(x, amp, phase, phase_blanking_threshold) 61 | else: 62 | x2, phase2 = x, phase 63 | if limit: 64 | xlim = lib.limit(x, amp) 65 | ax.set_xlim(xlim) 66 | f = (x2 >= xlim[0]) & (x2 <= xlim[1]) 67 | ax2.set_ylim(lib.limit(phase2[f], padding=0.05)) 68 | 69 | li1, = ax.plot(x, amp, amplitude_line) 70 | li2, = ax2.plot(x2, phase2, phase_line) 71 | 72 | return li1, li2, amp, phase 73 | 74 | 75 | class PulsePlot: 76 | 77 | def __init__(self, pulse, plot=True, **kwargs): 78 | self.pulse = pulse 79 | if plot: 80 | self.plot(**kwargs) 81 | 82 | def plot(self, xaxis='wavelength', yaxis='intensity', limit=True, 83 | oversampling=False, phase_blanking=False, 84 | phase_blanking_threshold=1e-3, show=True): 85 | pulse = self.pulse 86 | 87 | fig, axs = plt.subplots(1, 2) 88 | ax1, ax2 = axs.flat 89 | ax12 = ax1.twinx() 90 | ax22 = ax2.twinx() 91 | 92 | if oversampling: 93 | t = np.linspace(pulse.t[0], pulse.t[-1], pulse.N * oversampling) 94 | field = pulse.field_at(t) 95 | else: 96 | t = pulse.t 97 | field = pulse.field 98 | 99 | # time domain 100 | li11, li12, tamp, tpha = plot_complex(t, field, ax1, ax12, yaxis=yaxis, 101 | phase_blanking=phase_blanking, limit=limit, 102 | phase_blanking_threshold=phase_blanking_threshold) 103 | fx = EngFormatter(unit="s") 104 | ax1.xaxis.set_major_formatter(fx) 105 | ax1.set_title("time domain") 106 | ax1.set_xlabel("time") 107 | ax1.set_ylabel(yaxis) 108 | ax12.set_ylabel("phase (rad)") 109 | 110 | # frequency domain 111 | if oversampling: 112 | w = np.linspace(pulse.w[0], pulse.w[-1], pulse.N * oversampling) 113 | spectrum = pulse.spectrum_at(w) 114 | else: 115 | w = pulse.w 116 | spectrum = pulse.spectrum 117 | 118 | if xaxis == "wavelength": 119 | w = convert(w + pulse.w0, "om", "wl") 120 | unit = "m" 121 | label = "wavelength" 122 | elif xaxis == "frequency": 123 | w = w 124 | unit = " rad Hz" 125 | label = "frequency" 126 | 127 | li21, li22, samp, spha = plot_complex(w, spectrum, ax2, ax22, yaxis=yaxis, 128 | phase_blanking=phase_blanking, limit=limit, 129 | phase_blanking_threshold=phase_blanking_threshold) 130 | fx = EngFormatter(unit=unit) 131 | ax2.xaxis.set_major_formatter(fx) 132 | ax2.set_title("frequency domain") 133 | ax2.set_xlabel(label) 134 | ax2.set_ylabel(yaxis) 135 | ax22.set_ylabel("phase (rad)") 136 | 137 | self.fig = fig 138 | self.ax1, self.ax2 = ax1, ax2 139 | self.ax12, self.ax22 = ax12, ax22 140 | self.li11, self.li12, self.li21, self.li22 = li11, li12, li21, li22 141 | self.tamp, self.tpha = tamp, tpha 142 | self.samp, self.spha = samp, spha 143 | 144 | if show: 145 | fig.tight_layout() 146 | plt.show() 147 | -------------------------------------------------------------------------------- /pypret/io/__init__.py: -------------------------------------------------------------------------------- 1 | """ A subpackage that provides Python object persistence in HDF5 files. 2 | 3 | It was written to make the storage of arbitrary nested Python structures 4 | in the exchangable HDF5 format easy. Its main purpose is to easily add 5 | persistence to existing numerical or data analysis codes. 6 | 7 | While the files itself are plain HDF5 and can be read in any language 8 | supporting HDF5, the format is not compatible to Matlab's own file format. 9 | If you are searching for such a solution look at the hdf5storage package. 10 | 11 | Usage 12 | ----- 13 | The module exports a ``save()`` function that stores arbitrary structures 14 | of Python and NumPy data types. For example 15 | 16 | >>> x = {'data': [1, 2, 3], 'xrange': np.arange(5, dtype=np.uint8)} 17 | >>> io.save(x, "test.hdf5") 18 | 19 | This function should suffice for most needs as long as only standard types 20 | are used. The ``load()`` function loads these files and restores the structure 21 | and the types of the data: 22 | 23 | >>> io.load("test.hdf5") 24 | {'data': [1, 2, 3], 'xrange': array([0, 1, 2, 3, 4], dtype=uint8)} 25 | 26 | Custom Objects 27 | --------------- 28 | If you are using objects as simple containers without functionality you may 29 | consider using the SimpleNamespace class from the ``types`` module of the 30 | standard library. The advantage is that io knows how to handle it.:: 31 | 32 | from types import SimpleNamespace 33 | a = SimpleNamespace(name="my object", data=np.arange(5)) 34 | a.data2 = np.arange(10) 35 | copra.save(a) 36 | 37 | If your objects are containers with methods but without a custom ``__init__()`` 38 | the simplest way is to inherit or mix-in the ``IO`` class:: 39 | 40 | class Data(io.IO): 41 | x = 1 42 | 43 | def squared(self): 44 | return self.x * self.x 45 | 46 | When using the ``IO`` class by default all instance attributes are stored 47 | and loaded. More flexibility can be achieved by specifying ``_io``-attributes 48 | of your custom class. 49 | 50 | _io_store : list of str or None, optional 51 | Specify the the instance attributes that are stored exclusively. Acts 52 | as a whitelist. If ``None`` all instance attributes are stored. Default 53 | is ``None``. 54 | _io_store_not : list of str or None, optional 55 | Specify which instance attributes are not stored. Acts as a blacklist. 56 | If ``None`` no blacklisting is done. 57 | 58 | If you want to add attributes to storage you can call the 59 | ``_io_add_to_storage(key)`` method on your instance. 60 | The IO class initalizes the instance without calling ``__init__()``. Instead 61 | ``__new__()`` is called on the class and afterwards the ``_post_init()`` 62 | method which subclasses can implement. A fully working example of a class 63 | is the following (reduced from copra.FourierTransform):: 64 | 65 | class Grid(io.IO): 66 | _io_store = ['N', 'dx', 'x0'] 67 | 68 | def __init__(self, N, dx, x0=0.0): 69 | # This is _not_ called upon loading from storage 70 | self.N = N 71 | self.dx = dx 72 | self.x0 = x0 73 | self._post_init() 74 | 75 | def _post_init(self): 76 | # this is called upon loading from storage 77 | # calculate the grids 78 | n = np.arange(self.N) 79 | self.x = self.x0 + n * self.dx 80 | 81 | In this example the object can be exactly reproduced upon loading but only 82 | a minimal amount of storage is required. 83 | 84 | If you want to implement your own storage interface for a custom object 85 | you should inherit from ``IO`` and implement your own ``to_dict()`` and 86 | ``from_dict()`` methods. Look at the implementation of the default in ``IO`` 87 | to understand their behavior. 88 | 89 | 90 | File Format 91 | ----------- 92 | The file format this module uses is a straightforward mapping of Python 93 | types to the HDF5 data structure. Dictionaries and objects are mapped to 94 | HDF5 groups, numpy arrays use h5py's type translation. 95 | Iterables are converted to groups by introducing artificial keys of the 96 | type ``idx_%d``. This is rather inefficient which explains why the 97 | module should not be used to store large numerical arrays as a Python list. 98 | To store the type information it uses an HDF5 attribute ``__class__``. 99 | Furthermore, for scalars the attribute ``__dtype__`` and for strings the 100 | attribute ``__encoding__`` are additionally used. 101 | 102 | In conclusion, nested structures of Python types stored with this package are 103 | not suitable for exchanging. Dictionaries of numerical data stored with this 104 | package can be easily opened with any program that supports HDF5. 105 | """ 106 | from .handlers import (save_to_level, load_from_level, TypeHandler, 107 | InstanceHandler) 108 | from .io import save, load, IO, MetaIO 109 | -------------------------------------------------------------------------------- /pypret/io/handlers.py: -------------------------------------------------------------------------------- 1 | """ Implements functions that handle the serialization of types and classes. 2 | 3 | Type handlers store and load objects of exactly that type. Instance handlers 4 | work also work for subclasses of that type. 5 | 6 | The instance handlers are processed in the order they are stored. This means 7 | that if an object is an instance of several handled classes it will not raise 8 | an error and will be handled by the first matching handler in the OrderedDict. 9 | """ 10 | import numpy as np 11 | import types 12 | import inspect 13 | from collections import OrderedDict 14 | 15 | """ The handler dictionaries are automatically filled when Handler class 16 | definitions are parsed via the metaclass __new__ function. 17 | """ 18 | # Saver dictionaries: classes as keys, handler classes as values 19 | type_saver_handlers = dict() 20 | instance_saver_handlers = OrderedDict() 21 | # Loader dictionaries: class names as keys, handler classes as values 22 | loader_handlers = dict() 23 | 24 | 25 | def classname(val): 26 | """ Returns a qualified class name as string. 27 | 28 | The qualified class name consists of the module and the class name, 29 | separated by a dot. If an instance is passed to this function, the name 30 | of its class is returned. 31 | 32 | Parameters 33 | ---------- 34 | val : instance or class 35 | The instance or a class of which the qualified class name is returned. 36 | 37 | Returns 38 | ------- 39 | str : The qualified class name. 40 | """ 41 | if inspect.isclass(val): 42 | return ".".join([val.__module__, 43 | val.__name__]) 44 | return ".".join([val.__class__.__module__, 45 | val.__class__.__name__]) 46 | 47 | 48 | def set_attribute(level, key, value): 49 | level.attrs[key] = np.string_(value) 50 | 51 | 52 | def get_attribute(level, key): 53 | return level.attrs[key].decode('ascii') 54 | 55 | 56 | def set_classname(level, clsname): 57 | set_attribute(level, '__class__', clsname) 58 | 59 | 60 | def get_classname(level): 61 | return get_attribute(level, '__class__') 62 | 63 | 64 | def save_to_level(val, level, options, name=None): 65 | """ A generic save function that dispatches the correct handler. 66 | """ 67 | t = type(val) 68 | if t in type_saver_handlers: 69 | return type_saver_handlers[t].save_to_level(val, level, options, name) 70 | for i in instance_saver_handlers: 71 | if isinstance(val, i): 72 | return instance_saver_handlers[i].save_to_level(val, level, 73 | options, name) 74 | raise ValueError("%s of type %s is not supported by any handler!" % 75 | (str(val), str(t))) 76 | 77 | 78 | def load_from_level(level, obj=None): 79 | """ Loads an object from an HDF5 group or dataset. 80 | 81 | Parameters 82 | ---------- 83 | level : h5py.Dataset or h5py.Group 84 | An HDF5 node that stores an object in a valid format. 85 | obj : instance or None 86 | If provided this instance will be updated from the HDF5 node instead 87 | of creating a new instance of the stored object. 88 | 89 | Returns 90 | ------- 91 | instance of the stored object 92 | """ 93 | clsname = get_classname(level) 94 | if clsname not in loader_handlers: 95 | raise ValueError('Class `%s` has no registered handler.' % clsname) 96 | handler = loader_handlers[clsname] 97 | return handler.load_from_level(level, obj=obj) 98 | 99 | 100 | class TypeRegister(type): 101 | """ Metaclass that registers a type handler in a global dictionary. 102 | """ 103 | def __new__(cls, clsname, bases, attrs): 104 | # convert all methods to classmethods 105 | for attr_name, attr_value in attrs.items(): 106 | if isinstance(attr_value, types.FunctionType): 107 | attrs[attr_name] = classmethod(attr_value) 108 | newclass = super().__new__(cls, clsname, bases, attrs) 109 | # register the class as a handler for all specified types 110 | for t in newclass.types: 111 | newclass.register(t) 112 | return newclass 113 | 114 | 115 | class InstanceRegister(type): 116 | """Metaclass that registers an instance handler in global dictionary. 117 | """ 118 | def __new__(cls, clsname, bases, attrs): 119 | # convert all methods to classmethods 120 | for attr_name, attr_value in attrs.items(): 121 | if isinstance(attr_value, types.FunctionType): 122 | attrs[attr_name] = classmethod(attr_value) 123 | newclass = super().__new__(cls, clsname, bases, attrs) 124 | # register the class as a handler for all specified instances 125 | for t in newclass.instances: 126 | newclass.register(t) 127 | return newclass 128 | 129 | 130 | class Handler: 131 | # a default for subclasses (only `group` subclasses have to overwrite) 132 | level_type = 'dataset' 133 | 134 | @classmethod 135 | def save_to_level(cls, val, level, options, name): 136 | """ A generic wrapper around the custom save method that each 137 | handler implements. It creates a dataset or a group depending 138 | on the `level_type` class attribute and sets the `__class__` 139 | attribute correctly. 140 | For more flexibility subclasses can overwrite this method. 141 | """ 142 | # get the qualified class name of the object to be saved 143 | clsname = classname(val) 144 | # create the dataset or the group and call the save method on it 145 | if cls.is_dataset() and name is None: 146 | # if we want to save a dataset in the root group (name = None) 147 | # we have to give it a name 148 | name = "default" 149 | if cls.is_group(): 150 | if name is not None: 151 | level = cls.create_group(level, name, options) 152 | set_classname(level, clsname) 153 | ret = cls.save(val, level, options, name) 154 | if cls.is_dataset(): 155 | set_classname(ret, clsname) 156 | return ret 157 | 158 | @classmethod 159 | def load_from_level(cls, level, obj=None): 160 | """ The loader that has to be implemented by subclasses. 161 | """ 162 | raise NotImplementedError() 163 | 164 | @classmethod 165 | def create_group(cls, level, name, options): 166 | return level.create_group(name) 167 | 168 | @classmethod 169 | def create_dataset(cls, data, level, name, **kwargs): 170 | ds = level.create_dataset(name, data=data, **kwargs) 171 | return ds 172 | 173 | @classmethod 174 | def is_group(cls): 175 | return cls.level_type == 'group' 176 | 177 | @classmethod 178 | def is_dataset(cls): 179 | return cls.level_type == 'dataset' 180 | 181 | @classmethod 182 | def get_type(cls, level): 183 | return cls.casting[get_classname(level)] 184 | 185 | 186 | class TypeHandler(Handler, metaclass=TypeRegister): 187 | """ Handles data of a specific type or class. 188 | """ 189 | types = [] 190 | casting = {} 191 | 192 | @classmethod 193 | def register(cls, t): 194 | global type_saver_handlers, loader_handlers 195 | if t in type_saver_handlers: 196 | raise ValueError('Type `%s` is already handled by `%s`.' % 197 | (str(t), str(type_saver_handlers[t]))) 198 | typename = classname(t) 199 | type_saver_handlers[t] = cls 200 | loader_handlers[typename] = cls 201 | cls.casting[typename] = t 202 | 203 | 204 | class InstanceHandler(Handler, metaclass=InstanceRegister): 205 | """ Handles all instances of a specific (parent) class. 206 | 207 | If an instance is subclass to several classes for which a handler exists, 208 | no error will be raised (in contrast to TypeHandler). Rather, the first 209 | match in the global instance_saver_handlers OrderedDict will be used. 210 | """ 211 | instances = [] 212 | casting = {} 213 | 214 | @classmethod 215 | def register(cls, t): 216 | global instance_saver_handlers, loader_handlers 217 | if t in instance_saver_handlers: 218 | raise ValueError('Instance `%s` is already handled by `%s`.' % 219 | (str(t), str(instance_saver_handlers[t]))) 220 | typename = classname(t) 221 | instance_saver_handlers[t] = cls 222 | loader_handlers[typename] = cls 223 | cls.casting[typename] = t 224 | 225 | 226 | # Specific handlers 227 | class NoneHandler(TypeHandler): 228 | types = [type(None)] 229 | 230 | def save(cls, val, level, options, name): 231 | ds = cls.create_dataset(0, level, name, **options(0)) 232 | return ds 233 | 234 | def load_from_level(cls, level, obj=None): 235 | return None 236 | 237 | 238 | class ScalarHandler(TypeHandler): 239 | types = [float, bool, complex, np.int8, np.int16, np.int32, 240 | np.int64, np.uint8, np.uint16, np.uint32, np.uint64, 241 | np.float16, np.float32, np.float64, np.bool_, np.complex64, 242 | np.complex128] 243 | 244 | def save(cls, val, level, options, name): 245 | ds = cls.create_dataset(val, level, name, **options(val)) 246 | return ds 247 | 248 | def load_from_level(cls, level, obj=None): 249 | # cast to the correct type 250 | type_ = cls.get_type(level) 251 | # retrieve scalar dataset 252 | return type_(level[()]) 253 | 254 | 255 | class IntHandler(TypeHandler): 256 | """ Special int handler to deal with Python's variable size ints. 257 | 258 | They are stored as byte arrays. Probably not the most efficient solution... 259 | """ 260 | types = [int] 261 | 262 | def save(cls, val, level, options, name): 263 | val = val.to_bytes((val.bit_length() + 7) // 8, byteorder='little') 264 | data = np.frombuffer(val, dtype=np.uint8) 265 | ds = cls.create_dataset(data, level, name, **options(data)) 266 | return ds 267 | 268 | def load_from_level(cls, level, obj=None): 269 | return int.from_bytes(level[:].tobytes(), byteorder='little') 270 | 271 | 272 | class TimeHandler(TypeHandler): 273 | types = [np.datetime64, np.timedelta64] 274 | 275 | def save(cls, val, level, options, name): 276 | val2 = val.view(' self.compression_threshold: 45 | kwargs = self.compressed_dataset.__dict__ 46 | else: 47 | kwargs = self.dataset.__dict__ 48 | return kwargs 49 | 50 | 51 | DEFAULT_OPTIONS = HDF5Options() 52 | -------------------------------------------------------------------------------- /pypret/io/tests/test_io.py: -------------------------------------------------------------------------------- 1 | """ This module tests the io subpackage implementation. 2 | 3 | Author: Nils Geib, nils.geib@uni-jena.de 4 | """ 5 | import numpy as np 6 | from pypret import io 7 | from pprint import pformat 8 | from os import remove 9 | 10 | 11 | class IO1(io.IO): 12 | x = 1 13 | 14 | def squared(self): 15 | return self.x * self.x 16 | 17 | def __repr__(self): 18 | return "IO1(x={0})".format(self.x) 19 | 20 | 21 | class Grid(io.IO): 22 | _io_store = ['N', 'dx', 'x0'] 23 | 24 | def __init__(self, N, dx, x0=0.0): 25 | # This is _not_ called upon loading from storage 26 | self.N = N 27 | self.dx = dx 28 | self.x0 = x0 29 | self._post_init() 30 | 31 | def _post_init(self): 32 | # this is called upon loading from storage 33 | # calculate the grids 34 | n = np.arange(self.N) 35 | self.x = self.x0 + n * self.dx 36 | 37 | def __repr__(self): 38 | return "TestIO1(N={0}, dx={1}, x0={2})".format( 39 | self.N, self.dx, self.x0) 40 | 41 | 42 | def test_io(): 43 | # test flat arrays 44 | _assert_io(np.arange(5)) 45 | _assert_io(np.arange(5, dtype=np.complex128)) 46 | # test nested structures of various types 47 | _assert_io([{'a': 1.0, 'b': np.uint16(1)}, np.random.rand(10), 48 | True, None, "hello", 1231241512354134123412353124, b"bytes"]) 49 | _assert_io([[[1]], [[[[1], 2], 3], 4], 5]) 50 | # Test custom objects 51 | _assert_io(IO1()) 52 | _assert_io(Grid(128, 0.23, x0=-2.3)) 53 | 54 | 55 | def _assert_io(x): 56 | """ This is slightly hacky: we use pprint to recursively print the objects 57 | and compare the resulting strings to make sure they are the same. This 58 | only works as pprint sorts the dictionary entries by their keys before 59 | printing. 60 | 61 | This requires custom objects to implement __repr__. 62 | """ 63 | io.save(x, "test.hdf5") 64 | x2 = io.load("test.hdf5") 65 | remove("test.hdf5") 66 | s1 = pformat(x) 67 | s2 = pformat(x2) 68 | if s1 != s2: 69 | print(s1) 70 | print(s2) 71 | assert False 72 | 73 | 74 | if __name__ == "__main__": 75 | test_io() 76 | -------------------------------------------------------------------------------- /pypret/lib.py: -------------------------------------------------------------------------------- 1 | """ Miscellaneous helper functions 2 | 3 | These functions fulfill small numerical tasks used in several places in the 4 | package. 5 | """ 6 | import numpy as np 7 | # make numba jit an optional dependence 8 | # see https://github.com/numba/numba/issues/3735 9 | try: 10 | from numba import jit 11 | except ImportError: 12 | def jit(pyfunc=None, **kwargs): 13 | def wrap(func): 14 | return func 15 | if pyfunc is not None: 16 | return wrap(pyfunc) 17 | else: 18 | return wrap 19 | 20 | # Constants for convenience (not more accurate) 21 | # two pi 22 | twopi = 6.2831853071795862 23 | # sqrt(2) 24 | sqrt2 = 1.4142135623730951 25 | # speed of light 26 | sol = 299792458.0 27 | 28 | # Constants that give slightly more accuracy (~1 ulp) 29 | # sqrt(2 * pi) 30 | sqrt2pi = 2.5066282746310007 31 | 32 | 33 | def as_list(x): 34 | """ Try to convert argument to list and return it. 35 | 36 | Useful to implement function arguments that could be scalar values 37 | or lists. 38 | """ 39 | try: 40 | return list(x) 41 | except TypeError: 42 | return list([x]) 43 | 44 | @jit(nopython=True, cache=True) 45 | def abs2(x): 46 | """ Calculates the squared magnitude of a complex array. 47 | """ 48 | return x.real * x.real + x.imag * x.imag 49 | 50 | 51 | def rms(x, y): 52 | """ Calculates the root mean square (rms) error between ``x`` and ``y``. 53 | """ 54 | return np.sqrt(abs2(x - y).mean()) 55 | 56 | 57 | @jit(nopython=True, cache=True) 58 | def norm2(x): 59 | """ Calculates the squared L2 or Euclidian norm of array ``x``. 60 | """ 61 | return abs2(x).sum() 62 | 63 | 64 | @jit(nopython=True, cache=True) 65 | def norm(x): 66 | """ Calculates the L2 or Euclidian norm of array ``x``. 67 | """ 68 | return np.sqrt(abs2(x).sum()) 69 | 70 | 71 | def phase(x): 72 | """ The phase of a complex array.""" 73 | return np.unwrap(np.angle(x)) 74 | 75 | 76 | def nrms(x, y): 77 | """ Calculates the normalized rms error between ``x`` and ``y``. 78 | 79 | The convention for normalization varies. Here we use:: 80 | 81 | max |y| 82 | 83 | as normalization. 84 | """ 85 | n = np.abs(y).max() 86 | if n == 0.0: 87 | raise ValueError("Second array cannot be zero.") 88 | return rms(x, y) / n 89 | 90 | 91 | def mean(x, y): 92 | """ Calculates the mean of the distribution described by (x, y). 93 | """ 94 | return np.sum(x * y) / np.sum(y) 95 | 96 | 97 | def variance(x, y): 98 | """ Calculates the variance of the distribution described by (x, y). 99 | """ 100 | dx = x - mean(x, y) 101 | return np.sum(dx * dx * y) / np.sum(y) 102 | 103 | 104 | def standard_deviation(x, y): 105 | """ Calculates the standard deviation of the distribution described by 106 | (x, y). 107 | """ 108 | return np.sqrt(variance(x, y)) 109 | 110 | 111 | def gaussian(x, x0=0.0, sigma=1.0): 112 | """ Calculates a Gaussian function with center ``x0`` and standard 113 | deviation ``sigma``. 114 | """ 115 | d = (x - x0) / sigma 116 | return np.exp(-0.5 * d * d) 117 | 118 | 119 | def rescale(x, window=[0.0, 1.0]): 120 | """ Rescales a numpy array to the range specified by ``window``. 121 | 122 | Default is [0, 1]. 123 | """ 124 | maxx = np.max(x) 125 | minx = np.min(x) 126 | return (x - minx) / (maxx - minx) * (window[1] - window[0]) + window[0] 127 | 128 | 129 | def marginals(data, normalize=False, axes=None): 130 | """ Calculates the marginals of the data array. 131 | 132 | axes specifies the axes of the marginals, e.g., the axes on which the 133 | sum is projected. 134 | 135 | If axis is None a list of all marginals is returned. 136 | """ 137 | if axes is None: 138 | axes = range(data.ndim) 139 | axes = as_list(axes) 140 | full_axes = list(range(data.ndim)) 141 | m = [] 142 | for i in axes: 143 | # for the marginal sum over all axes except the specified one 144 | margin_axes = tuple(j for j in full_axes if j != i) 145 | m.append(np.sum(data, axis=margin_axes)) 146 | if normalize: 147 | m = [rescale(mx) for mx in m] 148 | return tuple(m) if len(m) != 1 else m[0] 149 | 150 | 151 | def find(x, condition, n=1): 152 | """ Return the index of the nth element that fulfills the condition. 153 | """ 154 | search_n = 1 155 | for i in range(len(x)): 156 | if condition(x[i]): 157 | if search_n == n: 158 | return i 159 | search_n += 1 160 | return -1 161 | 162 | 163 | def best_scale(E, E0): 164 | """ Scales rho so that:: 165 | 166 | sum (rho * |E| - |E0|)^2 167 | 168 | is minimal. 169 | """ 170 | Eabs, E0abs = np.abs(E), np.abs(E0) 171 | return np.sum(Eabs * E0abs) / np.sum(Eabs * Eabs) 172 | 173 | 174 | def arglimit(y, threshold=1e-3, padding=0.0, normalize=True): 175 | """ Returns the first and last index where `y >= threshold * max(abs(y))`. 176 | """ 177 | t = np.abs(y) 178 | if normalize: 179 | t /= np.max(t) 180 | 181 | idx1 = find(t, lambda x: x >= threshold) 182 | if idx1 == -1: 183 | idx1 = 0 184 | idx2 = find(t[::-1], lambda x: x >= threshold) 185 | if idx2 == -1: 186 | idx2 = t.shape[0] - 1 187 | else: 188 | idx2 = t.shape[0] - 1 - idx2 189 | 190 | return (idx1, idx2) 191 | 192 | 193 | def limit(x, y=None, threshold=1e-3, padding=0.25, extend=True): 194 | """ Returns the maximum x-range where the y-values are sufficiently large. 195 | 196 | Parameters 197 | ---------- 198 | x : array_like 199 | The x values of the graph. 200 | y : array_like, optional 201 | The y values of the graph. If `None` the maximum range of `x` is 202 | used. That is only useful if `padding > 0`. 203 | threshold : float 204 | The threshold relative to the maximum of `y` of values that should be 205 | included in the bracket. 206 | padding : float 207 | The relative padding on each side in fractions of the bracket size. 208 | extend : bool, optional 209 | Signals if the returned range can be larger than the values in ``x``. 210 | Default is `True`. 211 | 212 | Returns 213 | ------- 214 | xl, xr : float 215 | Lowest and biggest value of the range. 216 | 217 | """ 218 | if y is None: 219 | x1, x2 = np.min(x), np.max(x) 220 | if not extend: 221 | return (x1, x2) 222 | else: 223 | idx1, idx2 = arglimit(y, threshold=threshold) 224 | x1, x2 = sorted([x[idx1], x[idx2]]) 225 | 226 | # calculate the padding 227 | if padding != 0.0: 228 | pad = (x2 - x1) * padding 229 | x1 -= pad 230 | x2 += pad 231 | 232 | if not extend: 233 | x1 = max(x1, np.min(x)) 234 | x2 = min(x2, np.max(x)) 235 | 236 | return (x1, x2) 237 | 238 | 239 | def fwhm(x, y): 240 | """ Calculates the full width at half maximum of the distribution described 241 | by (x, y). 242 | """ 243 | xl, xr = limit(x, y, threshold=0.5, padding=0.0) 244 | return np.abs(xr - xl) 245 | 246 | 247 | def edges(x): 248 | """ Calculates the edges of the array elements. 249 | 250 | Assuming that the input array contains the midpoints of a supposed data 251 | set, the function returns the (N+1) edges of the data set points. 252 | """ 253 | diff = np.diff(x) 254 | reverse = False 255 | if np.any(np.sign(diff) != np.sign(diff[0])): 256 | raise ValueError("Input array must be sorted") 257 | elif diff[0] < 0.0: 258 | x = x[::-1] 259 | reverse = True 260 | 261 | result = np.concatenate(( 262 | [1.5 * x[0] - 0.5 * x[1]], 263 | 0.5 * (x[1:] + x[:-1]), 264 | [1.5 * x[-1] - 0.5 * x[-2]] 265 | )) 266 | if reverse: 267 | result = result[::-1] 268 | 269 | return result 270 | 271 | 272 | def build_coords(*axes): 273 | """ Builds a coordinate array from the axes. 274 | """ 275 | AXES = np.meshgrid(*axes, indexing='ij') 276 | return np.stack(AXES, axis=-1) 277 | 278 | 279 | def mask_phase(x, amp, phase, threshold=1e-3): 280 | mask = (amp / np.max(amp) < threshold) 281 | blank_phase = np.ma.masked_array(phase, mask=mask) 282 | blank_x = np.ma.masked_array(x, mask=mask) 283 | return blank_x, blank_phase 284 | 285 | def retrieval_report(res): 286 | """ Simple helper that prints out important information from the 287 | retrieval result object. 288 | """ 289 | print("Retrieval report") 290 | print("trace error".ljust(15) + "R = %.17e".rjust(25) % res.trace_error) 291 | if hasattr(res, "trace_error_optimal"): 292 | print("min. trace error".ljust(15) + "R0 = %.17e".rjust(25) % res.trace_error_optimal) 293 | print("".ljust(15) + "R - R0 = %.17e".rjust(25) % (res.trace_error - res.trace_error_optimal)) 294 | print() 295 | print("pulse error".ljust(15) + "ε = %.17e".rjust(25) % res.pulse_error) 296 | -------------------------------------------------------------------------------- /pypret/material.py: -------------------------------------------------------------------------------- 1 | """ This module provides classes to calculate the refractive index 2 | based on Sellmeier equations. 3 | 4 | This is required to correctly model d-scan measurements. 5 | 6 | Currently only very few materials are implemented. But more should be easy 7 | to add. If the refractive index is described by formula 1 or 2 from 8 | refractiveindex.info you can simply instantiate `SellmeierF1` or `SellmeierF2`. 9 | If not, inherit from BaseMaterial and implement the `self._func` method. 10 | """ 11 | import numpy as np 12 | from .frequencies import convert 13 | from . import lib 14 | from . import io 15 | 16 | 17 | class BaseMaterial(io.IO): 18 | """ Abstract base class for dispersive materials. 19 | 20 | """ 21 | 22 | def __init__(self, coefficients, freq_range, scaling=1.0e6, 23 | check_bounds=True, name="", long_name=""): 24 | ''' Creates a dispersive material. 25 | 26 | Parameters 27 | ---------- 28 | coefficients: ndarray 29 | The Sellmeier coefficients. 30 | freq_range : iterable 31 | The wavelength range in which the Sellmeier equation is valid 32 | (given in m). 33 | check_bounds : bool, optional 34 | Specifies if the frequency argument should be checked on every 35 | evaluation to match the allowed range. 36 | scaling : float, optional 37 | Specifies the scaling of the Sellmeier formula. E.g., most 38 | Sellmeier formulas are defined in terms of µm (micrometer), 39 | whereas our function interface works in meter. In that case the 40 | scaling would be `1e6`. Default is `1.0e6`. 41 | ''' 42 | if len(freq_range) != 2: 43 | raise ValueError("Frequency range must specified with two elements.") 44 | self._coefficients = np.array(coefficients) 45 | self._range = np.array(freq_range) 46 | self._scaling = scaling 47 | self.check = check_bounds 48 | self.name = name 49 | self.long_name = long_name 50 | 51 | def _check(self, x): 52 | if not self.check: 53 | return 54 | minx, maxx = np.min(x), np.max(x) 55 | if (minx < self._range[0]) or (maxx > self._range[1]): 56 | raise ValueError('Wavelength array [%e, %e] outside of valid range ' 57 | 'of the Sellmeier equation [%e, %e].' % 58 | (minx, maxx, self._range[0], self._range[1])) 59 | 60 | def _convert(self, x, unit): 61 | '''This is intended for conversion to be used in `self._func`.''' 62 | if unit != 'wl': 63 | x = convert(x, unit, 'wl') 64 | self._check(x) 65 | if self._scaling != 1.0: 66 | x = x * self._scaling 67 | return x 68 | 69 | def n(self, x, unit='wl'): 70 | '''The refractive index at frequency `x` specified in units `unit`. ''' 71 | return self._func(self._convert(x, unit)) 72 | 73 | def k(self, x, unit='wl'): 74 | '''The wavenumber in the material in rad / m.''' 75 | wl = convert(x, unit, "wl") 76 | return self.n(wl, unit="wl") * lib.twopi / wl 77 | 78 | 79 | class SellmeierF1(BaseMaterial): 80 | ''' Defines a dispersive material via a specific Sellmeier equation. 81 | 82 | This subclass supports materials with a Sellmeier equation of the 83 | form:: 84 | 85 | n^2(l) - 1 = c1 + c2 * l^2 / (l2 - c3^2) + ... 86 | 87 | This is formula 1 from refractiveindex.info [DispersionFormulas]_. 88 | ''' 89 | def _func(self, x): 90 | c = self._coefficients 91 | x2 = x * x 92 | n2 = np.full_like(x, 1.0 + c[0]) 93 | for i in range(1, len(c)-1, 2): 94 | n2 += c[i] * x2 / (x2 - c[i+1] * c[i+1]) 95 | return np.sqrt(n2) 96 | 97 | 98 | class SellmeierF2(BaseMaterial): 99 | ''' Defines a dispersive material via a specific Sellmeier equation. 100 | 101 | This subclass supports materials with a Sellmeier equation of the 102 | form:: 103 | 104 | n^2(l) - 1 = c1 + c2 * l^2 / (l2 - c3) + ... 105 | 106 | This is formula 2 from refractiveindex.info [DispersionFormulas]_. 107 | ''' 108 | def _func(self, x): 109 | c = self._coefficients 110 | x2 = x * x 111 | n2 = np.full_like(x, 1.0 + c[0]) 112 | for i in range(1, c.size - 1, 2): 113 | n2 += c[i] * x2 / (x2 - c[i+1]) 114 | return np.sqrt(n2) 115 | 116 | FS = SellmeierF1(coefficients=[0.0000000, 0.6961663, 117 | 0.0684043, 0.4079426, 118 | 0.1162414, 0.8974794, 119 | 9.8961610], 120 | freq_range=[0.21e-6, 6.7e-6], 121 | name="FS", 122 | long_name="Fused silica (fused quartz)") 123 | """Material instance describing fused silica (fused quartz). 124 | 125 | The data was taken from refractiveindex.info 126 | """ 127 | 128 | BK7 = SellmeierF2(coefficients=[0.00000000000, 1.039612120, 129 | 0.00600069867, 0.231792344, 130 | 0.02001791440, 1.010469450, 131 | 103.560653], 132 | freq_range=[0.3e-6, 2.5e-6], 133 | name="BK7", long_name="N-BK7 (SCHOTT)") 134 | """Material instance describing N-BK7 (SCHOTT). 135 | 136 | The data was taken from refractiveindex.info 137 | """ -------------------------------------------------------------------------------- /pypret/mesh_data.py: -------------------------------------------------------------------------------- 1 | """ This module implements an object for dealing with two-dimensional data. 2 | """ 3 | import numpy as np 4 | from scipy.interpolate import RegularGridInterpolator 5 | from . import lib 6 | from . import io 7 | 8 | 9 | class MeshData(io.IO): 10 | _io_store = ["data", "axes", "labels", "units", "uncertainty"] 11 | 12 | def __init__(self, data, *axes, uncertainty=None, labels=None, 13 | units=None): 14 | """ Creates a MeshData instance. 15 | 16 | Parameters 17 | ---------- 18 | data : ndarray 19 | A at least two-dimensional array containing the data. 20 | *axes : ndarray 21 | Arrays specifying the coordinates of the data axes. Must be given 22 | in indexing order. 23 | uncertainty : ndarray 24 | An ndarray of the same size as `data` that contains some measure 25 | of the uncertainty of the meshdata. E.g., it could be the standard 26 | deviation of the data. 27 | labels : list of str, optional 28 | A list of strings labeling the axes. The last element labels the 29 | data itself, e.g. ``labels`` must have one more element than the 30 | number of axes. 31 | units : list of str, optional 32 | A list of unit strings. 33 | """ 34 | self.data = data.copy() 35 | self.axes = [np.array(a).copy() for a in axes] 36 | if uncertainty is not None: 37 | self.uncertainty = uncertainty.copy() 38 | else: 39 | self.uncertainty = None 40 | if self.ndim != len(axes): 41 | raise ValueError("Number of supplied axes is wrong!") 42 | if self.shape != tuple(ax.size for ax in self.axes): 43 | raise ValueError("Shape of supplied axes is wrong!") 44 | self.labels = labels 45 | if self.labels is None: 46 | self.labels = ["" for ax in self.axes] 47 | self.units = units 48 | if self.units is None: 49 | self.units = ["" for ax in self.axes] 50 | 51 | @property 52 | def shape(self): 53 | """ Returns the shape of the data as a tuple. 54 | """ 55 | return self.data.shape 56 | 57 | @property 58 | def ndim(self): 59 | """ Returns the dimension of the data as integer. 60 | """ 61 | return self.data.ndim 62 | 63 | def copy(self): 64 | """ Creates a copy of the MeshData instance. """ 65 | return MeshData(self.data, *self.axes, uncertainty=self.uncertainty, 66 | labels=self.labels, units=self.units) 67 | 68 | def marginals(self, normalize=False, axes=None): 69 | """ Calculates the marginals of the data. 70 | 71 | axes specifies the axes of the marginals, e.g., the axes on which the 72 | sum is projected. 73 | """ 74 | return lib.marginals(self.data, normalize=normalize, axes=axes) 75 | 76 | def normalize(self): 77 | """ Normalizes the maximum of the data to 1. 78 | """ 79 | self.scale(1.0 / self.data.max()) 80 | 81 | def scale(self, scale): 82 | if self.uncertainty is not None: 83 | self.uncertainty *= scale 84 | self.data *= scale 85 | 86 | def autolimit(self, *axes, threshold=1e-2, padding=0.25): 87 | """ Limits the data based on the marginals. 88 | """ 89 | if len(axes) == 0: 90 | # default: operate on all axes 91 | axes = list(range(self.ndim)) 92 | marginals = lib.marginals(self.data) 93 | limits = [] 94 | for i, j in enumerate(axes): 95 | limit = lib.limit(self.axes[j], marginals[j], 96 | threshold=threshold, padding=padding) 97 | limits.append(limit) 98 | self.limit(*limits, axes=axes) 99 | 100 | def limit(self, *limits, axes=None): 101 | """ Limits the data range of this instance. 102 | 103 | Parameters 104 | ---------- 105 | *limits : tuples 106 | The data limits in the axes as tuples. Has to match the dimension 107 | of the data or the number of axes specified in the `axes` 108 | parameter. 109 | axes : tuple or None 110 | The axes in which the limit is applied. Default is `None` in which 111 | case all axes are selected. 112 | """ 113 | if axes is None: 114 | # default: operate on all axes 115 | axes = list(range(self.ndim)) 116 | axes = lib.as_list(axes) 117 | if len(axes) != len(limits): 118 | raise ValueError("Number of limits must match the specified axes!") 119 | slices = [] 120 | for j in range(self.ndim): 121 | if j in axes: 122 | i = axes.index(j) 123 | ax = self.axes[j] 124 | x1, x2 = limits[i] 125 | # do it this way as we cannot assume them to be sorted... 126 | idx1 = np.argmin(np.abs(ax - x1)) 127 | idx2 = np.argmin(np.abs(ax - x2)) 128 | if idx1 > idx2: 129 | idx1, idx2 = idx2, idx1 130 | elif idx1 == idx2: 131 | raise ValueError('Selected empty slice along axis %d!' % i) 132 | slices.append(slice(idx1, idx2 + 1)) 133 | else: 134 | # empty slice 135 | slices.append(slice(None)) 136 | self.axes[j] = self.axes[j][slices[-1]] 137 | self.data = self.data[(*slices,)] 138 | if self.uncertainty is not None: 139 | self.uncertainty = self.uncertainty[(*slices,)] 140 | 141 | def interpolate(self, axis1=None, axis2=None, degree=2, sorted=False): 142 | """ Interpolates the data on a new two-dimensional, equidistantly 143 | spaced grid. 144 | """ 145 | axes = [axis1, axis2] 146 | for i in range(self.ndim): 147 | if axes[i] is None: 148 | axes[i] = self.axes[i] 149 | # FITPACK can only deal with strictly increasing axes 150 | # so sort them beforehand if necessary... 151 | orig_axes = self.axes 152 | data = self.data.copy() 153 | if self.uncertainty is not None: 154 | uncertainty = self.uncertainty.copy() 155 | if not sorted: 156 | for i in range(len(orig_axes)): 157 | idx = np.argsort(orig_axes[i]) 158 | orig_axes[i] = orig_axes[i][idx] 159 | data = np.take(data, idx, axis=i) 160 | if self.uncertainty is not None: 161 | uncertainty = np.take(uncertainty, idx, axis=i) 162 | dataf = RegularGridInterpolator(tuple(orig_axes), data, 163 | bounds_error=False, fill_value=0.0) 164 | grid = lib.build_coords(*axes) 165 | self.data = dataf(grid) 166 | self.axes = axes 167 | if self.uncertainty is not None: 168 | dataf = RegularGridInterpolator(tuple(orig_axes), uncertainty, 169 | bounds_error=False, fill_value=0.0) 170 | self.uncertainty = dataf(grid) 171 | 172 | def flip(self, *axes): 173 | """ Flips the data on the specified axes. 174 | """ 175 | if len(axes) == 0: 176 | return 177 | axes = lib.as_list(axes) 178 | slices = [slice(None) for ax in self.axes] 179 | for ax in axes: 180 | self.axes[ax] = self.axes[ax][::-1] 181 | slices[ax] = slice(None, None, -1) 182 | self.data = self.data[slices] 183 | if self.uncertainty is not None: 184 | self.uncertainty = self.uncertainty[slices] 185 | -------------------------------------------------------------------------------- /pypret/pulse.py: -------------------------------------------------------------------------------- 1 | """ Provides a class to simulate an ultrashort optical pulse using its envelope 2 | description. 3 | 4 | The temporal envelope is denoted as `field` and the spectral envelope as 5 | `spectrum` in the code and the function signatures. 6 | """ 7 | import numpy as np 8 | from . import io 9 | from . import lib 10 | from .frequencies import convert 11 | from scipy.optimize import minimize_scalar, root_scalar 12 | 13 | 14 | class Pulse(io.IO): 15 | """ A class for modelling femtosecond pulses by their envelope. 16 | """ 17 | _io_store = ['ft', 'wl0', '_field', '_spectrum'] 18 | 19 | def __init__(self, ft, wl0, unit='wl'): 20 | """ Initializes an optical pulse described by its envelope. 21 | 22 | Parameters 23 | ---------- 24 | ft : FourierTransform 25 | A ``FourierTransform`` instance that specifies a temporal and 26 | spectral grid. 27 | wl0 : float 28 | The center frequency of the pulse. 29 | unit : str 30 | The unit in which the center frequency is specified. Can be either 31 | of ``wl``, ``om``, ``f``, or ``k``. See ``frequencies`` for more 32 | information. Default is ``wl``. 33 | """ 34 | self.ft = ft 35 | self.wl0 = convert(wl0, unit, 'wl') 36 | self._field = np.zeros(ft.N, dtype=np.complex128) 37 | self._spectrum = np.zeros(ft.N, dtype=np.complex128) 38 | self._post_init() 39 | 40 | def copy(self): 41 | """ Returns a copy of the pulse object. 42 | 43 | Note that they still reference the same `FourierTransform` instance, 44 | which is assumed to be immutable. 45 | """ 46 | p = Pulse(self.ft, self.wl0) 47 | p.spectrum = self.spectrum 48 | return p 49 | 50 | def _post_init(self): 51 | ft = self.ft 52 | self.t = ft.t 53 | self.w = ft.w 54 | self.dt = ft.dt 55 | self.dw = ft.dw 56 | self.N = ft.N 57 | 58 | self.w0 = convert(self.wl0, 'wl', 'om') 59 | self.wl = convert(self.w + self.w0, 'om', 'wl') 60 | 61 | @property 62 | def field(self): 63 | """ The complex-valued temporal envelope of the pulse. 64 | 65 | On read access returns a copy of the internal array. On write 66 | access the spectral envelope is automatically updated. 67 | """ 68 | return self._field.copy() 69 | 70 | @field.setter 71 | def field(self, val): 72 | self._field[:] = val 73 | self.update_spectrum() 74 | 75 | def field_at(self, t): 76 | """ The complex-valued temporal envelope of the pulse at the 77 | times `t`. 78 | """ 79 | return self.ft.backward_at(self._spectrum, t) 80 | 81 | @property 82 | def spectrum(self): 83 | """ The complex-valued spectral envelope of the pulse. 84 | 85 | On read access returns a copy of the internal array. On write 86 | access the temporal envelope is automatically updated. 87 | """ 88 | return self._spectrum.copy() 89 | 90 | @spectrum.setter 91 | def spectrum(self, val): 92 | self._spectrum[:] = val 93 | self.update_field() 94 | 95 | def spectrum_at(self, w): 96 | """ The complex-valued spectral envelope of the pulse at the 97 | frequencies `w`. 98 | """ 99 | return self.ft.forward_at(self._field, w) 100 | 101 | def update_field(self): 102 | """ Manually updates the field from the (modified) spectrum. 103 | """ 104 | self.ft.backward(self._spectrum, out=self._field) 105 | 106 | def update_spectrum(self): 107 | """ Manually updates the spectrum from the (modified) field. 108 | """ 109 | self.ft.forward(self._field, out=self._spectrum) 110 | 111 | @property 112 | def intensity(self): 113 | """ The temporal intensity profile of the pulse in vacuum. 114 | 115 | Only read access. 116 | """ 117 | return lib.abs2(self._field) 118 | 119 | @property 120 | def amplitude(self): 121 | """ The temporal amplitude profile of the pulse in vacuum. 122 | 123 | Only read access. 124 | """ 125 | return self._field.abs() 126 | 127 | @property 128 | def phase(self): 129 | """ The temporal phase of the pulse. 130 | 131 | Only read access. 132 | """ 133 | return lib.phase(self._field) 134 | 135 | @property 136 | def spectral_intensity(self): 137 | """ The spectral intensity profile of the pulse in vacuum. 138 | 139 | Only read access. 140 | """ 141 | return lib.abs2(self._spectrum) 142 | 143 | @property 144 | def spectral_amplitude(self): 145 | """ The spectral amplitude profile of the pulse in vacuum. 146 | 147 | Only read access. 148 | """ 149 | return self._spectrum.abs() 150 | 151 | @property 152 | def spectral_phase(self): 153 | """ The spectral phase of the pulse. 154 | 155 | Only read access. 156 | """ 157 | return lib.phase(self._spectrum) 158 | 159 | @property 160 | def time_bandwidth_product(self): 161 | """ Calculates the rms time-bandwidth product of the pulse. 162 | 163 | In this definition a transform-limited Gaussian pulse has a 164 | time-bandwidth product of 0.5. So the number returned by this 165 | function will always be >= 0.5. 166 | """ 167 | return (lib.standard_deviation(self.t, self.intensity) * 168 | lib.standard_deviation(self.w, self.spectral_intensity)) 169 | 170 | def fwhm(self, dt=None): 171 | """ Calculates the full width at half maximum (FWHM) of the temporal 172 | intensity profile. 173 | 174 | Parameters 175 | ---------- 176 | dt : float or None, optional 177 | Specifies the required accuracy of the calculation. If `None` (the 178 | default) it is only as good as the spacing of the underlying 179 | simulation grid - which can be quite coarse compared to the FWHM. 180 | If smaller it is calculated based on trigonometric interpolation. 181 | 182 | """ 183 | t, intensity = self.t, self.intensity 184 | if dt is None or dt == self.dt: 185 | return lib.fwhm(t, intensity) 186 | 187 | # exact calculation 188 | def objective(tau): 189 | return lib.abs2(self.field_at(np.array([tau]))[0]) 190 | 191 | # determine the maximum accurately 192 | idx = np.argmax(intensity) 193 | res = minimize_scalar(lambda x: -objective(x), 194 | (t[idx] - self.dt, t[idx]), 195 | tol=dt / 100.0, method="brent") 196 | y0 = -res.fun 197 | # determine the right and left sided intersection points 198 | idx1, idx2 = lib.arglimit(intensity, threshold=0.5 * y0, 199 | padding=0.0, normalize=False) 200 | # left side 201 | res = root_scalar(lambda x: objective(x) - 0.5 * y0, 202 | bracket=(t[idx1] - self.dt, t[idx1]), 203 | xtol=dt, method="brentq") 204 | xl = res.root 205 | # right side 206 | res = root_scalar(lambda x: objective(x) - 0.5 * y0, 207 | bracket=(t[idx2], t[idx2] + self.dt), 208 | xtol=dt, method="brentq") 209 | xr = res.root 210 | return xr - xl 211 | -------------------------------------------------------------------------------- /pypret/pulse_error.py: -------------------------------------------------------------------------------- 1 | """ This module implements testing procedures for retrieval algorithms. 2 | """ 3 | import numpy as np 4 | from scipy import optimize 5 | from . import lib 6 | 7 | 8 | def pulse_error(E, E0, ft, dot_ambiguity=False, 9 | spectral_shift_ambiguity=False): 10 | ''' Calculates the normalized rms error between two pulse spectra while 11 | taking into account the retrieval ambiguities. 12 | 13 | One step in `optimal_rms_error` (the determination of the initial bracket) 14 | could probably be more efficient, see [Dorrer2002]_). We use the less 15 | elegant but maybe more straightforward way of simply sampling the 16 | range for a bracket that encloses a minimum. 17 | 18 | Parameters 19 | ---------- 20 | E, E0: 1d-array 21 | Complex-valued arrays that contain the spectra of the pulses. 22 | ``E`` will be matched against ``E0``. 23 | ft : FourierTransform instance 24 | Performs Fourier transforms on the pulse grid. 25 | dot_ambiguity : bool, optional 26 | Takes the direction of time ambiguity into account. Default is 27 | ``False``. 28 | spectral_shift_ambiguity : bool, optional 29 | Takes the spectral shift ambiguity into account. Default is ``False``. 30 | ''' 31 | test_fields = [[ft.w, E, E0]] 32 | if spectral_shift_ambiguity: 33 | # spectrally shift by exactly half the grid size 34 | Et = ft.backward(E) 35 | Et *= np.exp(0.5j * ft.N * ft.dw * ft.t) 36 | test_fields.append([ft.w, ft.forward(Et), E0]) 37 | if dot_ambiguity: 38 | max_iter = len(test_fields) 39 | for i in range(max_iter): 40 | tf = test_fields[i] 41 | test_fields.append([tf[0], tf[1].conj(), tf[2]]) 42 | 43 | best_error = np.inf 44 | for w, spec1, spec2 in test_fields: 45 | error, matched = optimal_rms_error(w, spec1, spec2) 46 | if error < best_error: 47 | best_error = error 48 | best_match = matched 49 | return best_error, best_match 50 | 51 | 52 | def best_constant_phase(E, E0): 53 | """ Finds ``c`` with ``|c| = 1`` so that ``sum(abs2(c * y1 - y2))`` is 54 | minimal. 55 | 56 | Uses an analytic solution. 57 | """ 58 | A = np.sum(E.conj() * E0) 59 | c = A / np.abs(A) 60 | err1 = np.sum(lib.abs2(c * E - E0)) 61 | err2 = np.sum(lib.abs2(-c * E - E0)) 62 | if err2 < err1: 63 | c = -c 64 | return c 65 | 66 | 67 | def optimal_rms_error(w, E, E0): 68 | """ Calculates the RMS error of two arrays, ignoring scaling, constant 69 | and linear phase of one of them. 70 | 71 | Formally it calculates the minimal error:: 72 | 73 | R = sqrt(|rho * exp(i*(x*a + b)) * y1 - y2|^2 / |y2|^2) 74 | 75 | with respect to rho, a and b. If additionally ``conjugation = True`` then 76 | the error for conjugate(y1) is calculated and the best transformation of y1 77 | is also returned. 78 | """ 79 | # E is rescaled so that the amplitudes match in the least-squares sense 80 | E = E * lib.best_scale(E, E0) 81 | 82 | # find optimal linear and constant phase 83 | # determine the frequency spacing 84 | dw = np.max(np.abs(np.diff(w))) 85 | # rescale the objective function to make it easier for the optimizer 86 | scale = 1.0 / np.sqrt(np.sum(lib.abs2(E0)) * E.shape[0]) 87 | 88 | def objective(alpha): 89 | linear = np.exp(1.0j * alpha / dw * w) 90 | phase0 = best_constant_phase(linear * E, E0) 91 | cresiduals = (phase0 * linear * E - E0) * scale 92 | return lib.norm2(cresiduals) 93 | 94 | # find an initial bracket 95 | alphas = np.linspace(-np.pi, np.pi, 2 * E.shape[0]) 96 | err = np.array([objective(a) for a in alphas]) 97 | idx = np.argmin(err) 98 | bracket = [ 99 | alphas[max(0, idx - 1)], 100 | alphas[min(alphas.shape[0] - 1, idx + 1)] 101 | ] 102 | # run a bounded optimization on that bracket to obtain high precision 103 | res = optimize.minimize_scalar( 104 | objective, 105 | bounds=bracket, 106 | method='bounded', 107 | options=dict(maxiter=100, xatol=1e-10) 108 | ) 109 | linear = np.exp(1.0j * res.x / dw * w) 110 | phase0 = best_constant_phase(linear * E, E0) 111 | E = phase0 * linear * E 112 | 113 | return lib.nrms(E, E0), E 114 | -------------------------------------------------------------------------------- /pypret/random_pulse.py: -------------------------------------------------------------------------------- 1 | """ Provides a function to generate random pulses with specified TBP. 2 | """ 3 | import numpy as np 4 | import scipy.optimize as opt 5 | from . import lib 6 | 7 | 8 | def random_pulse(pulse, tbp, edge_value=None, check=True): 9 | """ Creates a random pulse with a specified time-bandwidth product. 10 | 11 | Parameters 12 | ---------- 13 | pulse : Pulse instance 14 | tbp : float 15 | The specified time-bandwidth product. 16 | edge_value : float, optional 17 | The maximal value for the pulse amplitude at the edges of the grid. 18 | It defaults to the double value epsilon ~2e-16. 19 | 20 | Returns 21 | ------- 22 | bool : True on success, False if an error occured. The resulting pulse 23 | is stored in the Pulse instance passed to the function. 24 | 25 | Notes 26 | ----- 27 | The function creates random pulses by iteratively restricting the bandwidth 28 | in time and frequency domain. It starts from random complex values in 29 | frequency domain, multiplies a Gaussian function, transforms in the 30 | time domain and multiplies a Gaussian function again. The filter functions 31 | are Gaussians with the specified time-bandwidth product. 32 | The TBP of the Gaussian filters, however, does not directly correspond 33 | to the TBP of the resulting pulse. To use this algorithm to generate a 34 | pulse with exactly the specified TBP, it is run in the range 0.5 * TBP to 35 | 1.5 * TBP using a scalar root search (brentq). Usually this guarantees 36 | convergence within a few tries. 37 | The larger the TBP the larger the number of points has to be. So the 38 | algorithm may fail to find a solution if pulse.N is too small. 39 | """ 40 | if edge_value is None: 41 | # this is roughly the roundoff error induced by an FFT 42 | edge_value = pulse.N * np.finfo(np.double).eps 43 | # access/calculate some fundamental grid parameters 44 | t, w = pulse.t, pulse.w 45 | t1, t2 = t[0], t[-1] 46 | w1, w2 = w[0], w[-1] 47 | t0, w0 = 0.5 * (t1 + t2), 0.5 * (w1 + w2) 48 | log_edge = np.log(edge_value) 49 | 50 | """ Calculate the width of a Gaussian function that drops exactly to 51 | edge_value at the edges of the grid. 52 | """ 53 | spectral_width = np.sqrt(-0.125 * (w1 - w2)**2 / log_edge) 54 | # Now the same in the temporal domain 55 | max_temporal_width = np.sqrt(-0.125 * (t1 - t2)**2 / log_edge) 56 | # The actual temporal width is obtained by the uncertainty relation 57 | # from the specified TBP 58 | temporal_width = 2.0 * tbp / spectral_width 59 | 60 | if temporal_width > max_temporal_width: 61 | print("The required time-bandwidth product cannot be reached! " 62 | "Decrease edge_value or increase pulse.N!") 63 | return False 64 | 65 | # special case for TBP = 0.5 (transform-limited case) 66 | if tbp == 0.5: 67 | phase = np.exp(1.0j * lib.twopi * np.random.rand(pulse.N)) 68 | pulse.spectrum = lib.gaussian(w, w0, spectral_width) * phase 69 | return True 70 | 71 | # create the filter functions, the scaling by the number of rounds is 72 | # purely a heuristic 73 | spectral_filter = lib.gaussian(w, w0, spectral_width) 74 | 75 | """ The algorithm works by iteratively filtering in the frequency and time 76 | domain. However, the chosen filter functions only roughly give 77 | the correct TBP. To obtain the exact result we scale the temporal 78 | filter bandwidth by a factor and perform a scalar minimization on 79 | that value. 80 | """ 81 | spectrum = (np.random.rand(pulse.N) * 82 | np.exp(1j * lib.twopi * np.random.rand(pulse.N))) 83 | # rough guess for the relative range in which our optimal value lies 84 | factor_min, factor_max = 0.5, 1.5 85 | 86 | def create_pulse(factor): 87 | """ This performs the filtering. """ 88 | temporal_filter = lib.gaussian(t, t0, temporal_width * factor) 89 | 90 | pulse.spectrum = spectrum * spectral_filter 91 | pulse.field = pulse.field * temporal_filter 92 | 93 | def objective(factor): 94 | """ This function should be zero """ 95 | create_pulse(factor) 96 | return tbp - pulse.time_bandwidth_product 97 | 98 | # The objective function has to change sign in the bounds we chose 99 | i = 0 100 | while np.sign(objective(factor_min)) == np.sign(objective(factor_max)): 101 | # for some random arrays this condition is not always fulfilled. 102 | # just try again 103 | spectrum = (np.random.rand(pulse.N) * 104 | np.exp(1j * lib.twopi * np.random.rand(pulse.N))) 105 | if i == 10: 106 | # I have never observed this case. 107 | raise ValueError('Could not create a pulse for these parameters!') 108 | 109 | # actually perform the optimization 110 | factor = opt.brentq(objective, factor_min, factor_max) 111 | # and finally create the pulse 112 | create_pulse(factor) 113 | 114 | # The random pulse is stored in pulse 115 | return True 116 | 117 | 118 | def random_gaussian(pulse, fwhm, phase_max=0.1 * np.pi): 119 | """ Generates a Gaussian pulse with random phase. 120 | 121 | Its pulse of duration is given by ``fwhm``. 122 | """ 123 | # convert intensity fwhm to field std-dev. 124 | sigma = 0.5 * fwhm / np.sqrt(np.log(2.0)) 125 | phase = np.exp(1.0j * np.random.uniform(-phase_max, phase_max, pulse.N)) 126 | pulse.field = lib.gaussian(pulse.t, sigma=sigma) * phase 127 | return True 128 | -------------------------------------------------------------------------------- /pypret/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | """ This sub-package implements different pulse retrieval algorithms 2 | using a common interface. This facilitates direct comparison. 3 | 4 | All algorithms are implemented as a subclass of :class:`BaseRetriever`. The 5 | algorithms which are implemented step-by-step, i.e., do not rely on some 6 | monolithic minimization algorithm implemented elsewhere, are further 7 | subclassed from :class:`StepRetriever`. 8 | 9 | The public function to instantiate the correct retriever is :func:`Retriever`. 10 | 11 | This sub-package does not implement any form of data pre-processing. It expects 12 | correctly interpolated data in form of a MeshData object. 13 | """ 14 | from .retriever import Retriever 15 | from .step_retriever import (COPRARetriever, PCGPARetriever, GPARetriever, 16 | GPDSCANRetriever, PIERetriever) 17 | from .nlo_retriever import LMRetriever, NMRetriever, DERetriever, BFGSRetriever -------------------------------------------------------------------------------- /pypret/retrieval/nlo_retriever.py: -------------------------------------------------------------------------------- 1 | """ This module implements retrieval algorithms based on general 2 | nonlinear optimization algorithms such as Levenberg-Marquadt, 3 | differential evolution, or Nelder-Mead. 4 | """ 5 | import numpy as np 6 | from scipy import optimize 7 | from .retriever import BaseRetriever 8 | 9 | 10 | class NLORetriever(BaseRetriever): 11 | 12 | def _scalar_objective(self, x): 13 | # rename 14 | rs = self._retrieval_state 15 | log = self.log 16 | # calculate trace error 17 | En = x.view(np.complex128) 18 | r = self._objective_function(En) 19 | R = self._Rr(r) 20 | # printing and logging 21 | if rs.nfev % 100 == 0 and self.verbose: 22 | print(rs.nfev, R) 23 | if self.logging: 24 | log.trace_error.append(R) 25 | rs.nfev += 1 26 | # normalize the error to avoid ill-scaling 27 | return r / self.Tmn_meas.max()**2 28 | 29 | def _vector_objective(self, x): 30 | # rename 31 | rs = self._retrieval_state 32 | log = self.log 33 | # calculate the error vector 34 | En = x.view(np.complex128) 35 | Tmn = self.pnps.calculate(En, self.parameter) 36 | diff = self._error_vector(Tmn, store=False) 37 | R = self._Rr(np.sum(diff * diff)) 38 | # printing and logging 39 | if rs.nfev % 100 == 0 and self.verbose: 40 | print(rs.nfev, R) 41 | if self.logging: 42 | log.trace_error.append(R) 43 | rs.nfev += 1 44 | # normalize the error vector to avoid ill-scaling 45 | return diff / np.sqrt(self.M * self.N) / self.Tmn_meas.max() 46 | 47 | def _retrieve_begin(self, measurement, initial_guess, weights): 48 | super()._retrieve_begin(measurement, initial_guess, weights) 49 | rs = self._retrieval_state 50 | rs.nfev = 0 51 | 52 | def _retrieve_end(self): 53 | super()._retrieve_end() 54 | self._result.nfev = self._retrieval_state.nfev 55 | 56 | def result(self, pulse_original=None, full=True): 57 | res = super().result(pulse_original=pulse_original, full=full) 58 | res.nfev = self._retrieval_state.nfev 59 | return res 60 | 61 | 62 | class LMRetriever(NLORetriever): 63 | """ Implements pulse retrieval based on the Levenberg-Marquadt algorithm. 64 | 65 | This is an efficient nonlinear least-squares solver, however, it will still 66 | be *very* slow for large pulses (N > 256). The reason is that the 67 | (MN x N) Jacobian is evaluated using numerical differentiation. 68 | 69 | The recommendation is to use this method either on small problems or to 70 | refine or verify solutions provided by a different algorithm. 71 | """ 72 | method = "lm" 73 | 74 | def __init__(self, pnps, ftol=1e-08, xtol=1e-08, gtol=1e-08, lm_verbose=0, 75 | **kwargs): 76 | """ For a full documentation of the arguments see :class:`Retriever`. 77 | 78 | For the documentation of `ftol`, `xtol`, `gtol` see the documentation 79 | of :func:`scipy.optimize.least_squares`. They are passed directly 80 | to the optimizer. If you want to run the optimizer for a fixed 81 | number of iterations, set all values to 1e-14 to effectively 82 | disable the stopping criteria. 83 | """ 84 | super().__init__(pnps, ftol=ftol, xtol=xtol, gtol=gtol, 85 | lm_verbose=lm_verbose, **kwargs) 86 | 87 | def _retrieve(self): 88 | # local rename 89 | o = self.options 90 | res = self._result 91 | # store current guess in attribute 92 | spectrum = self.initial_guess.copy() 93 | # This algorithm is not robust against the scaling of the input vector! 94 | spectrum /= np.abs(spectrum).max() 95 | x0 = spectrum.view(np.float64).copy() 96 | # calculate the maximum number of function evaluations 97 | max_nfev = None 98 | if o.maxfev is not None: 99 | max_nfev = o.maxfev // x0.shape[0] 100 | optres = optimize.least_squares( 101 | self._vector_objective, 102 | x0, 103 | method='trf', 104 | jac='2-point', 105 | max_nfev=max_nfev, 106 | tr_solver='exact', 107 | ftol=o.ftol, 108 | gtol=o.gtol, 109 | xtol=o.xtol, 110 | verbose=o.lm_verbose 111 | ) 112 | res.approximate_error = False 113 | res.spectrum = optres.x.view(dtype=np.complex128) 114 | res.trace_error = self.trace_error(res.spectrum) 115 | res.approximate_error = False 116 | return res.spectrum 117 | 118 | 119 | class DERetriever(NLORetriever): 120 | """ This retriever uses the gradient-free differential evolution algorithm. 121 | 122 | It tries to match the parameters described in [Escoto2018]_ as far as 123 | they are mentioned. No further effort was made to optimize them. If you 124 | are interested in using DE as a pulse retrieval algorithm you are 125 | advised to study the documentation at 126 | :func:`scipy.optimize.differential_evolution`. 127 | 128 | The initial population in our implementation is based on the provided guess 129 | with added complex, Gaussian noise of 5% of the maximum amplitude. 130 | In our tests we saw no convergence when starting from completely 131 | random initial guesses. 132 | """ 133 | method = "de" 134 | 135 | def _retrieve(self): 136 | # local rename 137 | o = self.options 138 | res = self._result 139 | # calculate the maximum number of function evaluations 140 | max_nfev = None 141 | if o.maxfev is not None: 142 | max_nfev = int(round(o.maxfev / 10 - 1)) 143 | # generate initial population 144 | init = [self.initial_guess.view(np.float64).copy()] 145 | amp = np.abs(self.initial_guess).max() 146 | for i in range(9): 147 | sol = (self.initial_guess + 148 | 0.05 * amp * np.random.normal(size=self.N) + 149 | 0.05j * amp * np.random.normal(size=self.N)) 150 | init.append(sol.view(np.float64)) 151 | optres = optimize.differential_evolution( 152 | self._scalar_objective, 153 | [(-1.0, 1.0) for i in range(2 * self.N)], 154 | strategy='rand1bin', 155 | maxiter=max_nfev, 156 | recombination=0.5, 157 | popsize=10, # is overwritten by init 158 | tol=1e-7, 159 | polish=False, 160 | init=np.array(init) 161 | ) 162 | res.approximate_error = False 163 | res.spectrum = optres.x.view(dtype=np.complex128) 164 | res.trace_error = self.trace_error(res.spectrum) 165 | res.approximate_error = False 166 | return res.spectrum 167 | 168 | 169 | class NMRetriever(NLORetriever): 170 | """ This retriever uses the gradient-free Nelder-Mead algorithm. 171 | """ 172 | method = "nm" 173 | 174 | def _retrieve(self): 175 | # local rename 176 | o = self.options 177 | res = self._result 178 | # store current guess in attribute 179 | spectrum = self.initial_guess.copy() 180 | # This algorithm is not robust against the scaling of the input vector! 181 | spectrum /= np.abs(spectrum).max() 182 | x0 = spectrum.view(np.float64).copy() 183 | # calculate the maximum number of function evaluations 184 | max_nfev = None 185 | if o.maxfev is not None: 186 | max_nfev = o.maxfev 187 | optres = optimize.minimize( 188 | self._scalar_objective, 189 | x0, 190 | method='Nelder-Mead', 191 | options={'maxfev': max_nfev}, 192 | ) 193 | res.approximate_error = False 194 | res.spectrum = optres.x.view(dtype=np.complex128) 195 | res.trace_error = self.trace_error(res.spectrum) 196 | res.approximate_error = False 197 | return res.spectrum 198 | 199 | 200 | class BFGSRetriever(NLORetriever): 201 | """ This retriever uses the BFGS algorithm with numerical differentiation. 202 | """ 203 | method = "bfgs" 204 | 205 | def _retrieve(self): 206 | # local rename 207 | o = self.options 208 | res = self._result 209 | # store current guess in attribute 210 | spectrum = self.initial_guess.copy() 211 | # This algorithm is not robust against the scaling of the input vector! 212 | spectrum /= np.abs(spectrum).max() 213 | x0 = spectrum.view(np.float64).copy() 214 | # calculate the maximum number of function evaluations 215 | max_nfev = None 216 | if o.maxfev is not None: 217 | max_nfev = o.maxfev // x0.shape[0] 218 | optres = optimize.minimize( 219 | self._scalar_objective, 220 | x0, 221 | method='BFGS', 222 | options={'maxiter': max_nfev}, 223 | ) 224 | res.approximate_error = False 225 | res.spectrum = optres.x.view(dtype=np.complex128) 226 | res.trace_error = self.trace_error(res.spectrum) 227 | res.approximate_error = False 228 | return res.spectrum 229 | -------------------------------------------------------------------------------- /pypret/retrieval/retriever.py: -------------------------------------------------------------------------------- 1 | """ This module provides the basic classes for the pulse retrieval algorithms. 2 | """ 3 | import numpy as np 4 | from types import SimpleNamespace 5 | from .. import io 6 | from ..mesh_data import MeshData 7 | from ..pulse_error import pulse_error 8 | from .. import lib 9 | from ..pnps import BasePNPS 10 | 11 | # global dictionary that contains all PNPS classes 12 | _RETRIEVER_CLASSES = {} 13 | 14 | 15 | # ============================================================================= 16 | # Metaclass and factory 17 | # ============================================================================= 18 | class MetaRetriever(type): 19 | """ Metaclass that registers Retriever classes in a global dictionary. 20 | """ 21 | def __new__(cls, clsmethod, bases, attrs): 22 | global _RETRIEVER_CLASSES 23 | newclass = super().__new__(cls, clsmethod, bases, attrs) 24 | method = newclass.method 25 | if method is None: 26 | return newclass 27 | # register the Retriever method, e.g. 'copra' 28 | if method in _RETRIEVER_CLASSES: 29 | raise ValueError("Two retriever classes implement retriever '%s'." 30 | % method) 31 | _RETRIEVER_CLASSES[method] = newclass 32 | return newclass 33 | 34 | 35 | class MetaIORetriever(io.MetaIO, MetaRetriever): 36 | # to fix metaclass conflicts 37 | pass 38 | 39 | 40 | # ============================================================================= 41 | # Retriever Base class 42 | # ============================================================================= 43 | class BaseRetriever(io.IO, metaclass=MetaIORetriever): 44 | """ The abstract base class for pulse retrieval. 45 | 46 | This class implements common functionality for different retrieval 47 | algorithms. 48 | """ 49 | method = None 50 | supported_schemes = None 51 | _io_store = ['pnps', 'options', 'logging', 'log', 52 | '_retrieval_state', '_result'] 53 | 54 | def __init__(self, pnps, logging=False, verbose=False, **kwargs): 55 | self.pnps = pnps 56 | self.ft = self.pnps.ft 57 | self.options = SimpleNamespace(**kwargs) 58 | self._result = None 59 | self.logging = logging 60 | self.verbose = verbose 61 | self.log = None 62 | rs = self._retrieval_state = SimpleNamespace() 63 | rs.running = False 64 | if (self.supported_schemes is not None and 65 | pnps.scheme not in self.supported_schemes): 66 | raise ValueError("Retriever '%s' does not support scheme '%s'. " 67 | "It only supports %s." % 68 | (self.method, pnps.scheme, self.supported_schemes) 69 | ) 70 | 71 | def retrieve(self, measurement, initial_guess, weights=None, 72 | **kwargs): 73 | """ Retrieve pulse from ``measurement`` starting at ``initial_guess``. 74 | 75 | Parameters 76 | ---------- 77 | measurement : MeshData 78 | A MeshData instance that contains the PNPS measurement. The first 79 | axis has to correspond to the PNPS parameter, the second to the 80 | frequency. The data has to be the measured _intensity_ over the 81 | frequency (not wavelength!). The second axis has to match exactly 82 | the frequency axis of the underlying PNPS instance. No 83 | interpolation is done. 84 | initial_guess : 1d-array 85 | The spectrum of the pulse that is used as initial guess in the 86 | iterative retrieval. 87 | weights : 1d-array 88 | Weights that are attributed to the measurement for retrieval. 89 | In the case of (assumed) Gaussian uncertainties with standard 90 | deviation sigma they should correspond to 1/sigma. 91 | Not all algorithms support using the weights. 92 | kwargs : dict 93 | Can override retrieval options specified in :func:`__init__`. 94 | 95 | Notes 96 | ----- 97 | This function provides no interpolation or data processing. You have 98 | to write a retriever wrapper for that purpose. 99 | """ 100 | self.options.__dict__.update(**kwargs) 101 | if not isinstance(measurement, MeshData): 102 | raise ValueError("measurement has to be a MeshData instance!") 103 | self._retrieve_begin(measurement, initial_guess, weights) 104 | self._retrieve() 105 | self._retrieve_end() 106 | 107 | def _retrieve_begin(self, measurement, initial_guess, weights): 108 | pnps = self.pnps 109 | if not np.all(pnps.process_w == measurement.axes[1]): 110 | raise ValueError("Measurement has to lie on simulation grid!") 111 | # Store measurement 112 | self.measurement = measurement 113 | self.parameter = measurement.axes[0] 114 | self.Tmn_meas = measurement.data 115 | 116 | self.initial_guess = initial_guess 117 | # set the size 118 | self.M, self.N = self.Tmn_meas.shape 119 | # Setup the weights 120 | if weights is None: 121 | self._weights = np.ones((self.M, self.N)) 122 | else: 123 | self._weights = weights.copy() 124 | # Retrieval state 125 | rs = self._retrieval_state 126 | rs.approximate_error = False 127 | rs.running = True 128 | rs.steps_since_improvement = 0 129 | # Initialize result 130 | res = self._result = SimpleNamespace() 131 | res.trace_error = self.trace_error(self.initial_guess) 132 | res.approximate_error = False 133 | res.spectrum = self.initial_guess.copy() 134 | # Setup the logger 135 | if self.logging: 136 | log = self.log = SimpleNamespace() 137 | log.trace_error = [] 138 | log.initial_guess = self.initial_guess.copy() 139 | else: 140 | self.log = None 141 | if self.verbose: 142 | print("Started retriever '%s'" % self.method) 143 | print("Options:") 144 | print(self.options) 145 | print("Initial trace error R = {:.10e}".format(res.trace_error)) 146 | print("Starting retrieval...") 147 | print() 148 | 149 | def _retrieve_end(self): 150 | rs = self._retrieval_state 151 | rs.running = False 152 | res = self._result 153 | if res.approximate_error: 154 | res.trace_error = self.trace_error(res.spectrum) 155 | res.approximate_error = False 156 | 157 | def _project(self, measured, Smk): 158 | """ Performs the projection on the measured intensity. 159 | """ 160 | # in frequency domain 161 | Smn = self.ft.forward(Smk) 162 | # project and specially handle values with zero amplitude 163 | absSmn = np.abs(Smn) 164 | f = (absSmn > 0.0) 165 | Smn[~f] = np.sqrt(measured[~f] + 0.0j) 166 | Smn[f] = Smn[f] / absSmn[f] * np.sqrt(measured[f] + 0.0j) 167 | # back in time domain 168 | Smk2 = self.ft.backward(Smn) 169 | return Smk2 170 | 171 | def _objective_function(self, spectrum): 172 | """ Calculates the minimization objective from the pulse spectrum. 173 | 174 | This is Eq. 11 in the paper: 175 | 176 | r = sum (Tmn^meas - mu * Tmn) 177 | """ 178 | # calculate the PNPS trace 179 | Tmn = self.pnps.calculate(spectrum, self.parameter) 180 | return self._r(Tmn) 181 | 182 | def trace_error(self, spectrum, store=True): 183 | """ Calculates the trace error from the pulse spectrum. 184 | """ 185 | Tmn = self.pnps.calculate(spectrum, self.parameter) 186 | return self._R(Tmn, store=store) 187 | 188 | def _r(self, Tmn, store=True): 189 | """ Calculates the minimization objective r from a simulated trace Tmn. 190 | """ 191 | diff = self._error_vector(Tmn, store=store) 192 | return np.sum(diff * diff) 193 | 194 | def _error_vector(self, Tmn, store=True): 195 | """ Calculates the residual vector from measured to simulated 196 | intensity. 197 | """ 198 | # rename 199 | rs = self._retrieval_state 200 | Tmn_meas = self.Tmn_meas 201 | # scaling factor 202 | w2 = self._weights * self._weights 203 | mu = np.sum(Tmn_meas * Tmn * w2) / np.sum(Tmn * Tmn * w2) 204 | # store intermediate results in current retrieval state 205 | if store: 206 | rs.mu = mu 207 | rs.Tmn = Tmn 208 | rs.Smk = self.pnps.Smk 209 | return np.ravel((Tmn_meas - mu * Tmn) * self._weights) 210 | 211 | def _R(self, Tmn, store=True): 212 | """ Calculates the trace error from a simulated trace Tmn. 213 | """ 214 | r = self._r(Tmn, store=store) 215 | return self._Rr(r) 216 | 217 | def _Rr(self, r): 218 | """ Calculates the trace error from the minimization objective r. 219 | """ 220 | return np.sqrt(r / (self.M * self.N * 221 | (self.Tmn_meas * self._weights).max()**2)) 222 | 223 | def result(self, pulse_original=None, full=True): 224 | """ Analyzes the retrieval results in one retrieval instance 225 | and processes it for plotting or storage. 226 | """ 227 | rs = self._retrieval_state 228 | if self._result is None or self._retrieval_state.running: 229 | return None 230 | res = SimpleNamespace() 231 | # the meta data 232 | res.parameter = self.parameter 233 | res.options = self.options 234 | res.logging = self.logging 235 | res.measurement = self.measurement 236 | # store the retriever itself 237 | if full: 238 | res.pnps = self.pnps 239 | else: 240 | res.pnps = None 241 | 242 | # the pulse spectra 243 | # 1 - the retrieved pulse 244 | res.pulse_retrieved = self._result.spectrum 245 | # 2 - the original test pulse, optional 246 | res.pulse_original = pulse_original 247 | # 3 - the initial guess 248 | res.pulse_initial = self.initial_guess 249 | 250 | # the measurement traces 251 | # 1 - the original data used for retrieval 252 | res.trace_input = self.Tmn_meas 253 | # 2 - the trace error and the trace calculated from the retrieved pulse 254 | res.trace_error = self.trace_error(res.pulse_retrieved) 255 | res.trace_retrieved = rs.mu * rs.Tmn 256 | res.response_function = rs.mu 257 | # the weights 258 | res.weights = self._weights 259 | 260 | # this is set if the original spectrum is provided 261 | if res.pulse_original is not None: 262 | # the trace error of the test pulse (non-zero for noisy input) 263 | res.trace_error_optimal = self.trace_error(res.pulse_original) 264 | # 3 - the optimal trace calculated from the test pulse 265 | res.trace_original = rs.mu * rs.Tmn 266 | dot_ambiguity = False 267 | if self.pnps.method == "ifrog" or self.pnps.scheme == "shg-frog": 268 | dot_ambiguity = True 269 | # the pulse error to the test pulse 270 | res.pulse_error, res.pulse_retrieved = pulse_error( 271 | res.pulse_retrieved, res.pulse_original, self.ft, 272 | dot_ambiguity=dot_ambiguity) 273 | 274 | if res.logging: 275 | # the logged trace errors 276 | res.trace_errors = np.array(self.log.trace_error) 277 | # the running minimum of the trace errors (for plotting) 278 | res.rm_trace_errors = np.minimum.accumulate(res.trace_errors, 279 | axis=-1) 280 | if self.verbose: 281 | lib.retrieval_report(res) 282 | return res 283 | 284 | 285 | def Retriever(pnps: BasePNPS, method: str = "copra", maxiter=300, maxfev=None, 286 | logging=False, verbose=False, **kwargs) -> BaseRetriever: 287 | """ Creates a retriever instance. 288 | 289 | Parameters 290 | ---------- 291 | pnps : PNPS 292 | A PNPS instance that is used to simulate a PNPS measurement. 293 | method : str, optional 294 | Type of solver. Should be one of 295 | - 'copra' :class:`(see here) ` 296 | - 'gpa' :class:`(see here) ` 297 | - 'gp-dscan' :class:`(see here) ` 298 | - 'pcgpa' :class:`(see here) ` 299 | - 'pie' :class:`(see here) ` 300 | - 'lm' :class:`(see here) ` 301 | - 'bfgs' :class:`(see here) ` 302 | - 'de' :class:`(see here) ` 303 | - 'nelder-mead' :class:`(see here) ` 304 | 305 | 'copra' is the default choice. 306 | maxiter : int, optional 307 | The maximum number of algorithm iterations. The default is 300. 308 | maxfev : int, optional 309 | The maximum number of function evaluations. If given, the algorithms 310 | stop before this number is reached. Not all algorithms support this 311 | feature. Default is ``None``, in which case it is ignored. 312 | logging : bool, optional 313 | Stores trace errors and pulses over the iterations if supported 314 | by the retriever class. Default is `False`. 315 | verbose : bool, optional 316 | Prints out trace errors during the iteration if supported by the 317 | retriever class. Default is `False`. 318 | """ 319 | method = method.lower() 320 | try: 321 | cls = _RETRIEVER_CLASSES[method] 322 | except KeyError: 323 | raise ValueError("Retriever '%s' is unknown!" % (method)) 324 | return cls(pnps, maxiter=maxiter, maxfev=maxfev, 325 | logging=logging, verbose=verbose, **kwargs) 326 | -------------------------------------------------------------------------------- /pypret/retrieval/step_retriever.py: -------------------------------------------------------------------------------- 1 | """ This module implements specific pulse retrieval algorithms, e.g., 2 | COPRA, GPA, PCGPA, etc. 3 | """ 4 | import numpy as np 5 | from scipy.optimize import minimize_scalar 6 | from .retriever import BaseRetriever 7 | from .. import lib 8 | 9 | 10 | class StepRetriever(BaseRetriever): 11 | 12 | def _retrieve(self): 13 | # local rename 14 | o = self.options 15 | log = self.log 16 | res = self._result 17 | rs = self._retrieval_state 18 | # store current guess in attribute 19 | spectrum = self.initial_guess.copy() 20 | # initialize R 21 | R = res.trace_error 22 | for i in range(o.maxiter): 23 | # store trace error and spectrum for later analysis 24 | if self.logging: 25 | # if the trace error was only approximated, calculate it here 26 | if rs.approximate_error: 27 | R = self.trace_error(spectrum, store=False) 28 | log.trace_error.append(R) 29 | # Perform a single retriever step in one of the algorithms 30 | R, new_spectrum = self._retrieve_step(i, spectrum.copy()) 31 | # update the solution if the result is better 32 | if R < res.trace_error: 33 | # R is calculated for the input, i.e., the old spectrum. 34 | res.trace_error = R 35 | res.approximate_error = rs.approximate_error 36 | res.spectrum[:] = spectrum # store the old spectrum 37 | rs.steps_since_improvement = 0 38 | else: 39 | rs.steps_since_improvement += 1 40 | # accept the new spectrum 41 | spectrum[:] = new_spectrum 42 | if self.verbose: 43 | if i == 0: 44 | print("iteration".ljust(10) + "trace error".ljust(20)) 45 | s = "{:d}".format(i + 1).ljust(10) 46 | if rs.approximate_error: 47 | s += "~" 48 | else: 49 | s += " " 50 | s += "{:.10e}".format(R) 51 | if R == res.trace_error: 52 | s += "*" 53 | print(s) 54 | if not rs.running: 55 | break 56 | if self.verbose: 57 | print() 58 | print("~ approximate trace error") 59 | print("* accepted as best trace error") 60 | print() 61 | 62 | # return the retrieved spectrum 63 | # for a more detailed analysis call self.result() 64 | return res.spectrum 65 | 66 | 67 | class COPRARetriever(StepRetriever): 68 | """ This module implements the common pulse retrieval algorithm 69 | [Geib2019]_. 70 | """ 71 | method = "copra" 72 | 73 | def __init__(self, pnps, alpha=0.25, **kwargs): 74 | """ For a full documentation of the arguments see :class:`Retriever`. 75 | 76 | Parameters 77 | ---------- 78 | alpha : float, optional 79 | Scales the step size in the global stage of COPRA. Higher values 80 | mean potentially faster convergence but less accuracy. Lower 81 | values provide higher accuracy for the cost of speed. Default is 82 | 0.25. 83 | """ 84 | super().__init__(pnps, alpha=alpha, **kwargs) 85 | 86 | def _retrieve_begin(self, measurement, initial_guess, weights): 87 | super()._retrieve_begin(measurement, initial_guess, weights) 88 | pnps = self.pnps 89 | rs = self._retrieval_state 90 | rs.mode = "local" # COPRA starts with local mode 91 | # calculate the maximum gradient norm 92 | # self.trace_error() was called beforehand -> rs.Tmn and rs.Smk exist! 93 | Smk2 = self._project(self.Tmn_meas / rs.mu, rs.Smk) 94 | nablaZnm = pnps.gradient(Smk2, self.parameter) 95 | rs.current_max_gradient = np.max(np.sum(lib.abs2(nablaZnm), axis=1)) 96 | 97 | def _retrieve_step(self, iteration, En): 98 | """ Perform a single COPRA step. 99 | 100 | Parameters 101 | ---------- 102 | iteration : int 103 | The current iteration number - mainly for logging. 104 | En : 1d-array 105 | The current pulse spectrum. 106 | """ 107 | # local rename 108 | ft = self.ft 109 | options = self.options 110 | pnps = self.pnps 111 | rs = self._retrieval_state 112 | Tmn_meas = self.Tmn_meas 113 | # current gradient -> last gradient 114 | rs.previous_max_gradient = rs.current_max_gradient 115 | rs.current_max_gradient = 0.0 116 | # switch iteration 117 | if rs.steps_since_improvement == 5: 118 | rs.mode = "global" 119 | # local iteration 120 | if rs.mode == "local": 121 | # running estimate for the trace 122 | Tmn = np.zeros((self.M, self.N)) 123 | for m in np.random.permutation(np.arange(self.M)): 124 | p = self.parameter[m] 125 | Tmn[m, :] = pnps.calculate(En, p) 126 | Smk2 = self._project(Tmn_meas[m, :] / rs.mu, pnps.Smk) 127 | nablaZnm = pnps.gradient(Smk2, p) 128 | # calculate the step size 129 | Zm = lib.norm2(Smk2 - pnps.Smk) 130 | gradient_norm = lib.norm2(nablaZnm) 131 | if gradient_norm > rs.current_max_gradient: 132 | rs.current_max_gradient = gradient_norm 133 | gamma = Zm / max(rs.current_max_gradient, 134 | rs.previous_max_gradient) 135 | # update the spectrum 136 | En -= gamma * nablaZnm 137 | # Tmn is only an approximation as En changed in the iteration! 138 | rs.approximate_error = True 139 | R = self._R(Tmn) # updates rs.mu!!! 140 | # global iteration 141 | elif rs.mode == "global": 142 | Tmn = pnps.calculate(En, self.parameter) 143 | r = self._r(Tmn) 144 | R = self._Rr(r) # updates rs.mu!!! 145 | rs.approximate_error = False 146 | # gradient descent w.r.t. Smk 147 | w2 = self._weights * self._weights 148 | gradrmk = (-4 * ft.dt / (ft.dw * lib.twopi) * 149 | ft.backward(rs.mu * ft.forward(pnps.Smk) * 150 | (Tmn_meas - rs.mu * Tmn) * w2)) 151 | etar = options.alpha * r / lib.norm2(gradrmk) 152 | Smk2 = pnps.Smk - etar * gradrmk 153 | # gradient descent w.r.t. En 154 | nablaZn = pnps.gradient(Smk2, self.parameter).sum(axis=0) 155 | # calculate the step size 156 | Z = lib.norm2(Smk2 - pnps.Smk) 157 | etaz = options.alpha * Z / lib.norm2(nablaZn) 158 | # update the spectrum 159 | En -= etaz * nablaZn 160 | return R, En 161 | 162 | 163 | class PCGPARetriever(StepRetriever): 164 | """ This class implements the principal components generalized projections 165 | algorithm (PCGPA) for SHG-FROG. 166 | 167 | We follow the algorithm as described in [Kane1999]_ but use the PNPS 168 | formalism from [Geib2019]_ and some minor modifications: 169 | 170 | - it supports both the singular value decomposition and the power 171 | method to find/approximate the largest eigenvector. 172 | - the projection includes the scaling factor µ. This makes the method 173 | robust against initial guesses with the wrong magnitude. It should 174 | have no adverse effect. 175 | 176 | """ 177 | method = "pcgpa" 178 | supported_schemes = ["shg-frog"] 179 | 180 | def __init__(self, pnps, decomposition="power", **kwargs): 181 | """ For a full documentation of the arguments see :class:`Retriever`. 182 | 183 | Parameters 184 | ---------- 185 | decomposition : str, optional 186 | It specifies how the FROG signal is decomposed. If `power` (the 187 | default) the power method is used to find the largest eigenvalue. 188 | If `svd` a full singular value decomposition is performed. This 189 | is potentially much slower but more accurate. 190 | """ 191 | super().__init__(pnps, decomposition=decomposition, **kwargs) 192 | 193 | def _retrieve_begin(self, measurement, initial_guess, weights): 194 | super()._retrieve_begin(measurement, initial_guess, weights) 195 | if np.any(self.parameter != measurement.axes[0]): 196 | raise ValueError("The delay has to be sampled exactly at the " 197 | "temporal simulation grid.") 198 | 199 | def _retrieve_step(self, iteration, En): 200 | # local rename 201 | ft = self.ft 202 | options = self.options 203 | pnps = self.pnps 204 | Tmn_meas = self.Tmn_meas 205 | rs = self._retrieval_state 206 | 207 | R = self.trace_error(En) # updates rs.mu!!! 208 | # project on measured intensity 209 | Smk2 = self._project(Tmn_meas / rs.mu, pnps.Smk) 210 | # to outer product form 211 | for n in range(ft.N): 212 | Smk2[:, n] = np.roll(Smk2[::-1, n], n) 213 | if options.decomposition == "power": 214 | # apply power method iteration once 215 | Ek = ft.backward(En) 216 | Ek[:] = Smk2.conj() @ Ek 217 | Ek[:] = Ek.conj() / lib.norm(Ek) 218 | elif options.decomposition == "svd": 219 | # use full svd (slow!) 220 | U, s, V = np.linalg.svd(Smk2) 221 | Ek = U[:, 0] * np.sqrt(s[0]) # select U 222 | ft.forward(Ek, out=En) 223 | return R, En 224 | 225 | 226 | class GPARetriever(StepRetriever): 227 | """ Implements the classical generalized projections algorithm for 228 | SHG-FROG as described in [DeLong1994]_ and [Trebino2000]. 229 | 230 | As far as I know the determination of the step size in GPA is not 231 | made explicit in the publications. It is usually done in a line search. 232 | In this implementation we offer three different options: 233 | 234 | - an exact line search using a Brent style minimizer 235 | - a backtracking (inexact) line search using the Armijo-Goldstein 236 | condition with c=0.5 and tau=0.5. 237 | - the same heuristic choice for the step size used in copra. 238 | 239 | The last method is the fastest, but as the first is the classic choice 240 | for GPA, it is the default. 241 | """ 242 | method = "gpa" 243 | supported_schemes = ["shg-frog"] 244 | 245 | def __init__(self, pnps, step_size="exact", **kwargs): 246 | """ For a full documentation of the arguments see :class:`Retriever`. 247 | 248 | Parameters 249 | ---------- 250 | step_size : str, optional 251 | Specifies how the step size of the gradient step in GPA is 252 | determined. Default is `exact` which performs an exact line search. 253 | `inexact` performs a backtracking line search and `copra` uses the 254 | ad-hoc estimates for the step size used in COPRA. 255 | """ 256 | super().__init__(pnps, step_size=step_size, **kwargs) 257 | 258 | def _retrieve_begin(self, measurement, initial_guess): 259 | super()._retrieve_begin(measurement, initial_guess) 260 | if np.any(self.parameter != measurement.axes[0]): 261 | raise ValueError("The delay has to be sampled exactly at the " 262 | "temporal simulation grid.") 263 | 264 | def _retrieve_step(self, iteration, En): 265 | # local rename 266 | ft = self.ft 267 | options = self.options 268 | pnps = self.pnps 269 | Tmn_meas = self.Tmn_meas 270 | rs = self._retrieval_state 271 | 272 | R = self.trace_error(En) # updates rs.mu!!! 273 | # obtain intermediate results 274 | delay, Amk, Ek, Smk, Tmn = pnps.intermediate(self.parameter) 275 | Ek = Ek[0, :] # the same for every parameter 276 | # project on measured intensity 277 | Smk2 = self._project(Tmn_meas / rs.mu, Smk) 278 | # calculate the gradient of Z w.r.t. to the temporal pulse envelope 279 | # by directly implementing (S58) of [Geib2019]_ 280 | dS = Smk2 - Smk 281 | indices = np.array(np.rint(self.parameter / ft.dt), dtype=np.int32) 282 | gradient = np.zeros((self.M, self.N), dtype=np.complex128) 283 | for m in range(self.M): 284 | gradient[m, :] = np.roll(dS[m, :] * Ek.conj(), -indices[m]) 285 | gradient = -2 * np.sum(gradient + dS * Amk.conj(), axis=0) 286 | 287 | # approximate the step size with the "copra" step size 288 | gamma0 = lib.norm2(dS) / lib.norm2(gradient) 289 | if options.step_size == "copra": 290 | # directly choose the copra step size 291 | gamma = gamma0 292 | else: 293 | # do a line search 294 | def objective(gamma): 295 | self.trace_error(ft.forward(Ek - gamma * gradient)) 296 | return lib.norm2(Smk2 - pnps.Smk) 297 | if options.step_size == "exact": 298 | # perform an exact line search 299 | bracket = [0.9 * gamma0, gamma0] 300 | ret = minimize_scalar(objective, 301 | bracket=bracket, 302 | method="brent") 303 | gamma = ret.x 304 | elif options.step_size == "inexact": 305 | # perform a back-tracking line search until the Armijo 306 | # condition is fulfilled. 307 | t = 0.5 * lib.norm2(gradient) 308 | tau = 0.5 309 | if iteration == 0: 310 | rs.old_gamma = 5.0 * gamma0 311 | gamma = 2 * rs.old_gamma 312 | objective0 = objective(0.0) 313 | while objective(gamma) - objective0 > -gamma * t: 314 | gamma = gamma * tau 315 | rs.old_gamma = gamma 316 | Ek = Ek - gamma * gradient 317 | 318 | ft.forward(Ek, out=En) 319 | return R, En 320 | 321 | 322 | class GPDSCANRetriever(StepRetriever): 323 | """ This class implements a pulse retrieval algorithm for SHG and THG 324 | d-scan based on the paper [Miranda2017]_. 325 | 326 | In our tests we found that it does not converge in the noiseless case. 327 | In other words the global solution to the least-squares problem is not a 328 | fixed point of the iteration. 329 | """ 330 | supported_schemes = ["shg-dscan", "thg-dscan"] 331 | method = "gp-dscan" 332 | 333 | def _retrieve_begin(self, measurement, initial_guess): 334 | super()._retrieve_begin(measurement, initial_guess) 335 | pnps = self.pnps 336 | rs = self._retrieval_state 337 | # calculate phase mask once 338 | rs.Hmn = np.zeros((self.M, self.N), dtype=np.complex128) 339 | for i, p in enumerate(self.parameter): 340 | rs.Hmn[i, :] = pnps.mask(p) 341 | 342 | def _retrieve_step(self, iteration, En): 343 | # local rename 344 | ft = self.ft 345 | pnps = self.pnps 346 | Tmn_meas = self.Tmn_meas 347 | rs = self._retrieval_state 348 | Hmn = rs.Hmn 349 | 350 | R = self.trace_error(En) # updates rs.mu!!! 351 | # project on measured intensity 352 | Smk2 = self._project(Tmn_meas / rs.mu, pnps.Smk) 353 | # modify En 354 | if pnps.process == "shg": 355 | Smk2 *= ft.backward(Hmn * En).conj() 356 | Smk2 /= np.abs(Smk2)**(2/3) 357 | En = np.sum(Hmn.conj() * ft.forward(Smk2), axis=0) 358 | elif pnps.process == "thg": 359 | Smk2 *= ft.backward(Hmn * En).conj()**2 360 | Smk2 /= np.abs(Smk2)**(4/5) 361 | En = np.sum(Hmn.conj() * ft.forward(Smk2), axis=0) 362 | return R, En 363 | 364 | 365 | class PIERetriever(StepRetriever): 366 | """ This class implements a pulse retrieval algorithm for SHG-FROG based on 367 | the ptychographical iterative engine (PIE). It is based on the paper 368 | [Sidorenko2016]_ and its erratum [Sidorenko2017]_. 369 | 370 | We modified the algorithm to include the scaling factor µ in the 371 | projection. This makes the method robust against initial guesses with the 372 | wrong magnitude. It should have no adverse effect. 373 | """ 374 | method = "pie" 375 | supported_schemes = ["shg-frog"] 376 | 377 | def _retrieve_step(self, iteration, En): 378 | # local rename 379 | ft = self.ft 380 | pnps = self.pnps 381 | rs = self._retrieval_state 382 | Tmn_meas = self.Tmn_meas 383 | 384 | # running estimate for the trace 385 | Tmn = np.zeros((self.M, self.N)) 386 | # random choice for the step size scaling 387 | beta = np.random.uniform(0.1, 0.5) 388 | for m in np.random.permutation(np.arange(self.M)): 389 | p = self.parameter[m] 390 | Tmn[m, :] = pnps.calculate(En, p) 391 | # get intermediate results from private attribute 392 | delay, Amk, Ek, Smk, _ = pnps._tmp[p] 393 | # project 394 | Smk2 = self._project(Tmn_meas[m, :] / rs.mu, Smk) 395 | # perform update 396 | Ek += beta * Amk.conj() * (Smk2 - Smk) / lib.abs2(Ek).max() 397 | # update the spectrum 398 | ft.forward(Ek, out=En) 399 | 400 | # Tmn is only an approximation as En changed in the iteration! 401 | rs.approximate_error = True 402 | R = self._R(Tmn) # updates rs.mu!!! 403 | 404 | return R, En 405 | -------------------------------------------------------------------------------- /pypret/tests/data/dscan-sd-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-sd-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/dscan-sd-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-sd-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/dscan-shg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-shg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/dscan-shg-gp-dscan-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-shg-gp-dscan-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/dscan-shg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-shg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/dscan-thg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-thg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/dscan-thg-gp-dscan-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-thg-gp-dscan-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/dscan-thg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/dscan-thg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/frog-pg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/frog-pg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/frog-pg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/frog-pg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/frog-shg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/frog-shg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/frog-shg-gpa-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/frog-shg-gpa-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/frog-shg-pcgpa-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/frog-shg-pcgpa-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/frog-shg-pie-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/frog-shg-pie-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/frog-shg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/frog-shg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/ifrog-sd-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/ifrog-sd-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/ifrog-sd-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/ifrog-sd-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/ifrog-shg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/ifrog-shg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/ifrog-shg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/ifrog-shg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/ifrog-thg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/ifrog-thg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/ifrog-thg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/ifrog-thg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/initial.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/initial.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/miips-sd-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/miips-sd-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/miips-sd-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/miips-sd-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/miips-shg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/miips-shg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/miips-shg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/miips-shg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/miips-thg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/miips-thg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/miips-thg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/miips-thg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/pulse.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/pulse.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/tdp-shg-copra-retrieved.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/tdp-shg-copra-retrieved.hdf5 -------------------------------------------------------------------------------- /pypret/tests/data/tdp-shg-trace.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/pypret/tests/data/tdp-shg-trace.hdf5 -------------------------------------------------------------------------------- /pypret/tests/test_fourier.py: -------------------------------------------------------------------------------- 1 | """ This module tests the Fourier implementation. 2 | 3 | Author: Nils Geib, nils.geib@uni-jena.de 4 | """ 5 | import numpy as np 6 | from pypret.fourier import FourierTransform, Gaussian 7 | from pypret import lib 8 | 9 | 10 | def test_gaussian_transformation(): 11 | """ This test compares the numerical approximation of the Fourier transform 12 | to the analytic solution for a Gaussian function. It uses non-centered 13 | grids and a non-centered Gaussian on purpose. 14 | """ 15 | # define the grid parameters 16 | # choose some arbitrary values to break symmetries 17 | dt = 0.32 18 | N = 205 19 | dw = np.pi / (0.5 * N * dt) 20 | t0 = -(N//2 + 2.1323) * dt 21 | w0 = -(N//2 - 1.23) * dw 22 | # and actually create it 23 | ft = FourierTransform(N, dt, t0=t0, w0=w0) 24 | # create and calculate a non-centered Gaussian distribution 25 | gaussian = Gaussian(10 * dt, 0.1 * t0, 0.12 * w0) 26 | temporal0 = gaussian.temporal(ft.t) 27 | spectral0 = gaussian.spectral(ft.w) 28 | 29 | # calculate the numerical approximations 30 | spectral1 = ft.forward(temporal0) 31 | temporal1 = ft.backward(spectral0) 32 | 33 | temporal_error = lib.nrms(temporal1, temporal0) 34 | spectral_error = lib.nrms(spectral1, spectral0) 35 | 36 | # calculate the error (actual error depends on the FFT implementation) 37 | assert temporal_error < 1e-14 38 | assert spectral_error < 1e-14 39 | 40 | 41 | if __name__ == "__main__": 42 | test_gaussian_transformation() 43 | -------------------------------------------------------------------------------- /pypret/tests/test_mesh_data.py: -------------------------------------------------------------------------------- 1 | """ This module tests the MeshData implementation. 2 | 3 | Author: Nils Geib, nils.geib@uni-jena.de 4 | """ 5 | import numpy as np 6 | from pypret.mesh_data import MeshData 7 | 8 | 9 | def test_mesh_data(): 10 | x = np.linspace(-1.0, 2.0, 100) 11 | y = np.linspace(2.0, -1.0, 110) 12 | X, Y = np.meshgrid(x, y, indexing='ij') 13 | Z = X**2 + 0.4 * Y**2 - 1.0 14 | 15 | md = MeshData(Z, x, y, labels=['delay', 'wavelength'], 16 | units=['s', 'm']) 17 | 18 | md2 = md.copy() 19 | md.marginals() 20 | md.normalize() 21 | md.autolimit() 22 | md.limit((0.0, 0.5), (-0.5, 1.0)) 23 | 24 | x2 = np.linspace(-1.0, 2.0, 50) 25 | y2 = np.linspace(2.0, -1.0, 60) 26 | md.interpolate(y2, x2) 27 | 28 | md.flip() 29 | 30 | 31 | if __name__ == "__main__": 32 | test_mesh_data() 33 | -------------------------------------------------------------------------------- /pypret/tests/test_regression.py: -------------------------------------------------------------------------------- 1 | """ This test spots regressions in the trace calculation and the 2 | retrieval algorithm by comparing their outputs against ones obtained by a 3 | previous version of the code. 4 | 5 | It does not necessarily test the correctness of the calculations - simply 6 | that they did not change. 7 | """ 8 | from pathlib import Path 9 | import numpy as np 10 | import pypret 11 | from pypret import PNPS, material 12 | 13 | 14 | def get_pnps(pulse, method, process): 15 | if method == "miips": 16 | parameter = np.linspace(0.0, 2.0*np.pi, pulse.N//2) # delta in rad 17 | pnps = PNPS(pulse, method, process, gamma=22.5e-15, alpha=1.5 * np.pi) 18 | elif method == "dscan": 19 | parameter = np.linspace(-0.025, 0.025, pulse.N//2) # insertion in m 20 | pnps = PNPS(pulse, method, process, material=material.BK7) 21 | elif method == "ifrog": 22 | parameter = pulse.t 23 | pnps = PNPS(pulse, method, process) 24 | elif method == "frog": 25 | parameter = pulse.t 26 | pnps = PNPS(pulse, method, process) 27 | elif method == "tdp": 28 | parameter = np.linspace(pulse.t[0], pulse.t[-1], pulse.N//2) 29 | pnps = PNPS(pulse, method, process, center=790e-9, width=10.6e-9) 30 | else: 31 | raise ValueError("Method not supported!") 32 | return pnps, parameter 33 | 34 | 35 | def test_regression(): 36 | # test if a test pulse already exists 37 | dirname = Path(__file__).parent / Path("data") 38 | pulse_path = dirname / Path("pulse.hdf5") 39 | if not pulse_path.exists(): 40 | # create simulation grid 41 | ft = pypret.FourierTransform(64, dt=5.0e-15) 42 | # instantiate a pulse object, central wavelength 800 nm 43 | pulse = pypret.Pulse(ft, 800e-9) 44 | # create a random pulse with time-bandwidth product of 2. 45 | pypret.random_pulse(pulse, 1.0, edge_value=1e-8) 46 | # store the pulse 47 | pulse.save(pulse_path, archive=True) 48 | # initial pulse 49 | initial_path = dirname / Path("initial.hdf5") 50 | if not initial_path.exists(): 51 | pulse = pypret.load(pulse_path, archive=True) 52 | pypret.random_gaussian(pulse, 50e-15) 53 | pulse.save(initial_path, archive=True) 54 | initial = pypret.load(initial_path, archive=True).spectrum 55 | 56 | # ========================================================================= 57 | # Test for regressions in the trace calculation 58 | # ========================================================================= 59 | pulse = pypret.load(pulse_path, archive=True) 60 | for method, dct in pypret.pnps._PNPS_CLASSES.items(): 61 | for process, cls in dct.items(): 62 | scheme = process + "-" + method 63 | pnps, parameter = get_pnps(pulse, method, process) 64 | pnps.calculate(pulse.spectrum, parameter) 65 | trace = pnps.trace 66 | trace_path = (dirname / 67 | Path("%s-%s-trace.hdf5" % (method, process))) 68 | # store if not available 69 | if not trace_path.exists(): 70 | pypret.save(trace.data, trace_path, archive=True) 71 | else: 72 | trace2 = pypret.load(trace_path, archive=True) 73 | assert pypret.lib.nrms(trace.data, trace2) < 1e-15 74 | # use the stored values for the next test 75 | trace.data = trace2 76 | 77 | # test for regressions in the retrieval algorithms 78 | retrieval_algorithms = ["copra"] 79 | if scheme in ["shg-dscan", "thg-dscan"]: 80 | retrieval_algorithms += ["gp-dscan"] 81 | if scheme == "shg-frog": 82 | retrieval_algorithms += ["gpa", "pcgpa", "pie"] 83 | 84 | for algorithm in retrieval_algorithms: 85 | fname = (dirname / 86 | Path("%s-%s-%s-retrieved.hdf5" % 87 | (method, process, algorithm))) 88 | ret = pypret.Retriever(pnps, algorithm, maxiter=10, maxfev=256) 89 | np.random.seed(1234) 90 | ret.retrieve(trace, initial) 91 | result = ret._result.spectrum 92 | if not fname.exists(): 93 | pypret.save(result, fname) 94 | else: 95 | result2 = pypret.load(fname) 96 | assert pypret.lib.nrms(result, result2) < 1e-8 97 | 98 | 99 | if __name__ == "__main__": 100 | test_regression() 101 | -------------------------------------------------------------------------------- /pypret/tests/test_sellmeier.py: -------------------------------------------------------------------------------- 1 | """ This module tests the MeshData implementation. 2 | 3 | Author: Nils Geib, nils.geib@uni-jena.de 4 | """ 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from pypret.material import BK7 8 | 9 | 10 | def test_sellmeier(): 11 | assert abs(BK7.n(500e-9) - 1.5214) < 1e-4 12 | assert abs(BK7.n(800e-9) - 1.5108) < 1e-4 13 | assert abs(BK7.n(1200e-9) - 1.5049) < 1e-4 14 | 15 | 16 | if __name__ == "__main__": 17 | test_sellmeier() 18 | wl = np.linspace(300e-9, 1200e-9, 1000) 19 | 20 | fig, ax = plt.subplots() 21 | ax.plot(wl * 1e9, BK7.n(wl)) 22 | ax.set_xlabel("wavelength (nm)") 23 | ax.set_ylabel("refractive index") 24 | ax.set_title("BK7") 25 | ax.grid() 26 | fig.tight_layout() 27 | plt.show() 28 | -------------------------------------------------------------------------------- /scripts/benchmarking.py: -------------------------------------------------------------------------------- 1 | """ This module implements testing procedures for retrieval algorithms. 2 | """ 3 | import path_helper 4 | from types import SimpleNamespace 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib.gridspec as gridspec 8 | from matplotlib.ticker import EngFormatter 9 | from pypret import (FourierTransform, Pulse, random_gaussian, random_pulse, 10 | PNPS, material, Retriever, lib) 11 | from pypret.graphics import plot_complex 12 | 13 | 14 | def benchmark_retrieval(pulse, scheme, algorithm, additive_noise=0.0, 15 | repeat=10, maxiter=300, verbose=False, 16 | initial_guess="random_gaussian", **kwargs): 17 | """ Benchmarks a pulse retrieval algorithm. Uses the parameters from our 18 | paper. 19 | 20 | If you want to benchmark other pulses/configurations you can use the 21 | procedure below as a starting point. 22 | """ 23 | # instantiate the result object 24 | res = SimpleNamespace() 25 | res.pulse = pulse.copy() 26 | res.original_spectrum = pulse.spectrum 27 | 28 | # split the scheme 29 | process, method = scheme.lower().split("-") 30 | 31 | if method == "miips": 32 | # MIIPS 33 | parameter = np.linspace(0.0, 2.0*np.pi, 128) # delta in rad 34 | pnps = PNPS(pulse, method, process, gamma=22.5e-15, alpha=1.5 * np.pi) 35 | elif method == "dscan": 36 | # d-scan 37 | parameter = np.linspace(-0.025, 0.025, 128) # insertion in m 38 | pnps = PNPS(pulse, method, process, material=material.BK7) 39 | elif method == "ifrog": 40 | # ifrog 41 | if process == "sd": 42 | parameter = np.linspace(pulse.t[0], pulse.t[-1], pulse.N * 4) 43 | else: 44 | parameter = pulse.t # delay in s 45 | pnps = PNPS(pulse, method, process) 46 | elif method == "frog": 47 | # frog 48 | parameter = pulse.t # delay in s 49 | pnps = PNPS(pulse, method, process) 50 | elif method == "tdp": 51 | # d-scan 52 | parameter = np.linspace(pulse.t[0], pulse.t[-1], 128) # delay in s 53 | pnps = PNPS(pulse, method, process, center=790e-9, width=10.6e-9) 54 | else: 55 | raise ValueError("Method not supported!") 56 | pnps.calculate(pulse.spectrum, parameter) 57 | measurement = pnps.trace 58 | 59 | # add noise 60 | std = measurement.data.max() * additive_noise 61 | measurement.data += std * np.random.normal(size=measurement.data.shape) 62 | 63 | ret = Retriever(pnps, algorithm, verbose=verbose, logging=True, 64 | maxiter=maxiter, **kwargs) 65 | 66 | res.retrievals = [] 67 | for i in range(repeat): 68 | if initial_guess == "random_gaussian": 69 | # create random Gaussian pulse 70 | random_gaussian(pulse, 50e-15, 0.3 * np.pi) 71 | elif initial_guess == "random": 72 | pulse.spectrum = (np.random.uniform(size=pulse.N) * 73 | np.exp(2.0j * np.pi * 74 | np.random.uniform(size=pulse.N))) 75 | elif initial_guess == "original": 76 | pulse.spectrum = res.original_spectrum 77 | else: 78 | raise ValueError("Initial guess mode '%s' not supported." % initial_guess) 79 | ret.retrieve(measurement, pulse.spectrum) 80 | res.retrievals.append(ret.result(res.original_spectrum)) 81 | 82 | return res 83 | 84 | 85 | class RetrievalResultPlot: 86 | 87 | def __init__(self, retrieval_result, plot=True, **kwargs): 88 | rr = self.retrieval_result = retrieval_result 89 | if rr.pulse_original is None: 90 | raise ValueError("This plot requires an original pulse to compare" 91 | " to.") 92 | if plot: 93 | self.plot(**kwargs) 94 | 95 | def plot(self, xaxis='wavelength', yaxis='intensity', limit=True, 96 | phase_blanking=False, phase_blanking_threshold=1e-3, show=True): 97 | rr = self.retrieval_result 98 | # reconstruct a pulse from that 99 | pulse = Pulse(rr.pnps.ft, rr.pnps.w0, unit="om") 100 | 101 | # construct the figure 102 | fig = plt.figure(figsize=(30.0/2.54, 20.0/2.54)) 103 | gs1 = gridspec.GridSpec(2, 2) 104 | gs2 = gridspec.GridSpec(2, 6) 105 | ax1 = plt.subplot(gs1[0, 0]) 106 | ax2 = plt.subplot(gs1[0, 1]) 107 | ax3 = plt.subplot(gs2[1, :2]) 108 | ax4 = plt.subplot(gs2[1, 2:4]) 109 | ax5 = plt.subplot(gs2[1, 4:]) 110 | ax12 = ax1.twinx() 111 | ax22 = ax2.twinx() 112 | 113 | # Plot in time domain 114 | pulse.spectrum = rr.pulse_original # the test pulse 115 | li011, li012, samp, spha = plot_complex(pulse.t, pulse.field, ax1, ax12, yaxis=yaxis, 116 | phase_blanking=phase_blanking, limit=limit, 117 | phase_blanking_threshold=phase_blanking_threshold, 118 | amplitude_line="k-", phase_line="k--") 119 | pulse.spectrum = rr.pulse_retrieved # the retrieved pulse 120 | li11, li12, samp, spha = plot_complex(pulse.t, pulse.field, ax1, ax12, yaxis=yaxis, 121 | phase_blanking=phase_blanking, limit=limit, 122 | phase_blanking_threshold=phase_blanking_threshold) 123 | li11.set_linewidth(3.0) 124 | li11.set_color("#1f77b4") 125 | li11.set_alpha(0.6) 126 | li12.set_linewidth(3.0) 127 | li12.set_color("#ff7f0e") 128 | li12.set_alpha(0.6) 129 | 130 | fx = EngFormatter(unit="s") 131 | ax1.xaxis.set_major_formatter(fx) 132 | ax1.set_title("time domain") 133 | ax1.set_xlabel("time") 134 | ax1.set_ylabel(yaxis) 135 | ax12.set_ylabel("phase (rad)") 136 | ax1.legend([li011, li11, li12], ["original", "intensity", 137 | "phase"]) 138 | 139 | # frequency domain 140 | if xaxis == "wavelength": 141 | x = pulse.wl 142 | unit = "m" 143 | label = "wavelength" 144 | elif xaxis == "frequency": 145 | x = pulse.w 146 | unit = " rad Hz" 147 | label = "frequency" 148 | # Plot in spectral domain 149 | li021, li022, samp, spha = plot_complex(x, rr.pulse_original, ax2, ax22, yaxis=yaxis, 150 | phase_blanking=phase_blanking, limit=limit, 151 | phase_blanking_threshold=phase_blanking_threshold, 152 | amplitude_line="k-", phase_line="k--") 153 | li21, li22, samp, spha = plot_complex(x, rr.pulse_retrieved, ax2, ax22, yaxis=yaxis, 154 | phase_blanking=phase_blanking, limit=limit, 155 | phase_blanking_threshold=phase_blanking_threshold) 156 | li21.set_linewidth(3.0) 157 | li21.set_color("#1f77b4") 158 | li21.set_alpha(0.6) 159 | li22.set_linewidth(3.0) 160 | li22.set_color("#ff7f0e") 161 | li22.set_alpha(0.6) 162 | 163 | fx = EngFormatter(unit=unit) 164 | ax2.xaxis.set_major_formatter(fx) 165 | ax2.set_title("frequency domain") 166 | ax2.set_xlabel(label) 167 | ax2.set_ylabel(yaxis) 168 | ax22.set_ylabel("phase (rad)") 169 | ax2.legend([li021, li21, li22], ["original", "intensity", 170 | "phase"]) 171 | 172 | axes = [ax3, ax4, ax5] 173 | sc = 1.0 / rr.trace_input.max() 174 | traces = [rr.trace_input * sc, rr.trace_retrieved * sc, 175 | (rr.trace_input - rr.trace_retrieved) * sc] 176 | titles = ["input", "retrieved", "difference"] 177 | cmaps = ["nipy_spectral", "nipy_spectral", "RdBu"] 178 | md = rr.measurement 179 | for ax, trace, title, cmap in zip(axes, traces, titles, cmaps): 180 | x, y = lib.edges(rr.pnps.process_w), lib.edges(rr.parameter) 181 | im = ax.pcolormesh(x, y, trace, cmap=cmap) 182 | fig.colorbar(im, ax=ax) 183 | ax.set_xlabel(md.labels[1]) 184 | ax.set_ylabel(md.labels[0]) 185 | fx = EngFormatter(unit=md.units[1]) 186 | ax.xaxis.set_major_formatter(fx) 187 | fy = EngFormatter(unit=md.units[0]) 188 | ax.yaxis.set_major_formatter(fy) 189 | ax.set_title(title) 190 | 191 | self.fig = fig 192 | self.ax1, self.ax2 = ax1, ax2 193 | self.ax12, self.ax22 = ax12, ax22 194 | self.li11, self.li12, self.li21, self.li22 = li11, li12, li21, li22 195 | self.ax3, self.ax4, self.ax5 = ax3, ax4, ax5 196 | 197 | if show: 198 | #gs.tight_layout(fig) 199 | gs1.update(left=0.05, right=0.95, top=0.9, bottom=0.1, 200 | hspace=0.25, wspace=0.5) 201 | gs2.update(left=0.1, right=0.95, top=0.9, bottom=0.1, 202 | hspace=0.5, wspace=1.0) 203 | plt.show() 204 | -------------------------------------------------------------------------------- /scripts/create_pulse_bank.py: -------------------------------------------------------------------------------- 1 | """ This script generates 100 pulses with time-bandwidth product 2 and 2 | stores them in an HDF5 file. Those pulses are used by other scripts 3 | in this folder. 4 | """ 5 | import path_helper 6 | import pypret 7 | 8 | ft = pypret.FourierTransform(256, dt=5.0e-15) 9 | pulse = pypret.Pulse(ft, 800e-9) 10 | 11 | pulses = [] 12 | for i in range(100): 13 | pypret.random_pulse(pulse, 2.0) 14 | pulses.append(pulse.copy()) 15 | 16 | pypret.save(pulses, "pulse_bank.hdf5") 17 | -------------------------------------------------------------------------------- /scripts/nlo_retrievers.py: -------------------------------------------------------------------------------- 1 | """ This script compares pulse retrieval using different general-purpose, 2 | nonlinear optimization algorithms. 3 | 4 | It can be used to reproduce the results shown in Fig. 2 in our paper 5 | [Geib2019]_. 6 | """ 7 | import path_helper 8 | from types import SimpleNamespace 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from pypret import (FourierTransform, Pulse, random_gaussian, random_pulse, 12 | PNPS, Retriever, load, save) 13 | from pypret.graphics import plot_meshdata 14 | 15 | # CONFIG 16 | repeat = 10 # how often to repeat the retrieval 17 | algorithms = ['lm', 'bfgs', 'de', 'nm'] # algorithms to test 18 | verbose = False # print out the error during iteration 19 | 20 | # %% 21 | # Create a simulation grid 22 | ft = FourierTransform(64, dt=20.0e-15) 23 | # instantiate a pulse with central frequency 800 nm 24 | pulse = Pulse(ft, 800e-9) 25 | # create a random, localized pulse with time-bandwidth product 1 26 | random_pulse(pulse, 1.0, edge_value=1e-7) 27 | 28 | # instantiate an SHG-FROG measurement 29 | parameter = pulse.t # delay in s 30 | pnps = PNPS(pulse, 'frog', 'shg') 31 | # simulate the noiseless measurement 32 | original_spectrum = pulse.spectrum 33 | pnps.calculate(pulse.spectrum, parameter) 34 | measurement = pnps.trace 35 | 36 | data = SimpleNamespace(measurement=measurement, pulse=pulse, 37 | results={}) 38 | for algorithm in algorithms: 39 | ret = Retriever(pnps, algorithm, verbose=verbose, logging=True, 40 | maxfev=20000) 41 | for i in range(repeat): 42 | print("Running algorithm %s run %d/%d" % (algorithm.upper(), i+1, 43 | repeat)) 44 | # create initial pulse, Gaussian in time domain 45 | random_gaussian(pulse, 50e-15) 46 | ret.retrieve(measurement, pulse.spectrum) 47 | res = ret.result(original_spectrum) 48 | print("Finished after %d function evaluations." % res.nfev) 49 | print("final trace error R=%.15e" % res.trace_error) 50 | # store the result with the best trace error 51 | if i == 0 or res.trace_error < data.results[algorithm].trace_error: 52 | data.results[algorithm] = res 53 | 54 | # save simulation data for further plotting 55 | save(data, "nlo_retriever_data.hdf5") 56 | 57 | # %% 58 | # Plot: this part can be run separately 59 | data = load("nlo_retriever_data.hdf5") 60 | fig, (ax1, ax2) = plt.subplots(1, 2) 61 | 62 | # mesh plot of the SHG-FROG trace 63 | im = plot_meshdata(ax1, data.measurement, cmap="nipy_spectral") 64 | ax1.set_title("SHG-FROG") 65 | 66 | for algorithm in algorithms: 67 | res = data.results[algorithm] 68 | iterations = np.arange(res.trace_errors.size) 69 | ax2.plot(iterations, np.minimum.accumulate(res.trace_errors), 70 | label=algorithm.upper()) 71 | ax2.set_yscale('log') 72 | ax2.set_xlabel("function evaluations") 73 | ax2.set_ylabel("trace error R") 74 | ax2.legend(loc="best") 75 | fig.tight_layout() 76 | plt.show() 77 | -------------------------------------------------------------------------------- /scripts/path_helper.py: -------------------------------------------------------------------------------- 1 | """ A small helper that allows to import from the pypret package 2 | in this git repository without any changes to PYTHONPATH. 3 | """ 4 | # first try if pypret is installed 5 | import importlib 6 | spec = importlib.util.find_spec("pypret") 7 | if spec is None or spec.origin == "namespace": 8 | # if pypret is not installed 9 | # add relative path with high priority 10 | import sys 11 | from pathlib import Path 12 | pypret_folder = Path(__file__).resolve().parents[1] 13 | sys.path.insert(0, str(pypret_folder)) 14 | -------------------------------------------------------------------------------- /scripts/pulse_bank.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/scripts/pulse_bank.hdf5 -------------------------------------------------------------------------------- /scripts/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ncgeib/pypret/d47deb675640439df7c8b7c08d71f45ecea3c568/scripts/result.png -------------------------------------------------------------------------------- /scripts/simple_example.py: -------------------------------------------------------------------------------- 1 | """ This shows a simple application of pypret. 2 | 3 | It calculates a PNPS trace from a pulse and displays it. 4 | """ 5 | import path_helper 6 | import numpy as np 7 | import pypret 8 | from pypret import (FourierTransform, Pulse, random_pulse, PNPS, MeshDataPlot) 9 | 10 | ft = FourierTransform(256, dt=2.5e-15) 11 | pulse = Pulse(ft, 800e-9) 12 | random_pulse(pulse, 2.0) 13 | 14 | method = "ifrog" 15 | process = "shg" 16 | 17 | if method == "miips": 18 | # MIIPS 19 | parameter = np.linspace(0.0, 2.0*np.pi, 128) # delta in rad 20 | pnps = PNPS(pulse, method, process, gamma=22.5e-15, alpha=1.5 * np.pi) 21 | elif method == "dscan": 22 | # d-scan 23 | parameter = np.linspace(-0.025, 0.025, 128) # insertion in m 24 | pnps = PNPS(pulse, method, process, material=pypret.material.BK7) 25 | elif method == "ifrog": 26 | # ifrog 27 | if process == "sd": 28 | parameter = np.linspace(pulse.t[0], pulse.t[-1], pulse.N * 4) 29 | else: 30 | parameter = pulse.t # delay in s 31 | pnps = PNPS(pulse, method, process) 32 | elif method == "frog": 33 | # frog 34 | parameter = pulse.t # delay in s 35 | pnps = PNPS(pulse, method, process) 36 | elif method == "tdp": 37 | # d-scan 38 | parameter = np.linspace(pulse.t[0], pulse.t[-1], 128) # delay in s 39 | pnps = PNPS(pulse, method, process, center=790e-9, width=10.6e-9) 40 | else: 41 | raise ValueError("Method not supported!") 42 | pnps.calculate(pulse.spectrum, parameter) 43 | 44 | # Example how to save the calculation 45 | #pnps.save("test_pnps.hdf5") 46 | #pnps2 = pypret.load("test_pnps.hdf5") 47 | 48 | md = pnps.trace 49 | md.autolimit(1) 50 | mdp = MeshDataPlot(md, show=False) 51 | mdp.ax.set_title("SHG-iFROG") 52 | 53 | mdp.show() -------------------------------------------------------------------------------- /scripts/test_retrieval_algorithms.py: -------------------------------------------------------------------------------- 1 | """ This script tests COPRA against a lot of different PNPS schemes 2 | and compares it against PCGPA and ptychographic retrieval for SHG-FROG. 3 | It reproduces the data of Fig. 4, 5 and 7 from [Geib2019]_. 4 | 5 | Notes 6 | ----- 7 | As we are using multiprocessing to speed up the parameter scan you may not be 8 | able to run this script inside of an IDE such as spyder. In that case 9 | please run the script from the commandline using standard Python. 10 | 11 | For plotting the results see `test_retrieval_algorithms_plot.py`. 12 | """ 13 | import path_helper 14 | import pypret 15 | from benchmarking import benchmark_retrieval, RetrievalResultPlot 16 | from pathlib import Path 17 | from concurrent import futures 18 | 19 | # the configs to test: (scheme, algorithm) 20 | configs = [ 21 | ("shg-frog", "copra"), 22 | ("shg-frog", "pcgpa"), 23 | ("shg-frog", "pie"), 24 | ("pg-frog", "copra"), 25 | ("shg-tdp", "copra"), 26 | ("shg-dscan", "copra"), 27 | ("thg-dscan", "copra"), 28 | ("sd-dscan", "copra"), 29 | ("shg-ifrog", "copra"), 30 | ("thg-ifrog", "copra"), 31 | ("sd-ifrog", "copra"), 32 | ("shg-miips", "copra"), 33 | ("thg-miips", "copra"), 34 | ("sd-miips", "copra") 35 | ] 36 | maxworkers = 10 # number of processes used 37 | maxiter = 100 # 300 in the paper 38 | npulses = 10 # 100 in the paper 39 | repeat = 3 # 10 in the paper 40 | # [0.0, 1e-3, 3e-3, 5e-3, 1e-2, 3e-2, 5e-2] in the paper 41 | noise_levels = [1e-2, 3e-2] 42 | 43 | # block the main routine as we are using multiprocessing 44 | if __name__ == "__main__": 45 | pulses = pypret.load("pulse_bank.hdf5") 46 | 47 | path = Path("results") 48 | if not path.exists(): 49 | path.mkdir() 50 | 51 | for scheme, algorithm in configs: 52 | for noise in noise_levels: 53 | print("Testing %s with %s and noise level %.1f%%" % 54 | (scheme.upper(), algorithm.upper(), noise * 100)) 55 | results = [] 56 | # run the different pulses in different processes, 57 | # not optimal but better than no parallelism 58 | fs = {} 59 | with futures.ProcessPoolExecutor(max_workers=maxworkers) as executor: 60 | for i, pulse in enumerate(pulses[:npulses]): 61 | future = executor.submit(benchmark_retrieval, 62 | pulses, scheme, algorithm, repeat=repeat, 63 | verbose=False, maxiter=maxiter, 64 | additive_noise=noise) 65 | fs[future] = i 66 | for future in futures.as_completed(fs): 67 | i = fs[future] 68 | try: 69 | res = future.result() 70 | except Exception as exc: 71 | print('Retrieval generated an exception: %s' % exc) 72 | results.append(res) 73 | print("Finished pulse %d/%d" % (i+1, npulses) ) 74 | fname = "%s_%s_noise_%.1e.hdf5.7z" % (scheme.upper(), 75 | algorithm.upper(), noise) 76 | pypret.save(results, path / fname, archive=True) 77 | print("Stored in %s" % fname) 78 | -------------------------------------------------------------------------------- /scripts/test_retrieval_algorithms_plot.py.py: -------------------------------------------------------------------------------- 1 | """ This script plots the results obtained by `test_retrieval_algorithms.py`. 2 | Therefore, that script has to be run first. 3 | 4 | It reproduces Fig. 4, 5 and 7 from [Geib2019]_. 5 | """ 6 | import path_helper 7 | from pathlib import Path 8 | from types import SimpleNamespace 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import pypret 12 | 13 | # the configs to plot: (scheme, algorithm) separated by subplot 14 | configs = [ 15 | ("shg-frog", "copra"), 16 | ("shg-frog", "pcgpa"), 17 | ("shg-frog", "pie"), 18 | ("pg-frog", "copra"), 19 | ("shg-tdp", "copra"), 20 | ("shg-dscan", "copra"), 21 | ("thg-dscan", "copra"), 22 | ("sd-dscan", "copra"), 23 | ("shg-ifrog", "copra"), 24 | ("thg-ifrog", "copra"), 25 | ("sd-ifrog", "copra"), 26 | ("shg-miips", "copra"), 27 | ("thg-miips", "copra"), 28 | ("sd-miips", "copra") 29 | ] 30 | noise_levels = [0.0, 1e-2, 3e-2] 31 | 32 | def get_fname(scheme, algorithm, noise): 33 | return "%s_%s_noise_%.1e.hdf5.7z" % (scheme.upper(), 34 | algorithm.upper(), noise) 35 | 36 | # %% 37 | path = Path("results") 38 | results = {} 39 | for scheme, algorithm in configs: 40 | for noise in noise_levels: 41 | # iterate over configs 42 | fname = get_fname(scheme, algorithm, noise) 43 | ares = SimpleNamespace(pulse_errors=[], trace_errors=[], 44 | relative_trace_errors=[], 45 | relative_rm_trace_errors=[]) 46 | # iterate over retrieved pulses 47 | for res in pypret.load(path / fname, archive=True): 48 | # for every pulse several retrievals were made 49 | # select solution with lowest trace error 50 | pres = res.retrievals[np.argmin([r.trace_error 51 | for r in res.retrievals])] 52 | # now store the final trace error of the solution 53 | ares.trace_errors.append(pres.trace_error) 54 | # relative trace error (R - R0 in the paper) 55 | ares.relative_trace_errors.append(pres.trace_error - 56 | pres.trace_error_optimal) 57 | # the running minimum of the trace error minus the optimal 58 | # trace error (R - R0 in the paper) - for plotting 59 | ares.relative_rm_trace_errors.append(pres.rm_trace_errors - 60 | pres.trace_error_optimal) 61 | # the pulse error of the solution 62 | ares.pulse_errors.append(pres.pulse_error) 63 | # convert to numpy arrays and calculate the median 64 | ares.trace_errors = np.array(ares.trace_errors) 65 | ares.relative_trace_errors = np.array(ares.relative_trace_errors) 66 | ares.relative_rm_trace_errors = np.array(ares.relative_rm_trace_errors) 67 | ares.pulse_errors = np.array(ares.pulse_errors) 68 | # store in dictionary 69 | results[fname] = ares 70 | 71 | pypret.save(results, path / "plot_results.hdf5") 72 | 73 | # %% 74 | # print out the results (equivalent to Fig. 7) 75 | results = pypret.load(path / "plot_results.hdf5") 76 | # print numerical results 77 | s = "filename".ljust(40) 78 | s += "trace error".ljust(15) 79 | s += "R - R0".ljust(15) 80 | s += "pulse_error".ljust(15) 81 | print(s) 82 | for scheme, algorithm in configs: 83 | for noise in noise_levels: 84 | fname = "%s_%s_noise_%.1e.hdf5.7z" % (scheme.upper(), 85 | algorithm.upper(), noise) 86 | ares = results[fname] 87 | s = fname.ljust(40) 88 | s += ("%.3e" % np.median(ares.trace_errors)).ljust(15) 89 | s += ("%.3e" % np.median(ares.relative_trace_errors)).ljust(15) 90 | s += ("%.3e" % np.median(ares.pulse_errors)).ljust(15) 91 | print(s) 92 | 93 | # %% 94 | # do the plots (equivalent to Fig. 4 and 5) 95 | plot_configs = [ 96 | [("shg-frog", "copra"), 97 | ("shg-frog", "pcgpa"), 98 | ("shg-frog", "pie")], 99 | [("shg-ifrog", "copra"), 100 | ("thg-ifrog", "copra"), 101 | ("sd-ifrog", "copra")], 102 | [("shg-dscan", "copra"), 103 | ("thg-dscan", "copra"), 104 | ("sd-dscan", "copra")], 105 | [("shg-miips", "copra"), 106 | ("thg-miips", "copra"), 107 | ("sd-miips", "copra")] 108 | ] 109 | 110 | # first plot no noise case 111 | noise = 0.0 112 | fig, axs = plt.subplots(2, 2, figsize=(20.0/2.54, 20.0/2.54)) 113 | for config, ax in zip(plot_configs, axs.flat): 114 | for scheme, algorithm in config: 115 | fname = get_fname(scheme, algorithm, noise) 116 | ares = results[fname] 117 | errors = np.median(ares.relative_rm_trace_errors, axis=0) 118 | iterations = np.arange(errors.size) 119 | ax.plot(iterations, errors, label=scheme + " " + algorithm) 120 | ax.set_xlabel("iterations") 121 | ax.set_ylabel("R (running minimum)") 122 | ax.legend(loc="best") 123 | ax.set_yscale("log") 124 | fig.tight_layout() 125 | plt.show() 126 | 127 | # second plot 1% noise case 128 | noise = 0.01 129 | fig, axs = plt.subplots(2, 2, figsize=(20.0/2.54, 20.0/2.54)) 130 | for config, ax in zip(plot_configs, axs.flat): 131 | for scheme, algorithm in config: 132 | fname = get_fname(scheme, algorithm, noise) 133 | ares = results[fname] 134 | errors = np.median(ares.relative_rm_trace_errors, axis=0) 135 | iterations = np.arange(errors.size) 136 | ax.plot(iterations, errors, label=scheme + " " + algorithm) 137 | ax.set_xlabel("iterations") 138 | ax.set_ylabel("R - R0 (running minimum)") 139 | ax.legend(loc="best") 140 | ax.set_yscale("symlog", linthreshy=1e-5) 141 | fig.tight_layout() 142 | plt.show() 143 | -------------------------------------------------------------------------------- /scripts/test_single_retrieval.py: -------------------------------------------------------------------------------- 1 | """ This module shows how to use the benchmark script. It triggers 2 | a single retrieval simulation and plots the results. It can be used 3 | to quickly assess the performance of a single algorithm/measurement scheme 4 | combination. 5 | """ 6 | import path_helper 7 | import pypret 8 | from benchmarking import benchmark_retrieval, RetrievalResultPlot 9 | 10 | scheme = ( # can be one of the following 11 | "shg-frog" 12 | # "pg-frog" 13 | # "tg-frog" 14 | # "shg-tdp" 15 | # "shg-dscan" 16 | # "thg-dscan" 17 | # "sd-dscan" 18 | # "shg-ifrog" 19 | # "thg-ifrog" 20 | # "sd-ifrog" 21 | # "shg-miips" 22 | # "thg-miips" 23 | # "sd-miips" 24 | ) 25 | 26 | pulses = pypret.load("pulse_bank.hdf5") 27 | res = benchmark_retrieval(pulses[2], scheme, "copra", repeat=1, 28 | verbose=True, maxiter=300, 29 | additive_noise=0.01) 30 | rrp = RetrievalResultPlot(res.retrievals[0]) 31 | rrp.fig.savefig("result.png") 32 | --------------------------------------------------------------------------------