├── tests ├── __init__.py ├── README ├── test_psutilprofiler.py ├── test_interferometry.py ├── test_metadata.py ├── test_fftw.py ├── test_selection.py ├── test_lint.py ├── test_misc.py ├── test_tod.py ├── test_pipeline.py ├── test_config.py ├── test_selection_parallel.py ├── test_tools.py ├── test_truncate.py ├── conftest.py ├── test_memh5_parallel.py ├── test_moving_weighted_median.py └── test_time.py ├── caput ├── scripts │ └── __init__.py ├── __init__.py ├── MedianTree.pxd ├── cache.py ├── _fast_tools.pyx ├── MedianTreeNodeData.hpp ├── MedianTreeNode.hpp ├── pfb.py ├── interferometry.py ├── tools.py ├── truncate.hpp ├── truncate.pyx ├── fileformats.py ├── fftw.py ├── misc.py └── profile.py ├── doc ├── reference.rst ├── _templates │ └── autosummary │ │ ├── module.rst │ │ └── class.rst ├── installation.rst ├── index.rst ├── config.rst ├── Makefile └── conf.py ├── README.md ├── .readthedocs.yaml ├── .gitignore ├── setup.py ├── CITATION.cff ├── .github └── workflows │ └── main.yml └── pyproject.toml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the `caput` module.""" 2 | -------------------------------------------------------------------------------- /tests/README: -------------------------------------------------------------------------------- 1 | Unit tests using the python `unittest` framework. 2 | -------------------------------------------------------------------------------- /caput/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools for running caput piplines.""" 2 | -------------------------------------------------------------------------------- /doc/reference.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ------------- 3 | 4 | .. automodule:: caput 5 | -------------------------------------------------------------------------------- /doc/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | {% extends "!autosummary/module.rst" %} 2 | 3 | {% block classes %} 4 | {% endblock %} 5 | 6 | {% block exceptions %} 7 | {% endblock %} 8 | 9 | {% block functions %} 10 | {% endblock %} 11 | -------------------------------------------------------------------------------- /tests/test_psutilprofiler.py: -------------------------------------------------------------------------------- 1 | """Test running the caput.pipeline.""" 2 | 3 | 4 | def test_pipeline(run_pipeline): 5 | """Test profiling a very simple pipeline.""" 6 | result = run_pipeline(["--psutil"]) 7 | print(result.output) 8 | assert result.exit_code == 0 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # caput 2 | 3 | Cluster Astronomical Python Utilities 4 | 5 | Package contains useful utilities for dealing with large datasets on computer 6 | clusters with applications to radio astronomy in mind. Includes modules for 7 | dynamically importing and utilizing mpi4py, in-memory mock-ups of h5py objects, 8 | and infrastructure for running data analysis pipelines on computer clusters. 9 | 10 | Documentation can be found at http://caput.readthedocs.org/ 11 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build info 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: doc/conf.py 17 | fail_on_warning: false 18 | 19 | python: 20 | install: 21 | - method: pip 22 | path: . 23 | extra_requirements: 24 | - docs 25 | -------------------------------------------------------------------------------- /caput/__init__.py: -------------------------------------------------------------------------------- 1 | """caput. 2 | 3 | Submodules 4 | ---------- 5 | .. autosummary:: 6 | :toctree: _autosummary 7 | 8 | config 9 | interferometry 10 | memh5 11 | misc 12 | mpiarray 13 | mpiutil 14 | pfb 15 | pipeline 16 | time 17 | tod 18 | weighted_median 19 | """ 20 | 21 | from importlib.metadata import PackageNotFoundError, version 22 | 23 | try: 24 | __version__ = version("caput") 25 | except PackageNotFoundError: 26 | # package is not installed 27 | pass 28 | 29 | del version, PackageNotFoundError 30 | -------------------------------------------------------------------------------- /doc/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | :: 5 | 6 | pip install git+https://github.com/radiocosmology/caput.git 7 | 8 | caput depends on h5py_, numpy_ and PyYAML_. For full functionality it also 9 | requires click_, mpi4py_ and Skyfield_. 10 | 11 | .. _GitHub: https://github.com/KeepSafe/aiohttp 12 | .. _h5py: http:/www.h5py.org/ 13 | .. _numpy: http://www.numpy.org/ 14 | .. _PyYAML: http://pyyaml.org/ 15 | .. _mpi4py: http://mpi4py.readthedocs.io/en/stable/ 16 | .. _click: http://click.palletsprojects.com/ 17 | .. _Skyfield: http://rhodesmill.org/skyfield/ 18 | .. _Freenode: http://freenode.net 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # editor rubble 2 | *~ 3 | .*.swp 4 | 5 | *.py[cod] 6 | 7 | # C extensions 8 | *.so 9 | *.c 10 | *.cpp 11 | 12 | # Data files 13 | caput/data/ 14 | 15 | # Packages 16 | *.egg 17 | *.egg-info 18 | dist 19 | build 20 | eggs 21 | parts 22 | bin 23 | var 24 | sdist 25 | develop-eggs 26 | .installed.cfg 27 | lib 28 | lib64 29 | 30 | # Installer logs 31 | pip-log.txt 32 | 33 | # Unit test / coverage reports 34 | .coverage 35 | .tox 36 | nosetests.xml 37 | 38 | # Unit test files 39 | *.tmp_test* 40 | */perf* 41 | 42 | # Translations 43 | *.mo 44 | 45 | # Mr Developer 46 | .mr.developer.cfg 47 | .project 48 | .pydevproject 49 | 50 | # Misc 51 | *.DS_Store 52 | -------------------------------------------------------------------------------- /caput/MedianTree.pxd: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # cython: language_level = 2 3 | from libcpp cimport bool 4 | from libcpp.memory cimport shared_ptr 5 | 6 | cdef extern from "MedianTree.hpp" namespace "MedianTree": 7 | cdef cppclass Tree[T]: 8 | Tree() nogil 9 | shared_ptr[Data[T]] insert(const T& element, const double weight) nogil 10 | bool remove(const shared_ptr[Data[T]] node) nogil 11 | T weighted_median(char method) nogil 12 | int size() 13 | 14 | cdef extern from "MedianTreeNodeData.hpp" namespace "MedianTree": 15 | cdef cppclass Data[T]: 16 | Data(const T& value, const double weight) nogil 17 | double weight 18 | -------------------------------------------------------------------------------- /caput/cache.py: -------------------------------------------------------------------------------- 1 | """Tools for caching expensive calculations.""" 2 | 3 | import numpy as np 4 | from cachetools import LRUCache 5 | 6 | 7 | class NumpyCache(LRUCache): # pylint: disable=too-many-ancestors 8 | """An LRU cache for numpy arrays that will expand to a maximum size in bytes. 9 | 10 | This should be used like a dictionary except that the least recently used entries 11 | are evicted to restrict memory usage to the specified maximum. 12 | 13 | Parameters 14 | ---------- 15 | size_bytes 16 | The maximum size of the cache in bytes. 17 | """ 18 | 19 | def __init__(self, size_bytes: int): 20 | def _array_size(arr: np.ndarray): 21 | if not isinstance(arr, np.ndarray): 22 | raise ValueError("Item must be a numpy array.") 23 | 24 | return arr.nbytes 25 | 26 | super().__init__(maxsize=size_bytes, getsizeof=_array_size) 27 | -------------------------------------------------------------------------------- /doc/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {% extends "!autosummary/class.rst" %} 2 | 3 | {% block methods %} 4 | {% if methods %} 5 | .. autosummary:: 6 | :toctree: 7 | {% for item in methods %} 8 | {{ name }}.{{ item }} 9 | {%- endfor %} 10 | .. HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. 11 | {% for item in all_methods %} 12 | {{ name }}.{{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block attributes %} 18 | {% if attributes %} 19 | .. autosummary:: 20 | :toctree: 21 | {% for item in attributes %} 22 | {{ name }}.{{ item }} 23 | {%- endfor %} 24 | .. HACK -- the point here is that we don't want this to appear in the output, but the autosummary should still generate the pages. 25 | {% for item in all_attributes %} 26 | {{ name }}.{{ item }} 27 | {%- endfor %} 28 | {% endif %} 29 | {% endblock %} 30 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | caput 2 | ===== 3 | 4 | A collection of utilities for building data analysis pipelines. 5 | 6 | Features 7 | -------- 8 | 9 | - A generic container for holding self-documenting datasets in memory with 10 | straightforward syncing to h5py_ files (:mod:`~caput.memh5`). Plus some 11 | specialisation to holding time stream data (:mod:`~caput.tod`). 12 | 13 | - Tools to make MPI-parallel analysis a little easier (:mod:`caput.mpiutil` and 14 | :mod:`~caput.mpiarray`). 15 | 16 | - Infrastructure for building, managing and configuring pipelines for data 17 | processing (:mod:`~caput.pipeline` and :mod:`~caput.config`). 18 | 19 | - Routines for converting to between different time representations, dealing 20 | with leap seconds, and calculating celestial times (:mod:`~caput.time`) 21 | 22 | 23 | Index 24 | ----- 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | 29 | installation 30 | config 31 | reference 32 | 33 | 34 | Indices and tables 35 | ------------------ 36 | 37 | * :ref:`genindex` 38 | * :ref:`modindex` 39 | * :ref:`search` 40 | 41 | .. _h5py: http:/www.h5py.org/ 42 | -------------------------------------------------------------------------------- /caput/_fast_tools.pyx: -------------------------------------------------------------------------------- 1 | from cython.parallel import prange 2 | cimport cython 3 | 4 | from libc.math cimport fabs 5 | 6 | cdef extern from "float.h" nogil: 7 | double DBL_MAX 8 | double FLT_MAX 9 | 10 | ctypedef fused real_or_complex: 11 | double 12 | double complex 13 | float 14 | float complex 15 | 16 | @cython.wraparound(False) 17 | @cython.boundscheck(False) 18 | @cython.cdivision(True) 19 | cpdef _invert_no_zero(real_or_complex [:] array, real_or_complex [:] out): 20 | 21 | cdef bint cond 22 | cdef Py_ssize_t i = 0 23 | cdef Py_ssize_t n = array.shape[0] 24 | cdef double thresh, ar, ai 25 | if (real_or_complex is cython.doublecomplex) or (real_or_complex is cython.double): 26 | thresh = 1.0 / DBL_MAX 27 | else: 28 | thresh = 1.0 / FLT_MAX 29 | 30 | if (real_or_complex is cython.doublecomplex) or (real_or_complex is cython.floatcomplex): 31 | for i in prange(n, nogil=True): 32 | cond = (fabs(array[i].real) < thresh) and (fabs(array[i].imag) < thresh) 33 | out[i] = 0.0 if cond else 1.0 / array[i] 34 | else: 35 | for i in prange(n, nogil=True): 36 | cond = fabs(array[i]) < thresh 37 | out[i] = 0.0 if cond else 1.0 / array[i] -------------------------------------------------------------------------------- /caput/MedianTreeNodeData.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MEDIAN_NODE_DATA 2 | #define MEDIAN_NODE_DATA 3 | 4 | #include "MedianTree.hpp" 5 | 6 | namespace MedianTree { 7 | 8 | // Forward declarations 9 | template 10 | class Tree; 11 | 12 | template 13 | class Node; 14 | 15 | template 16 | class Data { 17 | public: 18 | friend class Tree; 19 | friend class Node; 20 | 21 | Data(const T& value, const double weight); 22 | double weight; 23 | 24 | private: 25 | T value; 26 | 27 | inline bool operator<(const Data c) const; 28 | inline bool operator==(const Data c) const; 29 | }; 30 | 31 | template 32 | using data_ptr = std::shared_ptr>; 33 | 34 | template 35 | Data::Data(const T& value, const double weight) { 36 | this->value = value; 37 | this->weight = weight; 38 | } 39 | 40 | template 41 | inline bool Data::operator<(const Data c) const { 42 | if (value == c.value) 43 | return weight < c.weight; 44 | return value < c.value; 45 | } 46 | 47 | template 48 | inline bool Data::operator==(const Data c) const { 49 | if (value == c.value && weight == c.weight) 50 | return true; 51 | return false; 52 | } 53 | 54 | } // namespace MedianTree 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /doc/config.rst: -------------------------------------------------------------------------------- 1 | .. _config: 2 | 3 | Configuration 4 | ============= 5 | 6 | The caput pipeline runner script accepts a YAML file for configuration. The structure of this file 7 | is documented in :ref:`config`. 8 | 9 | General options 10 | --------------- 11 | 12 | 13 | Pipeline 14 | -------- 15 | 16 | Logging 17 | ....... 18 | The log levels can be configured in multiple ways: 19 | 20 | - Use the `logging` section directly in the pipeline blog to define the root log level with either 21 | `DEBUG`, `INFO`, `WARNING` or `ERROR`. You can also also set log levels for single modules here 22 | and may add a root log level with the key `"root"`. The default is `{"root": "WARNING"}` 23 | 24 | Examples: 25 | 26 | :: 27 | 28 | pipeline: 29 | logging: 30 | root: DEBUG 31 | annoying.module: INFO 32 | 33 | would show `DEBUG` messages for everything, but `INFO` only for a module called `annoying.module`. 34 | 35 | :: 36 | 37 | pipeline: 38 | logging: ERROR 39 | 40 | would reduce all loggin to `ERROR` messages. 41 | 42 | - Set the `log_level` parameter of any task of type 43 | `draco.core.task.LoggedTask `_. 44 | 45 | - Further filter logging by MPI ranks using 46 | `draco.core.task.SetMPILogging `_. 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Build cython extensions. 2 | 3 | The full project config can be found in `pyproject.toml`. `setup.py` is still 4 | required to build cython extensions. 5 | """ 6 | 7 | import os 8 | import re 9 | import sys 10 | import sysconfig 11 | 12 | import numpy 13 | from Cython.Build import cythonize 14 | from setuptools import setup 15 | from setuptools.extension import Extension 16 | 17 | # Decide whether to use OpenMP or not 18 | if ( 19 | ("CAPUT_NO_OPENMP" in os.environ) 20 | or (re.search("gcc", sysconfig.get_config_var("CC")) is None) 21 | or (sys.platform == "darwin") 22 | ): 23 | print("Not using OpenMP") 24 | omp_args = [] 25 | else: 26 | omp_args = ["-fopenmp"] 27 | 28 | # Set up project extensions 29 | extensions = [ 30 | Extension( 31 | name="caput.weighted_median", 32 | sources=["caput/weighted_median.pyx"], 33 | include_dirs=[numpy.get_include()], 34 | language="c++", 35 | extra_compile_args=[*omp_args, "-std=c++11", "-g0", "-O3"], 36 | extra_link_args=[*omp_args, "-std=c++11"], 37 | ), 38 | Extension( 39 | name="caput.truncate", 40 | sources=["caput/truncate.pyx"], 41 | include_dirs=[numpy.get_include()], 42 | extra_compile_args=[*omp_args, "-g0", "-O3"], 43 | extra_link_args=omp_args, 44 | ), 45 | Extension( 46 | name="caput._fast_tools", 47 | sources=["caput/_fast_tools.pyx"], 48 | include_dirs=[numpy.get_include()], 49 | extra_compile_args=[*omp_args, "-g0", "-O3"], 50 | extra_link_args=omp_args, 51 | ), 52 | ] 53 | 54 | setup( 55 | name="caput", # required 56 | ext_modules=cythonize(extensions), 57 | ) 58 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: caput 3 | message: >- 4 | If you use this software, please cite it using the 5 | metadata from this file. 6 | type: software 7 | doi: "10.5281/zenodo.5846374" 8 | repository-code: "https://github.com/radiocosmology/caput" 9 | abstract: >- 10 | Useful utilities for dealing with large datasets on computer clusters with 11 | applications to radio astronomy in mind. Includes modules for dynamically importing 12 | and utilizing mpi4py, in-memory mock-ups of h5py objects, and infrastructure for 13 | running data analysis pipelines on computer clusters. 14 | license: MIT 15 | authors: 16 | - given-names: "J. Richard" 17 | family-names: "Shaw" 18 | affiliation: "University of British Columbia" 19 | orcid: "https://orcid.org/0000-0002-4543-4588" 20 | 21 | - given-names: "Kiyoshi" 22 | family-names: "Masui" 23 | affiliation: "MIT" 24 | orcid: "https://orcid.org/0000-0002-4279-6946" 25 | 26 | - given-names: "Liam" 27 | family-names: "Gray" 28 | affiliation: "University of British Columbia" 29 | orcid: "https://orcid.org/0000-0003-3986-954X" 30 | 31 | - given-names: "Rick" 32 | family-names: "Nitsche" 33 | affiliation: "University of British Columbia" 34 | 35 | - given-names: "Anja" 36 | family-names: "Boskovic" 37 | affiliation: "University of British Columbia" 38 | 39 | - given-names: "Shifan" 40 | family-names: "Zuo" 41 | 42 | - given-names: "Donald V." 43 | family-names: "Wiebe" 44 | affiliation: "University of British Columbia" 45 | orcid: "https://orcid.org/0000-0002-6669-3159" 46 | 47 | - given-names: "Tristan" 48 | family-names: "Pinsonneault-Marotte" 49 | affiliation: "University of British Columbia" 50 | orcid: "https://orcid.org/0000-0002-9516-3245" 51 | 52 | - given-names: "Mateus" 53 | family-names: "Fandino" 54 | affiliation: "Thompson Rivers University" 55 | orcid: "https://orcid.org/0000-0002-6899-1176" 56 | 57 | - given-names: "Simon" 58 | family-names: "Foreman" 59 | affiliation: "Arizona State University" 60 | orcid: "https://orcid.org/0000-0002-0190-2271" 61 | 62 | - given-names: "Seth R." 63 | family-names: "Siegel" 64 | affiliation: "McGill University" 65 | orcid: "https://orcid.org/0000-0003-2631-6217" -------------------------------------------------------------------------------- /tests/test_interferometry.py: -------------------------------------------------------------------------------- 1 | """Test interferometry routines.""" 2 | 3 | import pytest 4 | from math import pi, sqrt 5 | 6 | from caput import interferometry 7 | 8 | 9 | def test_sphdist(): 10 | from skyfield.units import Angle 11 | 12 | # 90 degrees 13 | assert interferometry.sphdist( 14 | Angle(radians=0), Angle(radians=0), Angle(radians=pi / 2), Angle(radians=0) 15 | ).radians == pytest.approx(pi / 2) 16 | assert interferometry.sphdist( 17 | Angle(radians=0), Angle(radians=0), Angle(radians=0), Angle(radians=pi / 2) 18 | ).radians == pytest.approx(pi / 2) 19 | 20 | # 60 degrees 21 | assert interferometry.sphdist( 22 | Angle(radians=0), Angle(radians=0), Angle(radians=pi / 4), Angle(radians=pi / 4) 23 | ).radians == pytest.approx(pi / 3) 24 | assert interferometry.sphdist( 25 | Angle(radians=pi / 4), Angle(radians=pi / 4), Angle(radians=0), Angle(radians=0) 26 | ).radians == pytest.approx(pi / 3) 27 | 28 | 29 | def test_rotate_ypr(): 30 | 31 | # No rotation 32 | basis = interferometry.rotate_ypr([0, 0, 0], 1, 2, 3) 33 | assert basis == pytest.approx([1, 2, 3]) 34 | 35 | # Rotate into +Y two ways 36 | basis = interferometry.rotate_ypr([pi / 2, 0, 0], 1, 0, 0) 37 | assert basis == pytest.approx([0, 1, 0]) 38 | 39 | basis = interferometry.rotate_ypr([0, pi / 2, 0], 0, 0, 1) 40 | assert basis == pytest.approx([0, 1, 0]) 41 | 42 | # General rotation 43 | x0 = sqrt(2) / 2 44 | y0 = 0.5 45 | z0 = 0.5 46 | basis = interferometry.rotate_ypr([pi / 2, pi / 3, pi / 4], x0, y0, z0) 47 | 48 | # The calculation goes like this ("v" denotes a square root): 49 | 50 | # After yawing by 90 degrees: 51 | # x1 = -y0 y1 = x0 z1 = z0 52 | 53 | # Then pitching by 60 degrees: 54 | # x2 = x1 y2 = 0.5 * y1 + v3/2 z1 z2 = -v3/2 y1 + 0.5 z1 55 | 56 | # Then rolling by 45 degrees: 57 | # x3 = v2/2 (x2 - z2) y3 = y2 z3 = v2/2 (x2 + z2) 58 | 59 | assert basis[0] == pytest.approx(sqrt(2) / 2 * (-y0 + sqrt(3) * x0 / 2 - z0 / 2)) 60 | assert basis[1] == pytest.approx(x0 / 2 + sqrt(3) * z0 / 2) 61 | assert basis[2] == pytest.approx(sqrt(2) / 2 * (-y0 - sqrt(3) * x0 / 2 + z0 / 2)) 62 | -------------------------------------------------------------------------------- /tests/test_metadata.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import yaml 3 | 4 | import numpy 5 | import caput 6 | 7 | 8 | class TestConfig(unittest.TestCase): 9 | def test_default_params(self): 10 | testconfig = """ 11 | pipeline: 12 | bla: "foo" 13 | """ 14 | 15 | man = caput.pipeline.Manager.from_yaml_str(testconfig) 16 | 17 | self.assertIn("versions", man.all_tasks_params) 18 | self.assertIn("pipeline_config", man.all_tasks_params) 19 | 20 | self.assertDictEqual(man.all_tasks_params["versions"], {}) 21 | # remove line numbers 22 | pipeline_config = man.all_tasks_params["pipeline_config"] 23 | self.assertDictEqual( 24 | pipeline_config, 25 | yaml.load(testconfig, Loader=yaml.SafeLoader), 26 | ) 27 | 28 | def test_metadata_params(self): 29 | testconfig = """ 30 | foo: bar 31 | pipeline: 32 | save_versions: 33 | - numpy 34 | - caput 35 | bla: "foo" 36 | """ 37 | 38 | man = caput.pipeline.Manager.from_yaml_str(testconfig) 39 | 40 | self.assertIn("versions", man.all_tasks_params) 41 | self.assertIn("pipeline_config", man.all_tasks_params) 42 | 43 | self.assertDictEqual( 44 | man.all_tasks_params["versions"], 45 | {"numpy": numpy.__version__, "caput": caput.__version__}, 46 | ) 47 | 48 | # remove line numbers 49 | pipeline_config = man.all_tasks_params["pipeline_config"] 50 | self.assertDictEqual( 51 | pipeline_config, 52 | yaml.load(testconfig, Loader=yaml.SafeLoader), 53 | ) 54 | 55 | def test_metadata_params_no_config(self): 56 | testconfig = """ 57 | pipeline: 58 | save_versions: numpy 59 | save_config: False 60 | """ 61 | 62 | man = caput.pipeline.Manager.from_yaml_str(testconfig) 63 | 64 | self.assertIn("versions", man.all_tasks_params) 65 | self.assertIn("pipeline_config", man.all_tasks_params) 66 | 67 | self.assertDictEqual( 68 | man.all_tasks_params["versions"], {"numpy": numpy.__version__} 69 | ) 70 | self.assertIsNone(man.all_tasks_params["pipeline_config"], {}) 71 | -------------------------------------------------------------------------------- /tests/test_fftw.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the fftw module.""" 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from caput import fftw 7 | from scipy import fft as sfft 8 | 9 | 10 | ARRAY_SIZE = (100, 111) 11 | SEED = 12345 12 | ATOL = 1e-10 13 | rng = np.random.Generator(np.random.SFC64(SEED)) 14 | 15 | # NOTE: only complex->complex transforms are currently supported, 16 | # but we still want to test that a proper error is raised 17 | random_double_array = rng.standard_normal(size=ARRAY_SIZE, dtype=np.float64) 18 | random_complex_array = rng.standard_normal( 19 | size=ARRAY_SIZE 20 | ) + 1.0j * rng.standard_normal(size=ARRAY_SIZE) 21 | 22 | 23 | @pytest.mark.parametrize("x", [random_double_array]) 24 | def test_invalid_type(x): 25 | """Test that an error is raised with a non-complex type.""" 26 | with pytest.raises(TypeError): 27 | fftw.FFT(x.shape, x.dtype) 28 | 29 | 30 | @pytest.mark.parametrize("x", [random_complex_array]) 31 | @pytest.mark.parametrize("ax", [(0,), (1,), None]) 32 | def test_forward_backward(x, ax): 33 | """Test that ifft(fft(x)) returns the original array.""" 34 | # Test the direct class implementation 35 | fftobj = fftw.FFT(x.shape, x.dtype, ax) 36 | 37 | if np.isrealobj(x): 38 | # pyfftw will destroy the input array for 39 | # real->real inverse transform, but we want 40 | # to test that it _won't_ destroy the array in 41 | # the complex case 42 | xi = fftobj.ifft(fftobj.fft(x.copy())) 43 | else: 44 | xi = fftobj.ifft(fftobj.fft(x)) 45 | 46 | assert np.allclose(x, xi, atol=ATOL) 47 | 48 | # Test the api 49 | if np.isrealobj(x): 50 | xi = fftw.ifft(fftw.fft(x.copy())) 51 | else: 52 | xi = fftw.ifft(fftw.fft(x)) 53 | 54 | assert np.allclose(x, xi, atol=ATOL) 55 | 56 | 57 | @pytest.mark.parametrize("x", [random_complex_array]) 58 | @pytest.mark.parametrize("ax", [(0,), (1,), None]) 59 | def test_scipy(x, ax): 60 | """Test that this produces the same results as `scipy.fft`.""" 61 | Xc = fftw.fft(x, ax) 62 | ixc = fftw.ifft(Xc, ax) 63 | 64 | # Scipy requires different calls for 1D, 2D, real, and complex cases 65 | if ax is not None and len(ax) == 1: 66 | Xs = sfft.fft(x, axis=ax[0]) 67 | ixs = sfft.ifft(Xs, axis=ax[0]) 68 | else: 69 | Xs = sfft.fft2(x) 70 | ixs = sfft.ifft2(Xs) 71 | 72 | assert np.allclose(Xc, Xs, atol=ATOL) 73 | assert np.allclose(ixc, ixs, atol=ATOL) 74 | -------------------------------------------------------------------------------- /tests/test_selection.py: -------------------------------------------------------------------------------- 1 | """Serial version of the selection tests.""" 2 | 3 | import pytest 4 | from pytest_lazy_fixtures import lf 5 | import numpy as np 6 | 7 | from caput.memh5 import MemGroup 8 | from caput import fileformats 9 | 10 | fsel = slice(1, 8, 2) 11 | isel = slice(1, 4) 12 | ind = [0, 2, 7] 13 | sel = {"dset1": (fsel, isel, slice(None)), "dset2": (fsel, slice(None))} 14 | index_sel = {"dset1": (fsel, ind, slice(None)), "dset2": (ind, slice(None))} 15 | 16 | 17 | @pytest.fixture 18 | def h5_file_select(datasets, h5_file, rm_all_files): 19 | """Provides an HDF5 file with some content for testing.""" 20 | container = MemGroup() 21 | container.create_dataset("dset1", data=datasets[0].view()) 22 | container.create_dataset("dset2", data=datasets[1].view()) 23 | container.to_hdf5(h5_file) 24 | yield h5_file, datasets 25 | rm_all_files(h5_file) 26 | 27 | 28 | @pytest.fixture 29 | def zarr_file_select(datasets, zarr_file, rm_all_files): 30 | """Provides a Zarr file with some content for testing.""" 31 | container = MemGroup() 32 | container.create_dataset("dset1", data=datasets[0].view()) 33 | container.create_dataset("dset2", data=datasets[1].view()) 34 | container.to_file(zarr_file, file_format=fileformats.Zarr) 35 | yield zarr_file, datasets 36 | rm_all_files(zarr_file) 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "container_on_disk, file_format", 41 | [ 42 | (lf("h5_file_select"), fileformats.HDF5), 43 | (lf("zarr_file_select"), fileformats.Zarr), 44 | ], 45 | ) 46 | def test_file_select(container_on_disk, file_format): 47 | """Tests that makes hdf5 objects and tests selecting on their axes.""" 48 | 49 | m = MemGroup.from_file( 50 | container_on_disk[0], selections=sel, file_format=file_format 51 | ) 52 | assert np.all(m["dset1"][:] == container_on_disk[1][0][(fsel, isel, slice(None))]) 53 | assert np.all(m["dset2"][:] == container_on_disk[1][1][(fsel, slice(None))]) 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "container_on_disk, file_format", 58 | [ 59 | (lf("h5_file_select"), fileformats.HDF5), 60 | pytest.param( 61 | lf("zarr_file_select"), 62 | fileformats.Zarr, 63 | marks=pytest.mark.xfail(reason="Zarr doesn't support index selections."), 64 | ), 65 | ], 66 | ) 67 | def test_file_select_index(container_on_disk, file_format): 68 | """Tests that makes hdf5 objects and tests selecting on their axes.""" 69 | 70 | # now test index selection 71 | m = MemGroup.from_file( 72 | container_on_disk[0], selections=index_sel, file_format=file_format 73 | ) 74 | assert np.all(m["dset1"][:] == container_on_disk[1][0][index_sel["dset1"]]) 75 | assert np.all(m["dset2"][:] == container_on_disk[1][1][index_sel["dset2"]]) 76 | -------------------------------------------------------------------------------- /tests/test_lint.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import pytest 3 | import yaml 4 | 5 | from click.testing import CliRunner 6 | 7 | from caput.config import Property 8 | from caput.pipeline import TaskBase 9 | from caput.scripts import runner as caput_script 10 | 11 | 12 | class DoNothing(TaskBase): 13 | pass 14 | 15 | 16 | class DoNothing2(DoNothing): 17 | a_list = Property(proptype=list) 18 | 19 | 20 | @pytest.fixture() 21 | def simple_config(): 22 | yield { 23 | "pipeline": { 24 | "tasks": [ 25 | { 26 | "type": "tests.test_lint.DoNothing", 27 | "out": "out1", 28 | }, 29 | { 30 | "type": "tests.test_lint.DoNothing", 31 | "out": "out2", 32 | "in": "out1", 33 | }, 34 | { 35 | "type": "tests.test_lint.DoNothing2", 36 | "in": "out2", 37 | "out": "out3", 38 | "requires": "out1", 39 | "params": { 40 | "a_list": [1], 41 | }, 42 | }, 43 | ] 44 | } 45 | } 46 | 47 | 48 | def write_to_file(config_json): 49 | with tempfile.NamedTemporaryFile(mode="w+t", delete=False) as temp: 50 | yaml.safe_dump(config_json, temp, encoding="utf-8") 51 | temp.flush() 52 | return temp.name 53 | 54 | 55 | def test_load_yaml(simple_config): 56 | test_runner = CliRunner() 57 | config_file = write_to_file(simple_config) 58 | result = test_runner.invoke(caput_script.lint_config, [config_file]) 59 | assert result.exit_code == 0 60 | 61 | 62 | def test_unknown_task(simple_config): 63 | test_runner = CliRunner() 64 | simple_config["pipeline"]["tasks"][0]["type"] = "what.was.the.name.of.my.Task" 65 | config_file = write_to_file(simple_config) 66 | result = test_runner.invoke(caput_script.lint_config, [config_file]) 67 | assert result.exit_code != 0 68 | 69 | 70 | def test_wrong_type(simple_config): 71 | test_runner = CliRunner() 72 | simple_config["pipeline"]["tasks"][2]["params"]["a_list"] = 1 73 | config_file = write_to_file(simple_config) 74 | result = test_runner.invoke(caput_script.lint_config, [config_file]) 75 | assert result.exit_code != 0 76 | 77 | 78 | def test_lonely_in(simple_config): 79 | test_runner = CliRunner() 80 | simple_config["pipeline"]["tasks"][2]["in"] = "foo" 81 | config_file = write_to_file(simple_config) 82 | result = test_runner.invoke(caput_script.lint_config, [config_file]) 83 | assert result.exit_code != 0 84 | 85 | 86 | def test_lonely_requires(simple_config): 87 | test_runner = CliRunner() 88 | simple_config["pipeline"]["tasks"][2]["requires"] = "bar" 89 | config_file = write_to_file(simple_config) 90 | result = test_runner.invoke(caput_script.lint_config, [config_file]) 91 | assert result.exit_code != 0 92 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | """Test the miscellaneous tools.""" 2 | 3 | import unittest 4 | import tempfile 5 | import os 6 | import pytest 7 | import shutil 8 | 9 | from caput import misc 10 | 11 | 12 | class TestLock(unittest.TestCase): 13 | def setUp(self): 14 | self.dir = tempfile.mkdtemp() 15 | 16 | def test_lock_new(self): 17 | """Test the normal behaviour""" 18 | 19 | base = "newfile.dat" 20 | newfile_name = os.path.join(self.dir, base) 21 | lockfile_name = os.path.join(self.dir, "." + base + ".lock") 22 | 23 | with misc.lock_file(newfile_name) as fname: 24 | # Check lock file has been created 25 | self.assertTrue(os.path.exists(lockfile_name)) 26 | 27 | # Create a stub file 28 | with open(fname, "w+") as fh: 29 | fh.write("hello") 30 | 31 | # Check the file exists only at the temporary path 32 | self.assertTrue(os.path.exists(fname)) 33 | self.assertFalse(os.path.exists(newfile_name)) 34 | 35 | # Check the file exists at the final path and the lock file removed 36 | self.assertTrue(os.path.exists(newfile_name)) 37 | self.assertFalse(os.path.exists(lockfile_name)) 38 | 39 | def test_lock_exception(self): 40 | """Check what happens in an exception""" 41 | 42 | base = "newfile2.dat" 43 | newfile_name = os.path.join(self.dir, base) 44 | lockfile_name = os.path.join(self.dir, "." + base + ".lock") 45 | 46 | with pytest.raises(RuntimeError): 47 | with misc.lock_file(newfile_name) as fname: 48 | # Create a stub file 49 | with open(fname, "w+") as fh: 50 | fh.write("hello") 51 | 52 | raise RuntimeError("Test error") 53 | 54 | # Check that neither the file, nor its lock exists 55 | self.assertFalse(os.path.exists(newfile_name)) 56 | self.assertFalse(os.path.exists(lockfile_name)) 57 | 58 | def test_lock_exception_preserve(self): 59 | """Check what happens in an exception when asked to preserve the temp file""" 60 | 61 | base = "newfile3.dat" 62 | newfile_name = os.path.join(self.dir, base) 63 | lockfile_name = os.path.join(self.dir, "." + base + ".lock") 64 | tmpfile_name = os.path.join(self.dir, "." + base) 65 | 66 | with pytest.raises(RuntimeError): 67 | with misc.lock_file(newfile_name, preserve=True) as fname: 68 | # Create a stub file 69 | with open(fname, "w+") as fh: 70 | fh.write("hello") 71 | 72 | raise RuntimeError("Test error") 73 | 74 | # Check that neither the file, nor its lock exists, but that the 75 | # temporary file does 76 | self.assertTrue(os.path.exists(tmpfile_name)) 77 | self.assertFalse(os.path.exists(newfile_name)) 78 | self.assertFalse(os.path.exists(lockfile_name)) 79 | 80 | def tearDown(self): 81 | shutil.rmtree(self.dir) 82 | 83 | 84 | if __name__ == "__main__": 85 | unittest.main() 86 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: caput-ci-build 2 | on: 3 | pull_request: 4 | branches: 5 | - master 6 | push: 7 | branches: 8 | - master 9 | 10 | jobs: 11 | 12 | lint-code: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Set up Python 3.13 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: "3.13" 21 | 22 | - name: Install apt dependencies 23 | run: | 24 | sudo apt-get update 25 | sudo apt-get install -y libopenmpi-dev openmpi-bin libhdf5-serial-dev 26 | 27 | - name: Install pip dependencies 28 | run: | 29 | pip install black ruff 30 | 31 | - name: Run ruff (flake8 and pydocstyle) 32 | run: ruff check . 33 | 34 | - name: Check code with black 35 | run: black --check . 36 | 37 | run-tests: 38 | 39 | strategy: 40 | matrix: 41 | os: [ubuntu-latest, macos-latest] 42 | python-version: ["3.10", "3.13"] 43 | 44 | runs-on: ${{ matrix.os }} 45 | steps: 46 | - uses: actions/checkout@v4 47 | 48 | - name: Install apt dependencies 49 | if: matrix.os == 'ubuntu-latest' 50 | run: | 51 | sudo apt-get update 52 | sudo apt-get install -y libopenmpi-dev openmpi-bin libhdf5-serial-dev 53 | 54 | - name: Install brew dependencies 55 | if: matrix.os == 'macos-latest' 56 | run: | 57 | brew install open-mpi hdf5 58 | 59 | - name: Set up Python ${{ matrix.python-version }} 60 | uses: actions/setup-python@v5 61 | with: 62 | python-version: ${{ matrix.python-version }} 63 | 64 | - name: Update pip 65 | run: pip install --upgrade pip 66 | 67 | - name: Install pip dependencies 68 | run: | 69 | pip install -e . 70 | pip install -e .[compression,mpi,test,fftw] 71 | 72 | - name: Run serial tests 73 | run: pytest --doctest-modules . 74 | 75 | - name: Run parallel tests 76 | run: | 77 | mpirun --oversubscribe -np 4 pytest tests/test_memh5_parallel.py 78 | mpirun --oversubscribe -np 4 pytest tests/test_mpiarray.py 79 | mpirun -np 1 pytest tests/test_selection_parallel.py 80 | mpirun --oversubscribe -np 2 pytest tests/test_selection_parallel.py 81 | mpirun --oversubscribe -np 4 pytest tests/test_selection_parallel.py 82 | 83 | build-docs: 84 | 85 | runs-on: ubuntu-latest 86 | steps: 87 | - uses: actions/checkout@v4 88 | 89 | - name: Set up Python 3.13 90 | uses: actions/setup-python@v5 91 | with: 92 | python-version: "3.13" 93 | 94 | - name: Install apt dependencies 95 | run: | 96 | sudo apt-get update 97 | sudo apt-get install -y libhdf5-serial-dev 98 | 99 | - name: Install pip dependencies 100 | run: | 101 | pip install . 102 | pip install .[docs] 103 | 104 | - name: Build sphinx docs 105 | run: sphinx-build -W -b html doc/ doc/_build/html 106 | -------------------------------------------------------------------------------- /caput/MedianTreeNode.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MEDIAN_TREE_NODE 2 | #define MEDIAN_TREE_NODE 3 | 4 | #include "MedianTree.hpp" 5 | #include "MedianTreeNodeData.hpp" 6 | 7 | namespace MedianTree { 8 | 9 | /* AVL Node class, holds data and references to adjacent nodes. */ 10 | template 11 | class Node { 12 | public: 13 | friend class Tree; 14 | 15 | Node(const T& newData, const double weight); 16 | ~Node(); 17 | 18 | private: 19 | // Shared pointers to the children 20 | std::shared_ptr> left; 21 | std::shared_ptr> right; 22 | 23 | // Weak pointer to the parent for traversal 24 | std::weak_ptr> parent; 25 | 26 | data_ptr data; 27 | 28 | // Store these locally to reduce the time cost of lookups 29 | int left_height; 30 | int right_height; 31 | double left_weight; 32 | double right_weight; 33 | 34 | // If true this is the left child of its parent, False otherwise. 35 | bool left_child; 36 | 37 | int total_height(); 38 | 39 | /** 40 | * @brief Update the internal parameters after a structure change. 41 | * 42 | * Only looks at the direct children. If they have changed, update must be 43 | * called on them first. 44 | **/ 45 | void update(); 46 | 47 | double total_weight(); 48 | }; 49 | 50 | template 51 | using node_ptr = std::shared_ptr>; 52 | 53 | /* Constructor for Node, sets the node's data to element. */ 54 | template 55 | Node::Node(const T& element, const double weight) : 56 | data(std::make_shared>(element, weight)) { 57 | // std::cout<<"new node: ("< 68 | Node::~Node() { 69 | left = nullptr; 70 | right = nullptr; 71 | } 72 | 73 | /* Gets the total total_height of the subtree the node is the root of. 74 | * Adds together 1, left_height & right_height. */ 75 | template 76 | int Node::total_height() { 77 | // return height + left_height + right_height; 78 | return std::max(left_height, right_height) + 1; 79 | } 80 | 81 | /* Gets the total total_weight of the subtree the node is the root of. 82 | * Adds together weight, left_weight & wight_height. */ 83 | template 84 | double Node::total_weight() { 85 | // std::cout<=2.0.0rc1", 8 | ] 9 | build-backend = "setuptools.build_meta" 10 | 11 | [project] 12 | name = "caput" 13 | description = "Cluster Astronomical Python Utilities" 14 | license = { file = "LICENSE" } 15 | authors = [ 16 | { name = "The CHIME Collaboration", email = "lgray@phas.ubc.ca" } 17 | ] 18 | maintainers = [ 19 | { name = "Liam Gray", email = "lgray@phas.ubc.ca" }, 20 | { name = "Don Wiebe", email = "dvw@phas.ubc.ca" } 21 | ] 22 | dynamic = ["readme", "version"] 23 | requires-python = ">=3.10" 24 | dependencies = [ 25 | "cachetools", 26 | "click", 27 | "cython", 28 | "h5py", 29 | "numpy>=1.24", 30 | "psutil", 31 | "PyYAML", 32 | "scipy>=1.13", 33 | "skyfield>=1.31", 34 | ] 35 | classifiers = [ 36 | "Development Status :: 5 - Production/Stable", 37 | "Intended Audience :: Developers", 38 | "Intended Audience :: Science/Research", 39 | "Programming Language :: Python :: 3", 40 | "Programming Language :: Python :: 3.10", 41 | "Programming Language :: Python :: 3.11", 42 | "Programming Language :: Python :: 3.12", 43 | "Programming Language :: Python :: 3.13", 44 | ] 45 | 46 | [project.optional-dependencies] 47 | mpi = ["mpi4py>=1.3"] 48 | compression = ["bitshuffle", "zarr>=2.11.0,<3", "numcodecs>=0.7.3,<0.16"] 49 | profiling = ["pyinstrument"] 50 | docs = ["Sphinx>=5.0", "sphinx_rtd_theme", "funcsigs", "mock"] 51 | lint = ["ruff", "black"] 52 | test = ["pytest", "pytest-lazy-fixtures"] 53 | fftw = ["pyfftw>=0.13.1"] 54 | 55 | [project.urls] 56 | Documentation = "https://caput.readthedocs.io/" 57 | Repository = "https://github.com/radiocosmology/caput" 58 | 59 | [project.scripts] 60 | caput-pipeline = "caput.scripts.runner:cli" 61 | 62 | [tool.setuptools.dynamic] 63 | readme = { file = ["README.md"], content-type = "text/markdown" } 64 | 65 | [tool.setuptools-git-versioning] 66 | enabled = true 67 | 68 | [tool.setuptools.packages] 69 | find = {} 70 | 71 | [tool.ruff] 72 | # Enable: 73 | # pycodestyle ('E') 74 | # pydocstyle ('D') 75 | # pyflakes ('F') 76 | # isort ('I') 77 | # pyupgrade ('UP') 78 | # numpy-specific ('NPY') 79 | # ruff-specific ('RUF') 80 | # flake8-blind-except ('BLE') 81 | # flake8-comprehensions ('C4') 82 | # flake8-return ('RET') 83 | lint.select = ["E", "D", "F", "I", "UP", "NPY", "RUF", "BLE", "C4", "RET"] 84 | lint.ignore = [ 85 | "E501", # E501: line length violations. Enforce these with `black` 86 | "E741", # E741: Ambiguous variable name 87 | "D105", # D105: Missing docstring in magic method 88 | "D107", # D107: Missing docstring in init 89 | "D203", # D203: 1 blank line required before class docstring 90 | "D213", # D213: Multi-line docstring summary should start at the second line 91 | "D400", # D400: First line should end with a period (only ignoring this because there's another error that catches the same thing) 92 | "D401", # D401: First line should be in imperative mood 93 | "D402", # D402: First line should not be the function’s “signature” 94 | "D413", # D413: Missing blank line after last section 95 | "D416", # D416: Section name should end with a colon 96 | "NPY002", # NPY002: replace legacy numpy.random calls with np.random.Generator 97 | ] 98 | exclude = [ 99 | ".git", 100 | ".github", 101 | "build", 102 | "doc", 103 | "tests", 104 | ] 105 | target-version = "py310" 106 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import pytest 4 | import yaml 5 | 6 | from caput import config 7 | 8 | 9 | # Test classes 10 | class Person(config.Reader): 11 | name = config.Property(default="Bill", proptype=str) 12 | age = config.Property(default=26, proptype=float, key="ageinyears") 13 | 14 | 15 | class PersonWithPet(Person): 16 | petname = config.Property(default="Molly", proptype=str) 17 | petage = 36 18 | 19 | 20 | class ListTypeTests(Person): 21 | list_max_length = config.list_type(maxlength=2) 22 | list_exact_length = config.list_type(length=2) 23 | list_type = config.list_type(type_=int) 24 | 25 | 26 | class DictTypeTests(config.Reader): 27 | dict_config = config.Property(proptype=dict) 28 | 29 | 30 | # Test data dict 31 | testdict = {"name": "Richard", "ageinyears": 40, "petname": "Sooty"} 32 | 33 | 34 | # Tests 35 | def test_default_params(): 36 | person1 = Person() 37 | 38 | assert person1.name == "Bill" 39 | assert person1.age == 26.0 40 | assert isinstance(person1.age, float) 41 | 42 | 43 | def test_set_params(): 44 | person = Person() 45 | person.name = "Mick" 46 | 47 | assert person.name == "Mick" 48 | 49 | 50 | def test_read_config(): 51 | person = Person() 52 | person.read_config(testdict) 53 | 54 | assert person.name == "Richard" 55 | assert person.age == 40.0 56 | 57 | 58 | def test_inherit_read_config(): 59 | person = PersonWithPet() 60 | person.read_config(testdict) 61 | 62 | assert person.name == "Richard" 63 | assert person.age == 40.0 64 | assert person.petname == "Sooty" 65 | 66 | 67 | def test_pickle(): 68 | person = PersonWithPet() 69 | person.read_config(testdict) 70 | person2 = pickle.loads(pickle.dumps(person)) 71 | 72 | assert person2.name == "Richard" 73 | assert person2.age == 40.0 74 | assert person2.petname == "Sooty" 75 | 76 | 77 | def test_list_type(): 78 | lt = ListTypeTests() 79 | 80 | with pytest.raises(config.CaputConfigError): 81 | lt.read_config({"list_max_length": [1, 3, 4]}) 82 | 83 | # Should work fine 84 | lt = ListTypeTests() 85 | lt.read_config({"list_max_length": [1, 2]}) 86 | 87 | with pytest.raises(config.CaputConfigError): 88 | lt.read_config({"list_exact_length": [3]}) 89 | 90 | # Work should fine 91 | lt = ListTypeTests() 92 | lt.read_config({"list_exact_length": [1, 2]}) 93 | 94 | with pytest.raises(config.CaputConfigError): 95 | lt.read_config({"list_type": ["hello"]}) 96 | 97 | # Work should fine 98 | lt = ListTypeTests() 99 | lt.read_config({"list_type": [1, 2]}) 100 | 101 | 102 | def test_no_line(): 103 | # This tests that dicts get set as config parameters as expected, and covers a flaw 104 | # in an earlier version of the linting code where `__line__` keys were getting 105 | # inserted into dict types config properties 106 | 107 | dt = DictTypeTests() 108 | 109 | # Test with an empty dict 110 | yaml_str = yaml.dump({"dict_config": {}}) 111 | yaml_params = yaml.load(yaml_str, Loader=config.SafeLineLoader) 112 | dt.read_config(yaml_params) 113 | 114 | assert len(dt.dict_config) == 0 115 | assert isinstance(dt.dict_config, dict) 116 | 117 | # Test with a non-empty dict 118 | yaml_str = yaml.dump({"dict_config": {"a": 3}}) 119 | yaml_params = yaml.load(yaml_str, Loader=config.SafeLineLoader) 120 | dt.read_config(yaml_params) 121 | 122 | assert len(dt.dict_config) == 1 123 | assert isinstance(dt.dict_config, dict) 124 | assert dt.dict_config["a"] == 3 125 | -------------------------------------------------------------------------------- /tests/test_selection_parallel.py: -------------------------------------------------------------------------------- 1 | """Parallel version of the selection tests. 2 | 3 | Needs to be run on 1, 2 or 4 MPI processes. 4 | """ 5 | 6 | from mpi4py import MPI 7 | import numpy as np 8 | import pytest 9 | from pytest_lazy_fixtures import lf 10 | 11 | from caput import mpiutil, mpiarray, fileformats 12 | from caput.memh5 import MemGroup 13 | 14 | 15 | comm = MPI.COMM_WORLD 16 | 17 | 18 | @pytest.fixture 19 | def container_on_disk(datasets, file_name, file_format, rm_all_files): 20 | """Prepare a file for the select_parallel tests.""" 21 | if comm.rank == 0: 22 | m1 = mpiarray.MPIArray.wrap(datasets[0], axis=0, comm=MPI.COMM_SELF) 23 | m2 = mpiarray.MPIArray.wrap(datasets[1], axis=0, comm=MPI.COMM_SELF) 24 | container = MemGroup(distributed=True, comm=MPI.COMM_SELF) 25 | container.create_dataset("dset1", data=m1, distributed=True) 26 | container.create_dataset("dset2", data=m2, distributed=True) 27 | container.to_file(file_name, file_format=file_format) 28 | 29 | comm.Barrier() 30 | 31 | yield file_name, datasets 32 | 33 | comm.Barrier() 34 | 35 | if comm.rank == 0: 36 | rm_all_files(file_name) 37 | 38 | 39 | @pytest.fixture 40 | def xfail_zarr_listsel(request): 41 | file_format = request.getfixturevalue("file_format") 42 | ind = request.getfixturevalue("ind") 43 | 44 | if file_format == fileformats.Zarr and isinstance(ind, (list, tuple)): 45 | request.node.add_marker( 46 | pytest.mark.xfail(reason="Zarr doesn't support list based indexing.") 47 | ) 48 | 49 | 50 | @pytest.mark.parametrize( 51 | "file_name, file_format", 52 | [ 53 | (lf("h5_file"), fileformats.HDF5), 54 | (lf("zarr_file"), fileformats.Zarr), 55 | ], 56 | ) 57 | @pytest.mark.parametrize("fsel", [slice(1, 8, 2), slice(5, 8, 2)]) 58 | @pytest.mark.parametrize("isel", [slice(1, 4), slice(5, 8, 2)]) 59 | @pytest.mark.parametrize("ind", [slice(None), [0, 2, 7]]) 60 | @pytest.mark.usefixtures("xfail_zarr_listsel") 61 | def test_FileSelect_distributed(container_on_disk, fsel, isel, file_format, ind): 62 | """Load H5/Zarr file into parallel container while down-selecting axes.""" 63 | 64 | if ind == slice(None): 65 | sel = {"dset1": (fsel, isel, slice(None)), "dset2": (fsel, slice(None))} 66 | else: 67 | sel = {"dset1": (fsel, ind, slice(None)), "dset2": (ind, slice(None))} 68 | 69 | # Tests are designed to run for 1, 2 or 4 processes 70 | assert 4 % comm.size == 0 71 | 72 | m = MemGroup.from_file( 73 | container_on_disk[0], 74 | selections=sel, 75 | distributed=True, 76 | comm=comm, 77 | file_format=file_format, 78 | ) 79 | 80 | d1 = container_on_disk[1][0][sel["dset1"]] 81 | d2 = container_on_disk[1][1][sel["dset2"]] 82 | 83 | _, s, e = mpiutil.split_local(d1.shape[0], comm=comm) 84 | d1slice = slice(s, e) 85 | 86 | _, s, e = mpiutil.split_local(d2.shape[0], comm=comm) 87 | d2slice = slice(s, e) 88 | 89 | # For debugging... 90 | # Need to dereference datasets as this is collective 91 | # md1 = m["dset1"][:] 92 | # md2 = m["dset2"][:] 93 | # for ri in range(comm.size): 94 | # if ri == comm.rank: 95 | # print(comm.rank) 96 | # print(md1.shape, d1.shape, d1[d1slice].shape) 97 | # print(md1[0, :2, :2] if md1.size else "Empty") 98 | # print(d1[d1slice][0, :2, :2] if d1[d1slice].size else "Empty") 99 | # print() 100 | # print(md2.shape, d2.shape, d2[d2slice].shape) 101 | # print(md2[0, :2] if md2.size else "Empty") 102 | # print(d2[d2slice][0, :2] if d2[d2slice].size else "Empty") 103 | # comm.Barrier() 104 | 105 | assert np.all(m["dset1"][:].local_array == d1[d1slice]) 106 | assert np.all(m["dset2"][:].local_array == d2[d2slice]) 107 | -------------------------------------------------------------------------------- /tests/test_tools.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the tools module.""" 2 | 3 | from mpi4py import MPI 4 | import numpy as np 5 | import pytest 6 | 7 | from caput import mpiarray, tools 8 | 9 | 10 | ARRAY_SIZE = (100, 111) 11 | SEED = 12345 12 | ATOL = 0.0 13 | rng = np.random.Generator(np.random.SFC64(SEED)) 14 | 15 | random_float_array = rng.standard_normal(size=ARRAY_SIZE, dtype=np.float32) 16 | random_double_array = rng.standard_normal(size=ARRAY_SIZE, dtype=np.float64) 17 | random_complex_array = rng.standard_normal( 18 | size=ARRAY_SIZE 19 | ) + 1.0j * rng.standard_normal(size=ARRAY_SIZE) 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "a", [random_complex_array, random_float_array, random_double_array] 24 | ) 25 | def test_invert_no_zero(a): 26 | zero_ind = ((0, 10, 12), (56, 34, 78)) 27 | good_ind = np.ones(a.shape, dtype=bool) 28 | good_ind[zero_ind] = False 29 | 30 | # set up some invalid values for inverse 31 | a[zero_ind[0][0], zero_ind[1][0]] = 0.0 32 | a[zero_ind[0][1], zero_ind[1][1]] = 0.5 / np.finfo(a.real.dtype).max 33 | 34 | if np.iscomplexobj(a): 35 | # these should be inverted fine 36 | a[10, 0] = 1.0 37 | a[10, 1] = 1.0j 38 | # also test invalid in the imaginary part 39 | a[zero_ind[0][2], zero_ind[1][2]] = 0.5j / np.finfo(a.real.dtype).max 40 | else: 41 | a[zero_ind[0][2], zero_ind[1][2]] = -0.5 / np.finfo(a.real.dtype).max 42 | 43 | b = tools.invert_no_zero(a) 44 | assert np.allclose(b[good_ind], 1.0 / a[good_ind], atol=ATOL) 45 | assert (b[zero_ind] == 0).all() 46 | 47 | 48 | def test_invert_no_zero_mpiarray(): 49 | comm = MPI.COMM_WORLD 50 | comm.Barrier() 51 | 52 | a = mpiarray.MPIArray((20, 30), axis=0, comm=comm) 53 | a[:] = comm.rank 54 | 55 | b = tools.invert_no_zero(a) 56 | 57 | assert b.shape == a.shape 58 | assert b.comm == a.comm 59 | assert b.axis == a.axis 60 | assert b.local_shape == a.local_shape 61 | 62 | assert (a * b).local_array == pytest.approx(0.0 if comm.rank == 0 else 1.0) 63 | comm.Barrier() 64 | 65 | 66 | def test_invert_no_zero_noncontiguous(): 67 | a = np.arange(100, dtype=np.float64).reshape(10, 10) 68 | 69 | res = np.ones((10, 10), dtype=np.float64) 70 | res[0, 0] = 0.0 71 | 72 | # Check the contiguous layout is working 73 | b_cont = tools.invert_no_zero(a.T.copy()) 74 | assert a.T * b_cont == pytest.approx(res) 75 | 76 | # Check that a Fortran contiguous array works 77 | b_noncont = tools.invert_no_zero(a.T) 78 | assert a.T * b_noncont == pytest.approx(res) 79 | 80 | # Check a complex sub slicing that is neither C nor F contiguous 81 | a_noncont = a.T[1::2, 1::2] 82 | b_noncont = tools.invert_no_zero(a_noncont) 83 | res_cont = tools.invert_no_zero(a_noncont.copy(order="C")) 84 | assert np.all(b_noncont == res_cont) 85 | assert b_noncont.flags["C_CONTIGUOUS"] 86 | 87 | 88 | def test_allequal(): 89 | # Test some basic types 90 | assert tools.allequal(1, 1) 91 | assert tools.allequal("x", "x") 92 | assert tools.allequal([1, 2, 3], [1, 2, 3]) 93 | assert tools.allequal({"a": 1}, {"a": 1}) 94 | assert tools.allequal({1, 2, 3}, {1, 2, 3}) 95 | assert tools.allequal((1, 2, 3), (1, 2, 3)) 96 | 97 | # Test numpy arrays and mpiarrays 98 | assert tools.allequal(np.array([1, 2, 3]), np.array([1, 2, 3])) 99 | assert tools.allequal( 100 | mpiarray.MPIArray.wrap(np.array([1, 2, 3]), axis=0), 101 | mpiarray.MPIArray.wrap(np.array([1, 2, 3]), axis=0), 102 | ) 103 | assert not tools.allequal( 104 | np.array([1, 2, 3]), 105 | mpiarray.MPIArray.wrap(np.array([1, 2, 3]), axis=0), 106 | ) 107 | 108 | # Test objects with numpy arrays in them 109 | assert tools.allequal( 110 | [np.array([1]), np.array([2])], [np.array([1]), np.array([2])] 111 | ) 112 | assert tools.allequal({"a": np.array([1])}, {"a": np.array([1])}) 113 | 114 | # Test objects with different types 115 | assert tools.allequal([1, 3.4, "a"], [1, 3.4, "a"]) 116 | 117 | # Test for items that are not equal 118 | assert not tools.allequal(1, 3) 119 | assert not tools.allequal(1, 1.1) 120 | assert not tools.allequal([np.array([1])], [np.array([2])]) 121 | assert not tools.allequal( 122 | np.array(["x"], dtype="U32"), np.array(["x"], dtype="S32") 123 | ) 124 | 125 | # Test for lengths that are not equal 126 | assert not tools.allequal([1], [1, 2]) 127 | assert not tools.allequal({"a": 1}, {"a": 1, "b": 2}) 128 | -------------------------------------------------------------------------------- /tests/test_truncate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from caput import truncate 5 | 6 | 7 | def test_bit_truncate(): 8 | assert truncate.bit_truncate_int(129, 1) == 128 9 | assert truncate.bit_truncate_int(-129, 1) == -128 10 | assert truncate.bit_truncate_int(1, 1) == 0 11 | 12 | assert truncate.bit_truncate_long(129, 1) == 128 13 | assert truncate.bit_truncate_long(-129, 1) == -128 14 | assert truncate.bit_truncate_long(576460752303423489, 1) == 576460752303423488 15 | assert ( 16 | truncate.bit_truncate_long(4520628863461491, 140737488355328) 17 | == 4503599627370496 18 | ) 19 | assert truncate.bit_truncate_long(1, 1) == 0 20 | 21 | assert truncate.bit_truncate_int(54321, 0) == 54321 22 | 23 | assert truncate.bit_truncate_long(576460752303423489, 0) == 576460752303423489 24 | 25 | # special cases 26 | assert truncate.bit_truncate_int(129, 0) == 129 27 | assert truncate.bit_truncate_int(0, 1) == 0 28 | assert truncate.bit_truncate_int(129, -1) == 0 29 | assert truncate.bit_truncate_long(129, 0) == 129 30 | assert truncate.bit_truncate_long(0, 1) == 0 31 | assert truncate.bit_truncate_long(129, -1) == 0 32 | 33 | 34 | def test_truncate_float(): 35 | assert truncate.bit_truncate_float(32.121, 1) == 32 36 | assert truncate.bit_truncate_float(-32.121, 1) == -32 37 | assert truncate.bit_truncate_float(32.125, 0) == 32.125 38 | assert truncate.bit_truncate_float(1, 1) == 0 39 | 40 | assert truncate.bit_truncate_float(1 + 1 / 1024, 1 / 2048) == 1 + 1 / 1024 41 | assert ( 42 | truncate.bit_truncate_float(1 + 1 / 1024 + 1 / 2048, 1 / 2048) == 1 + 2 / 1024 43 | ) 44 | assert truncate.bit_truncate_double(1 + 1 / 1024, 1 / 2048) == 1 + 1 / 1024 45 | assert ( 46 | truncate.bit_truncate_double(1 + 1 / 1024 + 1 / 2048, 1 / 2048) == 1 + 2 / 1024 47 | ) 48 | 49 | assert truncate.bit_truncate_double(32.121, 1) == 32 50 | assert truncate.bit_truncate_double(-32.121, 1) == -32 51 | assert truncate.bit_truncate_double(32.121, 0) == 32.121 52 | assert truncate.bit_truncate_double(0.9191919191, 0.001) == 0.919921875 53 | assert truncate.bit_truncate_double(0.9191919191, 0) == 0.9191919191 54 | assert truncate.bit_truncate_double(0.010101, 0) == 0.010101 55 | assert truncate.bit_truncate_double(1, 1) == 0 56 | 57 | # special cases 58 | assert truncate.bit_truncate_float(32.121, -1) == 0 59 | assert truncate.bit_truncate_double(32.121, -1) == 0 60 | assert truncate.bit_truncate_float(32.121, np.inf) == 0 61 | assert truncate.bit_truncate_double(32.121, np.inf) == 0 62 | assert truncate.bit_truncate_float(np.inf, 1) == np.inf 63 | assert truncate.bit_truncate_double(np.inf, 1) == np.inf 64 | assert np.isnan(truncate.bit_truncate_float(np.nan, 1)) 65 | assert np.isnan(truncate.bit_truncate_double(np.nan, 1)) 66 | 67 | assert truncate.bit_truncate_float(np.inf, np.inf) == 0 68 | assert truncate.bit_truncate_double(np.inf, np.inf) == 0 69 | assert truncate.bit_truncate_float(np.nan, np.inf) == 0 70 | assert truncate.bit_truncate_double(np.nan, np.inf) == 0 71 | 72 | # Test that an error is raised when `err` is `NaN` 73 | with pytest.raises(ValueError): 74 | truncate.bit_truncate_float(32.121, np.nan) 75 | 76 | with pytest.raises(ValueError): 77 | truncate.bit_truncate_double(32.121, np.nan) 78 | 79 | 80 | def test_truncate_array(): 81 | assert ( 82 | truncate.bit_truncate_relative( 83 | np.asarray([32.121, 32.5], dtype=np.float32), 1 / 32 84 | ) 85 | == np.asarray([32, 32], dtype=np.float32) 86 | ).all() 87 | assert ( 88 | truncate.bit_truncate_relative_double( 89 | np.asarray([32.121, 32.5], dtype=np.float64), 1 / 32 90 | ) 91 | == np.asarray([32, 32], dtype=np.float64) 92 | ).all() 93 | 94 | 95 | def test_truncate_weights(): 96 | assert ( 97 | truncate.bit_truncate_weights( 98 | np.asarray([32.121, 32.5], dtype=np.float32), 99 | np.asarray([1 / 32, 1 / 32], dtype=np.float32), 100 | 0.001, 101 | ) 102 | == np.asarray([32, 32], dtype=np.float32) 103 | ).all() 104 | assert ( 105 | truncate.bit_truncate_weights( 106 | np.asarray([32.121, 32.5], dtype=np.float64), 107 | np.asarray([1 / 32, 1 / 32], dtype=np.float64), 108 | 0.001, 109 | ) 110 | == np.asarray([32, 32], dtype=np.float64) 111 | ).all() 112 | 113 | 114 | def test_truncate_relative(): 115 | assert ( 116 | truncate.bit_truncate_relative( 117 | np.asarray([32.121, 32.5], dtype=np.float32), 118 | 0.1, 119 | ) 120 | == np.asarray([32, 32], dtype=np.float32) 121 | ).all() 122 | assert ( 123 | truncate.bit_truncate_relative( 124 | np.asarray([32.121, 32.5], dtype=np.float64), 125 | 0.1, 126 | ) 127 | == np.asarray([32, 32], dtype=np.float64) 128 | ).all() 129 | 130 | # Check the case where values are negative 131 | assert ( 132 | truncate.bit_truncate_relative( 133 | np.asarray([-32.121, 32.5], dtype=np.float32), 134 | 0.1, 135 | ) 136 | == np.asarray([-32, 32], dtype=np.float32) 137 | ).all() 138 | assert ( 139 | truncate.bit_truncate_relative( 140 | np.asarray([-32.121, 32.5], dtype=np.float64), 141 | 0.1, 142 | ) 143 | == np.asarray([-32, 32], dtype=np.float64) 144 | ).all() 145 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 14 | # the i18n builder cannot share the environment and doctrees with the others 15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 16 | 17 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 18 | 19 | help: 20 | @echo "Please use \`make ' where is one of" 21 | @echo " html to make standalone HTML files" 22 | @echo " dirhtml to make HTML files named index.html in directories" 23 | @echo " singlehtml to make a single large HTML file" 24 | @echo " pickle to make pickle files" 25 | @echo " json to make JSON files" 26 | @echo " htmlhelp to make HTML files and a HTML help project" 27 | @echo " qthelp to make HTML files and a qthelp project" 28 | @echo " devhelp to make HTML files and a Devhelp project" 29 | @echo " epub to make an epub" 30 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 31 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 32 | @echo " text to make text files" 33 | @echo " man to make manual pages" 34 | @echo " texinfo to make Texinfo files" 35 | @echo " info to make Texinfo files and run them through makeinfo" 36 | @echo " gettext to make PO message catalogs" 37 | @echo " changes to make an overview of all changed/added/deprecated items" 38 | @echo " linkcheck to check all external links for integrity" 39 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 40 | 41 | clean: 42 | -rm -rf $(BUILDDIR)/* _autosummary 43 | 44 | html: 45 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 46 | @echo 47 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 48 | 49 | dirhtml: 50 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 51 | @echo 52 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 53 | 54 | singlehtml: 55 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 56 | @echo 57 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 58 | 59 | pickle: 60 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 61 | @echo 62 | @echo "Build finished; now you can process the pickle files." 63 | 64 | json: 65 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 66 | @echo 67 | @echo "Build finished; now you can process the JSON files." 68 | 69 | htmlhelp: 70 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 71 | @echo 72 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 73 | ".hhp project file in $(BUILDDIR)/htmlhelp." 74 | 75 | qthelp: 76 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 77 | @echo 78 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 79 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 80 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/ch_util.qhcp" 81 | @echo "To view the help file:" 82 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/ch_util.qhc" 83 | 84 | devhelp: 85 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 86 | @echo 87 | @echo "Build finished." 88 | @echo "To view the help file:" 89 | @echo "# mkdir -p $$HOME/.local/share/devhelp/ch_util" 90 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/ch_util" 91 | @echo "# devhelp" 92 | 93 | epub: 94 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 95 | @echo 96 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 97 | 98 | latex: 99 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 100 | @echo 101 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 102 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 103 | "(use \`make latexpdf' here to do that automatically)." 104 | 105 | latexpdf: 106 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 107 | @echo "Running LaTeX files through pdflatex..." 108 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 109 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 110 | 111 | text: 112 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 113 | @echo 114 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 115 | 116 | man: 117 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 118 | @echo 119 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 120 | 121 | texinfo: 122 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 123 | @echo 124 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 125 | @echo "Run \`make' in that directory to run these through makeinfo" \ 126 | "(use \`make info' here to do that automatically)." 127 | 128 | info: 129 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 130 | @echo "Running Texinfo files through makeinfo..." 131 | make -C $(BUILDDIR)/texinfo info 132 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 133 | 134 | gettext: 135 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 136 | @echo 137 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 138 | 139 | changes: 140 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 141 | @echo 142 | @echo "The overview file is in $(BUILDDIR)/changes." 143 | 144 | linkcheck: 145 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 146 | @echo 147 | @echo "Link check complete; look for any errors in the above output " \ 148 | "or in $(BUILDDIR)/linkcheck/output.txt." 149 | 150 | doctest: 151 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 152 | @echo "Testing of doctests in the sources finished, look at the " \ 153 | "results in $(BUILDDIR)/doctest/output.txt." 154 | -------------------------------------------------------------------------------- /caput/pfb.py: -------------------------------------------------------------------------------- 1 | """Tools for calculating the effects of the CASPER tools PFB. 2 | 3 | This module can: 4 | - Evaluate the typical window functions used 5 | - Evaluate a python model of the PFB 6 | - Calculate the decorrelation effect for signals offset by a known time delay. 7 | 8 | Window functions 9 | ================ 10 | - :py:meth:`sinc_window` 11 | - :py:meth:`sinc_hanning` 12 | - :py:meth:`sinc_hamming` 13 | 14 | PFB 15 | === 16 | - :py:meth:`pfb` 17 | - :py:meth:`decorrelation_ratio` 18 | """ 19 | 20 | import numpy as np 21 | from scipy.interpolate import interp1d 22 | 23 | 24 | def sinc_window(ntap, lblock): 25 | """Sinc window function. 26 | 27 | Parameters 28 | ---------- 29 | ntap : integer 30 | Number of taps. 31 | lblock: integer 32 | Length of block. 33 | 34 | Returns 35 | ------- 36 | window : np.ndarray[ntap * lblock] 37 | """ 38 | # Sampling locations of sinc function 39 | X = np.linspace(-ntap / 2, ntap / 2, ntap * lblock, endpoint=False) 40 | 41 | # np.sinc function is sin(pi*x)/pi*x, not sin(x)/x, so we can just X 42 | return np.sinc(X) 43 | 44 | 45 | def sinc_hann(ntap, lblock): 46 | """Hann-sinc window function. 47 | 48 | Parameters 49 | ---------- 50 | ntap : integer 51 | Number of taps. 52 | lblock: integer 53 | Length of block. 54 | 55 | Returns 56 | ------- 57 | window : np.ndarray[ntap * lblock] 58 | """ 59 | return sinc_window(ntap, lblock) * np.hanning(ntap * lblock) 60 | 61 | 62 | def sinc_hamming(ntap, lblock): 63 | """Hamming-sinc window function. 64 | 65 | Parameters 66 | ---------- 67 | ntap : integer 68 | Number of taps. 69 | lblock: integer 70 | Length of block. 71 | 72 | Returns 73 | ------- 74 | window : np.ndarray[ntap * lblock] 75 | """ 76 | return sinc_window(ntap, lblock) * np.hamming(ntap * lblock) 77 | 78 | 79 | class PFB: 80 | """Model for the CASPER PFB. 81 | 82 | This is the PFB used in CHIME and other experiments. 83 | 84 | Parameters 85 | ---------- 86 | ntap : int 87 | Number of taps (i.e. blocks) used in one step of the PFB. 88 | lblock : int 89 | The length of a block that gets transformed. This is twice the number 90 | of output frequencies. 91 | window : function, optional 92 | The window function being used. If not set, use a Sinc-Hamming window. 93 | oversample : int, optional 94 | The amount to oversample when calculating the decorrelation ratio. 95 | This will improve accuracy. 96 | """ 97 | 98 | def __init__(self, ntap, lblock, window=None, oversample=4): 99 | self.ntap = ntap 100 | self.lblock = lblock 101 | 102 | self.window = sinc_hamming if window is None else window 103 | self.oversample = oversample 104 | 105 | def apply(self, timestream): 106 | """Apply the PFB to a timestream. 107 | 108 | Parameters 109 | ---------- 110 | timestream : np.ndarray 111 | Timestream to process. 112 | 113 | Returns 114 | ------- 115 | pfb : np.ndarray[:, lblock // 2] 116 | Array of PFB frequencies. 117 | """ 118 | # Number of blocks 119 | nblock = timestream.size // self.lblock - (self.ntap - 1) 120 | 121 | # Initialise array for spectrum 122 | spec = np.zeros((nblock, self.lblock // 2), dtype=np.complex128) 123 | 124 | # Window function 125 | w = self.window(self.ntap, self.lblock) 126 | 127 | # Iterate over blocks and perform the PFB 128 | for bi in range(nblock): 129 | # Cut out the correct timestream section 130 | ts_sec = timestream[(bi * self.lblock) : ((bi + self.ntap) * self.lblock)] 131 | 132 | # Perform a real FFT (with applied window function) 133 | ft = np.fft.rfft(ts_sec * w) 134 | 135 | # Choose every n-th frequency 136 | spec[bi] = ft[: ((self.lblock // 2) * self.ntap) : self.ntap] 137 | 138 | return spec 139 | 140 | _decorr_interp = None 141 | 142 | def decorrelation_ratio(self, delay): 143 | """Calculate the decorrelation caused by a relative relay of two timestreams. 144 | 145 | This is caused by the fact that the PFB is generated from a finite 146 | time window of data. 147 | 148 | Parameters 149 | ---------- 150 | delay : array_like 151 | The relative delay between the correlated streams in units of 152 | samples (not required to be an integer). 153 | 154 | Returns 155 | ------- 156 | decorrelation : array_like 157 | The decorrelation ratio. 158 | """ 159 | if self._decorr_interp is None: 160 | N = self.ntap * self.lblock 161 | 162 | # Calculate the window and zero pad the array by a factor of oversample 163 | window_extended = np.zeros(N * self.oversample) 164 | window_extended[:N] = self.window(self.ntap, self.lblock) 165 | 166 | # Calculate the FFT and copy into an array over padded by another factor of 167 | # oversample. As we are doing real/inverse-real FFTs the actual length of 168 | # this array has the usual 1/2 N + 1 sizing. 169 | wf = np.fft.rfft(window_extended) 170 | wfpad = np.zeros(N * self.oversample**2 // 2 + 1, dtype=np.complex128) 171 | wfpad[: wf.size] = np.abs(wf) ** 2 172 | 173 | # Calculate the ratio and the effective delays it is available at 174 | decorrelation_ratio = np.fft.irfft(wfpad) 175 | tau = np.fft.fftfreq( 176 | N * self.oversample**2, d=(1.0 / (N * self.oversample)) 177 | ) 178 | 179 | # Extract only the relevant range of time 180 | tau_r = tau[np.abs(tau) <= N] 181 | dc_r = decorrelation_ratio[np.abs(tau) <= N] / decorrelation_ratio[0] 182 | 183 | self._decorr_interp = interp1d( 184 | tau_r, 185 | dc_r, 186 | kind="linear", 187 | fill_value=0, 188 | assume_sorted=False, 189 | bounds_error=False, 190 | ) 191 | 192 | return self._decorr_interp(delay) 193 | -------------------------------------------------------------------------------- /caput/interferometry.py: -------------------------------------------------------------------------------- 1 | """Useful functions for radio interferometry. 2 | 3 | Coordinates 4 | ----------- 5 | - :py:meth:`sphdist` 6 | - :py:meth:`sph_to_ground` 7 | - :py:meth:`ground_to_sph` 8 | - :py:meth:`project_distance` 9 | - :py:meth:`rotate_ypr` 10 | 11 | Interferometry 12 | -------------- 13 | - :py:meth:`fringestop_phase` 14 | """ 15 | 16 | import numpy as np 17 | 18 | 19 | def sphdist(long1, lat1, long2, lat2): 20 | """Return the angular distance between two coordinates on the sphere. 21 | 22 | Parameters 23 | ---------- 24 | long1, lat1 : Skyfield Angle objects 25 | longitude and latitude of the first coordinate. Each should be the 26 | same length; can be one or longer. 27 | 28 | long2, lat2 : Skyfield Angle objects 29 | longitude and latitude of the second coordinate. Each should be the 30 | same length. If long1, lat1 have length longer than 1, long2 and 31 | lat2 should either have the same length as coordinate 1 or length 1. 32 | 33 | Returns 34 | ------- 35 | dist : Skyfield Angle object 36 | angle between the two coordinates 37 | """ 38 | from skyfield.units import Angle 39 | 40 | dsinb = np.sin((lat1.radians - lat2.radians) / 2.0) ** 2 41 | 42 | dsinl = ( 43 | np.cos(lat1.radians) 44 | * np.cos(lat2.radians) 45 | * (np.sin((long1.radians - long2.radians) / 2.0)) ** 2 46 | ) 47 | 48 | dist = np.arcsin(np.sqrt(dsinl + dsinb)) 49 | 50 | return Angle(radians=2 * dist) 51 | 52 | 53 | def sph_to_ground(ha, lat, dec): 54 | """Get the ground based XYZ coordinates. 55 | 56 | All input angles are radians. HA, DEC should be in CIRS coordinates. 57 | 58 | Parameters 59 | ---------- 60 | ha : array_like 61 | The Hour Angle of the source to fringestop too. 62 | lat : array_like 63 | The latitude of the observatory. 64 | dec : array_like 65 | The declination of the source. 66 | 67 | Returns 68 | ------- 69 | x, y, z : array_like 70 | The projected angular position in ground fixed XYZ coordinates. 71 | """ 72 | x = -1 * np.cos(dec) * np.sin(ha) 73 | y = np.cos(lat) * np.sin(dec) - np.sin(lat) * np.cos(dec) * np.cos(ha) 74 | z = np.sin(lat) * np.sin(dec) + np.cos(lat) * np.cos(dec) * np.cos(ha) 75 | 76 | return x, y, z 77 | 78 | 79 | def ground_to_sph(x, y, lat): 80 | """Get the CIRS coordinates. 81 | 82 | Latitude is given in radians. Assumes z is positive 83 | 84 | Parameters 85 | ---------- 86 | x : array_like 87 | The East projection of the angular position 88 | y : array_like 89 | The North projection of the angular position 90 | lat : array_like 91 | The latitude of the observatory. 92 | 93 | Returns 94 | ------- 95 | ha, dec: array_like 96 | Hour Angle and declination in radians 97 | """ 98 | z = np.sqrt(1 - x**2 - y**2) 99 | 100 | xe = z * np.cos(lat) - y * np.sin(lat) 101 | ye = x 102 | ze = y * np.cos(lat) + z * np.sin(lat) 103 | 104 | ha = -1 * np.arctan2(ye, xe) 105 | dec = np.arctan2(ze, np.sqrt(xe**2 + ye**2)) 106 | 107 | return ha, dec 108 | 109 | 110 | def projected_distance(ha, lat, dec, x, y, z=0.0): 111 | """Return the distance project in the direction of a source. 112 | 113 | Parameters 114 | ---------- 115 | ha : array_like 116 | The Hour Angle of the source to fringestop too. 117 | lat : array_like 118 | The latitude of the observatory. 119 | dec : array_like 120 | The declination of the source. 121 | x : array_like 122 | The EW coordinate in wavelengths (increases to the E) 123 | y : array_like 124 | The NS coordinate in wavelengths (increases to the N) 125 | z : array_like, optional 126 | The vertical coordinate on wavelengths (increases to the sky!) 127 | 128 | Returns 129 | ------- 130 | dist : np.ndarray 131 | The projected distance. Has whatever units x, y, z did. 132 | """ 133 | # We could use `sph_to_ground` here, but it's likely to be more memory 134 | # efficient to do this directly 135 | dist = x * (-1 * np.cos(dec) * np.sin(ha)) 136 | dist += y * (np.cos(lat) * np.sin(dec) - np.sin(lat) * np.cos(dec) * np.cos(ha)) 137 | dist += z * (np.sin(lat) * np.sin(dec) + np.cos(lat) * np.cos(dec) * np.cos(ha)) 138 | 139 | return dist 140 | 141 | 142 | def rotate_ypr(rot, xhat, yhat, zhat): 143 | """Rotate a basis by a yaw, pitch and roll. 144 | 145 | Parameters 146 | ---------- 147 | rot : [yaw, pitch, roll] 148 | Angles of rotation, in radians. 149 | xhat: np.ndarray 150 | X-component of the basis. X is the axis of rotation for pitch. 151 | yhat: np.ndarray 152 | Y-component of the basis. Y is the axis of rotation for roll. 153 | zhat: np.ndarray 154 | Z-component of the basis. Z is the axis of rotation for yaw. 155 | 156 | Returns 157 | ------- 158 | xhat, yhat, zhat : np.ndarray[3] 159 | New basis vectors. 160 | """ 161 | yaw, pitch, roll = rot 162 | 163 | # Yaw rotation 164 | xhat1 = np.cos(yaw) * xhat - np.sin(yaw) * yhat 165 | yhat1 = np.sin(yaw) * xhat + np.cos(yaw) * yhat 166 | zhat1 = zhat 167 | 168 | # Pitch rotation 169 | xhat2 = xhat1 170 | yhat2 = np.cos(pitch) * yhat1 + np.sin(pitch) * zhat1 171 | zhat2 = -np.sin(pitch) * yhat1 + np.cos(pitch) * zhat1 172 | 173 | # Roll rotation 174 | xhat3 = np.cos(roll) * xhat2 - np.sin(roll) * zhat2 175 | yhat3 = yhat2 176 | zhat3 = np.sin(roll) * xhat2 + np.cos(roll) * zhat2 177 | 178 | return xhat3, yhat3, zhat3 179 | 180 | 181 | def fringestop_phase(ha, lat, dec, u, v, w=0.0): 182 | """Return the phase required to fringestop. All angle inputs are radians. 183 | 184 | Note that for a visibility V_{ij} = < E_i E_j^*>, this expects the u, v, 185 | w coordinates are the components of (d_i - d_j) / lambda. 186 | 187 | Parameters 188 | ---------- 189 | ha : array_like 190 | The Hour Angle of the source to fringestop too. 191 | lat : array_like 192 | The latitude of the observatory. 193 | dec : array_like 194 | The declination of the source. 195 | u : array_like 196 | The EW separation in wavelengths (increases to the E) 197 | v : array_like 198 | The NS separation in wavelengths (increases to the N) 199 | w : array_like, optional 200 | The vertical separation on wavelengths (increases to the sky!) 201 | 202 | Returns 203 | ------- 204 | phase : np.ndarray 205 | The phase required to *correct* the fringeing. Shape is 206 | given by the broadcast of the arguments together. 207 | """ 208 | phase = -2.0j * np.pi * projected_distance(ha, lat, dec, u, v, w) 209 | 210 | return np.exp(phase, out=phase) 211 | -------------------------------------------------------------------------------- /caput/tools.py: -------------------------------------------------------------------------------- 1 | """Collection of assorted tools.""" 2 | 3 | from collections import deque 4 | from collections.abc import Iterable, Mapping 5 | from itertools import chain 6 | from numbers import Number 7 | from sys import getsizeof 8 | from types import ModuleType 9 | 10 | import numpy as np 11 | 12 | from caput._fast_tools import _invert_no_zero 13 | 14 | 15 | def invert_no_zero(x, out=None): 16 | """Return the reciprocal, but ignoring zeros. 17 | 18 | Where `x != 0` return 1/x, or just return 0. Importantly this routine does 19 | not produce a warning about zero division. 20 | 21 | Parameters 22 | ---------- 23 | x : np.ndarray 24 | Array to invert 25 | out : np.ndarray, optional 26 | Output array to insert results 27 | 28 | Returns 29 | ------- 30 | r : np.ndarray 31 | Return the reciprocal of x. Where possible the output has the same memory layout 32 | as the input, if this cannot be preserved the output is C-contiguous. 33 | """ 34 | if not isinstance(x, np.generic | np.ndarray) or np.issubdtype(x.dtype, np.integer): 35 | with np.errstate(divide="ignore", invalid="ignore", over="ignore"): 36 | return np.where(x == 0, 0.0, 1.0 / x) 37 | 38 | if out is not None: 39 | if x.shape != out.shape: 40 | raise ValueError( 41 | f"Input and output arrays don't have same shape: {x.shape} != {out.shape}." 42 | ) 43 | else: 44 | # This works even for MPIArrays, producing a correctly shaped MPIArray 45 | out = np.empty_like(x, order="A") 46 | 47 | # In order to be able to flatten the arrays to do element by element operations, we 48 | # need to ensure the inputs are numpy arrays, and so we take a view which will work 49 | # even if `x` (and thus `out`) are MPIArray's 50 | _invert_no_zero( 51 | x.view(np.ndarray).ravel(order="A"), out.view(np.ndarray).ravel(order="A") 52 | ) 53 | 54 | return out 55 | 56 | 57 | def unique_ordered(x: Iterable) -> list: 58 | """Take unique values from an iterable with order preserved. 59 | 60 | Parameters 61 | ---------- 62 | x : Iterable 63 | An iterable to get unique values from 64 | 65 | Returns 66 | ------- 67 | unique : list 68 | unique items in x with order preserved 69 | """ 70 | seen = set() 71 | # So the method is only resolved once 72 | seen_add = seen.add 73 | 74 | return [i for i in x if not (i in seen or seen_add(i))] 75 | 76 | 77 | def allequal(obj1, obj2): 78 | """Check if two objects are equal. 79 | 80 | This comparison can check standard python types and numpy types, 81 | even in the case where they are nested (ex: a dict with a numpy 82 | array as a value). 83 | 84 | Parameters 85 | ---------- 86 | obj1 : scalar, list, tuple, dict, or np.ndarray 87 | Object to compare 88 | obj2 : scalar, list, tuple, dict, or np.ndarray 89 | Object to compare 90 | """ 91 | try: 92 | _assert_equal(obj1, obj2) 93 | except AssertionError: 94 | return False 95 | return True 96 | 97 | 98 | def _assert_equal(obj1, obj2): 99 | """Assert two objects are equal. 100 | 101 | This comparison can check standard python types and numpy types, 102 | even in the case where they are nested (ex: a dict with a numpy 103 | array as a value). 104 | 105 | For more information: 106 | https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_equal.html 107 | 108 | Parameters 109 | ---------- 110 | obj1 : scalar, list, tuple, dict, or np.ndarray 111 | Object to compare 112 | obj2 : scalar, list, tuple, dict, or np.ndarray 113 | Object to compare 114 | """ 115 | __tracebackhide__ = True 116 | if isinstance(obj1, dict): 117 | # Check that the dict-type objects are equivalent 118 | if not isinstance(obj2, dict): 119 | raise AssertionError(repr(type(obj2))) 120 | _assert_equal(len(obj1), len(obj2)) 121 | # Check that each key:value pair are equal 122 | for k in obj1.keys(): 123 | if k not in obj2: 124 | raise AssertionError(repr(k)) 125 | _assert_equal(obj1[k], obj2[k]) 126 | 127 | return 128 | 129 | if isinstance(obj1, list | tuple) and isinstance(obj2, list | tuple): 130 | # Check that the sequence-type objects are equivalent 131 | _assert_equal(len(obj1), len(obj2)) 132 | # Check that each item is the same 133 | for k in range(len(obj1)): 134 | _assert_equal(obj1[k], obj2[k]) 135 | 136 | return 137 | 138 | # If both objects are np.ndarray subclasses, check that they 139 | # have the same type 140 | if isinstance(obj1, np.ndarray) and isinstance(obj2, np.ndarray): 141 | assert type(obj1) is type(obj2) 142 | 143 | obj1 = obj1.view(np.ndarray) 144 | obj2 = obj2.view(np.ndarray) 145 | 146 | # Check all other built-in types and numpy types are equal 147 | np.testing.assert_equal(obj1, obj2) 148 | 149 | 150 | def total_size(obj: object): 151 | """Return the approximate memory used by an object and anything it references. 152 | 153 | Parameters 154 | ---------- 155 | obj 156 | Any object 157 | """ 158 | # types not to iterate - these can handle their own internal 159 | # recursion/iteration when calling `sys.getsizeof` 160 | exclude_types = (str, bytes, Number, range, bytearray) 161 | 162 | seen = set() 163 | default = getsizeof(0) 164 | 165 | def sum_(x): 166 | try: 167 | return sum(x) 168 | except TypeError: 169 | return 0 170 | 171 | def sizeof(x): 172 | if isinstance(x, ModuleType): 173 | raise TypeError( 174 | f"function `total_size` is not implemented for type {ModuleType}" 175 | ) 176 | # Don't check the same item twice 177 | if id(x) in seen: 178 | return 0 179 | 180 | seen.add(id(x)) 181 | # Get the base size of this object 182 | size = getsizeof(x, default) 183 | # Exclude certain types 184 | if isinstance(x, exclude_types): 185 | pass 186 | # Check basic types and iterate accordingly 187 | elif isinstance(x, tuple | list | set | deque): 188 | size += sum_(map(sizeof, iter(x))) 189 | elif isinstance(x, Mapping): 190 | size += sum_(map(sizeof, chain.from_iterable(x.items()))) 191 | 192 | # Check custom objects 193 | if hasattr(x, "__dict__"): 194 | size += sizeof(vars(x)) 195 | 196 | if hasattr(x, "__slots__"): 197 | size += sum_( 198 | sizeof(getattr(x, attr)) for attr in x.__slots__ if hasattr(x, attr) 199 | ) 200 | 201 | return size 202 | 203 | return sizeof(obj) 204 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Pytest fixtures and simple tasks that can be used by all unit tests.""" 2 | 3 | import glob 4 | import tempfile 5 | 6 | import numpy as np 7 | import pytest 8 | 9 | from caput.pipeline import PipelineStopIteration, TaskBase, IterBase, Manager 10 | from caput.scripts.runner import cli 11 | from caput import config, fileformats, mpiutil 12 | 13 | 14 | @pytest.fixture(scope="session") 15 | def datasets(): 16 | """A couple of simple numpy arrays.""" 17 | len_axis = 8 18 | 19 | dset1 = np.arange(len_axis * len_axis * len_axis) 20 | dset1 = dset1.reshape((len_axis, len_axis, len_axis)) 21 | 22 | dset2 = np.arange(len_axis * len_axis) 23 | dset2 = dset2.reshape((len_axis, len_axis)) 24 | 25 | return dset1, dset2 26 | 27 | 28 | class PrintEggs(TaskBase): 29 | """Simple task used for testing.""" 30 | 31 | eggs = config.Property(proptype=list) 32 | 33 | def __init__(self, *args, **kwargs): 34 | self.i = 0 35 | super().__init__(*args, **kwargs) 36 | 37 | def setup(self, requires=None): 38 | """Run setup.""" 39 | print("Setting up PrintEggs.") 40 | 41 | def next(self, _input=None): 42 | """Run next.""" 43 | if self.i >= len(self.eggs): 44 | raise PipelineStopIteration() 45 | print("Spam and %s eggs." % self.eggs[self.i]) 46 | self.i += 1 47 | 48 | def finish(self): 49 | """Run finish.""" 50 | print("Finished PrintEggs.") 51 | 52 | 53 | class GetEggs(TaskBase): 54 | """Simple task used for testing.""" 55 | 56 | eggs = config.Property(proptype=list) 57 | 58 | def __init__(self, *args, **kwargs): 59 | self.i = 0 60 | super().__init__(*args, **kwargs) 61 | 62 | def setup(self, requires=None): 63 | """Run setup.""" 64 | print("Setting up GetEggs.") 65 | 66 | def next(self, _input=None): 67 | """Run next.""" 68 | if self.i >= len(self.eggs): 69 | raise PipelineStopIteration() 70 | egg = self.eggs[self.i] 71 | self.i += 1 72 | return egg 73 | 74 | def finish(self): 75 | """Run finish.""" 76 | print("Finished GetEggs.") 77 | 78 | 79 | class CookEggs(IterBase): 80 | """Simple task used for testing.""" 81 | 82 | style = config.Property(proptype=str) 83 | 84 | def setup(self, requires=None): 85 | """Run setup.""" 86 | print("Setting up CookEggs.") 87 | 88 | def process(self, _input): 89 | """Run process.""" 90 | print("Cooking %s %s eggs." % (self.style, _input)) 91 | 92 | def finish(self): 93 | """Run finish.""" 94 | print("Finished CookEggs.") 95 | 96 | def read_input(self, filename): 97 | """Run read input not implemented.""" 98 | raise NotImplementedError() 99 | 100 | def read_output(self, filename): 101 | """Run read output not implemented.""" 102 | raise NotImplementedError() 103 | 104 | def write_output(self, filename, output, file_format=None, **kwargs): 105 | """Run write output not implemented.""" 106 | raise NotImplementedError() 107 | 108 | 109 | @pytest.fixture 110 | def run_pipeline(): 111 | """Provides the `run_pipeline` function which will run the pipeline.""" 112 | eggs_pipeline_conf = """ 113 | --- 114 | pipeline: 115 | tasks: 116 | - type: tests.conftest.PrintEggs 117 | params: eggs_params 118 | - type: tests.conftest.GetEggs 119 | params: eggs_params 120 | out: egg 121 | - type: tests.conftest.CookEggs 122 | params: cook_params 123 | in: egg 124 | eggs_params: 125 | eggs: ['green', 'duck', 'ostrich'] 126 | cook_params: 127 | style: 'fried' 128 | """ 129 | 130 | def _run_pipeline(parameters=None, configstr=eggs_pipeline_conf): 131 | """Run `caput.scripts.runner run` with given parameters and config. 132 | 133 | Parameters 134 | ---------- 135 | parameters : List[str] 136 | Parameters to pass to the cli, for example `["--profile"]` (see `--help`). 137 | configstr : str 138 | YAML string to use as a config. This function will write it to a file that is then passed to the cli. 139 | 140 | Returns 141 | ------- 142 | result : `click.testing.Result` 143 | Holds the captured result. Try accessing e.g. `result.exit_code`, `result.output`. 144 | """ 145 | with tempfile.NamedTemporaryFile("w+") as configfile: 146 | configfile.write(configstr) 147 | configfile.flush() 148 | from click.testing import CliRunner 149 | 150 | runner = CliRunner() 151 | if parameters is None: 152 | return runner.invoke(cli, ["run", configfile.name]) 153 | else: 154 | return runner.invoke(cli, ["run", *parameters, configfile.name]) 155 | 156 | return _run_pipeline 157 | 158 | 159 | @pytest.fixture 160 | def get_pipeline(): 161 | """Provides the `get_pipeline` function which returns an initialized pipeline manager.""" 162 | eggs_pipeline_conf = """ 163 | --- 164 | pipeline: 165 | tasks: 166 | - type: tests.conftest.PrintEggs 167 | params: eggs_params 168 | - type: tests.conftest.GetEggs 169 | params: eggs_params 170 | out: egg 171 | - type: tests.conftest.CookEggs 172 | params: cook_params 173 | in: egg 174 | eggs_params: 175 | eggs: ['green', 'duck', 'ostrich'] 176 | cook_params: 177 | style: 'fried' 178 | """ 179 | 180 | def _get_pipeline(configstr=eggs_pipeline_conf): 181 | """Initialize a pipeline manager object. 182 | 183 | Parameters 184 | ---------- 185 | configstr : str 186 | YAML string to use as a config. 187 | 188 | Returns 189 | ------- 190 | manager : Manager 191 | Initialized pipeline manager. 192 | """ 193 | 194 | return Manager.from_yaml_str(configstr) 195 | 196 | return _get_pipeline 197 | 198 | 199 | @pytest.fixture 200 | def h5_file(rm_all_files): 201 | """Provides a file name and removes all files/dirs with the same prefix later.""" 202 | fname = "tmp_test_memh5.h5" 203 | yield fname 204 | rm_all_files(fname) 205 | 206 | 207 | @pytest.fixture 208 | def zarr_file(rm_all_files): 209 | """Provides a directory name and removes all files/dirs with the same prefix later.""" 210 | fname = "tmp_test_memh5.zarr" 211 | yield fname 212 | rm_all_files(fname) 213 | 214 | 215 | @pytest.fixture 216 | def h5_file_distributed(rm_all_files): 217 | """Provides a file name and removes all files/dirs with the same prefix later.""" 218 | fname = "tmp_test_memh5_distributed.h5" 219 | yield fname 220 | if mpiutil.rank == 0: 221 | rm_all_files(fname) 222 | 223 | 224 | @pytest.fixture 225 | def zarr_file_distributed(rm_all_files): 226 | """Provides a directory name and removes all files/dirs with the same prefix later.""" 227 | fname = "tmp_test_memh5.zarr" 228 | yield fname 229 | if mpiutil.rank == 0: 230 | rm_all_files(fname) 231 | 232 | 233 | @pytest.fixture 234 | def rm_all_files(): 235 | """Provides the `rm_all_files` function.""" 236 | 237 | def _rm_all_files(file_name): 238 | """Remove all files and directories starting with `file_name`.""" 239 | file_names = glob.glob(file_name + "*") 240 | for fname in file_names: 241 | fileformats.remove_file_or_dir(fname) 242 | 243 | return _rm_all_files 244 | -------------------------------------------------------------------------------- /tests/test_memh5_parallel.py: -------------------------------------------------------------------------------- 1 | """Unit tests for the parallel features of the memh5 module.""" 2 | 3 | import pytest 4 | from pytest_lazy_fixtures import lf 5 | import numpy as np 6 | import h5py 7 | import zarr 8 | import copy 9 | 10 | from caput import fileformats, memh5, mpiarray, mpiutil 11 | 12 | 13 | comm = mpiutil.world 14 | rank, size = mpiutil.rank, mpiutil.size 15 | 16 | 17 | def test_create_dataset(): 18 | """Test for creating datasets in MemGroup.""" 19 | global_data = np.arange(size * 5 * 10, dtype=np.float32) 20 | local_data = global_data.reshape(size, -1, 10)[rank] 21 | d_array = mpiarray.MPIArray.wrap(local_data, axis=0) 22 | d_array_T = d_array.redistribute(axis=1) 23 | 24 | # Check that we must specify in advance if the dataset is distributed 25 | g = memh5.MemGroup() 26 | if comm is not None: 27 | with pytest.raises(RuntimeError): 28 | g.create_dataset("data", data=d_array) 29 | 30 | g = memh5.MemGroup(distributed=True) 31 | 32 | # Create an array from data 33 | g.create_dataset("data", data=d_array, distributed=True) 34 | 35 | # Create an array from data with a different distribution 36 | g.create_dataset("data_T", data=d_array, distributed=True, distributed_axis=1) 37 | 38 | # Create an empty array with a specified shape 39 | g.create_dataset( 40 | "data2", 41 | shape=(size * 5, 10), 42 | dtype=np.float64, 43 | distributed=True, 44 | distributed_axis=1, 45 | ) 46 | assert np.allclose(d_array, g["data"][:]) 47 | assert np.allclose(d_array_T, g["data_T"][:]) 48 | if comm is not None: 49 | assert d_array_T.local_shape == g["data2"].local_shape 50 | 51 | # Test global indexing 52 | assert (g["data"][rank * 5] == local_data[0]).all() 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "compression,compression_opts,chunks", 57 | [ 58 | (None, None, None), 59 | ("bitshuffle", (None, "lz4"), (size // 2 + ((size // 2) == 0), 3)), 60 | ], 61 | ) 62 | @pytest.mark.parametrize( 63 | "test_file,file_open_function,file_format", 64 | [ 65 | (lf("h5_file_distributed"), h5py.File, fileformats.HDF5), 66 | ( 67 | lf("zarr_file_distributed"), 68 | zarr.open_group, 69 | fileformats.Zarr, 70 | ), 71 | ], 72 | ) 73 | def test_io( 74 | test_file, file_open_function, file_format, compression, compression_opts, chunks 75 | ): 76 | """Test for I/O in MemGroup.""" 77 | 78 | # Create distributed memh5 object 79 | g = memh5.MemGroup(distributed=True) 80 | g.attrs["rank"] = rank 81 | 82 | # Create an empty array with a specified shape 83 | pdset = g.create_dataset( 84 | "parallel_data", 85 | shape=(size, 10), 86 | dtype=np.float64, 87 | distributed=True, 88 | distributed_axis=0, 89 | compression=compression, 90 | compression_opts=compression_opts, 91 | chunks=chunks, 92 | ) 93 | pdset[:] = rank 94 | pdset.attrs["const"] = 17 95 | 96 | # Create an empty array with a specified shape 97 | sdset = g.create_dataset("serial_data", shape=(size * 5, 10), dtype=np.float64) 98 | sdset[:] = rank 99 | sdset.attrs["const"] = 18 100 | 101 | # Create nested groups 102 | g.create_group("hello/world") 103 | 104 | # Test round tripping unicode data 105 | g.create_dataset("unicode_data", data=np.array(["hello"])) 106 | 107 | g.to_file( 108 | test_file, 109 | convert_attribute_strings=True, 110 | convert_dataset_strings=True, 111 | file_format=file_format, 112 | ) 113 | 114 | # Test that the HDF5 file has the correct structure 115 | with file_open_function(test_file, "r") as f: 116 | # Test that the file attributes are correct 117 | assert f["parallel_data"].attrs["const"] == 17 118 | 119 | # Test that the parallel dataset has been written correctly 120 | assert (f["parallel_data"][:, 0] == np.arange(size)).all() 121 | assert f["parallel_data"].attrs["const"] == 17 122 | 123 | # Test that the common dataset has been written correctly (i.e. by rank=0) 124 | assert (f["serial_data"][:] == 0).all() 125 | assert f["serial_data"].attrs["const"] == 18 126 | 127 | # Check group structure is correct 128 | assert "hello" in f 129 | assert "world" in f["hello"] 130 | 131 | # Check compression/chunks 132 | if chunks is None: 133 | if file_format is fileformats.HDF5: 134 | assert f["parallel_data"].chunks is None 135 | 136 | elif file_format is fileformats.Zarr: 137 | assert f["parallel_data"].chunks == f["parallel_data"].shape 138 | assert f["parallel_data"].compressor is None 139 | else: 140 | assert f["parallel_data"].chunks == chunks 141 | 142 | if file_format is fileformats.Zarr: 143 | assert f["parallel_data"].compressor is not None 144 | 145 | # Test that the read in group has the same structure as the original 146 | g2 = memh5.MemGroup.from_file( 147 | test_file, 148 | distributed=True, 149 | convert_attribute_strings=True, 150 | convert_dataset_strings=True, 151 | file_format=file_format, 152 | ) 153 | 154 | # Check that the parallel data is still the same 155 | assert (g2["parallel_data"][:] == g["parallel_data"][:]).all() 156 | 157 | # Check that the serial data is all zeros (should not be the same as before) 158 | assert (g2["serial_data"][:] == np.zeros_like(sdset[:])).all() 159 | 160 | # Check group structure is correct 161 | assert "hello" in g2 162 | assert "world" in g2["hello"] 163 | 164 | # Check the unicode dataset 165 | assert g2["unicode_data"].dtype.kind == "U" 166 | assert g2["unicode_data"][0] == "hello" 167 | 168 | # Check the attributes 169 | assert g2["parallel_data"].attrs["const"] == 17 170 | assert g2["serial_data"].attrs["const"] == 18 171 | 172 | 173 | @pytest.mark.parametrize( 174 | "test_file,file_open_function,file_format", 175 | [ 176 | (lf("h5_file_distributed"), h5py.File, fileformats.HDF5), 177 | ( 178 | lf("zarr_file_distributed"), 179 | zarr.open_group, 180 | fileformats.Zarr, 181 | ), 182 | ], 183 | ) 184 | def test_misc(test_file, file_open_function, file_format): 185 | """Misc tests for MemDiskGroupDistributed""" 186 | 187 | dg = memh5.MemDiskGroup(distributed=True) 188 | 189 | pdset = dg.create_dataset( 190 | "parallel_data", 191 | shape=(10,), 192 | dtype=np.float64, 193 | distributed=True, 194 | distributed_axis=0, 195 | ) 196 | # pdset[:] = dg._data.comm.rank 197 | pdset[:] = rank 198 | # Test successfully added 199 | assert "parallel_data" in dg 200 | 201 | dg.save(test_file, file_format=file_format) 202 | 203 | dg2 = memh5.MemDiskGroup.from_file( 204 | test_file, distributed=True, file_format=file_format 205 | ) 206 | 207 | # Test successful load 208 | assert "parallel_data" in dg2 209 | assert (dg["parallel_data"][:] == dg2["parallel_data"][:]).all() 210 | 211 | # self.assertRaises(NotImplementedError, dg.to_disk, self.fname) 212 | 213 | # Test refusal to base off a h5py object when distributed 214 | with file_open_function(test_file, "r") as f: 215 | if comm is not None: 216 | with pytest.raises(ValueError): 217 | # MemDiskGroup will guess the file format 218 | memh5.MemDiskGroup(data_group=f, distributed=True) 219 | mpiutil.barrier() 220 | 221 | 222 | def test_redistribute(): 223 | """Test redistribute in BasicCont.""" 224 | 225 | g = memh5.BasicCont(distributed=True) 226 | 227 | # Create an array from data 228 | g.create_dataset("data", shape=(10, 10), distributed=True, distributed_axis=0) 229 | assert g["data"].distributed_axis == 0 230 | g.redistribute(1) 231 | assert g["data"].distributed_axis == 1 232 | 233 | 234 | # Unit test for MemDataset 235 | 236 | 237 | def test_dataset_copy(): 238 | # Check for string types 239 | x = memh5.MemDatasetDistributed(shape=(4, 5), dtype=np.float32) 240 | x[:] = 0 241 | 242 | # Check a deepcopy using .copy 243 | y = x.copy() 244 | assert x == y 245 | y[:] = 1 246 | # Check this this is in fact a deep copy 247 | assert x != y 248 | 249 | # This is a shallow copy 250 | y = x.copy(shallow=True) 251 | assert x == y 252 | y[:] = 1 253 | assert x == y 254 | 255 | # Check a deepcopy using copy.deepcopy 256 | y = copy.deepcopy(x) 257 | assert x == y 258 | y[:] = 2 259 | assert x != y 260 | -------------------------------------------------------------------------------- /caput/truncate.hpp: -------------------------------------------------------------------------------- 1 | // 2**31 + 2**30 will be used to check for overflow 2 | const uint32_t HIGH_BITS = 3221225472; 3 | 4 | // 2**63 + 2**62 will be used to check for overflow 5 | const uint64_t HIGH_BITS_DOUBLE = 13835058055282163712UL; 6 | 7 | // The length of the part in a float that represents the exponent 8 | const int32_t LEN_EXPONENT_FLOAT = 8; 9 | 10 | // The length of the part in a double that represents the exponent 11 | const int64_t LEN_EXPONENT_DOUBLE = 11; 12 | 13 | // Starting bit (offset) of the part in a float that represents the exponent 14 | const int32_t POS_EXPONENT_FLOAT = 23; 15 | 16 | // Starting bit (offset) of the part in a double that represents the exponent 17 | const int64_t POS_EXPONENT_DOUBLE = 52; 18 | 19 | // A mask to apply on the exponent representation of a float, to get rid of the sign part 20 | const int32_t MASK_EXPONENT_W_O_SIGN_FLOAT = 255; 21 | 22 | // A mask to apply on the exponent representation of a double, to get rid of the sign part 23 | const int64_t MASK_EXPONENT_W_O_SIGN_DOUBLE = 2047; 24 | 25 | // A mask to apply on a float to get only the mantissa (2**23 - 1) 26 | const int32_t MASK_MANTISSA_FLOAT = 8388607; 27 | 28 | // A mask to apply on a double to get only the mantissa (2**52 - 1) 29 | const int64_t MASK_MANTISSA_DOUBLE = 4503599627370495L; 30 | 31 | // Implicit 24th bit of the mantissa in a float (2**23) 32 | const int32_t IMPLICIT_BIT_FLOAT = 8388608; 33 | 34 | // Implicit 53rt bit of the mantissa in a double (2**52) 35 | const int64_t IMPLICIT_BIT_DOUBLE = 4503599627370496L; 36 | 37 | // The maximum error we can have for the mantissa in a float (less than 2**30) 38 | const int32_t ERR_MAX_FLOAT = 1073741823; 39 | 40 | // The maximum error we can have for the mantissa in a double (less than 2**30) 41 | const int64_t ERR_MAX_DOUBLE = 4611686018427387903L; 42 | 43 | /** 44 | * @brief Truncate the precision of *val* by rounding to a multiple of a power of 45 | * two, keeping error less than or equal to *err*. 46 | * 47 | * @warning Undefined results for err < 0 and err > 2**30. 48 | */ 49 | inline int32_t bit_truncate(int32_t val, int32_t err) { 50 | // *gran* is the granularity. It is the power of 2 that is *larger than* the 51 | // maximum error *err*. 52 | int32_t gran = err; 53 | gran |= gran >> 1; 54 | gran |= gran >> 2; 55 | gran |= gran >> 4; 56 | gran |= gran >> 8; 57 | gran |= gran >> 16; 58 | gran += 1; 59 | 60 | // Bitmask selects bits to be rounded. 61 | int32_t bitmask = gran - 1; 62 | 63 | // Determine if there is a round-up/round-down tie. 64 | // This operation gets the `gran = 1` case correct (non tie). 65 | int32_t tie = ((val & bitmask) << 1) == gran; 66 | 67 | // The acctual rounding. 68 | int32_t val_t = (val - (gran >> 1)) | bitmask; 69 | val_t += 1; 70 | // There is a bit of extra bit twiddling for the err == 0. 71 | val_t -= (err == 0); 72 | 73 | // Break any tie by rounding to even. 74 | val_t -= val_t & (tie * gran); 75 | 76 | return val_t; 77 | } 78 | 79 | 80 | /** 81 | * @brief Truncate the precision of *val* by rounding to a multiple of a power of 82 | * two, keeping error less than or equal to *err*. 83 | * 84 | * @warning Undefined results for err < 0 and err > 2**62. 85 | */ 86 | inline int64_t bit_truncate_64(int64_t val, int64_t err) { 87 | // *gran* is the granularity. It is the power of 2 that is *larger than* the 88 | // maximum error *err*. 89 | int64_t gran = err; 90 | gran |= gran >> 1; 91 | gran |= gran >> 2; 92 | gran |= gran >> 4; 93 | gran |= gran >> 8; 94 | gran |= gran >> 16; 95 | gran |= gran >> 32; 96 | gran += 1; 97 | 98 | // Bitmask selects bits to be rounded. 99 | int64_t bitmask = gran - 1; 100 | 101 | // Determine if there is a round-up/round-down tie. 102 | // This operation gets the `gran = 1` case correct (non tie). 103 | int64_t tie = ((val & bitmask) << 1) == gran; 104 | 105 | // The acctual rounding. 106 | int64_t val_t = (val - (gran >> 1)) | bitmask; 107 | val_t += 1; 108 | // There is a bit of extra bit twiddling for the err == 0. 109 | val_t -= (err == 0); 110 | 111 | // Break any tie by rounding to even. 112 | val_t -= val_t & (tie * gran); 113 | 114 | return val_t; 115 | } 116 | 117 | 118 | /** 119 | * @brief Count the number of leading zeros in a binary number. 120 | * Taken from https://stackoverflow.com/a/23857066 121 | */ 122 | inline int32_t count_zeros(int32_t x) { 123 | x = x | (x >> 1); 124 | x = x | (x >> 2); 125 | x = x | (x >> 4); 126 | x = x | (x >> 8); 127 | x = x | (x >> 16); 128 | return __builtin_popcount(~x); 129 | } 130 | 131 | 132 | /** 133 | * @brief Count the number of leading zeros in a binary number. 134 | * Taken from https://stackoverflow.com/a/23857066 135 | */ 136 | inline int64_t count_zeros_64(int64_t x) { 137 | x = x | (x >> 1); 138 | x = x | (x >> 2); 139 | x = x | (x >> 4); 140 | x = x | (x >> 8); 141 | x = x | (x >> 16); 142 | x = x | (x >> 32); 143 | return __builtin_popcountl(~x); 144 | } 145 | 146 | 147 | /** 148 | * @brief Fast power of two float. 149 | * 150 | * Result is undefined for e < -126. 151 | * 152 | * @param e Exponent 153 | * 154 | * @returns The result of 2^e 155 | */ 156 | inline float fast_pow(int8_t e) { 157 | float* out_f; 158 | // Construct float bitwise 159 | uint32_t out_i = ((uint32_t)(127 + e) << 23); 160 | // Cast into float 161 | out_f = (float*)&out_i; 162 | return *out_f; 163 | } 164 | 165 | 166 | /** 167 | * @brief Fast power of two double. 168 | * 169 | * Result is undefined for e < -1022 and e > 1023. 170 | * 171 | * @param e Exponent 172 | * 173 | * @returns The result of 2^e 174 | */ 175 | inline double fast_pow_double(int16_t e) { 176 | double* out_f; 177 | // Construct float bitwise 178 | uint64_t out_i = ((uint64_t)(1023 + e) << 52); 179 | // Cast into float 180 | out_f = (double*)&out_i; 181 | return *out_f; 182 | } 183 | 184 | 185 | /** 186 | * @brief Truncate precision of a floating point number by applying the algorithm of 187 | * `bit_truncate` to the mantissa. 188 | * 189 | * Note that NaN and inf are not explicitly checked for. According to the IEEE spec, it is 190 | * impossible for the truncation to turn an inf into a NaN. However, if the truncation 191 | * happens to remove all of the non-zero bits in the mantissa, a NaN can become inf. 192 | * 193 | */ 194 | inline float _bit_truncate_float(float val, float err) { 195 | // cast float memory into an int 196 | int32_t* cast_val_ptr = (int32_t*)&val; 197 | // extract the exponent and sign 198 | int32_t val_pre = cast_val_ptr[0] >> POS_EXPONENT_FLOAT; 199 | // strip sign 200 | int32_t val_pow = val_pre & MASK_EXPONENT_W_O_SIGN_FLOAT; 201 | int32_t val_s = val_pre >> LEN_EXPONENT_FLOAT; 202 | // extract mantissa. mask is 2**23 - 1. Add back the implicit 24th bit 203 | int32_t val_man = (cast_val_ptr[0] & MASK_MANTISSA_FLOAT) + IMPLICIT_BIT_FLOAT; 204 | // scale the error to the integer representation of the mantissa 205 | // scale by 2**(23 + 127 - pow) 206 | int32_t int_err = (int32_t)(err * fast_pow(150 - val_pow)); 207 | // make sure hasn't overflowed. if set to 2**30-1, will surely round to 0. 208 | // must keep err < 2**30 for bit_truncate to work 209 | int_err = (int_err & HIGH_BITS) ? ERR_MAX_FLOAT : int_err; 210 | 211 | // truncate 212 | int32_t tr_man = bit_truncate(val_man, int_err); 213 | 214 | // count leading zeros 215 | int32_t z_count = count_zeros(tr_man); 216 | // adjust power after truncation to account for loss of implicit bit 217 | val_pow -= z_count - 8; 218 | // shift mantissa by same amount, remove implicit bit 219 | tr_man = (tr_man << (z_count - 8)) & MASK_MANTISSA_FLOAT; 220 | // round to zero case 221 | val_pow = ((z_count != 32) ? val_pow : 0); 222 | // restore sign and exponent 223 | int32_t tr_val = tr_man | ((val_pow | (val_s << 8)) << 23); 224 | // cast back to float 225 | float* tr_val_ptr = (float*)&tr_val; 226 | 227 | return tr_val_ptr[0]; 228 | } 229 | 230 | 231 | /** 232 | * @brief Truncate precision of a double floating point number by applying the algorithm of 233 | * `bit_truncate` to the mantissa. 234 | * 235 | * Note that NaN and inf are not explicitly checked for. According to the IEEE spec, it is 236 | * impossible for the truncation to turn an inf into a NaN. However, if the truncation 237 | * happens to remove all of the non-zero bits in the mantissa, a NaN can become inf. 238 | * 239 | */ 240 | inline double _bit_truncate_double(double val, double err) { 241 | // Step 1: Extract the sign, exponent and mantissa: 242 | // ------------------------------------------------ 243 | // cast float memory into an int 244 | int64_t* cast_val_ptr = (int64_t*)&val; 245 | // extract the exponent and sign 246 | int64_t val_pre = cast_val_ptr[0] >> POS_EXPONENT_DOUBLE; 247 | // strip sign 248 | int64_t val_pow = val_pre & MASK_EXPONENT_W_O_SIGN_DOUBLE; 249 | int64_t val_s = val_pre >> LEN_EXPONENT_DOUBLE; 250 | // extract mantissa. mask is 2**52 - 1. Add back the implicit 53rd bit 251 | int64_t val_man = (cast_val_ptr[0] & MASK_MANTISSA_DOUBLE) + IMPLICIT_BIT_DOUBLE; 252 | 253 | // Step 2: Scale the error to the integer representation of the mantissa: 254 | // ---------------------------------------------------------------------- 255 | // scale by 2**(52 + 1023 - pow) 256 | int64_t int_err = (int64_t)(err * fast_pow_double(1075 - val_pow)); 257 | // make sure hasn't overflowed. if set to 2**62-1, will surely round to 0. 258 | // must keep err < 2**62 for bit_truncate_double to work 259 | int_err = (int_err & HIGH_BITS_DOUBLE) ? ERR_MAX_DOUBLE : int_err; 260 | 261 | // Step 3: Truncate the mantissa: 262 | // ------------------------------ 263 | int64_t tr_man = bit_truncate_64(val_man, int_err); 264 | 265 | // Step 4: Put it back together: 266 | // ----------------------------- 267 | // count leading zeros 268 | int64_t z_count = count_zeros_64(tr_man); 269 | // adjust power after truncation to account for loss of implicit bit 270 | val_pow -= z_count - 11; 271 | // shift mantissa by same amount, remove implicit bit 272 | tr_man = (tr_man << (z_count - 11)) & MASK_MANTISSA_DOUBLE; 273 | // round to zero case 274 | val_pow = ((z_count != 64) ? val_pow : 0); 275 | // restore sign and exponent 276 | int64_t tr_val = tr_man | ((val_pow | (val_s << 11)) << 52); 277 | // cast back to double 278 | double* tr_val_ptr = (double*)&tr_val; 279 | 280 | return tr_val_ptr[0]; 281 | } 282 | -------------------------------------------------------------------------------- /caput/truncate.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | 3 | """Routines for truncating data to a specified precision.""" 4 | 5 | cimport cython 6 | from cython.parallel import prange 7 | 8 | import numpy as np 9 | cimport numpy as cnp 10 | 11 | from libc.math cimport fabs 12 | 13 | cdef extern from "truncate.hpp": 14 | inline int bit_truncate(int val, int err) nogil 15 | 16 | cdef extern from "truncate.hpp": 17 | inline long bit_truncate_64(long val, long err) nogil 18 | 19 | cdef extern from "truncate.hpp": 20 | inline float _bit_truncate_float(float val, float err) nogil 21 | 22 | 23 | cdef extern from "truncate.hpp": 24 | inline double _bit_truncate_double(double val, double err) nogil 25 | 26 | ctypedef double complex complex128 27 | 28 | cdef extern from "complex.h" nogil: 29 | double cabs(complex128) 30 | 31 | 32 | def bit_truncate_int(int val, int err): 33 | """ 34 | Bit truncation of a 32bit integer. 35 | 36 | Truncate the precision of `val` by rounding to a multiple of a power of 37 | two, keeping error less than or equal to `err`. 38 | 39 | Made available for testing. 40 | """ 41 | return bit_truncate(val, err) 42 | 43 | def bit_truncate_long(long val, long err): 44 | """ 45 | Bit truncation of a 64bit integer. 46 | 47 | Truncate the precision of `val` by rounding to a multiple of a power of 48 | two, keeping error less than or equal to `err`. 49 | 50 | Made available for testing. 51 | """ 52 | return bit_truncate_64(val, err) 53 | 54 | 55 | def bit_truncate_float(float val, float err): 56 | """Truncate using a fixed error. 57 | 58 | Parameters 59 | ---------- 60 | val 61 | The value to truncate. 62 | err 63 | The absolute precision to allow. 64 | 65 | Returns 66 | ------- 67 | val 68 | The truncated value. 69 | 70 | Raises 71 | ------ 72 | ValueError 73 | If `err` is a NaN. 74 | """ 75 | if err != err: 76 | raise ValueError(f"Error {err} is invalid.") 77 | 78 | return _bit_truncate_float(val, err) 79 | 80 | 81 | def bit_truncate_double(double val, double err): 82 | """Truncate using a fixed error. 83 | 84 | Parameters 85 | ---------- 86 | val 87 | The value to truncate. 88 | err 89 | The absolute precision to allow. 90 | 91 | Returns 92 | ------- 93 | val 94 | The truncated value. 95 | 96 | Raises 97 | ------ 98 | ValueError 99 | If `err` is a NaN. 100 | """ 101 | if err != err: 102 | raise ValueError(f"Error {err} is invalid.") 103 | 104 | return _bit_truncate_double(val, err) 105 | 106 | 107 | def bit_truncate_weights(val, inv_var, fallback): 108 | if val.dtype == np.float32 and inv_var.dtype == np.float32: 109 | return bit_truncate_weights_float(val, inv_var, fallback) 110 | if val.dtype == np.float64 and inv_var.dtype == np.float64: 111 | return bit_truncate_weights_double(val, inv_var, fallback) 112 | else: 113 | raise RuntimeError(f"Can't truncate data of type {val.dtype}/{inv_var.dtype} " 114 | f"(expected float32 or float64).") 115 | 116 | 117 | @cython.boundscheck(False) 118 | @cython.wraparound(False) 119 | def bit_truncate_weights_float(float[:] val, float[:] inv_var, float fallback): 120 | """Truncate using a set of inverse variance weights. 121 | 122 | Giving the error as an inverse variance is particularly useful for data analysis. 123 | 124 | N.B. non-contiguous arrays are supported in order to allow real and imaginary parts 125 | of numpy arrays to be truncated without making a copy. 126 | 127 | Parameters 128 | ---------- 129 | val 130 | The array of values to truncate the precision of. These values are modified in place. 131 | inv_var 132 | The acceptable precision expressed as an inverse variance. 133 | fallback 134 | A relative precision to use for cases where the inv_var is zero. 135 | 136 | Returns 137 | ------- 138 | val 139 | The modified array. This shares the same underlying memory as the input. 140 | """ 141 | cdef Py_ssize_t n = val.shape[0] 142 | cdef Py_ssize_t i = 0 143 | 144 | if val.ndim != 1: 145 | raise ValueError("Input array must be 1-d.") 146 | if inv_var.shape[0] != n: 147 | raise ValueError( 148 | f"Weight and value arrays must have same shape ({inv_var.shape[0]} != {n})" 149 | ) 150 | 151 | for i in prange(n, nogil=True): 152 | if inv_var[i] != 0: 153 | val[i] = _bit_truncate_float(val[i], 1.0 / inv_var[i]**0.5) 154 | else: 155 | val[i] = _bit_truncate_float(val[i], fallback * val[i]) 156 | 157 | return np.asarray(val) 158 | 159 | @cython.boundscheck(False) 160 | @cython.wraparound(False) 161 | def bit_truncate_weights_double(double[:] val, double[:] inv_var, double fallback): 162 | """Truncate array of doubles using a set of inverse variance weights. 163 | 164 | Giving the error as an inverse variance is particularly useful for data analysis. 165 | 166 | N.B. non-contiguous arrays are supported in order to allow real and imaginary parts 167 | of numpy arrays to be truncated without making a copy. 168 | 169 | Parameters 170 | ---------- 171 | val 172 | The array of values to truncate the precision of. These values are modified in place. 173 | inv_var 174 | The acceptable precision expressed as an inverse variance. 175 | fallback 176 | A relative precision to use for cases where the inv_var is zero. 177 | 178 | Returns 179 | ------- 180 | val 181 | The modified array. This shares the same underlying memory as the input. 182 | """ 183 | cdef Py_ssize_t n = val.shape[0] 184 | cdef Py_ssize_t i = 0 185 | 186 | if val.ndim != 1: 187 | raise ValueError("Input array must be 1-d.") 188 | if inv_var.shape[0] != n: 189 | raise ValueError( 190 | f"Weight and value arrays must have same shape ({inv_var.shape[0]} != {n})" 191 | ) 192 | 193 | for i in prange(n, nogil=True): 194 | if inv_var[i] != 0: 195 | val[i] = _bit_truncate_double(val[i], 1.0 / inv_var[i]**0.5) 196 | else: 197 | val[i] = _bit_truncate_double(val[i], fallback * val[i]) 198 | 199 | return np.asarray(val) 200 | 201 | def bit_truncate_relative(val, prec): 202 | if val.dtype == np.float32: 203 | return bit_truncate_relative_float(val, prec) 204 | if val.dtype == np.float64: 205 | return bit_truncate_relative_double(val, prec) 206 | else: 207 | raise RuntimeError(f"Can't truncate data of type {val.dtype} (expected float32 or float64).") 208 | 209 | 210 | @cython.boundscheck(False) 211 | @cython.wraparound(False) 212 | def bit_truncate_relative_float(float[:] val, float prec): 213 | """Truncate using a relative tolerance. 214 | 215 | N.B. non-contiguous arrays are supported in order to allow real and imaginary parts 216 | of numpy arrays to be truncated without making a copy. 217 | 218 | Parameters 219 | ---------- 220 | val 221 | The array of values to truncate the precision of. These values are modified in place. 222 | prec 223 | The fractional precision required. 224 | 225 | Returns 226 | ------- 227 | val 228 | The modified array. This shares the same underlying memory as the input. 229 | """ 230 | cdef Py_ssize_t n = val.shape[0] 231 | cdef Py_ssize_t i = 0 232 | 233 | for i in prange(n, nogil=True): 234 | val[i] = _bit_truncate_float(val[i], fabs(prec * val[i])) 235 | 236 | return np.asarray(val) 237 | 238 | 239 | @cython.boundscheck(False) 240 | @cython.wraparound(False) 241 | def bit_truncate_relative_double(cnp.float64_t[:] val, cnp.float64_t prec): 242 | """Truncate doubles using a relative tolerance. 243 | 244 | N.B. non-contiguous arrays are supported in order to allow real and imaginary parts 245 | of numpy arrays to be truncated without making a copy. 246 | 247 | Parameters 248 | ---------- 249 | val 250 | The array of double values to truncate the precision of. These values are modified in place. 251 | prec 252 | The fractional precision required. 253 | 254 | Returns 255 | ------- 256 | val 257 | The modified array. This shares the same underlying memory as the input. 258 | """ 259 | cdef Py_ssize_t n = val.shape[0] 260 | cdef Py_ssize_t i = 0 261 | 262 | for i in prange(n, nogil=True): 263 | val[i] = _bit_truncate_double(val[i], fabs(prec * val[i])) 264 | 265 | return np.asarray(val, dtype=np.float64) 266 | 267 | 268 | @cython.boundscheck(False) 269 | @cython.wraparound(False) 270 | def bit_truncate_max_complex(complex128[:, :] val, float prec, float prec_max_row): 271 | """Truncate using a relative per element and per the maximum of the last dimension. 272 | 273 | This scheme allows elements to be truncated based on their own value and a 274 | measure of their relative importance compared to other elements. In practice the 275 | per element absolute precision for an element `val[i, j]` is given by `max(prec * 276 | val[i, j], prec_max_dim * val[i].max())` 277 | 278 | Parameters 279 | ---------- 280 | val 281 | The array of values to truncate the precision of. These values are modified in place. 282 | prec 283 | The fractional precision on each elements. 284 | prec_max_row 285 | The precision to use relative to the maximum of the of each row. 286 | 287 | Returns 288 | ------- 289 | val 290 | The modified array. This shares the same underlying memory as the input. 291 | """ 292 | cdef Py_ssize_t n = val.shape[0] 293 | cdef Py_ssize_t m = val.shape[1] 294 | cdef Py_ssize_t i = 0, j = 0 295 | cdef float abs_prec 296 | cdef double vr, vi 297 | cdef double max_abs 298 | cdef double abs2 299 | 300 | for i in prange(n, nogil=True): 301 | 302 | max_abs = 0.0 303 | 304 | # Find the largest abs**2 value in the row, store in max_abs, but note that it is the *square* 305 | for j in range(m): 306 | vr = val[i, j].real 307 | vi = val[i, j].imag 308 | abs2 = vr * vr + vi * vi 309 | 310 | if abs2 > max_abs: 311 | max_abs = abs2 312 | 313 | max_abs = max_abs**0.5 314 | 315 | for j in range(m): 316 | # Get the precision to apply 317 | abs_prec = max(cabs(val[i, j]) * prec, prec_max_row * max_abs) 318 | 319 | vr = val[i, j].real 320 | vi = val[i, j].imag 321 | val[i, j].real = _bit_truncate_float(vr, abs_prec) 322 | val[i, j].imag = _bit_truncate_float(vi, abs_prec) 323 | 324 | return np.asarray(val) 325 | -------------------------------------------------------------------------------- /caput/fileformats.py: -------------------------------------------------------------------------------- 1 | """Interface for file formats supported by caput: HDF5 and Zarr.""" 2 | 3 | import logging 4 | import os 5 | import shutil 6 | 7 | import h5py 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | try: 12 | import zarr 13 | except ImportError: 14 | logger.debug("zarr support disabled. Install zarr to change this.") 15 | zarr_available = False 16 | else: 17 | zarr_available = True 18 | 19 | try: 20 | import numcodecs 21 | from bitshuffle.h5 import H5_COMPRESS_LZ4, H5FILTER 22 | except ModuleNotFoundError: 23 | logger.debug( 24 | "Install with 'compression' extra_require to use bitshuffle/numcodecs compression filters." 25 | ) 26 | compression_enabled = False 27 | H5FILTER, H5_COMPRESS_LZ4 = None, None 28 | else: 29 | compression_enabled = True 30 | 31 | if compression_enabled: 32 | # hdf5 parallel compression is broken before 1.13.1 33 | if h5py.version.hdf5_version_tuple < (1, 13, 1): 34 | import warnings 35 | 36 | warnings.warn( 37 | "HDF5 parallel compression has flaws prior to version 1.13.1, and can fail " 38 | f"unexpectedly. The current linked version is {h5py.version.hdf5_version_tuple}.", 39 | RuntimeWarning, 40 | ) 41 | 42 | 43 | class FileFormat: 44 | """Abstract base class for file formats supported by this module.""" 45 | 46 | module = None 47 | 48 | @staticmethod 49 | def open(*args, **vargs): 50 | """Open a file. 51 | 52 | Not implemented in base class 53 | """ 54 | raise NotImplementedError 55 | 56 | @staticmethod 57 | def compression_kwargs(compression=None, compression_opts=None, compressor=None): 58 | """Sort compression arguments in a format expected by file format module. 59 | 60 | Parameters 61 | ---------- 62 | compression : str or int 63 | Name or identifier of HDF5 compression filter. 64 | compression_opts 65 | See HDF5 documentation for compression filters. 66 | compressor : `numcodecs` compressor 67 | As required by `zarr`. 68 | 69 | Returns 70 | ------- 71 | dict 72 | Compression arguments as required by the file format module. 73 | """ 74 | if compressor and (compression or compression_opts): 75 | raise ValueError( 76 | f"Found more than one kind of compression args: compression ({compression}, {compression_opts}) " 77 | f"and compressor {compressor}." 78 | ) 79 | 80 | 81 | class HDF5(FileFormat): 82 | """Interface for using HDF5 file format from caput.""" 83 | 84 | module = h5py 85 | 86 | @staticmethod 87 | def open(*args, **kwargs): 88 | """Open an HDF5 file using h5py.""" 89 | return h5py.File(*args, **kwargs) 90 | 91 | @staticmethod 92 | def compression_kwargs(compression=None, compression_opts=None, compressor=None): 93 | """Format compression arguments for h5py API.""" 94 | super(HDF5, HDF5).compression_kwargs(compression, compression_opts, compressor) 95 | if compressor: 96 | raise NotImplementedError 97 | 98 | if compression == "bitshuffle" and not compression_enabled: 99 | raise ValueError( 100 | "Install with 'compression' extra_require to use bitshuffle/numcodecs compression filters." 101 | ) 102 | 103 | if compression_enabled and compression in ( 104 | "bitshuffle", 105 | H5FILTER, 106 | str(H5FILTER), 107 | ): 108 | compression = H5FILTER 109 | try: 110 | blocksize, c = compression_opts 111 | except ValueError as e: 112 | raise ValueError( 113 | f"Failed to interpret compression_opts: {e}\ncompression_opts: {compression_opts}." 114 | ) from e 115 | if blocksize is None: 116 | blocksize = 0 117 | if c in (str(H5_COMPRESS_LZ4), "lz4"): 118 | c = H5_COMPRESS_LZ4 119 | compression_opts = (blocksize, c) 120 | 121 | if compression is not None: 122 | return {"compression": compression, "compression_opts": compression_opts} 123 | return {} 124 | 125 | 126 | class Zarr(FileFormat): 127 | """Interface for using zarr file format from caput.""" 128 | 129 | if zarr_available: 130 | module = zarr 131 | else: 132 | module = None 133 | 134 | @staticmethod 135 | def open(*args, **kwargs): 136 | """Open a zarr file.""" 137 | if not zarr_available: 138 | raise RuntimeError("Can't open zarr file. Please install zarr.") 139 | return zarr.open_group(*args, **kwargs) 140 | 141 | @staticmethod 142 | def compression_kwargs(compression=None, compression_opts=None, compressor=None): 143 | """Format compression arguments for zarr API.""" 144 | super(Zarr, Zarr).compression_kwargs(compression, compression_opts, compressor) 145 | if compression: 146 | if not compression_enabled: 147 | raise ValueError( 148 | "Install with 'compression' extra_require to use bitshuffle/numcodecs compression filters." 149 | ) 150 | if compression == "gzip": 151 | return {"compressor": numcodecs.gzip.GZip(level=compression_opts)} 152 | if compression in (H5FILTER, str(H5FILTER), "bitshuffle"): 153 | try: 154 | blocksize, c = compression_opts 155 | except ValueError as e: 156 | raise ValueError( 157 | f"Failed to interpret compression_opts: {e}\ncompression_opts: {compression_opts}" 158 | ) from e 159 | if c in (H5_COMPRESS_LZ4, str(H5_COMPRESS_LZ4)): 160 | c = "lz4" 161 | if blocksize is None: 162 | blocksize = 0 163 | return { 164 | "compressor": numcodecs.Blosc( 165 | c, 166 | shuffle=numcodecs.blosc.BITSHUFFLE, 167 | blocksize=int(blocksize) if blocksize is not None else None, 168 | ) 169 | } 170 | raise ValueError(f"Compression filter not supported in zarr: {compression}") 171 | 172 | return {"compressor": compressor} 173 | 174 | 175 | class ZarrProcessSynchronizer: 176 | """A context manager for Zarr's ProcessSynchronizer that removes the lock files when done. 177 | 178 | If an MPI communicator is supplied, only rank 0 will attempt to remove files. 179 | 180 | Parameters 181 | ---------- 182 | name : str 183 | Name of the lockfile directory. 184 | comm : 185 | MPI communicator (optional). 186 | """ 187 | 188 | def __init__(self, name, comm=None): 189 | if not zarr_available: 190 | raise RuntimeError( 191 | "Can't use zarr process synchronizer. Please install zarr." 192 | ) 193 | self.name = name 194 | self._comm = comm 195 | 196 | def __enter__(self): 197 | return zarr.ProcessSynchronizer(self.name) 198 | 199 | def __exit__(self, exc_type, exc_val, exc_tb): 200 | if self._comm is None or self._comm.rank == 0: 201 | remove_file_or_dir(self.name) 202 | 203 | 204 | def remove_file_or_dir(name: str): 205 | """Remove the file or directory with the given name. 206 | 207 | Parameters 208 | ---------- 209 | name : str 210 | File or directory name to remove. 211 | """ 212 | if os.path.isdir(name): 213 | try: 214 | shutil.rmtree(name) 215 | except FileNotFoundError: 216 | pass 217 | else: 218 | try: 219 | os.remove(name) 220 | except FileNotFoundError: 221 | pass 222 | 223 | 224 | def guess_file_format(name, default=HDF5): 225 | """Guess the file format from the file name. 226 | 227 | Parameters 228 | ---------- 229 | name : str or pathlib.Path 230 | File name. 231 | default : FileFormat or None 232 | Fallback value if format can't be guessed. Default `fileformats.HDF5`. 233 | 234 | Returns 235 | ------- 236 | format : `FileFormat` 237 | File format guessed. 238 | """ 239 | import pathlib 240 | 241 | if isinstance(name, pathlib.Path): 242 | name = str(name) 243 | 244 | if name.endswith(".zarr.zip"): 245 | return Zarr 246 | if name.endswith(".zarr") or pathlib.Path(name).is_dir(): 247 | return Zarr 248 | if name.endswith(".h5") or name.endswith(".hdf5"): 249 | return HDF5 250 | return default 251 | 252 | 253 | def check_file_format(filename, file_format, data): 254 | """Compare file format with guess from filename and data. Return concluded format. 255 | 256 | Parameters 257 | ---------- 258 | filename : str 259 | File name. 260 | file_format : FileFormat or None 261 | File format. None if it should be guessed. 262 | data : any 263 | If this is an h5py.Group or zarr.Group, it will be used to guess or confirm the file format. 264 | 265 | Returns 266 | ------- 267 | file_format : HDF5 or Zarr 268 | File format. 269 | """ 270 | # check value 271 | if file_format not in (None, HDF5, Zarr): 272 | raise ValueError( 273 | f"Unexpected value for : {file_format} " 274 | f"(expected caput.fileformats.HDF5, caput.fileformats.Zarr or None)." 275 | ) 276 | 277 | # guess file format from 278 | if isinstance(data, h5py.Group): 279 | file_format_guess_output = HDF5 280 | elif zarr_available and isinstance(data, zarr.Group): 281 | file_format_guess_output = Zarr 282 | else: 283 | file_format_guess_output = None 284 | 285 | # guess file format from 286 | file_format_guess_name = guess_file_format(filename, None) 287 | 288 | # make sure guesses don't mismatch and decide on the format 289 | if ( 290 | file_format_guess_output 291 | and file_format_guess_name 292 | and file_format_guess_name != file_format_guess_output 293 | ): 294 | raise ValueError( 295 | f" ({file_format}) and ({filename}) don't seem to match." 296 | ) 297 | file_format_guess = ( 298 | file_format_guess_output if file_format_guess_output else file_format_guess_name 299 | ) 300 | if file_format is None: 301 | file_format = file_format_guess 302 | elif file_format != file_format_guess: 303 | raise ValueError( 304 | f"Value of ({file_format}) doesn't match ({filename}) " 305 | f"and type of data ({type(data).__name__})." 306 | ) 307 | 308 | return file_format 309 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # caput documentation build configuration file, created by 4 | # sphinx-quickstart on Thu Oct 10 12:52:16 2013. 5 | # 6 | # This file is execfile()d with the current directory set to its containing dir. 7 | # 8 | # Note that not all possible configuration values are present in this 9 | # autogenerated file. 10 | # 11 | # All configuration values have a default; values that are commented out 12 | # serve to show the default. 13 | 14 | 15 | import os, re 16 | 17 | # Check if we are on readthedocs 18 | on_rtd = os.environ.get("READTHEDOCS", None) == "True" 19 | 20 | import sys 21 | import caput 22 | 23 | # Mock up modules missing on readthedocs. 24 | from unittest.mock import Mock as MagicMock 25 | 26 | 27 | class Mock(MagicMock): 28 | @classmethod 29 | def __getattr__(cls, name): 30 | return Mock() 31 | 32 | 33 | # Do not mock up mpi4py. This is an "extra", and docs build without it. 34 | # MOCK_MODULES = ['h5py', 'mpi4py'] 35 | MOCK_MODULES = ["h5py"] 36 | if on_rtd: 37 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 38 | 39 | 40 | # If extensions (or modules to document with autodoc) are in another directory, 41 | # add these directories to sys.path here. If the directory is relative to the 42 | # documentation root, use os.path.abspath to make it absolute, like shown here. 43 | # sys.path.insert(0, os.path.abspath('../')) 44 | 45 | # -- General configuration ----------------------------------------------------- 46 | 47 | # If your documentation needs a minimal Sphinx version, state it here. 48 | # needs_sphinx = '1.0' 49 | 50 | # Add any Sphinx extension module names here, as strings. They can be extensions 51 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 52 | # 'numpydoc' does not ship with sphinx. To get it use `pip install numpydoc`. 53 | extensions = [ 54 | "sphinx.ext.autodoc", 55 | "sphinx.ext.autosummary", 56 | "sphinx.ext.intersphinx", 57 | "sphinx.ext.mathjax", 58 | "sphinx.ext.viewcode", 59 | "sphinx.ext.napoleon", 60 | "sphinx.ext.autosectionlabel", 61 | ] 62 | 63 | napoleon_google_docstring = False 64 | napoleon_numpy_docstring = True 65 | napoleon_include_init_with_doc = False 66 | napoleon_include_private_with_doc = False 67 | napoleon_include_special_with_doc = False 68 | napoleon_use_admonition_for_examples = False 69 | napoleon_use_admonition_for_notes = False 70 | napoleon_use_admonition_for_references = False 71 | napoleon_use_ivar = False 72 | napoleon_use_param = True 73 | napoleon_use_rtype = True 74 | napoleon_type_aliases = None 75 | napoleon_attr_annotations = True 76 | 77 | autoclass_content = "both" # include both class docstring and __init__ 78 | autodoc_default_options = { 79 | # Make sure that any autodoc declarations show the right members 80 | "members": True, 81 | "show-inheritance": True, 82 | } 83 | 84 | autosummary_generate = True # Make _autosummary files and include them 85 | autosummary_imported_members = False 86 | 87 | intersphinx_mapping = {"h5py": ("http://docs.h5py.org/en/latest/", None)} 88 | intersphinx_cache_limit = 1 89 | 90 | 91 | # This autodoc preprocessor replaces tokens like :class:`h5py.Dataset` with 92 | # :class:`h5py.Dataset `, as this is how h5py intersphinx domains are 93 | # set up. 94 | def process_docstring(app, what, name, obj, options, lines): 95 | for ii in range(len(lines)): 96 | lines[ii] = re.sub( 97 | r":([a-z]*):`h5py\.([a-zA-Z_\.]*)`", r":\1:`h5py.\2 `", lines[ii] 98 | ) 99 | 100 | 101 | def setup(app): 102 | app.connect("autodoc-process-docstring", process_docstring) 103 | 104 | 105 | # Add any paths that contain templates here, relative to this directory. 106 | templates_path = ["_templates"] 107 | 108 | # The suffix of source filenames. 109 | source_suffix = ".rst" 110 | 111 | # The encoding of source files. 112 | # source_encoding = 'utf-8-sig' 113 | 114 | # The master toctree document. 115 | master_doc = "index" 116 | 117 | # General information about the project. 118 | project = "caput" 119 | copyright = "2013-2016, Kiyoshi Masui and J. Richard Shaw" 120 | 121 | # The version info for the project you're documenting, acts as replacement for 122 | # |version| and |release|, also used in various other places throughout the 123 | # built documents. 124 | # 125 | # The short X.Y version. 126 | version = caput.__version__ 127 | # The full version, including alpha/beta/rc tags. 128 | release = caput.__version__ 129 | 130 | # The language for content autogenerated by Sphinx. Refer to documentation 131 | # for a list of supported languages. 132 | # language = None 133 | 134 | # There are two options for replacing |today|: either, you set today to some 135 | # non-false value, then it is used: 136 | # today = '' 137 | # Else, today_fmt is used as the format for a strftime call. 138 | # today_fmt = '%B %d, %Y' 139 | 140 | # List of patterns, relative to source directory, that match files and 141 | # directories to ignore when looking for source files. 142 | exclude_patterns = ["_build"] 143 | 144 | # The reST default role (used for this markup: `text`) to use for all documents. 145 | # default_role = None 146 | 147 | # If true, '()' will be appended to :func: etc. cross-reference text. 148 | # add_function_parentheses = True 149 | 150 | # If true, the current module name will be prepended to all description 151 | # unit titles (such as .. function::). 152 | # add_module_names = True 153 | 154 | # If true, sectionauthor and moduleauthor directives will be shown in the 155 | # output. They are ignored by default. 156 | # show_authors = False 157 | 158 | # The name of the Pygments (syntax highlighting) style to use. 159 | pygments_style = "sphinx" 160 | 161 | # A list of ignored prefixes for module index sorting. 162 | # modindex_common_prefix = [] 163 | 164 | 165 | # -- Options for HTML output --------------------------------------------------- 166 | 167 | # The theme to use for HTML and HTML Help pages. See the documentation for 168 | # a list of builtin themes. 169 | # html_theme = 'cloud' 170 | # html_theme_path = ["cloud"] 171 | # html_theme = 'default' 172 | 173 | html_theme = "sphinx_rtd_theme" 174 | 175 | # Theme options are theme-specific and customize the look and feel of a theme 176 | # further. For a list of options available for each theme, see the 177 | # documentation. 178 | # html_theme_options = {} 179 | 180 | # Add any paths that contain custom themes here, relative to this directory. 181 | # html_theme_path = [] 182 | 183 | # The name for this set of Sphinx documents. If None, it defaults to 184 | # " v documentation". 185 | # html_title = None 186 | 187 | # A shorter title for the navigation bar. Default is the same as html_title. 188 | # html_short_title = None 189 | 190 | # The name of an image file (relative to this directory) to place at the top 191 | # of the sidebar. 192 | # html_logo = None 193 | 194 | # The name of an image file (within the static path) to use as favicon of the 195 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 196 | # pixels large. 197 | # html_favicon = None 198 | 199 | # Add any paths that contain custom static files (such as style sheets) here, 200 | # relative to this directory. They are copied after the builtin static files, 201 | # so a file named "default.css" will overwrite the builtin "default.css". 202 | # html_static_path = ["_static"] 203 | 204 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 205 | # using the given strftime format. 206 | # html_last_updated_fmt = '%b %d, %Y' 207 | 208 | # If true, SmartyPants will be used to convert quotes and dashes to 209 | # typographically correct entities. 210 | # html_use_smartypants = True 211 | 212 | # Custom sidebar templates, maps document names to template names. 213 | # html_sidebars = {} 214 | 215 | # Additional templates that should be rendered to pages, maps page names to 216 | # template names. 217 | # html_additional_pages = {} 218 | 219 | # If false, no module index is generated. 220 | # html_domain_indices = True 221 | 222 | # If false, no index is generated. 223 | # html_use_index = True 224 | 225 | # If true, the index is split into individual pages for each letter. 226 | # html_split_index = False 227 | 228 | # If true, links to the reST sources are added to the pages. 229 | # html_show_sourcelink = True 230 | 231 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 232 | # html_show_sphinx = True 233 | 234 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 235 | # html_show_copyright = True 236 | 237 | # If true, an OpenSearch description file will be output, and all pages will 238 | # contain a tag referring to it. The value of this option must be the 239 | # base URL from which the finished HTML is served. 240 | # html_use_opensearch = '' 241 | 242 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 243 | # html_file_suffix = None 244 | 245 | # Output file base name for HTML help builder. 246 | htmlhelp_basename = "caputdoc" 247 | 248 | 249 | # -- Options for LaTeX output -------------------------------------------------- 250 | 251 | latex_elements = { 252 | # The paper size ('letterpaper' or 'a4paper'). 253 | #'papersize': 'letterpaper', 254 | # The font size ('10pt', '11pt' or '12pt'). 255 | #'pointsize': '10pt', 256 | # Additional stuff for the LaTeX preamble. 257 | #'preamble': '', 258 | } 259 | 260 | # Grouping the document tree into LaTeX files. List of tuples 261 | # (source start file, target name, title, author, documentclass [howto/manual]). 262 | latex_documents = [ 263 | ("index", "caput.tex", "caput Documentation", "Kiyoshi Masui", "manual") 264 | ] 265 | 266 | # The name of an image file (relative to this directory) to place at the top of 267 | # the title page. 268 | # latex_logo = None 269 | 270 | # For "manual" documents, if this is true, then toplevel headings are parts, 271 | # not chapters. 272 | # latex_use_parts = False 273 | 274 | # If true, show page references after internal links. 275 | # latex_show_pagerefs = False 276 | 277 | # If true, show URL addresses after external links. 278 | # latex_show_urls = False 279 | 280 | # Documents to append as an appendix to all manuals. 281 | # latex_appendices = [] 282 | 283 | # If false, no module index is generated. 284 | # latex_domain_indices = True 285 | 286 | 287 | # -- Options for manual page output -------------------------------------------- 288 | 289 | # One entry per manual page. List of tuples 290 | # (source start file, name, description, authors, manual section). 291 | man_pages = [("index", "caput", "caput Documentation", ["Kiyoshi Masui"], 1)] 292 | 293 | # If true, show URL addresses after external links. 294 | # man_show_urls = False 295 | 296 | 297 | # -- Options for Texinfo output ------------------------------------------------ 298 | 299 | # Grouping the document tree into Texinfo files. List of tuples 300 | # (source start file, target name, title, author, 301 | # dir menu entry, description, category) 302 | texinfo_documents = [ 303 | ( 304 | "index", 305 | "caput", 306 | "caput Documentation", 307 | "Kiyoshi Masui", 308 | "caput", 309 | "Cluster Astronomical Python Utilities", 310 | "Miscellaneous", 311 | ) 312 | ] 313 | 314 | # Documents to append as an appendix to all manuals. 315 | # texinfo_appendices = [] 316 | 317 | # If false, no module index is generated. 318 | # texinfo_domain_indices = True 319 | 320 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 321 | # texinfo_show_urls = 'footnote' 322 | -------------------------------------------------------------------------------- /caput/fftw.py: -------------------------------------------------------------------------------- 1 | """Fast FFT implementation using FFTW. 2 | 3 | This module adds some minor abstraction to use pyfftw in a way 4 | which seems to be faster than using the `pyfftw.builders` interface, 5 | and uses the same api as `scipy.fft` and `numpy.fft`. 6 | 7 | Only forward and reverse complex->complex transforms 8 | are currently supported. 9 | 10 | Examples 11 | -------- 12 | The core of this module is the :class:`FFT`, which essentially just 13 | abstracts the :class:`pyfftw:FFTW` in the simplest way. 14 | 15 | >>> import numpy as np 16 | >>> from caput import fftw 17 | >>> 18 | >>> shape = (24, 50) 19 | >>> x = np.random.rand(*shape) + 1j * np.random.rand(*shape) 20 | >>> 21 | >>> fftobj = fftw.FFT(x.shape, x.dtype, axes=-1) 22 | >>> 23 | >>> X = fftobj.fft(x) 24 | >>> xi = fftobj.ifft(X) 25 | >>> 26 | >>> np.allclose(x, xi) 27 | True 28 | 29 | The direct API can also be used, although it is slower when doing repeated 30 | transforms for arrays of the same shape and type because a new :class:`FFT` 31 | has to be created each time. 32 | 33 | References 34 | ---------- 35 | .. https://pyfftw.readthedocs.io 36 | .. http://www.fftw.org 37 | 38 | Classes 39 | ======= 40 | - :py:class:`FFT` 41 | 42 | Functions 43 | ========= 44 | - :py:meth:`fft` 45 | - :py:meth:`ifft` 46 | - :py:meth:`fftconvolve` 47 | - :py:meth:`fftwindow` 48 | """ 49 | 50 | from __future__ import annotations 51 | 52 | # NOTE: Due to a bug in pyfftw, it needs to be imported before 53 | # numpy in order to avoid some sort of namespace collision. 54 | # If you run into a RuntimeError when trying to use the `FFT` 55 | # class, make sure that your environment imports `pyfftw` 56 | # before `numpy`. Hopefully this will be fixed soon. 57 | try: 58 | import pyfftw 59 | except ImportError as exc: 60 | raise ImportError( 61 | "`pyfftw` is not installed. Install `pyfftw` via `caput[fftw]`." 62 | ) from exc 63 | 64 | import numpy as np 65 | 66 | from caput import mpiutil 67 | 68 | 69 | class FFT: 70 | """Faster FFTs with FFTW.""" 71 | 72 | def __init__( 73 | self, 74 | shape: tuple, 75 | dtype: type, 76 | axes: None | int | tuple = None, 77 | forward: bool = True, 78 | backward: bool = True, 79 | ): 80 | """Create FFTW objects for repeat use. 81 | 82 | This implementation is most efficient when used to repeatedly 83 | apply ffts to arrays with the same shape and dtype, because a 84 | single, highly optimised pathway can be used with a single 85 | initialisation. 86 | 87 | Even for a single use this will typically 88 | be faster than the `scipy.fft` or `numpy.fft` implementations, 89 | especially when multiple cores can be used. 90 | 91 | Parameters 92 | ---------- 93 | shape 94 | The shape of the arrays to initialise for 95 | dtype 96 | Datatype to create a pathway for. At the moment, only 97 | complex -> complex or real -> real are supported. The 98 | `pyfftw` implementation of the real -> real backward 99 | transform will destroy the input array 100 | axes 101 | Axes over which to apply the fft. Default is all axes. 102 | forward 103 | If true, initialise the forward fft. Default is True. 104 | backward 105 | If true, initialise the backward fft. Default is True. 106 | """ 107 | if not np.issubdtype(dtype, np.complexfloating): 108 | raise TypeError("Only complex->complex transforms are currently supported.") 109 | 110 | self._nsimd = pyfftw.simd_alignment 111 | ncpu = mpiutil.cpu_count() 112 | flags = ("FFTW_MEASURE",) 113 | 114 | if axes is None: 115 | axes = tuple(range(len(shape))) 116 | elif isinstance(axes, int): 117 | axes = (axes,) 118 | 119 | # Store fft params 120 | self._params = { 121 | "ncpu": ncpu, 122 | "simd_alignment": self._nsimd, 123 | "shape": shape, 124 | "dtype": dtype, 125 | "axes": axes, 126 | "flags": flags, 127 | } 128 | 129 | fftargs = { 130 | "input_array": pyfftw.empty_aligned(shape, dtype, n=self._nsimd), 131 | "output_array": pyfftw.empty_aligned(shape, dtype, n=self._nsimd), 132 | "axes": axes, 133 | "flags": flags, 134 | "threads": ncpu, 135 | } 136 | 137 | if forward: 138 | self._fft = pyfftw.FFTW(direction="FFTW_FORWARD", **fftargs) 139 | 140 | if backward: 141 | self._ifft = pyfftw.FFTW(direction="FFTW_BACKWARD", **fftargs) 142 | 143 | @property 144 | def params(self): 145 | """Display the parameters of this FFT. 146 | 147 | Returns 148 | ------- 149 | params: dict 150 | ncpu, simd alignment, shape, dtype, axes, and flags 151 | used by this FFT object. 152 | """ 153 | return self._params 154 | 155 | def fft(self, x): 156 | """Perform a forward FFT. 157 | 158 | Parameters 159 | ---------- 160 | x : np.ndarray 161 | Input array, must match the dtype specified 162 | at creation 163 | 164 | Returns 165 | ------- 166 | fft : np.ndarray 167 | DFT of the input array over specified axes 168 | """ 169 | try: 170 | return self._fft( 171 | input_array=x, 172 | output_array=pyfftw.empty_aligned(x.shape, x.dtype, n=self._nsimd), 173 | ) 174 | except AttributeError: 175 | raise RuntimeError("Forward fft not initialised.") 176 | 177 | def ifft(self, x): 178 | """Perform a backward FFT. 179 | 180 | When performing the backward real -> real IFFT, 181 | the input array is destroyed. 182 | 183 | Parameters 184 | ---------- 185 | x : np.ndarray 186 | Input array, must match the dtype specified 187 | at creation 188 | 189 | Returns 190 | ------- 191 | fft : np.ndarray 192 | IDFT of the input array over specified axes 193 | """ 194 | try: 195 | return self._ifft( 196 | input_array=x, 197 | output_array=pyfftw.empty_aligned(x.shape, x.dtype, n=self._nsimd), 198 | ) 199 | except AttributeError: 200 | raise RuntimeError("Backward fft not initialised.") 201 | 202 | def fftconvolve(self, in1, in2): 203 | """Convolve two arrays by multiplying in the Fourier domain. 204 | 205 | `in1` and `in2` must have the same dtype, and both the forward 206 | and backward FFTs must be initialised. 207 | 208 | Parameters 209 | ---------- 210 | in1 : np.ndarray 211 | First input array 212 | in2 : np.ndarray 213 | Second input array to by convolved with `x`. Must have 214 | the same dtype as `x`. 215 | 216 | Returns 217 | ------- 218 | out : np.ndarray 219 | Discrete convolution of `in1` and `in2` 220 | """ 221 | X1 = self.fft(in1) 222 | X2 = self.fft(in2) 223 | 224 | X1 *= X2 225 | 226 | return self.ifft(X1) 227 | 228 | def fftwindow(self, x, window): 229 | """Apply a window function in Fourier space. 230 | 231 | The only difference between this and `fftconvolve` is that 232 | this assumes that `window` is _already_ in the Fourier domain, 233 | and `window` can be real or complex when `x` is complex. 234 | 235 | Both the forward and backward FFTs must be initialised. 236 | 237 | Parameters 238 | ---------- 239 | x : np.ndarray 240 | Input array 241 | window : np.ndarray 242 | Window to be applied in the Fourier domain. 243 | 244 | Returns 245 | ------- 246 | out : np.ndarray 247 | Input array `x` with `window` applied in the Fourier domain. 248 | """ 249 | X = self.fft(x) 250 | X *= window 251 | 252 | return self.ifft(X) 253 | 254 | 255 | def fft(x, axes=None): 256 | """Perform a forward discrete Fourer Transform. 257 | 258 | If the fourier transform is to be applied repeatedly to 259 | arrays with the same size and dtype, it is faster to use 260 | the `FFT` class directly to avoid creating new `FFT` objects. 261 | 262 | Parameters 263 | ---------- 264 | x : np.ndarray 265 | Input array, real or complex 266 | axes : None | int | tuple 267 | Axes over which to take the fft. Default is all axes. 268 | 269 | Returns 270 | ------- 271 | fft : np.ndarray 272 | DFT of the input array over specified axes 273 | """ 274 | fftobj = FFT(x.shape, x.dtype, axes, forward=True, backward=False) 275 | 276 | return fftobj.fft(x) 277 | 278 | 279 | def ifft(x, axes=None): 280 | """Perform an inverse discrete Fourier Transform. 281 | 282 | If the fourier transform is to be applied repeatedly to 283 | arrays with the same size and dtype, it is faster to use 284 | the `FFT` class directly to avoid creating new `FFT` objects. 285 | 286 | Parameters 287 | ---------- 288 | x : np.ndarray 289 | Input array, real or complex 290 | axes : None | int | tuple 291 | Axes over which to take the ifft. Default is all axes. 292 | 293 | Returns 294 | ------- 295 | fft : np.ndarray 296 | IDFT of the input array over specified axes 297 | """ 298 | fftobj = FFT(x.shape, x.dtype, axes, forward=False, backward=True) 299 | 300 | return fftobj.ifft(x) 301 | 302 | 303 | def fftconvolve(in1, in2, axes=None): 304 | """Convolve two arrays by multiplying in the Fourier domain. 305 | 306 | `in1` and `in2` must have the same dtype. 307 | 308 | If the convolution is to be applied repeatedly to 309 | arrays with the same size and dtype, it is faster to use 310 | the `FFT` class directly to avoid creating new `FFT` objects. 311 | 312 | Parameters 313 | ---------- 314 | in1 : np.ndarray 315 | First input array 316 | in2 : np.ndarray 317 | Second input array to by convolved with `x`. Must have 318 | the same dtype as `x`. 319 | axes : None | int | tuple 320 | Axes over which to do the convolution. Default is all axes. 321 | 322 | Returns 323 | ------- 324 | out : np.ndarray 325 | Discrete convolution of `in1` and `in2` 326 | """ 327 | fftobj = FFT(in1.shape, in1.dtype, axes, forward=True, backward=True) 328 | 329 | return fftobj.fftconvolve(in1, in2) 330 | 331 | 332 | def fftwindow(x, window, axes): 333 | """Apply a window function in Fourier space. 334 | 335 | The only difference between this and `fftconvolve` is that 336 | this assumes that `window` is _already_ in the Fourier domain, 337 | and `window` can be real or complex when `x` is complex. 338 | 339 | If the window is to be applied repeatedly to 340 | arrays with the same size and dtype, it is faster to use 341 | the `FFT` class directly to avoid creating new `FFT` objects. 342 | 343 | Parameters 344 | ---------- 345 | x : np.ndarray 346 | Input array 347 | window : np.ndarray 348 | Window to be applied in the Fourier domain. 349 | axes : None | int | tuple 350 | Axes over which to apply the window. Default is all axes. 351 | 352 | Returns 353 | ------- 354 | out : np.ndarray 355 | Input array `x` with `window` applied in the Fourier domain. 356 | """ 357 | fftobj = FFT(x.shape, x.dtype, axes, forward=True, backward=True) 358 | 359 | return fftobj.fftwindow(x, window) 360 | -------------------------------------------------------------------------------- /caput/misc.py: -------------------------------------------------------------------------------- 1 | """A set of miscellaneous routines that don't really fit anywhere more specific.""" 2 | 3 | import importlib 4 | import os 5 | from pathlib import Path 6 | from typing import TYPE_CHECKING, Optional, overload 7 | 8 | import h5py 9 | import numpy as np 10 | 11 | if TYPE_CHECKING: 12 | from mpi4py import MPI 13 | 14 | 15 | def vectorize(**base_kwargs): 16 | """Improved vectorization decorator. 17 | 18 | Unlike the :class:`np.vectorize` decorator this version works on methods in 19 | addition to functions. It also gives an actual scalar value back for any 20 | scalar input, instead of returning a 0-dimension array. 21 | 22 | Parameters 23 | ---------- 24 | **base_kwargs 25 | Any keyword arguments accepted by :class:`np.vectorize` 26 | 27 | Returns 28 | ------- 29 | vectorized_function : func 30 | """ 31 | 32 | class _vectorize_desc: 33 | # See 34 | # http://www.ianbicking.org/blog/2008/10/decorators-and-descriptors.html 35 | # for a description of this pattern 36 | 37 | def __init__(self, func): 38 | # Save a reference to the function and set various properties so the 39 | # docstrings etc. get passed through 40 | self.func = func 41 | self.__doc__ = func.__doc__ 42 | self.__name__ = func.__name__ 43 | self.__module__ = func.__module__ 44 | 45 | def __call__(self, *args, **kwargs): 46 | # This gets called whenever the wrapped function is invoked 47 | arr = np.vectorize(self.func, **base_kwargs)(*args, **kwargs) 48 | 49 | if not arr.shape: 50 | arr = arr.item() 51 | 52 | return arr 53 | 54 | def __get__(self, obj, objtype=None): 55 | # As a descriptor, this gets called whenever this is used to wrap a 56 | # function, and simply binds it to the instance 57 | 58 | if obj is None: 59 | return self 60 | 61 | new_func = self.func.__get__(obj, objtype) 62 | return self.__class__(new_func) 63 | 64 | return _vectorize_desc 65 | 66 | 67 | def scalarize(dtype=np.float64): 68 | """Handle scalars and other iterables being passed to numpy requiring code. 69 | 70 | Parameters 71 | ---------- 72 | dtype : np.dtype, optional 73 | The output datatype. Used only to set the return type of zero-length arrays. 74 | 75 | Returns 76 | ------- 77 | vectorized_function : func 78 | """ 79 | 80 | class _scalarize_desc: 81 | # See 82 | # http://www.ianbicking.org/blog/2008/10/decorators-and-descriptors.html 83 | # for a description of this pattern 84 | 85 | def __init__(self, func): 86 | # Save a reference to the function and set various properties so the 87 | # docstrings etc. get passed through 88 | self.func = func 89 | self.__doc__ = func.__doc__ 90 | self.__name__ = func.__name__ 91 | self.__module__ = func.__module__ 92 | 93 | def __call__(self, *args, **kwargs): 94 | # This gets called whenever the wrapped function is invoked 95 | 96 | args, scalar, empty = zip(*[self._make_array(a) for a in args]) 97 | 98 | if all(empty): 99 | return np.array([], dtype=dtype) 100 | 101 | ret = self.func(*args, **kwargs) 102 | 103 | if all(scalar): 104 | ret = ret[0] 105 | 106 | return ret 107 | 108 | @staticmethod 109 | def _make_array(x): 110 | # Change iterables to arrays and scalars into length-1 arrays 111 | 112 | from skyfield import timelib 113 | 114 | # Special handling for the slightly awkward skyfield types 115 | if isinstance(x, timelib.Time): 116 | if isinstance(x.tt, np.ndarray): 117 | scalar = False 118 | else: 119 | scalar = True 120 | x = x.ts.tt_jd(np.array([x.tt])) 121 | 122 | elif isinstance(x, np.ndarray): 123 | scalar = False 124 | 125 | elif isinstance(x, list | tuple): 126 | x = np.array(x) 127 | scalar = False 128 | 129 | else: 130 | x = np.array([x]) 131 | scalar = True 132 | 133 | return (x, scalar, len(x) == 0) 134 | 135 | def __get__(self, obj, objtype=None): 136 | # As a descriptor, this gets called whenever this is used to wrap a 137 | # function, and simply binds it to the instance 138 | 139 | if obj is None: 140 | return self 141 | 142 | new_func = self.func.__get__(obj, objtype) 143 | return self.__class__(new_func) 144 | 145 | return _scalarize_desc 146 | 147 | 148 | def listize(**_): 149 | """Make functions that already work with `np.ndarray` or scalars accept lists. 150 | 151 | Also works with tuples. 152 | 153 | Returns 154 | ------- 155 | listized_function : func 156 | """ 157 | 158 | class _listize_desc: 159 | def __init__(self, func): 160 | # Save a reference to the function and set various properties so the 161 | # docstrings etc. get passed through 162 | self.func = func 163 | self.__doc__ = func.__doc__ 164 | self.__name__ = func.__name__ 165 | self.__module__ = func.__module__ 166 | 167 | def __call__(self, *args, **kwargs): 168 | # This gets called whenever the wrapped function is invoked 169 | 170 | new_args = [] 171 | for arg in args: 172 | if isinstance(arg, list | tuple): 173 | arg = np.array(arg) 174 | new_args.append(arg) 175 | 176 | return self.func(*new_args, **kwargs) 177 | 178 | def __get__(self, obj, objtype=None): 179 | # As a descriptor, this gets called whenever this is used to wrap a 180 | # function, and simply binds it to the instance 181 | 182 | if obj is None: 183 | return self 184 | 185 | new_func = self.func.__get__(obj, objtype) 186 | return self.__class__(new_func) 187 | 188 | return _listize_desc 189 | 190 | 191 | @overload 192 | def open_h5py_mpi( 193 | f: str | Path | h5py.File, 194 | mode: str, 195 | use_mpi: bool = True, 196 | comm: Optional["MPI.Comm"] = None, 197 | ) -> h5py.File: ... 198 | 199 | 200 | @overload 201 | def open_h5py_mpi( 202 | f: h5py.Group, mode: str, use_mpi: bool = True, comm: Optional["MPI.Comm"] = None 203 | ) -> h5py.Group: ... 204 | 205 | 206 | def open_h5py_mpi(f, mode, use_mpi=True, comm=None): 207 | """Ensure that we have an h5py File object. 208 | 209 | Opens with MPI-IO if possible. 210 | 211 | The returned file handle is annotated with two attributes: `.is_mpi` 212 | which says whether the file was opened as an MPI file and `.opened` which 213 | says whether it was opened in this call. 214 | 215 | Parameters 216 | ---------- 217 | f : string, h5py.File or h5py.Group 218 | Filename to open, or already open file object. If already open this 219 | is just returned as is. 220 | mode : string 221 | Mode to open file in. 222 | use_mpi : bool, optional 223 | Whether to use MPI-IO or not (default True) 224 | comm : mpi4py.Comm, optional 225 | MPI communicator to use. Uses `COMM_WORLD` if not set. 226 | 227 | Returns 228 | ------- 229 | fh : h5py.File 230 | File handle for h5py.File, with two extra attributes `.is_mpi` and 231 | `.opened`. 232 | """ 233 | import h5py 234 | 235 | has_mpi = h5py.get_config().mpi 236 | 237 | if isinstance(f, str): 238 | # Open using MPI-IO if we can 239 | if has_mpi and use_mpi: 240 | from mpi4py import MPI 241 | 242 | comm = comm if comm is not None else MPI.COMM_WORLD 243 | fh = h5py.File(f, mode, libver="latest", driver="mpio", comm=comm) 244 | else: 245 | fh = h5py.File(f, mode, libver="latest") 246 | fh.opened = True 247 | elif isinstance(f, h5py.File | h5py.Group): 248 | fh = f 249 | fh.opened = False 250 | else: 251 | raise ValueError( 252 | f"Can't write to {f} (Expected a h5py.File, h5py.Group or str filename)." 253 | ) 254 | 255 | fh.is_mpi = fh.file.driver == "mpio" 256 | 257 | return fh 258 | 259 | 260 | class lock_file: 261 | """Manage a lock file around a file creation operation. 262 | 263 | Parameters 264 | ---------- 265 | filename : str 266 | Final name for the file. 267 | preserve : bool, optional 268 | Keep the temporary file in the event of failure. 269 | comm : MPI.COMM, optional 270 | If present only rank=0 will create/remove the lock file and move the 271 | file. 272 | 273 | Returns 274 | ------- 275 | tmp_name : str 276 | File name to use in the locked block. 277 | 278 | Examples 279 | -------- 280 | >>> from . import memh5 281 | >>> container = memh5.BasicCont() 282 | >>> with lock_file('file_to_create.h5') as fname: 283 | ... container.save(fname) 284 | ... 285 | """ 286 | 287 | def __init__(self, name, preserve=False, comm=None): 288 | if comm is not None and not hasattr(comm, "rank"): 289 | raise ValueError("comm argument does not seem to be an MPI communicator.") 290 | 291 | self.name = name 292 | # If comm not specified, set internal rank0 marker to True, 293 | # so that rank>0 tasks can open their own files 294 | self.rank0 = True if comm is None else comm.rank == 0 295 | self.preserve = preserve 296 | 297 | def __enter__(self): 298 | if self.rank0: 299 | with open(self.lockfile, "w+") as fh: 300 | fh.write("") 301 | 302 | return self.tmpfile 303 | 304 | def __exit__(self, exc_type, exc_val, exc_tb): 305 | if self.rank0: 306 | # Check if exception was raised and delete the temp file if needed 307 | if exc_type is not None: 308 | if not self.preserve: 309 | os.remove(self.tmpfile) 310 | # Otherwise things were successful and we should move the file over 311 | else: 312 | os.rename(self.tmpfile, self.name) 313 | 314 | # Finally remove the lock file 315 | os.remove(self.lockfile) 316 | 317 | return False 318 | 319 | @property 320 | def tmpfile(self): 321 | """Full path to the lockfile (without file extension).""" 322 | base, fname = os.path.split(self.name) 323 | return os.path.join(base, "." + fname) 324 | 325 | @property 326 | def lockfile(self): 327 | """Full path to the lockfile (with file extension).""" 328 | return self.tmpfile + ".lock" 329 | 330 | 331 | # TODO: remove this. This was to support a patching of this routine to support Python 2 332 | # that used to exist in here. This will be removed when all other repos are changed to 333 | # use the version from `inspect` 334 | def getfullargspec(*args, **kwargs): 335 | """See `inspect.getfullargspec`. 336 | 337 | This is a Python 2 patch that will be removed. 338 | """ 339 | import inspect 340 | import warnings 341 | 342 | warnings.warn( 343 | "This patch to support Python 2 is no longer needed and will be removed.", 344 | DeprecationWarning, 345 | ) 346 | 347 | return inspect.getfullargspec(*args, **kwargs) 348 | 349 | 350 | def import_class(class_path): 351 | """Import class dynamically from a string. 352 | 353 | Parameters 354 | ---------- 355 | class_path : str 356 | Fully qualified path to the class. If only a single component, look up in the 357 | globals. 358 | 359 | Returns 360 | ------- 361 | class : class object 362 | The class we want to load. 363 | """ 364 | path_split = class_path.split(".") 365 | module_path = ".".join(path_split[:-1]) 366 | class_name = path_split[-1] 367 | if module_path: 368 | m = importlib.import_module(module_path) 369 | task_cls = getattr(m, class_name) 370 | else: 371 | task_cls = globals()[class_name] 372 | return task_cls 373 | -------------------------------------------------------------------------------- /tests/test_moving_weighted_median.py: -------------------------------------------------------------------------------- 1 | """Unit tests for moving weighted average function.""" 2 | 3 | import time 4 | import unittest 5 | 6 | import numpy as np 7 | 8 | from caput.weighted_median import weighted_median, moving_weighted_median 9 | 10 | 11 | def py_weighted_median(data, weights): 12 | """Flattens the given arrays and calculates a weighted median from that.""" 13 | data = np.reshape(data, np.prod(np.shape(data))) 14 | weights = np.reshape(weights, np.prod(np.shape(weights))) 15 | 16 | data, weights = np.array(data).squeeze(), np.array(weights).squeeze() 17 | 18 | # remove values with 0-weights 19 | choice = weights != 0 20 | data = data[choice] 21 | weights = weights[choice] 22 | 23 | s_data, s_weights = map(np.array, zip(*sorted(zip(data, weights)))) 24 | midpoint = 0.5 * sum(s_weights) 25 | if any(weights > midpoint): 26 | w_median = (data[weights == np.max(weights)])[0] 27 | else: 28 | cs_weights = np.cumsum(s_weights) 29 | try: 30 | idx_upper = ( 31 | np.intersect1d( 32 | np.where(cs_weights > midpoint)[0], 33 | np.where(cs_weights != 0)[0], 34 | True, 35 | )[0] 36 | + 1 37 | ) 38 | except IndexError: 39 | idx_upper = len(data) 40 | # skip zero-weights 41 | try: 42 | while weights[idx_upper] == 0: 43 | idx_upper += 1 44 | except IndexError: 45 | pass 46 | try: 47 | idx_lower = ( 48 | np.intersect1d( 49 | np.where(cs_weights[-1] - cs_weights > midpoint)[0], 50 | np.where(cs_weights != 0)[0], 51 | True, 52 | )[-1] 53 | + 1 54 | ) 55 | except IndexError: 56 | idx_lower = 0 57 | # skip zero-weights 58 | try: 59 | while weights[idx_lower] == 0: 60 | idx_lower -= 1 61 | except IndexError: 62 | pass 63 | if idx_upper == len(data) and idx_lower == -1: 64 | # All weights are 0. 65 | return 0 66 | w_median = np.mean(s_data[idx_lower:idx_upper]) 67 | return w_median 68 | 69 | 70 | def py_mwm_1d(values, weights, size): 71 | """Moving weighted median (one-dimensional).""" 72 | medians = [] 73 | 74 | # slide a window of size over the value array and get weighted median inside window 75 | for i in range(len(values)): 76 | # size is bigger than value array 77 | if i + size // 2 >= len(values) and i - size // 2 < 0: 78 | medians.append(py_weighted_median(values, weights)) 79 | 80 | # window is sliding into the value array 81 | elif i - size // 2 < 0: 82 | medians.append( 83 | py_weighted_median( 84 | values[: i + size // 2 + 1], weights[: i + size // 2 + 1] 85 | ) 86 | ) 87 | 88 | # window is sliding over the end of the value array 89 | elif i + size // 2 >= len(values): 90 | medians.append( 91 | py_weighted_median(values[i - size // 2 :], weights[i - size // 2 :]) 92 | ) 93 | 94 | # the normal case: window is inside the value array 95 | else: 96 | medians.append( 97 | py_weighted_median( 98 | values[i - size // 2 : i + size // 2 + 1], 99 | weights[i - size // 2 : i + size // 2 + 1], 100 | ) 101 | ) 102 | 103 | return medians 104 | 105 | 106 | def py_mwm_nd(values, weights, size): 107 | """Moving weighted median (n-dimensional).""" 108 | values = np.asarray(values) 109 | weights = np.asarray(weights) 110 | medians = np.ndarray(values.shape) 111 | 112 | # window radius around the index we want to get a median for 113 | r = np.floor_divide(size, 2).astype(int) 114 | 115 | # iterate over n-dim array 116 | for index, _ in np.ndenumerate(values): 117 | # get the edge indides of the window 118 | lbound = np.subtract(index, r, dtype=int) 119 | hbound = np.add(index, r + 1, dtype=int) 120 | 121 | # make sure they are inside the array 122 | lbound = np.maximum(lbound, 0) 123 | hbound = np.minimum(hbound, np.shape(values)) 124 | 125 | window = tuple(slice(i, j, 1) for (i, j) in zip(lbound, hbound)) 126 | medians[index] = py_weighted_median( 127 | values[window].flatten(), weights[window].flatten() 128 | ) 129 | return medians 130 | 131 | 132 | class TestMWM(unittest.TestCase): 133 | def test_the_test(self): 134 | mwm = py_mwm_1d( 135 | values=[1, 2, 3, 4, 5, 6, 7, 8], weights=[1, 2, 3, 4, 5, 6, 7, 8], size=3 136 | ) 137 | assert mwm == [2, 2.5, 3, 4, 5, 6, 7, 8] 138 | 139 | assert ( 140 | py_weighted_median([1, 2, 3, 4, 5, 6, 7, 8], [1, 0, 0, 0, 0, 0, 0, 1]) 141 | == 4.5 142 | ) 143 | 144 | assert py_weighted_median([1, 3, 3, 7], [0, 7, 7, 0]) == 3 145 | 146 | def test_the_nd_test(self): 147 | # 2D 148 | values = [[1, 2, 3], [2, 3, 1], [3, 1, 2]] 149 | mwm = py_mwm_nd(values, values, (3, 3)) 150 | np.testing.assert_array_equal( 151 | [[2.0, 2.5, 3.0], [2.5, 2.5, 2.5], [3.0, 2.5, 2.0]], mwm 152 | ) 153 | 154 | def test_wm(self): 155 | # As we return the same type as the input, the split median will get rounded down 156 | assert weighted_median( 157 | [1, 2, 3, 4, 5, 6, 7, 8], [1, 0, 0, 0, 0, 0, 0, 1] 158 | ) == int(4.5) 159 | 160 | assert ( 161 | weighted_median( 162 | [1, 2, 3, 4, 5, 6, 7, 8], [1, 0, 0, 0, 0, 0, 0, 1], method="lower" 163 | ) 164 | == 1 165 | ) 166 | 167 | assert ( 168 | weighted_median( 169 | [1, 2, 3, 4, 5, 6, 7, 8], [1, 0, 0, 0, 0, 0, 0, 1], method="higher" 170 | ) 171 | == 8 172 | ) 173 | 174 | values = np.asarray([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float64) 175 | weights = np.asarray([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float64) 176 | np.testing.assert_array_equal( 177 | py_weighted_median(values, weights), weighted_median(values, weights) 178 | ) 179 | 180 | values = np.asarray([[1, 2, 3], [2, 3, 1], [3, 1, 2]], dtype=np.float64) 181 | np.testing.assert_array_equal( 182 | py_weighted_median(values, values), weighted_median(values, values) 183 | ) 184 | 185 | values = np.asarray([0.1, 0, 7.7, 0, 9.999, 42, 0, 1, 9, 1], dtype=np.float64) 186 | weights = np.asarray([0, 7.5, 0.33, 23.23, 0, 4, 7, 8, 9, 0], dtype=np.float64) 187 | np.testing.assert_array_equal( 188 | py_weighted_median(values, weights), weighted_median(values, weights) 189 | ) 190 | 191 | values = [0, 0, 7, 0, 9] 192 | weights = [0, 7, 1, 23, 0] 193 | np.testing.assert_array_equal( 194 | py_weighted_median(values, weights), weighted_median(values, weights) 195 | ) 196 | 197 | values = [0, 7, 0, 9, 42] 198 | weights = [7, 1, 23, 0, 4] 199 | np.testing.assert_array_equal( 200 | py_weighted_median(values, weights), weighted_median(values, weights) 201 | ) 202 | 203 | values = [7, 0, 9, 42, 0] 204 | weights = [1, 23, 0, 4, 7] 205 | np.testing.assert_array_equal( 206 | py_weighted_median(values, weights), weighted_median(values, weights) 207 | ) 208 | 209 | values = [0, 9, 42, 0, 1] 210 | weights = [23, 0, 4, 7, 8] 211 | np.testing.assert_array_equal( 212 | py_weighted_median(values, weights), weighted_median(values, weights) 213 | ) 214 | 215 | values = [0, 4, 6, 7, 8, 8] 216 | weights = [7, 6, 2, 0, 8, 7] 217 | np.testing.assert_array_equal( 218 | py_weighted_median(values, weights), weighted_median(values, weights) 219 | ) 220 | 221 | def test_weighted_median_methods(self): 222 | # Note this test 223 | values = np.array([9, 2, 5, 5, 2, 9]) 224 | weights = np.array([3, 0, 5, 0, 8, 0]) 225 | np.testing.assert_equal(weighted_median(values, weights, method="lower"), 2) 226 | np.testing.assert_equal(weighted_median(values, weights, method="higher"), 5) 227 | 228 | def test_1d_mwm_int(self): 229 | values = [0, 0, 7, 0, 9, 42, 0, 1, 9, 1] 230 | weights = [0, 7, 1, 23, 0, 4, 7, 8, 9, 0] 231 | np.testing.assert_array_almost_equal( 232 | py_mwm_1d(values, weights, 5), moving_weighted_median(values, weights, 5) 233 | ) 234 | 235 | def test_1d_mwm(self): 236 | values = [0.1, 0, 7.7, 0, 9.999, 42, 0, 1, 9, 1] 237 | weights = [0, 7.5, 0.33, 23.23, 0, 4, 7, 8, 9, 0] 238 | np.testing.assert_array_equal( 239 | py_mwm_1d(values, weights, 5), moving_weighted_median(values, weights, 5) 240 | ) 241 | 242 | # These two are just for measuring performance: 243 | def test_1d_mwm_big(self): 244 | N = 100 245 | values = np.random.random_sample(N) 246 | weights = np.random.random_sample(N) 247 | t0 = time.time() 248 | res_py = py_mwm_1d(values, weights, 5) 249 | t_py = time.time() - t0 250 | t0 = time.time() 251 | res_cython = moving_weighted_median(values, weights, 5) 252 | t_cython = time.time() - t0 253 | np.testing.assert_array_equal(res_py, res_cython) 254 | print( 255 | "1D moving weighted median with {} elements took {}s / {}s".format( 256 | N, t_py, t_cython 257 | ) 258 | ) 259 | 260 | def test_2d_mwm_big(self): 261 | N = 100 262 | M = 100 263 | N_w = 5 264 | values = np.random.random_sample((N, M)) 265 | weights = np.random.random_sample((N, M)) 266 | t0 = time.time() 267 | moving_weighted_median(values, weights, (N_w, N_w)) 268 | t_cython = time.time() - t0 269 | print( 270 | "2D moving {0}x{0} weighted median with {1}x{1} elements took {2}s".format( 271 | N_w, N, t_cython 272 | ) 273 | ) 274 | 275 | def test_zero_weights(self): 276 | values = [1, 1, 1] 277 | weights = [0, 0, 0] 278 | np.testing.assert_array_equal( 279 | [0, 0, 0], moving_weighted_median(values, weights, 1) 280 | ) 281 | 282 | def test_2d_mwm_small(self): 283 | values = [[9, 2], [5, 5], [2, 9]] 284 | weights = [[3, 0], [5, 0], [8, 0]] 285 | np.testing.assert_array_equal( 286 | py_mwm_nd(values, weights, 3), 287 | moving_weighted_median(values, weights, (3, 3)), 288 | ) 289 | 290 | values = [ 291 | [8.0, 6.0, 1.0, 5.0], 292 | [6.0, 6.0, 8.0, 10.0], 293 | [8.0, 4.0, 4.0, 7.0], 294 | [10.0, 1.0, 7.0, 8.0], 295 | ] 296 | weights = [ 297 | [8.0, 6.0, 1.0, 5.0], 298 | [6.0, 6.0, 8.0, 10.0], 299 | [8.0, 4.0, 4.0, 7.0], 300 | [10.0, 1.0, 7.0, 8.0], 301 | ] 302 | np.testing.assert_array_equal( 303 | py_mwm_nd(values, weights, 3), 304 | moving_weighted_median(values, weights, (3, 3)), 305 | ) 306 | values = [[9, 4], [2, 5]] 307 | weights = [[4, 8], [4, 5]] 308 | np.testing.assert_array_equal( 309 | py_mwm_nd(values, weights, 3), 310 | moving_weighted_median(values, weights, (3, 3)), 311 | ) 312 | 313 | def test_2d_mwm_small_large_window(self): 314 | # Try a window that is much larger than the input array. 315 | # This has caused crashes in older versions 316 | 317 | values = [ 318 | [8.0, 6.0, 1.0, 5.0], 319 | [6.0, 6.0, 8.0, 10.0], 320 | [8.0, 4.0, 4.0, 7.0], 321 | [10.0, 1.0, 7.0, 8.0], 322 | ] 323 | weights = [ 324 | [8.0, 6.0, 1.0, 5.0], 325 | [6.0, 6.0, 8.0, 10.0], 326 | [8.0, 4.0, 4.0, 7.0], 327 | [10.0, 1.0, 7.0, 8.0], 328 | ] 329 | py_res = py_mwm_nd(values, weights, 11) 330 | cy_res = moving_weighted_median(values, weights, (11, 11)) 331 | np.testing.assert_array_equal(py_res, cy_res) 332 | 333 | # The window is so large all values should be equal, double check that 334 | np.testing.assert_array_equal(cy_res, cy_res[0, 0]) 335 | 336 | def test_2d_mwm_int(self): 337 | values = np.asarray(np.random.randint(0, 10, (14, 8)), np.float64) 338 | weights = np.asarray(np.random.randint(0, 10, (14, 8)), np.float64) 339 | np.testing.assert_array_equal( 340 | py_mwm_nd(values, weights, (3, 3)), 341 | moving_weighted_median(values, weights, (3, 3)), 342 | ) 343 | 344 | def test_2d_mwm(self): 345 | values = np.random.rand(14, 8) 346 | weights = np.random.rand(14, 8) 347 | np.testing.assert_array_equal( 348 | py_mwm_nd(values, weights, 3), 349 | moving_weighted_median(values, weights, (3, 3)), 350 | ) 351 | 352 | # weights are all zeros for a region that is smaller than the window 353 | def test_small_zero_weight(self): 354 | window = (3, 3) 355 | zero_shape = (2, 2) 356 | shape = (10, 10) 357 | 358 | data = np.ones(shape, dtype=np.float64) 359 | weight = np.ones_like(data) 360 | weight[1 : zero_shape[0] + 1, 1 : zero_shape[1] + 1] = 0.0 361 | 362 | np.testing.assert_array_equal( 363 | data, moving_weighted_median(data, weight, window) 364 | ) 365 | 366 | # weights are all zeros for a region that is the size of the window, this has historically 367 | # caused a segfault 368 | def test_med_zero_weight(self): 369 | window = (3,) 370 | zero_shape = (3,) 371 | shape = (10,) 372 | 373 | data = np.ones(shape, dtype=np.float64) 374 | weight = np.ones_like(data) 375 | weight[1 : zero_shape[0] + 1] = 0.0 376 | 377 | result = data.copy() 378 | result[2] = np.nan 379 | 380 | np.testing.assert_array_equal( 381 | result, moving_weighted_median(data, weight, window) 382 | ) 383 | -------------------------------------------------------------------------------- /caput/profile.py: -------------------------------------------------------------------------------- 1 | """Helper routines for profiling the CPU and IO usage of code.""" 2 | 3 | import logging 4 | import math 5 | import os 6 | import threading 7 | import time 8 | from pathlib import Path 9 | from typing import ClassVar, Optional 10 | 11 | import numpy as np 12 | import psutil 13 | 14 | from . import mpiutil 15 | 16 | 17 | class Profiler: 18 | """A context manager to profile a block of code using various profilers. 19 | 20 | Parameters 21 | ---------- 22 | profile 23 | Whether to run the profiler or not. 24 | profiler 25 | Which profiler to run. Currently `cProfile` and `pyinstrument` are supported. 26 | comm 27 | An optional MPI communicator. This is only used for labelling the output files. 28 | path 29 | The optional path under which to write the profiles. If not set use the 30 | current directory. 31 | """ 32 | 33 | profilers: ClassVar = ["cprofile", "pyinstrument"] 34 | 35 | def __init__( 36 | self, 37 | profile: bool = True, 38 | profiler: str = "cprofile", 39 | comm: Optional["mpiutil.MPI.IntraComm"] = None, 40 | path: os.PathLike | None = None, 41 | ): 42 | self.profile = profile 43 | 44 | if profiler not in self.profilers: 45 | raise ValueError(f"Unsupported profiler: {profiler}") 46 | 47 | self.profiler = profiler 48 | self.comm = comm 49 | self._pr = None 50 | 51 | if path is None: 52 | self.path = Path.cwd() 53 | else: 54 | self.path = Path(path) 55 | 56 | def __enter__(self): 57 | if not self.profile: 58 | return 59 | 60 | if self.profiler == "cprofile": 61 | import cProfile 62 | 63 | self._pr = cProfile.Profile() 64 | self._pr.enable() 65 | 66 | elif self.profiler == "pyinstrument": 67 | import pyinstrument 68 | 69 | self._pr = pyinstrument.Profiler() 70 | self._pr.start() 71 | 72 | def __exit__(self, *args, **kwargs): 73 | if not self.profile: 74 | return 75 | 76 | if self.comm is None: 77 | rank = mpiutil.rank 78 | size = mpiutil.size 79 | else: 80 | rank = self.comm.rank 81 | size = self.comm.size 82 | 83 | rank_length = int(math.log10(size)) + 1 84 | 85 | if self.profiler == "cprofile": 86 | self._pr.disable() 87 | self._pr.dump_stats(f"profile_{rank:0{rank_length}}.prof") 88 | 89 | elif self.profiler == "pyinstrument": 90 | self._pr.stop() 91 | with open(f"profile_{rank:0{rank_length}}.txt", "w") as fh: 92 | fh.write(self._pr.output_text(unicode=True)) 93 | 94 | 95 | class IOUsage: 96 | """A context manager that gives the amount of IO done. 97 | 98 | To access the IO usage the context manager object must be created and bound to a 99 | variable *before* the *with* statement. 100 | 101 | >>> u = IOUsage() 102 | >>> with u: 103 | ... print("do some IO in here") 104 | do some IO in here 105 | >>> print(u.usage) #doctest: +ELLIPSIS 106 | {...} 107 | 108 | Parameters 109 | ---------- 110 | logger 111 | If a logging object is passed the values of the IO done counters are logged 112 | at INFO level. 113 | """ 114 | 115 | _start = None 116 | 117 | def __init__(self, logger: logging.Logger | None = None): 118 | self._logger = logger 119 | self._usage = {} 120 | 121 | @staticmethod 122 | def _get_io(): # pylint: disable=no-self-use 123 | # Get the cumulative IO performed 124 | 125 | if psutil.MACOS: 126 | d = psutil.disk_io_counters() 127 | else: 128 | p = psutil.Process() 129 | d = p.io_counters() 130 | 131 | # There doesn't seem to be a public API for this 132 | return d._asdict() 133 | 134 | @staticmethod 135 | def _units(key): # pylint: disable=no-self-use 136 | # Try and infer the units for this particular counter 137 | 138 | suffix = key.split("_")[-1] 139 | 140 | if suffix == "count": 141 | return "" 142 | if suffix == "time": 143 | return "ms" 144 | if suffix == "bytes": 145 | return "bytes" 146 | return "" 147 | 148 | def __enter__(self): 149 | self._start = self._get_io() 150 | 151 | def __exit__(self, *args, **kwargs): 152 | f = self._get_io() 153 | 154 | for name in f: 155 | self._usage[name] = f[name] - self._start[name] 156 | 157 | if self._logger: 158 | for key, value in self.usage.items(): 159 | self._logger.info(f"IO usage({key}): {value} {self._units(key)}") 160 | 161 | @property 162 | def usage(self): 163 | """The IO usage within the block.""" 164 | return self._usage.copy() 165 | 166 | 167 | class PSUtilProfiler(psutil.Process): 168 | """A context manager that profiles using psutil. 169 | 170 | To access the profiling data the context manager object 171 | must be created and bound to a variable *before* the *with* statement. 172 | 173 | Dumps results into csv file, one for each rank. 174 | 175 | >>> p = PSUtilProfiler(label='task-label') 176 | >>> with p: 177 | ... print("do some task in here") 178 | do some task in here 179 | >>> print(p.usage) #doctest: +ELLIPSIS 180 | {...} 181 | 182 | `start` and `stop` can be used to use the PSUtilProfiler outside of cotnext management. 183 | 184 | >>> p = PSUtilProfiler(label='task-label') 185 | >>> p.start() 186 | >>> print('do some task in here') 187 | do some task in here 188 | >>> p.stop() 189 | >>> print(p.usage) #doctest: +ELLIPSIS 190 | {...} 191 | 192 | Parameters 193 | ---------- 194 | use_profiler : bool 195 | Whether to run the profiler or not. 196 | label : str 197 | Default description of what is being profiled. 198 | logger 199 | If a logging object is passed the values of the IO done counters are logged 200 | at INFO level. 201 | comm 202 | An optional MPI communicator. This is only used for labelling the output files. 203 | path 204 | The optional directory path under which to write the profile csvs. If not set use the 205 | current directory. 206 | """ 207 | 208 | def __init__( 209 | self, 210 | use_profiler: bool = True, 211 | label: str = "", 212 | logger: logging.Logger | None = None, 213 | comm: Optional["mpiutil.MPI.IntraComm"] = None, 214 | path: os.PathLike | None = None, 215 | ): 216 | self._use_profiler = use_profiler 217 | self._label = label 218 | self._usage = {} 219 | self._logger = logger 220 | self.comm = comm 221 | 222 | if self.comm is None: 223 | rank = mpiutil.rank 224 | else: 225 | rank = self.comm.rank 226 | 227 | if path is None: 228 | self.path = Path.cwd() 229 | else: 230 | if not self.path.is_dir() or not self.path.exists(): 231 | raise ValueError( 232 | f"Make sure {self.path} passed to PSUtillProfiler is a directory that exists." 233 | ) 234 | self.path = Path(path) 235 | 236 | self.path = self.path / f"perf_{rank}.csv" 237 | 238 | self._start_cpu_times = None 239 | self._start_memory = None 240 | self._start_disk_io = None 241 | self._start_time = None 242 | 243 | super().__init__() 244 | 245 | if self._use_profiler and not self.path.exists(): 246 | import csv 247 | 248 | with open(self.path, mode="w") as fp: 249 | colnames = [ 250 | "task_name", 251 | "time_s", 252 | "cpu_times_user", 253 | "cpu_times_system", 254 | "cpu_times_children_user", 255 | "cpu_times_children_system", 256 | "cpu_times_iowait", 257 | "average_cpu_load_percent", 258 | "disk_io_read_count", 259 | "disk_io_write_count", 260 | "disk_io_read_bytes", 261 | "disk_io_write_bytes", 262 | "disk_io_read_chars", 263 | "disk_io_write_chars", 264 | "memory_change_uss", 265 | "current_available_memory_bytes", 266 | "current_total_used_memory_bytes", 267 | "peak_used_memory_bytes", 268 | ] 269 | cw = csv.writer(fp) 270 | cw.writerow(colnames) 271 | 272 | if self._use_profiler and self._logger: 273 | self._logger.info(f"Profiling pipeline: {self.cpu_count} cores available.") 274 | 275 | def __eq__(self, other): 276 | if not isinstance(other, psutil.Process): 277 | return False 278 | return ( 279 | psutil.Process.__eq__(self, other) 280 | and self._start_cpu_times == other._start_cpu_times 281 | and self._start_memory == other._start_memory 282 | and self._start_disk_io == other._start_disk_io 283 | and self._start_time == other._start_time 284 | ) 285 | 286 | def __enter__(self): 287 | if not self._use_profiler: 288 | return 289 | self.start() 290 | 291 | def __exit__(self, *args, **kwargs): 292 | if not self._use_profiler: 293 | return 294 | self.stop() 295 | 296 | def start(self): 297 | """Start profiling. 298 | 299 | Results generated when `stop` is called are based on this start time. 300 | 301 | """ 302 | self._start_time = time.time() 303 | 304 | # Get all stats at the same time 305 | with self.oneshot(): 306 | self._start_cpu_times = self.cpu_times() 307 | self.cpu_percent() 308 | self._start_memory = self.memory_full_info().uss 309 | if psutil.MACOS: 310 | self._start_disk_io = psutil.disk_io_counters() 311 | else: 312 | self._start_disk_io = self.io_counters() 313 | 314 | self._thread_flag = threading.Event() 315 | self._thread_flag.set() 316 | 317 | self._peak_memory = self._start_memory 318 | self._monitor_thread = threading.Thread(target=self.monitor) 319 | self._monitor_thread.start() 320 | 321 | def stop(self): 322 | """Stop profiler. Dump results to csv file and/or log and/or set results on self.usage. 323 | 324 | `start` must be called first. 325 | 326 | Returns 327 | ------- 328 | Sets usage dictionary on `self` with the following attributes, under 'label' key. 329 | 330 | cpu_times : `dict` 331 | dict version of `psutil.cpu_times`. Process CPU times since `start` was called in seconds. 332 | cpu_percent : float 333 | Process CPU utilization since `start` was called as percentage. Can be >100 if multiple threads run on 334 | different cores. See `PSUtil.cpu_count` for available cores. 335 | disk_io : `dict` 336 | dict version of `psutil.io_counters` (on Linux) or `psutil.disk_io_counters` (on MacOS). 337 | Difference since `start` was called. 338 | memory : str 339 | Difference of memory in use by this process since `start` was called. If negative, 340 | less memory is in use now. 341 | used_memory : str 342 | Current used memory at the time of the task's end. 343 | available_memory : str 344 | Current memory available to the system at the time of the task's end. 345 | 346 | Raises 347 | ------ 348 | RuntimeError 349 | If stop was called before start. 350 | """ 351 | if self._start_cpu_times is None: 352 | raise RuntimeError("PSUtilProfiler.stop was called before start.") 353 | 354 | # Stop the monitoring process 355 | self._thread_flag.clear() 356 | self._monitor_thread.join() 357 | 358 | stop_time = time.time() 359 | 360 | # Get all stats at the same time 361 | with self.oneshot(): 362 | cpu_times = self.cpu_times() 363 | cpu_percent = self.cpu_percent() 364 | memory = self.memory_full_info().uss 365 | used_memory = psutil.virtual_memory().used 366 | available_memory = psutil.virtual_memory().available 367 | if psutil.MACOS: 368 | disk_io = psutil.disk_io_counters() 369 | else: 370 | disk_io = self.io_counters() 371 | 372 | # Construct results 373 | self._usage = {"task_name": self._label} 374 | 375 | cpu_times_arr = np.subtract(cpu_times, self._start_cpu_times) 376 | disk_io_arr = np.subtract(disk_io, self._start_disk_io) 377 | 378 | cpu_times = dict(zip(cpu_times._fields, cpu_times_arr)) 379 | # contain results in dictionary 380 | disk_io = dict(zip(disk_io._fields, disk_io_arr)) 381 | # contain results in dictionary 382 | 383 | memory = memory - self._start_memory 384 | 385 | self._usage["cpu_times"] = cpu_times 386 | self._usage["cpu_percent"] = cpu_percent 387 | self._usage["disk_io"] = disk_io 388 | 389 | def bytes2human(num): 390 | for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 391 | if abs(num) < 1024.0: 392 | return f"{num:3.1f}{unit}B" 393 | num /= 1024.0 394 | return f"{num:.1f}YiB" 395 | 396 | self._usage["memory"] = memory 397 | self._usage["used_memory"] = used_memory 398 | self._usage["available_memory"] = available_memory 399 | self._usage["peak_memory"] = self._peak_memory 400 | 401 | time_s = stop_time - self._start_time 402 | 403 | self._usage["time_s"] = time_s 404 | 405 | if time_s < 0.1 and self._logger: 406 | self._logger.info( 407 | f"{self._label} ran for {time_s:.4f} < 0.1s, results might be inaccurate.\n" 408 | ) 409 | 410 | if self._logger: 411 | self._logger.info(f"{self._label} ran for {time_s:.4f}s") 412 | self._logger.info(f"{cpu_times}") 413 | self._logger.info(f"average CPU load: {cpu_percent}") 414 | self._logger.info(f"{disk_io}") 415 | self._logger.info(f"change in (uss) memory: {bytes2human(memory)}") 416 | self._logger.info( 417 | f"current available memory: {bytes2human(available_memory)}" 418 | ) 419 | self._logger.info(f"current total used memory: {bytes2human(used_memory)}") 420 | self._logger.info(f"peak used memory: {bytes2human(self._peak_memory)}") 421 | 422 | with open(self.path, mode="a", newline="") as fp: 423 | import csv 424 | 425 | cw = csv.writer(fp) 426 | cw.writerow( 427 | [ 428 | self._label, 429 | time_s, 430 | cpu_times["user"], 431 | cpu_times["system"], 432 | cpu_times["children_user"], 433 | cpu_times["children_system"], 434 | cpu_times.get("iowait", "-"), 435 | cpu_percent, 436 | disk_io["read_count"], 437 | disk_io["write_count"], 438 | disk_io["read_bytes"], 439 | disk_io["write_bytes"], 440 | disk_io.get("read_chars", "-"), 441 | disk_io.get("write_chars", "-"), 442 | memory, 443 | available_memory, 444 | used_memory, 445 | self._peak_memory, 446 | ] 447 | ) 448 | 449 | def monitor(self): 450 | """Track peak memory.""" 451 | while self._thread_flag.is_set(): 452 | current_mem = psutil.virtual_memory().used 453 | self._peak_memory = max(self._peak_memory, current_mem) 454 | time.sleep(0.5) 455 | 456 | @property 457 | def cpu_count(self): 458 | """Number of cores available to this process.""" 459 | if psutil.MACOS: 460 | return psutil.cpu_count() 461 | return len(self.cpu_affinity()) 462 | 463 | @property 464 | def usage(self): 465 | """The memory and cpu usage within the block.""" 466 | return self._usage.copy() 467 | -------------------------------------------------------------------------------- /tests/test_time.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from datetime import datetime, timezone 5 | 6 | import numpy as np 7 | import pytest 8 | from skyfield import earthlib, api 9 | from pytest import approx, raises 10 | 11 | from caput import time as ctime 12 | 13 | 14 | # Download the required Skyfield files from a mirror on a CHIME server. 15 | # 16 | # The upstream servers for the timescale and ephemeris data can be 17 | # flaky. Use this to ensure a copy will be downloaded at the risk of it 18 | # being potentially out of date. This is useful for things like CI 19 | # servers, but otherwise letting Skyfield do it's downloading is a 20 | # better idea. 21 | # 22 | mirror_url = "https://bao.chimenet.ca/skyfield/" 23 | 24 | files = ["Leap_Second.dat", "finals2000A.all", "de421.bsp"] 25 | 26 | loader = ctime.skyfield_wrapper.load 27 | for file in files: 28 | if not os.path.exists(loader.path_to(file)): 29 | loader.download(mirror_url + file) 30 | 31 | 32 | def test_epoch(): 33 | # At the J2000 epoch, sidereal time and transit RA should be the same. 34 | epoch = datetime(2000, 1, 1, 11, 58, 56) 35 | 36 | # Create an observer at an arbitrary location 37 | obs = ctime.Observer(118.3, 36.1) 38 | 39 | # Calculate the transit_RA 40 | unix_epoch = ctime.datetime_to_unix(epoch) 41 | TRA = obs.transit_RA(unix_epoch) 42 | LSA = obs.lsa(unix_epoch) 43 | 44 | # Calculate LST 45 | lst = obs.lst(unix_epoch) 46 | 47 | # Tolerance limited by stellar aberation 48 | assert lst == approx(TRA, abs=0.01, rel=1e-10) 49 | assert lst == approx(LSA, abs=0.01, rel=1e-10) 50 | 51 | 52 | def test_transit_array(): 53 | # Do a simple test of transit_RA in an array. Use the fact that the RA 54 | # advances predictably to predict the answers 55 | 56 | epoch = datetime(2000, 1, 1, 11, 58, 56) 57 | 58 | # Create an observer at an arbitrary location 59 | obs = ctime.Observer(118.3, 36.1) 60 | 61 | # Calculate LST 62 | lst = obs.lst(ctime.datetime_to_unix(epoch)) 63 | 64 | # Drift rate should be very close to 1 degree/4minutes. 65 | # Fetch times calculated by ephem 66 | delta_deg = np.arange(20) 67 | lst = lst + delta_deg 68 | 69 | # Calculate RA using transit_RA 70 | unix_epoch = ctime.datetime_to_unix(epoch) 71 | unix_times = unix_epoch + (delta_deg * 60 * 4 * ctime.STELLAR_S) 72 | TRA = obs.transit_RA(unix_times) 73 | LSA = obs.lsa(unix_times) 74 | 75 | # Compare 76 | assert lst == approx(TRA, abs=0.02, rel=1e-10) 77 | assert lst == approx(LSA, abs=0.02, rel=1e-10) 78 | 79 | 80 | def test_delta(): 81 | delta = np.arange(0, 200000, 1000) # Seconds. 82 | # time.time() when I wrote this. No leap seconds for the next few 83 | # days. 84 | 85 | obs = ctime.Observer(118.3, 36.1) 86 | 87 | start = 1383679008.816173 88 | times = start + delta 89 | 90 | # Test for transit RA 91 | start_ra = obs.transit_RA(start) 92 | ra = obs.transit_RA(times) 93 | delta_ra = ra - start_ra 94 | expected = delta / 3600.0 * 15.0 / ctime.SIDEREAL_S 95 | error = ((expected - delta_ra + 180.0) % 360) - 180 96 | # Tolerance limited by stellar aberation (40" peak to peak). 97 | assert error == approx(0, abs=0.02) 98 | 99 | # Test for lsa 100 | start_lsa = obs.lsa(start) 101 | lsa = obs.lsa(times) 102 | delta_lsa = lsa - start_lsa 103 | expected = delta / 3600.0 * 15.0 / ctime.STELLAR_S 104 | error = ((expected - delta_lsa + 180.0) % 360) - 180 105 | # Tolerance limited by stellar aberation (40" peak to peak). 106 | assert error == approx(0, abs=0.02) 107 | 108 | 109 | def test_lsa_skyfield(): 110 | # Check an lsa calculated by caput.time against one calculated by PyEphem 111 | 112 | dt = datetime(2014, 10, 2, 13, 4, 5) 113 | dt_utc = dt.replace(tzinfo=api.utc) 114 | 115 | t1 = ctime.datetime_to_unix(dt) 116 | obs = ctime.Observer(42.8, 4.7) 117 | lsa1 = obs.unix_to_lsa(t1) 118 | 119 | t = ctime.skyfield_wrapper.timescale.utc(dt_utc) 120 | lsa2 = (earthlib.earth_rotation_angle(t.ut1) * 360.0 + obs.longitude) % 360.0 121 | 122 | assert lsa1 == approx(lsa2, abs=1e-4) 123 | 124 | 125 | def test_lsa_tra(): 126 | # Near the epoch transit RA and LRA should be extremely close 127 | 128 | dt = datetime(2001, 2, 3, 4, 5, 6) 129 | 130 | t1 = ctime.datetime_to_unix(dt) 131 | obs = ctime.Observer(118.0, 31.0) 132 | lsa = obs.unix_to_lsa(t1) 133 | tra = obs.transit_RA(t1) 134 | 135 | assert lsa == approx(tra, abs=1e-5) 136 | 137 | 138 | def test_reverse_lsa(): 139 | # Check that the lsa_to_unix routine correctly inverts a call to 140 | # unix_to_lsa 141 | 142 | dt1 = datetime(2018, 3, 12, 1, 2, 3) 143 | t1 = ctime.datetime_to_unix(dt1) 144 | 145 | dt0 = datetime(2018, 3, 12) 146 | t0 = ctime.datetime_to_unix(dt0) 147 | 148 | obs = ctime.Observer(42.8, 4.7) 149 | lsa = obs.unix_to_lsa(t1) 150 | 151 | t2 = obs.lsa_to_unix(lsa, t0) 152 | 153 | assert t1 == approx(t2, abs=1e-2) 154 | 155 | 156 | def test_lsa_array(): 157 | dt = datetime(2000, 1, 1, 12, 0, 0) 158 | 159 | t1 = ctime.datetime_to_unix(dt) 160 | 161 | obs = ctime.Observer(0.0, 0.0) 162 | times = t1 + np.linspace(0, 24 * 3600.0, 25) 163 | lsas = obs.unix_to_lsa(times) 164 | 165 | # Check that the vectorization works correctly 166 | for t, lsa in zip(times, lsas): 167 | assert lsa == obs.unix_to_lsa(t) 168 | 169 | # Check the inverse is correct. The first 24 entries should be correct, 170 | # but the last one should be one sidereal day behind (because times[0] 171 | # was not in the correct sidereal day) 172 | itimes = obs.lsa_to_unix(lsas, times[0]) 173 | assert times[:-1] == approx(itimes[:-1], rel=1e-5, abs=1e-5) 174 | assert (times[-1] - itimes[-1]) == approx(24 * 3600.0 * ctime.SIDEREAL_S, abs=0.1) 175 | 176 | # Check that it works with zero length arrays 177 | assert obs.lsa_to_unix(np.array([]), np.array([])).size == 0 178 | 179 | 180 | def test_lsd(): 181 | """Test Local Earth Rotation Day (LSD) definition.""" 182 | 183 | obs = ctime.Observer(113.2, 62.4) 184 | obs.lsd_start_day = ctime.datetime_to_unix(datetime(2014, 1, 2)) 185 | 186 | # Check the zero point is correct 187 | assert obs.lsd_zero() == obs.lsa_to_unix(0.0, obs.lsd_start_day) 188 | 189 | dt = datetime(2017, 3, 4, 5, 6, 7) 190 | ut = ctime.datetime_to_unix(dt) 191 | 192 | # Check that the fractional part if equal to the transit RA 193 | 194 | assert 360.0 * (obs.unix_to_lsd(ut) % 1.0) == approx(obs.unix_to_lsa(ut), abs=1e-4) 195 | 196 | # Check a specific precalculated CSD 197 | # csd1 = -1.1848262244129479 198 | # self.assertAlmostEqual(ephemeris.csd(et1), csd1, places=7) 199 | 200 | 201 | def test_lsd_array(): 202 | dt = datetime(2025, 1, 1, 12, 0, 0) 203 | 204 | t1 = ctime.datetime_to_unix(dt) 205 | 206 | obs = ctime.Observer(0.0, 0.0) 207 | times = t1 + np.linspace(0, 48 * 3600.0, 25) 208 | lsds = obs.unix_to_lsd(times) 209 | 210 | # Check that the vectorization works correctly 211 | for t, lsd in zip(times, lsds): 212 | assert lsd == obs.unix_to_lsd(t) 213 | 214 | # Check the inverse is correct. 215 | itimes = obs.lsd_to_unix(lsds) 216 | assert times == approx(itimes, rel=1e-5, abs=1e-5) 217 | 218 | # Check that it works with zero length arrays 219 | assert obs.lsd_to_unix(np.array([])).size == 0 220 | 221 | # Check that is works with lists (this was previously a bug) 222 | assert obs.lsd_zero() == approx(obs.lsd_to_unix([0.0, 0.0]), abs=1e-3, rel=0) 223 | 224 | 225 | def test_era_accuracy(): 226 | # Pick a time to check the ERA around 227 | dts = ctime.ensure_unix(datetime(2000, 1, 1)) 228 | 229 | # These should give back the same time, but the accuracy of the STELLAR_S constant 230 | # and Skyfields interpolation of dUT1 limit this. 231 | t0 = ctime.era_to_unix(0, dts) 232 | t1 = ctime.era_to_unix(0, dts - 5 * 3600) 233 | 234 | # The accuracy should be better than a millisecond 235 | assert t0 == approx(t1, abs=1e-3) 236 | 237 | 238 | def test_datetime_to_string(): 239 | dt = datetime(2014, 4, 21, 16, 33, 12, 12356) 240 | fdt = ctime.datetime_to_timestr(dt) 241 | assert fdt == "20140421T163312Z" 242 | 243 | 244 | def test_string_to_datetime(): 245 | dt = ctime.timestr_to_datetime("20140421T163312Z_stone") 246 | ans = datetime(2014, 4, 21, 16, 33, 12) 247 | assert dt == ans 248 | 249 | 250 | def test_from_unix_time(): 251 | """Make sure we are properly parsing the unix time. 252 | 253 | This is as much a test of Skyfield as our code. 254 | """ 255 | 256 | unix_time = random.random() * 2e6 257 | dt = datetime.fromtimestamp(unix_time, timezone.utc) 258 | st = ctime.unix_to_skyfield_time(unix_time) 259 | new_dt = st.utc_datetime() 260 | assert dt.year == new_dt.year 261 | assert dt.month == new_dt.month 262 | assert dt.day == new_dt.day 263 | assert dt.hour == new_dt.hour 264 | assert dt.minute == new_dt.minute 265 | assert dt.second == new_dt.second 266 | 267 | # Skyfield rounds its output at the millisecond level. 268 | assert dt.microsecond == approx(new_dt.microsecond, abs=1000) 269 | 270 | 271 | def test_time_precision(): 272 | """Make sure we have ~0.03 ms precision and that we aren't overflowing 273 | anything at double precision. This number comes from the precision on 274 | Julian date time representations: 275 | http://aa.usno.navy.mil/software/novas/USNOAA-TN2011-02.pdf 276 | """ 277 | 278 | delta = 0.001 # Try a 1 ms shift 279 | unix_time = time.time() 280 | unix_time2 = unix_time + delta 281 | tt1 = ctime.unix_to_skyfield_time(unix_time).tt_calendar() 282 | tt2 = ctime.unix_to_skyfield_time(unix_time2).tt_calendar() 283 | err = abs(tt2[-1] - tt1[-1] - delta) 284 | 285 | assert err < 4e-5 # Check that it is accurate at the 0.03 ms level. 286 | 287 | 288 | def test_datetime_to_unix(): 289 | unix_time = time.time() 290 | dt = datetime.fromtimestamp(unix_time, timezone.utc) 291 | new_unix_time = ctime.datetime_to_unix(dt) 292 | assert new_unix_time == approx(unix_time, abs=1e-5) 293 | 294 | 295 | def test_leap_seconds(): 296 | # 'test_' removed from name to deactivate the test untill this can be 297 | # implemented. 298 | l_second_date = datetime(2009, 1, 1, 0, 0, 0) 299 | l_second_date = ctime.datetime_to_unix(l_second_date) 300 | before = l_second_date - 10000 301 | after = l_second_date + 10000 302 | after_after = l_second_date + 200 303 | assert ctime.leap_seconds_between(before, after) == 1 304 | assert ctime.leap_seconds_between(after, after_after) == 0 305 | 306 | # Check that a period including an extra leap seconds has two increments 307 | l_second2_date = ctime.datetime_to_unix(datetime(2012, 7, 1, 0, 0, 0)) 308 | after2 = l_second2_date + 10000 309 | 310 | assert ctime.leap_seconds_between(before, after2) == 2 311 | 312 | 313 | def test_era_known(): 314 | # Check an ERA calculated by caput.time against one calculated by 315 | # http://dc.zah.uni-heidelberg.de/apfs/times/q/form (note the latter 316 | # uses UT1, so we have maximum precision of 1s) 317 | 318 | dt = datetime(2016, 4, 3, 2, 1, 0) 319 | 320 | t1 = ctime.datetime_to_unix(dt) 321 | era1 = ctime.unix_to_era(t1) 322 | era2 = 221.0 + (52.0 + 50.828 / 60.0) / 60.0 323 | 324 | assert era1 == approx(era2, abs=1e-3) 325 | 326 | # Test another one 327 | dt = datetime(2001, 2, 3, 4, 5, 6) 328 | 329 | t1 = ctime.datetime_to_unix(dt) 330 | era1 = ctime.unix_to_era(t1) 331 | era2 = 194.0 + (40.0 + 11.549 / 60.0) / 60.0 332 | 333 | assert era1 == approx(era2, abs=1e-3) 334 | 335 | 336 | def test_era_inverse(): 337 | # Check a full forward/inverse cycle 338 | dt = datetime(2016, 4, 3, 2, 1, 0) 339 | t1 = ctime.datetime_to_unix(dt) 340 | era = ctime.unix_to_era(t1) 341 | t2 = ctime.era_to_unix(era, t1 - 3600.0) 342 | 343 | # Should be accurate at the 1 ms level 344 | assert t1 == approx(t2, abs=1e-3) 345 | 346 | # Check a full forward/inverse cycle over a leap second boundary 347 | dt = datetime(2009, 1, 1, 3, 0, 0) 348 | t1 = ctime.datetime_to_unix(dt) 349 | era = ctime.unix_to_era(t1) 350 | t2 = ctime.era_to_unix(era, t1 - 6 * 3600.0) 351 | 352 | # Should be accurate at the 10 ms level 353 | assert t1 == approx(t2, abs=1e-2) 354 | 355 | 356 | def test_ensure_unix(): 357 | # Check that ensure_unix is doing its job for both scalar and array 358 | # inputs 359 | 360 | dt = datetime(2016, 4, 3, 2, 1, 0) 361 | dt_list = [datetime(2016, 4, 3, 2, 1, 0), datetime(2016, 4, 3, 2, 1, 0)] 362 | 363 | ut = ctime.datetime_to_unix(dt) 364 | ut_array = ctime.datetime_to_unix(dt_list) 365 | 366 | sf = ctime.unix_to_skyfield_time(ut) 367 | sf_array = ctime.unix_to_skyfield_time(ut_array) 368 | 369 | assert ctime.ensure_unix(dt) == ut 370 | assert ctime.ensure_unix(ut) == ut 371 | 372 | assert ctime.ensure_unix(sf) == approx(ut, abs=1e-3) 373 | 374 | assert (ctime.ensure_unix(dt_list) == ut_array).all() 375 | assert (ctime.ensure_unix(ut_array) == ut_array).all() 376 | assert ctime.ensure_unix(sf_array) == approx(ut_array, rel=1e-10, abs=1e-4) 377 | 378 | # Check that it works for zero length arrays 379 | assert ctime.ensure_unix(np.array([])).size == 0 380 | 381 | 382 | @pytest.fixture 383 | def chime(): 384 | # Position from ch_util.ephemeris on 2020/11/09 385 | return ctime.Observer(lon=-119.62, lat=49.32, alt=545.0) 386 | 387 | 388 | @pytest.fixture(scope="module") 389 | def eph(): 390 | return ctime.skyfield_wrapper.ephemeris 391 | 392 | 393 | def test_transit_times(chime, eph): 394 | # Routines to test the transit time calculations 395 | 396 | dts = datetime(2020, 11, 5) 397 | dte = datetime(2020, 11, 7) 398 | 399 | times = chime.transit_times(eph["sun"], dts, dte) 400 | 401 | # Calculated via the old version of `ch_util.ephemeris.solar_transit(dts, dte)`` 402 | precalc_times = [1604605326.0967, 1604691728.9071] 403 | 404 | # Check that the calculations agree within 2s. This criterion comes from the first 405 | # attempts to use the new routines, and seems reasonable enough that I'm not going 406 | # to track down the difference 407 | assert times == approx(precalc_times, abs=2) 408 | 409 | # Test automatic step calculation 410 | small_times = chime.transit_times( 411 | eph["sun"], precalc_times[0] - 300, precalc_times[0] + 60 412 | ) 413 | 414 | assert small_times == approx(precalc_times[:1], abs=2) 415 | 416 | # end <= start raises ValueError 417 | with raises(ValueError): 418 | chime.transit_times(eph["sun"], dte, dts) 419 | 420 | # step >= interval raises ValueError 421 | with raises(ValueError): 422 | chime.transit_times(eph["sun"], dts, dte, step=10) 423 | 424 | 425 | def test_rise_set_times(chime, eph): 426 | dts = datetime(2020, 11, 5) 427 | dte = datetime(2020, 11, 7) 428 | 429 | # From old version of `ch_util.ephemeris` 430 | precalc_times = np.array( 431 | [1604535925.5065, 1604588383.0504, 1604622231.5790, 1604674881.4216], 432 | dtype=np.float64, 433 | ) 434 | precalc_risings = np.array([False, True, False, True], dtype=bool) 435 | 436 | times, risings = chime.rise_set_times(eph["sun"], dts, dte) 437 | 438 | assert times == approx(precalc_times, abs=2) 439 | assert np.all(risings == precalc_risings) 440 | 441 | risings = chime.rise_times(eph["sun"], dts, dte) 442 | settings = chime.set_times(eph["sun"], dts, dte) 443 | 444 | assert risings == approx(precalc_times[precalc_risings], abs=2) 445 | assert settings == approx(precalc_times[~precalc_risings], abs=2) 446 | 447 | 448 | def test_solar_ephemerides(chime): 449 | """Test solar ephemerides""" 450 | 451 | dts = datetime(2020, 11, 5) 452 | dte = datetime(2020, 11, 7) 453 | 454 | times = chime.solar_transit(dts, dte) 455 | 456 | # Calculated via the old version of `ch_util.ephemeris.solar_transit(dts, dte)`` 457 | precalc_times = [1604605326.0967, 1604691728.9071] 458 | assert times == approx(precalc_times, abs=2) 459 | 460 | # From old version of `ch_util.ephemeris` 461 | # 462 | # NB: these times are different than the ones used in test_rise_set_times 463 | # above, because there both the sun's diameter and atmospheric refraction 464 | # are taken to be zero. 465 | precalc_risings = np.array([1604588047, 1604674544], dtype=np.float64) 466 | 467 | risings = chime.solar_rising(dts, dte) 468 | assert risings == approx(precalc_risings, abs=2) 469 | 470 | # From old version of `ch_util.ephemeris` 471 | precalc_settings = np.array([1604536261, 1604622568], dtype=np.float64) 472 | 473 | settings = chime.solar_setting(dts, dte) 474 | assert settings == approx(precalc_settings, abs=2) 475 | 476 | 477 | def test_lunar_ephemerides(chime): 478 | """Test lunar ephemerides""" 479 | 480 | dts = datetime(2020, 11, 5) 481 | dte = datetime(2020, 11, 7) 482 | 483 | times = chime.lunar_transit(dts, dte) 484 | 485 | # Calculated via the old version of `ch_util.ephemeris.lunar_transit(dts, dte)`` 486 | precalc_times = [1604575728, 1604665306] 487 | assert times == approx(precalc_times, abs=2) 488 | 489 | # From old version of `ch_util.ephemeris`, with adjusted diameter 490 | precalc_risings = np.array([1604545414, 1604634887], dtype=np.float64) 491 | 492 | risings = chime.lunar_rising(dts, dte) 493 | assert risings == approx(precalc_risings, abs=2) 494 | 495 | # From old version of `ch_util.ephemeris`, with adjusted diameter 496 | precalc_settings = np.array([1604606169, 1604695470], dtype=np.float64) 497 | 498 | settings = chime.lunar_setting(dts, dte) 499 | assert settings == approx(precalc_settings, abs=2) 500 | --------------------------------------------------------------------------------