├── .github └── workflows │ ├── ci_cd.yml │ └── ci_cd_test.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat └── source │ ├── _static │ ├── communications.pdf │ ├── communications.png │ ├── custom.css │ ├── icon.png │ ├── icon.svg │ ├── logo.png │ └── logo.svg │ ├── _templates │ └── navigation.html │ ├── citing.rst │ ├── conf.py │ ├── contacts.rst │ ├── faqs.rst │ ├── fundings.rst │ ├── index.rst │ ├── install.rst │ ├── license.rst │ └── publications.rst ├── lyncs ├── __init__.py ├── field │ ├── __init__.py │ ├── array.py │ ├── base.py │ ├── contractions.py │ ├── random.py │ ├── reductions.py │ ├── types │ │ ├── __init__.py │ │ ├── base.py │ │ ├── generic.py │ │ └── quantum.py │ └── ufuncs.py ├── io │ ├── __init__.py │ ├── base.py │ └── lime │ │ ├── __init__.py │ │ └── lime.py └── lattice.py ├── notebooks ├── .gitignore └── Lyncs.ipynb ├── pyproject.toml ├── setup.cfg ├── setup.py └── test ├── __init__.py ├── field ├── test_array.py ├── test_base.py └── test_numpy.py ├── test_lattice.py └── test_lime.py /.github/workflows/ci_cd.yml: -------------------------------------------------------------------------------- 1 | # This workflow updates the packages on PyPI 2 | 3 | name: build & test 4 | 5 | on: 6 | push: 7 | branches: 8 | - 'master' 9 | paths-ignore: 10 | - 'docs/**' 11 | 12 | jobs: 13 | build-n-publish: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: '3.x' 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | sudo apt-get install -y libopenmpi-dev openmpi-bin 28 | 29 | - name: Install from source 30 | run: | 31 | pip install -e .[all] 32 | 33 | - name: Run tests 34 | run: | 35 | pytest -v --cov-report=xml 36 | export CODECOV_TOKEN="${{ secrets.CODECOV_TOKEN }}" 37 | bash <(curl -s https://codecov.io/bash) -f ./coverage.xml -n lyncs 38 | 39 | - name: Upload if not up to date 40 | env: 41 | TWINE_USERNAME: __token__ 42 | TWINE_PASSWORD: ${{ secrets.pypi_password }} 43 | run: | 44 | pip uninstall -y lyncs 45 | pip install lyncs==$(lyncs_setuptools version) || ( 46 | pip install twine 47 | python setup.py sdist 48 | twine upload dist/* 49 | count=0 50 | while ! pip install lyncs==$(lyncs_setuptools version) && [ $count -lt 20 ]; do 51 | sleep 1 52 | count=$((count+1)) 53 | done 54 | ) 55 | 56 | clean-run: 57 | 58 | needs: build-n-publish 59 | runs-on: ubuntu-latest 60 | strategy: 61 | matrix: 62 | python-version: [3.6, 3.7, 3.8] 63 | 64 | steps: 65 | - uses: actions/checkout@v2 66 | 67 | - name: Set up Python ${{ matrix.python-version }} 68 | uses: actions/setup-python@v2 69 | with: 70 | python-version: ${{ matrix.python-version }} 71 | 72 | - name: Install dependencies 73 | run: | 74 | python -m pip install --upgrade pip 75 | sudo apt-get install -y libopenmpi-dev openmpi-bin 76 | 77 | - name: Install via pip 78 | run: | 79 | pip install lyncs[all] 80 | 81 | - name: Run tests 82 | run: | 83 | pytest -v --import-mode=importlib 84 | -------------------------------------------------------------------------------- /.github/workflows/ci_cd_test.yml: -------------------------------------------------------------------------------- 1 | # This workflow builds and tests PRs 2 | 3 | name: PR build & test 4 | 5 | on: 6 | pull_request: 7 | branches: 8 | - 'master' 9 | paths-ignore: 10 | - 'docs/**' 11 | 12 | jobs: 13 | build-n-publish: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | with: 19 | ref: ${{ github.head_ref }} 20 | 21 | - name: Set up Python 22 | uses: actions/setup-python@v1 23 | with: 24 | python-version: '3.x' 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | sudo apt-get install -y libopenmpi-dev openmpi-bin 30 | 31 | - name: Applying black formatting 32 | run: | 33 | pip install black 34 | black --diff . 35 | black . 36 | 37 | - name: Pushing changes if any 38 | uses: stefanzweifel/git-auto-commit-action@v4 39 | with: 40 | commit_message: Applying black formatting (from Github Action) 41 | commit_user_name: sbacchio 42 | commit_user_email: s.bacchio@gmail.com 43 | commit_author: Simone Bacchio 44 | 45 | - name: Install from source 46 | run: | 47 | pip install -e . 48 | 49 | - name: Pylint output 50 | run: | 51 | pip install lyncs_setuptools[pylint] 52 | badge=$(lyncs_pylint_badge --fail-under 8 . | sed "s/\&/\\\&/g") 53 | badge_line=$(awk '/!\[pylint\]/ {print FNR}' README.md) 54 | sed -i "${badge_line}s#.*#${badge}#" README.md 55 | 56 | - name: Pushing changes if any 57 | uses: stefanzweifel/git-auto-commit-action@v4 58 | with: 59 | commit_message: Updating pylint score (from Github Action) 60 | commit_user_name: sbacchio 61 | commit_user_email: s.bacchio@gmail.com 62 | commit_author: Simone Bacchio 63 | 64 | - name: Clean notebooks 65 | run: | 66 | pip install lyncs[notebook] 67 | jupyter-nbconvert --clear-output notebooks/*.ipynb 68 | 69 | - name: Pushing changes if any 70 | uses: stefanzweifel/git-auto-commit-action@v4 71 | with: 72 | commit_message: Cleaning notebooks, no output wanted (from Github Action) 73 | commit_user_name: sbacchio 74 | commit_user_email: s.bacchio@gmail.com 75 | commit_author: Simone Bacchio 76 | 77 | - name: Testing notebooks 78 | run: | 79 | jupyter-nbconvert --to html --execute notebooks/*.ipynb 80 | 81 | - name: Run tests 82 | run: | 83 | pip install -e .[test] 84 | pytest -v 85 | 86 | - name: Run tests for all 87 | run: | 88 | pip install -e .[all] 89 | pytest -v 90 | 91 | - name: Run lyncs_setuptools 92 | run: | 93 | lyncs_setuptools 94 | 95 | - name: Upload if not up to date 96 | env: 97 | TWINE_USERNAME: __token__ 98 | TWINE_PASSWORD: ${{ secrets.test_pypi_password }} 99 | run: | 100 | pip uninstall -y lyncs 101 | pip install --extra-index-url https://test.pypi.org/simple/ lyncs==$(lyncs_setuptools version) || ( 102 | pip install twine 103 | python setup.py sdist 104 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 105 | count=0 106 | while ! pip install --extra-index-url https://test.pypi.org/simple/ lyncs==$(lyncs_setuptools version) && [ $count -lt 20 ]; do 107 | sleep 1 108 | count=$((count+1)) 109 | done 110 | ) 111 | 112 | clean-run: 113 | 114 | needs: build-n-publish 115 | runs-on: ubuntu-latest 116 | strategy: 117 | matrix: 118 | python-version: [3.6, 3.7, 3.8] 119 | 120 | steps: 121 | - uses: actions/checkout@v2 122 | 123 | - name: Set up Python ${{ matrix.python-version }} 124 | uses: actions/setup-python@v2 125 | with: 126 | python-version: ${{ matrix.python-version }} 127 | 128 | - name: Install dependencies 129 | run: | 130 | python -m pip install --upgrade pip 131 | sudo apt-get install -y libopenmpi-dev openmpi-bin 132 | 133 | - name: Install via pip 134 | run: | 135 | pip install --extra-index-url https://test.pypi.org/simple/ lyncs[all] 136 | 137 | - name: Run tests 138 | run: | 139 | pytest -v --import-mode=importlib 140 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | # Default build directory 35 | build/ 36 | tmp/ 37 | 38 | # Edited files 39 | *~ 40 | 41 | # Byte-compiled / optimized / DLL files 42 | __pycache__/ 43 | *.py[cod] 44 | *$py.class 45 | 46 | # Distribution / packaging 47 | .Python 48 | build/ 49 | develop-eggs/ 50 | dist/ 51 | downloads/ 52 | eggs/ 53 | .eggs/ 54 | lib/ 55 | lib64/ 56 | parts/ 57 | sdist/ 58 | var/ 59 | wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | MANIFEST 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .coverage 79 | .coverage.* 80 | .cache 81 | nosetests.xml 82 | coverage.xml 83 | *.cover 84 | .hypothesis/ 85 | .pytest_cache/ 86 | .ipynb_checkpoints/ 87 | .tmp/ 88 | 89 | # Dask visualize 90 | mydask.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, Simone Bacchio 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A python API for Lattice QCD applications 2 | 3 | [![python](https://img.shields.io/pypi/pyversions/lyncs.svg?logo=python&logoColor=white)](https://pypi.org/project/lyncs/) 4 | [![pypi](https://img.shields.io/pypi/v/lyncs.svg?logo=python&logoColor=white)](https://pypi.org/project/lyncs/) 5 | [![license](https://img.shields.io/github/license/Lyncs-API/lyncs?logo=github&logoColor=white)](https://github.com/Lyncs-API/lyncs/blob/master/LICENSE) 6 | [![build & test](https://img.shields.io/github/workflow/status/Lyncs-API/lyncs/build%20&%20test?logo=github&logoColor=white)](https://github.com/Lyncs-API/lyncs/actions) 7 | [![codecov](https://img.shields.io/codecov/c/github/Lyncs-API/lyncs?logo=codecov&logoColor=white)](https://codecov.io/gh/Lyncs-API/lyncs) 8 | [![pylint](https://img.shields.io/badge/pylint%20score-7.4%2F10-yellow?logo=python&logoColor=white)](http://pylint.pycqa.org/) 9 | [![black](https://img.shields.io/badge/code%20style-black-000000.svg?logo=codefactor&logoColor=white)](https://github.com/ambv/black) 10 | 11 | ![alt text](https://github.com/sbacchio/lyncs/blob/master/docs/source/_static/logo.png "Lyncs") 12 | 13 | Lyncs is a Python API for Lattice QCD applications currently under development with a first 14 | released version expected by the end of Q2 of 2020. Lyncs aims to bring several popular 15 | libraries for Lattice QCD under a common framework. Lyncs will interface with libraries for 16 | GPUs and CPUs in a way that can accommodate additional computing architectures as these 17 | arise, achieving the best performance for the calculations while maintaining the same high- 18 | level workflow. Lyncs is one of 10 applications supported by PRACE-6IP, WP8 "Forward 19 | Looking Software Solutions". 20 | 21 | Lyncs distributes calculations using Dask, with bindings to the libraries performed 22 | automatically via Cppyy. Multiple distributed tasks can be executed in parallel and different 23 | computing units can be used at the same time to fully exploit the machine allocation. The data 24 | redistribution is efficiently managed by the API. We expect this model of distributing tasks to 25 | be well suited for modular architectures, allowing to flexibly distribute 26 | work between the different modules. 27 | While Lyncs is designed to quite generally allow linking to multiple libraries, we will 28 | focus on a set of targeted packages that include tmLQCD, DDalphaAMG, PLEGMA and QUDA. 29 | 30 | 31 | ## Installation: 32 | 33 | The package can be installed via `pip`: 34 | 35 | ``` 36 | pip install [--user] lyncs 37 | ``` 38 | 39 | ### Sub-modules and plugins 40 | 41 | Sub-modules and plugins can also be installed via `pip` with: 42 | 43 | ``` 44 | pip install [--user] lyncs[NAME] 45 | ``` 46 | 47 | where NAME is the name of the sub-module. Hereafter the list of the available sub-modules. 48 | 49 | #### Groups 50 | 51 | - `all`: installs all the plugins enabling all Lyncs' functionalities (e.g. hmc, visualization, etc..). 52 | Note it does not install libraries with strong dependencies like MPI, GPUs, etc. 53 | Simple CPUs libraries may be installed. 54 | 55 | - `mpi`: installs all MPI libraries. 56 | 57 | - `cuda`: installs all NVIDIA GPUs libraries. 58 | 59 | - `io`: installs all IO libraries for full support of IO formats (clime, HDF5, etc..). 60 | 61 | #### LQCD libraires 62 | 63 | - `DDalphaAMG`: multigrid solver library for Wilson and Twisted mass fermions. 64 | 65 | - `QUDA`: NVIDIA GPUs library for LQCD. 66 | 67 | - `clime`: IO library for c-lime format. 68 | 69 | - `tmLQCD`: legacy code of the Extended Twisted Mass collaboration. 70 | 71 | ## Goals: 72 | 73 | - Include several Lattice QCD libraries under a single framework 74 | - Provide crosschecks and benchmarks of different libraries' implementations 75 | - Handle memory distribution and mapping 76 | - Allow for multitasking parallelization and unequal distribution 77 | 78 | 79 | ## Dependencies: 80 | 81 | ### Python utils: 82 | 83 | - numpy: Multidimensional arrays in python 84 | - dask: Utility for sceduling distributed tasks 85 | - cppyy: Automatic binding to C/C++ libraries 86 | - (optional) dask-mpi, mpi4py: MPI for python 87 | - (under consideration) numba: JIT compilation of python code 88 | - others: xmltodict, 89 | 90 | ### LQCD libraries: 91 | 92 | - QUDA: Lattice QCD operators and solvers on GPUs 93 | - DDalphaAMG: Multigrid solver on CPUs 94 | - tmLQCD: HMC routines on CPUs 95 | - PLEGMA: contraction kernels on GPUs 96 | 97 | ### Extras requirements: 98 | 99 | - Jupyter notebook/lab: for visualizing and running the avaialble notebooks 100 | - dask-labextension: utils for profiling the task execution in Jupyter lab 101 | 102 | 103 | ## Fundings: 104 | 105 | - PRACE-6IP, Grant agreement ID: 823767, Project name: LyNcs. 106 | 107 | 108 | ## Authors: 109 | 110 | - Simone Bacchio, The Cyprus Institute 111 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/communications.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyncs-API/lyncs/12accf4dbc0dc247dba8cf183bc5007e293c6034/docs/source/_static/communications.pdf -------------------------------------------------------------------------------- /docs/source/_static/communications.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyncs-API/lyncs/12accf4dbc0dc247dba8cf183bc5007e293c6034/docs/source/_static/communications.png -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | @import url("default.css"); 2 | div.sphinxsidebar span.caption-text { 3 | font-family: Georgia, serif; 4 | color: #444; 5 | font-size: 24px; 6 | font-weight: normal; 7 | } 8 | 9 | div.body p.caption { 10 | text-align: justify; 11 | } 12 | -------------------------------------------------------------------------------- /docs/source/_static/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyncs-API/lyncs/12accf4dbc0dc247dba8cf183bc5007e293c6034/docs/source/_static/icon.png -------------------------------------------------------------------------------- /docs/source/_static/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 23 | 25 | 33 | 39 | 40 | 48 | 54 | 55 | 58 | 62 | 66 | 67 | 75 | 81 | 82 | 90 | 96 | 97 | 106 | 112 | 113 | 121 | 127 | 128 | 131 | 135 | 139 | 140 | 143 | 147 | 151 | 152 | 155 | 159 | 163 | 164 | 167 | 171 | 175 | 176 | 179 | 183 | 187 | 188 | 191 | 195 | 199 | 200 | 210 | 220 | 230 | 240 | 250 | 260 | 271 | 282 | 293 | 304 | 315 | 326 | 327 | 349 | 351 | 352 | 354 | image/svg+xml 355 | 357 | 358 | 359 | 360 | 361 | 367 | 370 | 376 | 382 | 388 | 393 | 399 | 405 | 410 | 415 | 421 | 426 | 432 | 437 | 443 | 449 | 455 | 461 | 467 | 468 | 473 | 474 | 475 | -------------------------------------------------------------------------------- /docs/source/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyncs-API/lyncs/12accf4dbc0dc247dba8cf183bc5007e293c6034/docs/source/_static/logo.png -------------------------------------------------------------------------------- /docs/source/_templates/navigation.html: -------------------------------------------------------------------------------- 1 | {{ toctree(includehidden=theme_sidebar_includehidden, collapse=theme_sidebar_collapse) }} 2 | {% if theme_extra_nav_links %} 3 |
4 |
    5 | {% for text, uri in theme_extra_nav_links.items() %} 6 |
  • {{ text }}
  • 7 | {% endfor %} 8 |
9 | {% endif %} 10 | -------------------------------------------------------------------------------- /docs/source/citing.rst: -------------------------------------------------------------------------------- 1 | Citing & Acknowledgment 2 | ======================= 3 | 4 | If you need to cite Lyncs in your work, please use the following reference. 5 | 6 | **TODO** 7 | 8 | Acknowledgment 9 | -------------- 10 | 11 | If your work used Lyncs and you want to acknowledge the software, please use the following sentence. 12 | 13 | **TODO** 14 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "Lyncs" 21 | copyright = "2019, Simone Bacchio" 22 | author = "Simone Bacchio" 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = "0.0.0" 26 | master_doc = "index" 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | "sphinx.ext.autosectionlabel", 35 | ] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ["_templates"] 39 | 40 | # List of patterns, relative to source directory, that match files and 41 | # directories to ignore when looking for source files. 42 | # This pattern also affects html_static_path and html_extra_path. 43 | exclude_patterns = [] 44 | 45 | 46 | # -- Options for HTML output ------------------------------------------------- 47 | 48 | # The theme to use for HTML and HTML Help pages. See the documentation for 49 | # a list of builtin themes. 50 | # 51 | html_theme = "alabaster" 52 | 53 | # Add any paths that contain custom static files (such as style sheets) here, 54 | # relative to this directory. They are copied after the builtin static files, 55 | # so a file named "default.css" will overwrite the builtin "default.css". 56 | html_static_path = ["_static"] 57 | 58 | html_favicon = "_static/icon.png" 59 | 60 | html_theme_options = { 61 | "logo": "logo.png", 62 | "body_text_align": "justify", 63 | "github_user": "sbacchio", 64 | "github_repo": "lyncs", 65 | "touch_icon": "icon.png", 66 | } 67 | -------------------------------------------------------------------------------- /docs/source/contacts.rst: -------------------------------------------------------------------------------- 1 | Contacts 2 | ======== 3 | 4 | For any help on Lyncs, please refer to this :ref:`documentation `. 5 | 6 | FAQs are collected in :ref:`FAQs `. 7 | 8 | If you are facing any issue running Lyncs, please open an issue on the GitHub page `Issues `_. 9 | 10 | Authors 11 | ------- 12 | 13 | Simone Bacchio 14 | ^^^^^^^^^^^^^^ 15 | 16 | - `Institution `_: The Cyprus Institute 17 | - `Email `_: s.bacchio AT gmail.com 18 | -------------------------------------------------------------------------------- /docs/source/faqs.rst: -------------------------------------------------------------------------------- 1 | FAQs 2 | ==== 3 | 4 | 5 | -------------------------------------------------------------------------------- /docs/source/fundings.rst: -------------------------------------------------------------------------------- 1 | Fundings 2 | ======== 3 | 4 | Lyncs acknowledges the following fundings: 5 | 6 | - **PRACE-6IP**: Lyncs is one of 10 applications supported by PRACE-6IP, WP8 "*Forward Looking Software Solutions*". Grant agreement ID: 823767, Project name: LyNcs. 7 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | Lyncs 3 | ===== 4 | 5 | 6 | A python API for lattice QCD applications 7 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 | 9 | Lyncs is a Python API for Lattice QCD applications currently under development with a first 10 | released version expected by the end of Q2 of 2020. Lyncs aims to bring several popular 11 | libraries for Lattice QCD under a common framework. Lyncs will interface with libraries for 12 | GPUs and CPUs in a way that can accommodate additional computing architectures as these 13 | arise, achieving the best performance for the calculations while maintaining the same high- 14 | level workflow. Lyncs is one of 10 applications supported by PRACE-6IP, WP8 "Forward 15 | Looking Software Solutions". 16 | 17 | 18 | 19 | .. toctree:: 20 | :maxdepth: 1 21 | :hidden: 22 | :caption: Getting Started 23 | 24 | install.rst 25 | 26 | 27 | .. toctree:: 28 | :maxdepth: 1 29 | :hidden: 30 | :caption: User Guide 31 | 32 | 33 | 34 | .. toctree:: 35 | :maxdepth: 1 36 | :hidden: 37 | :caption: Help & Reference 38 | 39 | faqs.rst 40 | contacts.rst 41 | citing.rst 42 | publications.rst 43 | fundings.rst 44 | license.rst 45 | 46 | 47 | .. Indices and tables 48 | .. ================== 49 | .. 50 | .. * :ref:`genindex` 51 | .. * :ref:`modindex` 52 | .. * :ref:`search` 53 | -------------------------------------------------------------------------------- /docs/source/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | You can install Lyncs with ``conda``, with ``pip``, or by installing from source. 5 | 6 | Conda 7 | ----- 8 | 9 | 10 | Pip 11 | --- 12 | 13 | 14 | Install from source 15 | ------------------- 16 | 17 | -------------------------------------------------------------------------------- /docs/source/license.rst: -------------------------------------------------------------------------------- 1 | ../../LICENSE -------------------------------------------------------------------------------- /docs/source/publications.rst: -------------------------------------------------------------------------------- 1 | Publications 2 | ============ 3 | 4 | The following work have been published on Lyncs. 5 | -------------------------------------------------------------------------------- /lyncs/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lyncs, a python API for LQCD applications 3 | """ 4 | __version__ = "0.0.3" 5 | 6 | from importlib import import_module 7 | 8 | import lyncs_utils as utils 9 | from .lattice import * 10 | from .field import * 11 | 12 | # Local sub-modules 13 | from . import field 14 | from . import io 15 | 16 | # Importing available Lyncs packages 17 | for pkg in [ 18 | "mpi", 19 | "cppyy", 20 | "clime", 21 | "DDalphaAMG", 22 | "tmLQCD", 23 | ]: 24 | assert pkg not in globals(), f"{pkg} already defined" 25 | try: 26 | globals()[pkg] = import_module(f"lyncs_{pkg}") 27 | except ModuleNotFoundError: 28 | pass 29 | 30 | del import_module 31 | -------------------------------------------------------------------------------- /lyncs/field/__init__.py: -------------------------------------------------------------------------------- 1 | "Submodule for Field support" 2 | 3 | __all__ = [ 4 | "Field", 5 | ] 6 | 7 | from . import types 8 | from .base import * 9 | from .array import * 10 | from .ufuncs import * 11 | from .contractions import * 12 | from .reductions import * 13 | 14 | Field = ArrayField 15 | -------------------------------------------------------------------------------- /lyncs/field/array.py: -------------------------------------------------------------------------------- 1 | """ 2 | Array class of the Field type that implements 3 | the interface to the numpy array functions 4 | """ 5 | # pylint: disable=C0103,C0303,C0330,W0221 6 | 7 | __all__ = [ 8 | "ArrayField", 9 | "NumpyBackend", 10 | ] 11 | 12 | from collections import defaultdict 13 | from functools import wraps 14 | import numpy as np 15 | from tuneit import ( 16 | TunableClass, 17 | tunable_property, 18 | derived_property, 19 | Function, 20 | function, 21 | Permutation, 22 | Variable, 23 | Tunable, 24 | finalize, 25 | ) 26 | from lyncs_utils import add_kwargs_of, compute_property, isiterable 27 | from .base import BaseField, wrap_method 28 | from .types.base import FieldType 29 | 30 | 31 | class ArrayField(BaseField, TunableClass): 32 | """ 33 | Array class of the Field type that implements 34 | the interface to the numpy array functions. 35 | """ 36 | 37 | default_dtype = "complex128" 38 | 39 | @add_kwargs_of(BaseField.__init__) 40 | def __init_attributes__( 41 | self, field=None, dtype=None, indexes_order=None, labels_order=None, **kwargs 42 | ): 43 | """ 44 | Initializes the field class. 45 | 46 | Parameters 47 | ---------- 48 | dtype: str or numpy dtype compatible 49 | Data type of the field. 50 | indexes_order: tuple 51 | The order of the field indexes (field.indexes). 52 | This also fixes the field shape (field.ordered_shape). 53 | It is a tunable parameter and the decision can be postpone. 54 | copy: bool 55 | Whether the input field should be copied. 56 | If False the field is copied only if needed 57 | otherwise the input field will be used; 58 | if True, the field is copied. 59 | """ 60 | kwargs = super().__init_attributes__(field, **kwargs) 61 | 62 | indexes_order = self._get_indexes_order( 63 | field if isinstance(field, BaseField) else None, indexes_order 64 | ) 65 | if indexes_order is not None: 66 | self.indexes_order = indexes_order 67 | 68 | self._labels_order, kwargs = self._get_labels_order( 69 | field if isinstance(field, BaseField) else None, labels_order, **kwargs 70 | ) 71 | 72 | self._dtype = np.dtype( 73 | dtype 74 | if dtype is not None 75 | else field.dtype 76 | if hasattr(field, "dtype") 77 | else ArrayField.default_dtype 78 | ) 79 | 80 | return kwargs 81 | 82 | def __initialize_value__(self, value, **kwargs): 83 | "Initializes the value of the field" 84 | 85 | if value is not None: 86 | if not self.indexes_order.fixed: 87 | raise ValueError( 88 | "Cannot initilize a field with an array without fixing the indexes_order" 89 | ) 90 | for key, val in self.labels_order: 91 | if not val.fixed: 92 | raise ValueError( 93 | "Cannot initilize a field with an array without fixing the %s order" 94 | % key 95 | ) 96 | value = np.array(value) 97 | if value.shape != self.ordered_shape: 98 | raise ValueError("Shape of field and given array do not match") 99 | 100 | self.value = self.backend.init(value, self.ordered_shape, self.dtype) 101 | return kwargs 102 | 103 | def __validate_value__(self, value, **kwargs): 104 | "Checks if the field is well defined to have a value" 105 | 106 | if not self.indexes_order.fixed and not finalize(self.value).depends_on( 107 | self.indexes_order 108 | ): 109 | raise ValueError("Value has been given but indexes_order is not fixed.") 110 | 111 | for key, val in self.labels_order: 112 | if ( 113 | not val.fixed 114 | and any((var.startswith(key) for var in self.variables)) 115 | and not finalize(self.value).depends_on(val) 116 | ): 117 | raise ValueError( 118 | "Value has been given but %s order is not fixed." % key 119 | ) 120 | 121 | self.value = value 122 | 123 | return kwargs 124 | 125 | def __update_value__(self, field, copy=False, **kwargs): 126 | "Checks if something changed wrt field and updates the field value" 127 | 128 | if copy: 129 | self.value = self.backend.copy() 130 | 131 | same_indexes = set(self.indexes).intersection(field.indexes) 132 | indexes = field.coords.extract(same_indexes).get_indexes( 133 | self.coords.extract(same_indexes) 134 | ) 135 | if indexes: 136 | labels = {key: val for key, val in field.labels_order if key in indexes} 137 | self.value = self.backend.getitem(field.indexes_order, indexes, **labels) 138 | 139 | if set(self.indexes) != set(field.indexes): 140 | if not self.size == field.size: 141 | raise ValueError("When reshaping, the size of the field cannot change") 142 | self.value = self.backend.reshape( 143 | self.ordered_shape, self.indexes_order, field.indexes_order 144 | ) 145 | 146 | if self.indexes_order != field.indexes_order and self.indexes_order.size > 1: 147 | self.value = self.backend.reorder(self.indexes_order, field.indexes_order) 148 | 149 | labels = {} 150 | coords = {} 151 | old_order = dict(field.labels_order) 152 | for key, val in self.labels_order: 153 | if key in old_order and val != old_order[key] and val.size > 1: 154 | coords[key] = val 155 | labels[key] = old_order[key] 156 | self.value = self.backend.reorder_label( 157 | key, val, old_order[key], self.indexes_order 158 | ) 159 | 160 | if self.dtype != field.dtype: 161 | self.value = self.backend.astype(self.dtype) 162 | 163 | return kwargs 164 | 165 | @add_kwargs_of(__init_attributes__) 166 | def __init__(self, field=None, value=None, **kwargs): 167 | """ 168 | Initializes the field class. 169 | 170 | Parameters 171 | ---------- 172 | value: Tunable 173 | The underlying value of the field. Not for the faint of heart. 174 | If it is given, then all the attributes of the initialization 175 | are considered proparties of the valu and no transformation 176 | will be applied. 177 | """ 178 | TunableClass.__init__(self, field) 179 | 180 | kwargs = self.__init_attributes__(field, **kwargs) 181 | 182 | if value is not None: 183 | kwargs = self.__validate_value__(value, **kwargs) 184 | elif isinstance(field, BaseField): 185 | kwargs = self.__update_value__(field, **kwargs) 186 | else: 187 | kwargs = self.__initialize_value__(field, **kwargs) 188 | 189 | if kwargs: 190 | raise ValueError("Could not resolve the following kwargs.\n %s" % kwargs) 191 | 192 | def __is__(self, other): 193 | "This is a direct implementation of __eq__, while the latter is an element wise comparison" 194 | return self is other or ( 195 | super().__eq__(other) 196 | and self.dtype == other.dtype 197 | and self.node.key == other.node.key 198 | and self.indexes_order == other.indexes_order 199 | ) 200 | 201 | __eq__ = __is__ 202 | 203 | def __bool__(self): 204 | if self.dtype == "bool": 205 | return bool(self.all().result) 206 | raise ValueError( 207 | """ 208 | The truth value of an field with more than one element is ambiguous. 209 | Use field.any() or field.all() 210 | """ 211 | ) 212 | 213 | def copy(self, value=None, **kwargs): 214 | "Creates a shallow copy of the field" 215 | return super().copy(value=value, **kwargs) 216 | 217 | @wraps(TunableClass.compute) 218 | def compute(self, **kwargs): 219 | "Adds consistency checks on the value" 220 | super().compute(**kwargs) 221 | array = self.node.value.obj 222 | assert array.shape == self.ordered_shape, "Mistmatch in the shape %s != %s" % ( 223 | array.shape, 224 | self.shape, 225 | ) 226 | assert array.dtype == self.dtype, "Mistmatch in the dtype %s != %s" % ( 227 | array.dtype, 228 | self.dtype, 229 | ) 230 | 231 | @property 232 | def backend(self): 233 | "Returns the computational backend of the field (numpy)." 234 | return NumpyBackend(self) 235 | 236 | @property 237 | def dtype(self): 238 | "Data type of the field (numpy style)" 239 | return self._dtype 240 | 241 | @dtype.setter 242 | def dtype(self, value): 243 | if self.dtype != value: 244 | self._dtype = np.dtype(value) 245 | self.value = self.backend.astype(self.dtype) 246 | 247 | def astype(self, dtype): 248 | "Changes the dtype of the field." 249 | if self.dtype == dtype: 250 | return self 251 | return self.copy(dtype=dtype) 252 | 253 | @compute_property 254 | def bytes(self): 255 | "Returns the size of the field in bytes" 256 | return self.size * self.dtype.itemsize 257 | 258 | @tunable_property 259 | def indexes_order(self): 260 | "Order of the field indexes" 261 | return Permutation(self.indexes) 262 | 263 | def reorder(self, indexes_order=None, **kwargs): 264 | "Changes the indexes_order of the field." 265 | if indexes_order is None: 266 | indexes_order = self.indexes_order.copy(reset=True) 267 | return self.copy(indexes_order=indexes_order, **kwargs) 268 | 269 | def _get_indexes_order(self, field=None, indexes_order=None): 270 | if indexes_order is not None: 271 | if ( 272 | not isinstance(indexes_order, Variable) 273 | and not isinstance(indexes_order, Tunable) 274 | and set(indexes_order) != set(self.indexes) 275 | ): 276 | raise ValueError( 277 | "Not valid indexes_order. It has %s, while expected %s" 278 | % (indexes_order, self.indexes) 279 | ) 280 | return indexes_order 281 | if field is None: 282 | return None 283 | if len(self.indexes) <= 1: 284 | return self.indexes 285 | if set(self.indexes) == set(field.indexes): 286 | return field.indexes_order 287 | if set(self.indexes) <= set(field.indexes): 288 | return function(filter, self.indexes.__contains__, field.indexes_order) 289 | return None 290 | 291 | @property 292 | def labels_order(self): 293 | "Order of the field indexes" 294 | return self._labels_order 295 | 296 | def reorder_label(self, label, label_order=None, **kwargs): 297 | "Changes the order of the label." 298 | rng = self.get_range(label) 299 | if not isiterable(self.get_range(label), str): 300 | raise KeyError("%s is not a label of the field" % label) 301 | if len(rng) <= 1: 302 | return self.copy() 303 | if label_order is None: 304 | label_order = Permutation(rng, label=label) 305 | labels_order = kwargs.pop("labels_order", {}) 306 | labels_order[label] = label_order 307 | return self.copy(labels_order=labels_order, **kwargs) 308 | 309 | def _get_labels_order(self, field=None, labels_order=None, **kwargs): 310 | if labels_order is not None and not isinstance(labels_order, dict): 311 | raise TypeError("labels_order must be a dict") 312 | if labels_order is None: 313 | labels_order = {} 314 | 315 | # Checking for keys in kwargs 316 | for key in self.labels: 317 | if key + "_order" in kwargs: 318 | labels_order[key] = kwargs.pop(key + "_order") 319 | continue 320 | key = self.index_to_axis(key) 321 | if key + "_order" in kwargs: 322 | labels_order[key] = kwargs.pop(key + "_order") 323 | continue 324 | 325 | # Here we check the given values and unpack axes into indexes 326 | given_values = labels_order 327 | labels_order = {} 328 | for key, val in given_values.items(): 329 | rng = self.get_range(key) # This does also some quality control on the key 330 | if ( 331 | not isinstance(val, Variable) 332 | and not isinstance(val, Tunable) 333 | and set(val) != set(rng) 334 | ): 335 | raise ValueError( 336 | "Not valid %s order. It has %s, while expected %s" % (key, val, rng) 337 | ) 338 | for _k in self.get_indexes(key): 339 | if _k == key or _k not in given_values: 340 | labels_order[_k] = val 341 | 342 | # Getting labels_order from the field 343 | if field is not None: 344 | for key, val in field.labels_order: 345 | if key in self.labels and key not in labels_order: 346 | rng = self.get_range(key) 347 | if len(rng) <= 1: 348 | continue 349 | if set(rng) == set(field.get_range(key)): 350 | labels_order[key] = val 351 | elif set(rng) <= set(field.get_range(key)): 352 | labels_order[key] = function(filter, rng.__contains__, val) 353 | 354 | # Creating variables 355 | for key in self.labels: 356 | if key in labels_order and isinstance(labels_order[key], Permutation): 357 | continue 358 | rng = self.get_range(key) 359 | var = Permutation(rng, label=key) 360 | if key in labels_order: 361 | var.value = labels_order[key] 362 | labels_order[key] = var 363 | 364 | return tuple(labels_order.items()), kwargs 365 | 366 | @derived_property(indexes_order) 367 | def ordered_shape(self): 368 | "Shape of the field after fixing the indexes_order" 369 | shape = dict(self.shape) 370 | return tuple(shape[key] for key in self.indexes_order.value) 371 | 372 | def __setitem__(self, coords, value): 373 | return self.set(value, coords) 374 | 375 | def set(self, value, *keys, **coords): 376 | "Sets the components at the given coordinates" 377 | coords = self.lattice.coords.resolve(*keys, **coords, field=self) 378 | indexes = self.coords.get_indexes(coords) 379 | self.value = self.backend.setitem(self.indexes_order, indexes, value) 380 | 381 | def zeros(self, dtype=None): 382 | "Returns the field with all components put to zero" 383 | return self.copy(self.backend.zeros(dtype), dtype=dtype) 384 | 385 | def ones(self, dtype=None): 386 | "Returns the field with all components put to one" 387 | return self.copy(self.backend.ones(dtype), dtype=dtype) 388 | 389 | def random(self, seed=None): 390 | """ 391 | Returns a random field generator. If seed is given, reproducibility is ensured 392 | independently on the field parameters, e.g. indexes_order, etc. 393 | 394 | Parameters 395 | ---------- 396 | seed: int 397 | The seed to use for starting the random number generator. 398 | Note: There is a performance penality in initializing the field if seed is given. 399 | """ 400 | from .random import RandomFieldGenerator 401 | 402 | return RandomFieldGenerator(self, seed) 403 | 404 | def rand(self, seed=None): 405 | "Returns a real field with random numbers distributed uniformely between [0,1)" 406 | # return self.random(seed).random() 407 | return self.copy(self.backend.rand(), dtype="float64") 408 | 409 | @property 410 | def T(self): 411 | "Transposes the field." 412 | return self.transpose() 413 | 414 | def transpose(self, *axes, **axes_order): 415 | """ 416 | Transposes the matrix/tensor indexes of the field. 417 | 418 | *NOTE*: this is conceptually different from numpy.transpose 419 | where all the axes are transposed. 420 | 421 | Parameters 422 | ---------- 423 | axes: str 424 | If given, only the listed axes are transposed, 425 | otherwise all the tensorial axes are changed. 426 | By default the order of the indexes is inverted. 427 | axes_order: dict 428 | Same as axes, but specifying the reordering of the indexes. 429 | The key must be one of the axis and the value the order using 430 | an index per repetition of the axis numbering from 0,1,... 431 | """ 432 | counts = dict(self.axes_counts) 433 | for (axis, val) in axes_order.items(): 434 | if not axis in counts: 435 | raise KeyError("Axis %s not in field" % (axis)) 436 | if not isiterable(val): 437 | raise TypeError("Type of value for axis %s not valid" % (axis)) 438 | val = tuple(val) 439 | if not len(val) == counts[axis]: 440 | raise ValueError( 441 | "%d indexes have been given for axis %s but it has count %d" 442 | % (len(val), axis, counts[axis]) 443 | ) 444 | if not set(val) == set(range(counts[axis])): 445 | raise ValueError( 446 | "%s has been given for axis %s. Not a permutation of %s." 447 | % (val, axis, tuple(range(counts[axis]))) 448 | ) 449 | 450 | if not axes and not axes_order: 451 | axes = ("dofs",) 452 | 453 | axes = [ 454 | axis 455 | for axis in self.get_axes(*axes) 456 | if axis not in axes_order and counts[axis] > 1 457 | ] 458 | 459 | for (axis, val) in tuple(axes_order.items()): 460 | if val == tuple(range(counts[axis])): 461 | del axes_order[axis] 462 | 463 | if not axes and not axes_order: 464 | return self.copy() 465 | return self.copy( 466 | self.backend.transpose(self.indexes_order, axes=axes, **axes_order) 467 | ) 468 | 469 | @property 470 | def iscomplex(self): 471 | "Returns if the field is complex" 472 | return self.dtype in [np.csingle, np.cdouble, np.clongdouble] 473 | 474 | def conj(self): 475 | "Conjugates the field." 476 | if not self.iscomplex: 477 | return self.copy() 478 | return self.copy(self.backend.conj()) 479 | 480 | @property 481 | def H(self): 482 | "Conjugate transpose of the field." 483 | return self.dagger() 484 | 485 | def dagger(self, *axes, **axes_order): 486 | """ 487 | Conjugate and transposes the matrix/tensor indexes. 488 | See help(transpose) for more details. 489 | """ 490 | return self.conj().transpose(*axes, **axes_order) 491 | 492 | @classmethod 493 | def get_input_axes(cls, *axes, **kwargs): 494 | "Auxiliary function to uniform the axes input parameters" 495 | if not (bool(axes), "axes" in kwargs, "axis" in kwargs).count(True) <= 1: 496 | raise ValueError("Only one between *axes, axes= or axis= can be used") 497 | axes = kwargs.pop("axis", kwargs.pop("axes", axes)) 498 | if isinstance(axes, str): 499 | axes = (axes,) 500 | if not isiterable(axes, str): 501 | raise TypeError("Type for axes not valid. %s" % (axes)) 502 | return axes, kwargs 503 | 504 | def roll(self, shift, *axes, **kwargs): 505 | """ 506 | Rolls axis of shift. 507 | 508 | Parameters: 509 | ----------- 510 | shift: int or list of int 511 | The number of places by which elements are shifted. 512 | axis: str or list of str 513 | Axis/axes to roll of shift amount. 514 | """ 515 | axes, kwargs = self.get_input_axes(*axes, **kwargs) 516 | if kwargs: 517 | raise KeyError("Unknown parameter %s" % kwargs) 518 | indexes = self.get_indexes(*axes) if axes else self.get_indexes("all") 519 | return self.copy(self.backend.roll(shift, indexes, self.indexes_order)) 520 | 521 | 522 | FieldType.Field = ArrayField 523 | 524 | 525 | class backend_method: 526 | "Decorator for backend methods" 527 | 528 | def __init__(self, fnc, cls=None): 529 | self.fnc = fnc 530 | self.__name__ = self.fnc.__name__ 531 | if cls: 532 | self.fnc.__qualname__ = cls.__name__ + "." + self.key 533 | setattr(cls, self.key, self) 534 | 535 | @property 536 | def key(self): 537 | "Name of the method" 538 | return self.__name__ 539 | 540 | def __get__(self, obj, owner): 541 | if obj is None: 542 | return self.fnc 543 | return Function(self.fnc, args=(obj.field.value,), label=self.key) 544 | 545 | 546 | class NumpyBackend: 547 | "Numpy array backend for the field class" 548 | 549 | def __init__(self, field): 550 | self.field = field 551 | 552 | @classmethod 553 | def init(cls, field, shape, dtype): 554 | "Initializes a new field" 555 | if field is None: 556 | return function(np.zeros, shape, dtype=dtype) 557 | 558 | return function(np.array, field, dtype=dtype) 559 | 560 | @backend_method 561 | def copy(self): 562 | "Returns a copy of the field" 563 | return self.copy() 564 | 565 | @backend_method 566 | def astype(self, dtype): 567 | "Changes the dtype of the field" 568 | return self.astype(dtype) 569 | 570 | @backend_method 571 | def conj(self): 572 | "Conjugates the field" 573 | return self.conj() 574 | 575 | @backend_method 576 | def zeros(self, dtype): 577 | "Fills the field with zeros" 578 | return np.zeros_like(self, dtype=dtype) 579 | 580 | @backend_method 581 | def ones(self, dtype): 582 | "Fills the field with ones" 583 | return np.ones_like(self, dtype=dtype) 584 | 585 | @backend_method 586 | def rand(self): 587 | "Fills the field with random numbers" 588 | return np.random.rand(*self.shape) 589 | 590 | @backend_method 591 | def getitem(self, indexes_order, coords, **labels): 592 | "Direct implementation of getitem" 593 | for label, order in labels.items(): 594 | coords[label] = tuple(order.index(val) for val in coords[label]) 595 | indexes = tuple(coords.pop(idx, slice(None)) for idx in indexes_order) 596 | assert not coords, "Coords didn't empty" 597 | return self.__getitem__(indexes) 598 | 599 | @backend_method 600 | def setitem(self, indexes_order, coords, value, **labels): 601 | "Direct implementation of setitem" 602 | for label, order in labels.items(): 603 | coords[label] = tuple(order.index(val) for val in coords[label]) 604 | indexes = tuple(coords.pop(idx, slice(None)) for idx in indexes_order) 605 | assert not coords, "Coords didn't empty" 606 | self.__setitem__(indexes, value) 607 | return self 608 | 609 | @backend_method 610 | def reorder(self, new_order, old_order): 611 | "Direct implementation of reordering" 612 | indexes = tuple(old_order.index(idx) for idx in new_order) 613 | return self.transpose(axes=indexes) 614 | 615 | @backend_method 616 | def reorder_label(self, key, new_order, old_order, indexes_order): 617 | "Direct implementation of label reordering" 618 | indexes = tuple(old_order.index(idx) for idx in new_order) 619 | return self.take(indexes, axis=indexes_order.index(key)) 620 | 621 | @backend_method 622 | def reshape(self, shape, new_order=None, old_order=None): 623 | "Direct implementation of reshaping" 624 | common = tuple(idx for idx in old_order if idx in new_order) 625 | order = list(old_order.index(idx) for idx in common) 626 | 627 | if order != sorted(order): 628 | # We need to reorder first 629 | for idx in range(len(old_order)): 630 | if idx not in order: 631 | order.insert(idx, idx) 632 | self.transpose(axes=order) 633 | 634 | return self.reshape(shape) 635 | 636 | @backend_method 637 | def transpose(self, indexes_order=None, axes=None, **axes_order): 638 | "Direct implementation of transposing" 639 | indexes = defaultdict(list) 640 | axes_order = {key: list(val) for key, val in axes_order.items()} 641 | for idx in indexes_order: 642 | axis = BaseField.index_to_axis(idx) 643 | indexes[axis].append(idx) 644 | new_order = [] 645 | for idx in indexes_order: 646 | axis = BaseField.index_to_axis(idx) 647 | if axis in axes: 648 | idx = indexes[axis].pop() 649 | elif axis in axes_order: 650 | idx = indexes[axis][axes_order[axis].pop(0)] 651 | else: 652 | idx = indexes[axis].pop(0) 653 | indexes = tuple(new_order.index(idx) for idx in indexes_order) 654 | return self.transpose(axes=indexes) 655 | 656 | @backend_method 657 | def roll(self, shift, indexes=None, indexes_order=None): 658 | "Direct implementation of rolling" 659 | indexes = tuple(indexes_order.index(idx) for idx in indexes) 660 | return self.roll(shift, indexes) 661 | 662 | def __getattr__(self, key): 663 | raise AttributeError("Unknown %s" % key) 664 | 665 | 666 | METHODS = ( 667 | "reorder", 668 | "transpose", 669 | "conj", 670 | "dagger", 671 | "roll", 672 | ) 673 | 674 | for _ in METHODS: 675 | __all__.append(_) 676 | globals()[_] = wrap_method(_, ArrayField) 677 | -------------------------------------------------------------------------------- /lyncs/field/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class of the Field type that implements 3 | the interface to the Lattice class. 4 | """ 5 | # pylint: disable=C0303,C0330 6 | 7 | __all__ = [ 8 | "BaseField", 9 | ] 10 | 11 | import re 12 | from collections import Counter 13 | from functools import wraps 14 | from .types.base import FieldType 15 | from lyncs_utils import ( 16 | default_repr_pretty, 17 | compute_property, 18 | count, 19 | add_kwargs_of, 20 | isiterable, 21 | ) 22 | 23 | 24 | class BaseField: 25 | """ 26 | Base class of the Field type that implements 27 | the interface to the Lattice class and deduce 28 | the list of Field types from the field axes. 29 | 30 | The list of types are accessible via field.types 31 | """ 32 | 33 | _repr_pretty_ = default_repr_pretty 34 | 35 | _index_to_axis = re.compile("_[0-9]+$") 36 | 37 | @classmethod 38 | def indexes_to_axes(cls, *indexes): 39 | "Converts field indexes to lattice axes" 40 | return tuple(re.sub(BaseField._index_to_axis, "", index) for index in indexes) 41 | 42 | @classmethod 43 | def index_to_axis(cls, index): 44 | "Converts a field index to a lattice axis" 45 | return re.sub(BaseField._index_to_axis, "", index) 46 | 47 | def axes_to_indexes(self, *axes): 48 | "Converts lattice axes to field indexes" 49 | axes = tuple(self.lattice.expand(*axes)) 50 | counters = {axis: count() for axis in set(axes)} 51 | return tuple(axis + "_" + str(next(counters[axis])) for axis in axes) 52 | 53 | def __init_attributes__( 54 | self, field, axes=None, lattice=None, coords=None, **kwargs 55 | ): 56 | """ 57 | Initializes the field class. 58 | 59 | Parameters 60 | ---------- 61 | axes: list(str) 62 | List of axes of the field. 63 | lattice: Lattice 64 | The lattice on which the field is defined. 65 | coords: list/dict 66 | Coordinates of the field. 67 | kwargs: dict 68 | Extra parameters that will be passed to the field types. 69 | """ 70 | 71 | from ..lattice import Lattice, default_lattice 72 | 73 | if lattice is not None and not isinstance(lattice, Lattice): 74 | raise TypeError("lattice must be of Lattice type") 75 | 76 | self._lattice = ( 77 | lattice 78 | if lattice is not None 79 | else field.lattice 80 | if isinstance(field, BaseField) 81 | else default_lattice() 82 | ).freeze() 83 | 84 | self._axes = tuple( 85 | self.lattice.expand(axes) 86 | if axes is not None 87 | else field.axes 88 | if isinstance(field, BaseField) 89 | else self.lattice.dims 90 | ) 91 | 92 | if isinstance(field, BaseField): 93 | same_indexes = set(self.indexes).intersection(field.indexes) 94 | self._coords = field.coords.extract(same_indexes) 95 | else: 96 | self._coords = {} 97 | self._coords = self.lattice.coords.resolve(coords, field=self) 98 | 99 | self._types = tuple( 100 | (name, ftype) 101 | for name, ftype in FieldType.s.items() 102 | if isinstance(self, ftype) 103 | ) 104 | 105 | # ordering types by relevance 106 | self._types = tuple( 107 | (name, ftype) 108 | for name, ftype in sorted( 109 | self.types, 110 | key=lambda item: len(tuple(self.lattice.expand(item[1].axes.expand))), 111 | reverse=True, 112 | ) 113 | ) 114 | 115 | for (_, ftype) in self.types: 116 | try: 117 | kwargs = ftype.__init_attributes__(self, field=field, **kwargs) 118 | except AttributeError: 119 | continue 120 | 121 | return kwargs 122 | 123 | @add_kwargs_of(__init_attributes__) 124 | def __init__(self, field=None, **kwargs): 125 | """ 126 | Initializes the field class. 127 | 128 | Parameters 129 | ---------- 130 | field: Field 131 | If given, then the missing parameters are deduced from it. 132 | """ 133 | 134 | kwargs = self.__init_attributes__(field, **kwargs) 135 | 136 | if kwargs: 137 | raise ValueError("Could not resolve the following kwargs.\n %s" % kwargs) 138 | 139 | @property 140 | def lattice(self): 141 | "The lattice on which the field is defined." 142 | return self._lattice 143 | 144 | @property 145 | def axes(self): 146 | "List of axes of the field. Order is not significant. See indexes_order." 147 | return self._axes 148 | 149 | @compute_property 150 | def axes_counts(self): 151 | "Tuple of axes and counts in the field" 152 | return tuple(Counter(self.axes).items()) 153 | 154 | @compute_property 155 | def dims(self): 156 | "List of dims in the field axes" 157 | return tuple( 158 | key for key in self.indexes if self.index_to_axis(key) in self.lattice.dims 159 | ) 160 | 161 | @compute_property 162 | def dofs(self): 163 | "List of dofs in the field axes" 164 | return tuple( 165 | key for key in self.indexes if self.index_to_axis(key) in self.lattice.dofs 166 | ) 167 | 168 | @compute_property 169 | def labels(self): 170 | "List of labels in the field axes" 171 | return tuple( 172 | key 173 | for key in self.indexes 174 | if self.index_to_axis(key) in self.lattice.labels 175 | ) 176 | 177 | @compute_property 178 | def indexes(self): 179 | """ 180 | List of indexes of the field. Similar to .axes but axis are enumerated. 181 | Order is not significant. See field.indexes_order. 182 | """ 183 | return self.axes_to_indexes(self.axes) 184 | 185 | def reshape(self, *axes, **kwargs): 186 | """ 187 | Reshapes the field changing the axes. 188 | Note: only axes with size 1 can be removed and 189 | new axes are added with size 1 and coord=None 190 | """ 191 | axes = kwargs.pop("axes", axes) 192 | indexes = self.axes_to_indexes(axes) 193 | shape = dict(self.shape) 194 | _squeeze = (index for index in self.indexes if index not in indexes) 195 | for index in _squeeze: 196 | if not shape[index] == 1: 197 | raise ValueError("Can only remove axes which size is 1") 198 | _extend = (index for index in indexes if index not in self.indexes) 199 | coords = kwargs.pop("coords", {}) 200 | for index in _extend: 201 | coords.setdefault(index, None) 202 | return self.copy(axes=axes, coords=coords, **kwargs) 203 | 204 | def squeeze(self, *axes, **kwargs): 205 | "Removes axes with size one." 206 | axes = kwargs.pop("axes", axes) 207 | indexes = self.get_indexes(*axes) if axes else self.indexes 208 | axes = tuple( 209 | self.index_to_axis(key) 210 | for key, val in self.shape 211 | if key not in indexes or val > 1 212 | ) 213 | return self.copy(axes=axes, **kwargs) 214 | 215 | def unsqueeze(self, *axes, **kwargs): 216 | "Sets coordinate to None for the axes with size one." 217 | axes = kwargs.pop("axes", axes) 218 | indexes = self.get_indexes(*axes) if axes else self.indexes 219 | coords = kwargs.pop("coords", {}) 220 | coords.update( 221 | {key: None for key, val in self.shape if key in indexes and val == 1} 222 | ) 223 | return self.copy(coords=coords, **kwargs) 224 | 225 | def extend(self, *axes, **kwargs): 226 | "Adds axes with size one (coord=None)." 227 | axes = kwargs.pop("axes", axes) 228 | return self.reshape(self.axes + axes, **kwargs) 229 | 230 | def get_axes(self, *axes): 231 | "Returns the corresponding field axes to the given axes/dimensions" 232 | if not isiterable(axes, str): 233 | raise TypeError("The arguments need to be a list of strings") 234 | if "all" in axes: 235 | return tuple(sorted(self.axes)) 236 | axes = (axis for axis in self.lattice.expand(*axes) if axis in self.axes) 237 | return tuple(sorted(axes)) 238 | 239 | def get_indexes(self, *axes): 240 | "Returns the corresponding indexes of the given axes/indexes/dimensions" 241 | if not isiterable(axes, str): 242 | raise TypeError("The arguments need to be a list of strings") 243 | if "all" in axes: 244 | return tuple(sorted(self.indexes)) 245 | indexes = set(axis for axis in axes if axis in self.indexes) 246 | axes = tuple(self.lattice.expand(set(axes).difference(indexes))) 247 | indexes.update([idx for idx in self.indexes if self.index_to_axis(idx) in axes]) 248 | return tuple(sorted(indexes)) 249 | 250 | def get_range(self, key): 251 | "Returns the range of the given index/axis." 252 | tmp = self.get_indexes(key) 253 | if len(tmp) > 1: 254 | tmp = set(self.get_range(_k) for _k in tmp) 255 | if len(tmp) == 1: 256 | return tuple(tmp)[0] 257 | raise ValueError( 258 | "%s corresponds to more than one index with different size" % key 259 | ) 260 | if len(tmp) == 0: 261 | raise KeyError("%s not in field" % key) 262 | key = tmp[0] 263 | val = self.coords[key] 264 | if val == slice(None): 265 | axis = self.index_to_axis(key) 266 | return self.lattice.get_axis_range(axis) 267 | if isinstance(val, (int, str, type(None))): 268 | return (val,) 269 | if isinstance(val, slice): 270 | return range(val.start, val.stop, val.step) 271 | return val 272 | 273 | def get_size(self, key): 274 | "Returns the size of the given index/axis." 275 | return len(self.get_range(key)) 276 | 277 | @compute_property 278 | def shape(self): 279 | "Returns the list of indexes with size. Order is not significant." 280 | return tuple((key, self.get_size(key)) for key in self.indexes) 281 | 282 | @compute_property 283 | def size(self): 284 | "Returns the number of elements in the field." 285 | prod = 1 286 | for _, length in self.shape: 287 | prod *= length 288 | return prod 289 | 290 | @property 291 | def types(self): 292 | "List of field types that the field is instance of, ordered per relevance" 293 | return getattr(self, "_types", ()) 294 | 295 | @property 296 | def coords(self): 297 | "List of coordinates of the field." 298 | return self._coords 299 | 300 | def __dir__(self): 301 | attrs = set(super().__dir__()) 302 | for _, ftype in self.types: 303 | attrs.update((key for key in dir(ftype) if not key.startswith("_"))) 304 | return sorted(attrs) 305 | 306 | def __getattr__(self, key): 307 | "Looks up for methods in the field types" 308 | if key == "_types": 309 | raise AttributeError 310 | 311 | for _, ftype in self.types: 312 | if isinstance(self, ftype): 313 | try: 314 | return getattr(ftype, key).__get__(self) 315 | except AttributeError: 316 | continue 317 | 318 | raise AttributeError("%s not found" % key) 319 | 320 | def __setattr__(self, key, val): 321 | "Looks up for methods in the field types" 322 | for _, ftype in self.types: 323 | if isinstance(self, ftype): 324 | try: 325 | getattr(ftype, key).__set__(self, val) 326 | except AttributeError: 327 | continue 328 | 329 | super().__setattr__(key, val) 330 | 331 | @property 332 | def type(self): 333 | "Name of the Field. Equivalent to the most relevant field type." 334 | return self.types[0][0] 335 | 336 | def copy(self, **kwargs): 337 | "Creates a shallow copy of the field" 338 | return type(self)(self, **kwargs) 339 | 340 | def __getitem__(self, coords): 341 | return self.get(coords) 342 | 343 | def get(self, *keys, **coords): 344 | "Gets the components at the given coordinates" 345 | return self.copy(coords=(keys, coords)) 346 | 347 | def __pos__(self): 348 | return self 349 | 350 | def __eq__(self, other): 351 | return self is other or ( 352 | isinstance(other, type(self)) 353 | and self.lattice == other.lattice 354 | and set(self.indexes) == set(other.indexes) 355 | and self.coords == other.coords 356 | ) 357 | 358 | 359 | FieldType.BaseField = BaseField 360 | FieldType.Field = BaseField 361 | 362 | 363 | def wrap_method(method, ftype): 364 | "Wrapper for field methods" 365 | 366 | fnc = getattr(ftype, method) 367 | 368 | @wraps(fnc) 369 | def wrapped(field, *args, **kwargs): 370 | if not isinstance(field, ftype): 371 | raise TypeError("First argument of %s must be a field." % method) 372 | return fnc(field, *args, **kwargs) 373 | 374 | return wrapped 375 | 376 | 377 | METHODS = ( 378 | "squeeze", 379 | "extend", 380 | "reshape", 381 | ) 382 | 383 | for _ in METHODS: 384 | __all__.append(_) 385 | globals()[_] = wrap_method(_, BaseField) 386 | -------------------------------------------------------------------------------- /lyncs/field/contractions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Set of contraction functions for array fields 3 | """ 4 | # pylint: disable=C0303,C0330 5 | 6 | __all__ = [ 7 | "dot", 8 | "einsum", 9 | ] 10 | 11 | from collections import defaultdict 12 | import numpy as np 13 | from tuneit import Permutation 14 | from lyncs_utils import count 15 | from .array import ArrayField, NumpyBackend, backend_method 16 | 17 | 18 | def prepare_fields(*fields): 19 | "Auxiliary function for preparing the fields" 20 | 21 | fields = list(fields) 22 | fields[0], fields[1:] = fields[0].prepare(*fields[1:], elemwise=False) 23 | return tuple(fields) 24 | 25 | 26 | def dot_indexes(*fields, closed_indexes=None, open_indexes=None): 27 | "Auxiliary function for formatting the indexes of the dot product" 28 | 29 | axes = set() 30 | for field in fields: 31 | axes.update(field.axes) 32 | 33 | if closed_indexes is not None: 34 | if isinstance(closed_indexes, str): 35 | closed_indexes = [closed_indexes] 36 | tmp = set() 37 | for axis in closed_indexes: 38 | for field in fields: 39 | tmp.update(field.get_axes(axis)) 40 | 41 | closed_indexes = tmp 42 | assert closed_indexes.issubset(axes), "Trivial assertion." 43 | axes = axes.difference(closed_indexes) 44 | else: 45 | closed_indexes = set() 46 | 47 | if open_indexes is not None: 48 | if isinstance(open_indexes, str): 49 | open_indexes = [open_indexes] 50 | tmp = set() 51 | for axis in open_indexes: 52 | for field in fields: 53 | tmp.update(field.get_axes(axis)) 54 | 55 | open_indexes = tmp 56 | if open_indexes.intersection(closed_indexes): 57 | raise ValueError("Close and open indexes cannot have axes in common.") 58 | assert open_indexes.issubset(axes), "Trivial assertion." 59 | axes = axes.difference(open_indexes) 60 | else: 61 | open_indexes = set() 62 | 63 | return axes, closed_indexes, open_indexes 64 | 65 | 66 | def dot_prepare(*fields, axes=None, axis=None, closed_indexes=None, open_indexes=None): 67 | "Auxiliary function that prepares for a dot product checking the input" 68 | 69 | if (axis, axes, closed_indexes).count(None) < 2: 70 | raise KeyError( 71 | """ 72 | Only one between axis, axes or closed_indexes can be used. They are the same parameters. 73 | """ 74 | ) 75 | 76 | if closed_indexes is None and open_indexes is None: 77 | closed_indexes = "dofs" 78 | 79 | closed_indexes = ( 80 | axis if axis is not None else axes if axes is not None else closed_indexes 81 | ) 82 | 83 | axes, closed_indexes, open_indexes = dot_indexes( 84 | *fields, closed_indexes=closed_indexes, open_indexes=open_indexes 85 | ) 86 | 87 | counts = {} 88 | for field in fields: 89 | for key, num in field.axes_counts: 90 | if key in axes: 91 | if key not in counts: 92 | counts[key] = num 93 | continue 94 | 95 | if counts[key] != num: 96 | raise ValueError( 97 | """ 98 | Axis %s has count %s and %s for different fields. 99 | Axes that are neither closes or open indexes, 100 | must have the same count between all fields. 101 | """ 102 | % (key, num, counts[key]) 103 | ) 104 | 105 | return axes, closed_indexes, open_indexes 106 | 107 | 108 | def trace_indexes(*field_indexes, axes=None): 109 | "Auxiliary function that traces field_indexes given in a einsum style" 110 | 111 | if not axes: 112 | return (field_indexes,) 113 | 114 | for key, val in tuple(field_indexes[-1].items()): 115 | if key in axes and len(val) > 1: 116 | idx = val[-1] 117 | if len(val) > 2: 118 | field_indexes[-1][key] = val[:-1] + (val[0],) 119 | else: 120 | del field_indexes[-1][key] 121 | for field in reversed(field_indexes[:-1]): 122 | if key in field and idx in field[key]: 123 | assert idx == field[key][-1], "Trivial Assertion" 124 | field[key] = field[key][:-1] + (val[0],) 125 | 126 | return field_indexes 127 | 128 | 129 | def dot( 130 | *fields, 131 | axes=None, 132 | axis=None, 133 | closed_indexes=None, 134 | open_indexes=None, 135 | reduced_indexes=None, 136 | trace=False, 137 | average=False, 138 | debug=False 139 | ): 140 | """ 141 | Performs the dot product between fields. 142 | 143 | Default behaviors: 144 | ------------------ 145 | 146 | Contractions are performed between only degree of freedoms of the fields, e.g. field.dofs. 147 | For each field, indexes are always contracted in pairs combining the outer-most free index 148 | of the left with the inner-most of the right. 149 | 150 | I.e. dot(*fields) = dot(*fields, axes="dofs") 151 | 152 | Parameters: 153 | ----------- 154 | fields: Field 155 | List of fields to perform dot product between. 156 | axes: str, list 157 | Axes where the contraction is performed on. 158 | Indexes are contracted in pairs combining the outer-most free index 159 | of the left with the inner-most of the right. 160 | axis: str, list 161 | Same as axes. 162 | closed_indexes: str, list 163 | Same as axes. 164 | open_indexes: str, list 165 | Opposite of close indexes, i.e. the axes that are left open. 166 | reduced_indexes: str, list 167 | List of indexes to sum over and not available in the output field. 168 | average: bool 169 | If True, the reduced_indexes are averaged, i.e. result/prod(reduced_indexes.size). 170 | trace: bool 171 | If True, then the closed indexes are also traced 172 | debug: bool 173 | If True, then the output are the contraction indexes 174 | 175 | Examples: 176 | --------- 177 | dot(vector, vector, axes="color") 178 | [x,y,z,t,spin,color] x [x,y,z,t,spin,color] -> [x,y,z,t,spin] 179 | [X,Y,Z,T, mu , c_0 ] x [X,Y,Z,T, mu , c_0 ] -> [X,Y,Z,T, mu ] 180 | 181 | dot(vector, vector, closed_indexes="color", open_indexes="spin") 182 | [x,y,z,t,spin,color] x [x,y,z,t,spin,color] -> [x,y,z,t,spin,spin] 183 | [X,Y,Z,T, mu , c_0 ] x [X,Y,Z,T, nu , c_0 ] -> [X,Y,Z,T, mu , nu ] 184 | 185 | dot(gauge, gauge, closed_indexes="color", trace=True) 186 | [x,y,z,t,color,color] x [x,y,z,t,color,color] -> [x,y,z,t] 187 | [X,Y,Z,T, c_0 , c_1 ] x [X,Y,Z,T, c_1 , c_0 ] -> [X,Y,Z,T] 188 | """ 189 | fields = prepare_fields(*fields) 190 | axes, closed_indexes, open_indexes = dot_prepare( 191 | *fields, 192 | axis=axis, 193 | axes=axes, 194 | closed_indexes=closed_indexes, 195 | open_indexes=open_indexes, 196 | ) 197 | 198 | counter = count() 199 | field_indexes = [] 200 | new_field_indexes = defaultdict(tuple) 201 | for field in fields: 202 | field_indexes.append({}) 203 | for key, num in field.axes_counts: 204 | 205 | if key in axes: 206 | if key not in new_field_indexes: 207 | new_field_indexes[key] = tuple(counter(num)) 208 | field_indexes[-1][key] = tuple(new_field_indexes[key]) 209 | 210 | elif key in open_indexes: 211 | field_indexes[-1][key] = tuple(counter(num)) 212 | new_field_indexes[key] += field_indexes[-1][key] 213 | 214 | else: 215 | assert key in closed_indexes, "Trivial assertion." 216 | if key not in new_field_indexes: 217 | new_field_indexes[key] = tuple(counter(num)) 218 | field_indexes[-1][key] = tuple(new_field_indexes[key]) 219 | else: 220 | assert len(new_field_indexes[key]) > 0, "Trivial assertion." 221 | field_indexes[-1][key] = (new_field_indexes[key][-1],) + tuple( 222 | counter(num - 1) 223 | ) 224 | new_field_indexes[key] = ( 225 | new_field_indexes[key][:-1] + field_indexes[-1][key][1:] 226 | ) 227 | if len(new_field_indexes[key]) == 0: 228 | del new_field_indexes[key] 229 | 230 | field_indexes.append(dict(new_field_indexes)) 231 | 232 | if trace: 233 | field_indexes = trace_indexes(*field_indexes, axes=closed_indexes) 234 | 235 | if average: 236 | pass 237 | 238 | return einsum(*fields, indexes=field_indexes, debug=debug) 239 | 240 | 241 | ArrayField.dot = dot 242 | ArrayField.__matmul__ = dot 243 | 244 | 245 | def einsum(*fields, indexes=None, debug=False): 246 | """ 247 | Performs the einsum product between fields. 248 | 249 | Parameters: 250 | ----------- 251 | fields: Field 252 | List of fields to perform the einsum between. 253 | indexes: list of dicts of indexes 254 | List of dictionaries for each field plus one for output field if not scalar. 255 | Each dictionary should have a key per axis of the field. 256 | Every key should have a list of indexes for every repetition of the axis in the field. 257 | Indexes must be integers. 258 | 259 | Examples: 260 | --------- 261 | einsum(vector, vector, indexes=[{'x':0,'y':1,'z':2,'t':3,'spin':4,'color':5}, 262 | {'x':0,'y':1,'z':2,'t':3,'spin':4,'color':6}, 263 | {'x':0,'y':1,'z':2,'t':3,'color':(5,6)} ]) 264 | 265 | [x,y,z,t,spin,color] x [x,y,z,t,spin,color] -> [x,y,z,t,color,color] 266 | [0,1,2,3, 4 , 5 ] x [0,1,2,3, 4 , 6 ] -> [0,1,2,3, 5 , 6 ] 267 | """ 268 | fields = prepare_fields(*fields) 269 | if isinstance(indexes, dict): 270 | indexes = (indexes,) 271 | indexes = tuple(indexes) 272 | 273 | if not len(indexes) in (len(fields), len(fields) + 1): 274 | raise ValueError("A set of indexes per field must be given.") 275 | 276 | if not all((isinstance(idxs, dict) for idxs in indexes)): 277 | raise TypeError("Each set of indexes list must be a dictionary") 278 | 279 | for idxs in indexes: 280 | for key, val in list(idxs.items()): 281 | if isinstance(val, int) or len(val) == 1: 282 | new_key = fields[0].axes_to_indexes(key)[0] 283 | idxs[new_key] = val if isinstance(val, int) else val[0] 284 | if new_key != key: 285 | del idxs[key] 286 | continue 287 | for i, _id in enumerate(val): 288 | new_key = key + "_%d" % i 289 | assert new_key not in idxs 290 | idxs[new_key] = _id 291 | del idxs[key] 292 | 293 | for (i, field) in enumerate(fields): 294 | if not set(indexes[i].keys()) == set(field.indexes): 295 | raise ValueError( 296 | """ 297 | Indexes must be specified for all the field axes/indexes. 298 | For field %d, 299 | Got indexes: %s 300 | Field indexes: %s 301 | """ 302 | % (i, tuple(indexes[i].keys()), field.indexes) 303 | ) 304 | 305 | if debug: 306 | return indexes 307 | 308 | indexes_order = Permutation( 309 | list(indexes[-1].keys()), label="indexes_order", uid=True 310 | ).value 311 | # TODO: coords 312 | return fields[0].copy( 313 | fields[0].backend.contract( 314 | *(field.value for field in fields[1:]), 315 | *(field.indexes_order for field in fields), 316 | indexes=indexes, 317 | indexes_order=indexes_order, 318 | ), 319 | axes=fields[0].indexes_to_axes(*indexes[-1].keys()), 320 | indexes_order=indexes_order, 321 | ) 322 | 323 | 324 | ArrayField.einsum = einsum 325 | 326 | SYMBOLS = list("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 327 | 328 | 329 | @backend_method 330 | def contract(*fields_orders, indexes=None, indexes_order=None): 331 | "Implementation of contraction via einsum" 332 | assert len(fields_orders) % 2 == 0 333 | fields = fields_orders[: len(fields_orders) // 2] 334 | orders = fields_orders[len(fields_orders) // 2 :] 335 | 336 | symbols = [] 337 | for order, idxs in zip(orders + (indexes_order,), indexes): 338 | symbols.append("") 339 | for idx in order: 340 | symbols[-1] += SYMBOLS[idxs[idx]] 341 | 342 | string = ",".join(symbols[:-1]) + "->" + symbols[-1] 343 | return np.einsum(string, *fields) 344 | 345 | 346 | NumpyBackend.contract = contract 347 | -------------------------------------------------------------------------------- /lyncs/field/random.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from functools import wraps 3 | import re 4 | from numpy import random 5 | from tuneit import Function 6 | from .base import BaseField 7 | from .array import backend_method 8 | 9 | 10 | @dataclass(frozen=True) 11 | class RandomFieldGenerator: 12 | field: BaseField 13 | seed: int = None 14 | 15 | @property 16 | def backend(self): 17 | "Returns the computational backend of the field (numpy.random)." 18 | return RandomBackend(self.field) 19 | 20 | @property 21 | def backend_kwargs(self): 22 | "Returns the list of field variables to be passed to the backend function" 23 | if self.seed is None: 24 | return {} 25 | kwargs = dict(seed=self.seed, indexes_order=self.field.indexes_order) 26 | kwargs.update(dict(self.field.labels_order)) 27 | return kwargs 28 | 29 | def shuffle(self, *axes): 30 | "Randomly shuffles the content of the field along the axes" 31 | return self.field.copy(self.backend.shuffle(*axes, **self.backend_kwargs)) 32 | 33 | def bytes(self): 34 | "Fills up the field with random bytes" 35 | return self.field.copy(self.backend.bytes(**self.backend_kwargs)) 36 | 37 | 38 | def random_method(fnc): 39 | "Wraps numpy random methods" 40 | 41 | @wraps(fnc) 42 | def method(self, *args, **kwargs): 43 | 44 | if "size" in kwargs: 45 | raise KeyError("The parameter 'size' has been disabled") 46 | if "out" in kwargs: 47 | raise KeyError("The parameter 'out' has been disabled") 48 | 49 | # Getting dtype of the resulting field 50 | dtype = kwargs.get("dtype", self.field.dtype) 51 | try: 52 | out = fnc(random.default_rng(), *args, dtype=dtype, size=1, **kwargs) 53 | kwargs["dtype"] = dtype 54 | except TypeError: 55 | out = fnc(random.default_rng(), *args, size=1, **kwargs) 56 | if dtype != self.field.dtype: 57 | raise TypeError( 58 | "Cannot change dtype using '%s' function" % fnc.__name__ 59 | ) 60 | 61 | dtype = out.dtype 62 | return self.field.copy( 63 | getattr(self.backend, fnc.__name__)(*args, **kwargs, **self.backend_kwargs), 64 | dtype=dtype, 65 | ) 66 | 67 | # Editing doc 68 | doc = method.__doc__.split("\n") 69 | assert doc[1].strip().startswith(method.__name__ + "(") 70 | assert "size=None" in doc[1] 71 | 72 | params = [] 73 | for param in "size", "out": 74 | if param in doc[1]: 75 | params.append(param) 76 | doc[1] = re.sub("(, )?" + param + "=?['a-zA-Z_0-9']*", "", doc[1]) 77 | doc.insert(1, "") 78 | if params: 79 | doc.insert( 80 | 1, 81 | " " * 8 82 | + "The parameter(s) '%s' is disabled since deduced from the field." 83 | % "', '".join(params), 84 | ) 85 | doc.insert(1, " " * 8 + "Note: this documentation has been copied from numpy.") 86 | 87 | method.__doc__ = "\n".join(doc) 88 | 89 | return method 90 | 91 | 92 | @dataclass 93 | class RandomBackend: 94 | field: BaseField 95 | 96 | def generate(self, fnc, *args, seed=None, indexes_order=None, **kwargs): 97 | if seed is None: 98 | return getattr(random.default_rng(), fnc)(*args, **kwargs, size=field.shape) 99 | raise NotImplementedError("Reproducible random number not implemented") 100 | 101 | @backend_method 102 | def shuffle(self, *axes, **kwargs): 103 | "Randomly shuffles the content of the field along the axes" 104 | pass 105 | 106 | @backend_method 107 | def bytes(self, **kwargs): 108 | "Fills up the field with random bytes" 109 | pass 110 | 111 | def __getattr__(self, key): 112 | if hasattr(random.Generator, key): 113 | return Function(self.generate, args=(key), label=key) 114 | return super().__getattr__(key) 115 | 116 | def __getstate__(self): 117 | return self.field 118 | 119 | 120 | for _fnc in dir(random.Generator): 121 | if not _fnc.startswith("_") and _fnc not in ( 122 | "bit_generator", 123 | "bytes", 124 | "shuffle", 125 | "permutation", 126 | ): 127 | setattr( 128 | RandomFieldGenerator, _fnc, random_method(getattr(random.Generator, _fnc)) 129 | ) 130 | -------------------------------------------------------------------------------- /lyncs/field/reductions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reductions for the fields 3 | """ 4 | # pylint: disable=C0303,C0330 5 | 6 | __all__ = [ 7 | "trace", 8 | ] 9 | 10 | import numpy as np 11 | from lyncs_utils import count 12 | from .array import ArrayField, NumpyBackend, backend_method 13 | from .contractions import einsum, trace_indexes, dot_prepare 14 | 15 | 16 | def trace(self, *axes, **kwargs): 17 | """ 18 | Performs the trace over repeated axes contracting the outer-most index with the inner-most. 19 | 20 | Parameters 21 | ---------- 22 | axes: str 23 | If given, only the listed axes are traced. 24 | If the axes are two indexes of the field, then those two indexes are traced. 25 | """ 26 | tmp, kwargs = ArrayField.get_input_axes(*axes, **kwargs) 27 | if kwargs: 28 | raise ValueError("Unknown kwargs %s" % kwargs) 29 | _, axes, _ = dot_prepare(self, axes=tmp) 30 | 31 | counts = dict(self.axes_counts) 32 | axes = tuple(axis for axis in axes if counts[axis] > 1) 33 | 34 | if not axes: 35 | return self 36 | 37 | if len(axes) == 1: 38 | if ( 39 | len(tmp) == 2 40 | and set(tmp) <= set(self.indexes) 41 | and self.index_to_axis(tmp[0]) == self.index_to_axis(tmp[1]) 42 | ): 43 | indexes = tmp 44 | else: 45 | indexes = sorted(self.get_indexes()) 46 | indexes = (indexes[0], indexes[-1]) 47 | axes = tuple( 48 | self.index_to_axis(idx) for idx in self.indexes if idx not in indexes 49 | ) 50 | return self.copy(self.backend.trace(indexes, self.indexes_order), axes=axes) 51 | 52 | counter = count() 53 | indexes = {} 54 | for axis, num in counts.items(): 55 | indexes[axis] = tuple(counter(num)) 56 | 57 | indexes = trace_indexes(indexes, indexes, axes=axes) 58 | return einsum(self, indexes=indexes) 59 | 60 | 61 | ArrayField.trace = trace 62 | 63 | _ = backend_method( 64 | lambda self, indexes, indexes_order: self.trace( 65 | axis1=indexes_order.index(indexes[0]), axis2=indexes_order.index(indexes[1]) 66 | ) 67 | ) 68 | _.__name__ = "trace" 69 | NumpyBackend.trace = _ 70 | 71 | 72 | def reduction_method(key, fnc=None, doc=None): 73 | """ 74 | Default implementation for field reductions 75 | 76 | Parameters 77 | ---------- 78 | key: str 79 | The key of the method 80 | fnc: callable 81 | Fallback for the method in case self it is not a field 82 | """ 83 | 84 | def method(self, *axes, **kwargs): 85 | if not isinstance(self, ArrayField): 86 | if fnc is None: 87 | raise TypeError( 88 | "First argument of %s must be of type Field. Given %s" 89 | % (key, type(self).__name__) 90 | ) 91 | 92 | return fnc(self, *axes, **kwargs) 93 | 94 | axes, kwargs = self.get_input_axes(*axes, **kwargs) 95 | indexes = self.get_indexes(*axes) if axes else self.indexes 96 | axes = tuple( 97 | self.index_to_axis(idx) for idx in self.indexes if idx not in indexes 98 | ) 99 | 100 | # Deducing the dtype of the output 101 | if fnc is not None: 102 | trial = fnc(np.ones((1), dtype=self.dtype), **kwargs) 103 | else: 104 | trial = getattr(np.ones((1), dtype=self.dtype), key)(**kwargs) 105 | 106 | if axes: 107 | result = getattr(self.backend, key)(indexes, self.indexes_order, **kwargs) 108 | else: 109 | result = getattr(self.backend, key)(**kwargs) 110 | 111 | if isinstance(trial, tuple): 112 | return tuple( 113 | ( 114 | self.copy(result[i], dtype=trial.dtype, axes=axes) 115 | for i, trial in enumerate(trial) 116 | ) 117 | ) 118 | return self.copy(result, dtype=trial.dtype, axes=axes) 119 | 120 | method.__name__ = key 121 | 122 | if doc: 123 | method.__doc__ = doc 124 | elif fnc: 125 | method.__doc__ = fnc.__doc__ 126 | 127 | return method 128 | 129 | 130 | def backend_reduction_method(key, fnc=None, doc=None): 131 | """ 132 | Returns a method for the backend that calls 133 | the given reduction (key) of the field value. 134 | """ 135 | 136 | def method(self, indexes=None, indexes_order=None, **kwargs): 137 | if indexes is not None: 138 | kwargs["axis"] = tuple(indexes_order.index(idx) for idx in indexes) 139 | if fnc is None: 140 | return getattr(self, key)(**kwargs) 141 | return fnc(self, **kwargs) 142 | 143 | method.__name__ = key 144 | if doc is not None: 145 | method.__doc__ = doc 146 | elif fnc is not None: 147 | method.__doc__ = fnc.__doc__ 148 | 149 | return method 150 | 151 | 152 | REDUCTIONS = ( 153 | ("any",), 154 | ("all",), 155 | ("min",), 156 | ("max",), 157 | ("argmin",), 158 | ("argmax",), 159 | ("sum",), 160 | ("prod",), 161 | ("mean",), 162 | ("std",), 163 | ("var",), 164 | ) 165 | 166 | for (reduction,) in REDUCTIONS: 167 | __all__.append(reduction) 168 | globals()[reduction] = reduction_method(reduction, fnc=getattr(np, reduction)) 169 | setattr(ArrayField, reduction, globals()[reduction]) 170 | if hasattr(np.ndarray, reduction): 171 | fnc = backend_reduction_method(reduction, doc=getattr(np, reduction).__doc__) 172 | else: 173 | fnc = backend_reduction_method(reduction, fnc=getattr(np, reduction)) 174 | backend_method(fnc, NumpyBackend) 175 | -------------------------------------------------------------------------------- /lyncs/field/types/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from .generic import * 3 | from .quantum import * 4 | -------------------------------------------------------------------------------- /lyncs/field/types/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base tools for defining field types. 3 | """ 4 | 5 | __all__ = [ 6 | "FieldType", 7 | "Axes", 8 | ] 9 | 10 | import re 11 | from types import MappingProxyType 12 | from collections import OrderedDict 13 | 14 | 15 | class FieldType(type): 16 | """ 17 | Metaclass for the field types. 18 | Field types are special classes used to define properties of fields 19 | depending on the axes. 20 | 21 | Field types have the following restrictions 22 | - The name of the class must be **unique**. All the field 23 | types are stored in dictionary (FieldType.s) and this list is 24 | looked up for searching properties of the field. 25 | - The FieldType needs to define an attribute __axes__ to specify 26 | which axes are needed for the properties of the type. 27 | 28 | 29 | Behaviour of field type special attributes: 30 | __axes__: list of axes that identify the field type. 31 | One can use generic names as "dims", "dofs" or properties 32 | like "space" or any of the special dimensions that may 33 | be defined on the lattice. If the current lattice does not 34 | have dimensions with these names, the field type will be simply 35 | ignored. 36 | Special characters may follow the name of the dimension as 37 | "+", "!"... 38 | The "+" means that the dimension can appear more than once. 39 | If a dimension is not followed by any special character then 40 | the "+" behavious is applied; e.g. "dims" -> "dims+". 41 | The "!" means that the specific dimension can be repeated only 42 | once. E.g. "spin!" means that only one spin dimension must be 43 | present to be of this type. Repetition of the dimension with 44 | "!" increase the counter, e.g. ["spin!", "spin!"] = "spin!!" 45 | means that the spin dimension must appear twice to be of the type. 46 | When "!" is used for a group of dimensions, i.e. "dofs!" then 47 | means that all the dofs must appear and only once. Then 48 | ["dofs!", "dofs"] = "dofs!+" means that all the dofs must appear 49 | but repetitions are allowed. 50 | 51 | - __init__ = 52 | """ 53 | 54 | __types__ = OrderedDict() 55 | s = MappingProxyType(__types__) 56 | BaseField = None 57 | Field = None 58 | 59 | @classmethod 60 | def __prepare__(cls, name, bases, **kwargs): 61 | """ 62 | Checks that the name of the class is unique and construct from 63 | the axes in the base classes (bases) the axes for the new class. 64 | """ 65 | assert not kwargs, "kwargs not used" 66 | assert name not in cls.__types__, "A FieldType named %s already exists" % name 67 | 68 | axes = Axes() 69 | for base in bases: 70 | if isinstance(base, FieldType): 71 | axes += base.axes 72 | 73 | return {"__axes__": axes} 74 | 75 | def __new__(cls, name, bases, attrs, **kwargs): 76 | "Checks that __axes__ is a valid Axes" 77 | assert not kwargs, "kwargs not used" 78 | 79 | assert "__axes__" in attrs 80 | attrs["__axes__"] = Axes(attrs["__axes__"]) 81 | return super().__new__(cls, name, bases, attrs) 82 | 83 | def __init__(cls, name, bases, attrs, **kwargs): 84 | """ 85 | Adds the class to the list of all FieldType and checks 86 | that is subclass of bases. 87 | """ 88 | assert not kwargs, "kwargs not used" 89 | 90 | for base in bases: 91 | assert issubclass( 92 | cls, base 93 | ), """ 94 | The axes defined in the class %s are not compatible 95 | with the parent class %s. 96 | Axes of %s: %s 97 | Axes of %s: %s 98 | """ % ( 99 | cls, 100 | base, 101 | cls, 102 | cls.axes, 103 | base, 104 | base.axes, 105 | ) 106 | 107 | FieldType.__types__[name] = cls 108 | FieldType.__types__.move_to_end(name, last=False) 109 | super().__init__(name, bases, attrs) 110 | 111 | def __call__(cls, *args, **kwargs): 112 | "Returns a Field with the correct axes" 113 | return cls.Field(*args, axes=kwargs.pop("axes", cls.axes.expand), **kwargs) 114 | 115 | def __subclasscheck__(cls, child): 116 | "Checks if child is subclass of class" 117 | return isinstance(child, FieldType) and cls.axes in child.axes 118 | 119 | def __instancecheck__(cls, field): 120 | "Checks if field is compatible with the class" 121 | if not isinstance(field, cls.BaseField): 122 | return False 123 | if cls.axes.labels not in field.lattice: 124 | return False 125 | axes = list(field.axes) 126 | for axis in field.lattice.expand(cls.axes.must): 127 | if axis not in axes: 128 | return False 129 | axes.remove(axis) 130 | axes = set(axes) 131 | for axis in cls.axes.may: 132 | if not axes.intersection(field.lattice.expand(axis)): 133 | return False 134 | return True 135 | 136 | @property 137 | def axes(cls): 138 | return cls.__axes__ 139 | 140 | 141 | class Axes(tuple): 142 | """ 143 | Functionalities to parse the axes. 144 | """ 145 | 146 | _get_label = re.compile(r"[a-zA-Z]([a-zA-Z0-9]|_[0-9]*[a-zA-Z])*") 147 | 148 | @classmethod 149 | def get_label(cls, key): 150 | return cls._get_label.match(key)[0] 151 | 152 | _get_count = re.compile( 153 | # r"([\+\!\?\*]|({([0-9]+,?)+(,...)?}))?$" 154 | r"[\!]*[\+]?$" 155 | ) 156 | 157 | @classmethod 158 | def get_count(cls, key): 159 | return cls._get_count.search(key)[0] 160 | 161 | _check_key = re.compile(_get_label.pattern + _get_count.pattern) 162 | 163 | @classmethod 164 | def check_keys(cls, keys): 165 | for key in keys: 166 | if not cls._check_key.match(key): 167 | raise KeyError("Invalid key: %s." % key) 168 | 169 | def __new__(cls, axes=()): 170 | if isinstance(axes, cls): 171 | return axes 172 | 173 | if isinstance(axes, str): 174 | axes = (axes,) 175 | 176 | cls.check_keys(axes) 177 | 178 | tmp = dict() 179 | for axis in axes: 180 | clean = cls.get_label(axis) 181 | sym = "!" * axis.count("!") + tmp.get(clean, "") 182 | if axis == clean or axis[-1] == "+": 183 | if sym.endswith("+"): 184 | sym = sym[:-1] + "!" 185 | sym += "+" 186 | tmp[clean] = sym 187 | 188 | axes = tuple((key + val for key, val in tmp.items())) 189 | 190 | return super().__new__(Axes, axes) 191 | 192 | def __add__(self, axes): 193 | return Axes(super().__add__(Axes(axes))) 194 | 195 | def __contains__(self, axes): 196 | axes = Axes(axes) 197 | axes = dict(zip(axes.labels, axes.counts)) 198 | this = dict(zip(self.labels, self.counts)) 199 | return all((axis in this for axis in axes)) and all( 200 | ( 201 | len(count) <= len(this[axis]) 202 | if count[-1] == "+" 203 | else count == this[axis] 204 | for axis, count in axes.items() 205 | ) 206 | ) 207 | 208 | @property 209 | def expand(self): 210 | axes = [] 211 | for axis, count in zip(self.labels, self.counts): 212 | axes += [axis] * len(count) 213 | return tuple(axes) 214 | 215 | @property 216 | def must(self): 217 | axes = [] 218 | for axis, count in zip(self.labels, self.counts): 219 | axes += [axis] * count.count("!") 220 | return tuple(axes) 221 | 222 | @property 223 | def may(self): 224 | axes = [] 225 | for axis, count in zip(self.labels, self.counts): 226 | axes += [axis] * count.count("+") 227 | return tuple(axes) 228 | 229 | @property 230 | def labels(self): 231 | return tuple((self.get_label(axis) for axis in self)) 232 | 233 | @property 234 | def counts(self): 235 | return tuple((self.get_count(axis) for axis in self)) 236 | -------------------------------------------------------------------------------- /lyncs/field/types/generic.py: -------------------------------------------------------------------------------- 1 | from .base import FieldType 2 | 3 | __all__ = [ 4 | "Scalar", 5 | "Degrees", 6 | "Sites", 7 | "Links", 8 | "Vector", 9 | "Propagator", 10 | ] 11 | 12 | 13 | class Scalar(metaclass=FieldType): 14 | "Scalar field, no axes specified" 15 | __axes__ = [] 16 | 17 | # scalar methods 18 | 19 | 20 | class Degrees(Scalar): 21 | "Field with degrees of freedom (dofs)" 22 | __axes__ = ["dofs"] 23 | 24 | # dofs operations, i.e. trace, dot 25 | 26 | 27 | class Sites(Scalar): 28 | "Field on all the sites volume" 29 | __axes__ = ["dims"] 30 | 31 | # volume methods, i.e. reductions 32 | 33 | 34 | class Links(Sites): 35 | "Field on all the links between sites" 36 | __axes__ += ["dirs"] 37 | # Oriented links 38 | 39 | 40 | class Vector(Sites): 41 | "Vector Field on all the sites volume" 42 | __axes__ += ["dofs!"] 43 | 44 | 45 | class Propagator(Sites): 46 | "Propagator Field on all the sites volume" 47 | __axes__ += ["dofs!!"] 48 | -------------------------------------------------------------------------------- /lyncs/field/types/quantum.py: -------------------------------------------------------------------------------- 1 | from .generic import Scalar, Links 2 | 3 | __all__ = [ 4 | "Gauge", 5 | "GaugeLinks", 6 | "Spinor", 7 | "SpinMatrix", 8 | ] 9 | 10 | 11 | class Gauge(Scalar): 12 | __axes__ = ["gauge"] 13 | 14 | 15 | class GaugeMatrix(Gauge): 16 | __axes__ = ["gauge!", "gauge!"] 17 | 18 | 19 | class GaugeLinks(Links, GaugeMatrix): 20 | def plaquette(self, dirs=None): 21 | dirs = dirs or self.get_range("dirs") 22 | if not set(dirs).issubset(self.get_range("dirs")): 23 | raise ValueError("Dirs not part of the field") 24 | if not len(dirs) > 1: 25 | raise ValueError("At least two dirs needed for computing plaquette") 26 | 27 | plaq = tuple( 28 | self[dir1].dot( 29 | self[dir2].roll(-1, dir1), 30 | self[dir1].roll(-1, dir2).H, 31 | self[dir2].H, 32 | trace=True, 33 | axes="all", 34 | mean=True, 35 | ) 36 | for i, dir1 in enumerate(dirs[:-1]) 37 | for dir2 in dirs[i + 1 :] 38 | ).real 39 | return plaq[0].add(plaq[1:]) / len(plaq) 40 | 41 | 42 | class Spinor(Scalar): 43 | __axes__ = ["spin"] 44 | 45 | 46 | class SpinMatrix(Spinor): 47 | __axes__ = ["spin!", "spin!"] 48 | -------------------------------------------------------------------------------- /lyncs/field/ufuncs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Universal functions for the fields 3 | """ 4 | 5 | __all__ = [ 6 | "prepare", 7 | ] 8 | 9 | import numpy as np 10 | from .array import ArrayField, NumpyBackend, backend_method 11 | 12 | 13 | def prepare(self, *fields, elemwise=True, **kwargs): 14 | """ 15 | Prepares a set of fields for a calculation. 16 | 17 | Returns: 18 | -------- 19 | fields, out_field 20 | where fields is a tuple of the fields to use in the calculation 21 | and out_field is the Field type where to store the result 22 | 23 | Parameters 24 | ---------- 25 | fields: Field(s) 26 | List of fields involved in the calculation. 27 | elemwise: bool 28 | Whether the calculation is performed element-wise, 29 | i.e. all the fields must have the same axes and in the same order. 30 | kwargs: dict 31 | List of field parameters fixed in the calculation (e.g. specific indexes_order) 32 | """ 33 | if not isinstance(self, ArrayField): 34 | raise ValueError("First field is not of ArrayField type") 35 | for idx, field in enumerate(fields): 36 | if not isinstance(field, type(self)): 37 | raise ValueError( 38 | "Field #%d of type %s is not compatible with %s" 39 | % (idx + 1, type(field), type(self)) 40 | ) 41 | 42 | # TODO: add more checks for compatibility 43 | 44 | if not fields and not kwargs: 45 | return self, () 46 | if not fields: 47 | # TODO: should check kwargs and do a copy only if needed 48 | return self.copy(**kwargs), () 49 | 50 | if elemwise: 51 | # TODO: should reorder the field giving the same order 52 | pass 53 | 54 | # TODO: should check for coords and restrict all the fields to the intersection 55 | 56 | return self, fields 57 | 58 | 59 | ArrayField.prepare = prepare 60 | 61 | 62 | def ufunc_method(key, elemwise=True, fnc=None, doc=None): 63 | """ 64 | Implementation of a field ufunc 65 | 66 | Parameters 67 | ---------- 68 | key: str 69 | The key of the method 70 | elemwise: bool 71 | Whether the calculation is performed element-wise, 72 | i.e. all the fields must have the same axes and in the same order. 73 | fnc: callable 74 | Fallback for the method in case self it is not a field 75 | """ 76 | 77 | def method(self, *args, **kwargs): 78 | if not isinstance(self, ArrayField): 79 | if fnc is None: 80 | raise TypeError( 81 | "First argument of %s must be of type Field. Given %s" 82 | % (key, type(self).__name__) 83 | ) 84 | 85 | return fnc(self, *args, **kwargs) 86 | 87 | # Deducing the dtype of the output 88 | tmp_args = ( 89 | np.ones((1), dtype=arg.dtype) if isinstance(arg, ArrayField) else arg 90 | for arg in args 91 | ) 92 | if fnc is not None: 93 | trial = fnc(np.ones((1), dtype=self.dtype), *tmp_args, **kwargs) 94 | else: 95 | trial = getattr(np.ones((1), dtype=self.dtype), key)(*tmp_args, **kwargs) 96 | 97 | # Uniforming the fields involved 98 | args = list(args) 99 | i_fields = tuple( 100 | (i, arg) for i, arg in enumerate(args) if isinstance(arg, ArrayField) 101 | ) 102 | self, fields = self.prepare( 103 | *(field for (_, field) in i_fields), elemwise=elemwise 104 | ) 105 | 106 | for (i, _), field in zip(i_fields, fields): 107 | args[i] = field.value 108 | 109 | result = getattr(self.backend, key)(*args, **kwargs) 110 | if isinstance(trial, tuple): 111 | return tuple( 112 | ( 113 | self.copy(result[i], dtype=trial.dtype) 114 | for i, trial in enumerate(trial) 115 | ) 116 | ) 117 | return self.copy(result, dtype=trial.dtype) 118 | 119 | method.__name__ = key 120 | 121 | if doc: 122 | method.__doc__ = doc 123 | elif fnc: 124 | method.__doc__ = fnc.__doc__ 125 | 126 | return method 127 | 128 | 129 | def comparison(key, eq=True): 130 | "Additional wrapper for comparisons" 131 | fnc = ufunc_method(key) 132 | 133 | def method(self, other): 134 | if self.__is__(other): 135 | return eq 136 | return fnc(self, other) 137 | 138 | return method 139 | 140 | 141 | def backend_ufunc_method(key, fnc=None, doc=None): 142 | """ 143 | Returns a method for the backend that calls 144 | the given method (key) of the field value. 145 | """ 146 | 147 | def method(self, *args, **kwargs): 148 | if fnc is None: 149 | return getattr(self, key)(*args, **kwargs) 150 | return fnc(self, *args, **kwargs) 151 | 152 | method.__name__ = key 153 | if doc is not None: 154 | method.__doc__ = doc 155 | elif fnc is not None: 156 | method.__doc__ = fnc.__doc__ 157 | 158 | return method 159 | 160 | 161 | OPERATORS = ( 162 | ("__abs__",), 163 | ("__add__",), 164 | ("__radd__",), 165 | ("__mod__",), 166 | ("__rmod__",), 167 | ("__mul__",), 168 | ("__rmul__",), 169 | ("__neg__",), 170 | ("__pow__",), 171 | ("__rpow__",), 172 | ("__sub__",), 173 | ("__rsub__",), 174 | ("__truediv__",), 175 | ("__rtruediv__",), 176 | ("__floordiv__",), 177 | ("__rfloordiv__",), 178 | ) 179 | 180 | for (op,) in OPERATORS: 181 | setattr(ArrayField, op, ufunc_method(op)) 182 | backend_method(backend_ufunc_method(op), NumpyBackend) 183 | 184 | COMPARISONS = ( 185 | ("__eq__", True), 186 | ("__gt__", False), 187 | ("__ge__", True), 188 | ("__lt__", False), 189 | ("__le__", True), 190 | ("__ne__", False), 191 | ) 192 | 193 | for (op, eq) in COMPARISONS: 194 | setattr(ArrayField, op, comparison(op, eq)) 195 | backend_method(backend_ufunc_method(op), NumpyBackend) 196 | 197 | 198 | UFUNCS = ( 199 | # math operations 200 | ("add", True), 201 | ("subtract", True), 202 | ("multiply", True), 203 | ("divide", True), 204 | ("logaddexp", False), 205 | ("logaddexp2", False), 206 | ("true_divide", True), 207 | ("floor_divide", True), 208 | ("negative", True), 209 | ("power", True), 210 | ("float_power", True), 211 | ("remainder", True), 212 | ("mod", True), 213 | ("fmod", True), 214 | ("conj", False), 215 | ("exp", False), 216 | ("exp2", False), 217 | ("log", False), 218 | ("log2", False), 219 | ("log10", False), 220 | ("log1p", False), 221 | ("expm1", False), 222 | ("sqrt", True), 223 | ("square", True), 224 | ("cbrt", False), 225 | ("reciprocal", True), 226 | # trigonometric functions 227 | ("sin", False), 228 | ("cos", False), 229 | ("tan", False), 230 | ("arcsin", False), 231 | ("arccos", False), 232 | ("arctan", False), 233 | ("arctan2", False), 234 | ("hypot", False), 235 | ("sinh", False), 236 | ("cosh", False), 237 | ("tanh", False), 238 | ("arcsinh", False), 239 | ("arccosh", False), 240 | ("arctanh", False), 241 | ("deg2rad", False), 242 | ("rad2deg", False), 243 | # comparison functions 244 | ("greater", True), 245 | ("greater_equal", True), 246 | ("less", True), 247 | ("less_equal", True), 248 | ("not_equal", True), 249 | ("equal", True), 250 | ("isneginf", False), 251 | ("isposinf", False), 252 | ("logical_and", False), 253 | ("logical_or", False), 254 | ("logical_xor", False), 255 | ("logical_not", False), 256 | ("maximum", False), 257 | ("minimum", False), 258 | ("fmax", False), 259 | ("fmin", False), 260 | # floating functions 261 | ("isfinite", True), 262 | ("isinf", True), 263 | ("isnan", True), 264 | ("signbit", False), 265 | ("copysign", False), 266 | ("nextafter", False), 267 | ("spacing", False), 268 | ("modf", False), 269 | ("ldexp", False), 270 | ("frexp", False), 271 | ("fmod", False), 272 | ("floor", True), 273 | ("ceil", True), 274 | ("trunc", False), 275 | ("round", True), 276 | # more math routines 277 | ("degrees", False), 278 | ("radians", False), 279 | ("rint", True), 280 | ("fabs", True), 281 | ("sign", True), 282 | ("absolute", True), 283 | # non-ufunc elementwise functions 284 | ("clip", True), 285 | ("isreal", False), 286 | ("iscomplex", False), 287 | ("real", False), 288 | ("imag", False), 289 | ("fix", False), 290 | ("i0", False), 291 | ("sinc", False), 292 | ("nan_to_num", True), 293 | ("isclose", True), 294 | ("allclose", True), 295 | ) 296 | 297 | for (ufunc, is_member) in UFUNCS: 298 | __all__.append(ufunc) 299 | globals()[ufunc] = ufunc_method(ufunc, fnc=getattr(np, ufunc)) 300 | if is_member: 301 | setattr(ArrayField, ufunc, globals()[ufunc]) 302 | if hasattr(np.ndarray, ufunc): 303 | fnc = backend_ufunc_method(ufunc, doc=getattr(np, ufunc).__doc__) 304 | else: 305 | fnc = backend_ufunc_method(ufunc, fnc=getattr(np, ufunc)) 306 | backend_method(fnc, NumpyBackend) 307 | 308 | setattr(ArrayField, "real", property(globals()["real"])) 309 | setattr(ArrayField, "imag", property(globals()["imag"])) 310 | -------------------------------------------------------------------------------- /lyncs/io/__init__.py: -------------------------------------------------------------------------------- 1 | "I/O functionalities for lyncs" 2 | 3 | from .base import * 4 | -------------------------------------------------------------------------------- /lyncs/io/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base tools for saving and loading data 3 | """ 4 | 5 | __all__ = [ 6 | "load", 7 | "loads", 8 | "dump", 9 | "dumps", 10 | ] 11 | 12 | from contextlib import contextmanager 13 | from importlib import import_module 14 | 15 | 16 | DOC = """ 17 | dformat: str 18 | The format of the file to read. Allowed formats (case non-sensitive): 19 | - None: deduced from file 20 | - "pkl": pickle file format 21 | - "txt", "ASCII": txt file format 22 | - "HDF5", "H5": HDF5 file format 23 | - "lime": lime file format 24 | kwargs: dict 25 | Additional list of information for performing the reading/writing. 26 | """ 27 | 28 | 29 | def load( 30 | filein, 31 | dformat=None, 32 | **kwargs, 33 | ): 34 | """ 35 | Loads data from a file and returns it as a lyncs object. 36 | 37 | Parameters 38 | ---------- 39 | filein: str, file-object 40 | The filename of the data file to read. It can also be a file-like object.""" 41 | 42 | return Format(dformat, filein, read=True).load(filein, **kwargs) 43 | 44 | 45 | def loads( 46 | data, 47 | dformat=None, 48 | **kwargs, 49 | ): 50 | """ 51 | Loads data from a raw string of bytes and returns it as a lyncs object. 52 | 53 | Parameters 54 | ----------""" 55 | 56 | return Format(dformat, data=data).loads(data, **kwargs) 57 | 58 | 59 | def dump(obj, fileout, dformat=None, **kwargs): 60 | """ 61 | Saves a lyncs object into a file. 62 | 63 | Parameters 64 | ---------- 65 | fileout: str 66 | The filename of the data file to write. It can also be a file-like object.""" 67 | 68 | return Format(dformat, fileout, obj=obj).dump(obj, fileout, **kwargs) 69 | 70 | 71 | def dumps(obj, dformat=None, **kwargs): 72 | """ 73 | Dumps a lyncs object as a string of bytes. 74 | 75 | Parameters 76 | ----------""" 77 | 78 | return Format(dformat, obj=obj).dumps(obj, **kwargs) 79 | 80 | 81 | load.__doc__ += DOC 82 | loads.__doc__ += DOC 83 | dump.__doc__ += DOC 84 | dumps.__doc__ += DOC 85 | 86 | 87 | @contextmanager 88 | def fopen(filename, *args, **kwargs): 89 | "Opens filename if is a string otherwise consider it as a file-like object" 90 | 91 | if isinstance(filename, str): # filename 92 | with open(filename, *args, **kwargs) as fin: 93 | yield fin 94 | else: # file-like object 95 | yield filename 96 | 97 | 98 | class Format(str): 99 | """ 100 | Holder for file formats. 101 | Deduces the format from the input and returns functions from the respective module. 102 | See Format.s for the file formats available. 103 | The key is the module name and the value is a tuple of the aliases (lower case). 104 | """ 105 | 106 | # module: (aliases,) !!NOTE!! use only lower cases for aliases 107 | s = { 108 | "pickle": ( 109 | "pkl", 110 | "pickle", 111 | ), 112 | "lyncs.io.json": ( 113 | "json", 114 | "txt", 115 | "ascii", 116 | ), 117 | "lyncs.io.hdf5": ("hdf5", "h5"), 118 | "lyncs.io.lime": ("lime",), 119 | } 120 | 121 | def __new__(cls, value, filename=None, read=False, data=None, obj=None): 122 | "Multiple ways to deduce the file format" 123 | 124 | if isinstance(value, str): 125 | if value in Format.s: 126 | return super().__new__(cls, value) 127 | value = value.lower() 128 | for key, aliases in Format.s.items(): 129 | if value in aliases: 130 | return cls(key) 131 | raise ValueError("Could not deduce the format from %s" % value) 132 | 133 | if isinstance(filename, str): 134 | try: # Deducing from the extension 135 | return cls(filename.split(".")[-1]) 136 | except ValueError: 137 | pass 138 | 139 | if read: 140 | with fopen(filename, "r") as fin: 141 | return cls(value, data=fin.read(1024)) 142 | 143 | if data is not None: 144 | for key in Format.s: 145 | try: 146 | if cls(key).is_compatible(data): 147 | return cls(key) 148 | except (ValueError, ImportError, AttributeError): 149 | pass 150 | 151 | if obj is not None and hasattr("__lyncs_file_format__"): 152 | return cls(obj.__lyncs_file_format__()) 153 | 154 | raise ValueError( 155 | "Not enough information has been given to deduce the file format." 156 | ) 157 | 158 | def __init__(self, *args, **kwargs): 159 | assert self in Format.s 160 | 161 | try: 162 | import_module(self) 163 | except ImportError as err: 164 | raise err 165 | 166 | super().__init__() 167 | 168 | def __getattr__(self, key): 169 | return getattr(import_module(self), key) 170 | 171 | def __eq__(self, other): 172 | return super().__eq__(other) or other in Format.s[self] 173 | -------------------------------------------------------------------------------- /lyncs/io/lime/__init__.py: -------------------------------------------------------------------------------- 1 | from ...tunable import Tunable, computable, delayed 2 | 3 | # List of available implementations 4 | engines = [ 5 | "pylime", 6 | ] 7 | 8 | from . import lime as pylime 9 | 10 | default_engine = "pylime" 11 | 12 | try: 13 | from . import DDalphaAMG 14 | 15 | engines.append("DDalphaAMG") 16 | except: 17 | pass 18 | # from lyncs import config 19 | # if config.clime_enabled: 20 | # import .clime 21 | # engines.append("clime") 22 | 23 | # TODO: add more, e.g. lemon 24 | 25 | 26 | @computable 27 | def engine(engine=None): 28 | import sys 29 | 30 | engine = engine or default_engine 31 | self = sys.modules[__name__] 32 | return getattr(self, engine) 33 | 34 | 35 | def is_compatible(filename): 36 | try: 37 | return pylime.is_lime_file(filename) 38 | except: 39 | return False 40 | 41 | 42 | def get_lattice(records): 43 | assert "ildg-format" in records, "ildg-format not found" 44 | 45 | from lyncs import Lattice 46 | import xmltodict 47 | 48 | info = xmltodict.parse(records["ildg-format"])["ildgFormat"] 49 | 50 | return Lattice( 51 | dims={ 52 | "t": int(info["lt"]), 53 | "x": int(info["lx"]), 54 | "y": int(info["ly"]), 55 | "z": int(info["lz"]), 56 | }, 57 | dofs="QCD", 58 | ) 59 | 60 | 61 | def get_field_type(records): 62 | assert "ildg-format" in records, "ildg-format not found" 63 | 64 | import xmltodict 65 | 66 | info = xmltodict.parse(records["ildg-format"])["ildgFormat"] 67 | 68 | if info["field"] in [ 69 | "su3gauge", 70 | ]: 71 | return "gauge_links" 72 | else: 73 | # TODO 74 | assert False, "To be implemented" 75 | 76 | 77 | def get_type(filename, lattice=None, field_type=None, **kwargs): 78 | records = pylime.scan_file(filename) 79 | records = { 80 | r["lime_type"]: (r["data"] if "data" in r else r["data_length"]) 81 | for r in records 82 | } 83 | 84 | assert "ildg-binary-data" in records, "ildg-binary-data not found" 85 | 86 | read_lattice = None 87 | try: 88 | read_lattice = get_lattice(records) 89 | except AssertionError: 90 | if not lattice: 91 | raise 92 | 93 | if lattice and read_lattice: 94 | assert ( 95 | lattice == read_lattice 96 | ), "Given lattice not compatible with the one read from file" 97 | else: 98 | lattice = read_lattice 99 | 100 | read_field_type = None 101 | try: 102 | read_field_type = get_field_type(records) 103 | except AssertionError: 104 | if not field_type: 105 | raise 106 | 107 | if field_type and read_field_type: 108 | assert ( 109 | field_type == read_field_type 110 | ), "Given field_type not compatible with the one read from file" 111 | else: 112 | field_type = read_field_type 113 | 114 | from lyncs import Field 115 | 116 | import xmltodict 117 | 118 | info = xmltodict.parse(records["ildg-format"])["ildgFormat"] 119 | 120 | field = Field( 121 | lattice=lattice, 122 | field_type=field_type, 123 | dtype="complex%d" % (int(info["precision"]) * 2), 124 | ) 125 | 126 | assert ( 127 | field.byte_size == records["ildg-binary-data"] 128 | ), """ 129 | Size of deduced field (%s) is not compatible with size of data (%s). 130 | """ % ( 131 | field.byte_size, 132 | records["ildg-binary-data"], 133 | ) 134 | 135 | return field 136 | 137 | 138 | def fixed_options(field, key): 139 | dims_order = ["t", "z", "y", "x"] 140 | 141 | if key == "axes_order": 142 | if field.field_type == "gauge_links": 143 | return dims_order + ["n_dims", "color", "color"] 144 | elif key == "color_order": 145 | return [0, 1] 146 | elif key == "dirs_order": 147 | return dims_order 148 | else: 149 | # TODO 150 | assert False, "To be implemented" 151 | 152 | 153 | class file_manager(Tunable): 154 | def __init__(self, field, **kwargs): 155 | self.field = field 156 | from ...tunable import Choice 157 | 158 | self.add_option("lime_engine", Choice(engines)) 159 | 160 | for key, val in kwargs.items(): 161 | assert hasattr(self, key), "Attribute %s not found" % key 162 | setattr(self, key, val) 163 | 164 | @property 165 | def filename(self): 166 | return getattr(self, "_filename", None) 167 | 168 | @filename.setter 169 | def filename(self, value): 170 | assert isinstance(value, str), "Filename must be a string" 171 | from os.path import abspath 172 | 173 | self._filename = abspath(value) 174 | 175 | @property 176 | def engine(self): 177 | return engine(self.lime_engine) 178 | 179 | def read(self, **kwargs): 180 | from ...field import Field 181 | 182 | filename = kwargs.get("filename", self.filename) 183 | field = Field(self.field, zeros_init=True, **self.fixed_options) 184 | 185 | field.field = self.engine.read( 186 | filename, 187 | shape=field.field_shape, 188 | chunks=field.field_chunks, 189 | ) 190 | return field 191 | 192 | @property 193 | def fixed_options(self): 194 | opts = {} 195 | opts["axes_order"] = fixed_options(self.field, "axes_order") 196 | 197 | if self.field.field_type == "gauge_links": 198 | opts["color_order"] = fixed_options(self.field, "color_order") 199 | opts["dirs_order"] = fixed_options(self.field, "dirs_order") 200 | else: 201 | # TODO 202 | assert False, "To be implemented" 203 | 204 | return opts 205 | 206 | def __dask_tokenize__(self): 207 | from dask.base import normalize_token 208 | 209 | return normalize_token((type(self), self.filename)) 210 | -------------------------------------------------------------------------------- /lyncs/io/lime/lime.py: -------------------------------------------------------------------------------- 1 | from ...tunable import computable 2 | 3 | # Constants 4 | lime_header_size = 144 5 | lime_magic_number = 1164413355 6 | lime_file_version_number = 1 7 | lime_type_length = 128 8 | 9 | 10 | def read_type(file, type): 11 | import struct 12 | 13 | data = file.read(struct.calcsize(type)) 14 | return struct.unpack_from(type, data)[0] 15 | 16 | 17 | def is_lime_file(filename): 18 | with open(filename, "rb") as f: 19 | magic_number = read_type(f, ">l") 20 | return magic_number == lime_magic_number 21 | 22 | 23 | def scan_file(filename): 24 | "Scans the content of a lime file and returns the list of records" 25 | 26 | def read_record_data(record): 27 | "Conditions when the data of a record should be read in this function" 28 | if record["lime_type"] in [ 29 | "ildg-binary-data", 30 | ]: 31 | return False 32 | elif record["lime_type"] in [ 33 | "ildg-format", 34 | ]: 35 | return True 36 | elif records[-1]["data_length"] < 1000: 37 | return True 38 | return False 39 | 40 | import os 41 | 42 | fsize = os.path.getsize(filename) 43 | records = [] 44 | with open(filename, "rb") as f: 45 | pos = 0 46 | while pos + lime_header_size < fsize: 47 | f.seek(pos) 48 | records.append( 49 | dict( 50 | pos=pos + lime_header_size, 51 | magic_number=read_type(f, ">l"), 52 | file_version_number=read_type(f, ">h"), 53 | msg_bits=read_type(f, ">h"), 54 | data_length=read_type(f, ">q"), 55 | lime_type=read_type(f, "%ds" % (lime_type_length,)) 56 | .decode() 57 | .split("\0")[0], 58 | ) 59 | ) 60 | assert records[-1]["magic_number"] == lime_magic_number 61 | records[-1]["MBbit"] = (records[-1]["msg_bits"] & (1 << 15)) >> 15 62 | records[-1]["MEbit"] = (records[-1]["msg_bits"] & (1 << 14)) >> 14 63 | if read_record_data(records[-1]): 64 | records[-1]["data"] = read_type( 65 | f, "%ds" % (records[-1]["data_length"],) 66 | ).decode() 67 | pos = ( 68 | records[-1]["pos"] 69 | + records[-1]["data_length"] 70 | + ((8 - records[-1]["data_length"] % 8) % 8) 71 | ) 72 | 73 | return records 74 | 75 | 76 | def read_chunk(filename, shape, dtype, data_offset, chunks=None, chunk_id=None): 77 | import numpy as np 78 | from itertools import product 79 | 80 | shape = np.array(shape) 81 | chunks = np.array(chunks or shape) 82 | chunk_id = np.array(chunk_id or np.zeros_like(shape)) 83 | 84 | n_chunks = shape // chunks 85 | 86 | start = ( 87 | [ 88 | 0, 89 | ] 90 | + list(np.where(n_chunks > 1)[0]) 91 | )[-1] 92 | consecutive = np.prod(chunks[start:]) 93 | n_reads = np.prod(chunks) // consecutive 94 | 95 | if n_reads == 1: 96 | offset = 0 97 | for i, l, L in zip(chunk_id, chunks, shape): 98 | offset = offset * L + i * l 99 | offset *= dtype.itemsize 100 | offset += data_offset 101 | return np.fromfile( 102 | filename, dtype=dtype, count=consecutive, offset=offset 103 | ).reshape(chunks) 104 | 105 | arr = np.ndarray(tuple(chunks[:start]) + (consecutive,), dtype=dtype) 106 | read_ids = list(product(*[range(l) for l in chunks[:start]])) 107 | assert len(read_ids) == n_reads 108 | 109 | for read_id in read_ids: 110 | offset = 0 111 | for i, j, l, L in zip( 112 | chunk_id, 113 | read_id + tuple(0 for i in range(len(shape) - start)), 114 | chunks, 115 | shape, 116 | ): 117 | offset = offset * L + i * l + j 118 | offset *= dtype.itemsize 119 | offset += data_offset 120 | 121 | arr[read_id] = np.fromfile( 122 | filename, dtype=dtype, count=consecutive, offset=offset 123 | ) 124 | 125 | return arr.reshape(chunks) 126 | 127 | 128 | @computable 129 | def read(filename, shape, chunks): 130 | from dask.highlevelgraph import HighLevelGraph 131 | from dask.array.core import normalize_chunks, Array 132 | from itertools import product 133 | from ...tunable import delayed 134 | from numpy import prod, dtype 135 | import xmltodict 136 | 137 | records = scan_file(filename) 138 | records = {r["lime_type"]: r for r in records} 139 | 140 | data_record = records["ildg-binary-data"] 141 | data_offset = data_record["pos"] 142 | 143 | info = xmltodict.parse(records["ildg-format"]["data"])["ildgFormat"] 144 | dtype = dtype("complex%d" % (int(info["precision"]) * 2)) 145 | 146 | assert data_record["data_length"] == prod(shape) * dtype.itemsize 147 | 148 | normal_chunks = normalize_chunks(chunks, shape=shape) 149 | chunks_id = list(product(*[range(len(bd)) for bd in normal_chunks])) 150 | 151 | reads = [ 152 | delayed(read_chunk)(filename, shape, dtype, data_offset, chunks, chunk_id) 153 | for chunk_id in chunks_id 154 | ] 155 | 156 | keys = [(filename, *chunk_id) for chunk_id in chunks_id] 157 | vals = [read.key for read in reads] 158 | dsk = dict(zip(keys, vals)) 159 | 160 | graph = HighLevelGraph.from_collections(filename, dsk, dependencies=reads) 161 | 162 | return Array(graph, filename, normal_chunks, dtype=dtype) 163 | -------------------------------------------------------------------------------- /lyncs/lattice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Definition of the Lattice class and related routines 3 | """ 4 | 5 | __all__ = [ 6 | "default_lattice", 7 | "Lattice", 8 | ] 9 | 10 | import re 11 | import random 12 | from types import MappingProxyType 13 | from functools import partial, wraps 14 | from typing import Callable 15 | from inspect import signature 16 | from lyncs_utils import default_repr_pretty, isiterable, FreezableDict, compact_indexes 17 | from .field.base import BaseField 18 | from .field.types.base import Axes, FieldType 19 | 20 | 21 | def default_lattice(): 22 | "Returns the last defined lattice if any" 23 | assert Lattice.last_defined is not None, "Any lattice has been defined yet." 24 | return Lattice.last_defined 25 | 26 | 27 | class LatticeDict(FreezableDict): 28 | "Dictionary for lattice attributes. Checks the given keys." 29 | regex = re.compile(Axes._get_label.pattern + "$") 30 | 31 | def __new__(cls, *args, **kwargs): 32 | self = super().__new__(cls, *args, **kwargs) 33 | self.lattice = None 34 | return self 35 | 36 | def __init__(self, val=None, lattice=None, check=True): 37 | if lattice is not None and not isinstance(lattice, Lattice): 38 | raise ValueError("Lattice must be of Lattice type") 39 | self.lattice = lattice 40 | 41 | if check: 42 | super().__init__() 43 | if val is not None: 44 | for _k, _v in dict(val).items(): 45 | self[_k] = _v 46 | else: 47 | super().__init__(val) 48 | 49 | def __setitem__(self, key, val): 50 | if not type(self).regex.match(key): 51 | raise KeyError( 52 | """ 53 | Invalid key: %s. Keys can only contain letters, numbers or '_'. 54 | Keys must start with a letter and cannot end with '_' followed by number. 55 | """ 56 | % key 57 | ) 58 | if key not in self and self.lattice is not None: 59 | if key in self.lattice.__dir__(): 60 | raise KeyError("%s is already in use" % (key)) 61 | super().__setitem__(key, val) 62 | 63 | @wraps(dict.copy) 64 | def copy(self): 65 | return type(self)(self, lattice=self.lattice, check=False) 66 | 67 | def rename(self, key, new_key): 68 | "Renames a key of the dictionary" 69 | self[new_key] = self.pop(key) 70 | 71 | def reset(self, val=None): 72 | "Resets the content of the dictionary" 73 | # TODO: in case of reset/delitem the it should check if all the other 74 | # entries of the lattice are still valid. And in case remove them. 75 | tmp = self.copy() 76 | self.clear() 77 | try: 78 | self.update(val) 79 | except (ValueError, TypeError, KeyError): 80 | self.clear() 81 | self.update(tmp) 82 | raise 83 | 84 | 85 | class LatticeAxes(LatticeDict): 86 | "Dictionary for lattice axes. Values must be positive integers." 87 | 88 | def __setitem__(self, key, val): 89 | if not isinstance(val, int) or val <= 0: 90 | raise ValueError( 91 | "%s = %s not allowed. The value must be a positive int." % (key, val) 92 | ) 93 | super().__setitem__(key, val) 94 | 95 | 96 | class LatticeLabels(LatticeDict): 97 | "Dictionary for lattice labels. Values must be unique strings." 98 | 99 | def labels(self): 100 | "Returns all the field labels" 101 | for value in self.values(): 102 | yield from value 103 | 104 | def __setitem__(self, key, val): 105 | if isinstance(val, str): 106 | val = (val,) 107 | if not isiterable(val, str): 108 | raise TypeError("Labels value can only be a list of strings") 109 | 110 | val = tuple(val) 111 | if not len(set(val)) == len(val): 112 | raise ValueError("%s contains repeated labels" % (val,)) 113 | 114 | labels = set(self.labels()) 115 | if key in self: 116 | labels = labels.difference(self[key]) 117 | inter = labels.intersection(val) 118 | if inter: 119 | raise ValueError("%s are labels already in use" % inter) 120 | 121 | super().__setitem__(key, val) 122 | 123 | 124 | class LatticeGroups(LatticeDict): 125 | "Dictionary for lattice groups. Values must be a set of lattice keys." 126 | regex = re.compile("[a-zA-Z_][a-zA-Z0-9_]*$") 127 | 128 | def __setitem__(self, key, val): 129 | if key in self and isinstance(val, int): 130 | for _k in self[key]: 131 | self.lattice[_k] = val 132 | return 133 | if isinstance(val, str): 134 | val = (val,) 135 | if not isiterable(val, str): 136 | raise TypeError("Groups value can only be a list of strings") 137 | 138 | if self.lattice is not None: 139 | val = tuple(val) 140 | keys = set(self.lattice.keys()) 141 | if not keys >= set(val): 142 | raise ValueError("%s are not lattice keys" % set(val).difference(keys)) 143 | 144 | super().__setitem__(key, val) 145 | 146 | def replace(self, key, new_key): 147 | "Replaces a key with a the new key" 148 | 149 | for _key, val in self.items(): 150 | if key in val: 151 | val = list(val) 152 | val[val.index(key)] = new_key 153 | self[_key] = val 154 | 155 | 156 | class Lattice: 157 | """ 158 | Lattice base class. 159 | A container for all the lattice information. 160 | """ 161 | 162 | last_defined = None 163 | default_dims_labels = ["t", "x", "y", "z"] 164 | theories = { 165 | "QCD": { 166 | "spin": 4, 167 | "color": 3, 168 | "groups": { 169 | "gauge": ["color"], 170 | }, 171 | } 172 | } 173 | 174 | __slots__ = [ 175 | "_dims", 176 | "_dofs", 177 | "_labels", 178 | "_groups", 179 | "_coords", 180 | "_maps", 181 | "_fields", 182 | "_frozen", 183 | ] 184 | _repr_pretty_ = default_repr_pretty 185 | 186 | def __new__(cls, *args, **kwargs): 187 | # pylint: disable=W0613 188 | self = super().__new__(cls) 189 | self._fields = None 190 | self._frozen = False 191 | self.dims = None 192 | self.dofs = None 193 | self.labels = None 194 | self.groups = None 195 | self.coords = None 196 | self.maps = None 197 | return self 198 | 199 | def __init__( 200 | self, dims=4, dofs="QCD", labels=None, groups=None, coords=None, maps=None 201 | ): 202 | """ 203 | Lattice initializer. 204 | 205 | Notation 206 | -------- 207 | Dimensions: (dims) are axes of the Lattice which size is variable. 208 | The volume of the lattice, i.e. number of sites, is given by the product 209 | of dims. Dims are usually the axes where one can parallelize on. 210 | Degrees of Freedoms: (dofs) are local axes with fixed size (commonly small). 211 | Labels: (labels) are labelled axes of the lattice. Similar to dofs but instead 212 | of having a size they have a list of unique labels (str, int, hash-able) 213 | Axes: Any of the above, i.e. list of axes of the field. 214 | 215 | Parameters 216 | ---------- 217 | dims: int, list or dict (default 4) 218 | Dimensions (default names: t,x,y,z if less than 5 or dim0/1/2...) 219 | - int: number of dimensions. The default names will be used. 220 | - list: size of the dimensions. The default names will be used. 221 | - dict: names of the dimensions (keys) and sizes (value) 222 | dofs: str, int, list, dict (default QCD) 223 | Specifies local degree of freedoms. (default naming: dof0/1/2...) 224 | - str: one of the defined theories (QCD,...). See Lattice.theories 225 | - int: size of one degree of freedom 226 | - list: size per dimension of the degrees of freedom 227 | - dict: names of the degree of freedom (keys) and sizes (value) 228 | labels: dict 229 | Labelled dimensions of the lattice. A dictionary with keys the names of the 230 | dimensions and with values the list of labels. The labels must be unique. 231 | The size of the dimension is the number of labels. 232 | groups: dict 233 | Grouping of the dimensions. Each entry of the dictionary must contain a str 234 | or a list of strings that refer to either another label a dimension. 235 | coords: dict 236 | Coordinates of the lattice. Each entry of the dictionary must contain a set 237 | of coordinates 238 | """ 239 | self.dims = dims 240 | self.dofs = dofs 241 | self.labels.update(labels) 242 | self.groups.update(groups) 243 | self.coords.update(coords) 244 | self.maps.update(maps) 245 | 246 | Lattice.last_defined = self 247 | 248 | @property 249 | def frozen(self): 250 | """ 251 | Returns if the current lattice instance is frozen, i.e. cannot be changed anymore. 252 | To unfreeze it use lattice.copy. 253 | """ 254 | return self._frozen 255 | 256 | @frozen.setter 257 | def frozen(self, value): 258 | if value != self.frozen: 259 | if value is False: 260 | raise ValueError( 261 | "Frozen can only be changed to True. To unfreeze do a copy." 262 | ) 263 | self.dims.frozen = True 264 | self.dofs.frozen = True 265 | self.labels.frozen = True 266 | self.groups.allows_changes = False 267 | self.groups.allows_changes = False 268 | self._fields = self.fields 269 | self._frozen = True 270 | 271 | def freeze(self): 272 | "Returns a frozen copy of the lattice" 273 | if self.frozen: 274 | return self 275 | copy = self.copy() 276 | copy.frozen = True 277 | return copy 278 | 279 | @property 280 | def dims(self): 281 | "Map of lattice dimensions and their size" 282 | return self._dims 283 | 284 | @dims.setter 285 | def dims(self, value): 286 | if self.frozen: 287 | raise RuntimeError("The lattice has been frozen and dims cannot be changed") 288 | 289 | if not value: 290 | self._dims = LatticeAxes(lattice=self) 291 | return 292 | 293 | if isinstance(value, (dict, MappingProxyType)): 294 | self.dims.reset(value) 295 | 296 | # Adding default labels and groups 297 | dirs = list(self.dims) 298 | self.labels.setdefault("dirs", dirs) 299 | if len(dirs) > 1: 300 | self.groups.setdefault("time", (dirs[0],)) 301 | self.groups.setdefault("space", tuple(dirs[1:])) 302 | return 303 | 304 | if isinstance(value, int): 305 | if value < 0: 306 | raise ValueError("Non-positive number of dims") 307 | self.dims = [1] * value 308 | return 309 | 310 | if isiterable(value, int): 311 | if len(value) <= len(Lattice.default_dims_labels): 312 | self.dims = { 313 | Lattice.default_dims_labels[i]: v for i, v in enumerate(value) 314 | } 315 | else: 316 | self.dims = {"dim%d" % i: v for i, v in enumerate(value)} 317 | return 318 | 319 | if isiterable(value, str): 320 | self.dims = {v: 1 for v in value} 321 | return 322 | 323 | raise TypeError("Not allowed type %s for dims" % type(value)) 324 | 325 | @property 326 | def dofs(self): 327 | "Map of lattice degrees of freedom and their size" 328 | return self._dofs 329 | 330 | @dofs.setter 331 | def dofs(self, value): 332 | if self.frozen: 333 | raise RuntimeError("The lattice has been frozen and dofs cannot be changed") 334 | 335 | if not value: 336 | self._dofs = LatticeAxes(lattice=self) 337 | return 338 | 339 | if isinstance(value, (dict, MappingProxyType)): 340 | self.dofs.reset(value) 341 | return 342 | 343 | if isinstance(value, str): 344 | assert value in Lattice.theories, "Unknown dofs name" 345 | value = Lattice.theories[value].copy() 346 | labels = value.pop("labels", {}) 347 | groups = value.pop("groups", {}) 348 | self.dofs = value 349 | self.labels.update(labels) 350 | self.groups.update(groups) 351 | return 352 | 353 | if isinstance(value, int): 354 | if value < 0: 355 | raise ValueError("Non-positive number of dofs") 356 | self.dofs = [1] * value 357 | return 358 | 359 | if isiterable(value, int): 360 | self.dofs = {"dof%d" % i: v for i, v in enumerate(value)} 361 | return 362 | 363 | if isiterable(value, str): 364 | self.dofs = {v: 1 for v in value} 365 | return 366 | 367 | raise TypeError("Not allowed type %s for dofs" % type(value)) 368 | 369 | @property 370 | def labels(self): 371 | "List of labels of the lattice" 372 | return self._labels 373 | 374 | @labels.setter 375 | def labels(self, value): 376 | if self.frozen: 377 | raise RuntimeError( 378 | "The lattice has been frozen and labels cannot be changed" 379 | ) 380 | 381 | if not value: 382 | self._labels = LatticeLabels(lattice=self) 383 | return 384 | 385 | if isinstance(value, (dict, MappingProxyType)): 386 | self.labels.reset(value) 387 | return 388 | 389 | raise TypeError("Not allowed type %s for labels" % type(value)) 390 | 391 | def add_label(self, key, value): 392 | "Adds a label to the lattice" 393 | self.labels[key] = value 394 | 395 | @property 396 | def groups(self): 397 | "List of groups of the lattice" 398 | return self._groups 399 | 400 | @groups.setter 401 | def groups(self, value): 402 | if self.frozen: 403 | raise RuntimeError( 404 | "The lattice has been frozen and groups cannot be changed" 405 | ) 406 | 407 | if not value: 408 | self._groups = LatticeGroups(lattice=self) 409 | return 410 | 411 | if isinstance(value, (dict, MappingProxyType)): 412 | self.groups.reset(value) 413 | return 414 | 415 | raise TypeError("Not allowed type %s for groups" % type(value)) 416 | 417 | def add_group(self, key, value): 418 | "Adds a group to the lattice" 419 | self.groups[key] = value 420 | 421 | @property 422 | def coords(self): 423 | "List of coordinates of the lattice" 424 | return self._coords 425 | 426 | @coords.setter 427 | def coords(self, value): 428 | if self.frozen: 429 | raise RuntimeError( 430 | "The lattice has been frozen and coords cannot be changed" 431 | ) 432 | 433 | if not value: 434 | self._coords = LatticeCoords(lattice=self) 435 | return 436 | 437 | if isinstance(value, (dict, MappingProxyType)): 438 | self.coords.reset(value) 439 | return 440 | 441 | raise TypeError("Not allowed type %s for coordinates" % type(value)) 442 | 443 | def add_coord(self, key, value): 444 | "Adds a coord to the lattice" 445 | self.coords[key] = value 446 | 447 | @property 448 | def maps(self): 449 | "List of maps of the lattice" 450 | return self._maps 451 | 452 | @maps.setter 453 | def maps(self, value): 454 | if not value: 455 | self._maps = LatticeMaps(lattice=self) 456 | return 457 | 458 | if isinstance(value, (dict, MappingProxyType)): 459 | self.maps.reset(value) 460 | return 461 | 462 | raise TypeError("Not allowed type %s for maps" % type(value)) 463 | 464 | def add_map(self, new_lattice, mapping, unmapping=None, label=None, unlabel=None): 465 | "Adds a map to the lattice" 466 | 467 | if not isinstance(new_lattice, Lattice): 468 | raise TypeError( 469 | f"Given new_lattice of type {type(new_lattice)} is not a Lattice" 470 | ) 471 | new_lattice = new_lattice.copy() 472 | 473 | key = 1 474 | while "map%d" % key in self: 475 | key += 1 476 | key = label or getattr(mapping, "__name__", "map%d" % key) 477 | self.maps[key] = LatticeMap(self, new_lattice, mapping) 478 | 479 | if unmapping: 480 | new_lattice.add_map(self, unmapping, label=unlabel) 481 | 482 | def __eq__(self, other): 483 | return self is other or ( 484 | isinstance(other, Lattice) 485 | and self.dims == other.dims 486 | and self.dofs == other.dofs 487 | and self.labels == other.labels 488 | and self.groups == other.groups 489 | ) 490 | 491 | @property 492 | def axes(self): 493 | "Complete list of axes of the lattice" 494 | axes = list(self.dims.keys()) 495 | axes.extend(self.dofs.keys()) 496 | axes.extend(self.labels.keys()) 497 | return tuple(axes) 498 | 499 | def keys(self): 500 | "Complete list of keys of the lattice" 501 | yield "dims" 502 | yield "dofs" 503 | yield "labels" 504 | yield from self.dims.keys() 505 | yield from self.dofs.keys() 506 | yield from self.labels.keys() 507 | yield from self.groups.keys() 508 | yield from self.coords.keys() 509 | yield from self.maps.keys() 510 | 511 | def rename(self, key, new_key): 512 | "Renames a dimension within the lattice" 513 | 514 | if key == new_key: 515 | return 516 | 517 | if self.frozen: 518 | raise RuntimeError( 519 | "The lattice has been frozen and dimensions cannot be renamed" 520 | ) 521 | 522 | if new_key in dir(self): 523 | raise KeyError("%s is already in use" % (new_key)) 524 | 525 | key_is_axis = key in self.axes 526 | key_found = False 527 | for dct in ( 528 | self.dims, 529 | self.dofs, 530 | self.labels, 531 | self.groups, 532 | self.coords, 533 | self.maps, 534 | ): 535 | if key in dct: 536 | dct.rename(key, new_key) 537 | key_found = True 538 | 539 | if not key_found: 540 | raise KeyError("%s not found" % (key)) 541 | 542 | if key_is_axis: 543 | for dct in self.groups, self.coords, self.maps: 544 | dct.replace(key, new_key) 545 | 546 | def expand(self, *dimensions): 547 | "Expand the list of dimensions into the fundamental dimensions and degrees of freedom" 548 | for dim in dimensions: 549 | if isinstance(dim, str): 550 | if dim not in self.keys(): 551 | raise ValueError("Unknown dimension: %s" % dim) 552 | if dim in self.axes: 553 | yield dim 554 | else: 555 | yield from self.expand(self[dim]) 556 | elif isiterable(dim): 557 | yield from self.expand(*dim) 558 | else: 559 | raise TypeError("Unexpected type %s with value %s" % (type(dim), dim)) 560 | 561 | def get_axis_range(self, axis): 562 | "Returns the range of the given axis" 563 | if axis not in self.axes: 564 | raise ValueError("%s is not a lattice axis" % axis) 565 | if axis in self.labels: 566 | return self.labels[axis] 567 | return range(self[axis]) 568 | 569 | def get_axis_size(self, axis): 570 | "Returns the range of the given axis" 571 | if axis not in self.axes: 572 | raise ValueError("%s is not a lattice axis" % axis) 573 | if axis in self.labels: 574 | return len(self.labels[axis]) 575 | return self[axis] 576 | 577 | @property 578 | def fields(self): 579 | "List of available field types on the lattice" 580 | if self._fields is not None: 581 | return self._fields 582 | fields = ["Field"] 583 | for name, ftype in FieldType.s.items(): 584 | if ftype.axes.labels in self: 585 | fields.append(name) 586 | return tuple(sorted(fields)) 587 | 588 | @property 589 | def Field(self): 590 | "Returns the base Field type class initializer" 591 | return partial(FieldType.Field, lattice=self) 592 | 593 | def __dir__(self): 594 | yield from dir(type(self)) 595 | yield from self.keys() 596 | yield from self.coords 597 | yield from self.maps 598 | yield from self.fields 599 | 600 | def __contains__(self, key): 601 | if isinstance(key, str): 602 | return key in self.keys() 603 | keys = list(self.keys()) 604 | return all((k in keys for k in key)) 605 | 606 | def __getitem__(self, key): 607 | try: 608 | return getattr(type(self), key).__get__(self) 609 | except AttributeError: 610 | for attr in ["dims", "dofs", "labels", "groups", "coords", "maps"]: 611 | if key in getattr(self, attr): 612 | return getattr(self, attr)[key] 613 | if key in self.fields: 614 | return partial(FieldType.s[key], lattice=self) 615 | raise 616 | 617 | __getattr__ = __getitem__ 618 | 619 | def __setitem__(self, key, value): 620 | try: 621 | getattr(type(self), key).__set__(self, value) 622 | except AttributeError: 623 | for attr in ["dims", "dofs", "labels", "groups", "coords"]: 624 | if key in getattr(self, attr): 625 | getattr(self, attr).__setitem__(key, value) 626 | return 627 | raise 628 | 629 | __setattr__ = __setitem__ 630 | 631 | def copy(self, **kwargs): 632 | "Returns a copy of the lattice." 633 | kwargs.setdefault("dims", self.dims) 634 | kwargs.setdefault("dofs", self.dofs) 635 | kwargs.setdefault("groups", self.groups) 636 | kwargs.setdefault("labels", self.labels) 637 | kwargs.setdefault("coords", self.coords) 638 | kwargs.setdefault("maps", self.maps) 639 | return Lattice(**kwargs) 640 | 641 | def __copy__(self): 642 | return self.copy() 643 | 644 | def __getstate__(self): 645 | return ( 646 | self.dims, 647 | self.dofs, 648 | self.labels, 649 | self.groups, 650 | self.coords, 651 | self.maps, 652 | self.frozen, 653 | ) 654 | 655 | def __setstate__(self, state): 656 | ( 657 | self.dims, 658 | self.dofs, 659 | self.labels, 660 | self.groups, 661 | self.coords, 662 | self.maps, 663 | self.frozen, 664 | ) = state 665 | 666 | 667 | class Coordinates(FreezableDict): 668 | "Dictionary for coordinates" 669 | 670 | def __init__(self, val=None): 671 | super().__init__() 672 | if val is not None: 673 | for _k, _v in dict(val).items(): 674 | self[_k] = _v 675 | 676 | @classmethod 677 | def expand(cls, *indexes): 678 | "Expands all the indexes in the list." 679 | for idx in indexes: 680 | if isinstance(idx, (int, str, slice, range, type(None))): 681 | yield idx 682 | elif isiterable(idx): 683 | yield from cls.expand(*idx) 684 | else: 685 | raise TypeError("Unexpected type %s" % type(idx)) 686 | 687 | def __setitem__(self, key, value): 688 | value = list(self.expand(value)) 689 | while None in value and len(value) > 1: 690 | value.remove(None) 691 | if len(value) == 1: 692 | value = value[0] 693 | else: 694 | value = tuple(value) 695 | super().__setitem__(key, value) 696 | 697 | def __getitem__(self, key): 698 | try: 699 | return super().__getitem__(key) 700 | except KeyError: 701 | return slice(None) 702 | 703 | def update(self, value=None): 704 | "Updates the values of existing keys" 705 | if not value: 706 | return 707 | for key, val in dict(value).items(): 708 | if key in self: 709 | self[key] = (self[key], val) 710 | continue 711 | self[key] = val 712 | 713 | def finalize(self, key, interval): 714 | "Finalizes the list of values for the coordinate" 715 | if key not in self: 716 | raise KeyError("Unknown key %s" % key) 717 | if self[key] is None or self[key] == slice(None): 718 | return 719 | 720 | values = set() 721 | interval = tuple(interval) 722 | for value in self.expand(self[key]): 723 | if isinstance(value, str): 724 | if value not in interval: 725 | raise ValueError("Value %s not in interval" % value) 726 | values.add(value) 727 | continue 728 | if isinstance(value, int): 729 | values.add(interval[value]) 730 | continue 731 | assert isinstance(value, (slice, range)), "Trivial assertion" 732 | if isinstance(value, range): 733 | value = slice(value.start, value.stop, value.step) 734 | values.update(interval[value]) 735 | assert values <= set(interval), "Trivial assertion" 736 | if values == set(interval): 737 | values = slice(None) 738 | elif isiterable(values, str): 739 | values = tuple(sorted(values, key=interval.index)) 740 | else: 741 | tmp = tuple(compact_indexes(sorted(values))) 742 | if len(tmp) == 1: 743 | values = tmp[0] 744 | self[key] = values 745 | 746 | def cleaned(self): 747 | "Removes keys that are slice(None)" 748 | res = self.copy() 749 | for key in self.keys(): 750 | if self[key] == slice(None): 751 | del res[key] 752 | return res 753 | 754 | def intersection(self, coords): 755 | "Returns the intersection with the given set of coords" 756 | res = self.copy() 757 | for key, val in Coordinates(coords).items(): 758 | if key not in res: 759 | res[key] = val 760 | continue 761 | if val is None or val == slice(None): 762 | continue 763 | if res[key] is None: 764 | if isinstance(val, (int, str)): 765 | continue 766 | raise ValueError("None can only be assigned axis of size one") 767 | if res[key] == slice(None): 768 | raise ValueError("slice(None) can only be assigned to slice(None)") 769 | if isinstance(val, (str, int)): 770 | raise ValueError("%s=%s not compatible with %s" % (key, res[key], val)) 771 | if isinstance(res[key], (str, int)): 772 | if res[key] in val: 773 | continue 774 | raise ValueError("%s=%s not in %s" % (key, res[key], val)) 775 | if not set(val) >= set(res[key]): 776 | raise ValueError( 777 | "%s=%s not in field coordinates" 778 | % (key, set(res[key]).difference(set(val))) 779 | ) 780 | return res 781 | 782 | def extract(self, keys): 783 | "Returns a copy of self including the given keys" 784 | return type(self)({key: self[key] for key in keys}) 785 | 786 | def get_indexes(self, coords): 787 | "Returns the indexes of the values of coords" 788 | if self == coords: 789 | return {} 790 | 791 | indexes = coords.copy() 792 | for key, val in coords.items(): 793 | if self[key] == val: 794 | continue 795 | if val is None: 796 | if isinstance(self[key], (int, str)): 797 | continue 798 | raise ValueError("None can only be assigned axis of size one") 799 | if self[key] == slice(None): 800 | continue 801 | if self[key] is None: 802 | indexes[key] = None 803 | continue 804 | if isinstance(self[key], (str, int)): 805 | raise ValueError( 806 | "Key %s with value %s is not compatible with %s" 807 | % (key, val, self[key]) 808 | ) 809 | if isinstance(val, (str, int)): 810 | if val not in self[key]: 811 | raise ValueError("%s not in field coordinates" % (val)) 812 | if isinstance(val, int): 813 | indexes[key] = self[key].index(val) 814 | continue 815 | if isiterable(self[key], str): 816 | if set(val) <= set(self[key]): 817 | continue 818 | raise ValueError( 819 | "%s not in field coordinates" % (set(val).difference(self[key])) 820 | ) 821 | assert isiterable(self[key], int), "Unexpected value %s" % self[key] 822 | if set(val) <= set(self[key]): 823 | indexes[key] = tuple(self[key].index(idx) for idx in val) 824 | continue 825 | raise ValueError( 826 | "%s not in field coordinates" % (set(val).difference(self[key])) 827 | ) 828 | return indexes.cleaned() 829 | 830 | 831 | class LatticeCoords(LatticeDict): 832 | "LatticeCoords class" 833 | regex = re.compile("[a-zA-Z_][a-zA-Z0-9_]*$") 834 | 835 | def __init__(self, *args, **kwargs): 836 | super().__init__(*args, **kwargs) 837 | if self.lattice is None: 838 | raise ValueError("LatticeCoords requires a lattice") 839 | 840 | def __getitem__(self, key): 841 | try: 842 | return super().__getitem__(key) 843 | except KeyError: 844 | return self.deduce(key) 845 | 846 | def __setitem__(self, key, value): 847 | if key in self.lattice.labels.labels(): 848 | raise KeyError("%s is already used in lattice labels" % key) 849 | 850 | super().__setitem__(key, self.resolve(value)) 851 | 852 | @classmethod 853 | def format_coords(cls, *keys, **coords): 854 | "Returns a list of keys, coords from the given keys and coords" 855 | args = set() 856 | coords = Coordinates(coords) 857 | for key in keys: 858 | if key is None: 859 | continue 860 | if isinstance(key, str): 861 | args.add(key) 862 | elif isinstance(key, dict): 863 | coords.update(key) 864 | else: 865 | if not isiterable(key): 866 | raise TypeError( 867 | "keys can be str, dict or iterables. %s not accepted." % key 868 | ) 869 | _args, _coords = cls.format_coords(*key) 870 | coords.update(_coords) 871 | args.update(_args) 872 | return tuple(args), coords 873 | 874 | def replace(self, key, new_key): 875 | "Replaces a key with a the new key" 876 | # TODO 877 | 878 | def random(self, *axes, label=None): 879 | "A random coordinate in the lattice dims and dofs" 880 | if not axes: 881 | axes = self.lattice.axes 882 | else: 883 | axes = self.lattice.expand(axes) 884 | 885 | coord = {key: random.choice(self.lattice.get_axis_range(key)) for key in axes} 886 | 887 | if label is not None: 888 | self[label] = coord 889 | 890 | return coord 891 | 892 | def random_source(self, label=None): 893 | "A random coordinate in the lattice dims" 894 | return self.random("dims", label=label) 895 | 896 | def resolve(self, *keys, field=None, **coords): 897 | "Combines a set of coordinates" 898 | if field is not None and not isinstance(field, BaseField): 899 | raise ValueError("field must be a Field type") 900 | 901 | keys, coords = self.format_coords(*keys, **coords) 902 | if not keys and not coords: 903 | if field is not None: 904 | return Coordinates(field.coords).cleaned().freeze() 905 | return Coordinates().freeze() 906 | 907 | # Adding to resolved all the coordinates 908 | resolved = Coordinates() 909 | for axis, val in coords.items(): 910 | if field is not None: 911 | indexes = field.get_indexes(axis) 912 | if not indexes: 913 | raise KeyError("Index '%s' not in field" % axis) 914 | else: 915 | indexes = self.lattice.expand(axis) 916 | resolved.update({idx: val for idx in indexes}) 917 | 918 | for key in keys: 919 | coords = self.deduce(key) 920 | if field is not None: 921 | coords = { 922 | index: val 923 | for axis, val in coords.items() 924 | for index in field.get_indexes(axis) 925 | } 926 | if not coords: 927 | raise KeyError("Coord '%s' not in field" % key) 928 | resolved.update(coords) 929 | 930 | # Finalizing the coordinates values 931 | for key, val in resolved.items(): 932 | interval = self.lattice.get_axis_range(BaseField.index_to_axis(key)) 933 | resolved.finalize(key, interval=interval) 934 | 935 | if field is not None: 936 | resolved = resolved.intersection(field.coords) 937 | 938 | return resolved.cleaned().freeze() 939 | 940 | def deduce(self, key): 941 | """ 942 | Deduces the coordinates from the key. 943 | 944 | E.g. 945 | ---- 946 | "random source" 947 | "color diagonal" 948 | "x=0" 949 | """ 950 | if key in self: 951 | return dict(self[key]) 952 | 953 | # Looking up in lattice labels 954 | for name, labels in self.lattice.labels.items(): 955 | if key in labels: 956 | return {name: key} 957 | 958 | # TODO 959 | raise NotImplementedError 960 | 961 | 962 | class LatticeMap: 963 | "Class for defining maps between a lattice and another" 964 | 965 | def __init__(self, lat_from: Lattice, lat_to: Lattice, mapping: Callable): 966 | 967 | self.lat_from = lat_from 968 | self.lat_to = lat_to 969 | self.mapping = mapping 970 | 971 | annotations = self.mapping.__annotations__ 972 | params = signature(self.mapping).parameters 973 | if len(params) > 1 or "return" not in annotations: 974 | raise TypeError( 975 | """ 976 | Mapping uses annotations to deduce the input/output coordinates 977 | E.g. map(**kwargs: ["dims"]) -> ["dims"] 978 | """ 979 | ) 980 | if annotations["return"] not in self.lat_to: 981 | raise ValueError( 982 | f"Output dimentions {annotations['return']} not in second lattice" 983 | ) 984 | self.out = tuple(self.lat_to.expand(annotations["return"])) 985 | 986 | self.args = [] 987 | for key in params: 988 | if key in annotations: 989 | key = annotations[key] 990 | if key not in self.lat_from: 991 | raise ValueError(f"Input dimentions {key} not in first lattice") 992 | self.args.append(key) 993 | self.args = tuple(self.lat_from.expand(self.args)) 994 | 995 | def get(self, **coords): 996 | "Returns the coords after applying the map" 997 | to_transform = dict() 998 | for key in tuple(coords): 999 | if key in self.args: 1000 | to_transform[key] = coords.pop(key) 1001 | coords.update(self.mapping(**to_transform)) 1002 | return coords 1003 | 1004 | def __repr__(self): 1005 | return f"{self.args} -> {self.out}" 1006 | 1007 | def __call__(self): 1008 | for key, coords in self.lat_from.coords.items(): 1009 | new_coords = self.get(**coords) 1010 | if key not in self.lat_to: 1011 | self.lat_to.add_coord(key, new_coords) 1012 | else: 1013 | assert new_coords == self.lat_to[key] 1014 | 1015 | return self.lat_to 1016 | 1017 | 1018 | class LatticeMaps(LatticeDict): 1019 | "LatticeMaps class" 1020 | 1021 | def __setitem__(self, key, value): 1022 | if key in self.lattice.labels.labels(): 1023 | raise KeyError("%s is already used in lattice labels" % key) 1024 | if not isinstance(value, LatticeMap): 1025 | raise TypeError("Expected a LatticeMap for map") 1026 | 1027 | super().__setitem__(key, value) 1028 | 1029 | def replace(self, key, new_key): 1030 | "Replaces a key with a the new key" 1031 | # TODO 1032 | 1033 | def rename(self, key1, key2): 1034 | "Returns a lattice with key1 renamed to key2" 1035 | lattice = self.lattice.copy() 1036 | lattice.rename(key1, key2) 1037 | map_lattice(self.lattice, lattice, {"%s -> %s" % (key1, key2): None}) 1038 | return lattice 1039 | 1040 | def evenodd(self, axis=None): 1041 | """ 1042 | Returns a lattice with even-odd decomposition on the given axis. 1043 | Axis must be an even-sized dimension. 1044 | """ 1045 | lattice = self.lattice.copy() 1046 | 1047 | # Getting first dim with even size 1048 | if axis is None: 1049 | for key in lattice.dims(): 1050 | if lattice[key] % 2 == 0: 1051 | axis = key 1052 | break 1053 | elif axis not in lattice.dims: 1054 | raise KeyError("Axis must be one of the dims") 1055 | 1056 | if axis is None: 1057 | raise ValueError( 1058 | "Even-odd decomposition can be applied only if at least one dim is even" 1059 | ) 1060 | if lattice[axis] % 2 != 0: 1061 | raise ValueError( 1062 | "Even-odd decomposition can be applied only on a even-sized dimension" 1063 | ) 1064 | 1065 | lattice[axis] //= 2 1066 | return lattice 1067 | -------------------------------------------------------------------------------- /notebooks/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !*.ipynb -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 40.6.0", "wheel", "lyncs_setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | testpaths = test 3 | addopts = --cov=lyncs --cov-report term-missing -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from lyncs_setuptools import setup 2 | 3 | import os 4 | 5 | install_requires = [ 6 | "numpy", 7 | "xmltodict", 8 | "tuneit", 9 | "lyncs_utils", 10 | ] 11 | 12 | # Extras 13 | lyncs = { 14 | "dask": ["dask", "dask[array]"], 15 | "clime": ["lyncs_clime"], 16 | "DDalphaAMG": ["lyncs_DDalphaAMG"], 17 | "test": ["pytest", "pytest-cov"], 18 | } 19 | 20 | # Groups 21 | lyncs["io"] = lyncs["clime"] 22 | 23 | lyncs["mpi"] = [ 24 | "lyncs_mpi", 25 | ] + lyncs["DDalphaAMG"] 26 | 27 | lyncs["notebook"] = [ 28 | "jupyterlab", 29 | "tuneit[graph]", 30 | "graphviz", 31 | "perfplot", 32 | ] 33 | 34 | lyncs["all"] = lyncs["notebook"] + lyncs["test"] 35 | 36 | setup( 37 | "lyncs", 38 | install_requires=install_requires, 39 | extras_require=lyncs, 40 | ) 41 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lyncs-API/lyncs/12accf4dbc0dc247dba8cf183bc5007e293c6034/test/__init__.py -------------------------------------------------------------------------------- /test/field/test_array.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from lyncs.field import ArrayField 3 | from lyncs import Lattice 4 | 5 | lat = Lattice() 6 | lat.space = 4 7 | lat.time = 8 8 | 9 | 10 | def test_init(): 11 | with pytest.raises(ValueError): 12 | ArrayField(foo="bar") 13 | 14 | field = ArrayField(axes=["dims", "dofs", "dofs"], lattice=lat) 15 | 16 | assert field.__is__(field.copy()) 17 | assert field.iscomplex 18 | assert field.copy(dtype=int).__is__(field.astype(int)) 19 | assert field.astype(field.dtype).__is__(field) 20 | assert field[{"x": 0}].__is__(field.get(x=0)) 21 | field2 = field.copy() 22 | field2.dtype = int 23 | assert field2.dtype == int 24 | assert not field2.iscomplex 25 | assert field2.__is__(field2.conj()) 26 | assert field2.__is__(field.astype(int)) 27 | field2 = field.copy() 28 | field2[{"x": 0}] = 1 29 | 30 | assert field.size * 16 == field.bytes 31 | 32 | with pytest.raises(ValueError): 33 | bool(field) 34 | 35 | 36 | def test_init_value(): 37 | unit = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] 38 | with pytest.raises(ValueError): 39 | ArrayField(unit, axes=["color", "color"], lattice=lat) 40 | with pytest.raises(ValueError): 41 | ArrayField(unit, axes=["color"], lattice=lat) 42 | with pytest.raises(ValueError): 43 | ArrayField(value=unit, axes=["color", "color"], lattice=lat) 44 | with pytest.raises(ValueError): 45 | ArrayField([0, 1, 2, 3], axes=["dirs"], lattice=lat) 46 | field = ArrayField( 47 | unit, axes=["color", "color"], lattice=lat, indexes_order=("color_0", "color_1") 48 | ) 49 | field = ArrayField(unit[0], axes=["color"], lattice=lat) 50 | 51 | 52 | def test_reorder(): 53 | field = ArrayField(axes=["dims", "dofs"], lattice=lat) 54 | field2 = field.copy() 55 | assert field.indexes_order == field2.indexes_order 56 | field2 = field.reorder(field.indexes) 57 | assert field2.indexes_order.value == field.indexes 58 | with pytest.raises(ValueError): 59 | field.reorder(field.indexes[:-1]) 60 | field2 = field.reorder() 61 | assert field2.indexes_order != field.indexes_order 62 | with pytest.raises(ValueError): 63 | field2.indexes_order = field.indexes[:-1] 64 | field2.indexes_order = field.indexes 65 | assert field2.indexes_order == field.indexes 66 | field2 = field.reorder() 67 | field2.indexes_order = reversed(field.indexes) 68 | assert field2.indexes_order == tuple(reversed(field.indexes)) 69 | 70 | 71 | def test_reorder_label(): 72 | field = ArrayField(axes=["dirs", "dirs", "color"], lattice=lat) 73 | field2 = field.copy() 74 | assert field.labels_order == field2.labels_order 75 | field2 = field.reorder_label("dirs", field.get_range("dirs")) 76 | assert dict(field2.labels_order)["dirs_0"] == field.get_range("dirs") 77 | assert dict(field2.labels_order)["dirs_1"] == field.get_range("dirs") 78 | assert ( 79 | field2.labels_order 80 | == field.copy(dirs_order=field.get_range("dirs")).labels_order 81 | ) 82 | assert ( 83 | field2.labels_order 84 | == field.copy( 85 | dirs_0_order=field.get_range("dirs"), dirs_1_order=field.get_range("dirs") 86 | ).labels_order 87 | ) 88 | field2 = field.reorder_label("dirs_0") 89 | assert dict(field2.labels_order)["dirs_0"] != dict(field.labels_order)["dirs_0"] 90 | assert dict(field2.labels_order)["dirs_1"] == dict(field.labels_order)["dirs_1"] 91 | with pytest.raises(KeyError): 92 | field.reorder_label("color") 93 | with pytest.raises(ValueError): 94 | field.reorder_label("foo") 95 | with pytest.raises(TypeError): 96 | field.copy(labels_order="foo") 97 | with pytest.raises(ValueError): 98 | field2 = field.reorder_label("dirs_0", field.get_range("dirs")[:-1]) 99 | with pytest.raises(ValueError): 100 | order = field2.labels_order[0][1].copy(reset=True) 101 | field2.copy(field2.value, labels_order={"dirs_0": order}) 102 | 103 | field2 = field[{"dirs_0": "x"}] 104 | assert field2.__is__(field2.reorder_label("dirs_0")) 105 | with pytest.raises(ValueError): 106 | field2.reorder_label("dirs") 107 | 108 | field2 = field[{"dirs_0": ("x", "y")}] 109 | assert dict(field2.labels_order)["dirs_0"].fixed 110 | 111 | 112 | def test_reshape(): 113 | field = ArrayField(axes=["dims", "dofs"], lattice=lat) 114 | field2 = field.extend("dofs") 115 | assert field.type == "Vector" 116 | assert field2.type == "Propagator" 117 | assert field2.squeeze().indexes == field.indexes 118 | 119 | 120 | def test_transpose(): 121 | field = ArrayField(axes=["dofs", "dofs"], lattice=lat) 122 | assert field.T.__is__(field.transpose()) 123 | assert field.T.indexes_order == field.indexes_order 124 | assert field.transpose(spin=(0, 1)).__is__(field) 125 | # assert field.transpose(spin=(1, 0)) == field.transpose("spin") 126 | 127 | field = ArrayField(axes=["dofs"], lattice=lat) 128 | assert field.T.__is__(field) 129 | 130 | with pytest.raises(KeyError): 131 | field.transpose(foo=(0, 1)) 132 | with pytest.raises(TypeError): 133 | field.transpose(spin=0) 134 | with pytest.raises(ValueError): 135 | field.transpose(spin=(0, 1, 2)) 136 | with pytest.raises(ValueError): 137 | field.transpose(spin=(22, 33)) 138 | 139 | assert field.H.__is__(field.dagger()) 140 | 141 | field2 = field.real 142 | assert not field2.iscomplex 143 | assert field2.H.__is__(field2.T) 144 | 145 | 146 | def test_roll(): 147 | field = ArrayField(axes=["dims", "dofs"], lattice=lat) 148 | assert field.roll(1).__is__(field.roll(1, "all")) 149 | assert field.roll(1).__is__(field.roll(1, "dims", "dofs", "labels")) 150 | assert field.roll(1, "color").__is__(field.roll(1, axis="color")) 151 | assert field.roll(1, "color").__is__(field.roll(1, axes="color")) 152 | 153 | with pytest.raises(ValueError): 154 | field.roll(1, "color", axes="dofs") 155 | with pytest.raises(KeyError): 156 | field.roll(1, "color", foo="bar") 157 | with pytest.raises(TypeError): 158 | field.roll(1, 2, 3) 159 | -------------------------------------------------------------------------------- /test/field/test_base.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from lyncs.field import BaseField, squeeze 3 | from lyncs import Lattice 4 | 5 | lat = Lattice() 6 | lat.space = 4 7 | lat.time = 8 8 | 9 | 10 | def test_init(): 11 | 12 | with pytest.raises(TypeError): 13 | squeeze("foo") 14 | with pytest.raises(TypeError): 15 | BaseField(lattice="foo") 16 | with pytest.raises(ValueError): 17 | BaseField(foo="bar") 18 | 19 | field = BaseField(axes=["dims", "dofs", "dofs"], lattice=lat) 20 | 21 | assert field == field 22 | assert field == +field 23 | assert field == field.copy() 24 | assert dir(field) == field.__dir__() 25 | 26 | assert field.size == 4 * 4 * 4 * 8 * 3 * 3 * 4 * 4 27 | assert len(field.axes) == 8 28 | assert len(field.axes) == len(field.indexes) 29 | assert field.indexes_to_axes(*field.indexes) == field.axes 30 | assert field.axes_to_indexes(*field.axes) == field.indexes 31 | assert field.indexes_to_axes(*field.dims) == tuple(field.lattice.expand("dims")) 32 | assert field.indexes_to_axes(*field.dofs) == tuple( 33 | field.lattice.expand("dofs", "dofs") 34 | ) 35 | assert field.labels == () 36 | assert set(field.get_axes(*field.axes)) == set(field.axes) 37 | assert set(field.get_axes("all")) == set(field.axes) 38 | assert set(field.get_indexes("all")) == set(field.indexes) 39 | with pytest.raises(TypeError): 40 | field.get_axes(1) 41 | with pytest.raises(TypeError): 42 | field.get_indexes([1, 2]) 43 | 44 | types = dict(field.types) 45 | assert field.type == "Propagator" 46 | assert field.type == next(iter(types)) 47 | assert "Sites" in types 48 | assert "Scalar" in types 49 | assert "Degrees" in types 50 | 51 | source = field.lattice.coords.random_source("source") 52 | point = field[source] 53 | assert field["source"] == field[source] 54 | shape = dict(point.shape) 55 | for dim in point.dims: 56 | assert shape[dim] == 1 57 | 58 | dofs = point.squeeze() 59 | assert point.size == dofs.size 60 | assert dofs == squeeze(dofs) 61 | assert not dofs.dims 62 | assert dofs.dofs == dofs.indexes 63 | assert point.reshape(dofs.axes) == dofs 64 | with pytest.raises(ValueError): 65 | point.reshape(point.axes[:-1]) 66 | assert point == dofs.reshape(point.axes)[source] 67 | assert dofs.reshape(point.axes) == dofs.extend(lat.dims) 68 | 69 | extended = dofs.extend(lat.dims)[{"x": (0, 1)}] 70 | assert extended.size == 2 * dofs.size 71 | 72 | everywhere = point.unsqueeze() 73 | assert everywhere == dofs.extend("dims") 74 | 75 | 76 | def test_size(): 77 | field = BaseField(axes=["dims", "dirs", "color", "color"], lattice=lat) 78 | with pytest.raises(KeyError): 79 | field.get_size("spin") 80 | assert field.get_size("color") == 3 81 | with pytest.raises(ValueError): 82 | field.get_size("dims") 83 | 84 | 85 | def test_coords(): 86 | field = BaseField(axes=["dims", "dirs", "color", "color"], lattice=lat) 87 | 88 | field.lattice.coords["spin0"] = {"spin": 0} 89 | assert "spin0" in field.lattice.coords 90 | with pytest.raises(KeyError): 91 | field.get(spin=0) 92 | with pytest.raises(KeyError): 93 | field["spin0"] 94 | 95 | field.lattice.coords["col0"] = {"color": 0} 96 | assert "col0" in field.lattice.coords 97 | assert field["col0"] == field.get(color=0) 98 | 99 | highX = field.get(x=(2, 3)) 100 | assert highX.get_size("x") == 2 101 | assert highX.get(x=(2, 3)) == highX 102 | assert highX != field 103 | assert highX.get(x=(2)).get_size("x") == 1 104 | -------------------------------------------------------------------------------- /test/field/test_numpy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from lyncs.field import ArrayField 3 | from lyncs import Lattice 4 | import numpy as np 5 | 6 | 7 | def init_field(): 8 | lat = Lattice() 9 | lat.space = 4 10 | lat.time = 8 11 | field = ArrayField(axes=["dims", "color", "color"], lattice=lat) 12 | indexes = field.indexes 13 | field.indexes_order = indexes 14 | shape = field.ordered_shape 15 | return field, indexes, shape 16 | 17 | 18 | def test_init(): 19 | field, indexes, shape = init_field() 20 | 21 | assert field == field.copy(copy=True) 22 | 23 | assert field.zeros() == np.zeros(shape) 24 | assert field.ones() == np.ones(shape) 25 | 26 | field = field.rand() 27 | random = field.result 28 | assert field == field.copy() 29 | assert field == random 30 | 31 | vals = np.arange(9).reshape(3, 3) 32 | field = ArrayField( 33 | vals, axes=["color", "color"], indexes_order=["color_0", "color_1"] 34 | ) 35 | assert field == vals 36 | assert field.astype("float") == vals.astype("float") 37 | 38 | 39 | def getitem(arr, indexes, **coords): 40 | return arr.__getitem__(tuple(coords.pop(idx, slice(None)) for idx in indexes)) 41 | 42 | 43 | def test_getitem(): 44 | field, indexes, shape = init_field() 45 | field = field.rand() 46 | random = field.result 47 | 48 | assert field[{"x": 0}] == getitem(random, indexes, x_0=0) 49 | assert field[{"y": (0, 1, 2), "z": -1}] == getitem( 50 | random, indexes, y_0=range(3), z_0=-1 51 | ) 52 | assert field[{"color": 0}] == getitem(random, indexes, color_0=0, color_1=0) 53 | assert field[{"color_0": 0}] == getitem(random, indexes, color_0=0) 54 | 55 | 56 | def setitem(arr, value, indexes, **coords): 57 | arr.__setitem__(tuple(coords.pop(idx, slice(None)) for idx in indexes), value) 58 | return arr 59 | 60 | 61 | def test_setitem(): 62 | field, indexes, shape = init_field() 63 | field = field.rand() 64 | random = field.result 65 | 66 | field[{"x": 0}] = 0 67 | assert field == setitem(random, 0, indexes, x_0=0) 68 | 69 | field[{"x": (0, 1)}] = 0 70 | assert field == setitem(random, 0, indexes, x_0=(0, 1)) 71 | -------------------------------------------------------------------------------- /test/test_lattice.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import pytest 3 | from lyncs import Lattice, default_lattice 4 | from lyncs.lattice import ( 5 | LatticeDict, 6 | LatticeAxes, 7 | LatticeLabels, 8 | LatticeGroups, 9 | Coordinates, 10 | ) 11 | 12 | 13 | def test_keys(): 14 | keys = LatticeDict() 15 | keys["a0"] = None 16 | keys["a_0b"] = "foo" 17 | keys["a_b_c"] = 1 18 | 19 | assert keys == LatticeDict(keys) 20 | assert keys == keys.copy() 21 | 22 | for key in ["a_0", "_a", "a!", "a+", "a_"]: 23 | with pytest.raises(KeyError): 24 | keys[key] = 1 25 | 26 | with pytest.raises(KeyError): 27 | keys.update({"a_0": 1}) 28 | 29 | with pytest.raises(KeyError): 30 | keys = LatticeDict({"a_0": 1}) 31 | 32 | with pytest.raises(KeyError): 33 | keys.setdefault("a_0", 1) 34 | 35 | copy = keys.copy() 36 | assert copy == keys 37 | with pytest.raises(KeyError): 38 | copy.reset({"a_0": 1}) 39 | assert copy == keys 40 | 41 | frozen = keys.freeze() 42 | assert frozen.freeze() is frozen 43 | frozen.frozen = True 44 | with pytest.raises(ValueError): 45 | frozen.frozen = False 46 | with pytest.raises(RuntimeError): 47 | frozen["a0"] = 5 48 | 49 | del keys["a0"] 50 | with pytest.raises(RuntimeError): 51 | del frozen["a0"] 52 | assert "a0" in frozen 53 | assert "a0" not in keys 54 | 55 | with pytest.raises(ValueError): 56 | LatticeDict(None, "foo") 57 | 58 | 59 | def test_axes(): 60 | axes = LatticeAxes() 61 | axes["x"] = 1 62 | with pytest.raises(ValueError): 63 | axes["x"] = -1 64 | 65 | 66 | def test_labels(): 67 | labels = LatticeLabels() 68 | labels["x"] = "x" 69 | assert labels["x"] == ("x",) 70 | with pytest.raises(TypeError): 71 | labels["x"] = 1 72 | with pytest.raises(ValueError): 73 | labels["x"] = ("x", "x") 74 | assert labels["x"] == ("x",) 75 | with pytest.raises(ValueError): 76 | labels["y"] = ("x",) 77 | 78 | 79 | def test_init(): 80 | lat = Lattice(dims=4, dofs=[4, 3]) 81 | assert len(lat.dims) == 4 82 | assert len(lat.dofs) == 2 83 | assert set(["t", "x", "y", "z"]) == set(lat.dims.keys()) 84 | assert default_lattice() == lat 85 | 86 | with pytest.raises(AttributeError): 87 | lat.foo 88 | 89 | with pytest.raises(ValueError): 90 | list(lat.expand("foo")) 91 | 92 | with pytest.raises(ValueError): 93 | lat.get_axis_range("foo") 94 | 95 | with pytest.raises(ValueError): 96 | lat.get_axis_size("foo") 97 | 98 | with pytest.raises(AttributeError): 99 | lat.foo = "bar" 100 | 101 | lat.x = 5 102 | assert lat.x == 5 and lat.dims["x"] == lat.x 103 | lat.dof0 = 6 104 | assert lat.dof0 == 6 105 | lat.dirs = lat.dims 106 | assert lat.dirs == tuple(lat.dims) 107 | lat.space = 6 108 | assert lat.x == lat.y and lat.y == lat.z and lat.z == 6 109 | 110 | lat.dims["w"] = 8 111 | assert "w" in lat 112 | assert lat.space in lat 113 | 114 | with pytest.raises(KeyError): 115 | lat.dofs["x"] = 3 116 | 117 | assert lat == pickle.loads(pickle.dumps(lat)) 118 | 119 | assert all((hasattr(lat, key) for key in dir(lat))) 120 | assert set(lat.keys()).issubset(dir(lat)) 121 | 122 | 123 | def test_freeze(): 124 | lat = Lattice() 125 | lat2 = lat.freeze() 126 | assert lat2.freeze() is lat2 127 | 128 | lat.frozen = False 129 | lat2.frozen = True 130 | assert lat.fields == lat2.fields 131 | with pytest.raises(ValueError): 132 | lat2.frozen = False 133 | with pytest.raises(RuntimeError): 134 | lat2.x = 5 135 | with pytest.raises(RuntimeError): 136 | lat2.dims = None 137 | with pytest.raises(RuntimeError): 138 | lat2.dofs = None 139 | with pytest.raises(RuntimeError): 140 | lat2.labels = None 141 | with pytest.raises(RuntimeError): 142 | lat2.groups = None 143 | with pytest.raises(RuntimeError): 144 | lat2.coords = None 145 | 146 | 147 | def test_init_dims(): 148 | no_dims = Lattice(dims=None) 149 | assert no_dims == Lattice(dims=False) 150 | assert no_dims == Lattice(dims=[]) 151 | assert no_dims == Lattice(dims=0) 152 | 153 | lattice = Lattice(dims=5) 154 | assert len(lattice.dims) == 5 155 | 156 | lattice = Lattice(dims=["x", "y", "z"]) 157 | assert len(lattice.dims) == 3 158 | lattice.x = 8 159 | assert lattice.x == lattice.get_axis_size("x") 160 | 161 | with pytest.raises(TypeError): 162 | Lattice(dims=3.5) 163 | 164 | with pytest.raises(ValueError): 165 | Lattice(dims=-1) 166 | 167 | 168 | def test_init_dofs(): 169 | no_dofs = Lattice(dofs=None) 170 | assert no_dofs == Lattice(dofs=False) 171 | assert no_dofs == Lattice(dofs=[]) 172 | assert no_dofs == Lattice(dofs=0) 173 | 174 | lattice = Lattice(dofs=5) 175 | assert len(lattice.dofs) == 5 176 | 177 | lattice = Lattice(dofs=["a", "b", "c"]) 178 | assert len(lattice.dofs) == 3 179 | 180 | with pytest.raises(TypeError): 181 | Lattice(dofs=3.5) 182 | 183 | with pytest.raises(ValueError): 184 | Lattice(dofs=-1) 185 | 186 | 187 | def test_init_labels(): 188 | no_labels = Lattice(labels=None) 189 | assert no_labels == Lattice(labels=False) 190 | assert no_labels == Lattice(labels=[]) 191 | assert no_labels == Lattice(labels=0) 192 | 193 | lattice = Lattice(labels={"trial": ["foo", "bar"]}) 194 | assert "trial" in lattice.labels 195 | assert lattice.trial == ("foo", "bar") 196 | assert lattice.trial == lattice.get_axis_range("trial") 197 | assert len(lattice.trial) == lattice.get_axis_size("trial") 198 | 199 | lattice.add_label("another", ["one", "two"]) 200 | assert lattice["another"] == ("one", "two") 201 | 202 | with pytest.raises(TypeError): 203 | lattice.labels = 3.5 204 | 205 | with pytest.raises(TypeError): 206 | Lattice(labels={"trial": 3.5}) 207 | 208 | 209 | def test_init_groups(): 210 | no_groups = Lattice(groups=None) 211 | assert no_groups == Lattice(groups=False) 212 | assert no_groups == Lattice(groups=[]) 213 | assert no_groups == Lattice(groups=0) 214 | 215 | lattice = Lattice(groups={"trial": "x"}) 216 | assert "trial" in lattice.groups 217 | assert lattice.trial == ("x",) 218 | 219 | lattice.add_group("another", ["x", "y"]) 220 | assert lattice["another"] == ("x", "y") 221 | 222 | with pytest.raises(TypeError): 223 | lattice.groups = 3.5 224 | 225 | with pytest.raises(TypeError): 226 | Lattice(groups={"trial": 3.5}) 227 | 228 | with pytest.raises(ValueError): 229 | Lattice(groups={"trial": "foo"}) 230 | 231 | 232 | def test_init_coords(): 233 | no_coords = Lattice(coords=None) 234 | assert no_coords == Lattice(coords=False) 235 | assert no_coords == Lattice(coords=[]) 236 | assert no_coords == Lattice(coords=0) 237 | 238 | lattice = Lattice(coords={"trial": {"x": 0, "y": 0}}) 239 | assert "trial" in lattice.coords 240 | 241 | with pytest.raises(KeyError): 242 | lattice.labels["trial"] = ("foo", "bar") 243 | 244 | val = lattice.coords.random_source("source") 245 | assert "source" in lattice.coords 246 | set(dict(lattice.source).keys()) == set(lattice.dims) 247 | 248 | lattice.coords = {"trial": {"x": 0, "y": 0}} 249 | assert "source" not in lattice.coords 250 | lattice.trial = {"x": 0, "y": 0} 251 | 252 | lattice.add_coord("another", {"x": 0, "y": 0}) 253 | assert lattice["another"] == lattice["trial"] 254 | 255 | with pytest.raises(TypeError): 256 | lattice.coords = 3.5 257 | 258 | with pytest.raises(TypeError): 259 | Lattice(coords={"trial": 3.5}) 260 | 261 | 262 | def test_default(): 263 | lat = Lattice() 264 | assert len(lat.dims) == 4 265 | assert set(lat.dims) == set(["x", "y", "z", "t"]) 266 | assert len(lat.dofs) == 2 267 | assert set(lat.dofs) == set(["spin", "color"]) 268 | assert lat.spin == 4 269 | assert lat.color == 3 270 | assert "dirs" in lat and len(lat.dirs) == 4 271 | 272 | 273 | def test_coords(): 274 | lat = Lattice() 275 | lat.space = 4 276 | lat.time = 8 277 | 278 | assert lat.coords.resolve(y=slice(None)) == {} 279 | assert lat.coords.resolve({"y": (2, 3)}, y=(0, 1)) == {} 280 | assert lat.coords.resolve(lat.dirs) == {} 281 | 282 | 283 | def test_expand(): 284 | lat = Lattice() 285 | assert set(lat.expand("space")) == set(["x", "y", "z"]) 286 | assert set(lat.expand("dims")) == set(lat.expand("space", "time")) 287 | with pytest.raises(TypeError): 288 | list(lat.expand(1)) 289 | 290 | 291 | def test_coordinates(): 292 | coords = Coordinates() 293 | with pytest.raises(TypeError): 294 | coords["x"] = 3.5 295 | for val in ( 296 | None, 297 | (None,), 298 | (None, None), 299 | [ 300 | None, 301 | [ 302 | None, 303 | ], 304 | ], 305 | ): 306 | coords["x"] = val 307 | assert coords["x"] == None 308 | coords.update({"x": val}) 309 | assert coords["x"] == None 310 | assert coords["y"] == slice(None) 311 | 312 | 313 | def test_rename(): 314 | lat = Lattice() 315 | lat.rename("t", "t") 316 | lat.rename("t", "T") 317 | assert "T" in lat 318 | assert "t" not in lat 319 | assert lat["time"] == ("T",) 320 | 321 | with pytest.raises(RuntimeError): 322 | lat.freeze().rename("T", "t") 323 | -------------------------------------------------------------------------------- /test/test_lime.py: -------------------------------------------------------------------------------- 1 | def _test_read_config(): 2 | import lyncs 3 | import numpy 4 | 5 | conf_path = lyncs.__path__[0] + "/../tests/conf.1000" 6 | conf = lyncs.load(conf_path) 7 | assert conf.size == 4 ** 4 * 4 * 3 * 3 8 | 9 | # without distribution 10 | conf.chunks = conf.dims 11 | reference = conf.result() 12 | assert reference.shape == conf.field_shape 13 | 14 | from itertools import product 15 | 16 | chunks = product(*[[1, 2, 4]] * 4) 17 | for chunk in chunks: 18 | chunk = dict(zip(["t", "z", "y", "x"], chunk)) 19 | conf = lyncs.load(conf_path) 20 | conf.chunks = chunk 21 | read = conf.result() 22 | assert numpy.all(read == reference) 23 | --------------------------------------------------------------------------------