├── astra
├── interfaces
│ ├── __init__.py
│ └── bmad.py
├── __init__.py
├── install.py
├── tools.py
├── generator.py
├── control.py
├── archive.py
├── astra_distgen.py
├── evaluate.py
├── writers.py
├── fieldmaps.py
├── plot.py
├── astra_calc.py
├── parsers.py
├── _version.py
└── astra.py
├── docs
├── api
│ ├── astra.md
│ ├── install.md
│ └── generator.md
├── assets
│ └── apex-gun-lume-astra.png
├── stylesheets
│ └── extra.css
└── index.md
├── requirements.txt
├── MANIFEST.in
├── docs-requirements.txt
├── setup.cfg
├── overrides
└── main.html
├── scripts
└── execute_notebooks.bash
├── environment.yml
├── setup.py
├── README.md
├── mkdocs.yml
└── LICENSE
/astra/interfaces/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/api/astra.md:
--------------------------------------------------------------------------------
1 | ::: astra.astra
--------------------------------------------------------------------------------
/docs/api/install.md:
--------------------------------------------------------------------------------
1 | ::: astra.install
--------------------------------------------------------------------------------
/docs/api/generator.md:
--------------------------------------------------------------------------------
1 | ::: astra.generator
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | requests
2 | distgen
3 | numpy
4 | h5py
5 | openpmd-beamphysics
6 | lume-base
7 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include LICENSE
3 | include versioneer.py
4 | include astra/_version.py
5 |
--------------------------------------------------------------------------------
/docs-requirements.txt:
--------------------------------------------------------------------------------
1 | pygments
2 | mkdocs
3 | mkdocstrings
4 | mkdocs-material
5 | livereload
6 | pytkdocs[numpy-style]
--------------------------------------------------------------------------------
/docs/assets/apex-gun-lume-astra.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChristopherMayes/lume-astra/HEAD/docs/assets/apex-gun-lume-astra.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [versioneer]
2 | VCS = git
3 | style = pep440
4 | versionfile_source = astra/_version.py
5 | versionfile_build = astra/_version.py
6 | tag_prefix = v
7 |
--------------------------------------------------------------------------------
/astra/__init__.py:
--------------------------------------------------------------------------------
1 | from .astra import *
2 | from .generator import AstraGenerator
3 | from .evaluate import evaluate_astra_with_generator
4 | from .astra_distgen import run_astra_with_distgen, evaluate_astra_with_distgen
5 |
6 | from . import _version
7 | __version__ = _version.get_versions()['version']
8 |
--------------------------------------------------------------------------------
/overrides/main.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block content %}
4 | {% if page.nb_url %}
5 |
6 |
7 |
8 | {% endif %}
9 |
10 | {{ super() }}
11 | {% endblock content %}
--------------------------------------------------------------------------------
/scripts/execute_notebooks.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | NOTEBOOKS=$(find . -type f -name "*.ipynb" -not -path '*/.*')
4 |
5 | SKIP="autophase"
6 |
7 | echo $NOTEBOOKS
8 |
9 | for file in $NOTEBOOKS
10 | do
11 | if [[ "$file" == *"$SKIP"* ]]; then
12 | echo "Skipping $file"
13 | continue
14 | fi
15 |
16 | echo "Executing $file"
17 | jupyter nbconvert --to notebook --execute $file --inplace
18 | done
19 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # conda env create -f environment.yml
2 | name: lume-astra-dev
3 | channels:
4 | - conda-forge
5 | dependencies:
6 | - python=3.9
7 | - openpmd-beamphysics
8 | - lume-base
9 | - distgen
10 | - matplotlib
11 | # Developer
12 | - bmad
13 | - pytest
14 | - jupyterlab>=3
15 | - ipywidgets
16 | - pygments
17 | - mkdocs
18 | - mkdocstrings
19 | - mkdocs-material
20 | - mkdocs-jupyter
21 | - pip
22 | - pip:
23 | - mkdocstrings-python
24 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import versioneer
2 | from setuptools import setup, find_packages
3 | from os import path, environ
4 |
5 | cur_dir = path.abspath(path.dirname(__file__))
6 |
7 | with open(path.join(cur_dir, 'requirements.txt'), 'r') as f:
8 | requirements = f.read().split()
9 |
10 |
11 |
12 | setup(
13 | name='lume-astra',
14 | version=versioneer.get_version(),
15 | cmdclass=versioneer.get_cmdclass(),
16 | packages=find_packages(),
17 | package_dir={'xopt':'xopt'},
18 | url='https://github.com/ChristopherMayes/lume-astra',
19 | long_description=open('README.md').read(),
20 | long_description_content_type='text/markdown',
21 | install_requires=requirements,
22 | include_package_data=True,
23 | python_requires='>=3.6'
24 | )
25 |
--------------------------------------------------------------------------------
/docs/stylesheets/extra.css:
--------------------------------------------------------------------------------
1 | /*
2 | Hide In and Out prefix
3 | https://oceanumeric.github.io/blog/2022/jupyter-style/
4 | */
5 | .jp-CodeCell > .jp-Cell-inputWrapper {
6 | direction: rtl;
7 | width:113%;
8 | }
9 |
10 | .jp-InputArea-prompt {
11 | visibility: hidden;
12 | }
13 |
14 | .jp-OutputArea-prompt {
15 | visibility: hidden; /* disable this to tune the position */
16 | background-color:red;
17 | position:absolute;
18 | right: 0;
19 |
20 | }
21 |
22 | .jp-CodeCell > .jp-Cell-outputWrapper {
23 | margin-top: -10px;
24 | padding-top:0;
25 | display: table-cell;
26 | text-align: left;
27 | }
28 | .jp-Cell-outputWrapper > .jp-Cell-outputCollapser {
29 | /* background-color:red; */
30 | margin-top: -17px;
31 | }
32 |
33 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # LUME-Astra
2 | Python wrapper for [ASTRA](http://www.desy.de/~mpyflo/) (A Space Charge Tracking Algorithm, DESY) for use in LUME.
3 |
4 |
5 | ```python
6 | from astra import Astra
7 |
8 | A = Astra('Astra.in')
9 | A.run()
10 | A.plot(y=['norm_emit_x', 'norm_emit_y'], y2=['sigma_x', 'sigma_y'])
11 | ```
12 | 
13 |
14 | Installing lume-astra
15 | =====================
16 |
17 | Installing `lume-astra` from the `conda-forge` channel can be achieved by adding `conda-forge` to your channels with:
18 |
19 | ```
20 | conda config --add channels conda-forge
21 | ```
22 |
23 | Once the `conda-forge` channel has been enabled, `lume-astra` can be installed with:
24 |
25 | ```
26 | conda install lume-astra
27 | ```
28 |
29 | It is possible to list all of the versions of `lume-astra` available on your platform with:
30 |
31 | ```
32 | conda search lume-astra --channel conda-forge
33 |
34 | ```
35 |
36 | Installing Astra Executables
37 | =====================
38 |
39 | For convenience, you can set `$ASTRA_BIN` and `$GENERATOR_BIN` to point to the Astra and generator binaries for your system. See the [install_astra.ipynb](./examples/install_astra.ipynb) example for easy installation.
40 |
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # LUME-Astra
5 | Python wrapper for [ASTRA](http://www.desy.de/~mpyflo/) (A Space Charge Tracking Algorithm, DESY) for use in LUME.
6 |
7 | **`Documentation`** |
8 | ------------------- |
9 | [](https://christophermayes.github.io/lume-astra/) |
10 |
11 | ```python
12 | from astra import Astra
13 |
14 | A = Astra('Astra.in')
15 | A.run()
16 | A.plot(y=['norm_emit_x', 'norm_emit_y'], y2=['sigma_x', 'sigma_y'])
17 | ```
18 | 
19 |
20 |
21 | Installing lume-astra
22 | =====================
23 |
24 | Installing `lume-astra` from the `conda-forge` channel can be achieved by adding `conda-forge` to your channels with:
25 |
26 | ```
27 | conda config --add channels conda-forge
28 | ```
29 |
30 | Once the `conda-forge` channel has been enabled, `lume-astra` can be installed with:
31 |
32 | ```
33 | conda install lume-astra
34 | ```
35 |
36 | It is possible to list all of the versions of `lume-astra` available on your platform with:
37 |
38 | ```
39 | conda search lume-astra --channel conda-forge
40 |
41 | ```
42 |
43 |
44 | Installing Astra Executables
45 | =====================
46 |
47 | For convenience, you can set `$ASTRA_BIN` and `$GENERATOR_BIN` to point to the Astra and generator binaries for your system. See the [install_astra.ipynb](./examples/install_astra.ipynb) example for easy installation.
48 |
49 |
50 | ## Basic usage
51 |
52 | See [simple_astra_run.ipynb](./examples/basic_astra_examples.ipynb). In short:
53 |
54 | ```python
55 | from astra import Astra
56 |
57 | A = Astra('../templates/Astra.in')
58 |
59 | A.verbose = True
60 | A.run()
61 | ...
62 | output = A.output
63 | ```
64 |
65 |
66 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: lume-astra
2 | repo_url: https://github.com/ChristopherMayes/lume-astra
3 | repo_name: "ChristopherMayes/lume-astra"
4 |
5 | nav:
6 | - Home: index.md
7 | - Examples:
8 | - Basic:
9 | - examples/basic_astra_examples.ipynb
10 | - examples/install_astra.ipynb
11 | - examples/functional_astra_run.ipynb
12 | - examples/simple_distgen_example.ipynb
13 | - examples/simple_generator_example.ipynb
14 | - Interfaces:
15 | - examples/interfaces/astra_to_bmad.ipynb
16 | - examples/interfaces/bmad_to_astra.ipynb
17 | - Low Level:
18 | - examples/scan_example.ipynb
19 | - examples/plot_examples.ipynb
20 | - Elements:
21 | - examples/elements/apex_gun.ipynb
22 | - examples/elements/drift.ipynb
23 | - examples/elements/tws.ipynb
24 | - examples/elements/tesla_9cell_cavity.ipynb
25 |
26 |
27 |
28 |
29 | - API:
30 | - Astra: api/astra.md
31 | - Generator: api/generator.md
32 | - Install: api/install.md
33 |
34 | theme:
35 | icon:
36 | repo: fontawesome/brands/github
37 | name: material
38 |
39 | custom_dir: overrides
40 | features:
41 | - navigation.top
42 | - navigation.tabs
43 | - navigation.indexes
44 | palette:
45 | - media: "(prefers-color-scheme: light)"
46 | scheme: default
47 | primary: black
48 | toggle:
49 | icon: material/toggle-switch-off-outline
50 | name: Switch to dark mode
51 | - media: "(prefers-color-scheme: dark)"
52 | scheme: slate
53 | primary: black
54 | toggle:
55 | icon: material/toggle-switch
56 | name: Switch to light mode
57 |
58 | markdown_extensions:
59 | - pymdownx.highlight
60 | - pymdownx.superfences
61 | - pymdownx.arithmatex: # Enable MathJAX https://squidfunk.github.io/mkdocs-material/reference/mathjax/
62 | generic: true
63 |
64 | extra_javascript:
65 | - javascripts/mathjax.js
66 | - https://polyfill.io/v3/polyfill.min.js?features=es6
67 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
68 |
69 | extra:
70 | generator: false
71 | social:
72 | - icon: fontawesome/brands/github
73 | link: https://github.com/ChristopherMayes/lume-astra
74 | name: LUME-Astra
75 |
76 | extra_css:
77 | - stylesheets/extra.css
78 |
79 | plugins:
80 | - search
81 |
82 | - mkdocs-jupyter:
83 | include_source: True
84 |
85 | - mkdocstrings:
86 | default_handler: python
87 | handlers:
88 | python:
89 | selection:
90 | docstring_style: "numpy"
91 | inherited_members: false
92 | filters:
93 | - "!^_" # exclude all members starting with _
94 | - "^__init__$" # but always include __init__ modules and methods
95 | rendering:
96 | show_source: true
97 | show_root_heading: true
--------------------------------------------------------------------------------
/astra/install.py:
--------------------------------------------------------------------------------
1 | from .tools import make_executable
2 |
3 | import os, sys, platform
4 | import urllib.request
5 |
6 |
7 | ASTRA_URL = {
8 | 'Linux': 'http://www.desy.de/~mpyflo/Astra_for_64_Bit_Linux/Astra',
9 | 'Darwin': 'http://www.desy.de/~mpyflo/Astra_for_Mac_OSX/Astra',
10 | 'Windows': 'http://www.desy.de/~mpyflo/Astra_for_WindowsPC/Astra.exe'
11 | }
12 | GENERATOR_URL = {
13 | 'Linux': 'http://www.desy.de/~mpyflo/Astra_for_64_Bit_Linux/generator',
14 | 'Darwin': 'http://www.desy.de/~mpyflo/Astra_for_Mac_OSX/generator',
15 | 'Windows': 'http://www.desy.de/~mpyflo/Astra_for_WindowsPC/generator.exe'
16 | }
17 |
18 | EXAMPLES_URL = "https://ChristopherMayes.github.io/lume-astra/assets/lume-astra-examples.zip"
19 |
20 |
21 |
22 | def install_executable(url, dest, verbose=False):
23 | """
24 | Downloads a url into a destination and makes it executable. Will not overwrite.
25 | """
26 |
27 | if os.path.exists(dest):
28 | print(os.path.abspath(dest), 'exists, will not overwrite')
29 | else:
30 | if verbose:
31 | print(f'Downloading {url} to {dest}')
32 | urllib.request.urlretrieve(url, dest)
33 | make_executable(dest)
34 |
35 | return dest
36 |
37 | def install_astra(dest_dir=None, name='Astra', verbose=False):
38 | """
39 | Installs Astra from Klaus Floettmann's DESY website for the detected platform.
40 |
41 | Sets environmental variable ASTRA_BIN
42 | """
43 | system = platform.system()
44 | url=ASTRA_URL[system]
45 | dest = os.path.abspath(os.path.join(dest_dir, name))
46 |
47 | install_executable(url, dest, verbose=verbose)
48 |
49 | os.environ['ASTRA_BIN'] = dest
50 |
51 | if verbose:
52 | print(f'Installed Astra in {dest}, and set $ASTRA_BIN equal to this.')
53 |
54 | return dest
55 |
56 | def install_generator(dest_dir=None, name='generator', verbose=False):
57 | """
58 | Installs Astra's generator from Klaus Floettmann's DESY website for the detected platform.
59 |
60 | Sets environmental variable GENERATOR_BIN
61 | """
62 | system = platform.system()
63 | url=GENERATOR_URL[system]
64 | dest = os.path.abspath(os.path.join(dest_dir, name))
65 | install_executable(url, dest, verbose=verbose)
66 |
67 | os.environ['GENERATOR_BIN'] = dest
68 |
69 | if verbose:
70 | print(f'Installed Astra\'s generator in {dest}, and set $GENERATOR_BIN equal to this.')
71 |
72 | return dest
73 |
74 |
75 | def install_examples(location):
76 | """
77 | Install lume-astra examples on the informed location.
78 |
79 | Parameters
80 | ----------
81 | location : str
82 | The folder in which to install the examples
83 | """
84 | import requests
85 | import zipfile
86 | import io
87 | import os
88 |
89 | loc = os.path.expanduser(os.path.expandvars(location))
90 | os.makedirs(loc, exist_ok=True)
91 |
92 | response = requests.get(EXAMPLES_URL, stream=True)
93 | if response.status_code == 200:
94 | zip_file = zipfile.ZipFile(io.BytesIO(response.content))
95 | zip_file.extractall(loc)
96 | else:
97 | raise RuntimeError(f"Could not download examples archive. Status code was: {response.status_code}")
98 |
--------------------------------------------------------------------------------
/astra/tools.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import errno
3 | import os
4 | import subprocess
5 | import sys
6 | import traceback
7 |
8 | def execute(cmd, cwd=None):
9 | """
10 |
11 | Constantly print Subprocess output while process is running
12 | from: https://stackoverflow.com/questions/4417546/constantly-print-subprocess-output-while-process-is-running
13 |
14 | # Example usage:
15 | for path in execute(["locate", "a"]):
16 | print(path, end="")
17 |
18 | Useful in Jupyter notebook
19 |
20 | """
21 | popen = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, universal_newlines=True, cwd=cwd)
22 | if os.name == 'nt':
23 | # When running Astra with Windows it requires us to Press return at the end of execution
24 | popen.stdin.write("\n")
25 | popen.stdin.flush()
26 | popen.stdin.close()
27 | for stdout_line in iter(popen.stdout.readline, ""):
28 | yield stdout_line
29 | popen.stdin.close()
30 | popen.stdout.close()
31 | return_code = popen.wait()
32 | if return_code:
33 | raise subprocess.CalledProcessError(return_code, cmd)
34 |
35 |
36 | # Alternative execute
37 | def execute2(cmd, timeout=None, cwd=None):
38 | """
39 | Execute with time limit (timeout) in seconds, catching run errors.
40 | """
41 |
42 | output = {'error': True, 'log': ''}
43 | try:
44 | p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True,
45 | timeout=timeout, cwd=cwd)
46 | output['log'] = p.stdout
47 | output['error'] = False
48 | output['why_error'] = ''
49 | except subprocess.TimeoutExpired as ex:
50 | output['log'] = ex.stdout + '\n' + str(ex)
51 | output['why_error'] = 'timeout'
52 | except:
53 | #exc_tuple = sys.exc_info()
54 | error_str = traceback.format_exc()
55 | output['log'] = 'unknown run error'
56 | output['why_error'] = error_str
57 | return output
58 |
59 |
60 | def runs_script(runscript=[], dir=None, log_file=None, verbose=True):
61 | """
62 | Basic driver for running a script in a directory. Will
63 | """
64 |
65 | # Save init dir
66 | init_dir = os.getcwd()
67 |
68 | if dir:
69 | os.chdir(dir)
70 |
71 | log = []
72 |
73 | for path in execute(runscript):
74 | if verbose:
75 | print(path, end="")
76 | log.append(path)
77 | if log_file:
78 | with open(log_file, 'w') as f:
79 | for line in log:
80 | f.write(line)
81 |
82 | # Return to init dir
83 | os.chdir(init_dir)
84 | return log
85 |
86 |
87 | def mkdir_p(path):
88 | try:
89 | os.makedirs(path)
90 | except OSError as exc: # Python >2.5
91 | if exc.errno == errno.EEXIST and os.path.isdir(path):
92 | pass
93 | else:
94 | raise
95 |
96 |
97 | def make_executable(path):
98 | """
99 | https://stackoverflow.com/questions/12791997/how-do-you-do-a-simple-chmod-x-from-within-python
100 | """
101 | mode = os.stat(path).st_mode
102 | mode |= (mode & 0o444) >> 2 # copy R bits to X
103 | os.chmod(path, mode)
104 |
105 |
106 | def native_type(value):
107 | """
108 | Converts a numpy type to a native python type.
109 | See:
110 | https://stackoverflow.com/questions/9452775/converting-numpy-dtypes-to-native-python-types/11389998
111 | """
112 | return getattr(value, 'tolist', lambda: value)()
113 |
114 |
115 | def isotime():
116 | """UTC to ISO 8601 with Local TimeZone information without microsecond"""
117 | return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone().replace(
118 | microsecond=0).isoformat()
119 |
120 |
121 |
122 |
123 |
124 | def make_symlink(src, path):
125 | """
126 | Makes a symlink from a source file `src` into a path.
127 |
128 | Will not overwrite real files.
129 |
130 | Parameters
131 | ----------
132 | src: source filename
133 | path: path to make symlink into
134 |
135 | Returns
136 | -------
137 | succeess: bool
138 |
139 | """
140 |
141 | _, file = os.path.split(src)
142 |
143 | dest = os.path.join(path, file)
144 |
145 | # Replace old symlinks.
146 | if os.path.islink(dest):
147 | os.unlink(dest)
148 | elif os.path.exists(dest):
149 | return False
150 |
151 | os.symlink(src, dest)
152 |
153 | return True
154 |
155 |
--------------------------------------------------------------------------------
/astra/generator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | import traceback
5 |
6 | from lume import tools as lumetools
7 | from lume.base import CommandWrapper
8 | from astra import archive
9 | from pmd_beamphysics import ParticleGroup
10 | from pmd_beamphysics.interfaces.astra import parse_astra_phase_file
11 | import h5py
12 |
13 | from astra import parsers, writers, tools
14 |
15 |
16 | class AstraGenerator(CommandWrapper):
17 | """
18 | Class to run Astra's particle generator
19 |
20 |
21 | The file for Astra is written in:
22 | .output_file
23 |
24 | """
25 | COMMAND = "$GENERATOR_BIN"
26 |
27 | def __init__(self, input_file=None, **kwargs):
28 | super().__init__(input_file=input_file, **kwargs)
29 | # Save init
30 |
31 | if isinstance(input_file, str):
32 | self.original_input_file = self.input_file
33 | else:
34 | self.original_input_file = 'generator.in'
35 |
36 | # These will be filled
37 | self.input = {}
38 | self.output = {}
39 |
40 | # Call configure
41 | if self.input_file:
42 | self.load_input(self.input_file)
43 | self.configure()
44 |
45 | def load_input(self, input_filepath, absolute_paths=True, **kwargs):
46 | # Allow dict
47 | if isinstance(input_filepath, dict):
48 | self.input = input_filepath
49 | return
50 |
51 | super().load_input(input_filepath, **kwargs)
52 | if absolute_paths:
53 | parsers.fix_input_paths(self.input, root=self.original_path)
54 | self.input = self.input['input']
55 |
56 |
57 | def input_parser(self, path):
58 | return parsers.parse_astra_input_file(path)
59 |
60 | def configure(self):
61 | # Check that binary exists
62 | self.command = lumetools.full_path(self.command)
63 | self.setup_workdir(self.path)
64 |
65 | self.input_file = os.path.join(self.path, self.original_input_file)
66 |
67 | # We will change directories to work in the local directory
68 | self.input['fname'] = 'generator.part'
69 |
70 | self.configured = True
71 |
72 | def run(self):
73 | """
74 | Runs Generator
75 |
76 | Note: do not use os.chdir
77 | """
78 | self.write_input_file()
79 |
80 | runscript = self.get_run_script()
81 |
82 | try:
83 | res = tools.execute2(runscript, timeout=None, cwd=self.path)
84 | self.log = res['log']
85 |
86 | self.vprint(self.log)
87 |
88 | # This is the file that should be written
89 | if os.path.exists(self.output_file):
90 | self.finished = True
91 | else:
92 | print(f'AstraGenerator.output_file {self.output_file} does not exist.')
93 | print(f'Here is what the current working dir looks like: {os.listdir(self.path)}')
94 | self.load_output()
95 |
96 | except Exception as ex:
97 | print('AstraGenerator.run exception:', traceback.format_exc())
98 | self.error = True
99 | finally:
100 | pass
101 |
102 | @property
103 | def output_file(self):
104 | return os.path.join(self.path, self.input['fname'])
105 |
106 | def load_output(self):
107 | pfile = self.output_file
108 | data = parse_astra_phase_file(pfile)
109 | # Clock time is used when at cathode
110 | data['t'] = data['t_clock']
111 | P = ParticleGroup(data=data)
112 |
113 | self.output['particles'] = P
114 |
115 | def write_input_file(self):
116 | writers.write_namelists({'input': self.input}, self.input_file)
117 |
118 | # Methods from CommandWrapper not implemented here
119 | def archive(self, h5=None):
120 | """
121 | Archive all data to an h5 handle or filename.
122 |
123 | If no file is given, a file based on the fingerprint will be created.
124 |
125 | """
126 | if not h5:
127 | h5 = 'astra_generator_' + self.fingerprint() + '.h5'
128 |
129 | if isinstance(h5, str):
130 | h5 = os.path.expandvars(h5)
131 | g = h5py.File(h5, 'w')
132 | self.vprint(f'Archiving to file {h5}')
133 | else:
134 | # store directly in the given h5 handle
135 | g = h5
136 |
137 | # Write basic attributes
138 | archive.astra_init(g)
139 |
140 | # All input
141 | g2 = g.create_group('generator_input')
142 | for k, v in self.input.items():
143 | g2.attrs[k] = v
144 |
145 | return h5
146 |
147 | def load_archive(self, h5, configure=True):
148 | """
149 | Loads input and output from archived h5 file.
150 | """
151 | if isinstance(h5, str):
152 | h5 = os.path.expandvars(h5)
153 | g = h5py.File(h5, 'r')
154 | else:
155 | g = h5
156 |
157 | attrs = dict(g['generator_input'].attrs)
158 | self.input = {}
159 | for k, v, in attrs.items():
160 | self.input[k] = tools.native_type(v)
161 |
162 | if configure:
163 | self.configure()
164 |
165 |
166 | def plot(self, y=..., x=None, xlim=None, ylim=None, ylim2=None, y2=..., nice=True, include_layout=True, include_labels=False, include_particles=True, include_legend=True, return_figure=False):
167 | return super().plot(y=y, x=x, xlim=xlim, ylim=ylim, ylim2=ylim2, y2=y2, nice=nice, include_layout=include_layout, include_labels=include_labels, include_particles=include_particles, include_legend=include_legend, return_figure=return_figure)
168 |
169 | def write_input(self, input_filename):
170 | return super().write_input(input_filename)
171 |
--------------------------------------------------------------------------------
/astra/control.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | class ControlGroup:
4 | """
5 | Group elements to control the attributes for a list of elements.
6 |
7 | Based on Bmad's Ovelay and Group elements
8 |
9 | If absolute, the underlying attributes will be set absolutely.
10 |
11 | Othereise, underlying attributes will be set to changes from reference_values.
12 |
13 | If reference values are not given, they will be set when linking elements.
14 |
15 | Otherwise, only changes will be set.
16 |
17 | Optionally, a list of factors can be used
18 |
19 | Example 1:
20 | ELES = {'a':{'x':1}, 'b':{'x':2}}
21 | G = ControlGroup(ele_names=['a', 'b'], var_name='x')
22 | G.link(ELES) # This will set .reference_values = [1, 2]
23 | G['x'] = 3
24 | G.eles
25 | Returns:
26 | {'a': {'x': 4.0}, 'b': {'x': 5.0}}
27 |
28 | Example 2:
29 | ELES = {'a':{'x':1}, 'b':{'x':2}}
30 | G = ControlGroup(ele_names=['a', 'b'], var_name='dx', attributes='x', factors = [1.0, 2.0], absolute=False)
31 | G.link(ELES)
32 | G['dx'] = 3
33 | G.eles
34 | Returns:
35 | {'a': {'x': 4.0}, 'b': {'x': 8.0}})
36 |
37 | """
38 | def __init__(self,
39 | ele_names=[],
40 | var_name=None,
41 | # If underlying attribute is different
42 | attributes=None,
43 | # If factors != 1
44 | factors = None,
45 | reference_values = None,
46 | value=0,
47 | absolute=False #
48 | ):
49 |
50 | # Allow single element
51 | if isinstance(ele_names, str):
52 | ele_names = [ele_names]
53 |
54 | self.ele_names = ele_names # Link these.
55 | self.var_name = var_name
56 |
57 | self.attributes = attributes
58 | self.factors = factors
59 | self.reference_values = None
60 |
61 | self.absolute = absolute
62 |
63 | n_ele = len(self.ele_names)
64 |
65 | if not self.attributes:
66 | self.attributes = n_ele * [self.var_name]
67 | elif isinstance(self.attributes, str):
68 | self.attributes = n_ele * [self.attributes]
69 |
70 | assert len(self.attributes) == n_ele, 'attributes should be a list with the same length as ele_names'
71 |
72 | if not self.factors:
73 | self.factors = n_ele * [1.0]
74 | else:
75 | self.factors = [float(f) for f in self.factors] # Cast to float for YAML safety
76 | assert len(self.factors) == n_ele, 'factors should be a list with the same length as ele_names'
77 |
78 | if reference_values:
79 | self.reference_values = [float(f) for f in self.reference_values]
80 |
81 | self.value = float(value) # Cast to float for YAML safety
82 |
83 | # These need to be linked by the .link function
84 | self.ele_dict=None
85 |
86 | def link(self, ele_dict):
87 | """
88 | Link and ele dict, so that update will work
89 | """
90 | self.ele_dict=ele_dict
91 | # Populate reference values if none were defined
92 | if not self.reference_values:
93 | self.reference_values = self.ele_values
94 |
95 | # call setter
96 | self[self.var_name] = self.value
97 |
98 | @property
99 | def eles(self):
100 | """Return a list of the controlled eles"""
101 | return [self.ele_dict[name] for name in self.ele_names]
102 |
103 | @property
104 | def ele_values(self):
105 | """Returns the underlying element values"""
106 | return [self.ele_dict[ele_name][attrib] for ele_name, attrib in zip(self.ele_names, self.attributes)]
107 |
108 |
109 | def set_absolute(self, key, item):
110 | """
111 | Sets the underlying attributes directly.
112 | """
113 | self.value = item
114 |
115 | for name, attrib, f in zip(self.ele_names, self.attributes, self.factors):
116 | self.ele_dict[name][attrib] = f * self.value
117 |
118 | def set_delta(self, key, item):
119 | """
120 | Sets a change (delta) in the underlying attributes.
121 | """
122 |
123 | self.value = item
124 | for name, attrib, f, ref in zip(self.ele_names, self.attributes, self.factors, self.reference_values):
125 | self.ele_dict[name][attrib] = ref + f * self.value
126 |
127 | def __setitem__(self, key, item):
128 | """
129 | Calls the appropriate set routine: set_absolute or set_delta
130 | """
131 | assert key == self.var_name, f'{key} mismatch var_name: {self.var_name}'
132 |
133 | assert self.eles, 'No eles are linked. Please call .link(eles)'
134 |
135 | if self.absolute:
136 | self.set_absolute(key, item)
137 | else:
138 | self.set_delta(key, item)
139 |
140 | def __getitem__(self, key):
141 | assert key == self.var_name
142 | return self.value
143 |
144 | def __str__(self):
145 |
146 | if self.absolute:
147 | s2 = 'absolute'
148 | else:
149 | s2 = 'changes in'
150 |
151 | s = f'{self.__class__.__name__} of eles {self.ele_names} with variable {self.var_name} controlling {s2} {self.attributes} with factors {self.factors}'
152 | return s
153 |
154 | def dumps(self):
155 | """
156 | Dump the internal data as a JSON string
157 | """
158 | ele_dict = self.__dict__.pop('ele_dict')
159 | d = self.__dict__
160 | s = json.dumps(d)
161 | # Relink
162 | self.ele_dict = ele_dict
163 | return s
164 |
165 | def loads(self, s):
166 | """
167 | Loads from a JSON string. See .dumps()
168 | """
169 | d = json.loads(s)
170 | self.__dict__.update(d)
171 |
172 | def __repr__(self):
173 |
174 | s0 = self.dumps()
175 | s = f'{self.__class__.__name__}(**{s0})'
176 |
177 | return s
178 |
179 |
180 |
--------------------------------------------------------------------------------
/astra/archive.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pmd_beamphysics import ParticleGroup, pmd_init
4 | from pmd_beamphysics.units import read_dataset_and_unit_h5, write_dataset_and_unit_h5
5 |
6 | from .control import ControlGroup
7 | from .parsers import OutputUnits
8 | from .tools import isotime, native_type
9 |
10 |
11 | def fstr(s):
12 | """
13 | Makes a fixed string for h5 files
14 | """
15 | return np.bytes_(s)
16 |
17 |
18 |
19 |
20 | def astra_init(h5, version=None):
21 | """
22 | Set basic information to an open h5 handle
23 |
24 | """
25 |
26 | if not version:
27 | from astra import __version__
28 | version = __version__
29 |
30 |
31 | d = {
32 | 'dataType':'lume-astra',
33 | 'software':'lume-astra',
34 | 'version':version,
35 | 'date':isotime()
36 | }
37 | for k,v in d.items():
38 | h5.attrs[k] = fstr(v)
39 |
40 |
41 | def opmd_init(h5, basePath='/screen/%T/', particlesPath='./' ):
42 | """
43 | Root attribute initialization.
44 |
45 | h5 should be the root of the file.
46 | """
47 | d = {
48 | 'basePath':basePath,
49 | 'dataType':'openPMD',
50 | 'openPMD':'2.0.0',
51 | 'openPMDextension':'BeamPhysics;SpeciesType',
52 | 'particlesPath':particlesPath
53 | }
54 | for k,v in d.items():
55 | h5.attrs[k] = fstr(v)
56 |
57 |
58 |
59 | #----------------------------
60 | # Searching archives
61 |
62 | def is_astra_archive(h5, key='dataType', value=np.bytes_('lume-astra')):
63 | """
64 | Checks if an h5 handle is a lume-astra archive
65 | """
66 | return key in h5.attrs and h5.attrs[key]==value
67 |
68 |
69 | def find_astra_archives(h5):
70 | """
71 | Searches one
72 | """
73 | if is_astra_archive(h5):
74 | return ['./']
75 | else:
76 | return [g for g in h5 if is_astra_archive(h5[g])]
77 |
78 |
79 | #----------------------------
80 | # input
81 | def write_input_h5(h5, astra_input, name='input'):
82 | """
83 |
84 | Writes astra input to h5.
85 |
86 | astra_input is a dict with dicts
87 |
88 | See: read_input_h5
89 | """
90 | g0 = h5.create_group(name)
91 | for n in astra_input:
92 | namelist = astra_input[n]
93 | g = g0.create_group(n)
94 | for k in namelist:
95 | g.attrs[k] = namelist[k]
96 |
97 | def read_input_h5(h5):
98 | """
99 | Reads astra inpu5 from h5
100 |
101 | See: write_input_h5
102 | """
103 | d = {}
104 | for g in h5:
105 | d[g] = dict(h5[g].attrs)
106 |
107 | # Convert to native types
108 | for k, v in d[g].items():
109 | d[g][k] = native_type(v)
110 |
111 | return d
112 |
113 | #----------------------------
114 | # fieldmaps
115 | def write_fieldmap_h5(h5, fieldmap_dict, name='fieldmap'):
116 | """
117 | Writes all fieldmaps as simple datasets
118 | """
119 | g = h5.create_group(name)
120 | for k, v in fieldmap_dict.items():
121 | g[k] = v['data']
122 | for k2, a in v['attrs'].items():
123 | g[k].attrs[k2] = a
124 |
125 | def read_fieldmap_h5(h5):
126 | d = {}
127 | for fmap in h5:
128 | d[fmap] = {}
129 | d[fmap]['data'] = h5[fmap][:]
130 |
131 | # Handle legacy fieldmaps without attrs
132 | attrs = dict(h5[fmap].attrs)
133 | if not attrs:
134 | attrs = {'type': 'astra_1d'}
135 | d[fmap]['attrs'] = attrs
136 |
137 | return d
138 |
139 |
140 | #----------------------------
141 | # output
142 | def write_output_h5(h5, astra_output, name='output'):
143 | """
144 | Writes all of astra_output dict to an h5 handle
145 |
146 | """
147 | g = h5.create_group(name)
148 |
149 | for name2 in ['stats', 'other']:
150 | if name2 not in astra_output:
151 | continue
152 |
153 | g2 = g.create_group(name2)
154 | for key, data in astra_output[name2].items():
155 | unit = OutputUnits[key]
156 | write_dataset_and_unit_h5(g2, key, data, unit)
157 | if 'run_info' in astra_output:
158 | for k, v in astra_output['run_info'].items():
159 | g.attrs[k] = v
160 |
161 | write_particles_h5(g, astra_output['particles'], name='particles')
162 |
163 | def read_output_h5(h5):
164 | """
165 | Reads a properly archived astra output and returns a dict that corresponds to Astra.output
166 | """
167 |
168 | o = {}
169 | o['run_info'] = dict(h5.attrs)
170 | for name2 in ['stats', 'other']:
171 | if name2 not in h5:
172 | continue
173 | g = h5[name2]
174 | o[name2] = {}
175 | for key in g:
176 | expected_unit = OutputUnits[key] # expected unit
177 | o[name2][key], _ = read_dataset_and_unit_h5(g[key], expected_unit=expected_unit)
178 |
179 | if 'particles' in h5:
180 | o['particles'] = read_particles_h5(h5['particles'])
181 |
182 | return o
183 |
184 | #------------------------------------------
185 | # control groups
186 |
187 | def write_control_groups_h5(h5, group_data, name='control_groups'):
188 | """
189 | Writes the ControlGroup object data to the attrs in
190 | an h5 group for archiving.
191 |
192 | See: read_control_groups_h5
193 | """
194 | if name:
195 | g = h5.create_group(name)
196 | else:
197 | g = h5
198 |
199 | for name, G in group_data.items():
200 | g.attrs[name] = fstr(G.dumps())
201 |
202 |
203 |
204 | def read_control_groups_h5(h5, verbose=False):
205 | """
206 | Reads ControlGroup object data
207 |
208 | See: write_control_groups_h5
209 | """
210 | group_data = {}
211 | for name in h5.attrs:
212 | dat = h5.attrs[name]
213 | G = ControlGroup()
214 | G.loads(dat)
215 | group_data[name] = G
216 |
217 | if verbose:
218 | print('h5 read control_groups:', name, '=', G)
219 |
220 | return group_data
221 |
222 |
223 | #----------------------------
224 | # particles
225 |
226 | def write_particles_h5(h5, particles, name='particles'):
227 | """
228 | Write all screens to file, simply named by their index
229 |
230 | See: read_particles_h5
231 | """
232 | g = h5.create_group(name)
233 |
234 | # Set base attributes
235 | opmd_init(h5, basePath=name+'/%T/', particlesPath='/' )
236 |
237 | # Loop over screens
238 | for i, particle_group in enumerate(particles):
239 | name = str(i)
240 | particle_group.write(g, name=name)
241 |
242 |
243 | def read_particles_h5(h5):
244 | """
245 | Reads particles from h5
246 |
247 | See: write_particles_h5
248 | """
249 | # This should be a list of '0', '1', etc.
250 | # Cast to int, sort, reform to get the list order correct.
251 | ilist = sorted([int(x) for x in list(h5)])
252 | glist = [str(i) for i in ilist]
253 |
254 | return [ParticleGroup(h5=h5[g]) for g in glist]
255 |
--------------------------------------------------------------------------------
/astra/astra_distgen.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | from astra import Astra
5 | from . import tools
6 | from .astra import recommended_spacecharge_mesh
7 | from .evaluate import default_astra_merit
8 |
9 | from distgen import Generator
10 | from distgen.writers import write_astra
11 | from distgen.tools import update_nested_dict
12 |
13 | from lume import tools as lumetools
14 |
15 | from pmd_beamphysics import ParticleGroup
16 |
17 | from h5py import File
18 |
19 | import json
20 | import os
21 |
22 | def set_astra_and_distgen(astra_input, distgen_input, settings, verbose=False):
23 | """
24 | Searches astra and distgen input for keys in settings, and sets their values to the appropriate input.
25 | """
26 | for k, v in settings.items():
27 | found=False
28 | for nl in astra_input:
29 | if k in astra_input[nl]:
30 | found = True
31 | if verbose:
32 | print(k, 'is in astra', nl)
33 | astra_input[nl][k] = settings[k]
34 |
35 | if not found:
36 | distgen_input = update_nested_dict(distgen_input, {k:v}, verbose=verbose)
37 | #set_nested_dict(distgen_input, k, v)
38 |
39 | return astra_input, distgen_input
40 |
41 | def run_astra_with_distgen(settings=None,
42 | astra_input_file=None,
43 | distgen_input_file=None,
44 | workdir=None,
45 | astra_bin='$ASTRA_BIN',
46 | timeout=2500,
47 | verbose=False,
48 | auto_set_spacecharge_mesh=True):
49 | """
50 | Run Astra with particles generated by distgen.
51 |
52 | settings: dict with keys that can appear in an Astra,
53 | or distgen keys with prefix 'distgen:'
54 |
55 | Example usage:
56 | A = run_astra_with_distgen({'lspch':False, 'distgen:n_particle':1000},
57 | astra_input_file='astra.yaml',
58 | distgen_input_file='distgen.yaml',
59 | verbose=True,
60 | timeout=None
61 | )
62 |
63 | """
64 |
65 | # Call simpler evaluation if there is no generator:
66 | if not distgen_input_file:
67 | return run_astra(settings=settings,
68 | astra_input_file=astra_input_file,
69 | workdir=workdir,
70 | command=astra_bin,
71 | timeout=timeout,
72 | verbose=verbose)
73 |
74 |
75 | if verbose:
76 | print('run_astra_with_generator')
77 |
78 | # Distgen generator
79 | G = Generator(distgen_input_file, verbose=verbose)
80 |
81 | # Make astra objects
82 | if astra_input_file.endswith('.yaml'):
83 | if verbose:
84 | f'loading Astra as yaml: {astra_input_file}'
85 | A = Astra.from_yaml(astra_input_file)
86 | if workdir:
87 | A.workdir = workdir
88 | A.configure() # again to make sure things are set properly
89 |
90 | else:
91 | A = Astra(command=astra_bin, input_file=astra_input_file, workdir=workdir)
92 |
93 |
94 |
95 | A.timeout=timeout
96 | A.verbose = verbose
97 |
98 | # Special
99 | A.input['newrun']['l_rm_back'] = True # Remove backwards particles
100 |
101 | #
102 | if settings:
103 | for key, val in settings.items():
104 |
105 | found = False
106 | # Check distgen
107 | if key.startswith('distgen:'):
108 | key = key[len('distgen:'):]
109 | if verbose:
110 | print(f'Setting distgen {key} = {val}')
111 | G[key] = val
112 | continue
113 |
114 | # Check for direct settable attribute
115 | if ':' in key:
116 | A[key] = val
117 | continue
118 |
119 | for nl in A.input:
120 | if key in A.input[nl]:
121 | found = True
122 | if verbose:
123 | print(key, 'is in astra', nl)
124 | A.input[nl][key] = val
125 |
126 | if not found:
127 | raise ValueError(f'Key not found: {key}')
128 |
129 | # Attach distgen input. This is non-standard.
130 | A.distgen_input = G.input
131 |
132 | # Run distgen
133 | G.run()
134 | P = G.particles
135 | # Special flag for cathode start
136 | if G['start:type'] == 'cathode':
137 | P.status[:] = -1
138 |
139 | # Attach to Astra object
140 | A.initial_particles = P
141 |
142 | if auto_set_spacecharge_mesh:
143 | n_particles = len(P)
144 | sc_settings = recommended_spacecharge_mesh(n_particles)
145 | A.input['charge'].update(sc_settings)
146 | if verbose:
147 | print('set spacecharge mesh for n_particles:', n_particles, 'to', sc_settings)
148 |
149 | A.run()
150 |
151 | return A
152 |
153 | # Same as run_astra_with_distgen
154 | # Additional options
155 | def evaluate_astra_with_distgen(settings,
156 | astra_input_file=None,
157 | distgen_input_file=None,
158 | workdir=None,
159 | astra_bin='$ASTRA_BIN',
160 | timeout=2500,
161 | verbose=False,
162 | auto_set_spacecharge_mesh=True,
163 | archive_path=None,
164 | merit_f=None):
165 | """
166 | Similar to run_astra_with_distgen, but returns a flat dict of outputs as processed by merit_f.
167 |
168 | If no merit_f is given, a default one will be used. See:
169 | astra.evaluate.default_astra_merit
170 |
171 | Will raise an exception if there is an error.
172 |
173 | """
174 | A = run_astra_with_distgen(settings=settings,
175 | astra_input_file=astra_input_file,
176 | distgen_input_file=distgen_input_file,
177 | workdir=workdir,
178 | astra_bin=astra_bin,
179 | timeout=timeout,
180 | auto_set_spacecharge_mesh=auto_set_spacecharge_mesh,
181 | verbose=verbose)
182 |
183 | if merit_f:
184 | output = merit_f(A)
185 | else:
186 | output = default_astra_merit(A)
187 |
188 | if output['error']:
189 | raise ValueError('run_astra_with_distgen returned error in output')
190 |
191 | #Recreate Generator object for fingerprint, proper archiving
192 | # TODO: make this cleaner
193 | G = Generator(A.distgen_input)
194 |
195 | fingerprint = fingerprint_astra_with_distgen(A, G)
196 | output['fingerprint'] = fingerprint
197 |
198 | if archive_path:
199 | path = lumetools.full_path(archive_path)
200 | assert os.path.exists(path), f'archive path does not exist: {path}'
201 | archive_file = os.path.join(path, fingerprint+'.h5')
202 | output['archive'] = archive_file
203 |
204 | # Call the composite archive method
205 | archive_astra_with_distgen(A, G, archive_file=archive_file)
206 |
207 | return output
208 |
209 |
210 |
211 | def fingerprint_astra_with_distgen(astra_object, distgen_object):
212 | """
213 | Calls fingerprint() of each of these objects
214 | """
215 | f1 = astra_object.fingerprint()
216 | f2 = distgen_object.fingerprint()
217 | d = {'f1':f1, 'f2':f2}
218 | return lumetools.fingerprint(d)
219 |
220 |
221 |
222 | def archive_astra_with_distgen(astra_object,
223 | distgen_object,
224 | archive_file=None,
225 | astra_group ='astra',
226 | distgen_group ='distgen'):
227 | """
228 | Creates a new archive_file (hdf5) with groups for
229 | astra and distgen.
230 |
231 | Calls .archive method of Astra and Distgen objects, into these groups.
232 | """
233 |
234 | h5 = File(archive_file, 'w')
235 |
236 | #fingerprint = tools.fingerprint(astra_object.input.update(distgen.input))
237 |
238 | g = h5.create_group(distgen_group)
239 | distgen_object.archive(g)
240 |
241 | g = h5.create_group(astra_group)
242 | astra_object.archive(g)
243 |
244 | h5.close()
245 |
246 |
247 |
--------------------------------------------------------------------------------
/astra/evaluate.py:
--------------------------------------------------------------------------------
1 | from astra.astra import run_astra, run_astra_with_generator, run_astra_with_generator
2 | from astra.generator import AstraGenerator
3 | from astra.astra_calc import calc_ho_energy_spread
4 | from lume.tools import full_path
5 | from lume import tools as lumetools
6 | import numpy as np
7 | import json
8 | from inspect import getfullargspec
9 | import os
10 | from h5py import File
11 |
12 |
13 |
14 | def end_output_data(output):
15 | """
16 | Some outputs are lists. Get the last item.
17 | """
18 | o = {}
19 | for k in output:
20 | val = output[k]
21 | if isinstance(val, str): # Encode strings
22 | o[k] = val.encode()
23 | elif np.isscalar(val):
24 | o[k]=val
25 | else:
26 | o['end_'+k]=val[-1]
27 |
28 | return o
29 |
30 |
31 |
32 | def default_astra_merit(A):
33 | """
34 | merit function to operate on an evaluated LUME-Astra object A.
35 |
36 | Returns dict of scalar values
37 | """
38 | # Check for error
39 | if A.error:
40 | return {'error':True}
41 | else:
42 | m= {'error':False}
43 |
44 | # Gather output
45 | m.update(end_output_data(A.output['stats']))
46 |
47 | # Return early if no particles found
48 | if not A.particles:
49 | return m
50 |
51 | P = A.particles[-1]
52 |
53 | # Lost particles have status < -6
54 | nlost = len(np.where(P['status'] < -6)[0])
55 | m['end_n_particle_loss'] = nlost
56 |
57 | # Get live only for stat calcs
58 | P = P.where(P.status==1)
59 |
60 | # No live particles
61 | if len(P) == 0:
62 | return {'error':True}
63 |
64 |
65 |
66 | m['end_total_charge'] = P['charge']
67 | m['end_higher_order_energy_spread'] = P['higher_order_energy_spread']
68 | # Old method:
69 | #m['end_higher_order_energy_spread'] = calc_ho_energy_spread( {'t':P['z'], 'Energy':(P['pz'])*1e-3},verbose=False) # eV
70 |
71 | # Remove annoying strings
72 | if 'why_error' in m:
73 | m.pop('why_error')
74 |
75 |
76 | return m
77 |
78 |
79 |
80 |
81 | # Get defaults for **params in evaluate for each type of simulation
82 | def _get_defaults(run_f, extra=None):
83 | spec = getfullargspec(run_f)
84 | d = dict(zip(spec.args, spec.defaults))
85 | d.pop('settings')
86 | if extra:
87 | d.update(extra)
88 | return d
89 |
90 |
91 |
92 |
93 |
94 | def evaluate(settings, simulation='astra', archive_path=None, merit_f=None, **params):
95 | """
96 | Evaluate astra using possible simulations:
97 | 'astra'
98 | 'astra_with_generator'
99 | 'astra_with_distgen'
100 |
101 | Returns a flat dict of outputs.
102 |
103 | If merit_f is provided, this function will be used to form the outputs.
104 | Otherwise a default funciton will be applied.
105 |
106 | Will raise an exception if there is an error.
107 |
108 | """
109 |
110 | # Pick simulation to run
111 |
112 | if simulation=='astra':
113 | A = run_astra(settings, **params)
114 |
115 | elif simulation=='astra_with_generator':
116 | A = run_astra_with_generator(settings, **params)
117 |
118 | elif simulation == 'astra_with_distgen':
119 |
120 | # Import here to limit dependency on distgen
121 | from .astra_distgen import run_astra_with_distgen
122 | A = run_astra_with_distgen(settings, **params)
123 |
124 | else:
125 | raise ValueError(f'simulation not recognized: {simulation}')
126 |
127 | if merit_f:
128 | output = merit_f(A)
129 | else:
130 | output = default_astra_merit(A)
131 |
132 | if output['error']:
133 | raise ValueError(f'Error returned from Astra evaluate')
134 |
135 | fingerprint = A.fingerprint()
136 |
137 | output['fingerprint'] = fingerprint
138 |
139 | if archive_path:
140 | path = full_path(archive_path)
141 | assert os.path.exists(path), f'archive path does not exist: {path}'
142 | archive_file = os.path.join(path, fingerprint+'.h5')
143 | A.archive(archive_file)
144 | output['archive'] = archive_file
145 |
146 | return output
147 |
148 |
149 | # Convenience wrappers, and their full options
150 |
151 | # Get all kwargs from run_astra routines. Save these as the complete set of options
152 | EXTRA = {'archive_path':None, 'merit_f':None}
153 | EXTRA2 = {'archive_path':None, 'merit_f':None, 'distgen_input_file':None}
154 | DEFAULTS = {
155 | 'evaluate_astra': _get_defaults(run_astra, EXTRA ),
156 | 'evaluate_astra_with_generator': _get_defaults(run_astra_with_generator, EXTRA) ,
157 | 'evaluate_astra_with_distgen': _get_defaults(run_astra, EXTRA2)
158 | }
159 |
160 |
161 | def evaluate_astra(settings, archive_path=None, merit_f=None, **params):
162 | """
163 | Convenience wrapper. See evaluate.
164 | """
165 | return evaluate(settings, simulation='astra',
166 | archive_path=archive_path, merit_f=merit_f, **params)
167 |
168 |
169 | def old_evaluate_astra_with_generator(settings, archive_path=None, merit_f=None, **params):
170 | """
171 | Convenience wrapper. See evaluate.
172 | """
173 | return evaluate(settings, simulation='astra_with_generator',
174 | archive_path=archive_path, merit_f=merit_f, **params)
175 |
176 |
177 | def evaluate_astra_with_generator(settings,
178 | astra_input_file=None,
179 | generator_input_file=None,
180 | workdir=None,
181 | astra_bin='$ASTRA_BIN',
182 | generator_bin='$GENERATOR_BIN',
183 | timeout=2500,
184 | verbose=False,
185 | auto_set_spacecharge_mesh=True,
186 | archive_path=None,
187 | merit_f=None):
188 | """
189 | Similar to run_astra_with_generator, but returns a flat dict of outputs as processed by merit_f.
190 |
191 | If no merit_f is given, a default one will be used. See:
192 | astra.evaluate.default_astra_merit
193 |
194 | Will raise an exception if there is an error.
195 | """
196 |
197 | A = run_astra_with_generator(settings=settings,
198 | astra_input_file=astra_input_file,
199 | generator_input_file=generator_input_file,
200 | workdir=workdir,
201 | command=astra_bin,
202 | command_generator=generator_bin,
203 | timeout=timeout,
204 | auto_set_spacecharge_mesh=auto_set_spacecharge_mesh,
205 | verbose=verbose)
206 |
207 | if merit_f:
208 | output = merit_f(A)
209 | else:
210 | output = default_astra_merit(A)
211 |
212 | if output['error']:
213 | raise ValueError(f'Error returned from Astra evaluate')
214 |
215 | #Recreate Generator object for fingerprint, proper archiving
216 | # TODO: make this cleaner
217 | G = AstraGenerator(A.generator_input)
218 |
219 | fingerprint = fingerprint_astra_with_generator(A, G)
220 |
221 | output['fingerprint'] = fingerprint
222 |
223 | if archive_path:
224 | path = full_path(archive_path)
225 | assert os.path.exists(path), f'archive path does not exist: {path}'
226 | archive_file = os.path.join(path, fingerprint+'.h5')
227 | output['archive'] = archive_file
228 |
229 | # Call the composite archive method
230 | archive_astra_with_generator(A, G, archive_file=archive_file)
231 |
232 |
233 |
234 | return output
235 |
236 |
237 | def fingerprint_astra_with_generator(astra_object,generator_object):
238 | """
239 | Calls fingerprint() of each of these objects
240 | """
241 | f1 = astra_object.fingerprint()
242 | f2 = generator_object.fingerprint()
243 | d = {'f1':f1, 'f2':f2}
244 | return lumetools.fingerprint(d)
245 |
246 |
247 | def archive_astra_with_generator(astra_object,
248 | generator_object,
249 | archive_file=None,
250 | generator_group ='generator'):
251 | """
252 | Creates a new archive_file (hdf5) with groups for
253 | astra and generator.
254 |
255 | Calls .archive method of Astra and Generator objects, into these groups.
256 | """
257 |
258 | with File(archive_file, 'w') as h5:
259 | astra_object.archive(h5)
260 | generator_object.archive(h5)
261 |
262 |
263 |
264 |
265 |
266 |
267 |
--------------------------------------------------------------------------------
/astra/writers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numbers import Number
3 |
4 | from astra.tools import make_symlink
5 | from astra.fieldmaps import fieldmap3d_filenames
6 | import glob
7 |
8 | import os
9 |
10 | def namelist_lines(namelist_dict, name):
11 | """
12 | Converts namelist dict to output lines, for writing to file.
13 |
14 | Only allow scalars or lists.
15 |
16 | Do not allow np arrays or any other types from simplicity.
17 | """
18 | lines = []
19 | lines.append('&'+name)
20 | # parse
21 |
22 |
23 | for key, value in namelist_dict.items():
24 | #if type(value) == type(1) or type(value) == type(1.): # numbers
25 |
26 | if isinstance(value, Number): # numbers
27 | line= key + ' = ' + str(value)
28 | elif type(value) == type([]) or isinstance(value, np.ndarray): # lists or np arrays
29 | liststr = ''
30 | # Special case for dipole Double Complex items
31 | if key.upper() in ('D1', 'D2', 'D3', 'D4'):
32 | for item in value:
33 | liststr = liststr + str(item) + ','
34 | line = key + ' = ' + "(" + liststr[:-1] + ")"
35 | else:
36 | for item in value:
37 | liststr += str(item) + ' '
38 | line = key + ' = ' + liststr
39 | elif type(value) == type('a'): # strings
40 | line = key + ' = ' + "'" + value.strip("''") + "'" # input may need apostrophes
41 |
42 | elif bool(value) == value:
43 | line= key + ' = ' + str(value)
44 | else:
45 | #print 'skipped: key, value = ', key, value
46 | raise ValueError(f'Problem writing input key: {key}, value: {value}, type: {type(value)}')
47 |
48 | lines.append(line)
49 |
50 | lines.append('/')
51 | return lines
52 |
53 |
54 |
55 | def make_namelist_symlinks(namelist, path, prefixes=('file_', 'distribution', 'q_type'), verbose=False):
56 | """
57 | Looks for keys that start with prefixes.
58 | If the value is a path that exists, a symlink will be made.
59 | Old symlinks will be replaced.
60 |
61 | A replacement dict is returned
62 | """
63 |
64 | replacements = {}
65 | for key in namelist:
66 | if any([key.startswith(prefix) for prefix in prefixes]):
67 | src = namelist[key]
68 |
69 | _, file = os.path.split(src)
70 |
71 | # Special for 3D fieldmaps
72 | if file.lower().startswith('3d_'):
73 | flist = fieldmap3d_filenames(src)
74 | for f in flist:
75 | make_symlink(f, path)
76 | elif not os.path.exists(src):
77 | continue
78 | else:
79 | make_symlink(src, path)
80 |
81 | replacements[key] = file
82 |
83 |
84 | return replacements
85 |
86 |
87 |
88 |
89 | def write_namelists(namelists, filePath, make_symlinks=False, prefixes=['file_', 'distribution'], verbose=False):
90 | """
91 | Simple function to write namelist lines to a file
92 |
93 | If make_symlinks, prefixes will be searched for paths and the appropriate links will be made.
94 | For Windows, make_symlinks is ignored and it is always False.See note at https://docs.python.org/3/library/os.html#os.symlink .
95 | """
96 | # With Windows 10, users need Administator Privileges or run on Developer mode
97 | # in order to be able to create symlinks.
98 | # More info: https://docs.python.org/3/library/os.html#os.symlink
99 | if os.name == 'nt':
100 | make_symlinks = False
101 |
102 | with open(filePath, 'w') as f:
103 | for key in namelists:
104 | namelist = namelists[key]
105 |
106 | if make_symlinks:
107 | # Work on a copy
108 | namelist = namelist.copy()
109 | path, _ = os.path.split(filePath)
110 | replacements = make_namelist_symlinks(namelist, path, prefixes=prefixes, verbose=verbose)
111 | namelist.update(replacements)
112 |
113 |
114 | lines = namelist_lines(namelist, key)
115 | for l in lines:
116 | f.write(l+'\n')
117 |
118 |
119 |
120 |
121 | def fstr(s):
122 | """
123 | Makes a fixed string for h5 files
124 | """
125 | return np.bytes_(s)
126 |
127 |
128 |
129 | def opmd_init(h5, basePath='/screen/%T/', particlesPath='/' ):
130 | """
131 | Root attribute initialization.
132 |
133 | h5 should be the root of the file.
134 | """
135 | d = {
136 | 'basePath':basePath,
137 | 'dataType':'openPMD',
138 | 'openPMD':'2.0.0',
139 | 'openPMDextension':'BeamPhysics;SpeciesType',
140 | 'particlesPath':particlesPath
141 | }
142 | for k,v in d.items():
143 | h5.attrs[k] = fstr(v)
144 |
145 |
146 |
147 | def write_astra_particles_h5(h5, name, astra_data, species='electron'):
148 | """
149 | Write particle data at a screen in openPMD BeamPhysics format
150 | https://github.com/DavidSagan/openPMD-standard/blob/EXT_BeamPhysics/EXT_BeamPhysics.md
151 | """
152 |
153 | g = h5.create_group(name)
154 |
155 | n_particle = len(astra_data['x'])
156 | # Indices of good particles
157 | good = np.where(astra_data['status'] == 5)
158 |
159 | #-----------
160 | # Attributes
161 | g.attrs['speciesType'] = fstr(species)
162 | g.attrs['numParticles'] = n_particle
163 | g.attrs['chargeLive'] = abs(np.sum(astra_data['qmacro'][good])) # Make positive
164 | g.attrs['chargeUnitSI'] = 1
165 | #g.attrs['chargeUnitDimension']=(0., 0., 1, 1., 0., 0., 0.) # Amp*s = Coulomb
166 | g.attrs['totalCharge'] = abs(np.sum(astra_data['qmacro']))
167 |
168 | #---------
169 | # Datasets
170 |
171 | # Position
172 | g['position/x']=astra_data['x'] # in meters
173 | g['position/y']=astra_data['y']
174 | g['position/z']=astra_data['z_rel']
175 | for component in ['position/x', 'position/y', 'position/z', 'position']: # Add units to all components
176 | g[component].attrs['unitSI'] = 1.0
177 | g[component].attrs['unitDimension']=(1., 0., 0., 0., 0., 0., 0.) # m
178 |
179 |
180 | # positionOffset (Constant record)
181 | # Just z
182 | g2 = g.create_group('positionOffset/z')
183 | g2.attrs['value'] = astra_data['z_ref']
184 | g2.attrs['shape'] = (n_particle)
185 | g2.attrs['unitSI'] = g['position'].attrs['unitSI']
186 | g2.attrs['unitDimension'] = g['position'].attrs['unitDimension']
187 |
188 | # momenta
189 | g['momentum/x']=astra_data['px'] # m*c*gamma*beta_x in eV/c
190 | g['momentum/y']=astra_data['py']
191 | g['momentum/z']=astra_data['pz_rel']
192 | for component in ['momentum/x', 'momentum/y', 'momentum/z', 'momentum']:
193 | g[component].attrs['unitSI']= 5.34428594864784788094e-28 # eV/c in J/(m/s) = kg*m / s
194 | g[component].attrs['unitDimension']=(1., 1., -1., 0., 0., 0., 0.) # kg*m / s
195 |
196 | # momentumOffset (Constant record)
197 | # Just pz
198 | g2 = g.create_group('momentumOffset/z')
199 | g2.attrs['value'] = astra_data['pz_ref']
200 | g2.attrs['shape'] = (n_particle)
201 | g2.attrs['unitSI'] = g['momentum'].attrs['unitSI']
202 | g2.attrs['unitDimension'] = g['momentum'].attrs['unitDimension']
203 |
204 | # Time
205 | g['time'] = astra_data['t_rel']
206 | g['time'].attrs['unitSI'] = 1.0 # s
207 | g['time'].attrs['unitDimension'] = (0., 0., 1., 0., 0., 0., 0.) # s
208 |
209 | # Time offset (Constant record)
210 | g2 = g.create_group('timeOffset')
211 | g2.attrs['value'] = astra_data['t_ref']
212 | g2.attrs['shape'] = (n_particle)
213 | g2.attrs['unitSI'] = g['time'].attrs['unitSI']
214 | g2.attrs['unitDimension'] = g['time'].attrs['unitDimension']
215 |
216 | # Weights
217 | g['weight'] = astra_data['qmacro']
218 | g['weight'].attrs['unitSI'] = 1.0
219 | g['weight'].attrs['unitDimension']=(0., 0., 1, 1., 0., 0., 0.) # Amp*s = Coulomb
220 |
221 |
222 | # Status
223 | # The standard defines 1 as a live particle, but astra uses 1 as a 'passive' particle
224 | # and 5 as a 'standard' particle. 2 is not used.
225 | # To preserve this information, make 1->2 and then 5->1
226 |
227 | status = astra_data['status'].copy()
228 | where_1 = np.where(status==1)
229 | where_5 = good # was defined above
230 | status[where_1] = 2
231 | status[where_5] = 1
232 | g['particleStatus'] = status
233 | g['particleStatus'].attrs['unitSI'] = 1.0
234 | g['particleStatus'].attrs['unitDimension']=(0., 0., 0, 0., 0., 0., 0.) # Dimensionless
235 |
236 |
237 |
238 |
239 | def write_screens_h5(h5, astra_screens, name='screen'):
240 | """
241 | Write all screens to file, simply named by their index
242 | """
243 | g = h5.create_group(name)
244 |
245 | # Set base attributes
246 | opmd_init(h5, basePath='/'+name+'/%T/', particlesPath='/' )
247 |
248 | # Loop over screens
249 | for i in range(len(astra_screens)):
250 | name = str(i)
251 | write_astra_particles_h5(g, name, astra_screens[i])
252 |
--------------------------------------------------------------------------------
/astra/fieldmaps.py:
--------------------------------------------------------------------------------
1 | """
2 | Tools for loading fieldmap data
3 | """
4 | import numpy as np
5 | import re
6 | import os
7 | import glob
8 |
9 | # Prefix helpers
10 |
11 | POS_PREFIX = {'cavity':'c_pos', 'solenoid':'s_pos'}
12 | def pos_(section='cavity', index=1):
13 | prefix = POS_PREFIX[section]
14 |
15 | return f'{prefix}({index})'
16 |
17 | FILE_PREFIX = {'cavity':'file_efield', 'solenoid':'file_bfield'}
18 |
19 | def file_(section='cavity', index=1):
20 | prefix = FILE_PREFIX[section]
21 |
22 | return f'{prefix}({index})'
23 |
24 | MAX_PREFIX = {'cavity':'maxe', 'solenoid':'maxb'}
25 | def max_(section='cavity', index=1):
26 | prefix =MAX_PREFIX[section]
27 |
28 | return f'{prefix}({index})'
29 |
30 |
31 | def find_fieldmap_ixlist(astra_input, section='cavity'):
32 | """
33 |
34 | Looks for the appropriage file_efield(i) and extracts the integer i
35 | """
36 |
37 | dat = astra_input[section]
38 | prefix = FILE_PREFIX[section]
39 | ixlist = []
40 | for k in astra_input[section]:
41 | if k.startswith(prefix):
42 | m = re.search(r"\(([0-9]+)\)", k)
43 | ix = int(m.group(1))
44 | ixlist.append(ix)
45 | return ixlist
46 |
47 | def load_fieldmaps(astra_input, search_paths=[], fieldmap_dict={}, sections=['cavity', 'solenoid'], verbose=False, strip_path=False):
48 | """
49 | Loads all found fieldmaps into a dict with the filenames as keys
50 | """
51 | fmap = {}
52 | for sec in sections:
53 |
54 | if sec not in astra_input:
55 | continue
56 |
57 | ixlist = find_fieldmap_ixlist(astra_input, sec)
58 | for ix in ixlist:
59 | k = file_(section=sec, index=ix)
60 | file = astra_input[sec][k]
61 |
62 | # Skip 3D fieldmaps. These are symlinked
63 | if os.path.split(file)[1].lower().startswith('3d_'):
64 | continue
65 |
66 | if file not in fmap:
67 | # Look inside dict
68 | if file in fieldmap_dict:
69 | if verbose:
70 | print(f'Fieldmap inside dict: {file}')
71 | fmap[file] = fieldmap_dict[file]
72 | continue
73 |
74 | if verbose:
75 | print(f'Loading fieldmap file {file}')
76 |
77 | # Look in search path
78 | if not os.path.exists(file):
79 | if verbose:
80 | print(f'{file} not found, searching:')
81 | for path in search_paths:
82 | _, file = os.path.split(file)
83 | tryfile = os.path.join(path, file)
84 |
85 | if os.path.exists(tryfile):
86 | if verbose:
87 | print('Found:', tryfile)
88 | file = tryfile
89 | break
90 |
91 |
92 | # Set input
93 | astra_input[sec][k] = file
94 |
95 | fmap[file] = parse_fieldmap(file)
96 |
97 | # Loop again
98 | if strip_path:
99 | # Make a secondary dict with the shorter names.
100 | # Protect against /path1/dat1, /path2/dat1 overwriting
101 | fmap2 = {}
102 | translate = {}
103 | for k in fmap:
104 | _, k2 = os.path.split(k)
105 | i=0 # Check for uniqueness
106 | while k2 in fmap2:
107 | i+=1
108 | k2 = f'{k2}_{i}'
109 | # Collect translation
110 | translate[k] = k2
111 | fmap2[k2] = fmap[k]
112 |
113 | for sec in sections:
114 | ixlist = find_fieldmap_ixlist(astra_input, sec)
115 | for ix in ixlist:
116 | k = file_(section=sec, index=ix)
117 | file = astra_input[sec][k]
118 | astra_input[sec][k] = translate[file]
119 |
120 | return fmap2
121 |
122 | else:
123 | return fmap
124 |
125 |
126 | def write_fieldmaps(fieldmap_dict, path):
127 | """
128 | Writes fieldmap dict to path
129 |
130 | """
131 | assert os.path.exists(path)
132 |
133 | for k, fmap in fieldmap_dict.items():
134 | file = os.path.join(path, k)
135 |
136 | # Remove any previous symlinks
137 | if os.path.islink(file):
138 | os.unlink(file)
139 |
140 | write_fieldmap(file, fmap)
141 |
142 | def write_fieldmap(fname, fmap):
143 |
144 | attrs = fmap['attrs']
145 | ftype = attrs['type']
146 | if ftype == 'astra_tws':
147 | header = f"{attrs['z1']} {attrs['z2']} {attrs['n']} {attrs['m']}"
148 | np.savetxt(fname, fmap['data'], header=header, comments='')
149 | elif ftype == 'astra_1d':
150 | np.savetxt(fname, fmap['data'])
151 | else:
152 | raise ValueError(f'Unknown fieldmap type: {ftype}')
153 |
154 |
155 | def parse_fieldmap(filePath):
156 | """
157 | Parses 1D fieldmaps, including TWS fieldmaps.
158 |
159 | See p. 70 in the Astra manual for TWS
160 |
161 | Returns a dict of:
162 | attrs
163 | data
164 |
165 | See: write_fieldmap
166 |
167 | """
168 |
169 | header = list(map(float, open(filePath).readline().split()))
170 |
171 | attrs = {}
172 |
173 | if len(header) == 4:
174 | attrs['type'] = 'astra_tws'
175 | attrs['z1'] = header[0]
176 | attrs['z2'] = header[1]
177 | attrs['n'] = int(header[2])
178 | attrs['m'] = int(header[3])
179 | data = np.loadtxt(filePath, skiprows=1)
180 | else:
181 | attrs['type'] = 'astra_1d'
182 | data = np.loadtxt(filePath)
183 |
184 | return dict(attrs=attrs, data=data)
185 |
186 |
187 |
188 |
189 | def fieldmap_data(astra_input, section='cavity', index=1, fieldmaps={}, verbose=False):
190 | """
191 | Loads the fieldmap in absolute coordinates.
192 |
193 | If a fieldmaps dict is given, thes will be used instead of loading the file.
194 |
195 | Returns tuple:
196 | attrs, data
197 |
198 | """
199 |
200 | adat = astra_input[section] # convenience pointer
201 |
202 | # Position
203 | k = pos_(section, index)
204 | if k in adat:
205 | offset = adat[k]
206 | else:
207 | offset = 0
208 |
209 | file = adat[file_(section, index)]
210 |
211 | # TODO: 3D fieldmaps
212 | if os.path.split(file)[1].lower().startswith('3d_'):
213 | return None
214 |
215 | # Scaling
216 | k = max_(section, index)
217 | if k in adat:
218 | scale = adat[k]
219 | else:
220 | scale = 1
221 |
222 |
223 | if file in fieldmaps:
224 | fmap = fieldmaps[file].copy()
225 | else:
226 | print(f'loading from file {file}')
227 | fmap = parse_fieldmap(file)
228 |
229 | dat = fmap['data'].copy()
230 |
231 | # TWS special case
232 | # From the manual:
233 | # field map can be n times periodically repeated by specifying C_numb( ) = n.
234 | if section == 'cavity':
235 | # Look for this key
236 | k2 = f'c_numb({index})'
237 | if k2 in astra_input[section]:
238 | n_cell = astra_input[section][k2]
239 |
240 | if n_cell > 1:
241 | zfull, Ezfull = expand_tws_fmap(fmap, n_cell)
242 | dat = np.array([zfull, Ezfull]).T
243 |
244 |
245 | dat[:,0] += offset
246 | dat[:,1] *= scale/max(abs(dat[:,1]))
247 |
248 | return dat
249 |
250 |
251 |
252 |
253 | def expand_tws_fmap(fmap, n_cell):
254 | """
255 | Expands periodic TWS fieldmap data over a number of repeating cells:
256 | |Entrance | Cells | Exit | -> |Entrance | Cells| ...|Cells| Exit |
257 | z0 z1 --- n_repeat ----
258 | Takes care not to overlap points.
259 |
260 | The the body of the fieldmap has:
261 | m_cells_in_body = fmap['attrs']['m']
262 | So:
263 | n_repeat = int(n_cell / m_cells_in_body)
264 |
265 | Returns
266 | -------
267 | z, Ez : tuple of arrays
268 |
269 | """
270 |
271 | z0, Ez0 = fmap['data'].T
272 | zmin = z0.min()
273 | zmax = z0.max()
274 | # Beg and end of cell
275 | z1 = fmap['attrs']['z1']
276 | z2 = fmap['attrs']['z2']
277 | m_cells_in_body = fmap['attrs']['m']
278 | # Approximate spacing
279 | dz = np.mean(np.diff(z0))
280 | Lentrance = z1-zmin
281 | Lexit = zmax-z2
282 | Lcell = z2-z1
283 |
284 | #
285 | n_repeat = int(n_cell / m_cells_in_body)
286 |
287 | # Z arrays to be used to construct the full map
288 | zentrance = np.linspace(zmin, z1, int(round(Lentrance/dz+1)))
289 | zcell = np.linspace(z1, z2, int(round(Lcell/dz+1)))
290 | zexit = np.linspace(z2, zmax, int(round(Lexit/dz+1)))
291 |
292 | Ezentrance = np.interp(zentrance, z0, Ez0)
293 | Ezcell = np.interp(zcell, z0, Ez0)
294 | Ezexit = np.interp(zexit, z0, Ez0)
295 |
296 | # Collect data, not overlapping points
297 | ztot = [zentrance[:-1]]
298 | Eztot = [Ezentrance[:-1]]
299 | for i in range(n_repeat):
300 | ztot.append(zcell[:-1] + i*Lcell)
301 | Eztot.append(Ezcell[:-1])
302 | ztot.append(zexit + (n_repeat-1)*Lcell)
303 | Eztot.append(Ezexit)
304 |
305 | return np.concatenate(ztot), np.concatenate(Eztot)
306 |
307 |
308 | def fieldmap3d_filenames(base_filename):
309 | """
310 | Returns a list of existing 3D fieldmap filenames corresponding to a base filename.
311 |
312 | Example:
313 | fieldmap3d_filenames('3D_7500Vchopmap') returns:
314 | ['/abs/path/to/3D_7500Vchopmap.ey', ...]
315 |
316 | """
317 |
318 | _, name = os.path.split(base_filename)
319 | assert name.lower().startswith('3d_')
320 |
321 | flist = glob.glob(base_filename+'.*')
322 |
323 | files = []
324 | for file in flist:
325 | for ext in ['.ex', '.ey', '.ez', '.bx', '.by', 'bz']:
326 | if file.lower().endswith(ext):
327 | files.append(os.path.abspath(file))
328 | return files
329 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/astra/plot.py:
--------------------------------------------------------------------------------
1 | from astra.fieldmaps import find_fieldmap_ixlist, fieldmap_data, load_fieldmaps
2 | from pmd_beamphysics.units import nice_array, nice_scale_prefix
3 | import numpy as np
4 | import os
5 |
6 |
7 | from pmd_beamphysics.units import nice_array
8 | from pmd_beamphysics.labels import mathlabel
9 |
10 |
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | # Suggested for notebooks:
15 | #import matplotlib
16 | #matplotlib.rcParams['figure.figsize'] = (16,4)
17 | #%config InlineBackend.figure_format = 'retina'
18 |
19 |
20 | def old_plot_fieldmaps(astra_input, sections=['cavity', 'solenoid'], fieldmap_dict = {}, verbose=False):
21 | """
22 | Plots Cavity and Solenoid fielmaps.
23 |
24 | TODO: quadrupoles
25 |
26 | """
27 |
28 | if fieldmap_dict:
29 | fmaps = fieldmap_dict
30 | else:
31 | fmaps = load_fieldmaps(astra_input, sections=sections, verbose=verbose)
32 |
33 | assert len(sections) == 2, 'TODO: more general'
34 |
35 | fig, ax0 = plt.subplots()
36 |
37 | # Make RHS axis for the solenoid field.
38 | ax1 = ax0.twinx()
39 | ax = [ax0, ax1]
40 |
41 | ylabel = {'cavity': '$E_z$ (MV/m)', 'solenoid':'$B_z$ (T)'}
42 | color = {'cavity': 'green', 'solenoid':'blue'}
43 |
44 | for i, section in enumerate(sections):
45 | a = ax[i]
46 | ixlist = find_fieldmap_ixlist(astra_input, section)
47 | for ix in ixlist:
48 | dat = fieldmap_data(astra_input, section=section, index=ix, fieldmaps=fmaps, verbose=verbose)
49 | label = f'{section}_{ix}'
50 | c = color[section]
51 | a.plot(*dat.T, label=label, color=c)
52 | a.set_ylabel(ylabel[section])
53 | ax0.set_xlabel('$z$ (m)')
54 |
55 |
56 |
57 | def add_fieldmaps_to_axes(astra_object, axes, bounds=None,
58 | sections=['cavity', 'solenoid'],
59 | include_labels=True):
60 | """
61 | Adds fieldmaps to an axes.
62 |
63 | """
64 |
65 | astra_input = astra_object.input
66 |
67 | verbose=astra_object.verbose
68 |
69 | if astra_object.fieldmap:
70 | fmaps = astra_object.fieldmap
71 | else:
72 | fmaps = load_fieldmaps(astra_input, sections=sections, verbose=verbose)
73 | ax1 = axes
74 |
75 | ax1rhs = ax1.twinx()
76 | ax = [ax1, ax1rhs]
77 |
78 | ylabel = {'cavity': '$E_z$ (MV/m)', 'solenoid':'$B_z$ (T)'}
79 | color = {'cavity': 'green', 'solenoid':'blue'}
80 |
81 | for i, section in enumerate(sections):
82 | if section not in astra_input:
83 | continue
84 |
85 | a = ax[i]
86 | ixlist = find_fieldmap_ixlist(astra_input, section)
87 | for ix in ixlist:
88 | dat = fieldmap_data(astra_input, section=section, index=ix, fieldmaps=fmaps, verbose=verbose)
89 | if dat is None:
90 | continue
91 | label = f'{section}_{ix}'
92 | c = color[section]
93 | a.plot(*dat.T, label=label, color=c)
94 | a.set_ylabel(ylabel[section])
95 | ax1.set_xlabel('$z$ (m)')
96 |
97 | if bounds:
98 | ax1.set_xlim(bounds[0], bounds[1])
99 |
100 |
101 | def plot_fieldmaps(astra_object, include_labels=True, xlim=None, figsize=(12,4), **kwargs):
102 | """
103 | Simple fieldmap plot
104 | """
105 |
106 | fig, axes = plt.subplots(figsize=figsize, **kwargs)
107 |
108 | add_fieldmaps_to_axes(astra_object, axes, bounds=xlim, include_labels=include_labels,
109 | sections=['cavity', 'solenoid'])
110 |
111 |
112 | def plot_stats(astra_object, keys=['norm_emit_x', 'sigma_z'], sections=['cavity', 'solenoid'], fieldmaps = {}, verbose=False, tex=True):
113 | """
114 | Plots stats, with fieldmaps plotted from seections.
115 |
116 | TODO: quadrupoles
117 |
118 | """
119 |
120 | astra_input = astra_object.input
121 |
122 | fmaps = load_fieldmaps(astra_input, sections=sections, verbose=verbose)
123 |
124 | assert len(sections) == 2, 'TODO: more general'
125 |
126 | nplots = len(keys) + 1
127 |
128 | fig, axs = plt.subplots(nplots)
129 |
130 | # Make RHS axis for the solenoid field.
131 |
132 |
133 | xdat = astra_object.stat('mean_z')
134 | xmin = min(xdat)
135 | xmax = max(xdat)
136 | for i, key in enumerate(keys):
137 | ax = axs[i]
138 |
139 |
140 | ydat = astra_object.stat(key)
141 |
142 | ndat, factor, prefix = nice_array(ydat)
143 | unit = astra_object.units(key)
144 | units=f'{prefix}{unit}'
145 | # Hangle label
146 | ylabel = mathlabel (key, units=units, tex=tex)
147 | ax.set_ylabel(ylabel)
148 | ax.set_xlim(xmin, xmax)
149 | ax.plot(xdat, ndat)
150 |
151 | add_fieldmaps_to_axes(astra_object, axs[-1], bounds=(xmin, xmax),
152 | sections=['cavity', 'solenoid'],
153 | include_labels=True)
154 |
155 |
156 | def plot_stats_with_layout(astra_object, ykeys=['sigma_x', 'sigma_y'], ykeys2=['sigma_z'],
157 | xkey='mean_z', xlim=None,
158 | ylim=None, ylim2=None,
159 | nice=True,
160 | tex=True,
161 | include_layout=False,
162 | include_labels=True,
163 | include_particles=True,
164 | include_legend=True,
165 | return_figure=False,
166 | **kwargs):
167 | """
168 | Plots stat output multiple keys.
169 |
170 | If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key.
171 |
172 | Logical switches:
173 | nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized. Default: True
174 |
175 | tex: use mathtext (TeX) for plot labels. Default: True
176 |
177 | include_legend: The plot will include the legend. Default: True
178 |
179 | include_particles: Plot the particle statistics as dots. Default: True
180 |
181 | include_layout: the layout plot will be displayed at the bottom. Default: True
182 |
183 | include_labels: the layout will include element labels. Default: False
184 |
185 | return_figure: return the figure object for further manipulation. Default: False
186 |
187 | Copied almost verbatim from lume-impact's Impact.plot.plot_stats_with_layout
188 |
189 | """
190 | I = astra_object # convenience
191 |
192 | if include_layout:
193 | fig, all_axis = plt.subplots(2, gridspec_kw={'height_ratios': [4, 1]}, **kwargs)
194 | ax_layout = all_axis[-1]
195 | ax_plot = [all_axis[0]]
196 | else:
197 | fig, all_axis = plt.subplots( **kwargs)
198 | ax_plot = [all_axis]
199 |
200 | # collect axes
201 | if isinstance(ykeys, str):
202 | ykeys = [ykeys]
203 |
204 | if ykeys2:
205 | if isinstance(ykeys2, str):
206 | ykeys2 = [ykeys2]
207 | ax_twinx = ax_plot[0].twinx()
208 | ax_plot.append(ax_twinx)
209 |
210 | # No need for a legend if there is only one plot
211 | if len(ykeys)==1 and not ykeys2:
212 | include_legend=False
213 |
214 | #assert xkey == 'mean_z', 'TODO: other x keys'
215 |
216 | X = I.stat(xkey)
217 |
218 | # Only get the data we need
219 | if xlim:
220 | good = np.logical_and(X >= xlim[0], X <= xlim[1])
221 | X = X[good]
222 | else:
223 | xlim = X.min(), X.max()
224 | good = slice(None,None,None) # everything
225 |
226 | # Try particles within these bounds
227 | Pnames = []
228 | X_particles = []
229 |
230 | if include_particles:
231 | try:
232 | for pname in range(len(I.particles)): # Modified from Impact
233 | xp = I.particles[pname][xkey]
234 | if xp >= xlim[0] and xp <= xlim[1]:
235 | Pnames.append(pname)
236 | X_particles.append(xp)
237 | X_particles = np.array(X_particles)
238 | except:
239 | Pnames = []
240 | else:
241 | Pnames = []
242 |
243 | # X axis scaling
244 | units_x = str(I.units(xkey))
245 | if nice:
246 | X, factor_x, prefix_x = nice_array(X)
247 | units_x = prefix_x+units_x
248 | else:
249 | factor_x = 1
250 |
251 | # set all but the layout
252 | for ax in ax_plot:
253 | ax.set_xlim(xlim[0]/factor_x, xlim[1]/factor_x)
254 |
255 | xlabel = mathlabel(xkey, units=units_x, tex=tex)
256 |
257 | ax.set_xlabel(xlabel)
258 |
259 |
260 | # Draw for Y1 and Y2
261 |
262 | linestyles = ['solid','dashed']
263 |
264 | ii = -1 # counter for colors
265 | for ix, keys in enumerate([ykeys, ykeys2]):
266 | if not keys:
267 | continue
268 | ax = ax_plot[ix]
269 | linestyle = linestyles[ix]
270 |
271 | # Check that units are compatible
272 | ulist = [I.units(key) for key in keys]
273 | if len(ulist) > 1:
274 | for u2 in ulist[1:]:
275 | assert ulist[0] == u2, f'Incompatible units: {ulist[0]} and {u2}'
276 | # String representation
277 | units = str(ulist[0])
278 |
279 | # Data
280 | data = [I.stat(key)[good] for key in keys]
281 |
282 |
283 |
284 | if nice:
285 | factor, prefix = nice_scale_prefix(np.ptp(data))
286 | units = prefix+units
287 | else:
288 | factor = 1
289 |
290 | # Make a line and point
291 | for key, dat in zip(keys, data):
292 | #
293 | ii += 1
294 | color = 'C'+str(ii)
295 | label = mathlabel(key, units=units, tex=tex)
296 | ax.plot(X, dat/factor, label=label, color=color, linestyle=linestyle)
297 |
298 | # Particles
299 | if Pnames:
300 | try:
301 | Y_particles = np.array([I.particles[name][key] for name in Pnames])
302 | ax.scatter(X_particles/factor_x, Y_particles/factor, color=color)
303 | except:
304 | pass
305 | ylabel = mathlabel(*keys, units=units, tex=tex)
306 | ax.set_ylabel(ylabel)
307 |
308 | # Set limits, considering the scaling.
309 | if ix==0 and ylim:
310 | new_ylim = np.array(ylim)/factor
311 | ax.set_ylim(new_ylim)
312 | # Set limits, considering the scaling.
313 | if ix==1 and ylim2:
314 | pass
315 | # TODO
316 | if ylim2:
317 | new_ylim2 = np.array(ylim2)/factor
318 | ax_twinx.set_ylim(new_ylim2)
319 | else:
320 | pass
321 |
322 | # Collect legend
323 | if include_legend:
324 | lines = []
325 | labels = []
326 | for ax in ax_plot:
327 | a, b = ax.get_legend_handles_labels()
328 | lines += a
329 | labels += b
330 | ax_plot[0].legend(lines, labels, loc='best')
331 |
332 | # Layout
333 | if include_layout:
334 |
335 | # Gives some space to the top plot
336 | #ax_layout.set_ylim(-1, 1.5)
337 |
338 | if xkey == 'mean_z':
339 | #ax_layout.set_axis_off()
340 | ax_layout.set_xlim(xlim[0], xlim[1])
341 | else:
342 | ax_layout.set_xlabel(mathlabel('mean_z', units='m'))
343 | xlim = (0, I.stop)
344 | add_fieldmaps_to_axes(I, ax_layout, bounds=xlim, include_labels=include_labels)
345 |
346 | if return_figure:
347 | return fig
--------------------------------------------------------------------------------
/astra/astra_calc.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 |
3 | import sys
4 | import math
5 | import re
6 | import os
7 | import subprocess
8 | import numpy
9 | #import scipy
10 | #import scipy.stats
11 | #import scipy.optimize as sp
12 | from optparse import OptionParser
13 | #import matplotlib.pyplot as plt
14 | import numpy.polynomial.polynomial as poly
15 |
16 |
17 | # ---------------------------------------------------------------------------- #
18 | # Get screen data
19 | # ---------------------------------------------------------------------------- #
20 |
21 | def get_screen_data(phase_import_file, verbose):
22 |
23 | if (verbose):
24 | print('Loading file: ' + phase_import_file)
25 |
26 | # Import the file
27 | phase_data = numpy.loadtxt(phase_import_file)
28 |
29 | if (verbose):
30 | print('Finished loading.')
31 | print('reference: ', phase_data[0])
32 | # Note: assumes the reference particle is at index = 0
33 |
34 | x = phase_data[1:,0]
35 | y = phase_data[1:,1]
36 | z = phase_data[1:,2]
37 | z_ref = phase_data[0,2]
38 | z_rel = z
39 | z = z + z_ref
40 |
41 | px = phase_data[1:,3]
42 | py = phase_data[1:,4]
43 | pz = phase_data[1:,5]
44 | pz_ref = phase_data[0,5]
45 | pz = pz + pz_ref
46 |
47 | qmacro = phase_data[1:,7]
48 | astra_index = phase_data[1:,8]
49 | status = phase_data[1:,9]
50 |
51 | good_particles = numpy.where(status == 5)
52 |
53 | x = x[good_particles]
54 | y = y[good_particles]
55 | z = z[good_particles]
56 | z_rel = z_rel[good_particles]
57 | px = px[good_particles]
58 | py = py[good_particles]
59 | pz = pz[good_particles]
60 | qmacro = qmacro[good_particles]
61 | astra_index = astra_index[good_particles]
62 |
63 | MC = 510998.928 # in eV / c
64 | MC2 = 0.510998928 # in MeV
65 |
66 | GBx = px / MC;
67 | GBy = py / MC;
68 | GBz = pz / MC;
69 | GB2 = GBx**2 + GBy**2 + GBz**2;
70 | GB = numpy.sqrt(GB2)
71 | G = numpy.sqrt(GB2 + 1.0);
72 | Energy = G*MC2
73 | Bx = GBx/G;
74 | By = GBy/G;
75 | Bz = GBz/G;
76 |
77 | c = 0.299792458 # in meters / nanosecond
78 |
79 | if numpy.any(Bz <= 0):
80 | if verbose:
81 | print('ERROR: negative or zero velocity detecter')
82 | return {'error':True}
83 |
84 | t = -z_rel/(Bz * c);
85 | x = x + (Bx * c)*t;
86 | y = y + (By * c)*t;
87 |
88 | units={'x':'m','y':'m','z':'m','px':'eV/c','py':'eV/c','pz':'eV/c','t':'ns','Energy':'MeV'}
89 |
90 | screen_data = {}
91 | screen_data['x'] = x
92 | screen_data['y'] = y
93 | screen_data['z'] = z
94 | screen_data['z_rel'] = z_rel
95 | screen_data['px'] = px
96 | screen_data['py'] = py
97 | screen_data['pz'] = pz
98 | screen_data['qmacro'] = qmacro
99 | screen_data['astra_index'] = astra_index
100 | screen_data['MC'] = MC
101 | screen_data['MC2'] = MC2
102 | screen_data['GBx'] = GBx
103 | screen_data['GBy'] = GBy
104 | screen_data['G'] = G
105 | screen_data['GB'] = GB
106 | screen_data['Energy'] = Energy
107 | screen_data['Bx'] = Bx
108 | screen_data['By'] = By
109 | screen_data['Bz'] = Bz
110 | screen_data['c'] = c
111 | screen_data['t'] = t
112 |
113 | screen_data["units"]=units
114 |
115 | screen_data['error'] = False
116 |
117 | return screen_data
118 |
119 | # ---------------------------------------------------------------------------- #
120 | # Compute the 6x6 Sigma (Second Moment) Matrix
121 | # ---------------------------------------------------------------------------- #
122 | def calc_sigma_matrix(screen_data, verbose, variables=None):
123 |
124 | if(variables is None):
125 | variables = ['x','GBx','y','GBy','t','Energy'] # Note: t is arrival time
126 |
127 | # Use relative energy coordinates:
128 | if(variables[5]=="deltaE"):
129 | screen_data["deltaE"]=screen_data["Energy"]/screen_data["Energy"].mean()
130 | elif(variables[5]=="deltaP"):
131 | screen_data["deltaP"]=screen_data["GB"]/screen_data["GB"].mean()
132 |
133 | sigma = numpy.empty(shape=(6,6))
134 | sigma[:] = numpy.NAN
135 |
136 | for ii in range(6):
137 | for jj in range(6):
138 |
139 | if(numpy.isnan(sigma[ii,jj])):
140 |
141 | ustr = variables[ii]
142 | vstr = variables[jj]
143 |
144 | sigma[ii,jj] = numpy.mean( (screen_data[ustr]-screen_data[ustr].mean()) * (screen_data[vstr]-screen_data[vstr].mean()) )
145 | sigma[jj,ii] = sigma[ii,jj]
146 |
147 | return sigma
148 |
149 | def calc_avg_norm_Lz(screen_data,verbose,sigma=None):
150 |
151 | if(sigma==None):
152 | sigma = calc_sigma_matrix(screen_data,verbose)
153 |
154 | avg_norm_Lz = sigma[0,3] - sigma[2,1]
155 | return avg_norm_Lz
156 |
157 | def calc_uncorr_espread(sigma2x2):
158 |
159 | sig_t = numpy.sqrt(sigma2x2[0,0])
160 | sig_E = numpy.sqrt(sigma2x2[1,1])
161 |
162 | sig_max = max([sig_t,sig_E])
163 | sig_min = min([sig_t,sig_E])
164 |
165 | if(sig_max/sig_min < 10):
166 | print("Warning in calc_uncorr_epsread -> may have bad scaling of phase space variables.")
167 |
168 | # Compute the eigenvalues of 2x2 sigma matrix:
169 | #a = sigma2x2[0,0]
170 | #b = sigma2x2[0,1]
171 | #c = sigma2x2[1,1]
172 | #eps = a*c - b*b
173 |
174 | #lp = 0.5*( (a+c) + numpy.sqrt( (a+c)*(a+c) - 4*eps*eps) )
175 | #lm = 0.5*( (a+c) - numpy.sqrt( (a+c)*(a+c) - 4*eps*eps) )
176 |
177 | eigs = numpy.linalg.eig(sigma2x2)
178 | ls = eigs[0]
179 | lp = max(ls)
180 | lm = min(ls)
181 |
182 | if(sig_t > sig_E):
183 | sig_E_uncorr = math.sqrt(lm)
184 | else:
185 | sig_E_uncorr = math.sqrt(lp)
186 |
187 | return sig_E_uncorr
188 |
189 | # ---------------------------------------------------------------------------- #
190 | # Fits a quadratic to the Energy vs. time, subtracts it, finds the rms of the residual in keV
191 | # ---------------------------------------------------------------------------- #
192 |
193 | def calc_ho_energy_spread(screen_data, verbose):
194 |
195 | Energy = screen_data["Energy"]
196 | t = screen_data["t"]
197 |
198 | # Calculate higher order energy spread
199 | best_fit_coeffs = poly.polyfit(t, Energy, 2)
200 | best_fit = poly.polyval(t, best_fit_coeffs)
201 |
202 | #if (verbose):
203 | #t_plot = numpy.linspace(min(t), max(t), 100)
204 | #Energy_plot = poly.polyval(t_plot, best_fit_coeffs)
205 | #plt.plot(t, Energy, '.', t_plot, Energy_plot, '-')
206 | #plt.show()
207 |
208 | Energy_higher_order = Energy - best_fit
209 |
210 | energy_ho_rms = numpy.std(Energy_higher_order)*1000.0 # in keV
211 |
212 | if (verbose):
213 | print("Energy rms (higher order) = " + str(energy_ho_rms) )
214 |
215 | return energy_ho_rms
216 |
217 | # ---------------------------------------------------------------------------- #
218 | # Returns peak current in amps
219 | # ---------------------------------------------------------------------------- #
220 |
221 | def calc_peak_current(screen_data, verbose):
222 | t = screen_data["t"]
223 | qmacro = screen_data["qmacro"] # handles cases where macro particle charge varies
224 |
225 | # Calculate peak current
226 |
227 | hist_bins = 20 # number of histogram bins. 2000 particles -> 20 bins
228 |
229 | hist_edges = numpy.linspace(min(t), max(t), hist_bins)
230 | dt = hist_edges[1] - hist_edges[0]
231 | hist_edges = numpy.linspace(min(t)-0.5*dt, max(t)+0.5*dt, hist_bins+1)
232 | dt = hist_edges[1] - hist_edges[0]
233 |
234 | bin_index = numpy.searchsorted(hist_edges,t, "right")
235 | current = numpy.bincount(bin_index-1, weights=numpy.abs(qmacro))/dt
236 |
237 | peak_current = max(current)
238 |
239 | if (verbose):
240 | print("Peak current = " + str(peak_current) + " Amps")
241 |
242 | #if (verbose):
243 | #t_plot = hist_edges[0:-1] + 0.5*dt
244 | #print str(len(t_plot)) + " " + str(len(current))
245 | #plt.plot(t_plot, current, '-')
246 | #plt.show()
247 |
248 | return peak_current
249 |
250 | # ---------------------------------------------------------------------------- #
251 | # Returns skewness of the temporal distribution (see wikipedia on 'skewness')
252 | # ---------------------------------------------------------------------------- #
253 |
254 | def calc_skewness(screen_data, verbose):
255 | t = screen_data["t"]
256 |
257 | skewness = 0 ###FIXME numpy.abs(scipy.stats.skew(t))
258 |
259 | if (verbose):
260 | print("Skewness = " + str(skewness) + " (unitless)")
261 |
262 | return skewness
263 |
264 | # ---------------------------------------------------------------------------- #
265 | # Main function
266 | # ---------------------------------------------------------------------------- #
267 |
268 | def main():
269 |
270 | parser = OptionParser()
271 | parser.add_option("-v", "--verbose",
272 | action="store_true", dest="verbose", default=False,
273 | help="don't print status messages to stdout")
274 | parser.add_option("-n", "--noastra",
275 | action="store_true", dest="noastra", default=False,
276 | help="don't run astra")
277 | parser.add_option("-a", "--astra", dest="astra_name", default="astra",
278 | help="name of astra executable", metavar="FILE")
279 |
280 | (options, args) = parser.parse_args()
281 |
282 | verbose = options.verbose
283 | noastra = options.noastra
284 | astra_name = options.astra_name
285 |
286 | if (verbose):
287 | print("")
288 | print("Beginning Astra wrapper...")
289 |
290 | # Interpret input arguments
291 | path_to_input_file = args[0]
292 |
293 | # Directory / Filenames
294 |
295 | script_directory = os.path.dirname(sys.argv[0]) + '/'
296 | input_directory = os.path.dirname(path_to_input_file) + '/'
297 | script_name = os.path.basename(sys.argv[0])
298 | input_name = os.path.basename(path_to_input_file)
299 | astra_binary = script_directory + astra_name
300 | input_basename = input_name.replace('.in', '')
301 | merit_name = input_basename + '.merit'
302 |
303 | # Call Astra
304 |
305 | if (noastra == False):
306 | command = astra_binary + ' ' + path_to_input_file
307 | subprocess.call(command.split());
308 |
309 |
310 | # Get Output
311 |
312 | energy_ho_rms = 0; # Default good value
313 | peak_current = 3e14; # Default good value
314 | skewness = 0;
315 |
316 | phase_import_file = get_final_screen_name(input_directory, input_basename)
317 |
318 | if (len(phase_import_file) > 0):
319 |
320 | screen_data = get_screen_data(phase_import_file, verbose)
321 |
322 | energy_ho_rms = calc_ho_energy_spread(screen_data, verbose)
323 |
324 | peak_current = calc_peak_current(screen_data, verbose)
325 |
326 | skewness = calc_skewness(screen_data, verbose)
327 |
328 | variables = ['x','GBx','y','GBy','t','deltaE']
329 |
330 | sigma = calc_sigma_matrix(screen_data, verbose,variables)
331 | avgLz = calc_avg_norm_Lz(None,verbose,sigma)
332 | uncorr_epsread = calc_uncorr_espread(sigma[4:,4:])
333 |
334 | for ii in range(6):
335 | for jj in range(6):
336 | if(sigma[ii,jj]!=sigma[jj,ii]):
337 | print("sigma matrix is not symetric!")
338 |
339 | else:
340 | # File did not exist, so something got screwed up. Output a bad number
341 | energy_ho_rms = 3e14
342 | peak_current = 0.0
343 | skewness = 3e14
344 |
345 | # Output merit file
346 |
347 | merit = []
348 |
349 | merit.append(energy_ho_rms)
350 | merit.append(peak_current)
351 | merit.append(skewness)
352 |
353 | output_file = open(input_directory + merit_name, "w")
354 | for value in merit:
355 | output_file.write( str(value) + " ")
356 | output_file.write('\n')
357 | output_file.close()
358 |
359 | # Test function for calculating uncorrelated energy spread:
360 | sigT = 10
361 | sigE = 1e-4
362 | npart = 100000
363 |
364 | t = numpy.random.normal(0,sigT,(1,npart))
365 | E = numpy.random.normal(0,sigE,(1,npart))
366 |
367 | theta = 35*(math.pi/180)
368 | C = math.cos(theta)
369 | S = math.sin(theta)
370 |
371 | tr = C*t + S*E
372 | Er = -S*t + C*E
373 |
374 | dt = tr-tr.mean()
375 | dE = Er-Er.mean()
376 |
377 | t2 = numpy.mean(dt*dt)
378 | E2 = numpy.mean(dE*dE)
379 | tE = numpy.mean(dt*dE)
380 |
381 | sigma2x2 = numpy.array([[t2, tE],[tE, E2]])
382 | print(sigma2x2)
383 | espread = calc_uncorr_espread(sigma2x2)
384 | print( (espread-sigE)/sigE )
385 |
386 | # ---------------------------------------------------------------------------- #
387 | # This allows the main function to be at the beginning of the file
388 | # ---------------------------------------------------------------------------- #
389 |
390 | if __name__ == '__main__':
391 | main()
392 |
393 |
--------------------------------------------------------------------------------
/astra/interfaces/bmad.py:
--------------------------------------------------------------------------------
1 |
2 | from pmd_beamphysics import FieldMesh
3 | from pmd_beamphysics.fields.analysis import accelerating_voltage_and_phase
4 | from astra.parsers import find_max_pos
5 | import numpy as np
6 | import os
7 |
8 |
9 | def bmad_cavity(astra_object, ix,
10 | name='CAV{ix}',
11 | keyword='lcavity',
12 | superimpose=True,
13 | ele_origin='center',
14 | ref_offset = 0):
15 | """
16 |
17 | Parameters
18 | ----------
19 |
20 | astra_object : Astra class
21 |
22 | name : str
23 | Name of the element. This can be
24 |
25 | ix : int
26 | Cavity element index in the astra input, as in `c_pos(3)` for ix=3
27 |
28 |
29 | keyword : str
30 | Element keyword
31 | 'e_gun' : make a Bmad e_gun element
32 | 'lcavity' : make an Bmad lcavity element
33 | 'some_existing_ele' : use a pre-defined Bmad element
34 |
35 |
36 | superimpose : bool
37 | Use superposition in Bmad to place the element. Default = True
38 |
39 | ref_offset : float
40 | Addition offset to be used for superposition
41 |
42 | Returns
43 | -------
44 |
45 | line : str
46 | Bmad line text (may contain \\n)
47 |
48 |
49 | """
50 | cav = astra_object.input['cavity']
51 |
52 | pos = cav[f'c_pos({ix})'] #m
53 | freq = cav[f'nue({ix})'] * 1e9 # Hz
54 | emax = cav[f'maxe({ix})']*1e6 # V/m
55 | _, fieldmap = os.path.split(cav[f'file_efield({ix})'])
56 | fmap = astra_object.fieldmap[fieldmap]
57 | phi0 = cav[f'phi({ix})']/360
58 |
59 | # v=c voltage NOT GOOD IN INJECTORS
60 | z0 = fmap['data'][:,0]
61 | Ez0 = fmap['data'][:,1]
62 | Ez0 = Ez0/np.abs(Ez0).max() * emax # Normalize
63 | voltage, _ = accelerating_voltage_and_phase(z0, Ez0, freq)
64 |
65 | z_start = pos + z0.min()
66 | L = round(np.ptp(z0), 9)
67 |
68 | # Fill in name
69 | name = name.format(ix=ix)
70 |
71 | # Attributes
72 | attrs = dict(
73 | rf_frequency = freq,
74 | phi0 = phi0,
75 | )
76 | if emax == 0:
77 | attrs['voltage'] = 0
78 | else:
79 | attrs['autoscale_amplitude'] = False
80 | attrs['field_autoscale'] = emax
81 |
82 | if keyword in ['e_gun', 'lcavity']:
83 | attrs["L"] = L
84 |
85 | if superimpose:
86 | offset = ref_offset + pos
87 | attrs['offset'] = ref_offset + pos
88 | attrs['superimpose'] = True
89 | attrs['ele_origin'] = ele_origin
90 |
91 |
92 |
93 | dat = {
94 | 'attrs': attrs,
95 | 'ele_key': keyword,
96 | 'ele_name': name,
97 | }
98 |
99 | dat['line'] = ele_line(dat)
100 |
101 | return dat
102 |
103 |
104 | def bmad_solenoid(astra_object, ix,
105 | name='SOL{ix}',
106 | keyword='solenoid',
107 | superimpose=True,
108 | ele_origin='center',
109 | ref_offset = 0):
110 | """
111 |
112 |
113 | Returns
114 | -------
115 |
116 |
117 | """
118 | sol = astra_object.input['solenoid']
119 |
120 | pos = sol[f's_pos({ix})'] #m
121 | bmax = sol[f'maxb({ix})'] # T
122 | _, fieldmap = os.path.split(sol[f'file_bfield({ix})'])
123 | fmap = astra_object.fieldmap[fieldmap]
124 |
125 | z0 = fmap['data'][:,0]
126 | Bz0 = fmap['data'][:,1]
127 | Bz0 = Bz0/Bz0.max()
128 |
129 | BL = np.trapz(Bz0, z0)
130 | B2L = np.trapz(Bz0**2, z0)
131 |
132 | L_hard = BL**2/B2L
133 | B_hard = B2L/BL * bmax
134 |
135 | z_start = pos + z0.min()
136 | L = round(np.ptp(z0), 9)
137 |
138 | # Prevent wapping (will fix in Bmad in the future)
139 | if z_start < 0:
140 | z_start = 0
141 | L += z_start
142 |
143 | # Fill in name
144 | name = name.format(ix=ix)
145 |
146 |
147 | attrs = {}
148 | if keyword == 'solenoid':
149 | attrs['L'] = L
150 | else:
151 | attrs['bs_field'] = bmax
152 |
153 | if superimpose:
154 | attrs['superimpose'] = True
155 | attrs['ele_origin'] = ele_origin
156 | attrs['offset'] = ref_offset + pos
157 |
158 | info = f"""! B_max = {bmax} T
159 | ! \int B dL = B_max * {BL} m
160 | ! \int B^2 dL = B_max^2 * {B2L} m
161 | ! Hard edge L = {L_hard} m
162 | ! Hard edge B = {B_hard} T"""
163 |
164 |
165 | dat = {
166 | 'attrs': attrs,
167 | 'ele_key': keyword,
168 | 'ele_name': name,
169 | 'info': info
170 | }
171 |
172 | dat['line'] = ele_line(dat)
173 |
174 | return dat
175 |
176 |
177 | def ele_line(ele_dat):
178 | lines = []
179 | lines.append(f'{ele_dat["ele_name"]}: {ele_dat["ele_key"]}')
180 | if "info" in ele_dat:
181 | lines.append(ele_dat["info"])
182 | for k, v in ele_dat["attrs"].items():
183 | lines.append(f" {k} = {v}")
184 | line = ',\n '.join(lines)
185 | return line
186 |
187 |
188 |
189 |
190 | # ------------------
191 | # PyTao
192 |
193 | def ele_info(tao, ele_id):
194 | """
195 | Returns a dict of element attributes from ele_head and ele_gen_attribs
196 | """
197 | edat = tao.ele_head(ele_id)
198 | edat.update(tao.ele_gen_attribs(ele_id))
199 | s = edat['s']
200 | L = edat['L']
201 | edat['s_begin'] = s-L
202 | edat['s_center'] = (s + edat['s_begin'])/2
203 |
204 | return edat
205 |
206 | def tao_create_astra_fieldmap_ele(tao,
207 | ele_id,
208 | *,
209 | cache=None):
210 |
211 | # Ele info from Tao
212 | edat = ele_info(tao, ele_id)
213 | ix_ele = edat['ix_ele']
214 | ele_key = edat['key'].upper()
215 |
216 | # FieldMesh
217 | grid_params = tao.ele_grid_field(ix_ele, 1, 'base', as_dict=False)
218 | field_file = grid_params['file'].value
219 |
220 | short_fieldmap_name = edat['name']+'_fieldmap.dat'
221 | if cache is None:
222 | field_mesh = FieldMesh(field_file)
223 | # Convert to 1D fieldmap
224 | fieldmap = field_mesh.to_astra_1d()
225 | fieldmap['attrs']['eleAnchorPt'] = field_mesh.attrs['eleAnchorPt']
226 | else:
227 | # Check for existence
228 | if field_file in cache:
229 | # Already found
230 | fieldmap = cache[field_file]
231 | short_fieldmap_name = cache['short_fieldmap_name:'+field_file]
232 | else:
233 | # New fieldmap, add to cache
234 | field_mesh = FieldMesh(field_file)
235 | fieldmap = field_mesh.to_astra_1d()
236 | # Add anchor
237 | fieldmap['attrs']['eleAnchorPt'] = field_mesh.attrs['eleAnchorPt']
238 | cache[field_file] = fieldmap
239 | cache['short_fieldmap_name:'+field_file] = short_fieldmap_name
240 |
241 |
242 | eleAnchorPt = fieldmap['attrs']['eleAnchorPt']
243 |
244 | # Frequency
245 | freq = edat.get('RF_FREQUENCY', 0)
246 | #assert np.allclose(freq, field_mesh.frequency), f'{freq} != {field_mesh.frequency}'
247 |
248 | # Master parameter
249 | master_parameter = grid_params['master_parameter'].value
250 | if master_parameter == '':
251 | master_parameter = None
252 |
253 | # Find z_anchor
254 | if eleAnchorPt == 'beginning':
255 | z_anchor = edat['s_begin']
256 | elif eleAnchorPt == 'center':
257 | z_anchor = edat['s_center']
258 | else:
259 | raise NotImplementedError(f'{eleAnchorPt} not implemented')
260 |
261 | # Phase and scale
262 | if ele_key == 'SOLENOID':
263 | assert master_parameter is not None
264 | scale = edat[master_parameter]
265 |
266 | bfactor = np.abs(fieldmap['data'][:,1]).max()
267 | if not np.isclose(bfactor, 1):
268 | scale *= bfactor
269 |
270 | astra_ele = {
271 | 'astra_type': 'solenoid',
272 | 'file_bfield': short_fieldmap_name,
273 | 's_pos': z_anchor,
274 | 'maxb': scale,
275 | 's_xoff': edat['X_OFFSET'],
276 | 's_yoff': edat['Y_OFFSET'],
277 | 's_smooth': 0,
278 | 's_higher_order': True}
279 |
280 |
281 | elif ele_key in ('E_GUN', 'LCAVITY'):
282 | if master_parameter is None:
283 | scale = edat['FIELD_AUTOSCALE']
284 | else:
285 | scale = edat[master_parameter]
286 |
287 | efactor = np.abs(fieldmap['data'][:,1]).max()
288 | if not np.isclose(efactor, 1):
289 | scale *= efactor
290 |
291 | phi0_user = sum([edat['PHI0'], edat['PHI0_ERR'] ])
292 |
293 | astra_ele = {
294 | 'astra_type': 'cavity', # This will be later extracted
295 | 'file_efield': short_fieldmap_name,
296 | 'c_pos': z_anchor,
297 | 'maxe': -scale/1e6,
298 | 'nue': freq/1e9,
299 | 'phi': phi0_user * 360,
300 | 'c_xoff': edat['X_OFFSET'],
301 | 'c_yoff': edat['Y_OFFSET'],
302 | 'c_smooth': 0,
303 | 'c_higher_order': True}
304 |
305 | return astra_ele, {short_fieldmap_name: fieldmap}
306 |
307 | def tao_create_astra_quadrupole_ele(tao,
308 | ele_id):
309 |
310 | edat = ele_info(tao, ele_id)
311 |
312 | astra_ele = {
313 | 'q_pos': edat['s_center'],
314 | 'q_grad': -edat['B1_GRADIENT'],
315 | 'q_length': edat['L'],
316 | }
317 |
318 | return astra_ele
319 |
320 |
321 | def tao_create_astra_lattice_and_fieldmaps(tao,
322 | fieldmap_eles='E_GUN::*,SOLENOID::*,LCAVITY::*',
323 | quadrupole_eles = 'quad::*',
324 | ):
325 | """
326 | Create LUME-Astra style lattice (input namelists) and fieldmaps from a PyTao Tao instance.
327 |
328 |
329 | Parameters
330 | ----------
331 | tao: Tao object
332 |
333 | fieldmap_eles: str, default = 'E_GUN::*,SOLENOID::*,LCAVITY::*'
334 | Bmad match string to find fieldmap elements
335 |
336 |
337 | Returns
338 | -------
339 | dict with of dict with keys:
340 | 'cavity'
341 | 'solenoid'
342 | 'fieldmap'
343 |
344 | """
345 |
346 | # Extract elements to use
347 | ele_ixs = tao.lat_list(fieldmap_eles, 'ele.ix_ele', flags='-array_out -no_slaves')
348 |
349 | # Form lattice and fieldmaps
350 | cache = {}
351 | fieldmaps = {}
352 | cavity = []
353 | solenoid = []
354 | for ix_ele in ele_ixs:
355 | astra_ele, fieldmap_dict = tao_create_astra_fieldmap_ele(tao,
356 | ele_id=ix_ele,
357 | cache=cache)
358 |
359 | fieldmaps.update(fieldmap_dict)
360 |
361 | astra_type = astra_ele.pop('astra_type')
362 | if astra_type == 'cavity':
363 | cavity.append(astra_ele)
364 | elif astra_type == 'solenoid':
365 | solenoid.append(astra_ele)
366 |
367 |
368 | # Quadrupoles
369 | quad_ix_eles = tao.lat_list(quadrupole_eles, 'ele.ix_ele', flags='-array_out -no_slaves')
370 | quadrupole = []
371 | for ele_id in quad_ix_eles:
372 | quadrupole.append(tao_create_astra_quadrupole_ele(tao, ele_id))
373 |
374 | # convert to dicts
375 | cavity_dict = {'lefield':True}
376 | for ix, ele in enumerate(cavity):
377 | for key in ele:
378 | cavity_dict[f"{key}({ix+1})"] = ele[key]
379 |
380 | solenoid_dict = {'lbfield':True}
381 | for ix, ele in enumerate(solenoid):
382 | for key in ele:
383 | solenoid_dict[f"{key}({ix+1})"] = ele[key]
384 |
385 | quadrupole_dict = {'lquad':True}
386 | for ix, ele in enumerate(quadrupole):
387 | for key in ele:
388 | quadrupole_dict[f"{key}({ix+1})"] = ele[key]
389 |
390 | return {'cavity': cavity_dict,
391 | 'solenoid': solenoid_dict,
392 | 'quadrupole': quadrupole_dict,
393 | 'fieldmap': fieldmaps}
394 |
395 |
396 | def astra_from_tao(tao, cls=None):
397 | """
398 | Create a complete Astra object from a running Pytao Tao instance.
399 |
400 | Parameters
401 | ----------
402 | tao: Tao object
403 |
404 | Returns
405 | -------
406 | astra_object: Astra
407 | Converted Astra object
408 | """
409 |
410 | # Create blank object
411 | if cls is None:
412 | from astra import Astra as cls
413 | A = cls() # This has some defaults.
414 |
415 | # Check for cathode start
416 | if len(tao.lat_list('e_gun::*', 'ele.ix_ele')) > 0:
417 | cathode_start = True
418 | else:
419 | cathode_start = False
420 |
421 | # Special settings for cathode start.
422 | # TODO: pass these in more elegantly.
423 | if cathode_start:
424 | A.input['output']['cathodes'] = True
425 | A.input['charge']['lmirror'] = True
426 | else:
427 | A.input['output']['cathodes'] = False
428 |
429 | # Get elements
430 | res = tao_create_astra_lattice_and_fieldmaps(tao)
431 | A.fieldmap.update(res.pop('fieldmap'))
432 | A.input.update(res)
433 |
434 | # Update zstop
435 | zmax = find_max_pos(A.input)
436 | A.input['output']['zstop'] = zmax + 1 # Some padding
437 |
438 | return A
--------------------------------------------------------------------------------
/astra/parsers.py:
--------------------------------------------------------------------------------
1 | """
2 | Astra output parsing
3 |
4 | References:
5 |
6 | PHYSICAL REVIEW SPECIAL TOPICS - ACCELERATORS AND BEAMS,VOLUME 6, 034202 (2003)
7 | https://journals.aps.org/prab/pdf/10.1103/PhysRevSTAB.6.034202
8 |
9 | """
10 |
11 | from pmd_beamphysics.units import unit
12 |
13 | import os
14 |
15 | from math import isnan
16 | import numpy as np
17 | import re
18 |
19 | # ------------
20 |
21 |
22 | def unit_dict(keys, unit_symbols):
23 | """
24 | Forms a dict mapping keys to pmd_unit objects
25 | """
26 | d = {}
27 | for k, s in zip(keys, unit_symbols):
28 | d[k] = unit(s)
29 | return d
30 |
31 |
32 | # New style
33 | OutputColumnNames = {}
34 | OutputColumnFactors = {}
35 | OutputUnits = {} # This collects all units
36 |
37 | CemitColumnNames = ['mean_z', 'norm_emit_x', 'core_emit_95percent_x', 'core_emit_90percent_x', 'core_emit_80percent_x',
38 | 'norm_emit_y', 'core_emit_95percent_y', 'core_emit_90percent_y', 'core_emit_80percent_y',
39 | 'norm_emit_z', 'core_emit_95percent_z', 'core_emit_905percent_z', 'core_emit_80percent_z']
40 | CemitOriginalUnits = ['m'] + 8*['mm-mrad'] + 4*['kev-mm'] # Units that Astra writes
41 | CemitColumnFactors = [1] + 12*[1e-6] # Factors to make standard units
42 | CemitColumnUnits = ['m'] + 8*['m'] + 4*['eV*m']
43 | CemitColumn = dict( zip(CemitColumnNames, list(range(1, 1+len(CemitColumnNames) ) ) ) )
44 | OutputUnits.update(unit_dict(CemitColumnNames, CemitColumnUnits))
45 |
46 | XemitColumnNames = ['mean_z', 'mean_t', 'mean_x', 'sigma_x', 'sigma_xp', 'norm_emit_x', 'cov_x__xp/sigma_x']
47 | XemitOriginalUnits = ['m', 'ns', 'mm', 'mm', 'mrad', 'mm-mrad', 'mrad' ] # Units that Astra writes
48 | XemitColumnFactors = [1, 1e-9, 1e-3, 1e-3, 1e-3, 1e-6, 1e-3] # Factors to make standard units
49 | XemitColumnUnits = ['m', 's', 'm', 'm', '1', 'm', 'rad' ]
50 | XemitColumn = dict( zip(XemitColumnNames, list(range(1, 1+len(XemitColumnNames) ) ) ) )
51 | OutputUnits.update(unit_dict(XemitColumnNames, XemitColumnUnits))
52 |
53 |
54 | YemitColumnNames = ['mean_z', 'mean_t', 'mean_y', 'sigma_y', 'sigma_yp', 'norm_emit_y', 'cov_y__yp/sigma_y']
55 | YemitOriginalUnits = ['m', 'ns', 'mm', 'mm', 'mrad', 'mm-mrad', 'mrad' ] # Units that Astra writes
56 | YemitColumnFactors = [1, 1e-9, 1e-3, 1e-3, 1e-3, 1e-6, 1e-3] # Factors to make standard units
57 | YemitColumnUnits = ['m', 's', 'm', 'm', '1', 'm', 'rad' ]
58 | YemitColumn = dict( zip(YemitColumnNames, list(range(1, 1+len(YemitColumnNames) ) ) ) )
59 | OutputUnits.update(unit_dict(YemitColumnNames, YemitColumnUnits) )
60 |
61 | ZemitColumnNames = ['mean_z', 'mean_t', 'mean_kinetic_energy', 'sigma_z', 'sigma_energy', 'norm_emit_z', 'cov_z__energy/sigma_z']
62 | ZemitOriginalUnits = ['m', 'ns', 'MeV', 'mm', 'keV', 'mm-keV', 'keV' ]
63 | ZemitColumnFactors = [1, 1e-9, 1e6, 1e-3, 1e3, 1, 1e3] # Factors to make standard units
64 | ZemitColumnUnits = ['m', 's', 'eV', 'm', 'eV', 'm*eV', 'eV' ]
65 | ZemitColumn = dict( zip(ZemitColumnNames, list(range(1, 1+len(ZemitColumnNames) ) ) ) )
66 | OutputUnits.update(unit_dict(ZemitColumnNames, ZemitColumnUnits))
67 |
68 | LandFColumnNames = ['landf_z', 'landf_n_particles', 'landf_total_charge', 'landf_n_lost', 'landf_energy_deposited', 'landf_energy_exchange']
69 | LandFOriginalUnits = ['m', '1', 'nC', '1', 'J', 'J']
70 | LandFColumnFactors = [1, 1, -1e-9, 1, 1, 1]
71 | LandFColumnUnits = ['m', '1', 'C', '1', 'J', 'J']
72 | OutputUnits.update(unit_dict(LandFColumnNames, LandFColumnUnits))
73 |
74 | OutputColumnNames['Cemit'] = CemitColumnNames
75 | OutputColumnNames['Xemit'] = XemitColumnNames
76 | OutputColumnNames['Yemit'] = YemitColumnNames
77 | OutputColumnNames['Zemit'] = ZemitColumnNames
78 | OutputColumnNames['LandF'] = LandFColumnNames
79 |
80 | OutputColumnFactors['Cemit'] = CemitColumnFactors
81 | OutputColumnFactors['Xemit'] = XemitColumnFactors
82 | OutputColumnFactors['Yemit'] = YemitColumnFactors
83 | OutputColumnFactors['Zemit'] = ZemitColumnFactors
84 | OutputColumnFactors['LandF'] = LandFColumnFactors
85 |
86 | ERROR = {'error': True}
87 |
88 |
89 | # Special additions
90 | OutputUnits['cov_x__xp'] = unit('m')
91 | OutputUnits['cov_y__yp'] = unit('m')
92 | OutputUnits['cov_z__energy'] = unit('m*eV')
93 |
94 |
95 | def astra_run_extension(run_number):
96 | """
97 | Astra adds an extension according to the run number: 1 -> '001'
98 | """
99 | return str(run_number).zfill(3)
100 |
101 |
102 |
103 |
104 | def find_astra_output_files(input_filePath, run_number,
105 | types= ['Cemit', 'Xemit', 'Yemit', 'Zemit', 'LandF']
106 | ):
107 | """
108 | Finds the existing output files, based on standard Astra extensions.
109 | """
110 |
111 | extensions = ['.'+x+'.'+astra_run_extension(run_number) for x in types]
112 |
113 | # List of output files
114 | path, infile = os.path.split(input_filePath)
115 | prefix = infile.split('.')[0] # Astra uses inputfile to name output
116 | outfiles = [os.path.join(path, prefix+x) for x in extensions]
117 |
118 | return [o for o in outfiles if os.path.exists(o)]
119 |
120 |
121 |
122 | def astra_output_type(filename):
123 | return filename.split('.')[-2]
124 |
125 |
126 |
127 |
128 |
129 | def parse_astra_output_file(filePath, standardize_labels=True,):
130 | """
131 | Simple parsing of tabular output files, according to names in this file.
132 |
133 | If standardize labels, the covariance labels and data will be simplified.
134 |
135 | """
136 |
137 | # Check for empty file
138 | if os.stat(filePath).st_size == 0:
139 | raise ValueError(f'ERROR: Empty output file: {filePath}')
140 |
141 | data = np.loadtxt(filePath, ndmin=2)
142 | if data.shape == ():
143 | raise ValueError(f'No data in file: {filePath}')
144 |
145 | if len(data) == 0:
146 | raise ValueError(f'No data in file (zero length): {filePath}')
147 |
148 | d = {}
149 | type = astra_output_type(filePath)
150 |
151 | # Get the appropriate keys and factors
152 | keys = OutputColumnNames[type]
153 | factors = OutputColumnFactors[type]
154 |
155 | for i in range(len(keys)):
156 | #print(filePath, keys[i])
157 |
158 | d[keys[i]] = data[:,i]*factors[i]
159 |
160 |
161 | if standardize_labels:
162 | if type == 'Xemit':
163 | d['cov_x__xp'] = d.pop('cov_x__xp/sigma_x')*d['sigma_x']
164 |
165 | if type == 'Yemit':
166 | d['cov_y__yp'] = d.pop('cov_y__yp/sigma_y')*d['sigma_y']
167 |
168 | if type == 'Zemit':
169 | d['cov_z__energy'] = d.pop('cov_z__energy/sigma_z')*d['sigma_z']
170 |
171 | # Special modifications
172 | #if type in ['Xemit', 'Yemit', 'Zemit']:
173 |
174 |
175 | return d
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 | # ------ Number parsing ------
184 | def isfloat(value):
185 | try:
186 | float(value)
187 | return True
188 | except ValueError:
189 | return False
190 |
191 | def isbool(x):
192 | z = x.strip().strip('.').upper()
193 | if z in ['T', 'TRUE', 'F', 'FALSE']:
194 | return True
195 | else:
196 | return False
197 |
198 | def try_int(x):
199 | if x == int(x):
200 | return int(x)
201 | else:
202 | return x
203 |
204 | def try_bool(x):
205 | z = x.strip().strip('.').upper()
206 | if z in ['T', 'TRUE']:
207 | return True
208 | elif z in ['F', 'FALSE']:
209 | return False
210 | else:
211 | return x
212 |
213 |
214 | # Simple function to try casting to a float, bool, or int
215 | def number(x):
216 | z = x.replace('D', 'E') # Some floating numbers use D
217 | if isfloat(z):
218 | val = try_int(float(z))
219 | elif isbool(x):
220 | val = try_bool(x)
221 | else:
222 | # must be a string. Strip quotes.
223 | val = x.strip().strip('\'').strip('\"')
224 | return val
225 |
226 |
227 | # ------ Astra input file (namelist format) parsing
228 |
229 | def clean_namelist_key_value(line):
230 | """
231 | Cleans up a namelist "key = value line"
232 |
233 | Removes all spaces, and makes the key lower case.
234 |
235 | """
236 | z = line.split('=')
237 | # Make key lower case, strip
238 |
239 | key = z[0].strip().lower().replace(' ', '')
240 | value = ''.join(z[1:])
241 |
242 | return f'{key} = {value}'
243 |
244 | def unroll_namelist_line(line, commentchar='!', condense=False ):
245 | """
246 | Unrolls namelist lines. Looks for vectors, or multiple keys per line.
247 | """
248 | lines = []
249 | # Look for comments
250 | x = line.strip().strip(',').split(commentchar)
251 | if len(x) ==1:
252 | # No comments
253 | x = x[0].strip()
254 | else:
255 | # Unroll comment first
256 | comment = ''.join(x[1:])
257 | if not condense:
258 | lines.append('!'+comment)
259 | x = x[0].strip()
260 | if x == '':
261 | pass
262 | elif x[0] == '&' or x[0]=='/':
263 | # This is namelist control. Write.
264 | lines.append(x.lower())
265 | else:
266 | # Content line. Should contain =
267 | # unroll.
268 | # Check for multiple keys per line, or vectors.
269 | # TODO: handle both
270 | n_keys = len(x.split('='))
271 | if n_keys ==2:
272 | # Single key
273 | lines.append(clean_namelist_key_value(x))
274 | elif n_keys >2:
275 | for y in x.strip(',').split(','):
276 | lines.append(clean_namelist_key_value(y))
277 |
278 | return lines
279 |
280 | def parse_simple_namelist(filePath, commentchar='!', condense=False ):
281 | """
282 | Unrolls namelist style file. Returns lines.
283 | makes keys lower case
284 |
285 | Example:
286 |
287 | &my_namelist
288 |
289 | x=1, YY = 4 ! this is a comment:
290 | /
291 |
292 | unrolls to:
293 | &my_namelist
294 | ! this is a comment
295 | x = 1
296 | yy = 4
297 | /
298 |
299 | """
300 |
301 | lines = []
302 | with open(filePath, 'r') as f:
303 | if condense:
304 | pad = ''
305 | else:
306 | pad = ' '
307 |
308 | for line in f:
309 | ulines = unroll_namelist_line(line, commentchar=commentchar, condense=condense)
310 | lines = lines + ulines
311 |
312 |
313 | return lines
314 |
315 |
316 |
317 |
318 |
319 | def parse_unrolled_namelist(unrolled_lines):
320 | """
321 | Parses an unrolled namelist into a dict
322 |
323 | """
324 | namelists={}
325 | for line in unrolled_lines:
326 | if line[0]=='1' or line[0]=='/' or line[0]=='!':
327 | # Ignore
328 | continue
329 | if line[0]=='&':
330 | name = line[1:].lower()
331 | namelists[name]={}
332 | # point to current namelist
333 | n = namelists[name]
334 | continue
335 | # content line
336 | key, val = line.split('=')
337 |
338 | # look for vector
339 | vals = val.split()
340 | if len(vals) == 1:
341 | val = number(vals[0])
342 | else:
343 | if isfloat(vals[0].replace(',',' ')):
344 | # Vector. Remove commas
345 | val = [number(z) for z in val.replace(',',' ').split()]
346 | else:
347 | # This is just a string. Just strip
348 | val = val.strip()
349 | n[key.strip()] = val
350 |
351 |
352 | return namelists
353 |
354 |
355 | def parse_astra_input_file(filePath, condense=False):
356 | """
357 | Parses an Astra input file into separate dicts for each namelist.
358 | Returns a dict of namelists.
359 | """
360 | lines = parse_simple_namelist(filePath, condense=condense)
361 | namelists = parse_unrolled_namelist(lines)
362 | return namelists
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 | def fix_input_paths(input_dict, root='', prefixes=['file_', 'distribution', 'q_type']):
371 | """
372 | Looks for keys in the input dict of dicts, that start with any strings as in the prefixes list.
373 | This should indicate a file. Then, fill in the absoulute path.
374 | root should be the original input file path.
375 |
376 | Does not replace absolute paths, or paths where the file does not exist.
377 |
378 | """
379 | for nl in input_dict:
380 | for key in input_dict[nl]:
381 | if any([key.startswith(prefix) for prefix in prefixes]):
382 | val = input_dict[nl][key]
383 |
384 | # Skip absolute paths
385 | if os.path.isabs(val):
386 | continue
387 | newval = os.path.abspath(os.path.join(root, val))
388 |
389 | # Skip if does not exist
390 | if not os.path.exists(newval):
391 | continue
392 |
393 | #assert os.path.exists(newval)
394 | #print(key, val, newval)
395 | input_dict[nl][key] = newval
396 |
397 |
398 |
399 | def find_max_pos(astra_input):
400 | """
401 | Find the maximum center position of elements in the Astra input dict
402 | """
403 | zmax = 0
404 | for group in ['cavity', 'solenoid', 'quadrupole']:
405 | if group in astra_input:
406 | nl = astra_input[group]
407 | for key in nl:
408 | if '_pos' in key:
409 | zmax = max(zmax, nl[key])
410 |
411 | return zmax
412 |
413 | # ------------------------------------------------------------------
414 | # ------------------------- Astra particles ------------------------
415 | def find_phase_files(input_filePath, run_number=1):
416 | """
417 | Returns a list of the phase space files, sorted by z position
418 | (filemname , z_approx)
419 | """
420 | path, infile = os.path.split(input_filePath)
421 | prefix = infile.split('.')[0] # Astra uses inputfile to name output
422 | phase_import_file = ''
423 | phase_files = [];
424 | run_extension = astra_run_extension(run_number)
425 | for file in os.listdir(path):
426 | if re.match(prefix + '.\d\d\d\d.'+run_extension, file):
427 | # Get z position
428 | z = float(file.replace(prefix+ '.', '').replace('.'+run_extension,''))
429 | phase_file=os.path.join(path, file)
430 | phase_files.append((phase_file, z))
431 | # Sort by z
432 | return sorted(phase_files, key=lambda x: x[1])
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
--------------------------------------------------------------------------------
/astra/_version.py:
--------------------------------------------------------------------------------
1 |
2 | # This file helps to compute a version number in source trees obtained from
3 | # git-archive tarball (such as those provided by githubs download-from-tag
4 | # feature). Distribution tarballs (built by setup.py sdist) and build
5 | # directories (produced by setup.py build) will contain a much shorter file
6 | # that just contains the computed version number.
7 |
8 | # This file is released into the public domain. Generated by
9 | # versioneer-0.20 (https://github.com/python-versioneer/python-versioneer)
10 |
11 | """Git implementation of _version.py."""
12 |
13 | import errno
14 | import os
15 | import re
16 | import subprocess
17 | import sys
18 |
19 |
20 | def get_keywords():
21 | """Get the keywords needed to look up the version information."""
22 | # these strings will be replaced by git during git-archive.
23 | # setup.py/versioneer.py will grep for the variable names, so they must
24 | # each be defined on a line of their own. _version.py will just call
25 | # get_keywords().
26 | git_refnames = " (HEAD -> master)"
27 | git_full = "0af6c8454819a1cd22b789db7065ef6f0970dfbf"
28 | git_date = "2024-10-19 08:38:41 -0700"
29 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
30 | return keywords
31 |
32 |
33 | class VersioneerConfig: # pylint: disable=too-few-public-methods
34 | """Container for Versioneer configuration parameters."""
35 |
36 |
37 | def get_config():
38 | """Create, populate and return the VersioneerConfig() object."""
39 | # these strings are filled in when 'setup.py versioneer' creates
40 | # _version.py
41 | cfg = VersioneerConfig()
42 | cfg.VCS = "git"
43 | cfg.style = "pep440"
44 | cfg.tag_prefix = "v"
45 | cfg.parentdir_prefix = "None"
46 | cfg.versionfile_source = "astra/_version.py"
47 | cfg.verbose = False
48 | return cfg
49 |
50 |
51 | class NotThisMethod(Exception):
52 | """Exception raised if a method is not valid for the current scenario."""
53 |
54 |
55 | LONG_VERSION_PY = {}
56 | HANDLERS = {}
57 |
58 |
59 | def register_vcs_handler(vcs, method): # decorator
60 | """Create decorator to mark a method as the handler of a VCS."""
61 | def decorate(f):
62 | """Store f in HANDLERS[vcs][method]."""
63 | if vcs not in HANDLERS:
64 | HANDLERS[vcs] = {}
65 | HANDLERS[vcs][method] = f
66 | return f
67 | return decorate
68 |
69 |
70 | # pylint:disable=too-many-arguments,consider-using-with # noqa
71 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
72 | env=None):
73 | """Call the given command(s)."""
74 | assert isinstance(commands, list)
75 | process = None
76 | for command in commands:
77 | try:
78 | dispcmd = str([command] + args)
79 | # remember shell=False, so use git.cmd on windows, not just git
80 | process = subprocess.Popen([command] + args, cwd=cwd, env=env,
81 | stdout=subprocess.PIPE,
82 | stderr=(subprocess.PIPE if hide_stderr
83 | else None))
84 | break
85 | except EnvironmentError:
86 | e = sys.exc_info()[1]
87 | if e.errno == errno.ENOENT:
88 | continue
89 | if verbose:
90 | print("unable to run %s" % dispcmd)
91 | print(e)
92 | return None, None
93 | else:
94 | if verbose:
95 | print("unable to find command, tried %s" % (commands,))
96 | return None, None
97 | stdout = process.communicate()[0].strip().decode()
98 | if process.returncode != 0:
99 | if verbose:
100 | print("unable to run %s (error)" % dispcmd)
101 | print("stdout was %s" % stdout)
102 | return None, process.returncode
103 | return stdout, process.returncode
104 |
105 |
106 | def versions_from_parentdir(parentdir_prefix, root, verbose):
107 | """Try to determine the version from the parent directory name.
108 |
109 | Source tarballs conventionally unpack into a directory that includes both
110 | the project name and a version string. We will also support searching up
111 | two directory levels for an appropriately named parent directory
112 | """
113 | rootdirs = []
114 |
115 | for _ in range(3):
116 | dirname = os.path.basename(root)
117 | if dirname.startswith(parentdir_prefix):
118 | return {"version": dirname[len(parentdir_prefix):],
119 | "full-revisionid": None,
120 | "dirty": False, "error": None, "date": None}
121 | rootdirs.append(root)
122 | root = os.path.dirname(root) # up a level
123 |
124 | if verbose:
125 | print("Tried directories %s but none started with prefix %s" %
126 | (str(rootdirs), parentdir_prefix))
127 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
128 |
129 |
130 | @register_vcs_handler("git", "get_keywords")
131 | def git_get_keywords(versionfile_abs):
132 | """Extract version information from the given file."""
133 | # the code embedded in _version.py can just fetch the value of these
134 | # keywords. When used from setup.py, we don't want to import _version.py,
135 | # so we do it with a regexp instead. This function is not used from
136 | # _version.py.
137 | keywords = {}
138 | try:
139 | with open(versionfile_abs, "r") as fobj:
140 | for line in fobj:
141 | if line.strip().startswith("git_refnames ="):
142 | mo = re.search(r'=\s*"(.*)"', line)
143 | if mo:
144 | keywords["refnames"] = mo.group(1)
145 | if line.strip().startswith("git_full ="):
146 | mo = re.search(r'=\s*"(.*)"', line)
147 | if mo:
148 | keywords["full"] = mo.group(1)
149 | if line.strip().startswith("git_date ="):
150 | mo = re.search(r'=\s*"(.*)"', line)
151 | if mo:
152 | keywords["date"] = mo.group(1)
153 | except EnvironmentError:
154 | pass
155 | return keywords
156 |
157 |
158 | @register_vcs_handler("git", "keywords")
159 | def git_versions_from_keywords(keywords, tag_prefix, verbose):
160 | """Get version information from git keywords."""
161 | if "refnames" not in keywords:
162 | raise NotThisMethod("Short version file found")
163 | date = keywords.get("date")
164 | if date is not None:
165 | # Use only the last line. Previous lines may contain GPG signature
166 | # information.
167 | date = date.splitlines()[-1]
168 |
169 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
170 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
171 | # -like" string, which we must then edit to make compliant), because
172 | # it's been around since git-1.5.3, and it's too difficult to
173 | # discover which version we're using, or to work around using an
174 | # older one.
175 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
176 | refnames = keywords["refnames"].strip()
177 | if refnames.startswith("$Format"):
178 | if verbose:
179 | print("keywords are unexpanded, not using")
180 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
181 | refs = {r.strip() for r in refnames.strip("()").split(",")}
182 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
183 | # just "foo-1.0". If we see a "tag: " prefix, prefer those.
184 | TAG = "tag: "
185 | tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
186 | if not tags:
187 | # Either we're using git < 1.8.3, or there really are no tags. We use
188 | # a heuristic: assume all version tags have a digit. The old git %d
189 | # expansion behaves like git log --decorate=short and strips out the
190 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish
191 | # between branches and tags. By ignoring refnames without digits, we
192 | # filter out many common branch names like "release" and
193 | # "stabilization", as well as "HEAD" and "master".
194 | tags = {r for r in refs if re.search(r'\d', r)}
195 | if verbose:
196 | print("discarding '%s', no digits" % ",".join(refs - tags))
197 | if verbose:
198 | print("likely tags: %s" % ",".join(sorted(tags)))
199 | for ref in sorted(tags):
200 | # sorting will prefer e.g. "2.0" over "2.0rc1"
201 | if ref.startswith(tag_prefix):
202 | r = ref[len(tag_prefix):]
203 | # Filter out refs that exactly match prefix or that don't start
204 | # with a number once the prefix is stripped (mostly a concern
205 | # when prefix is '')
206 | if not re.match(r'\d', r):
207 | continue
208 | if verbose:
209 | print("picking %s" % r)
210 | return {"version": r,
211 | "full-revisionid": keywords["full"].strip(),
212 | "dirty": False, "error": None,
213 | "date": date}
214 | # no suitable tags, so version is "0+unknown", but full hex is still there
215 | if verbose:
216 | print("no suitable tags, using unknown + full revision id")
217 | return {"version": "0+unknown",
218 | "full-revisionid": keywords["full"].strip(),
219 | "dirty": False, "error": "no suitable tags", "date": None}
220 |
221 |
222 | @register_vcs_handler("git", "pieces_from_vcs")
223 | def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
224 | """Get version from 'git describe' in the root of the source tree.
225 |
226 | This only gets called if the git-archive 'subst' keywords were *not*
227 | expanded, and _version.py hasn't already been rewritten with a short
228 | version string, meaning we're inside a checked out source tree.
229 | """
230 | GITS = ["git"]
231 | if sys.platform == "win32":
232 | GITS = ["git.cmd", "git.exe"]
233 |
234 | _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
235 | hide_stderr=True)
236 | if rc != 0:
237 | if verbose:
238 | print("Directory %s not under git control" % root)
239 | raise NotThisMethod("'git rev-parse --git-dir' returned error")
240 |
241 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
242 | # if there isn't one, this yields HEX[-dirty] (no NUM)
243 | describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty",
244 | "--always", "--long",
245 | "--match", "%s*" % tag_prefix],
246 | cwd=root)
247 | # --long was added in git-1.5.5
248 | if describe_out is None:
249 | raise NotThisMethod("'git describe' failed")
250 | describe_out = describe_out.strip()
251 | full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
252 | if full_out is None:
253 | raise NotThisMethod("'git rev-parse' failed")
254 | full_out = full_out.strip()
255 |
256 | pieces = {}
257 | pieces["long"] = full_out
258 | pieces["short"] = full_out[:7] # maybe improved later
259 | pieces["error"] = None
260 |
261 | branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
262 | cwd=root)
263 | # --abbrev-ref was added in git-1.6.3
264 | if rc != 0 or branch_name is None:
265 | raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
266 | branch_name = branch_name.strip()
267 |
268 | if branch_name == "HEAD":
269 | # If we aren't exactly on a branch, pick a branch which represents
270 | # the current commit. If all else fails, we are on a branchless
271 | # commit.
272 | branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
273 | # --contains was added in git-1.5.4
274 | if rc != 0 or branches is None:
275 | raise NotThisMethod("'git branch --contains' returned error")
276 | branches = branches.split("\n")
277 |
278 | # Remove the first line if we're running detached
279 | if "(" in branches[0]:
280 | branches.pop(0)
281 |
282 | # Strip off the leading "* " from the list of branches.
283 | branches = [branch[2:] for branch in branches]
284 | if "master" in branches:
285 | branch_name = "master"
286 | elif not branches:
287 | branch_name = None
288 | else:
289 | # Pick the first branch that is returned. Good or bad.
290 | branch_name = branches[0]
291 |
292 | pieces["branch"] = branch_name
293 |
294 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
295 | # TAG might have hyphens.
296 | git_describe = describe_out
297 |
298 | # look for -dirty suffix
299 | dirty = git_describe.endswith("-dirty")
300 | pieces["dirty"] = dirty
301 | if dirty:
302 | git_describe = git_describe[:git_describe.rindex("-dirty")]
303 |
304 | # now we have TAG-NUM-gHEX or HEX
305 |
306 | if "-" in git_describe:
307 | # TAG-NUM-gHEX
308 | mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
309 | if not mo:
310 | # unparseable. Maybe git-describe is misbehaving?
311 | pieces["error"] = ("unable to parse git-describe output: '%s'"
312 | % describe_out)
313 | return pieces
314 |
315 | # tag
316 | full_tag = mo.group(1)
317 | if not full_tag.startswith(tag_prefix):
318 | if verbose:
319 | fmt = "tag '%s' doesn't start with prefix '%s'"
320 | print(fmt % (full_tag, tag_prefix))
321 | pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
322 | % (full_tag, tag_prefix))
323 | return pieces
324 | pieces["closest-tag"] = full_tag[len(tag_prefix):]
325 |
326 | # distance: number of commits since tag
327 | pieces["distance"] = int(mo.group(2))
328 |
329 | # commit: short hex revision ID
330 | pieces["short"] = mo.group(3)
331 |
332 | else:
333 | # HEX: no tags
334 | pieces["closest-tag"] = None
335 | count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
336 | pieces["distance"] = int(count_out) # total number of commits
337 |
338 | # commit date: see ISO-8601 comment in git_versions_from_keywords()
339 | date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
340 | # Use only the last line. Previous lines may contain GPG signature
341 | # information.
342 | date = date.splitlines()[-1]
343 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
344 |
345 | return pieces
346 |
347 |
348 | def plus_or_dot(pieces):
349 | """Return a + if we don't already have one, else return a ."""
350 | if "+" in pieces.get("closest-tag", ""):
351 | return "."
352 | return "+"
353 |
354 |
355 | def render_pep440(pieces):
356 | """Build up version string, with post-release "local version identifier".
357 |
358 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
359 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
360 |
361 | Exceptions:
362 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
363 | """
364 | if pieces["closest-tag"]:
365 | rendered = pieces["closest-tag"]
366 | if pieces["distance"] or pieces["dirty"]:
367 | rendered += plus_or_dot(pieces)
368 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
369 | if pieces["dirty"]:
370 | rendered += ".dirty"
371 | else:
372 | # exception #1
373 | rendered = "0+untagged.%d.g%s" % (pieces["distance"],
374 | pieces["short"])
375 | if pieces["dirty"]:
376 | rendered += ".dirty"
377 | return rendered
378 |
379 |
380 | def render_pep440_branch(pieces):
381 | """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
382 |
383 | The ".dev0" means not master branch. Note that .dev0 sorts backwards
384 | (a feature branch will appear "older" than the master branch).
385 |
386 | Exceptions:
387 | 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
388 | """
389 | if pieces["closest-tag"]:
390 | rendered = pieces["closest-tag"]
391 | if pieces["distance"] or pieces["dirty"]:
392 | if pieces["branch"] != "master":
393 | rendered += ".dev0"
394 | rendered += plus_or_dot(pieces)
395 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
396 | if pieces["dirty"]:
397 | rendered += ".dirty"
398 | else:
399 | # exception #1
400 | rendered = "0"
401 | if pieces["branch"] != "master":
402 | rendered += ".dev0"
403 | rendered += "+untagged.%d.g%s" % (pieces["distance"],
404 | pieces["short"])
405 | if pieces["dirty"]:
406 | rendered += ".dirty"
407 | return rendered
408 |
409 |
410 | def render_pep440_pre(pieces):
411 | """TAG[.post0.devDISTANCE] -- No -dirty.
412 |
413 | Exceptions:
414 | 1: no tags. 0.post0.devDISTANCE
415 | """
416 | if pieces["closest-tag"]:
417 | rendered = pieces["closest-tag"]
418 | if pieces["distance"]:
419 | rendered += ".post0.dev%d" % pieces["distance"]
420 | else:
421 | # exception #1
422 | rendered = "0.post0.dev%d" % pieces["distance"]
423 | return rendered
424 |
425 |
426 | def render_pep440_post(pieces):
427 | """TAG[.postDISTANCE[.dev0]+gHEX] .
428 |
429 | The ".dev0" means dirty. Note that .dev0 sorts backwards
430 | (a dirty tree will appear "older" than the corresponding clean one),
431 | but you shouldn't be releasing software with -dirty anyways.
432 |
433 | Exceptions:
434 | 1: no tags. 0.postDISTANCE[.dev0]
435 | """
436 | if pieces["closest-tag"]:
437 | rendered = pieces["closest-tag"]
438 | if pieces["distance"] or pieces["dirty"]:
439 | rendered += ".post%d" % pieces["distance"]
440 | if pieces["dirty"]:
441 | rendered += ".dev0"
442 | rendered += plus_or_dot(pieces)
443 | rendered += "g%s" % pieces["short"]
444 | else:
445 | # exception #1
446 | rendered = "0.post%d" % pieces["distance"]
447 | if pieces["dirty"]:
448 | rendered += ".dev0"
449 | rendered += "+g%s" % pieces["short"]
450 | return rendered
451 |
452 |
453 | def render_pep440_post_branch(pieces):
454 | """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
455 |
456 | The ".dev0" means not master branch.
457 |
458 | Exceptions:
459 | 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
460 | """
461 | if pieces["closest-tag"]:
462 | rendered = pieces["closest-tag"]
463 | if pieces["distance"] or pieces["dirty"]:
464 | rendered += ".post%d" % pieces["distance"]
465 | if pieces["branch"] != "master":
466 | rendered += ".dev0"
467 | rendered += plus_or_dot(pieces)
468 | rendered += "g%s" % pieces["short"]
469 | if pieces["dirty"]:
470 | rendered += ".dirty"
471 | else:
472 | # exception #1
473 | rendered = "0.post%d" % pieces["distance"]
474 | if pieces["branch"] != "master":
475 | rendered += ".dev0"
476 | rendered += "+g%s" % pieces["short"]
477 | if pieces["dirty"]:
478 | rendered += ".dirty"
479 | return rendered
480 |
481 |
482 | def render_pep440_old(pieces):
483 | """TAG[.postDISTANCE[.dev0]] .
484 |
485 | The ".dev0" means dirty.
486 |
487 | Exceptions:
488 | 1: no tags. 0.postDISTANCE[.dev0]
489 | """
490 | if pieces["closest-tag"]:
491 | rendered = pieces["closest-tag"]
492 | if pieces["distance"] or pieces["dirty"]:
493 | rendered += ".post%d" % pieces["distance"]
494 | if pieces["dirty"]:
495 | rendered += ".dev0"
496 | else:
497 | # exception #1
498 | rendered = "0.post%d" % pieces["distance"]
499 | if pieces["dirty"]:
500 | rendered += ".dev0"
501 | return rendered
502 |
503 |
504 | def render_git_describe(pieces):
505 | """TAG[-DISTANCE-gHEX][-dirty].
506 |
507 | Like 'git describe --tags --dirty --always'.
508 |
509 | Exceptions:
510 | 1: no tags. HEX[-dirty] (note: no 'g' prefix)
511 | """
512 | if pieces["closest-tag"]:
513 | rendered = pieces["closest-tag"]
514 | if pieces["distance"]:
515 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
516 | else:
517 | # exception #1
518 | rendered = pieces["short"]
519 | if pieces["dirty"]:
520 | rendered += "-dirty"
521 | return rendered
522 |
523 |
524 | def render_git_describe_long(pieces):
525 | """TAG-DISTANCE-gHEX[-dirty].
526 |
527 | Like 'git describe --tags --dirty --always -long'.
528 | The distance/hash is unconditional.
529 |
530 | Exceptions:
531 | 1: no tags. HEX[-dirty] (note: no 'g' prefix)
532 | """
533 | if pieces["closest-tag"]:
534 | rendered = pieces["closest-tag"]
535 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
536 | else:
537 | # exception #1
538 | rendered = pieces["short"]
539 | if pieces["dirty"]:
540 | rendered += "-dirty"
541 | return rendered
542 |
543 |
544 | def render(pieces, style):
545 | """Render the given version pieces into the requested style."""
546 | if pieces["error"]:
547 | return {"version": "unknown",
548 | "full-revisionid": pieces.get("long"),
549 | "dirty": None,
550 | "error": pieces["error"],
551 | "date": None}
552 |
553 | if not style or style == "default":
554 | style = "pep440" # the default
555 |
556 | if style == "pep440":
557 | rendered = render_pep440(pieces)
558 | elif style == "pep440-branch":
559 | rendered = render_pep440_branch(pieces)
560 | elif style == "pep440-pre":
561 | rendered = render_pep440_pre(pieces)
562 | elif style == "pep440-post":
563 | rendered = render_pep440_post(pieces)
564 | elif style == "pep440-post-branch":
565 | rendered = render_pep440_post_branch(pieces)
566 | elif style == "pep440-old":
567 | rendered = render_pep440_old(pieces)
568 | elif style == "git-describe":
569 | rendered = render_git_describe(pieces)
570 | elif style == "git-describe-long":
571 | rendered = render_git_describe_long(pieces)
572 | else:
573 | raise ValueError("unknown style '%s'" % style)
574 |
575 | return {"version": rendered, "full-revisionid": pieces["long"],
576 | "dirty": pieces["dirty"], "error": None,
577 | "date": pieces.get("date")}
578 |
579 |
580 | def get_versions():
581 | """Get version information or return default if unable to do so."""
582 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
583 | # __file__, we can work backwards from there to the root. Some
584 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
585 | # case we can only use expanded keywords.
586 |
587 | cfg = get_config()
588 | verbose = cfg.verbose
589 |
590 | try:
591 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
592 | verbose)
593 | except NotThisMethod:
594 | pass
595 |
596 | try:
597 | root = os.path.realpath(__file__)
598 | # versionfile_source is the relative path from the top of the source
599 | # tree (where the .git directory might live) to this file. Invert
600 | # this to find the root from __file__.
601 | for _ in cfg.versionfile_source.split('/'):
602 | root = os.path.dirname(root)
603 | except NameError:
604 | return {"version": "0+unknown", "full-revisionid": None,
605 | "dirty": None,
606 | "error": "unable to find root of source tree",
607 | "date": None}
608 |
609 | try:
610 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
611 | return render(pieces, cfg.style)
612 | except NotThisMethod:
613 | pass
614 |
615 | try:
616 | if cfg.parentdir_prefix:
617 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
618 | except NotThisMethod:
619 | pass
620 |
621 | return {"version": "0+unknown", "full-revisionid": None,
622 | "dirty": None,
623 | "error": "unable to compute version", "date": None}
624 |
--------------------------------------------------------------------------------
/astra/astra.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import tempfile
4 | import shutil
5 | import os
6 | import platform
7 | import traceback
8 | from time import time
9 | from copy import deepcopy
10 | import functools
11 |
12 | import h5py
13 | import numpy as np
14 | import yaml
15 | from lume import tools as lumetools
16 | from lume.base import CommandWrapper
17 | from pmd_beamphysics import ParticleGroup
18 |
19 | from . import parsers, writers, tools, archive
20 | from .control import ControlGroup
21 | from .fieldmaps import load_fieldmaps, write_fieldmaps
22 | from .generator import AstraGenerator
23 | from .plot import plot_stats_with_layout, plot_fieldmaps
24 | from .interfaces.bmad import astra_from_tao
25 |
26 | from pmd_beamphysics import ParticleGroup, single_particle
27 | from pmd_beamphysics.interfaces.astra import parse_astra_phase_file
28 |
29 |
30 | import numpy as np
31 |
32 | import h5py
33 | import yaml
34 |
35 | import warnings
36 |
37 |
38 |
39 | class Astra(CommandWrapper):
40 |
41 | """
42 | Astra simulation object. Essential methods:
43 | .__init__(...)
44 | .configure()
45 | .run()
46 |
47 | Input deck is held in .input
48 | Output data is parsed into .output
49 | .load_particles() will load particle data into .output['particles'][...]
50 |
51 | The Astra binary file can be set on init. If it doesn't exist, configure will check the
52 | $ASTRA_BIN
53 | environmental variable.
54 |
55 |
56 | """
57 | MPI_SUPPORTED = False
58 | COMMAND = "$ASTRA_BIN"
59 | INPUT_PARSER = parsers.parse_astra_input_file
60 |
61 | def __init__(self,
62 | input_file=None,
63 | *,
64 | group=None,
65 | **kwargs
66 | ):
67 | super().__init__(input_file=input_file, **kwargs)
68 | # Save init
69 | self.original_input_file = self._input_file
70 |
71 | # These will be set
72 | self.log = []
73 | self.output = {'stats': {}, 'particles': {}, 'run_info': {}}
74 | self.group = {} # Control Groups
75 | self.fieldmap = {} # Fieldmaps
76 |
77 | # Call configure
78 | if self.input_file:
79 | self.load_input(self.input_file)
80 | self.configure()
81 |
82 | # Add groups, if any.
83 | if group:
84 | for k, v in group.items():
85 | self.add_group(k, **v)
86 | else:
87 | self.vprint('Using default input: 1 m drift lattice')
88 | self.original_input_file = 'astra.in'
89 | self.input = deepcopy(DEFAULT_INPUT)
90 | self.configure()
91 |
92 | def add_group(self, name, **kwargs):
93 | """
94 | Add a control group. See control.py
95 |
96 | Parameters
97 | ----------
98 | name : str
99 | The group name
100 | """
101 | assert name not in self.input, f'{name} not allowed to be overwritten by group.'
102 | if name in self.group:
103 | self.vprint(f'Warning: group {name} already exists, overwriting.')
104 |
105 | g = ControlGroup(**kwargs)
106 | g.link(self.input)
107 | self.group[name] = g
108 |
109 | return self.group[name]
110 |
111 | def clean_output(self):
112 | run_number = parsers.astra_run_extension(self.input['newrun']['run'])
113 | outfiles = parsers.find_astra_output_files(self.input_file, run_number)
114 | for f in outfiles:
115 | os.remove(f)
116 |
117 | def clean_particles(self):
118 | run_number = parsers.astra_run_extension(self.input['newrun']['run'])
119 | phase_files = parsers.find_phase_files(self.input_file, run_number)
120 | files = [x[0] for x in phase_files] # This is sorted by approximate z
121 | for f in files:
122 | os.remove(f)
123 |
124 | # Convenience routines
125 | @property
126 | def particles(self):
127 | return self.output['particles']
128 |
129 | def stat(self, key):
130 | return self.output['stats'][key]
131 |
132 | def particle_stat(self, key, alive_only=True):
133 | """
134 | Compute a statistic from the particles.
135 |
136 | Alive particles have status == 1. By default, statistics will only be computed on these.
137 |
138 | n_dead will override the alive_only flag,
139 | and return the number of particles with status < -6 (Astra convention)
140 | """
141 |
142 | if key == 'n_dead':
143 | return np.array([len(np.where(P.status < -6)[0]) for P in self.particles])
144 |
145 | if key == 'n_alive':
146 | return np.array([len(np.where(P.status > -6)[0]) for P in self.particles])
147 |
148 | pstats = []
149 | for P in self.particles:
150 | if alive_only and P.n_dead > 0:
151 | P = P.where(P.status == 1)
152 | pstats.append(P[key])
153 | return np.array(pstats)
154 |
155 |
156 | def configure(self):
157 | self.setup_workdir(self._workdir)
158 | self.command = lumetools.full_path(self.command)
159 | self.vprint("Configured to run in:", self.path)
160 | self.input_file = os.path.join(self.path, self.original_input_file)
161 | self.configured = True
162 |
163 |
164 | # def configure(self):
165 | # self.configure_astra(workdir=self.path)
166 | #
167 | # def configure_astra(self, input_filepath=None, workdir=None):
168 | #
169 | # # if input_filepath:
170 | # # self.load_input(input_filepath)
171 | #
172 | # # Check that binary exists
173 | # #self.command = lumetools.full_path(self.command)
174 | #
175 | # self.setup_workdir(self._workdir)
176 | # #self.input_file = os.path.join(self.path, self.original_input_file)
177 | # self.configured = True
178 |
179 | def load_fieldmaps(self, search_paths=[]):
180 | """
181 | Loads fieldmaps into Astra.fieldmap as a dict.
182 |
183 | Optionally, a list of paths can be included that will search for these. The default will search self.path.
184 | """
185 |
186 | # Do not consider files if fieldmaps have been loaded.
187 | if self.fieldmap:
188 | strip_path = False
189 | else:
190 | strip_path = True
191 |
192 | if not search_paths:
193 | search_paths = [self.path]
194 |
195 | self.fieldmap = load_fieldmaps(self.input, fieldmap_dict=self.fieldmap, search_paths=search_paths,
196 | verbose=self.verbose, strip_path=strip_path)
197 |
198 | def load_initial_particles(self, h5):
199 | """Loads a openPMD-beamphysics particle h5 handle or file"""
200 | P = ParticleGroup(h5=h5)
201 | self.initial_particles = P
202 |
203 | def input_parser(self, path):
204 | return parsers.parse_astra_input_file(path)
205 |
206 | def load_input(self, input_filepath, absolute_paths=True, **kwargs):
207 | super().load_input(input_filepath, **kwargs)
208 | if absolute_paths:
209 | parsers.fix_input_paths(self.input, root=self.original_path)
210 |
211 | def load_output(self, include_particles=True):
212 | """
213 | Loads Astra output files into .output
214 |
215 | .output is a dict with dicts:
216 | .stats
217 | .run_info
218 | .other
219 |
220 | and if include_particles,
221 | .particles = list of ParticleGroup objects
222 |
223 | """
224 | run_number = parsers.astra_run_extension(self.input['newrun']['run'])
225 | outfiles = parsers.find_astra_output_files(self.input_file, run_number)
226 |
227 | # assert len(outfiles)>0, 'No output files found'
228 |
229 | stats = self.output['stats'] = {}
230 |
231 | for f in outfiles:
232 | type = parsers.astra_output_type(f)
233 | d = parsers.parse_astra_output_file(f)
234 | if type in ['Cemit', 'Xemit', 'Yemit', 'Zemit']:
235 | stats.update(d)
236 | elif type in ['LandF']:
237 | self.output['other'] = d
238 | else:
239 | raise ValueError(f'Unknown output type: {type}')
240 |
241 | # Check that the lengths of all arrays are the same
242 | nlist = {len(stats[k]) for k in stats}
243 |
244 | assert len(nlist) == 1, f'Stat keys do not all have the same length: {[len(stats[k]) for k in stats]}'
245 |
246 | if include_particles:
247 | self.load_particles()
248 |
249 | def load_particles(self, end_only=False):
250 | # Clear existing particles
251 | self.output['particles'] = []
252 |
253 | # Sort files by approximate z
254 | run_number = parsers.astra_run_extension(self.input['newrun']['run'])
255 | phase_files = parsers.find_phase_files(self.input_file, run_number)
256 | files = [x[0] for x in phase_files] # This is sorted by approximate z
257 | zapprox = [x[1] for x in phase_files]
258 |
259 | if end_only:
260 | files = files[-1:]
261 | if self.verbose:
262 | print('loading ' + str(len(files)) + ' particle files')
263 | print(zapprox)
264 | for f in files:
265 | pdat = parse_astra_phase_file(f)
266 | P = ParticleGroup(data=pdat)
267 | self.output['particles'].append(P)
268 |
269 | def run(self):
270 | self.run_astra()
271 |
272 | def run_astra(self, verbose=False, parse_output=True, timeout=None):
273 | """
274 | Runs Astra
275 |
276 | Changes directory, so does not work with threads.
277 | """
278 | if not self.configured:
279 | print('not configured to run')
280 | return
281 |
282 | run_info = self.output['run_info'] = {}
283 |
284 | t1 = time()
285 | run_info['start_time'] = t1
286 |
287 | # Write all input
288 | self.write_input()
289 |
290 | runscript = self.get_run_script()
291 | tools.make_executable(os.path.join(self.path, 'run'))
292 | run_info['run_script'] = ' '.join(runscript)
293 |
294 | if self.timeout:
295 | res = tools.execute2(runscript, timeout=timeout, cwd=self.path)
296 | log = res['log']
297 | self.error = res['error']
298 | run_info['why_error'] = res['why_error']
299 | # Log file must have this to have finished properly
300 | if log.find('finished simulation') == -1:
301 | raise ValueError("Couldn't find finished simulation")
302 |
303 | else:
304 | # Interactive output, for Jupyter
305 | log = []
306 | for path in tools.execute(runscript, cwd=self.path):
307 | self.vprint(path, end="")
308 | log.append(path)
309 |
310 | self.log = log
311 |
312 | if parse_output:
313 | self.load_output()
314 |
315 | run_info['run_time'] = time() - t1
316 |
317 | self.finished = True
318 |
319 | self.vprint(run_info)
320 |
321 | def units(self, key):
322 | if key in parsers.OutputUnits:
323 | return parsers.OutputUnits[key]
324 | else:
325 | return 'unknown unit'
326 |
327 | def load_archive(self, h5=None, configure=False):
328 | """
329 | Loads input and output from archived h5 file.
330 |
331 | See: Astra.archive
332 | """
333 | if isinstance(h5, str):
334 | h5 = os.path.expandvars(h5)
335 | g = h5py.File(h5, 'r')
336 |
337 | glist = archive.find_astra_archives(g)
338 | n = len(glist)
339 | if n == 0:
340 | # legacy: try top level
341 | message = 'legacy'
342 | elif n == 1:
343 | gname = glist[0]
344 | message = f'group {gname} from'
345 | g = g[gname]
346 | else:
347 | raise ValueError(f'Multiple archives found in file {h5}: {glist}')
348 |
349 | self.vprint(f'Reading {message} archive file {h5}')
350 | else:
351 | g = h5
352 |
353 | self.input = archive.read_input_h5(g['input'])
354 | self.output = archive.read_output_h5(g['output'])
355 | if 'initial_particles' in g:
356 | self.initial_particles = ParticleGroup(h5=g['initial_particles'])
357 |
358 | if 'fieldmap' in g:
359 | self.fieldmap = archive.read_fieldmap_h5(g['fieldmap'])
360 |
361 | if 'control_groups' in g:
362 | self.group = archive.read_control_groups_h5(g['control_groups'], verbose=self.verbose)
363 |
364 | self.vprint('Loaded from archive. Note: Must reconfigure to run again.')
365 | self.configured = False
366 |
367 | # Re-link groups
368 | # TODO: cleaner logic
369 | for _, cg in self.group.items():
370 | cg.link(self.input)
371 |
372 | if configure:
373 | self.configure()
374 |
375 | def archive(self, h5=None):
376 | """
377 | Archive all data to an h5 handle or filename.
378 |
379 | If no file is given, a file based on the fingerprint will be created.
380 |
381 | """
382 | if not h5:
383 | h5 = 'astra_' + self.fingerprint() + '.h5'
384 |
385 | if isinstance(h5, str):
386 | h5 = os.path.expandvars(h5)
387 | g = h5py.File(h5, 'w')
388 | self.vprint(f'Archiving to file {h5}')
389 | else:
390 | # store directly in the given h5 handle
391 | g = h5
392 |
393 | # Write basic attributes
394 | archive.astra_init(g)
395 |
396 | # Initial particles
397 | if self.initial_particles:
398 | self.initial_particles.write(g, name='initial_particles')
399 |
400 | # Fieldmaps
401 | if self.fieldmap:
402 | archive.write_fieldmap_h5(g, self.fieldmap, name='fieldmap')
403 |
404 | # All input
405 | archive.write_input_h5(g, self.input)
406 |
407 | # All output
408 | archive.write_output_h5(g, self.output)
409 |
410 | # Control groups
411 | if self.group:
412 | archive.write_control_groups_h5(g, self.group, name='control_groups')
413 |
414 | return h5
415 |
416 | def write_fieldmaps(self, path=None):
417 | """
418 | Writes any loaded fieldmaps to path
419 | """
420 | if path is None:
421 | path = self.path
422 |
423 | if self.fieldmap:
424 | write_fieldmaps(self.fieldmap, path)
425 | self.vprint(f'{len(self.fieldmap)} fieldmaps written to {path}')
426 |
427 | def write_input(self, input_filename=None, path=None, make_symlinks=True):
428 | """
429 | Writes all input. If fieldmaps have been loaded, these will also be written.
430 | """
431 |
432 | if path is None:
433 | path = self.path
434 |
435 | if self.initial_particles:
436 | fname = self.write_initial_particles(path=path)
437 | self.input['newrun']['distribution'] = fname
438 |
439 | self.write_fieldmaps(path=path)
440 |
441 | self.write_input_file(path=path, make_symlinks=make_symlinks)
442 |
443 | def write_input_file(self, path=None, make_symlinks=True):
444 | if path is None:
445 | path = self.path
446 | input_file = self.input_file
447 | else:
448 | input_file = os.path.join(path, 'astra.in')
449 |
450 | writers.write_namelists(self.input, input_file, make_symlinks=make_symlinks, verbose=self.verbose)
451 |
452 | def write_initial_particles(self, fname=None, path=None):
453 | if path is None:
454 | path = self.path
455 |
456 | fname = fname or os.path.join(path, 'astra.particles')
457 | #
458 | if len(self.initial_particles) == 1:
459 | probe = True
460 | else:
461 | probe = False
462 | self.initial_particles.write_astra(fname, probe=probe)
463 | self.vprint(f'Initial particles written to {fname}')
464 | return fname
465 |
466 | def plot(self, y=['sigma_x', 'sigma_y'], x='mean_z', xlim=None, ylim=None, ylim2=None, y2=[],
467 | nice=True,
468 | include_layout=True,
469 | include_labels=False,
470 | include_particles=True,
471 | include_legend=True,
472 | return_figure=False,
473 | **kwargs):
474 | """
475 | Plots stat output multiple keys.
476 |
477 | If a list of ykeys2 is given, these will be put on the right hand axis. This can also be given as a single key.
478 |
479 | Logical switches:
480 | nice: a nice SI prefix and scaling will be used to make the numbers reasonably sized. Default: True
481 |
482 | include_legend: The plot will include the legend. Default: True
483 |
484 | include_particles: Plot the particle statistics as dots. Default: True
485 |
486 | include_layout: the layout plot will be displayed at the bottom. Default: True
487 |
488 | include_labels: the layout will include element labels. Default: False
489 |
490 | return_figure: return the figure object for further manipulation. Default: False
491 |
492 | If there is no output to plot, the fieldmaps will be plotted with .plot_fieldmaps
493 |
494 | """
495 |
496 | # Just plot fieldmaps if there are no stats
497 | if not self.output['stats']:
498 | return plot_fieldmaps(self, xlim=xlim, **kwargs)
499 |
500 | return plot_stats_with_layout(self, ykeys=y, ykeys2=y2,
501 | xkey=x, xlim=xlim, ylim=ylim, ylim2=ylim2,
502 | nice=nice,
503 | include_layout=include_layout,
504 | include_labels=include_labels,
505 | include_particles=include_particles,
506 | include_legend=include_legend,
507 | return_figure=return_figure,
508 | **kwargs)
509 |
510 | def plot_fieldmaps(self, **kwargs):
511 | return plot_fieldmaps(self, **kwargs)
512 |
513 | def __getitem__(self, key):
514 | """
515 | Convenience syntax to get a header or element attribute.
516 |
517 | Special syntax:
518 |
519 | end_X
520 | will return the final item in a stat array X
521 | Example:
522 | 'end_norm_emit_x'
523 |
524 | particles:N
525 | will return a ParticleGroup N from the .particles list
526 | Example:
527 | 'particles:-1'
528 | returns the readback of the final particles
529 | particles:N:Y
530 | ParticleGroup N's property Y
531 | Example:
532 | 'particles:-1:sigma_x'
533 | returns sigma_x from the end of the particles list.
534 |
535 |
536 | See: __setitem__
537 | """
538 |
539 | # Object attributes
540 | if hasattr(self, key):
541 | return getattr(self, key)
542 |
543 | # Send back top level input (namelist) or group object.
544 | # Do not add these to __setitem__. The user shouldn't be allowed to change them as a whole,
545 | # because it will break all the links.
546 | if key in self.group:
547 | return self.group[key]
548 | if key in self.input:
549 | return self.input[key]
550 |
551 | if key.startswith('end_'):
552 | key2 = key[len('end_'):]
553 | assert key2 in self.output['stats'], f'{key} does not have valid output stat: {key2}'
554 | return self.output['stats'][key2][-1]
555 |
556 | if key.startswith('particles:'):
557 | key2 = key[len('particles:'):]
558 | x = key2.split(':')
559 | if len(x) == 1:
560 | return self.particles[int(x[0])]
561 | else:
562 | return self.particles[int(x[0])][x[1]]
563 |
564 | # key isn't an ele or group, should have property s
565 |
566 | x = key.split(':')
567 | assert len(x) == 2, f'{x} was not found in group or input dict, so should have : '
568 | name, attrib = x[0], x[1]
569 |
570 | # Look in input and group
571 | if name in self.input:
572 | return self.input[name][attrib]
573 | elif name in self.group:
574 | return self.group[name][attrib]
575 |
576 | def __setitem__(self, key, item):
577 | """
578 | Convenience syntax to set namelist or group attribute.
579 | attribute_string should be 'header:key' or 'ele_name:key'
580 |
581 | Examples of attribute_string: 'header:Np', 'SOL1:solenoid_field_scale'
582 |
583 | Settable attributes can also be given:
584 |
585 | ['stop'] = 1.2345 will set Impact.stop = 1.2345
586 |
587 | """
588 |
589 | # Set attributes
590 | if hasattr(self, key):
591 | setattr(self, key, item)
592 | return
593 |
594 | # Must be in input or group
595 | name, attrib = key.split(':')
596 | if name in self.input:
597 | self.input[name][attrib] = item
598 | elif name in self.group:
599 | self.group[name][attrib] = item
600 | else:
601 | raise ValueError(f'{name} does not exist in eles or groups of the Impact object.')
602 |
603 |
604 |
605 | # Tracking
606 | #---------
607 | def track(self, particles, z=None):
608 | """
609 | Track a ParticleGroup. An optional stopping z can be given.
610 |
611 | If successful, returns a ParticleGroup with the final particles.
612 |
613 | Otherwise, returns None
614 |
615 | """
616 |
617 | self.initial_particles = particles
618 | if z is not None:
619 | self['output:zstop'] = z
620 |
621 | # Assure phase space output is turned on
622 | nr = self.input['newrun']
623 | if 'zphase' not in nr:
624 | nr['zphase'] = 1
625 | if nr['zphase'] < 1:
626 | nr['zphase'] = 1
627 | # Turn particle output on.
628 | nr['phases'] = True
629 |
630 | self.run()
631 |
632 | if 'particles' in self.output:
633 | if len(self.output['particles']) == 0:
634 | return None
635 |
636 | final_particles = self.output['particles'][-1]
637 |
638 | # Special case to remove probe particles
639 | if len(self.initial_particles) == 1:
640 | final_particles = final_particles[-1]
641 | return final_particles
642 |
643 | else:
644 | return None
645 |
646 | def track1(self,
647 | x0=0,
648 | px0=0,
649 | y0=0,
650 | py0=0,
651 | z0=0,
652 | pz0=1e-15,
653 | t0=0,
654 | weight=1,
655 | status=1,
656 | z=None, # final z
657 | species='electron'):
658 | """
659 | Tracks a single particle with starting coordinates:
660 | x0, y0, z0 in meters
661 | px0, py0, pz0 in eV/c
662 | t0 in seconds
663 |
664 | to a position 'z' in meters
665 |
666 | If successful, returns a ParticleGroup with the final particle.
667 |
668 | Otherwise, returns None
669 |
670 | """
671 | p0 = single_particle(x=x0, px=px0, y=y0, py=py0, z=z0, pz=pz0, t=t0, weight=weight, status=status, species=species)
672 | return self.track(p0, z=z)
673 |
674 |
675 |
676 | @classmethod
677 | @functools.wraps(astra_from_tao)
678 | def from_tao(cls, tao):
679 | return astra_from_tao(tao, cls=cls)
680 |
681 |
682 |
683 | def set_astra(astra_object, generator_input, settings, verbose=False):
684 | """
685 | Searches astra and generator objects for keys in settings, and sets their values to the appropriate input
686 | """
687 | astra_input = astra_object.input # legacy syntax
688 |
689 | for k, v in settings.items():
690 | found = False
691 |
692 | # Check for direct settable attribute
693 | if ':' in k:
694 | astra_object[k] = v
695 | continue
696 |
697 | for nl in astra_input:
698 | if k in astra_input[nl]:
699 | found = True
700 | if verbose:
701 | print(k, 'is in astra', nl)
702 | astra_input[nl][k] = settings[k]
703 | if not found:
704 | if k in generator_input:
705 | found = True
706 | generator_input[k] = settings[k]
707 | if verbose:
708 | print(k, 'is in generator')
709 |
710 | if not found and verbose:
711 | print(k, 'not found')
712 | assert found
713 |
714 |
715 | def recommended_spacecharge_mesh(n_particles):
716 | """
717 | ! --------------------------------------------------------
718 | ! Suggested Nrad, Nlong_in settings from:
719 | ! A. Bartnik and C. Gulliford (Cornell University)
720 | !
721 | ! Nrad = 35, Nlong_in = 75 !28K
722 | ! Nrad = 29, Nlong_in = 63 !20K
723 | ! Nrad = 20, Nlong_in = 43 !10K
724 | ! Nrad = 13, Nlong_in = 28 !4K
725 | ! Nrad = 10, Nlong_in = 20 !2K
726 | ! Nrad = 8, Nlong_in = 16 !1K
727 | !
728 | ! Nrad ~ round(3.3*(n_particles/1000)^(2/3) + 5)
729 | ! Nlong_in ~ round(9.2*(n_particles/1000)^(0.603) + 6.5)
730 | !
731 | !
732 | """
733 | if n_particles < 1000:
734 | # Set a minimum
735 | nrad = 8
736 | nlong_in = 16
737 | else:
738 | # Prefactors were recalculated from above note.
739 | nrad = round(3.3e-2 * n_particles ** (2 / 3) + 5)
740 | nlong_in = round(0.143 * n_particles ** (0.603) + 6.5)
741 | return {'nrad': nrad, 'nlong_in': nlong_in}
742 |
743 |
744 | def run_astra(settings=None,
745 | astra_input_file=None,
746 | workdir=None,
747 | command='$ASTRA_BIN',
748 | timeout=2500,
749 | verbose=False):
750 | """
751 | Run Astra.
752 |
753 | settings: dict with keys that can appear in an Astra input file.
754 | """
755 | if verbose:
756 | print('run_astra')
757 |
758 | # Make astra object
759 | A = Astra(command=command, input_file=astra_input_file, workdir=workdir)
760 |
761 | A.timeout = timeout
762 | A.verbose = verbose
763 |
764 | A.input['newrun']['l_rm_back'] = True # Remove backwards particles
765 |
766 | # Set inputs
767 | if settings:
768 | set_astra(A, {}, settings, verbose=verbose)
769 |
770 | # Run
771 | A.run()
772 |
773 | return A
774 |
775 |
776 | def run_astra_with_generator(settings=None,
777 | astra_input_file=None,
778 | generator_input_file=None,
779 | workdir=None,
780 | command='$ASTRA_BIN',
781 | command_generator='$GENERATOR_BIN',
782 | timeout=2500, verbose=False,
783 | auto_set_spacecharge_mesh=True):
784 | """
785 | Run Astra with particles generated by Astra's generator.
786 |
787 | settings: dict with keys that can appear in an Astra or Generator input file.
788 | """
789 |
790 | assert astra_input_file, 'No astra input file'
791 |
792 | # Call simpler evaluation if there is no generator:
793 | if not generator_input_file:
794 | return run_astra(settings=settings,
795 | astra_input_file=astra_input_file,
796 | workdir=workdir,
797 | command=command,
798 | timeout=timeout,
799 | verbose=verbose)
800 |
801 | if verbose:
802 | print('run_astra_with_generator')
803 |
804 | # Make astra and generator objects
805 | A = Astra(command=command, input_file=astra_input_file, workdir=workdir)
806 | A.timeout = timeout
807 | A.verbose = verbose
808 | G = AstraGenerator(command=command_generator, input_file=generator_input_file, workdir=workdir)
809 | G.verbose = verbose
810 |
811 | A.input['newrun']['l_rm_back'] = True # Remove backwards particles
812 |
813 | # Set inputs
814 | if settings:
815 | set_astra(A, G.input, settings, verbose=verbose)
816 |
817 | # Attach generator input. This is non-standard.
818 | A.generator_input = G.input
819 |
820 | if auto_set_spacecharge_mesh:
821 | n_particles = G.input['ipart']
822 | sc_settings = recommended_spacecharge_mesh(n_particles)
823 | A.input['charge'].update(sc_settings)
824 | if verbose:
825 | print('set spacecharge mesh for n_particles:', n_particles, 'to', sc_settings)
826 |
827 | # Run Generator
828 | G.run()
829 | A.initial_particles = G.output['particles']
830 | A.run()
831 | if verbose:
832 | print('run_astra_with_generator finished')
833 |
834 | return A
835 | # Usage:
836 | # Aout = run_astra_with_generator(settings={'zstop':2, 'lspch':False}, astra_input_file=ASTRA_TEMPLATE,generator_input_file=GENERATOR_TEMPLATE, astra_bin=ASTRA_BIN, generator_bin=GENERATOR_BIN, verbose=True)
837 |
838 |
839 |
840 | DEFAULT_INPUT = {
841 | 'newrun': {'auto_phase': True,
842 | 'check_ref_part': False,
843 | 'distribution': 'astra.particles',
844 | 'h_max': 0.0075,
845 | 'h_min': 0,
846 | 'head': "'Drift example'",
847 | 'phase_scan': True,
848 | 'q_schottky': 0,
849 | 'run': 1,
850 | 'toff': 0,
851 | 'track_all': True,
852 | 'xoff': 0,
853 | 'yoff': 0},
854 |
855 | 'output': {'c_emits': True,
856 | 'cathodes': False,
857 | 'emits': True,
858 | 'high_res': True,
859 | 'landfs': True,
860 | 'larmors': False,
861 | 'lmagnetized': True,
862 | 'lproject_emit': False,
863 | 'lsub_rot': False,
864 | 'phases': True,
865 | 'refs': True,
866 | 'tchecks': False,
867 | 'tracks': True,
868 | 'zemit': 100,
869 | 'zphase': 1,
870 | 'zstart': 0,
871 | 'zstop': 1},
872 | 'charge': {
873 | 'lspch': False,
874 | 'cell_var': 2,
875 | 'lmirror': False,
876 | 'lspch3d': False,
877 | 'max_count': 10,
878 | 'max_scale': 0.01,
879 | 'min_grid': 4e-07,
880 | 'nlong_in': 43,
881 | 'nrad': 20,
882 | 'nxf': 32,
883 | 'nyf': 32,
884 | 'nzf': 32},
885 | }
886 |
887 |
888 |
889 |
890 |
891 |
--------------------------------------------------------------------------------