├── .coveragerc ├── .github └── workflows │ └── CI.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── clfd ├── __init__.py ├── _version.py ├── apps │ ├── __init__.py │ ├── cleanup.py │ └── functions.py ├── archive_handler.py ├── features │ ├── __init__.py │ ├── functions.py │ └── register.py ├── profile_masking.py ├── report.py ├── report_plotting.py ├── serialization.py └── spike_finding.py ├── docker └── Dockerfile ├── docs └── report.png ├── example_data ├── npy_example.npy └── psrchive_example.ar ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── expected_profmask.txt ├── expected_tpmask.txt ├── test_cli_app.py ├── test_profile_masking.py ├── test_psrchive_processing.py ├── test_spike_finding_and_replacement.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | concurrency = multiprocessing 3 | parallel = true 4 | sigterm = true 5 | 6 | # Have to do this for some reason 7 | omit = psrchive.py 8 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | # See: https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#example-running-a-job-within-a-container 9 | lint-and-test: 10 | runs-on: ubuntu-latest 11 | container: 12 | image: ghcr.io/v-morello/psrchive:latest 13 | credentials: 14 | username: ${{ github.actor }} 15 | password: ${{ secrets.GITHUB_TOKEN }} 16 | 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | - name: Install 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install .[dev] 24 | - name: Lint 25 | run: | 26 | make lint 27 | - name: Test 28 | run: | 29 | make test 30 | 31 | publish-to-pypi: 32 | # Inspired by: https://stackoverflow.com/a/73385644 33 | # Only run when a git tag is pushed 34 | if: startsWith(github.event.ref, 'refs/tags/') 35 | needs: [lint-and-test] 36 | runs-on: ubuntu-latest 37 | container: 38 | image: ghcr.io/v-morello/psrchive:latest 39 | credentials: 40 | username: ${{ github.actor }} 41 | password: ${{ secrets.GITHUB_TOKEN }} 42 | 43 | steps: 44 | - name: Checkout 45 | uses: actions/checkout@v4 46 | - name: Install pre-requisites 47 | run: | 48 | python -m pip install --upgrade pip 49 | pip install build twine 50 | - name: Build wheel 51 | # NOTE: build creates a .egg-info directory which interferes 52 | # with the pip install command in next step 53 | run: | 54 | python -m build 55 | rm -rf *.egg-info 56 | - name: Install wheel 57 | run: pip install dist/*.whl 58 | - name: Test wheel 59 | run: | 60 | clfd --despike example_data/psrchive_example.ar 61 | ls example_data/psrchive_example.ar.clfd 62 | ls example_data/psrchive_example_clfd_report.json 63 | ls example_data/psrchive_example_clfd_report.png 64 | - name: Upload to PyPI 65 | run: twine upload --repository pypi --username __token__ --password ${{ secrets.PYPI_TOKEN }} dist/*.whl 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .vscode 3 | .vscode/* 4 | *.pyc 5 | __pycache__ 6 | .venv 7 | dist/* 8 | build/ 9 | .pytest_cache 10 | .coverage 11 | scratch/ -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 3 | 4 | 5 | ## 1.0.1 - 2025-01-01 6 | 7 | Slightly improved report plots. 8 | 9 | ### Changed 10 | 11 | - Improved report plot layout; removed frames around each subplot, they had questionable added value and did not always render well at the edges of the image. 12 | 13 | 14 | ## 1.0.0 - 2024-12-30 15 | 16 | First stable release after a complete rewrite. `clfd` is now extensively tested in a CI pipeline. 17 | 18 | ### Added 19 | 20 | - CLI app now saves both a report and an associated PNG plot for each archive processed. 21 | - Report plots now show all the data in one figure: feature values, profile mask and time-phase spike mask (if spike removal was performed). 22 | 23 | ### Fixed 24 | 25 | - Formula for autocorrelation feature `acf` has been corrected. 26 | - CLI app now checks whether PSRCHIVE python bindings are installed. 27 | - CLI app now checks that there are no duplicate input archive names when `--outdir` is specified. Before, there could be name collisions between output files. 28 | - CLI app now validates feature names. 29 | 30 | ### Changed 31 | 32 | - Minimum required Python version is now 3.9. 33 | - By default, CLI app now uses as many parallel processes as there are CPU cores. 34 | - CLI app won't crash if there are archives that cannot be processed. Instead, the name of the offending archive is logged with the error traceback. This choice was made because the app uses multiprocessing. 35 | - Reports are now in a custom JSON format. They can be read using `Report.load()`. 36 | - `matplotlib` is now a required dependency. 37 | - Removed dependencies: `pandas`, `tables`, `scipy`. 38 | 39 | ### Removed 40 | 41 | - CLI app argument `--fmt`. It was pointless considering that the only supported data format is PSRCHIVE. 42 | - CLI app argument `-e, --ext`. The additional extension given to clean output archives is now always `.clfd`, which used to be the default choice. 43 | 44 | ## 0.4.0 - 2024-12-17 45 | 46 | This new version fixes an issue caused by a change in PSRCHIVE Python API, and starts a much needed modernisation effort of `clfd`. **We are dropping support for Python 2.x from this point.** 47 | 48 | ### Fixed 49 | 50 | - Fixed a problem with recent PSRCHIVE versions where the Python API for loading archives has changed. Thanks to Bradley Meyers for discovering and fixing the issue [(#3)](https://github.com/v-morello/clfd/issues/3) 51 | - Fixed a warning from the `pandas` library about future changes to its HDF5 API. 52 | 53 | ### Changed 54 | 55 | - PEP 517/518 compliant packaging, i.e. no more `setup.py`. 56 | 57 | ### Removed 58 | 59 | - Support for Python 2.x 60 | - Removed function `clfd.test()`. Tests are not shipped with the module anymore. 61 | 62 | 63 | ## 0.3.3 - 2021-04-27 64 | ### Added 65 | - New features: skewness, excess kurtosis, autocorrelation. These are not sensitive to the scale of the data, and should perform better on high bit depth data (8-bit and more). 66 | - `scipy` is now an additional dependency (for `skew` and `kurtosis` functions). 67 | 68 | 69 | ## 0.3.2 - 2019-12-06 70 | ### Fixed 71 | - Fixed a bug that can occur with the new psrchive python3 bindings, where the code was passing sub-integration and channel indices as `numpy.int64` to the `archive.get_Profile()` method instead of a python `int`, and a `TypeError` was raised. 72 | 73 | ### Changed 74 | - Updated statement in `README` about psrchive python bindings, as they are now compatible with python 3 in the latest psrchive version. 75 | 76 | ## 0.3.1 - 2019-11-19 77 | ### Added 78 | - The `clfd` command line application now has a `-o` / `--outdir` option to save all data products to a single output directory 79 | 80 | ## 0.3.0 - 2019-07-20 81 | ### Added 82 | - The latest release of `clfd` can now be easily installed via `pip install clfd`. 83 | - The setup script now defines a console script entry point `clfd` which points to the main function of `cleanup.py`. From the point of view of the user, this means a script called `clfd` is automatically placed in the `PATH`, which can be called from anywhere and that simply executes `cleanup.py`. This only works if using an installation method that actually executes the setup script, i.e. it won't work when just cloning the repository and placing it in the `PYTHONPATH`, in which case an alias named `clfd` for `cleanup.py` has to be defined manually. 84 | 85 | ### Changed 86 | - The module name in `setup.py` has been changed from `clfd-pulsar` to simply `clfd`. The original decision was made to avoid potential name collisions, but `clfd` is not currently taken anywhere so let's make things simple from now on. *To avoid any trouble when upgrading to the new version, users should first cleanly uninstall any older versions of* `clfd`, by typing `pip uninstall clfd-pulsar`. 87 | - Rearranged directory structure to make the module installable via `pip`. `apps`, `tests` and `example_data` are now subdirectories of `clfd`. 88 | - Updated installation instructions in `README` 89 | - Updated `README` with a word of warning on the `--despike` command line option 90 | 91 | ## 0.2.3 - 2019-07-17 92 | ### Fixed 93 | - Dependency ``pytables`` in ``setup.py`` should be called ``tables`` when installed via ``pip``, apparently the same package has a different name on conda and PyPI. 94 | - Removed relative import in ``report_plots.py`` that raised an error when using ``clfd`` with python 3+. 95 | - Fixed a bug in ``cleanup.py`` where the list of features was not properly passed down to the core functions, which means that the list of features used was always the default triplet ``std, ptp, lfamp``. 96 | 97 | ### Added 98 | - ``Report`` corner plot now displays the name of the report file. 99 | 100 | ## 0.2.2 - 2019-03-02 101 | ### Added 102 | - ``Report`` now has two method to generate nice plots: ``corner_plot()`` and ``profile_mask_plot()``. The corner plot shows pairwise scatter plots of profile features and individual histograms, and the other shows the 2D profile mask along with the fraction of data masked in each channel and sub-integration. 103 | - ``TODO.md`` with a list of planned features/upgrades/fixes. 104 | 105 | ### Changed 106 | - Improved fix for float saturation issues: do not scale the data anymore when loading an archive into a DataCube, instead use float64 accumulators when computing profile variance and standard deviation. The ``DataCube`` property ``data`` now returns the original data with the baselines (i.e. profile median values) subtracted, while the property ``orig_data`` returns the data exactly as they are read from the archive. 107 | 108 | ## 0.2.1 - 2019-02-15 109 | ### Fixed 110 | - DataCubes are now divided by the median absolute deviation (MAD) of non-zero values only. This avoids problems with archives where more than 50% of the data are equal to zero (may happen on GMRT data for example). 111 | - psrchive interface ``get_frequencies()`` method now works with older psrchive versions, since the ``Archive.get_frequencies()`` method of psrchive seems to be only a recent addition (sometime in 2018). 112 | 113 | ### Added 114 | - ``cleanup.py`` now has a ``--version`` option to print the version number of ``clfd`` and exit. 115 | - Installing ``clfd`` using ``pip`` (via the ``make install`` command) now also installs the ``pytables`` module if not present. 116 | 117 | ## 0.2.0 - 2019-02-08 118 | ### Added 119 | - The ``cleanup.py`` executable script now saves for each archive a ``Report`` object in HDF5 format. Reports contain all outputs produced by the cleaning process in a practical format, including: features, feature statistics (including min and max acceptable values for each), profile mask and time-phase mask (if the script was called with the ``--despike`` option). Note that saving and loading reports require the ``pytables`` python library. 120 | - ``cleanup.py`` now has a ``--no-report`` option if the user does not wish to produce report(s), or does not have ``pytables``. 121 | - The ``stats`` DataFrame returned by the ``profile_mask()`` function now contains the median of each feature 122 | - New ``test()`` function that runs all unit tests 123 | - Version of the module now stored in a unique location (``_version.py``) 124 | - Format interfaces now have a ``get_frequencies()`` method 125 | 126 | ## 0.1.1 - 2019-01-16 127 | ### Added 128 | - When loading an archive into a DataCube object, the data by are now divided by their overall median absolute deviation (MAD) after the baselines of each profile have been removed. This solves float32 saturation issues when computing some profile features (standard deviation and variance) on archives obtained from 8-bit Parkes ultra wide band receiver data. The data are still properly offset and scaled back before performing replacements (when using the `--despike` option). 129 | 130 | ## 0.1.0 - 2018-11-12 131 | ### Added 132 | - First release of clfd -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018-2019 Vincent Morello 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-exclude * 2 | global-include *.py 3 | exclude tests/* 4 | include LICENSE 5 | include pyproject.toml 6 | include README.md 7 | include CHANGELOG.md 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | PKG = clfd 3 | PKG_DIR = clfd/ 4 | LINE_LENGTH = 79 5 | TESTS_DIR = tests/ 6 | 7 | install: ## Install the package in development mode 8 | pip install -e .[dev] 9 | 10 | uninstall: ## Uninstall the package 11 | pip uninstall ${PKG} 12 | rm -rf ${PKG}.egg-info 13 | 14 | # GLORIOUS hack to autogenerate Makefile help 15 | # This simply parses the double hashtags that follow each Makefile command 16 | # https://marmelab.com/blog/2016/02/29/auto-documented-makefile.html 17 | help: ## Print this help message 18 | @echo "Makefile help for clfd" 19 | @echo "==========================" 20 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' 21 | 22 | lint: ## Run linting 23 | isort --check-only --profile black --line-length ${LINE_LENGTH} ${PKG_DIR} ${TESTS_DIR} 24 | flake8 --show-source --statistics --max-line-length ${LINE_LENGTH} ${PKG_DIR} ${TESTS_DIR} 25 | black --exclude .+\.ipynb --check --line-length ${LINE_LENGTH} ${PKG_DIR} ${TESTS_DIR} 26 | 27 | format: ## Apply automatic formatting 28 | black --exclude .+\.ipynb --line-length ${LINE_LENGTH} ${PKG_DIR} ${TESTS_DIR} 29 | isort --profile black --line-length ${LINE_LENGTH} ${PKG_DIR} ${TESTS_DIR} 30 | 31 | test: ## Run unit tests 32 | MPLBACKEND=Agg pytest -vv --cov ${PKG_DIR} 33 | 34 | .PHONY: install uninstall help lint format test 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![CI status](https://github.com/v-morello/clfd/actions/workflows/CI.yml/badge.svg?branch=master) 2 | [![arXiv](http://img.shields.io/badge/astro.ph-1811.04929-B31B1B.svg)](https://arxiv.org/abs/1811.04929) ![License](https://img.shields.io/badge/License-MIT-green.svg) ![Python versions](https://img.shields.io/pypi/pyversions/clfd.svg) 3 | 4 | # clfd 5 | 6 | ``clfd`` stands for **cl**ean **f**olded **d**ata, and implements two interference removal algorithms to be used on _folded_ pulsar search and pulsar timing data. They are based on a simple outlier detection method and require very little human input. These algorithms were initially developed for a re-processing of the High Time Resolution Universe (HTRU) survey, and can be credited with the discovery of several pulsars that would have otherwise been missed. 7 | 8 | ## Citation 9 | 10 | If using ``clfd`` contributes to a project that leads to a scientific publication, please cite 11 | ["The High Time Resolution Universe survey XIV: Discovery of 23 pulsars through GPU-accelerated reprocessing"](https://arxiv.org/abs/1811.04929) 12 | 13 | ## How it works 14 | 15 | A detailed explanation of what ``clfd`` does can be found in section 2.4 of the above article. There are two algorithms: 16 | 17 | - **Profile masking**: Convert each profile (there is one per channel and per sub-integration) to a small set of summary statistics called _features_ (e.g. standard deviation, peak-to-peak difference, ...); then, idenfify outliers in the resulting feature space using [Tukey's rule for outliers](https://en.wikipedia.org/wiki/Outlier#Tukey's_fences). Outlier profiles are flagged, i.e. their PSRCHIVE weights are set to 0. 18 | - **Time-phase spike subtraction**: Without dedispersing, integrate the input folded data cube along the frequency axis; identify outliers in the resulting 2D array, which are flagged as bad (time, phase) bins; replace the data in those bins in the input data cube along the frequency dimension. Replacement values are automatically chosen. 19 | 20 | Here's the report plot from `clfd` after running it on the example archive provided in the repository. The orange dashed lines delimit the acceptable range of values for each feature. Points lying outside correspond to profiles that should be masked. 21 | 22 | ![Example clfd output plot](docs/report.png) 23 | 24 | ## Extra dependencies 25 | 26 | `clfd` has little reason to exist without the [PSRCHIVE](http://psrchive.sourceforge.net/) python bindings, which it uses to read and write folded archives. Unless you are working on an HPC facility where someone has already done the job for you, you will have to build PSRCHIVE with Python support. We hope to provide a guide on how to do this in the future. In the meantime, the Dockerfile in this repository can provide some guidance. 27 | 28 | ## Installation 29 | 30 | ### From PyPI 31 | 32 | To install the latest release: 33 | ``` 34 | pip install clfd 35 | ``` 36 | 37 | The main command-line application should now be available to run, see below for more details. 38 | ``` 39 | clfd --help 40 | ``` 41 | 42 | 43 | ### Editable installation / Contributing 44 | 45 | If you want to freely edit the code or perhaps contribute to development, clone the repository and in the base directory of `clfd` run: 46 | 47 | ```bash 48 | pip install -e .[dev] 49 | ``` 50 | 51 | This performs an [editable install](https://pip.pypa.io/en/latest/reference/pip_install/#editable-installs) with additional development dependencies. 52 | 53 | 54 | ### The PYTHONPATH method 55 | 56 | If you are not allowed to install packages with ``pip`` (this may be the case on some computing clusters), or perhaps want to work around the idiosyncratic installation process of the PSRCHIVE Python bindings, then you can clone the repository and add the base directory of ``clfd`` to your ``PYTHONPATH`` environment variable, but then: 57 | 1. You have to install the required dependencies manually. 58 | 2. The main command-line application `clfd` (see below) will **NOT** be made available in your `PATH` automatically. 59 | 60 | We warmly recommend using one of the methods above unless you have no other option. 61 | 62 | ## Usage 63 | 64 | For detailed help on every option: 65 | 66 | ``` 67 | clfd --help 68 | ``` 69 | 70 | A typical command might be: 71 | 72 | ``` 73 | clfd --features std ptp lfamp --despike *.ar 74 | ``` 75 | 76 | Note that profile masking is systematically performed, but spike subtraction is optional and can be enabled with `--despike`. 77 | 78 | By default, `clfd` saves three files per input archive: 79 | - A cleaned archive with an additional `.clfd` extension appended 80 | - A PNG report plot as shown above 81 | - A so-called report, which contains all intermediate results of the cleaning process, in a customised JSON format. 82 | 83 | Reports can be loaded and interacted with as follows: 84 | ```python 85 | from clfd import Report 86 | r = Report.load(".json") 87 | ``` 88 | 89 | ### What profile features to use 90 | 91 | The user may choose a subset of the following features on the command line (see below): 92 | - `ptp`: peak to peak difference 93 | - `std`: standard deviation 94 | - `var`: variance 95 | - `lfamp`: amplitude of second bin in the Fourier transform of the profile 96 | - `skew`: skewness 97 | - `kurtosis`: excess kurtosis 98 | - `acf`: autocorrelation with a lag of 1 phase bin 99 | 100 | **The choice of profile features should be motivated by the dynamic range of the data.** Depending on the digitization / scaling scheme employed by the telescope backend, on some observing systems the mean and scale of the data may vary wildly across the band and the mean / standard deviation of the data in a given channel do not correlate well with the presence of interference. In such cases, you should use features that are insensitive to the scale of the data. Our recommendation is below, but feel free to experiment: 101 | - Low dynamic range: `std`, `ptp` and/or `lfamp`. Recommended for older 1-bit and 2-bit Parkes multibeam receiver data for example. 102 | - High dynamic range: `skew`, `kurtosis`, and/or `acf`. Recommended for any Parkes UWL data in particular. 103 | -------------------------------------------------------------------------------- /clfd/__init__.py: -------------------------------------------------------------------------------- 1 | # NOTE: this must be imported first 2 | from ._version import __version__ 3 | from .archive_handler import ArchiveHandler 4 | from .profile_masking import profile_mask 5 | from .report import Report 6 | from .spike_finding import find_time_phase_spikes 7 | 8 | __all__ = [ 9 | "profile_mask", 10 | "find_time_phase_spikes", 11 | "ArchiveHandler", 12 | "Report", 13 | "__version__", 14 | ] 15 | -------------------------------------------------------------------------------- /clfd/_version.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | 3 | __version__ = version(__package__) 4 | -------------------------------------------------------------------------------- /clfd/apps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v-morello/clfd/7338a80b3e834818ca17b83b8a051ccfebd405e9/clfd/apps/__init__.py -------------------------------------------------------------------------------- /clfd/apps/cleanup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import logging.config 4 | import multiprocessing 5 | import os 6 | import sys 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from typing import Iterable 10 | 11 | from clfd import __version__ 12 | from clfd.features import available_features 13 | 14 | from .functions import Worker, load_zapfile 15 | 16 | log = logging.getLogger("clfd") 17 | 18 | 19 | def make_parser(): 20 | parser = argparse.ArgumentParser( 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 22 | description=( 23 | "Apply smart RFI cleaning algorithms to folded data archives. " 24 | "Version: {}".format(__version__) 25 | ), 26 | ) 27 | parser.add_argument( 28 | "-o", 29 | "--outdir", 30 | type=validate_output_dir, 31 | default=None, 32 | help=( 33 | "Output directory for the data products. If not specified, the " 34 | "products corresponding to a given input file are written in the " 35 | "same directory as that file." 36 | ), 37 | ) 38 | parser.add_argument( 39 | "-z", 40 | "--zapfile", 41 | type=str, 42 | default=None, 43 | help=( 44 | "Optional text file with a list of frequency channels " 45 | "to forcibly mask and exclude from the analysis. Every line " 46 | "must be a channel index to mask." 47 | ), 48 | ) 49 | parser.add_argument( 50 | "-f", 51 | "--features", 52 | type=str, 53 | action="store", 54 | choices=sorted(available_features().keys()), 55 | metavar="FEAT_NAME", 56 | nargs="+", 57 | default=["std", "ptp", "lfamp"], 58 | help=( 59 | "List of profile features to use for the profile masking " 60 | "algorithm, separated by spaces. " 61 | f"Choices: {sorted(available_features().keys())}" 62 | ), 63 | ) 64 | parser.add_argument( 65 | "-q", 66 | "--qmask", 67 | type=float, 68 | default=2.0, 69 | help=( 70 | "Tukey's rule parameter for the profile masking algorithm. " 71 | "Larger values result in fewer outliers." 72 | ), 73 | ) 74 | parser.add_argument( 75 | "--despike", 76 | action="store_true", 77 | default=False, 78 | help=( 79 | "Apply the time-phase spike subtraction algorithm. " 80 | "WARNING: can attenuate / remove bright individual pulses " 81 | "from a low-DM pulsar. Can also fail in particularly bad RFI " 82 | "environments." 83 | ), 84 | ) 85 | parser.add_argument( 86 | "--qspike", 87 | type=float, 88 | default=4.0, 89 | help=( 90 | "Tukey's rule parameter for the time-phase spike subtraction " 91 | "algorithm. Larger values result in fewer outliers." 92 | ), 93 | ) 94 | parser.add_argument( 95 | "-p", 96 | "--processes", 97 | type=int, 98 | default=default_num_processes(), 99 | help="Number of parallel processes to use.", 100 | ) 101 | parser.add_argument( 102 | "--no-report", 103 | action="store_true", 104 | default=False, 105 | help="Do not save reports and associated plots.", 106 | ) 107 | parser.add_argument( 108 | "--version", 109 | action="version", 110 | version=__version__, 111 | help="Print version number and exit", 112 | ) 113 | parser.add_argument( 114 | "archives", 115 | type=str, 116 | nargs="+", 117 | help="Input PSRCHIVE archive(s)", 118 | ) 119 | return parser 120 | 121 | 122 | def default_num_processes() -> int: 123 | n = os.cpu_count() 124 | return n if n is not None else 1 125 | 126 | 127 | def validate_output_dir(path): 128 | """Function that checks the outdir argument""" 129 | if not os.path.isdir(path): 130 | msg = "Specified output directory {path!r} does not exist" 131 | raise argparse.ArgumentTypeError(msg) 132 | return path 133 | 134 | 135 | def configure_logging(): 136 | config = { 137 | "version": 1, 138 | "formatters": { 139 | "default": { 140 | "format": "[%(levelname)5s - %(asctime)s] %(message)s", 141 | }, 142 | }, 143 | "handlers": { 144 | "console": { 145 | "class": "logging.StreamHandler", 146 | "formatter": "default", 147 | "stream": "ext://sys.stdout", 148 | }, 149 | }, 150 | "loggers": { 151 | "clfd": { 152 | "level": "DEBUG", 153 | "handlers": ["console"], 154 | } 155 | }, 156 | } 157 | logging.config.dictConfig(config) 158 | 159 | 160 | def paths_with_duplicate_file_names(paths: Iterable[Path]) -> list[Path]: 161 | filename_path_mapping = defaultdict(list) 162 | for path in paths: 163 | filename_path_mapping[path.name].append(path) 164 | 165 | duplicates = [] 166 | for path_list in filename_path_mapping.values(): 167 | if len(path_list) > 1: 168 | duplicates.extend(path_list) 169 | 170 | return list(map(Path.resolve, duplicates)) 171 | 172 | 173 | def assert_psrchive_installed(): 174 | try: 175 | import psrchive # noqa: F401 176 | except ImportError: 177 | raise ImportError( 178 | "Could not import the PSRCHIVE Python bindings, which clfd's CLI " 179 | "app requires." 180 | ) from None 181 | 182 | 183 | def run_program(cli_args: list[str]): 184 | configure_logging() 185 | parser = make_parser() 186 | args = parser.parse_args(cli_args) 187 | assert_psrchive_installed() 188 | 189 | archive_paths = set(Path(ar) for ar in args.archives) 190 | 191 | if args.outdir and ( 192 | duplicates := paths_with_duplicate_file_names(archive_paths) 193 | ): 194 | msg = ( 195 | "There are duplicate input archive file names, while 'outdir' has " 196 | "been specified. This is not allowed, otherwise there would be " 197 | "output file name collisions. Offending paths:\n" 198 | ) 199 | msg += "\n".join(map(str, duplicates)) 200 | raise ValueError(msg) 201 | 202 | log.info(f"Files to process: {len(archive_paths)}") 203 | 204 | zap_channels = load_zapfile(args.zapfile) if args.zapfile else [] 205 | log.info(f"Ignoring {len(zap_channels)} channel indices: {zap_channels}") 206 | log.info(f"Using profile features: {', '.join(args.features)}") 207 | log.info(f"Using {args.processes} parallel processes") 208 | 209 | worker = Worker( 210 | { 211 | "zap_channels": zap_channels, 212 | "outdir": args.outdir, 213 | "features": args.features, 214 | "qmask": args.qmask, 215 | "despike": args.despike, 216 | "qspike": args.qspike, 217 | "save_report": not args.no_report, 218 | } 219 | ) 220 | 221 | with multiprocessing.Pool(processes=args.processes) as pool: 222 | for result in pool.imap_unordered(worker, archive_paths): 223 | if result.output_path: 224 | log.info(f"Finished: {result.output_path.resolve()!s}") 225 | elif result.traceback: 226 | log.error( 227 | f"Failed to process: {result.input_path.resolve()!s}\n" 228 | f"{result.traceback}" 229 | ) 230 | pool.close() 231 | pool.join() 232 | log.info("Done.") 233 | 234 | 235 | def main(): 236 | run_program(sys.argv[1:]) 237 | -------------------------------------------------------------------------------- /clfd/apps/functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Optional, Sequence, Union 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from clfd import ArchiveHandler, Report, find_time_phase_spikes, profile_mask 10 | 11 | 12 | def load_zapfile(fname: Union[str, os.PathLike]) -> list[int]: 13 | """ 14 | Load zapfile into a list of channel indices. 15 | """ 16 | with open(fname, "r") as file: 17 | zap_channels = list(map(int, file.read().strip().split())) 18 | return zap_channels 19 | 20 | 21 | def process_file( 22 | path: Union[str, os.PathLike], 23 | zap_channels: list[int], 24 | outdir: Optional[Union[str, os.PathLike]] = None, 25 | features: Sequence[str] = ("std", "ptp", "lfamp"), 26 | qmask: float = 2.0, 27 | despike: bool = False, 28 | qspike: float = 4.0, 29 | save_report: bool = False, 30 | ) -> Path: 31 | """ 32 | Process a single archive file end to end. Returns the path to the output 33 | archive. 34 | """ 35 | path = Path(path).resolve() 36 | handler = ArchiveHandler(path) 37 | cube = handler.data_cube() 38 | pm_result = profile_mask( 39 | cube, features=features, q=qmask, zap_channels=zap_channels 40 | ) 41 | handler.apply_profile_mask(pm_result.mask) 42 | 43 | if despike: 44 | sf_result, plan = find_time_phase_spikes( 45 | cube, q=qspike, zap_channels=zap_channels 46 | ) 47 | handler.apply_spike_subtraction_plan(plan) 48 | else: 49 | sf_result = None 50 | 51 | outdir = Path(outdir).resolve() if outdir else path.parent 52 | output_path = outdir / (path.name + ".clfd") 53 | handler.save(output_path) 54 | 55 | if save_report: 56 | report = Report( 57 | profile_masking_result=pm_result, 58 | spike_finding_result=sf_result, 59 | archive_path=path, 60 | ) 61 | report_path = outdir / (path.stem + "_clfd_report.json") 62 | report.save(report_path) 63 | 64 | plot_path = report_path.with_suffix(".png") 65 | plt.switch_backend("Agg") 66 | report.plot().savefig(plot_path) 67 | 68 | return output_path 69 | 70 | 71 | @dataclass 72 | class WorkerResult: 73 | input_path: Path 74 | output_path: Optional[Path] = None 75 | traceback: Optional[str] = None 76 | 77 | 78 | class Worker: 79 | """ 80 | Callable applied by a process Pool. 81 | """ 82 | 83 | def __init__(self, kwargs: dict): 84 | self.kwargs = kwargs 85 | 86 | def __call__(self, path: Path) -> WorkerResult: 87 | kwargs = self.kwargs | {"path": path} 88 | try: 89 | output_path = process_file(**kwargs) 90 | return WorkerResult(path, output_path, None) 91 | except Exception: 92 | return WorkerResult(path, None, traceback.format_exc()) 93 | -------------------------------------------------------------------------------- /clfd/archive_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | 4 | import numpy as np 5 | from numpy.typing import NDArray 6 | 7 | from clfd.spike_finding import SpikeSubtractionPlan 8 | 9 | 10 | class ArchiveHandler: 11 | """ 12 | Simple wrapper for a psrchive.Archive object, which allows editing it. 13 | """ 14 | 15 | def __init__(self, path: Union[str, os.PathLike]): 16 | import psrchive 17 | 18 | if hasattr(psrchive, "Archive_load"): 19 | loader = psrchive.Archive_load 20 | else: 21 | loader = psrchive.Archive.load 22 | 23 | self._archive = loader(str(path)) 24 | 25 | def data_cube(self) -> NDArray: 26 | """ 27 | Return the archive data as a 3-dimensional numpy array of shape 28 | (num_subints, num_chans, num_bins). Only Stokes I data is read. 29 | """ 30 | return self._archive.get_data()[:, 0, :, :] 31 | 32 | def apply_profile_mask(self, mask: NDArray): 33 | """ 34 | Apply profile mask to underlying archive, setting the weights of masked 35 | profiles to zero. 36 | """ 37 | ipol = 0 38 | for isub, ichan in zip(*np.where(mask)): 39 | # NOTE: cast indices from numpy.int64 to int, otherwise 40 | # get_Profile() complains about argument type 41 | prof = self._archive.get_Profile(int(isub), ipol, int(ichan)) 42 | prof.set_weight(0.0) 43 | 44 | def apply_spike_subtraction_plan(self, plan: SpikeSubtractionPlan): 45 | """ 46 | Set the values of data inside bad time-phase bins to appropriate 47 | replacement values. 48 | """ 49 | ipol = 0 50 | repvals = plan.replacement_values 51 | mapping = plan.subint_to_bad_phase_bins_mapping() 52 | 53 | for isub, bad_bins in mapping.items(): 54 | for ichan in plan.valid_channels: 55 | # NOTE: cast indices from numpy.int64 to int, otherwise 56 | # get_Profile() complains about argument type 57 | prof = self._archive.get_Profile(isub, ipol, int(ichan)) 58 | amps = prof.get_amps() 59 | amps[bad_bins] = repvals[isub, ichan, bad_bins] 60 | 61 | def save(self, path: Union[str, os.PathLike]): 62 | """ 63 | Save archive to given path. 64 | """ 65 | self._archive.unload(str(path)) 66 | -------------------------------------------------------------------------------- /clfd/features/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import acf, kurtosis, lfamp, ptp, skew, std, var 2 | from .register import available_features, get_feature, register_feature 3 | 4 | for func in (acf, kurtosis, lfamp, ptp, skew, std, var): 5 | register_feature(func) 6 | 7 | 8 | __all__ = [ 9 | "available_features", 10 | "get_feature", 11 | "register_feature", 12 | ] 13 | 14 | __all__.extend(list(available_features().keys())) 15 | -------------------------------------------------------------------------------- /clfd/features/functions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | from numpy.typing import NDArray 5 | 6 | 7 | def ptp(cube: NDArray): 8 | """Peak-to-peak difference""" 9 | return cube.ptp(axis=-1) 10 | 11 | 12 | def std(cube: NDArray): 13 | """Standard deviation""" 14 | # NOTE: use a float64 accumulator to avoid saturation issues 15 | return cube.std(axis=-1, dtype=np.float64) 16 | 17 | 18 | def var(cube: NDArray): 19 | """Variance""" 20 | # NOTE: use a float64 accumulator to avoid saturation issues 21 | return cube.var(axis=-1, dtype=np.float64) 22 | 23 | 24 | def lfamp(cube: NDArray): 25 | """Amplitude of second Fourier bin""" 26 | ft = np.fft.rfft(cube, axis=-1) 27 | return abs(ft[:, :, 1]) 28 | 29 | 30 | def acf(cube: NDArray): 31 | """Autocorrelation function with a lag of 1 phase bin""" 32 | X = cube 33 | m = X.mean(axis=-1, dtype=np.float64, keepdims=True) 34 | v = X.var(axis=-1, dtype=np.float64) 35 | v[v == 0] = np.inf 36 | acov = np.mean( 37 | (X[..., :-1] - m) * (X[..., 1:] - m), dtype=np.float64, axis=-1 38 | ) 39 | return acov / v 40 | 41 | 42 | def skew(cube: NDArray): 43 | """ 44 | Skewness. Sample bias is not removed. Returns 0 for constant data. 45 | """ 46 | # Work in float64 to avoid overflow 47 | m1 = cube.mean(axis=-1, keepdims=True, dtype=np.float64) 48 | m2 = _moment(cube, m1, 2, axis=-1) 49 | m3 = _moment(cube, m1, 3, axis=-1) 50 | with np.errstate(invalid="ignore"): 51 | return np.where(m2 == 0, 0.0, m3 / m2**1.5) 52 | 53 | 54 | def kurtosis(cube: NDArray): 55 | """ 56 | Excess kurtosis. Sample bias is not removed. Returns +inf for constant 57 | data. 58 | """ 59 | # Work in float64 to avoid overflow 60 | m1 = cube.mean(axis=-1, keepdims=True, dtype=np.float64) 61 | m2 = _moment(cube, m1, 2, axis=-1) 62 | m4 = _moment(cube, m1, 4, axis=-1) 63 | with np.errstate(invalid="ignore"): 64 | return np.where(m2 == 0, np.inf, m4 / m2**2 - 3) 65 | 66 | 67 | def _moment( 68 | x: NDArray, mean: NDArray, order: int, *, axis: Optional[int] = None 69 | ) -> NDArray: 70 | """ 71 | Raw statistical moment; mean of data must be externally provided with the 72 | correct shape. 73 | """ 74 | return np.mean((x - mean) ** order, axis=axis) 75 | -------------------------------------------------------------------------------- /clfd/features/register.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | AVAILABLE_FEATURES: dict[str, Callable] = {} 4 | 5 | 6 | def register_feature(func: Callable): 7 | """ 8 | Register feature function for use in profile masking. 9 | """ 10 | if func.__name__ in AVAILABLE_FEATURES: 11 | raise ValueError(f"A feature named {func.__name__!r} already exists") 12 | AVAILABLE_FEATURES[func.__name__] = func 13 | 14 | 15 | def available_features() -> dict[str, Callable]: 16 | """ 17 | Returns a dictionary {name: func} of available feature functions. 18 | """ 19 | return dict(AVAILABLE_FEATURES) 20 | 21 | 22 | def get_feature(name: str) -> Callable: 23 | """ 24 | Get feature function with given name. 25 | """ 26 | func = AVAILABLE_FEATURES.get(name, None) 27 | if func is None: 28 | raise KeyError(f"No feature named {name!r}") 29 | return func 30 | -------------------------------------------------------------------------------- /clfd/profile_masking.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from dataclasses import dataclass 3 | from typing import Iterable 4 | 5 | import numpy as np 6 | from numpy.typing import NDArray 7 | 8 | from clfd.features import get_feature 9 | from clfd.serialization import JSONSerializableDataclass 10 | 11 | 12 | @dataclass(frozen=True) 13 | class Stats(JSONSerializableDataclass): 14 | """ 15 | Stores quantiles of a profile feature. 16 | """ 17 | 18 | q1: float 19 | med: float 20 | q3: float 21 | 22 | @property 23 | def iqr(self) -> float: 24 | return self.q3 - self.q1 25 | 26 | def vmin(self, q: float) -> float: 27 | return self.q1 - q * self.iqr 28 | 29 | def vmax(self, q: float) -> float: 30 | return self.q3 + q * self.iqr 31 | 32 | 33 | @dataclass(frozen=True) 34 | class ProfileMaskingResult(JSONSerializableDataclass): 35 | """ 36 | Stores inputs, intermediate results, and outputs of the profile masking 37 | process. 38 | """ 39 | 40 | q: float 41 | """ 42 | Tukey parameter that was used for deriving the mask. 43 | """ 44 | 45 | zap_channels: list[int] 46 | """ 47 | List of channel indices that were ignored in the analysis and then 48 | forcibly masked. 49 | """ 50 | 51 | feature_values: dict[str, NDArray] 52 | """ 53 | A dictionary where keys are feature names (e.g., `std`, `ptp`, etc.) 54 | and values are arrays containing the computed feature values for the data 55 | cube. 56 | """ 57 | 58 | feature_stats: dict[str, Stats] 59 | """ 60 | A dictionary where keys are feature names and values are `Stats` objects 61 | for the corresponding features. 62 | """ 63 | 64 | mask: NDArray 65 | """ 66 | A boolean array of shape (num_subints, num_chans), where `True` indicates 67 | bad profiles that should be masked. 68 | """ 69 | 70 | 71 | def make_feature_values_dict( 72 | cube: NDArray, features: Iterable[str] 73 | ) -> dict[str, NDArray]: 74 | return {name: get_feature(name)(cube) for name in features} 75 | 76 | 77 | def make_feature_stats(feature_values: NDArray, keep_mask: NDArray) -> Stats: 78 | q1, med, q3 = np.percentile(feature_values[:, keep_mask], (25, 50, 75)) 79 | return Stats(q1, med, q3) 80 | 81 | 82 | def make_feature_stats_dict( 83 | feature_values_dict: dict[str, NDArray], keep_mask: NDArray 84 | ) -> dict[str, Stats]: 85 | return { 86 | name: make_feature_stats(values, keep_mask) 87 | for name, values in feature_values_dict.items() 88 | } 89 | 90 | 91 | def make_feature_mask( 92 | feature_values: NDArray, feature_stats: Stats, q: float 93 | ) -> NDArray: 94 | vmin = feature_stats.vmin(q) 95 | vmax = feature_stats.vmax(q) 96 | return (feature_values < vmin) | (feature_values > vmax) 97 | 98 | 99 | def make_feature_masks( 100 | feature_values_dict: dict[str, NDArray], 101 | feature_stats_dict: dict[str, NDArray], 102 | q: float, 103 | ) -> list[NDArray]: 104 | result = [] 105 | for name in feature_values_dict: 106 | values = feature_values_dict[name] 107 | stats = feature_stats_dict[name] 108 | result.append(make_feature_mask(values, stats, q=q)) 109 | return result 110 | 111 | 112 | def make_in_bounds_zap_indices_and_mask( 113 | zap_channels: Iterable[int], num_chan: int 114 | ) -> tuple[NDArray, NDArray]: 115 | zap_channels = [i for i in zap_channels if i in range(0, num_chan)] 116 | zap_mask = np.zeros(num_chan, dtype=bool) 117 | zap_mask[zap_channels] = True 118 | return zap_channels, zap_mask 119 | 120 | 121 | def profile_mask( 122 | cube: NDArray, 123 | features: Iterable[str] = ("std", "ptp", "lfamp"), 124 | q: float = 2.0, 125 | zap_channels: Iterable[int] = (), 126 | ) -> ProfileMaskingResult: 127 | """ 128 | Generate a masking profile for a data cube based on statistical features. 129 | 130 | This function analyzes a 3D data cube along its second axis to compute 131 | statistical features for each profile. It then applies thresholds to these 132 | features to generate a mask, optionally "zapping" specific channels. 133 | 134 | Parameters 135 | ---------- 136 | cube : NDArray 137 | Input data cube with shape (num_subints, num_chans, num_bins). 138 | features : Iterable[str], optional 139 | A list of profile feature names to calculate and use for masking. 140 | q : float, optional 141 | Parameter that controls the min and max values that define the 'inlier' 142 | or 'normality' range. For every feature, the first and third quartiles 143 | (Q1 and Q3) are calculated, and R = Q3 - Q1 is the interquartile range. 144 | The min and max acceptable values are then defined as: 145 | 146 | vmin = Q1 - q x R 147 | vmax = Q3 + q x R 148 | 149 | The original recommendation of Tukey is q = 1.5. 150 | zap_channels : Iterable[int], optional 151 | A list of channel indices to be ignored in the analysis and then 152 | forcibly masked at the end. 153 | 154 | Returns 155 | ------- 156 | ProfileMaskingResult 157 | """ 158 | num_chans = cube.shape[1] 159 | zap_channels, zap_mask = make_in_bounds_zap_indices_and_mask( 160 | zap_channels, num_chans 161 | ) 162 | if len(zap_channels) == num_chans: 163 | raise ValueError("Cannot run profile masking with all channels zapped") 164 | 165 | feature_values = make_feature_values_dict(cube, features) 166 | feature_stats = make_feature_stats_dict( 167 | feature_values, np.logical_not(zap_mask) 168 | ) 169 | 170 | profile_mask = functools.reduce( 171 | np.logical_or, make_feature_masks(feature_values, feature_stats, q) 172 | ) 173 | profile_mask[:, zap_channels] = True 174 | 175 | return ProfileMaskingResult( 176 | q=float(q), 177 | zap_channels=sorted(zap_channels), 178 | feature_values=feature_values, 179 | feature_stats=feature_stats, 180 | mask=profile_mask, 181 | ) 182 | -------------------------------------------------------------------------------- /clfd/report.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | 6 | from clfd import __version__ 7 | from clfd.profile_masking import ProfileMaskingResult 8 | from clfd.serialization import JSONSerializableDataclass, json_dump, json_load 9 | from clfd.spike_finding import SpikeFindingResult 10 | 11 | 12 | @dataclass(frozen=True) 13 | class Report(JSONSerializableDataclass): 14 | """ 15 | The intermediate and final results of the whole data cleaning process. 16 | """ 17 | 18 | profile_masking_result: ProfileMaskingResult 19 | spike_finding_result: Optional[SpikeFindingResult] = None 20 | archive_path: Optional[str] = None 21 | version: str = __version__ 22 | 23 | def __post_init__(self): 24 | object.__setattr__( 25 | self, "archive_path", str(Path(self.archive_path).resolve()) 26 | ) 27 | 28 | def plot(self): 29 | """ 30 | Plot the report, returning a matplotlib Figure. 31 | """ 32 | from clfd.report_plotting import plot_report 33 | 34 | return plot_report(self) 35 | 36 | def save(self, path: Union[str, os.PathLike]): 37 | """ 38 | Save to file in JSON format. 39 | """ 40 | with open(path, "w") as file: 41 | json_dump(self, file, indent=4) 42 | 43 | @classmethod 44 | def load(cls, path: Union[str, os.PathLike]) -> "Report": 45 | """ 46 | Load from file in JSON format. 47 | """ 48 | with open(path, "r") as file: 49 | return json_load(file) 50 | -------------------------------------------------------------------------------- /clfd/report_plotting.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections.abc import Sequence 3 | from dataclasses import dataclass 4 | from typing import Iterable 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from matplotlib.gridspec import GridSpec, SubplotSpec 9 | from matplotlib.patches import Rectangle 10 | from matplotlib.pyplot import Axes, Figure 11 | from numpy.typing import NDArray 12 | 13 | from clfd import Report 14 | from clfd.profile_masking import Stats 15 | 16 | 17 | @dataclass 18 | class PlotRegion: 19 | left: float 20 | right: float 21 | bottom: float 22 | top: float 23 | 24 | @property 25 | def width(self) -> float: 26 | return self.right - self.left 27 | 28 | @property 29 | def height(self) -> float: 30 | return self.top - self.bottom 31 | 32 | @property 33 | def xmid(self) -> float: 34 | return (self.left + self.right) / 2 35 | 36 | @property 37 | def ymid(self) -> float: 38 | return (self.top + self.bottom) / 2 39 | 40 | def scaled(self, xscale: float, yscale: float) -> "PlotRegion": 41 | return PlotRegion( 42 | left=self.xmid - xscale * self.width / 2, 43 | right=self.xmid + xscale * self.width / 2, 44 | bottom=self.ymid - yscale * self.height / 2, 45 | top=self.ymid + yscale * self.height / 2, 46 | ) 47 | 48 | 49 | class Frame: 50 | """ 51 | Container for a plot, delimited by a rectangle and with a title box at the 52 | top. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | title: str, 58 | *, 59 | left: float, 60 | right: float, 61 | top: float, 62 | bottom: float, 63 | ): 64 | self.title = title 65 | self.gridspec = GridSpec( 66 | 2, 67 | 1, 68 | left=left, 69 | right=right, 70 | top=top, 71 | bottom=bottom, 72 | height_ratios=(1, 24), 73 | hspace=0, 74 | ) 75 | title_box_axes = draw_axes_with_outer_frame_only(self.gridspec[0, 0]) 76 | draw_axes_with_outer_frame_only(self.gridspec[1, 0]) 77 | draw_title_box_filling_axes(title_box_axes, title) 78 | 79 | def usable_plot_region(self) -> PlotRegion: 80 | g = self.gridspec 81 | hup, hdown = g.get_height_ratios() 82 | usable_height_fraction = hdown / (hdown + hup) 83 | return PlotRegion( 84 | left=g.left, 85 | right=g.right, 86 | bottom=g.bottom, 87 | top=g.bottom + usable_height_fraction * (g.top - g.bottom), 88 | ) 89 | 90 | 91 | class FrameRow(Sequence[Frame]): 92 | """ 93 | Row of frames. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | titles: Iterable[str], 99 | width_ratios: Iterable[float], 100 | ): 101 | left = 0 102 | widths = [r / sum(width_ratios) for r in width_ratios] 103 | 104 | self._frames: list[Frame] = [] 105 | for title, width in zip(titles, widths): 106 | frame = Frame( 107 | title, left=left, right=left + width, bottom=0.0, top=1.0 108 | ) 109 | self._frames.append(frame) 110 | left += width 111 | 112 | def __getitem__(self, index: int) -> Frame: 113 | return self._frames[index] 114 | 115 | def __len__(self): 116 | return len(self._frames) 117 | 118 | 119 | def draw_axes_with_outer_frame_only(spec: SubplotSpec) -> plt.Axes: 120 | axes = plt.subplot(spec) 121 | axes.set_facecolor("w") 122 | axes.tick_params( 123 | axis="both", 124 | which="both", 125 | bottom=False, 126 | top=False, 127 | left=False, 128 | right=False, 129 | labelleft=False, 130 | labelright=False, 131 | labeltop=False, 132 | labelbottom=False, 133 | ) 134 | axes.set_frame_on(False) 135 | return axes 136 | 137 | 138 | def draw_title_box_filling_axes(axes: Axes, title: str): 139 | rectangle = Rectangle( 140 | (0, 0), 141 | width=1, 142 | height=1, 143 | facecolor="#054a91", 144 | ) 145 | axes.add_patch(rectangle) 146 | axes.text( 147 | 0.5, 148 | 0.5, 149 | title, 150 | fontsize=16, 151 | fontweight="bold", 152 | color="w", 153 | horizontalalignment="center", 154 | verticalalignment="center", 155 | ) 156 | 157 | 158 | def plot_report(report: Report) -> Figure: 159 | fig = plt.figure(figsize=(20, 10), dpi=80) 160 | feature_values_frame, profile_mask_frame, spike_mask_frame = FrameRow( 161 | ["Feature Values", "Profile Mask", "Time-Phase Spikes"], [9, 4, 4] 162 | ) 163 | 164 | pm = report.profile_masking_result 165 | region = feature_values_frame.usable_plot_region().scaled( 166 | xscale=0.8, yscale=0.9 167 | ) 168 | kwargs = { 169 | "left": region.left, 170 | "right": region.right, 171 | "bottom": region.bottom, 172 | "top": region.top, 173 | "wspace": 0.04, 174 | "hspace": 0.04, 175 | } 176 | corner_plot(pm.feature_values, pm.feature_stats, pm.q, **kwargs) 177 | 178 | region = profile_mask_frame.usable_plot_region().scaled( 179 | xscale=0.8, yscale=0.9 180 | ) 181 | kwargs = { 182 | "left": region.left, 183 | "right": region.right, 184 | "bottom": region.bottom, 185 | "top": region.top, 186 | } 187 | mask_plot(pm.mask, "subint", "channel", **kwargs) 188 | 189 | if report.spike_finding_result is not None: 190 | region = spike_mask_frame.usable_plot_region().scaled( 191 | xscale=0.8, yscale=0.9 192 | ) 193 | kwargs = { 194 | "left": region.left, 195 | "right": region.right, 196 | "bottom": region.bottom, 197 | "top": region.top, 198 | } 199 | mask_plot( 200 | report.spike_finding_result.mask, "subint", "phase bin", **kwargs 201 | ) 202 | return fig 203 | 204 | 205 | def corner_plot( 206 | feature_values: dict[str, NDArray], 207 | feature_stats: dict[str, Stats], 208 | q: float, 209 | **gridspec_kwargs, 210 | ): 211 | """ 212 | Make a corner plot of the feature values. 213 | """ 214 | rejection_boundaries_kwargs = { 215 | "linestyles": "--", 216 | "lw": 1, 217 | "color": "#e57a44", 218 | } 219 | 220 | num_features = len(feature_values) 221 | grid = GridSpec(num_features, num_features, **gridspec_kwargs) 222 | 223 | names = feature_values.keys() 224 | feature_limits = { 225 | name: (stats.vmin(4 * q), stats.vmax(4 * q)) 226 | for name, stats in feature_stats.items() 227 | } 228 | 229 | for (ix, xname), (iy, yname) in itertools.combinations( 230 | enumerate(names), 2 231 | ): 232 | # NOTE: 'ix' increases left-right and 'iy' top-bottom 233 | # Scatter plots 234 | axes = plt.subplot(grid[iy, ix]) 235 | xdata = feature_values[xname].ravel() 236 | ydata = feature_values[yname].ravel() 237 | axes.scatter(xdata, ydata, color="#303030", s=2, alpha=0.1) 238 | 239 | # Rejection boundaries 240 | stats_x = feature_stats[xname] 241 | stats_y = feature_stats[yname] 242 | axes.vlines( 243 | [stats_x.vmin(q), stats_x.vmax(q)], 244 | stats_y.vmin(q), 245 | stats_y.vmax(q), 246 | **rejection_boundaries_kwargs, 247 | ) 248 | axes.hlines( 249 | [stats_y.vmin(q), stats_y.vmax(q)], 250 | stats_x.vmin(q), 251 | stats_x.vmax(q), 252 | **rejection_boundaries_kwargs, 253 | ) 254 | 255 | # Set limits 256 | axes.set_xlim(*feature_limits[xname]) 257 | axes.set_ylim(*feature_limits[yname]) 258 | 259 | if iy == num_features - 1: 260 | axes.set_xlabel(xname, fontweight="bold") 261 | else: 262 | axes.set_xticklabels([]) 263 | 264 | if ix == 0: 265 | axes.set_ylabel(yname, fontweight="bold") 266 | else: 267 | axes.set_yticklabels([]) 268 | 269 | # Histograms 270 | num_bins = 50 271 | for i, name in enumerate(names): 272 | axes = plt.subplot(grid[i, i]) 273 | data = feature_values[name].ravel() 274 | xmin, xmax = feature_limits[name] 275 | axes.hist( 276 | data, 277 | bins=np.linspace(xmin, xmax, num_bins), 278 | histtype="step", 279 | color="#303030", 280 | ) 281 | 282 | # Rejection boundaries 283 | stats = feature_stats[name] 284 | ymin, ymax = axes.get_ylim() 285 | axes.vlines( 286 | [stats.vmin(q), stats.vmax(q)], 287 | ymin, 288 | ymax, 289 | **rejection_boundaries_kwargs, 290 | ) 291 | axes.set_ylim(ymin, ymax) 292 | axes.set_yticks([]) 293 | axes.set_yticklabels([]) 294 | 295 | axes.set_xlim(xmin, xmax) 296 | if i == num_features - 1: 297 | axes.set_xlabel(name, fontweight="bold") 298 | if i != num_features - 1: 299 | axes.set_xticklabels([]) 300 | 301 | 302 | def mask_plot( 303 | mask: NDArray, dim0_name: str, dim1_name: str, **gridspec_kwargs 304 | ): 305 | """ 306 | Plot a two-dimensional binary mask. 307 | """ 308 | grid = GridSpec(1, 1, **gridspec_kwargs) 309 | dim0, dim1 = mask.shape 310 | xlim = (-0.5, dim0 - 0.5) 311 | ylim = (-0.5, dim1 - 0.5) 312 | 313 | axes = plt.subplot(grid[0, 0]) 314 | axes.imshow( 315 | mask.T, 316 | aspect="auto", 317 | cmap="binary", 318 | interpolation="nearest", 319 | alpha=0.75, 320 | ) 321 | axes.set_xlim(*xlim) 322 | axes.set_ylim(*ylim) 323 | axes.set_xlabel(f"{dim0_name.capitalize()} index", fontweight="bold") 324 | axes.set_ylabel(f"{dim1_name.capitalize()} index", fontweight="bold") 325 | -------------------------------------------------------------------------------- /clfd/serialization.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import base64 3 | import json 4 | from dataclasses import fields 5 | from typing import Any, Callable 6 | 7 | import numpy as np 8 | from numpy.typing import NDArray 9 | 10 | SerializableDict = dict[str, Any] 11 | Serializer = Callable[[Any], SerializableDict] 12 | DeSerializer = Callable[[SerializableDict], Any] 13 | 14 | TYPE_KEY = "__type__" 15 | 16 | 17 | def serialize_ndarray(obj: NDArray) -> SerializableDict: 18 | return { 19 | TYPE_KEY: "ndarray", 20 | "shape": list(obj.shape), 21 | "dtype": str(obj.dtype), 22 | "base64_data": base64.b64encode(obj).decode(encoding="utf-8"), 23 | } 24 | 25 | 26 | def deserialize_ndarray(mapping: SerializableDict) -> NDArray: 27 | data_bytes = bytearray(mapping["base64_data"], encoding="utf-8") 28 | return np.frombuffer( 29 | base64.b64decode(data_bytes), dtype=mapping["dtype"] 30 | ).reshape(mapping["shape"]) 31 | 32 | 33 | SERIALIZERS: dict[str, Serializer] = {"ndarray": serialize_ndarray} 34 | DESERIALIZERS: dict[str, DeSerializer] = {"ndarray": deserialize_ndarray} 35 | 36 | 37 | def type_key_adder(serializer: Serializer) -> Serializer: 38 | def decorated(obj: Any) -> SerializableDict: 39 | return serializer(obj) | {TYPE_KEY: type(obj).__name__} 40 | 41 | return decorated 42 | 43 | 44 | def type_key_remover(deserializer: DeSerializer) -> DeSerializer: 45 | def decorated(mapping: SerializableDict) -> Any: 46 | mapping.pop(TYPE_KEY) 47 | return deserializer(mapping) 48 | 49 | return decorated 50 | 51 | 52 | class JSONSerializable(abc.ABC): 53 | """ 54 | Mixin to make any class JSON serializable. 55 | """ 56 | 57 | def __init_subclass__(cls): 58 | SERIALIZERS[cls.__name__] = type_key_adder(cls._to_dict) 59 | DESERIALIZERS[cls.__name__] = type_key_remover(cls._from_dict) 60 | 61 | @abc.abstractmethod 62 | def _to_dict(self) -> dict[str, Any]: 63 | """ 64 | Convert to JSON-serializable dict. 65 | """ 66 | 67 | @classmethod 68 | @abc.abstractmethod 69 | def _from_dict(cls, mapping: dict[str, Any]): 70 | """ 71 | Initialize object from dict loaded from JSON. 72 | """ 73 | 74 | 75 | def shallow_asdict(obj) -> dict[str, Any]: 76 | """ 77 | Non-recursive version of dataclasses.asdict(). 78 | """ 79 | return {f.name: getattr(obj, f.name) for f in fields(obj)} 80 | 81 | 82 | class JSONSerializableDataclass(JSONSerializable): 83 | def _to_dict(self) -> dict[str, Any]: 84 | # We don't want to use asdict() here, because it also applies to fields 85 | # that are dataclasses such as Stats, converting them into dicts 86 | # along the way. However, we want our code to add the special dict key 87 | # that defines the data type. 88 | return shallow_asdict(self) 89 | 90 | @classmethod 91 | def _from_dict(cls, mapping: dict[str, Any]): 92 | return cls(**mapping) 93 | 94 | 95 | class Encoder(json.JSONEncoder): 96 | def default(self, obj): 97 | serialize = SERIALIZERS.get(type(obj).__name__, None) 98 | if serialize: 99 | return serialize(obj) 100 | return json.JSONEncoder.default(self, obj) 101 | 102 | 103 | def object_hook(obj): 104 | if isinstance(obj, dict) and TYPE_KEY in obj: 105 | cls_name = obj[TYPE_KEY] 106 | deserialize = DESERIALIZERS[cls_name] 107 | return deserialize(obj) 108 | return obj 109 | 110 | 111 | def json_dumps(obj, **kwargs): 112 | kwargs["cls"] = Encoder 113 | return json.dumps(obj, **kwargs) 114 | 115 | 116 | def json_loads(s: str, **kwargs): 117 | kwargs["object_hook"] = object_hook 118 | return json.loads(s, **kwargs) 119 | 120 | 121 | def json_dump(obj, fp, **kwargs): 122 | kwargs["cls"] = Encoder 123 | return json.dump(obj, fp, **kwargs) 124 | 125 | 126 | def json_load(fp, **kwargs): 127 | kwargs["object_hook"] = object_hook 128 | return json.load(fp, **kwargs) 129 | -------------------------------------------------------------------------------- /clfd/spike_finding.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | import numpy as np 5 | from numpy.typing import NDArray 6 | 7 | from clfd.profile_masking import make_in_bounds_zap_indices_and_mask 8 | from clfd.serialization import JSONSerializableDataclass 9 | 10 | 11 | @dataclass(frozen=True) 12 | class SpikeFindingResult(JSONSerializableDataclass): 13 | """ 14 | Inputs to and result of the spike identification process. 15 | """ 16 | 17 | q: float 18 | """ 19 | Tukey parameter that was used for identifying bad time-phase bins. 20 | """ 21 | 22 | zap_channels: list[int] 23 | """ 24 | List of channel indices that were ignored in the analysis. 25 | """ 26 | 27 | mask: NDArray 28 | """ 29 | Binary mask with shape (num_subints, num_chans) where 'True' denotes a 30 | bad time-phase bin. 31 | """ 32 | 33 | 34 | @dataclass(frozen=True) 35 | class SpikeSubtractionPlan: 36 | """ 37 | The information necessary to replace bad data identified in the 38 | time-phase masking process. 39 | """ 40 | 41 | valid_channels: list[int] 42 | """ 43 | The complement of zapped channel indices. 44 | """ 45 | 46 | mask: NDArray 47 | """ 48 | Binary mask with shape (num_subints, num_chans) where 'True' denotes a 49 | bad time-phase bin. 50 | """ 51 | 52 | replacement_values: NDArray 53 | """ 54 | Replacement values with same shape as original cube. 55 | data[i, valid_chans, j] should be replaced by 56 | replacement_values[i, valid_chans, j]. 57 | """ 58 | 59 | def apply(self, cube: NDArray): 60 | """ 61 | Apply to data cube, returning a new cube where bad values have been 62 | replaced. 63 | """ 64 | clean_cube = cube.copy() 65 | for i, j in zip(*np.where(self.mask)): 66 | clean_cube[i, self.valid_channels, j] = self.replacement_values[ 67 | i, self.valid_channels, j 68 | ] 69 | return clean_cube 70 | 71 | def subint_to_bad_phase_bins_mapping(self) -> dict[int, NDArray]: 72 | """ 73 | Returns a dictionary {subint_index: bad_phase_bin_indices}. 74 | Useful for replacing bad data in a PSRCHIVE archive. 75 | """ 76 | result = {} 77 | num_subints = self.mask.shape[0] 78 | for isub in range(num_subints): 79 | (bad_bins,) = np.where(self.mask[isub]) 80 | if len(bad_bins): 81 | result[isub] = bad_bins 82 | return result 83 | 84 | 85 | def find_time_phase_spikes( 86 | cube: NDArray, q: float = 4.0, zap_channels: Iterable[int] = () 87 | ) -> tuple[SpikeFindingResult, SpikeSubtractionPlan]: 88 | """ 89 | Compute a data mask based on the cube's time-phase plot (sum of the 90 | data along the frequency axis of the cube). 91 | 92 | Parameters 93 | ---------- 94 | cube: NDArray 95 | The input data cube as a numpy array of shape 96 | (num_subints, num_chans, num_bins) 97 | q: float, optional 98 | Parameter that controls the min and max values that define the 99 | 'inlier' or 'normality' range. Larger values result in fewer outliers. 100 | zap_channels: Iterable[int], optional 101 | Frequency channel indices to exclude from the outlier analysis. 102 | 103 | Returns 104 | ------- 105 | result : SpikeFindingResult 106 | The results of the spike finding process, including a time-phase mask 107 | containing the (time, phase) bins that should be spike-subtracted. 108 | replacement_plan : SpikeSubtractionPlan 109 | The replacement plan for bad data. 110 | """ 111 | num_subints, num_chans, __ = cube.shape 112 | zap_channels, zap_mask = make_in_bounds_zap_indices_and_mask( 113 | zap_channels, num_chans 114 | ) 115 | if len(zap_channels) == num_chans: 116 | raise ValueError("Cannot run spike finding with all channels zapped") 117 | 118 | keep_mask = np.logical_not(zap_mask) 119 | (valid_channels,) = np.where(keep_mask) 120 | 121 | # For the purposes of this masking algorithm, we need to manipulate 122 | # baseline-subtracted data 123 | baselines = np.median(cube, axis=2).reshape(num_subints, num_chans, 1) 124 | subtracted_data = cube - baselines 125 | subints = subtracted_data[:, keep_mask, :].sum(axis=1) 126 | 127 | # Percentiles along time axis 128 | q1, med, q3 = np.percentile(subints, [25, 50, 75], axis=0) 129 | iqr = q3 - q1 130 | vmin = q1 - q * iqr 131 | vmax = q3 + q * iqr 132 | 133 | mask = (subints < vmin) | (subints > vmax) 134 | repvals = baselines + med / len(valid_channels) 135 | 136 | result = SpikeFindingResult( 137 | q=float(q), zap_channels=list(zap_channels), mask=mask 138 | ) 139 | replacement_plan = SpikeSubtractionPlan( 140 | valid_channels=list(valid_channels), 141 | mask=mask, 142 | replacement_values=repvals, 143 | ) 144 | return result, replacement_plan 145 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | ARG PSRCHIVE_VERSION=2024-12-02 4 | 5 | RUN apt-get update && \ 6 | apt-get install -y \ 7 | build-essential \ 8 | git \ 9 | autoconf \ 10 | libtool \ 11 | pkg-config \ 12 | gfortran \ 13 | fftw3-dev \ 14 | libcfitsio-dev \ 15 | python3 \ 16 | python3-dev \ 17 | python3-pip \ 18 | python-is-python3 19 | 20 | # Ubuntu 22.04 ships with pip 22.0.2 which does not work with the modern pyproject.toml format 21 | # Also, PSRCHIVE won't build with numpy 2 yet 22 | RUN pip install --upgrade pip && \ 23 | pip install "numpy<2.0.0" "swig>3" 24 | 25 | # Build PSRCHIVE with Python bindings 26 | RUN git clone --branch ${PSRCHIVE_VERSION} --depth=1 git://git.code.sf.net/p/psrchive/code psrchive && \ 27 | cd psrchive && \ 28 | ./bootstrap && \ 29 | ./configure --enable-shared --enable-static F77=gfortran && \ 30 | make -j8 && \ 31 | make install 32 | 33 | ENV LD_LIBRARY_PATH=/usr/local/lib:${LD_LIBRARY_PATH} 34 | ENV PYTHONPATH=/usr/local/lib/python3.10/site-packages:${PYTHONPATH} 35 | 36 | # Check build is successful 37 | RUN python -c "import psrchive" 38 | -------------------------------------------------------------------------------- /docs/report.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v-morello/clfd/7338a80b3e834818ca17b83b8a051ccfebd405e9/docs/report.png -------------------------------------------------------------------------------- /example_data/npy_example.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v-morello/clfd/7338a80b3e834818ca17b83b8a051ccfebd405e9/example_data/npy_example.npy -------------------------------------------------------------------------------- /example_data/psrchive_example.ar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v-morello/clfd/7338a80b3e834818ca17b83b8a051ccfebd405e9/example_data/psrchive_example.ar -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = [ 4 | "setuptools>=42", 5 | "setuptools-scm>=4", 6 | ] 7 | 8 | [project] 9 | authors = [ 10 | {name = "Vincent Morello", email = "vmorello@gmail.com"}, 11 | ] 12 | classifiers = [ 13 | "Programming Language :: Python :: 3.9", 14 | "Programming Language :: Python :: 3.10", 15 | "Programming Language :: Python :: 3.11", 16 | "Programming Language :: Python :: 3.12", 17 | "Programming Language :: Python :: 3.13", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: Unix", 20 | "Topic :: Scientific/Engineering :: Astronomy", 21 | ] 22 | description = "Smart RFI removal algorithms to be used on folded pulsar search and timing data" 23 | dynamic = ["version"] 24 | license = {text = "MIT License"} 25 | name = "clfd" 26 | readme = "README.md" 27 | requires-python = ">=3.9" 28 | 29 | dependencies = [ 30 | "numpy<2.0.0", 31 | "matplotlib", 32 | ] 33 | 34 | [project.optional-dependencies] 35 | dev = [ 36 | "isort", 37 | "flake8", 38 | "black", 39 | "pytest", 40 | "pytest-cov", 41 | "build", 42 | "twine", 43 | ] 44 | 45 | [project.scripts] 46 | clfd = "clfd.apps.cleanup:main" 47 | 48 | [project.urls] 49 | homepage = "https://github.com/v-morello/clfd" 50 | 51 | [tool.setuptools.packages.find] 52 | where = [""] 53 | include = ["clfd", "clfd.apps", "clfd.features"] 54 | 55 | [tool.setuptools_scm] 56 | 57 | [tool.black] 58 | line-length = 79 59 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/v-morello/clfd/7338a80b3e834818ca17b83b8a051ccfebd405e9/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytest 5 | from numpy.typing import NDArray 6 | 7 | from clfd import ArchiveHandler 8 | 9 | 10 | def two_dimensional_boolean_mask_from_text(text: str) -> NDArray: 11 | """ 12 | Self-explanatory. 13 | """ 14 | result = [] 15 | for row_text in text.strip().split("\n"): 16 | row = list(map(int, row_text)) 17 | result.append(row) 18 | return np.asarray(result, dtype=bool) 19 | 20 | 21 | def two_dimensional_boolean_mask_from_file(path: Path) -> NDArray: 22 | """ 23 | Self-explanatory. 24 | """ 25 | with open(path, "r") as file: 26 | return two_dimensional_boolean_mask_from_text(file.read()) 27 | 28 | 29 | @pytest.fixture(scope="module") 30 | def data_cube() -> NDArray: 31 | """ 32 | Data cube loaded from the example npy file. 33 | """ 34 | path = Path(__file__).parent / ".." / "example_data" / "npy_example.npy" 35 | return np.load(path) 36 | 37 | 38 | @pytest.fixture(scope="module") 39 | def expected_profmask() -> NDArray: 40 | """ 41 | Expected profile mask when running clfd on the test archive with: 42 | - features: (std, ptp, lfamp) 43 | - qmask: 2.0 44 | Has shape (num_subints, num_channels). 45 | """ 46 | return two_dimensional_boolean_mask_from_file( 47 | Path(__file__).parent / "expected_profmask.txt" 48 | ) 49 | 50 | 51 | @pytest.fixture(scope="module") 52 | def expected_tpmask() -> NDArray: 53 | """ 54 | Expected time-phase mask when running clfd on the test archive with 55 | qspike = 2.0. Has shape (num_subints, num_phase_bins). 56 | """ 57 | return two_dimensional_boolean_mask_from_file( 58 | Path(__file__).parent / "expected_tpmask.txt" 59 | ) 60 | 61 | 62 | @pytest.fixture(scope="module") 63 | def archive_path() -> Path: 64 | return ( 65 | Path(__file__).parent / ".." / "example_data" / "psrchive_example.ar" 66 | ) 67 | 68 | 69 | @pytest.fixture(scope="module") 70 | def archive_handler(archive_path: Path) -> ArchiveHandler: 71 | return ArchiveHandler(archive_path) 72 | -------------------------------------------------------------------------------- /tests/expected_profmask.txt: -------------------------------------------------------------------------------- 1 | 11111111111111111100001000000000000000000110000000000000000000000000000000000000000010000000000000011100000110111110000110000000 2 | 11111111111111111100000000000000000000000110000000000000000000000000000000000000000010000000100000111110001110010111001111100000 3 | 11111111111111111100000000000000000000000110000000000000000000000000000000000000000000000000000000011110001000000000000011000000 4 | 11111111111111111100001000000000000000000110000000000000000000000000000000000000000000000000000000011110001100000100001110000000 5 | -------------------------------------------------------------------------------- /tests/expected_tpmask.txt: -------------------------------------------------------------------------------- 1 | 0000000000000000000000100000000000000000000000000000000000110000000000100000000000000000000000000000000000000000000000000000000000000000000010000100000000000000000000000000000000000000000100100000000000000000000000000000000000000000000000000000000000000000 2 | 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000110000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 3 | 0000000000000000000000000000000000000000000000000010000000000000000000000010000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000010000000000000000000000000000000000000000 4 | 0000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000 5 | -------------------------------------------------------------------------------- /tests/test_cli_app.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pytest 6 | from numpy.typing import NDArray 7 | 8 | from clfd import Report 9 | from clfd.apps.cleanup import run_program 10 | from clfd.apps.functions import load_zapfile 11 | 12 | from .utils import skip_unless_psrchive_installed 13 | 14 | 15 | def test_cli_app_entrypoint_exists(): 16 | exit_code = subprocess.check_call(["clfd", "--help"]) 17 | assert exit_code == 0 18 | 19 | 20 | def test_load_zapfile_on_non_empty_zapfile( 21 | tmp_path_factory: pytest.TempPathFactory, 22 | ): 23 | channels = [1, 42] 24 | path = tmp_path_factory.mktemp("zapfile") / "zapfile.txt" 25 | with open(path, "w") as file: 26 | file.write("\n".join(map(str, channels))) 27 | assert load_zapfile(path) == channels 28 | 29 | 30 | def test_load_zapfile_on_empty_zapfile( 31 | tmp_path_factory: pytest.TempPathFactory, 32 | ): 33 | path = tmp_path_factory.mktemp("zapfile") / "zapfile.txt" 34 | open(path, "w").close() 35 | assert load_zapfile(path) == [] 36 | 37 | 38 | @skip_unless_psrchive_installed 39 | def test_cli_app_with_spike_removal( 40 | tmp_path_factory: pytest.TempPathFactory, 41 | archive_path: Path, 42 | expected_profmask: NDArray, 43 | expected_tpmask: NDArray, 44 | ): 45 | """ 46 | Basic end-to-end test of the CLI app on a PSRCHIVE archive. 47 | """ 48 | outdir = Path(tmp_path_factory.mktemp("clfd_outdir")) 49 | args = [ 50 | "--despike", 51 | "--qspike", 52 | str(2.0), 53 | "--outdir", 54 | str(outdir), 55 | str(archive_path), 56 | ] 57 | run_program(args) 58 | 59 | expected_archive_path = outdir / "psrchive_example.ar.clfd" 60 | assert expected_archive_path.exists() 61 | 62 | expected_report_path = outdir / "psrchive_example_clfd_report.json" 63 | assert expected_report_path.exists() 64 | 65 | expected_plot_path = outdir / "psrchive_example_clfd_report.png" 66 | assert expected_plot_path.exists() 67 | 68 | report = Report.load(expected_report_path) 69 | assert np.array_equal( 70 | report.profile_masking_result.mask, expected_profmask 71 | ) 72 | assert np.array_equal(report.spike_finding_result.mask, expected_tpmask) 73 | 74 | 75 | @skip_unless_psrchive_installed 76 | def test_cli_app_without_spike_removal( 77 | tmp_path_factory: pytest.TempPathFactory, 78 | archive_path: Path, 79 | expected_profmask: NDArray, 80 | ): 81 | outdir = Path(tmp_path_factory.mktemp("clfd_outdir")) 82 | args = [ 83 | "--outdir", 84 | str(outdir), 85 | str(archive_path), 86 | ] 87 | run_program(args) 88 | 89 | expected_archive_path = outdir / "psrchive_example.ar.clfd" 90 | assert expected_archive_path.exists() 91 | 92 | expected_report_path = outdir / "psrchive_example_clfd_report.json" 93 | assert expected_report_path.exists() 94 | 95 | expected_plot_path = outdir / "psrchive_example_clfd_report.png" 96 | assert expected_plot_path.exists() 97 | 98 | report = Report.load(expected_report_path) 99 | assert np.array_equal( 100 | report.profile_masking_result.mask, expected_profmask 101 | ) 102 | assert report.spike_finding_result is None 103 | 104 | 105 | @skip_unless_psrchive_installed 106 | def test_cli_app_rejects_bad_feature_names(archive_path: Path): 107 | args = [ 108 | "--features", 109 | "std", 110 | "this_feature_name_does_not_exist", 111 | "ptp", 112 | str(archive_path), 113 | ] 114 | with pytest.raises(SystemExit): 115 | run_program(args) 116 | 117 | 118 | @skip_unless_psrchive_installed 119 | def test_cli_app_rejects_non_existent_output_dir(archive_path: Path): 120 | args = [ 121 | "--outdir", 122 | "/non/existent/path", 123 | str(archive_path), 124 | ] 125 | with pytest.raises(SystemExit): 126 | run_program(args) 127 | 128 | 129 | @skip_unless_psrchive_installed 130 | def test_cli_app_reject_duplicate_input_file_names_if_outdir_is_specified( 131 | tmp_path_factory: pytest.TempPathFactory, 132 | archive_path: Path, 133 | ): 134 | outdir = Path(tmp_path_factory.mktemp("clfd_outdir")) 135 | args = [ 136 | "--outdir", 137 | str(outdir), 138 | str(archive_path), 139 | str(Path("/some/other/path") / archive_path.name), 140 | ] 141 | with pytest.raises(ValueError): 142 | run_program(args) 143 | 144 | 145 | @skip_unless_psrchive_installed 146 | def test_cli_app_does_not_crash_if_archive_cannot_be_processed( 147 | tmp_path_factory: pytest.TempPathFactory, 148 | archive_path: Path, 149 | ): 150 | outdir = Path(tmp_path_factory.mktemp("clfd_outdir")) 151 | args = [ 152 | "--outdir", 153 | str(outdir), 154 | str(archive_path), 155 | "/non/existent/archive.ar", 156 | ] 157 | run_program(args) 158 | -------------------------------------------------------------------------------- /tests/test_profile_masking.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable 2 | 3 | import numpy as np 4 | import pytest 5 | from numpy.typing import NDArray 6 | 7 | from clfd import profile_mask 8 | from clfd.features import available_features 9 | from clfd.serialization import json_dumps, json_loads 10 | 11 | from .utils import ndarray_eq 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "feature", 16 | available_features().values(), 17 | ids=available_features().keys(), 18 | ) 19 | def test_feature_returns_ndarray_with_expected_shape( 20 | data_cube: NDArray, feature: Callable 21 | ): 22 | result = feature(data_cube) 23 | assert isinstance(result, np.ndarray) 24 | assert result.shape == data_cube.shape[:2] 25 | 26 | 27 | def test_profile_masking(data_cube: NDArray, expected_profmask: NDArray): 28 | result = profile_mask( 29 | data_cube, features=("std", "ptp", "lfamp"), q=2.0, zap_channels=() 30 | ) 31 | assert np.array_equal(result.mask, expected_profmask) 32 | 33 | 34 | @pytest.mark.parametrize( 35 | "zap_channels", [[0], [127], [17, 3, 93, 42], range(42, 93)], ids=repr 36 | ) 37 | def test_profile_masking_with_zapped_channels( 38 | data_cube: NDArray, zap_channels: Iterable[int] 39 | ): 40 | result = profile_mask(data_cube, q=1.0e9, zap_channels=zap_channels) 41 | assert set(result.zap_channels) == set(zap_channels) 42 | assert result.zap_channels == sorted(result.zap_channels) 43 | assert np.all(result.mask[:, zap_channels]) 44 | 45 | 46 | def test_profile_masking_ignores_out_of_range_zap_channels(data_cube: NDArray): 47 | result = profile_mask(data_cube, q=1.0e9, zap_channels=range(120, 140)) 48 | 49 | expected_zap_chanels = list(range(120, 128)) 50 | assert result.zap_channels == expected_zap_chanels 51 | 52 | is_whole_channel_masked = np.ufunc.reduce( 53 | np.logical_and, result.mask, axis=1 54 | ) 55 | assert np.all(is_whole_channel_masked[120:128]) 56 | assert np.all(np.logical_not(is_whole_channel_masked[0:120])) 57 | 58 | 59 | def test_profile_masking_serialization_roundtrip(data_cube: NDArray): 60 | result = profile_mask( 61 | data_cube, features=("std", "ptp", "lfamp"), q=2.0, zap_channels=() 62 | ) 63 | serialized = json_dumps(result) 64 | deserialized = json_loads(serialized) 65 | assert ndarray_eq(deserialized, result) 66 | 67 | 68 | def test_profile_masking_with_all_channels_zapped_raises_value_error( 69 | data_cube: NDArray, 70 | ): 71 | num_chan = data_cube.shape[1] 72 | expected_msg = "Cannot run profile masking with all channels zapped" 73 | with pytest.raises(ValueError, match=expected_msg): 74 | profile_mask(data_cube, zap_channels=range(num_chan)) 75 | -------------------------------------------------------------------------------- /tests/test_psrchive_processing.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | from clfd import ArchiveHandler, find_time_phase_spikes, profile_mask 6 | 7 | from .utils import skip_unless_psrchive_installed 8 | 9 | 10 | @skip_unless_psrchive_installed 11 | def test_profile_masked_archive_is_saved_with_expected_weights( 12 | archive_path: Path, tmp_path: Path 13 | ): 14 | handler = ArchiveHandler(archive_path) 15 | cube = handler.data_cube() 16 | 17 | result = profile_mask( 18 | cube, features=["std", "ptp", "lfamp"], q=2.0, zap_channels=range(10) 19 | ) 20 | handler.apply_profile_mask(result.mask) 21 | output_path = tmp_path / "archive.ar" 22 | handler.save(output_path) 23 | 24 | handler = ArchiveHandler(output_path) 25 | assert np.array_equal(handler._archive.get_weights() == 0.0, result.mask) 26 | 27 | 28 | @skip_unless_psrchive_installed 29 | def test_apply_spike_subtraction_plan_replaces_bad_data_as_expected( 30 | archive_path: Path, tmp_path: Path 31 | ): 32 | handler = ArchiveHandler(archive_path) 33 | cube = handler.data_cube() 34 | 35 | q = 2.0 36 | zap_channels = range(10) 37 | result, plan = find_time_phase_spikes(cube, q=q, zap_channels=zap_channels) 38 | handler.apply_spike_subtraction_plan(plan) 39 | 40 | output_path = tmp_path / "archive.ar" 41 | handler.save(output_path) 42 | 43 | handler = ArchiveHandler(output_path) 44 | cube = handler.data_cube() 45 | 46 | for i, j in zip(*np.where(result.mask)): 47 | assert np.allclose( 48 | cube[i, plan.valid_channels, j], 49 | plan.replacement_values[i, plan.valid_channels, j], 50 | ) 51 | -------------------------------------------------------------------------------- /tests/test_spike_finding_and_replacement.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from numpy.typing import NDArray 4 | 5 | from clfd import find_time_phase_spikes 6 | 7 | 8 | def test_find_time_phase_spikes_produces_expected_mask( 9 | data_cube: NDArray, expected_tpmask: NDArray 10 | ): 11 | result, __ = find_time_phase_spikes(data_cube, q=2.0, zap_channels=()) 12 | num_subints, __, num_bins = data_cube.shape 13 | assert result.mask.shape == (num_subints, num_bins) 14 | assert np.array_equal(result.mask, expected_tpmask) 15 | 16 | 17 | def test_find_time_phase_spikes_produces_expected_valid_channels( 18 | data_cube: NDArray, 19 | ): 20 | zap_channels = range(10, 42) 21 | __, plan = find_time_phase_spikes(data_cube, zap_channels=zap_channels) 22 | 23 | num_chans = data_cube.shape[1] 24 | assert not set(plan.valid_channels).intersection(zap_channels) 25 | assert set(plan.valid_channels).union(set(zap_channels)) == set( 26 | range(num_chans) 27 | ) 28 | 29 | 30 | def test_replaced_spike_data_does_not_get_flagged_again( 31 | data_cube: NDArray, 32 | ): 33 | """ 34 | Once bad values have been replaced, if we call time_phase_mask() again 35 | **with the same params**, then no previously flagged time-phase bins should 36 | be flagged again (NOTE: new time-phase bins may get flagged though) 37 | """ 38 | q = 2.0 39 | zap_channels = range(10, 42) 40 | result, plan = find_time_phase_spikes( 41 | data_cube, q=q, zap_channels=zap_channels 42 | ) 43 | clean_cube = plan.apply(data_cube) 44 | new_result, __ = find_time_phase_spikes( 45 | clean_cube, q=q, zap_channels=zap_channels 46 | ) 47 | assert not np.any(new_result.mask & result.mask) 48 | 49 | 50 | def test_find_time_phase_spikes_with_all_channels_zapped_raises_value_error( 51 | data_cube: NDArray, 52 | ): 53 | num_chan = data_cube.shape[1] 54 | expected_msg = "Cannot run spike finding with all channels zapped" 55 | with pytest.raises(ValueError, match=expected_msg): 56 | find_time_phase_spikes(data_cube, zap_channels=range(num_chan)) 57 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from dataclasses import is_dataclass 3 | from typing import Any 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from clfd.serialization import shallow_asdict 9 | 10 | 11 | def has_module_available(module_name: str) -> bool: 12 | try: 13 | importlib.import_module(module_name) 14 | return True 15 | except ImportError: 16 | return False 17 | 18 | 19 | skip_unless_psrchive_installed = pytest.mark.skipif( 20 | not has_module_available("psrchive"), 21 | reason="psrchive python bindings must be installed", 22 | ) 23 | 24 | 25 | def is_container_like(obj) -> bool: 26 | return isinstance(obj, (list, tuple, dict)) or is_dataclass(obj) 27 | 28 | 29 | def ndarray_eq(a: Any, b: Any) -> bool: 30 | """ 31 | Semi-general equality test that works on ndarrays and containers with 32 | ndarrays. This function recursively looks into lists, tuples, dicts 33 | and dataclasses. 34 | """ 35 | if not is_container_like(a) and not is_container_like(b): 36 | if isinstance(a, np.ndarray) or isinstance(b, np.ndarray): 37 | return np.array_equal(a, b) 38 | return a == b 39 | 40 | if not type(a) is type(b): 41 | return False 42 | 43 | if is_dataclass(a): 44 | a = shallow_asdict(a) 45 | b = shallow_asdict(b) 46 | 47 | if isinstance(a, dict): 48 | return (a.keys() == b.keys()) and all( 49 | ndarray_eq(x, y) for x, y in zip(a.values(), b.values()) 50 | ) 51 | 52 | if isinstance(a, (list, tuple)): 53 | return (len(a) == len(b)) and all( 54 | ndarray_eq(x, y) for x, y in zip(a, b) 55 | ) 56 | 57 | raise TypeError(f"Unsupported type: {type(a)}") 58 | --------------------------------------------------------------------------------