├── pyproject.toml ├── mypy.ini ├── .gitignore ├── .dockerignore ├── test-requirements.txt ├── requirements-autotune.txt ├── .pre-commit-config.yaml ├── katsdpingest ├── test │ ├── __init__.py │ ├── test_utils.py │ ├── test_receiver.py │ ├── test_ingest_session.py │ ├── test_ingest_server.py │ └── test_sigproc.py ├── __init__.py ├── ingest_kernels │ ├── merge_flags.mako │ ├── prepare_flags.mako │ ├── prepare.mako │ ├── compress_weights.mako │ ├── count_flags.mako │ ├── accum.mako │ └── postproc.mako ├── utils.py ├── sender.py ├── ingest_server.py └── receiver.py ├── .flake8 ├── requirements.txt ├── jenkins-autotune.sh ├── Dockerfile ├── setup.py ├── scripts ├── ingest_autotune.py ├── autotune_mkimage.py └── ingest.py ├── LICENSE ├── Jenkinsfile └── test └── test_sigproc.py /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "katversion"] 3 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.6 3 | ignore_missing_imports = True 4 | files = katsdpingest, scripts, test 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[cod] 2 | build 3 | dist 4 | temp 5 | cover 6 | .coverage 7 | doc/_build 8 | *.egg-info 9 | .mypy_cache 10 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | temp 4 | cover 5 | .coverage 6 | doc/_build 7 | *.egg-info 8 | *.pyc 9 | __pycache__ 10 | .git 11 | Dockerfile 12 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | -c https://raw.githubusercontent.com/ska-sa/katsdpdockerbase/master/docker-base-build/base-requirements.txt 2 | 3 | async-timeout 4 | asynctest 5 | coverage 6 | nose 7 | -------------------------------------------------------------------------------- /requirements-autotune.txt: -------------------------------------------------------------------------------- 1 | -c https://raw.githubusercontent.com/ska-sa/katsdpdockerbase/master/docker-base-build/base-requirements.txt 2 | 3 | docker==4.2.0 4 | websocket-client==0.56.0 # via docker 5 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pycqa/flake8 3 | rev: 3.9.2 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/pre-commit/mirrors-mypy 7 | rev: v0.780 8 | hooks: 9 | - id: mypy 10 | args: [] 11 | -------------------------------------------------------------------------------- /katsdpingest/test/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit testing for katsdpingest. 2 | 3 | Currently this primarily tests the sigproc library that is part of 4 | katsdpingest. Testing of the capture scripts requires some infrastructure to 5 | run and is probably best done by a human.""" 6 | -------------------------------------------------------------------------------- /katsdpingest/__init__.py: -------------------------------------------------------------------------------- 1 | """Katsdpingest library.""" 2 | 3 | # BEGIN VERSION CHECK 4 | # Get package version when locally imported from repo or via -e develop install 5 | try: 6 | import katversion as _katversion 7 | except ImportError: 8 | import time as _time 9 | __version__ = "0.0+unknown.{}".format(_time.strftime('%Y%m%d%H%M')) 10 | else: 11 | __version__ = _katversion.get_version(__path__[0]) # type: ignore 12 | # END VERSION CHECK 13 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | ignore = 4 | # whitespace before ':' - flake8 doesn't handle slices properly 5 | E203, 6 | # multiple spaces after ':' - sometimes nice to align dict values 7 | E241, 8 | # The rest are a subset of flake8 defaults 9 | # missing whitespace around arithmetic operator - sometimes useful for inner ops 10 | E226, 11 | # line break before binary operator - that's what PEP8 recommends! 12 | W503 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -c https://raw.githubusercontent.com/ska-sa/katsdpdockerbase/master/docker-base-build/base-requirements.txt 2 | -c https://raw.githubusercontent.com/ska-sa/katsdpdockerbase/master/docker-base-gpu-build/requirements.txt 3 | 4 | aiokatcp 5 | hiredis # Speeds up katsdptelstate 6 | numpy 7 | pycuda 8 | spead2 9 | 10 | # TODO: eventually switch to using a release of katdal, once the enhanced 11 | # SpectralWindow class has shipped. 12 | katdal @ git+https://github.com/ska-sa/katdal 13 | katpoint @ git+https://github.com/ska-sa/katpoint 14 | katsdpmodels[aiohttp] @ git+https://github.com/ska-sa/katsdpmodels 15 | katsdpsigproc @ git+https://github.com/ska-sa/katsdpsigproc 16 | katsdpservices[argparse,aiomonitor] @ git+https://github.com/ska-sa/katsdpservices 17 | katsdptelstate[aio] @ git+https://github.com/ska-sa/katsdptelstate 18 | -------------------------------------------------------------------------------- /jenkins-autotune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | if [ "$#" -ne 1 ]; then 4 | echo "Usage: jenkins-autotune.sh " 1>&2 5 | exit 1 6 | fi 7 | GPU="$1" 8 | LABEL="${BRANCH_NAME#origin/}" 9 | if [ "$LABEL" = "master" ]; then 10 | LABEL=latest 11 | fi 12 | IMAGE="$DOCKER_REGISTRY/katsdpingest_$GPU:$LABEL" 13 | BASE_IMAGE="$DOCKER_REGISTRY/katsdpingest:$LABEL" 14 | COPY_FROM="$DOCKER_REGISTRY/katsdpingest_$GPU:latest" 15 | install_pinned.py -r requirements-autotune.txt 16 | docker pull "$BASE_IMAGE" 17 | docker pull "$COPY_FROM" 18 | trap "docker rmi $IMAGE" EXIT 19 | scripts/autotune_mkimage.py -H $DOCKER_HOST --tls --copy --copy-from "$COPY_FROM" "$IMAGE" "$BASE_IMAGE" 20 | docker push "$IMAGE" 21 | 22 | if [ -n "$DOCKER_REGISTRY2" ]; then 23 | IMAGE2="$DOCKER_REGISTRY2/katsdpingest_$GPU:$LABEL" 24 | trap "docker rmi $IMAGE2" EXIT 25 | docker tag "$IMAGE" "$IMAGE2" 26 | docker push "$IMAGE2" 27 | fi 28 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG KATSDPDOCKERBASE_REGISTRY=harbor.sdp.kat.ac.za/dpp 2 | 3 | FROM $KATSDPDOCKERBASE_REGISTRY/docker-base-gpu-build as build 4 | 5 | # Enable Python 3 venv 6 | ENV PATH="$PATH_PYTHON3" VIRTUAL_ENV="$VIRTUAL_ENV_PYTHON3" 7 | 8 | # Install Python dependencies 9 | COPY --chown=kat:kat requirements.txt /tmp/install/requirements.txt 10 | RUN install_pinned.py -r /tmp/install/requirements.txt 11 | 12 | # Install the current package 13 | COPY --chown=kat:kat . /tmp/install/katsdpingest 14 | RUN cd /tmp/install/katsdpingest && \ 15 | python ./setup.py clean && pip install --no-deps . && pip check 16 | 17 | ####################################################################### 18 | 19 | FROM $KATSDPDOCKERBASE_REGISTRY/docker-base-gpu-runtime 20 | LABEL maintainer="sdpdev+katsdpingest@ska.ac.za" 21 | 22 | COPY --chown=kat:kat --from=build /home/kat/ve3 /home/kat/ve3 23 | ENV PATH="$PATH_PYTHON3" VIRTUAL_ENV="$VIRTUAL_ENV_PYTHON3" 24 | 25 | # Allow raw packets (for ibverbs raw QPs) 26 | USER root 27 | RUN setcap cap_net_raw+p /usr/local/bin/capambel 28 | USER kat 29 | 30 | EXPOSE 2040 31 | EXPOSE 7148/udp 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | 5 | tests_require = ['nose', 'asynctest', 'async_timeout', 'katsdpsigproc[test]'] 6 | 7 | 8 | setup( 9 | name="katsdpingest", 10 | description="Karoo Array Telescope Data Capture", 11 | author="MeerKAT SDP Team", 12 | author_email="sdpdev+katsdpingest@ska.ac.za", 13 | packages=find_packages(), 14 | package_data={'': ['ingest_kernels/*.mako']}, 15 | include_package_data=True, 16 | scripts=[ 17 | "scripts/ingest.py", 18 | "scripts/ingest_autotune.py" 19 | ], 20 | setup_requires=['katversion'], 21 | install_requires=[ 22 | 'aiokatcp>=0.7.0', # Need 0.7 for auto_strategy 23 | 'aiomonitor', 24 | 'numpy>=1.13.0', # For np.unique with axis (might really need a higher version) 25 | 'spead2>=3.0.1', 26 | 'katsdpsigproc', 27 | 'katsdpservices[argparse,aiomonitor]', 28 | 'katsdptelstate[aio]', 29 | 'katpoint', 30 | 'katdal', 31 | 'katsdpmodels[aiohttp]' 32 | ], 33 | extras_require={ 34 | 'test': tests_require 35 | }, 36 | tests_require=tests_require, 37 | zip_safe=False, 38 | use_katversion=True 39 | ) 40 | -------------------------------------------------------------------------------- /scripts/ingest_autotune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Makes autotuning happen for common configurations and for all discovered 5 | devices, so that a subsequent run will not have to wait for autotuning. 6 | """ 7 | import logging 8 | import sys 9 | 10 | from katsdpingest import ingest_session 11 | from katsdpsigproc import accel 12 | 13 | 14 | def autotune_device(device): 15 | context = device.make_context() 16 | tune_channels = ingest_session.CBFIngest.tune_channels 17 | tune_percentile_sizes = ingest_session.CBFIngest.tune_percentile_sizes 18 | for channels in tune_channels: 19 | for excise in [False, True]: 20 | for continuum in [False, True]: 21 | ingest_session.CBFIngest.create_proc_template( 22 | context, tune_percentile_sizes, channels, excise, continuum) 23 | 24 | 25 | def main(): 26 | logging.basicConfig(level='INFO') 27 | logging.getLogger('katsdpsigproc.tune').setLevel(logging.INFO) 28 | devices = accel.all_devices() 29 | if not devices: 30 | logging.error('No acceleration devices found') 31 | sys.exit(1) 32 | for device in devices: 33 | autotune_device(device) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /katsdpingest/ingest_kernels/merge_flags.mako: -------------------------------------------------------------------------------- 1 | <%include file="/port.mako"/> 2 | <%namespace name="transpose" file="/transpose_base.mako"/> 3 | 4 | <%transpose:transpose_data_class class_name="transpose_flags" type="uchar" block="${block}" vtx="${vtx}" vty="${vty}"/> 5 | <%transpose:transpose_coords_class class_name="transpose_coords" block="${block}" vtx="${vtx}" vty="${vty}"/> 6 | 7 | KERNEL REQD_WORK_GROUP_SIZE(${block}, ${block}, 1) void merge_flags( 8 | GLOBAL uchar * RESTRICT out_flags, 9 | const GLOBAL uchar * RESTRICT in_flags, 10 | const GLOBAL uchar * RESTRICT baseline_flags, 11 | int out_flags_stride, 12 | int in_flags_stride) 13 | { 14 | LOCAL_DECL transpose_flags local_flags; 15 | transpose_coords coords; 16 | transpose_coords_init_simple(&coords); 17 | 18 | // Load input flags into shared memory 19 | <%transpose:transpose_load coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 20 | int addr = ${r} * in_flags_stride + ${c}; 21 | local_flags.arr[${lr}][${lc}] = in_flags[addr]; 22 | 23 | 24 | BARRIER(); 25 | 26 | // Combine with output and baseline flags 27 | <%transpose:transpose_store coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 28 | int addr = ${r} * out_flags_stride + ${c}; 29 | out_flags[addr] |= local_flags.arr[${lr}][${lc}] | baseline_flags[${r}]; 30 | 31 | } 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011-2019, National Research Foundation (SARAO) 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 20 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 24 | OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /Jenkinsfile: -------------------------------------------------------------------------------- 1 | #!groovy 2 | 3 | @Library('katsdpjenkins') _ 4 | katsdp.killOldJobs() 5 | 6 | katsdp.setDependencies([ 7 | 'ska-sa/katsdpsigproc/master', 8 | 'ska-sa/katsdpdockerbase/master', 9 | 'ska-sa/katsdpservices/master', 10 | 'ska-sa/katsdptelstate/master', 11 | 'ska-sa/katsdpmodels/master', 12 | 'ska-sa/katdal/master', 13 | 'ska-sa/katpoint/master']) 14 | 15 | catchError { 16 | katsdp.stagePrepare(timeout: [time: 60, unit: 'MINUTES']) 17 | katsdp.stageNosetestsGpu(cuda: true, opencl: true) 18 | katsdp.stageFlake8() 19 | katsdp.stageMypy() 20 | katsdp.stageMakeDocker(venv: true) 21 | 22 | stage('katsdpingest/autotuning') { 23 | if (katsdp.notYetFailed()) { 24 | katsdp.simpleNode(label: 'cuda-A30', timeout: [time: 60, unit: 'MINUTES']) { 25 | deleteDir() 26 | katsdp.unpackGit() 27 | katsdp.unpackVenv() 28 | katsdp.unpackKatsdpdockerbase() 29 | withCredentials([usernamePassword( 30 | credentialsId: 'harbor-dpp', 31 | usernameVariable: 'HARBOR_USER', 32 | passwordVariable: 'HARBOR_PASS')]) { 33 | sh 'docker login -u "$HARBOR_USER" -p "$HARBOR_PASS" "harbor.sdp.kat.ac.za"' 34 | } 35 | katsdp.virtualenv('venv') { 36 | dir('git') { 37 | lock("katsdpingest-autotune-${env.BRANCH_NAME}") { 38 | sh './jenkins-autotune.sh a30' 39 | } 40 | } 41 | } 42 | } 43 | } 44 | } 45 | } 46 | katsdp.mail('sdpdev+katsdpingest@ska.ac.za') 47 | -------------------------------------------------------------------------------- /katsdpingest/ingest_kernels/prepare_flags.mako: -------------------------------------------------------------------------------- 1 | <%include file="/port.mako"/> 2 | <%namespace name="transpose" file="/transpose_base.mako"/> 3 | 4 | <%transpose:transpose_data_class class_name="transpose_flags" type="uchar" block="${block}" vtx="${vtx}" vty="${vty}"/> 5 | <%transpose:transpose_coords_class class_name="transpose_coords" block="${block}" vtx="${vtx}" vty="${vty}"/> 6 | 7 | KERNEL REQD_WORK_GROUP_SIZE(${block}, ${block}, 1) void prepare_flags( 8 | GLOBAL uchar * RESTRICT flags, 9 | const GLOBAL float2 * RESTRICT vis, 10 | const GLOBAL uchar * RESTRICT channel_mask, 11 | const GLOBAL uint * RESTRICT channel_mask_idx, 12 | int flags_stride, 13 | int vis_stride, 14 | int channel_mask_stride, 15 | uint max_mask, 16 | uchar zero_flag) 17 | { 18 | LOCAL_DECL transpose_flags local_flags; 19 | transpose_coords coords; 20 | transpose_coords_init_simple(&coords); 21 | 22 | // Compute flags into shared memory 23 | <%transpose:transpose_load coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 24 | int idx = min(max_mask, channel_mask_idx[${r}]); 25 | local_flags.arr[${lr}][${lc}] = channel_mask[idx * channel_mask_stride + ${c}]; 26 | 27 | 28 | BARRIER(); 29 | 30 | // Write flags back to global memory in channel-major order 31 | <%transpose:transpose_store coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 32 | int addr = ${r} * flags_stride + ${c}; 33 | uchar f = local_flags.arr[${lr}][${lc}]; 34 | float2 v = vis[${r} * vis_stride + ${c}]; 35 | if (v.x == 0 && v.y == 0) 36 | f |= zero_flag; 37 | flags[addr] = f; 38 | 39 | } 40 | -------------------------------------------------------------------------------- /katsdpingest/ingest_kernels/prepare.mako: -------------------------------------------------------------------------------- 1 | <%include file="/port.mako"/> 2 | <%namespace name="transpose" file="/transpose_base.mako"/> 3 | 4 | <%transpose:transpose_data_class class_name="transpose_values" type="float2" block="${block}" vtx="${vtx}" vty="${vty}"/> 5 | <%transpose:transpose_coords_class class_name="transpose_coords" block="${block}" vtx="${vtx}" vty="${vty}"/> 6 | 7 | KERNEL REQD_WORK_GROUP_SIZE(${block}, ${block}, 1) void prepare( 8 | GLOBAL float2 * RESTRICT vis_out, 9 | const GLOBAL int2 * RESTRICT vis_in, 10 | const GLOBAL short * RESTRICT permutation, 11 | int vis_out_stride, 12 | int vis_in_stride, 13 | int baselines, 14 | float scale) 15 | { 16 | LOCAL_DECL transpose_values values; 17 | transpose_coords coords; 18 | transpose_coords_init_simple(&coords); 19 | 20 | /* Load values into shared memory, applying the type conversion and 21 | * scaling. The input array is padded, so no range checks are needed. 22 | */ 23 | <%transpose:transpose_load coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 24 | int2 in_value = vis_in[${r} * vis_in_stride + ${c}]; 25 | float2 scaled_value; 26 | scaled_value.x = (float) in_value.x * scale; 27 | scaled_value.y = (float) in_value.y * scale; 28 | values.arr[${lr}][${lc}] = scaled_value; 29 | 30 | 31 | BARRIER(); 32 | 33 | /* Write value back to memory, applying baseline permutation. 34 | * Due to the permutation, we now have to check for out-of-range 35 | * baseline, but not channel. 36 | */ 37 | <%transpose:transpose_store coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 38 | if (${r} < baselines) 39 | { 40 | int baseline = permutation[${r}]; 41 | if (baseline >= 0) 42 | { 43 | float2 vis = values.arr[${lr}][${lc}]; 44 | vis_out[baseline * vis_out_stride + ${c}] = vis; 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /katsdpingest/ingest_kernels/compress_weights.mako: -------------------------------------------------------------------------------- 1 | <%include file="/port.mako"/> 2 | <%namespace name="wg_reduce" file="/wg_reduce.mako"/> 3 | 4 | ${wg_reduce.define_scratch('float', wgsx, 'scratch_t', allow_shuffle=True)} 5 | ${wg_reduce.define_function('float', wgsx, 'reduce_max', 'scratch_t', wg_reduce.op_fmax, allow_shuffle=True, broadcast=True)} 6 | 7 | /** 8 | * Produce more compact (approximate) representation of weights. On output, 9 | * each weight is represented as a product of a per-channel float32 and a 10 | * per-channel, per-baseline uint8. 11 | * 12 | * This kernel is modelled on hreduce.mako from katsdpsigproc. A workgroup is 13 | * 2D, with each row of a workgroup handling a complete channel. 14 | */ 15 | KERNEL REQD_WORK_GROUP_SIZE(${wgsx}, ${wgsy}, 1) void compress_weights( 16 | GLOBAL unsigned char * RESTRICT weights_out, 17 | GLOBAL float * RESTRICT weights_channel, 18 | GLOBAL const float * RESTRICT weights_in, 19 | int weights_out_stride, 20 | int weights_in_stride, 21 | int baselines) 22 | { 23 | LOCAL_DECL scratch_t scratch[${wgsy}]; 24 | /* Find the largest value for each channel */ 25 | int channel = get_global_id(1); 26 | int lid = get_local_id(0); 27 | int in_offset = weights_in_stride * channel; 28 | // Set a small lower bound (2^-96), to avoid divide-by-zero issues if all 29 | // weights are zero. 30 | float max_weight = 1.2621774e-29f; 31 | // Compute a per-workitem value 32 | for (int i = lid; i < baselines; i += ${wgsx}) 33 | max_weight = fmax(max_weight, weights_in[in_offset + i]); 34 | // Reduce the per-workitem values 35 | max_weight = reduce_max(max_weight, lid, &scratch[get_local_id(1)]); 36 | float cweight = max_weight * (1.0f / 255.0f); 37 | if (lid == 0) 38 | weights_channel[channel] = cweight; 39 | 40 | /* Scale weights relative to cweight and convert to int */ 41 | float scale = 1.0f / cweight; 42 | int out_offset = weights_out_stride * channel; 43 | for (int i = lid; i < baselines; i += ${wgsx}) 44 | { 45 | float weight = weights_in[in_offset + i]; 46 | weight = weight * scale + 0.5f; 47 | weights_out[out_offset + i] = (unsigned char) weight; 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /katsdpingest/ingest_kernels/count_flags.mako: -------------------------------------------------------------------------------- 1 | <%include file="/port.mako"/> 2 | <%namespace name="wg_reduce" file="/wg_reduce.mako"/> 3 | 4 | ${wg_reduce.define_scratch('unsigned int', wgs, 'scratch_t', allow_shuffle=True)} 5 | ${wg_reduce.define_function('unsigned int', wgs, 'reduce', 'scratch_t', wg_reduce.op_plus, allow_shuffle=True, broadcast=False)} 6 | 7 | #define WGS ${wgs} 8 | #define BITS 8 9 | 10 | /** 11 | * Count number of visibilities with each flag bit, per baseline. 12 | * 13 | * This implementation is far from optimal, but it's also not massively 14 | * performance-critical. Some ideas for future optimisation if necessary: 15 | * 16 | * - Have each workitem load 32 bits at a time instead of 8 (has some 17 | * complications if the number of channels is odd). 18 | * - Have each workgroup handle several baselines and fewer channels. It would 19 | * lower reduction costs, but could also reduce parallelism. 20 | * - Have each workgroup handle only some of the channels, and do a final 21 | * CPU-side reduction (more parallelism). 22 | * - Do the reduction on all the counts jointly, instead of one at a time. 23 | * 24 | * @param[out] counts Output counts, shape (baselines, 8), contiguous 25 | * @param[out] any_counts Count of visibilities with any flag, per baseline 26 | * @param flags Per-visibility input flags, shape (baselines, flags_stride) 27 | * @param flags_stride Stride for @a flags 28 | * @param channels Number of channels over which to do count 29 | * @param channel_start Offset to first channel to count in @a flags 30 | * @param mask Mask ANDed with the flags before counting (used to eliminate the 31 | * pseudo-flag used to mark unflagged data). 32 | */ 33 | KERNEL REQD_WORK_GROUP_SIZE(WGS, 1, 1) void count_flags( 34 | GLOBAL unsigned int * RESTRICT counts, 35 | GLOBAL unsigned int * RESTRICT any_counts, 36 | const GLOBAL unsigned char * RESTRICT flags, 37 | int flags_stride, 38 | int channels, 39 | int channel_start, 40 | unsigned char mask) 41 | { 42 | LOCAL_DECL scratch_t scratch; 43 | 44 | int lid = get_local_id(0); 45 | int baseline = get_global_id(1); 46 | // Adjust pointer to start of current baseline 47 | flags += baseline * flags_stride + channel_start; 48 | unsigned int sums[BITS] = {}; 49 | unsigned int any = 0; 50 | for (int i = lid; i < channels; i += WGS) 51 | { 52 | unsigned char flag = flags[i] & mask; 53 | any += (flag != 0); 54 | for (int j = 0; j < 8; j++) 55 | { 56 | sums[j] += (flag & 1); 57 | flag >>= 1; 58 | } 59 | } 60 | 61 | // Accumulate across workitems 62 | for (int i = 0; i < BITS; i++) 63 | sums[i] = reduce(sums[i], lid, &scratch); 64 | any = reduce(any, lid, &scratch); 65 | 66 | // Write results 67 | if (lid == 0) 68 | { 69 | counts += baseline * BITS; 70 | for (int i = 0; i < BITS; i++) 71 | counts[i] += sums[i]; 72 | any_counts[baseline] += any; 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /katsdpingest/ingest_kernels/accum.mako: -------------------------------------------------------------------------------- 1 | <%include file="/port.mako"/> 2 | <%namespace name="transpose" file="/transpose_base.mako"/> 3 | 4 | <%transpose:transpose_data_class class_name="transpose_vis" type="float2" block="${block}" vtx="${vtx}" vty="${vty}"/> 5 | <%transpose:transpose_data_class class_name="transpose_weights" type="float" block="${block}" vtx="${vtx}" vty="${vty}"/> 6 | <%transpose:transpose_data_class class_name="transpose_flags" type="unsigned char" block="${block}" vtx="${vtx}" vty="${vty}"/> 7 | <%transpose:transpose_coords_class class_name="transpose_coords" block="${block}" vtx="${vtx}" vty="${vty}"/> 8 | 9 | DEVICE_FN void accum_vis(GLOBAL float2 *out, float2 value, float weight) 10 | { 11 | float2 sum = *out; 12 | sum.x = fma(value.x, weight, sum.x); 13 | sum.y = fma(value.y, weight, sum.y); 14 | *out = sum; 15 | } 16 | 17 | /* 18 | * in_full_stride is for in_vis and in_flags (which are indexed from channel_start), 19 | * while in_kept_stride is for in_weights (which is indexed from 0). 20 | */ 21 | KERNEL REQD_WORK_GROUP_SIZE(${block}, ${block}, 1) void accum( 22 | % for i in range(outputs): 23 | GLOBAL float2 * RESTRICT out_vis${i}, 24 | GLOBAL float * RESTRICT out_weights${i}, 25 | GLOBAL unsigned char * RESTRICT out_flags${i}, 26 | % endfor 27 | const GLOBAL float2 * RESTRICT in_vis, 28 | const GLOBAL float * RESTRICT in_weights, 29 | const GLOBAL unsigned char * RESTRICT in_flags, 30 | int out_stride, 31 | int in_full_stride, 32 | int in_kept_stride, 33 | int channel_start) 34 | { 35 | LOCAL_DECL transpose_vis local_vis; 36 | LOCAL_DECL transpose_weights local_weights; 37 | LOCAL_DECL transpose_flags local_flags; 38 | transpose_coords coords; 39 | 40 | transpose_coords_init_simple(&coords); 41 | 42 | // Load a block of data, for all channels 43 | <%transpose:transpose_load coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 44 | int full_addr = ${r} * in_full_stride + ${c} + channel_start; 45 | int kept_addr = ${r} * in_kept_stride + ${c}; 46 | local_vis.arr[${lr}][${lc}] = in_vis[full_addr]; 47 | local_weights.arr[${lr}][${lc}] = in_weights[kept_addr]; 48 | local_flags.arr[${lr}][${lc}] = in_flags[full_addr]; 49 | 50 | 51 | BARRIER(); 52 | 53 | // Apply flags to weights, and do weighted accumulation 54 | <%transpose:transpose_store coords="coords" block="${block}" vtx="${vtx}" vty="${vty}" args="r, c, lr, lc"> 55 | float2 vis = local_vis.arr[${lr}][${lc}]; 56 | float weight = local_weights.arr[${lr}][${lc}]; 57 | unsigned int flag = local_flags.arr[${lr}][${lc}]; 58 | % if excise: 59 | if (flag != 0) 60 | weight *= 5.42101086e-20f; // 2^-64 61 | else 62 | flag = ${unflagged_bit}; 63 | % endif 64 | int addr = ${r} * out_stride + ${c}; 65 | % for i in range(outputs): 66 | accum_vis(&out_vis${i}[addr], vis, weight); 67 | out_weights${i}[addr] += weight; 68 | out_flags${i}[addr] |= flag; 69 | % endfor 70 | 71 | } 72 | -------------------------------------------------------------------------------- /katsdpingest/ingest_kernels/postproc.mako: -------------------------------------------------------------------------------- 1 | <%include file="/port.mako"/> 2 | 3 | /* If there was partial excision and the sum of the non-excised visibilities is 4 | * zero, the visibility will come out infinitesimal rather than zero due to the 5 | * 2^-64 scaling trick. If it's an auto-correlation, it may get inverted later 6 | * to compute statistical weights, which leads to numerical issues, so we flush 7 | * it to zero. 8 | * 9 | * The minimum non-zero absolute value for the weighted sum is 1/n_accs, while 10 | * the largest possible spurious value is 2^-33 * (m-1)/n_accs, where m is the 11 | * number of input dumps added together (potentially somewhat large for 12 | * continuum) and n_accs is the number of accumulations in the correlator. 13 | * If one needs to support a very large range of n_accs then it should be 14 | * an extra parameter (or scaling by n_accs should be delayed until this 15 | * stage), but 2e-9 should be safe for all reasonable cases for now. 16 | */ 17 | DEVICE_FN float flush_zero(float x) 18 | { 19 | return fabsf(x) < 2e-9f ? 0.0f : x; 20 | } 21 | 22 | DEVICE_FN float2 flush_zero2(float2 vis) 23 | { 24 | return make_float2(flush_zero(vis.x), flush_zero(vis.y)); 25 | } 26 | 27 | KERNEL REQD_WORK_GROUP_SIZE(${wgsx}, ${wgsy}, 1) void postproc( 28 | GLOBAL float2 * RESTRICT vis, 29 | GLOBAL float * RESTRICT weights, 30 | GLOBAL unsigned char * RESTRICT flags, 31 | % if continuum: 32 | GLOBAL float2 * RESTRICT cont_vis, 33 | GLOBAL float * RESTRICT cont_weights, 34 | GLOBAL unsigned char * RESTRICT cont_flags, 35 | int cont_factor, 36 | % endif 37 | int stride) 38 | { 39 | % if not continuum: 40 | const int cont_factor = 1; 41 | % endif 42 | int baseline = get_global_id(0); 43 | int cont_channel = get_global_id(1); 44 | int channel0 = cont_channel * cont_factor; 45 | 46 | % if continuum: 47 | float2 cv; 48 | cv.x = 0.0f; 49 | cv.y = 0.0f; 50 | float cw = 0.0f; 51 | unsigned char cf = 0; 52 | % endif 53 | #pragma unroll 4 54 | for (int i = 0; i < cont_factor; i++) 55 | { 56 | int channel = channel0 + i; 57 | int addr = channel * stride + baseline; 58 | GLOBAL float2 *vptr = &vis[addr]; 59 | float2 v = *vptr; 60 | GLOBAL float *wptr = &weights[addr]; 61 | float w = *wptr; 62 | GLOBAL unsigned char *fptr = &flags[addr]; 63 | unsigned char f = *fptr; 64 | % if continuum: 65 | cv.x += v.x; 66 | cv.y += v.y; 67 | cw += w; 68 | cf |= f; 69 | % endif 70 | float scale = 1.0f / w; 71 | % if excise: 72 | if (!(f & ${unflagged_bit})) 73 | *wptr = 1.8446744e19f * w; // scale by 2^64, to compensate for previous 2^-64 74 | else 75 | { 76 | *fptr = 0; 77 | v = flush_zero2(v); 78 | } 79 | % endif 80 | v.x *= scale; 81 | v.y *= scale; 82 | *vptr = v; 83 | } 84 | 85 | % if continuum: 86 | float scale = 1.0 / cw; 87 | cv.x *= scale; 88 | cv.y *= scale; 89 | % if excise: 90 | if (!(cf & ${unflagged_bit})) 91 | cw *= 1.8446744e19; // scale by 2^64, to compensate for previous 2^-64 92 | else 93 | { 94 | cf = 0; 95 | cv = flush_zero2(cv); 96 | } 97 | % endif 98 | int cont_addr = cont_channel * stride + baseline; 99 | cont_vis[cont_addr] = cv; 100 | cont_weights[cont_addr] = cw; 101 | cont_flags[cont_addr] = cf; 102 | % endif 103 | } 104 | -------------------------------------------------------------------------------- /katsdpingest/test/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for the util module.""" 2 | from katsdpingest.utils import Range 3 | from nose.tools import (assert_equal, assert_raises, 4 | assert_true, assert_false, assert_in, assert_not_in) 5 | 6 | 7 | class TestRange: 8 | """Tests for :class:`katsdpingest.sigproc.Range`.""" 9 | def test_init(self): 10 | r = Range(3, 5) 11 | assert_equal(3, r.start) 12 | assert_equal(5, r.stop) 13 | r = Range(10, 10) 14 | assert_equal(10, r.start) 15 | assert_equal(10, r.stop) 16 | 17 | def test_str(self): 18 | assert_equal('3:5', str(Range(3, 5))) 19 | assert_equal('2:2', str(Range(2, 2))) 20 | 21 | def test_repr(self): 22 | assert_equal('Range(3, 5)', repr(Range(3, 5))) 23 | assert_equal('Range(10, 10)', repr(Range(10, 10))) 24 | 25 | def test_contains(self): 26 | r = Range(3, 5) 27 | assert_in(3, r) 28 | assert_in(4, r) 29 | assert_not_in(2, r) 30 | assert_not_in(5, r) 31 | 32 | def test_issubset(self): 33 | assert_true(Range(3, 5).issubset(Range(3, 5))) 34 | assert_true(Range(3, 5).issubset(Range(0, 10))) 35 | # Empty range if always a subset, even if the start value is outside 36 | assert_true(Range(6, 6).issubset(Range(3, 4))) 37 | # Disjoint 38 | assert_false(Range(1, 3).issubset(Range(6, 8))) 39 | assert_false(Range(6, 8).issubset(Range(1, 3))) 40 | # Partial overlap 41 | assert_false(Range(3, 8).issubset(Range(6, 10))) 42 | assert_false(Range(6, 10).issubset(Range(3, 8))) 43 | # Superset 44 | assert_false(Range(0, 10).issubset(Range(1, 10))) 45 | 46 | def test_issuperset(self): 47 | # It's implemented by issubset, so just a quick test for coverage 48 | assert_true(Range(3, 5).issuperset(Range(3, 5))) 49 | assert_false(Range(3, 5).issuperset(Range(3, 6))) 50 | 51 | def test_isaligned(self): 52 | assert_true(Range(2, 12).isaligned(2)) 53 | assert_false(Range(2, 12).isaligned(4)) 54 | assert_false(Range(5, 11).isaligned(5)) 55 | 56 | def test_alignto(self): 57 | assert_equal(Range(-10, 15), Range(-8, 13).alignto(5)) 58 | # Empty range case 59 | assert_equal(0, len(Range(9, 9).alignto(5))) 60 | 61 | def test_intersection(self): 62 | assert_equal(Range(3, 7), Range(-5, 7).intersection(Range(3, 10))) 63 | assert_equal(0, len(Range(3, 7).intersection(Range(7, 10)))) 64 | 65 | def test_union(self): 66 | # Overlapping 67 | assert_equal(Range(-5, 10), Range(-5, 7).union(Range(3, 10))) 68 | # Disjoint 69 | assert_equal(Range(-5, 10), Range(8, 10).union(Range(-5, 0))) 70 | # First one empty 71 | assert_equal(Range(-5, 10), Range(100, 100).union(Range(-5, 10))) 72 | # Second one empty 73 | assert_equal(Range(-5, 10), Range(-5, 10).union(Range(-10, -10))) 74 | # Both empty 75 | assert_equal(0, len(Range(5, 5).union(Range(10, 10)))) 76 | 77 | def test_len(self): 78 | assert_equal(4, len(Range(3, 7))) 79 | assert_equal(0, len(Range(2, 2))) 80 | 81 | def test_nonzero(self): 82 | assert_true(Range(3, 4)) 83 | assert_false(Range(3, 3)) 84 | 85 | def test_iter(self): 86 | assert_equal([3, 4, 5], list(Range(3, 6))) 87 | 88 | def test_relative_to(self): 89 | assert_equal(Range(3, 5), Range(8, 10).relative_to(Range(5, 10))) 90 | assert_raises(ValueError, Range(8, 10).relative_to, Range(5, 9)) 91 | 92 | def test_split(self): 93 | assert_equal(Range(14, 16), Range(10, 20).split(5, 2)) 94 | assert_equal(Range(0, 0), Range(10, 10).split(5, 2)) 95 | assert_raises(ValueError, Range(10, 20).split, 6, 3) 96 | assert_raises(ValueError, Range(10, 20).split, 5, -2) 97 | assert_raises(ValueError, Range(10, 20).split, 5, 5) 98 | -------------------------------------------------------------------------------- /scripts/autotune_mkimage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Create a derived image from the katsdpingest base image that contains 3 | autotuning results. 4 | """ 5 | 6 | import argparse 7 | import sys 8 | import os 9 | import os.path 10 | import tempfile 11 | import shutil 12 | import tarfile 13 | import io 14 | import contextlib 15 | from textwrap import dedent 16 | 17 | import docker 18 | from docker import APIClient 19 | 20 | 21 | DOCKERFILE = dedent('''\ 22 | FROM {base} 23 | COPY --chown=kat:kat tuning.db /home/kat/.cache/katsdpsigproc/ 24 | ''') 25 | 26 | 27 | def get_cache(cli, container_id): 28 | data, _ = cli.get_archive(container_id, '/home/kat/.cache/katsdpsigproc/tuning.db') 29 | tardata = b''.join(data) 30 | return tardata 31 | 32 | 33 | def untar_cache(tardata): 34 | with tarfile.open(fileobj=io.BytesIO(tardata)) as tar: 35 | with contextlib.closing(tar.extractfile('tuning.db')) as f: 36 | return f.read() 37 | 38 | 39 | def tune(cli, base_image, skip, init_tar=None): 40 | """Run a throwaway container to do the autotuning, and extract the result.""" 41 | command = ['ingest_autotune.py'] if not skip else ['/bin/true'] 42 | if init_tar is not None: 43 | command = ['sh', '-c', 44 | 'mkdir -p $HOME/.cache/katsdpsigproc && ' 45 | 'cp /tmp/tuning.db $HOME/.cache/katsdpsigproc && ' + command[0]] 46 | # If we're running inside a Docker container, expose the same devices 47 | # to our child container. 48 | environment = { 49 | 'NVIDIA_VISIBLE_DEVICES': os.environ.get('NVIDIA_VISIBLE_DEVICES', 'all') 50 | } 51 | container = cli.create_container( 52 | image=base_image, 53 | command=command, 54 | environment=environment, 55 | runtime='nvidia') 56 | try: 57 | if container['Warnings']: 58 | print(container['Warnings'], file=sys.stderr) 59 | container_id = container['Id'] 60 | if init_tar is not None: 61 | cli.put_archive(container_id, '/tmp', init_tar) 62 | cli.start(container_id) 63 | try: 64 | for line in cli.logs(container_id, True, True, True): 65 | sys.stdout.buffer.write(line) 66 | result = cli.wait(container_id) 67 | except (Exception, KeyboardInterrupt): 68 | cli.stop(container_id, timeout=2) 69 | raise 70 | if result['StatusCode'] == 0: 71 | return get_cache(cli, container_id) 72 | else: 73 | msg = 'Autotuning failed with status {0[Error]} ({0[StatusCode]})'.format(result) 74 | raise RuntimeError(msg) 75 | finally: 76 | cli.remove_container(container_id) 77 | 78 | 79 | def build(cli, image, base_image, tuning): 80 | tmpdir = tempfile.mkdtemp() 81 | try: 82 | with open(os.path.join(tmpdir, 'Dockerfile'), 'w') as f: 83 | f.write(DOCKERFILE.format(base=base_image)) 84 | with open(os.path.join(tmpdir, 'tuning.db'), 'wb') as f: 85 | f.write(tuning) 86 | for line in cli.build(path=tmpdir, rm=True, tag=image, decode=True): 87 | if 'stream' in line: 88 | sys.stdout.write(line['stream']) 89 | finally: 90 | shutil.rmtree(tmpdir) 91 | 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('image') 96 | parser.add_argument('base_image') 97 | parser.add_argument( 98 | '--copy', action='store_true', 99 | help='Copy old autotuning results from existing image') 100 | parser.add_argument( 101 | '--copy-from', type=str, metavar='IMAGE', 102 | help='Specify alternative image from which to obtain existing results (implies --copy)') 103 | parser.add_argument( 104 | '--skip', action='store_true', 105 | help='Only copy, do not run tuning check afterwards') 106 | parser.add_argument( 107 | '--host', '-H', type=str, default='unix:///var/run/docker.sock', 108 | help='Docker host') 109 | parser.add_argument( 110 | '--tls', action='store_true', 111 | help='Use TLS to connect to Docker daemon') 112 | args = parser.parse_args() 113 | if args.skip and not args.copy and args.copy_from is None: 114 | parser.error('Cannot use --skip without --copy or --copy-from') 115 | 116 | if args.tls: 117 | tls_config = docker.tls.TLSConfig( 118 | client_cert=(os.path.expanduser('~/.docker/cert.pem'), 119 | os.path.expanduser('~/.docker/key.pem')), 120 | verify=os.path.expanduser('~/.docker/ca.pem')) 121 | cli = APIClient(args.host, tls=tls_config) 122 | else: 123 | cli = APIClient(args.host) 124 | 125 | if args.copy_from is not None: 126 | copy_base = args.copy_from 127 | elif args.copy: 128 | copy_base = args.image 129 | else: 130 | copy_base = None 131 | 132 | init_tar = tune(cli, copy_base, True) if copy_base is not None else None 133 | tuned = untar_cache(tune(cli, args.base_image, args.skip, init_tar)) 134 | build(cli, args.image, args.base_image, tuned) 135 | 136 | 137 | if __name__ == '__main__': 138 | sys.exit(main()) 139 | -------------------------------------------------------------------------------- /katsdpingest/utils.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous ingest utilities""" 2 | 3 | import logging 4 | from typing import TypeVar, Tuple 5 | 6 | import katsdptelstate.aio 7 | import aiokatcp 8 | 9 | 10 | _logger = logging.getLogger(__name__) 11 | 12 | 13 | async def cbf_telstate_view(telstate: katsdptelstate.aio.TelescopeState, 14 | stream_name: str) -> katsdptelstate.aio.TelescopeState: 15 | """Create a telstate view that allows querying properties from a stream. 16 | It supports only baseline-correlation-products and 17 | tied-array-channelised-voltage streams. Properties that don't exist on the 18 | stream are searched on the upstream antenna-channelised-voltage stream, 19 | the instrument of that stream, and finally the 'cbf' namespace. 20 | 21 | Returns 22 | ------- 23 | view 24 | Telstate view that allows stream properties to be searched 25 | """ 26 | prefixes = [] 27 | stream_name = stream_name.replace('.', '_').replace('-', '_') 28 | prefixes.append(stream_name) 29 | # Generate a list of places to look for attributes: 30 | # - the stream itself 31 | # - the upstream antenna-channelised-voltage stream, and its instrument 32 | src = (await telstate.view(stream_name, exclusive=True)['src_streams'])[0] 33 | prefixes.append(src) 34 | instrument = await telstate.view(src, exclusive=True)['instrument_dev_name'] 35 | prefixes.append(instrument) 36 | prefixes.append('cbf') 37 | # Create a telstate view that has exactly the given prefixes (and no root prefix). 38 | for i, prefix in enumerate(reversed(prefixes)): 39 | telstate = telstate.view(prefix, exclusive=(i == 0)) 40 | return telstate 41 | 42 | 43 | class Range: 44 | """Representation of a range of values, as specified by a first and a 45 | past-the-end value. This can be seen as an extended form of `range` or 46 | `slice` (although without support for a non-unit step), where it is easy to 47 | query the start and stop values, along with other convenience methods. 48 | 49 | Ranges can be empty, in which case they still have a `start` and `stop` 50 | value that are equal, but the value itself is irrelevant. 51 | """ 52 | def __init__(self, start: int, stop: int) -> None: 53 | if start > stop: 54 | raise ValueError('start must be <= stop') 55 | self.start = start 56 | self.stop = stop 57 | 58 | @classmethod 59 | def parse(cls, value: str) -> 'Range': 60 | """Convert a string of the form 'A:B' to a :class:`~katsdpingest.utils.Range`, 61 | where A and B are integers. 62 | 63 | This is suitable as an argparse type converter. 64 | """ 65 | fields = value.split(':', 1) 66 | if len(fields) != 2: 67 | raise ValueError('Invalid range format {}'.format(value)) 68 | else: 69 | return Range(int(fields[0]), int(fields[1])) 70 | 71 | def __str__(self) -> str: 72 | return '{}:{}'.format(self.start, self.stop) 73 | 74 | def __repr__(self) -> str: 75 | return 'Range({}, {})'.format(self.start, self.stop) 76 | 77 | def __len__(self) -> int: 78 | return self.stop - self.start 79 | 80 | def __contains__(self, value) -> bool: 81 | return self.start <= value < self.stop 82 | 83 | def __eq__(self, other) -> bool: 84 | if not isinstance(other, Range): 85 | return False 86 | if not self: 87 | return not other 88 | else: 89 | return self.start == other.start and self.stop == other.stop 90 | 91 | def __ne__(self, other) -> bool: 92 | return not (self == other) 93 | 94 | # Can't prevent object from being mutated, but __eq__ is defined, so not 95 | # suitable for hashing. 96 | __hash__ = None # type: ignore # keep mypy happy 97 | 98 | def issubset(self, other) -> bool: 99 | return self.start == self.stop or (other.start <= self.start and self.stop <= other.stop) 100 | 101 | def issuperset(self, other) -> bool: 102 | return other.issubset(self) 103 | 104 | def isaligned(self, alignment) -> bool: 105 | """Whether the start and end of this interval are aligned to multiples 106 | of `alignment`. 107 | """ 108 | return not self or (self.start % alignment == 0 and self.stop % alignment == 0) 109 | 110 | def alignto(self, alignment) -> 'Range': 111 | """Return the smallest range containing self for which 112 | ``r.isaligned()`` is true. 113 | """ 114 | if not self: 115 | return self 116 | else: 117 | return Range(self.start // alignment * alignment, 118 | (self.stop + alignment - 1) // alignment * alignment) 119 | 120 | def intersection(self, other) -> 'Range': 121 | start = max(self.start, other.start) 122 | stop = min(self.stop, other.stop) 123 | if start > stop: 124 | return Range(0, 0) 125 | else: 126 | return Range(start, stop) 127 | 128 | def union(self, other) -> 'Range': 129 | """Return the smallest range containing both ranges.""" 130 | if not self: 131 | return other 132 | if not other: 133 | return self 134 | return Range(min(self.start, other.start), max(self.stop, other.stop)) 135 | 136 | def __iter__(self): 137 | return iter(range(self.start, self.stop)) 138 | 139 | def relative_to(self, other) -> 'Range': 140 | """Return a new range that represents `self` as a range relative to 141 | `other` (i.e. where the start element of `other` is numbered 0). If 142 | `self` is an empty range, an undefined empty range is returned. 143 | 144 | Raises 145 | ------ 146 | ValueError 147 | if `self` is not a subset of `other` 148 | """ 149 | if not self.issubset(other): 150 | raise ValueError('self is not a subset of other') 151 | return Range(self.start - other.start, self.stop - other.start) 152 | 153 | def asslice(self) -> slice: 154 | """Return a slice object representing the same range""" 155 | return slice(self.start, self.stop) 156 | 157 | def astuple(self) -> Tuple[int, int]: 158 | """Return a tuple containing the start and end values""" 159 | return (self.start, self.stop) 160 | 161 | def split(self, chunks: int, chunk_id: int) -> 'Range': 162 | """Return the `chunk_id`-th of `chunks` equally-sized pieces. 163 | 164 | Raises 165 | ------ 166 | ValueError 167 | if chunk_id is not in the range [0, chunks) or the range does not 168 | divide evenly. 169 | """ 170 | if not 0 <= chunk_id < chunks: 171 | raise ValueError('chunk_id is out of range') 172 | if len(self) % chunks != 0: 173 | raise ValueError('range {} does not divide into {} chunks'.format(self, chunks)) 174 | chunk_size = len(self) // chunks 175 | return Range(self.start + chunk_id * chunk_size, 176 | self.start + (chunk_id + 1) * chunk_size) 177 | 178 | 179 | _T = TypeVar('_T') 180 | 181 | 182 | class Sensor(aiokatcp.Sensor[_T]): 183 | """Wrapper around :class:`aiokatcp.Sensor` that suppresses redundant updates. 184 | 185 | A value update that doesn't change the value or the status is hidden from 186 | the base class, so that observers using the AUTO strategy don't see 187 | redundant updates. 188 | 189 | It also provides a helper function to increment the value and set an 190 | explicit timestamp. 191 | """ 192 | def set_value(self, value: _T, status: aiokatcp.Sensor.Status = None, 193 | timestamp: float = None) -> None: 194 | if status is None: 195 | status = self.status_func(value) 196 | if value != self.value or status != self.status: 197 | super().set_value(value, status, timestamp) 198 | 199 | def increment(self, delta, timestamp=None): 200 | self.set_value(self.value + delta, timestamp=timestamp) 201 | 202 | 203 | __all__ = ['cbf_telstate_view', 'Range', 'Sensor'] 204 | -------------------------------------------------------------------------------- /test/test_sigproc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Simple test and benchmark of the GPU ingest components""" 4 | 5 | import katsdpsigproc.rfi.device as rfi 6 | import katsdpsigproc.accel as accel 7 | import katsdpingest.sigproc as sp 8 | from katsdpingest.utils import Range 9 | import numpy as np 10 | import argparse 11 | 12 | 13 | def generate_data(vis_in_device, channels, baselines): 14 | vis_in = vis_in_device.empty_like() 15 | vis_in[:] = np.random.normal(scale=32.0, size=(channels, baselines, 2)).astype(np.int32) 16 | return vis_in 17 | 18 | 19 | def create_flagger(context, args): 20 | background = rfi.BackgroundMedianFilterDeviceTemplate( 21 | context, args.width) 22 | noise_est = rfi.NoiseEstMADTDeviceTemplate( 23 | context, args.channels + args.border) 24 | threshold = rfi.ThresholdSumDeviceTemplate(context) 25 | return rfi.FlaggerDeviceTemplate(background, noise_est, threshold) 26 | 27 | 28 | def create_percentile_ranges(antennas): 29 | n_cross = antennas * (antennas - 1) // 2 30 | sections = [ 31 | antennas, # autohh 32 | antennas, # autovv 33 | 2 * antennas, # autohv (each appears as hv and vh) 34 | n_cross, # crosshh 35 | n_cross, # crossvv 36 | 2 * n_cross # crosshv 37 | ] 38 | cuts = np.cumsum([0] + sections) 39 | return [ 40 | (cuts[0], cuts[2]), # autohhvv 41 | (cuts[0], cuts[1]), # autohh 42 | (cuts[1], cuts[2]), # autovv 43 | (cuts[2], cuts[3]), # autohv 44 | (cuts[3], cuts[5]), # crosshhvv 45 | (cuts[3], cuts[4]), # crosshh 46 | (cuts[4], cuts[5]), # crossvv 47 | (cuts[5], cuts[6]) # crosshv 48 | ] 49 | 50 | 51 | def create_template(context, args): 52 | percentile_ranges = create_percentile_ranges(args.mask_antennas) 53 | percentile_sizes = list(set([x[1] - x[0] for x in percentile_ranges])) 54 | return sp.IngestTemplate(context, create_flagger(context, args), percentile_sizes, args.excise) 55 | 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 59 | 60 | parser.add_argument_group('Data selection') 61 | parser.add_argument('--antennas', '-a', type=int, default=7, 62 | help='total number of antennas',) 63 | parser.add_argument('--mask-antennas', type=int, default=None, 64 | help='number of antennas in antenna mask') 65 | parser.add_argument('--channels', '-c', type=int, default=1024, 66 | help='number of frequency channels') 67 | parser.add_argument('--border', '-B', type=int, default=0, help='extra overlap channels') 68 | 69 | parser.add_argument_group('Parameters') 70 | parser.add_argument('--time-avg', '-T', type=int, default=4, 71 | help='number of input dumps per output dump') 72 | parser.add_argument('--freq-avg', '-F', type=int, default=16, 73 | help='number of input channels per continuum channel') 74 | parser.add_argument('--sd-time-avg', type=int, default=4, 75 | help='number of input dumps per signal display dump') 76 | parser.add_argument('--sd-freq-avg', type=int, default=128, 77 | help='number of input channels for signal display channel') 78 | parser.add_argument('--width', '-w', type=int, default=13, 79 | help='median filter kernel size (must be odd)') 80 | parser.add_argument('--sigmas', type=float, default=11.0, 81 | help='threshold for detecting RFI') 82 | parser.add_argument('--no-excise', dest='excise', action='store_false', 83 | help='disable excision of flagged data') 84 | parser.add_argument('--repeat', '-r', type=int, default=8, 85 | help='number of dumps to process') 86 | parser.add_argument('--no-transfer', '-N', action='store_true', 87 | help='skip data transfers') 88 | 89 | args = parser.parse_args() 90 | channels = args.channels + args.border 91 | channel_range = Range(args.border // 2, args.channels + args.border // 2) 92 | cbf_baselines = args.antennas * (args.antennas + 1) * 2 93 | if args.mask_antennas is None: 94 | args.mask_antennas = args.antennas 95 | baselines = args.mask_antennas * (args.mask_antennas + 1) * 2 96 | 97 | context = accel.create_some_context(True) 98 | command_queue = context.create_command_queue(profile=True) 99 | template = create_template(context, args) 100 | proc = template.instantiate( 101 | command_queue, channels, channel_range, 2 * args.mask_antennas, 102 | cbf_baselines, baselines, 103 | args.freq_avg, args.sd_freq_avg, create_percentile_ranges(args.mask_antennas), 104 | threshold_args={'n_sigma': args.sigmas}) 105 | print("{0} bytes required".format(proc.required_bytes())) 106 | proc.ensure_all_bound() 107 | 108 | permutation = np.random.permutation(baselines).astype(np.int16) 109 | permutation = np.r_[permutation, -np.ones(cbf_baselines - baselines, np.int16)] 110 | # The baseline_inputs and input_auto_baseline arrays aren't consistent 111 | # with the percentile ranges, but that doesn't really matter (although 112 | # it may impact memory access patterns and hence performance). 113 | baseline_inputs = [] 114 | input_auto_baseline = np.zeros(2 * args.mask_antennas, np.uint16) 115 | for i in range(2 * args.mask_antennas): 116 | for j in range(i // 2 * 2, 2 * args.mask_antennas): 117 | if i == j: 118 | input_auto_baseline[i] = len(baseline_inputs) 119 | baseline_inputs.append((i, j)) 120 | assert len(baseline_inputs) == baselines 121 | baseline_inputs = np.array(baseline_inputs, np.uint16) 122 | proc.buffer('permutation').set(command_queue, permutation) 123 | proc.buffer('input_auto_baseline').set(command_queue, input_auto_baseline) 124 | proc.buffer('baseline_inputs').set(command_queue, baseline_inputs) 125 | 126 | command_queue.finish() 127 | 128 | vis_in_device = proc.buffer('vis_in') 129 | output_names = ['spec_vis', 'spec_weights', 'spec_weights_channel', 'spec_flags', 130 | 'cont_vis', 'cont_weights', 'cont_weights_channel', 'cont_flags'] 131 | output_buffers = [proc.buffer(name) for name in output_names] 132 | output_arrays = [buf.empty_like() for buf in output_buffers] 133 | sd_names = ['sd_cont_vis', 'sd_cont_flags', 'sd_cont_weights', 'timeseries', 'timeseriesabs'] 134 | for i in range(5): 135 | sd_names.append('percentile{0}'.format(i)) 136 | sd_buffers = [proc.buffer(name) for name in sd_names] 137 | sd_arrays = [buf.empty_like() for buf in sd_buffers] 138 | 139 | dumps = [generate_data(vis_in_device, channels, cbf_baselines) 140 | for i in range(max(args.sd_time_avg, args.time_avg))] 141 | # Push data before we start timing, to ensure everything is allocated 142 | for dump in dumps: 143 | proc.buffer('vis_in').set(command_queue, dump) 144 | 145 | start_event = command_queue.enqueue_marker() 146 | proc.start_sum() 147 | for pass_ in range(args.repeat): 148 | if not args.no_transfer: 149 | proc.buffer('vis_in').set_async(command_queue, dumps[pass_ % len(dumps)]) 150 | proc() 151 | if (pass_ + 1) % args.time_avg == 0: 152 | proc.end_sum() 153 | if not args.no_transfer: 154 | for buf, array in zip(output_buffers, output_arrays): 155 | buf.get_async(command_queue, array) 156 | proc.start_sum() 157 | if (pass_ + 1) % args.sd_time_avg == 0: 158 | proc.end_sd_sum() 159 | if not args.no_transfer: 160 | for buf, array in zip(sd_buffers, sd_arrays): 161 | buf.get_async(command_queue, array) 162 | proc.start_sd_sum() 163 | end_event = command_queue.enqueue_marker() 164 | elapsed_ms = end_event.time_since(start_event) * 1000.0 165 | dump_ms = elapsed_ms / args.repeat 166 | print("{0:.3f}ms ({1:.3f}ms per dump)".format(elapsed_ms, dump_ms)) 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /katsdpingest/sender.py: -------------------------------------------------------------------------------- 1 | """Helper classes encapsulating the details of sending SPEAD streams.""" 2 | 3 | import logging 4 | import asyncio 5 | from typing import List, Dict, Sequence, Any # noqa: F401 6 | 7 | import numpy as np 8 | from katsdptelstate.endpoint import Endpoint 9 | import spead2.send.asyncio 10 | 11 | from .utils import Range 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class Data: 18 | """Bundles visibilities, flags and weights""" 19 | def __init__(self, 20 | vis: np.ndarray, 21 | flags: np.ndarray, 22 | weights: np.ndarray, 23 | weights_channel: np.ndarray) -> None: 24 | self.vis = vis 25 | self.flags = flags 26 | self.weights = weights 27 | self.weights_channel = weights_channel 28 | 29 | def __getitem__(self, idx) -> 'Data': 30 | """Do numpy slicing on all fields at once""" 31 | return Data(self.vis[idx], self.flags[idx], 32 | self.weights[idx], self.weights_channel[idx]) 33 | 34 | @property 35 | def nbytes(self) -> int: 36 | return (self.vis.nbytes + self.flags.nbytes 37 | + self.weights.nbytes + self.weights_channel.nbytes) 38 | 39 | 40 | async def async_send_heap(stream: spead2.send.asyncio.UdpStream, 41 | heap: spead2.send.Heap) -> None: 42 | """Send a heap on a stream and wait for it to complete, but log and 43 | suppress exceptions.""" 44 | try: 45 | await stream.async_send_heap(heap) 46 | except Exception: 47 | _logger.warn("Error sending heap", exc_info=True) 48 | 49 | 50 | class VisSender: 51 | """A single output SPEAD stream of L0 visibility data. 52 | 53 | Parameters 54 | ---------- 55 | thread_pool : `spead2.ThreadPool` 56 | Thread pool servicing the stream 57 | endpoint : `katsdptelstate.endpoint.Endpoint` 58 | Stream endpoint 59 | interface_address : str 60 | IP address of network interface to use, or ``None`` 61 | flavour : `spead2.Flavour` 62 | SPEAD flavour to use on `stream` 63 | int_time : float 64 | Time between dumps, in seconds of wall clock time (which may be 65 | different to data timestamp time if ``--clock-ratio`` is used). 66 | channel_range : :class:`katsdpingest.utils.Range` 67 | Range of channel numbers to be placed into this stream (of those passed to :meth:`send`) 68 | channel0 : int 69 | Index of first channel, within the full bandwidth of the L0 output 70 | all_channels : int 71 | Number of channels in the full L0 output 72 | baselines : int 73 | number of baselines in output 74 | """ 75 | def __init__(self, thread_pool: spead2.ThreadPool, 76 | endpoint: Endpoint, interface_address: str, 77 | flavour: spead2.Flavour, 78 | int_time: float, channel_range: Range, 79 | channel0: int, all_channels: int, baselines: int) -> None: 80 | channels = len(channel_range) 81 | item_size = np.dtype(np.complex64).itemsize + 2 * np.dtype(np.uint8).itemsize 82 | dump_size = channels * baselines * item_size 83 | dump_size += channels * np.dtype(np.float32).itemsize 84 | # Add a guess for SPEAD protocol overhead (including descriptors). This just needs 85 | # to be conservative, to make sure we don't try to send too slow. 86 | dump_size += 2048 87 | # Send slightly faster to allow for other network overheads (e.g. overhead per 88 | # packet, which is a fraction of total size) and to allow us to catch 89 | # up if we temporarily fall behind the rate. 90 | rate = dump_size / int_time * 1.05 if int_time else 0.0 91 | kwargs = {} # type: Dict[str, Any] 92 | if interface_address is not None: 93 | kwargs['interface_address'] = interface_address 94 | kwargs['ttl'] = 1 95 | self._stream = spead2.send.asyncio.UdpStream( 96 | thread_pool, [(endpoint.host, endpoint.port)], 97 | spead2.send.StreamConfig(max_packet_size=8872, rate=rate), **kwargs) 98 | self._stream.set_cnt_sequence(channel0, all_channels) 99 | self._ig = spead2.send.ItemGroup(descriptor_frequency=1, flavour=flavour) 100 | self._channel_range = channel_range 101 | self._channel0 = channel0 102 | self._ig.add_item(id=None, name='correlator_data', 103 | description="Visibilities", 104 | shape=(channels, baselines), dtype=np.complex64) 105 | self._ig.add_item(id=None, name='flags', 106 | description="Flags for visibilities", 107 | shape=(channels, baselines), dtype=np.uint8) 108 | self._ig.add_item(id=None, name='weights', 109 | description="Detailed weights, to be scaled by weights_channel", 110 | shape=(channels, baselines), dtype=np.uint8) 111 | self._ig.add_item(id=None, name='weights_channel', 112 | description="Coarse (per-channel) weights", 113 | shape=(channels,), dtype=np.float32) 114 | self._ig.add_item(id=None, name='timestamp', 115 | description="Seconds since CBF sync time", 116 | shape=(), dtype=None, format=[('f', 64)]) 117 | self._ig.add_item(id=None, name='dump_index', 118 | description='Index in time', 119 | shape=(), dtype=None, format=[('u', 64)]) 120 | self._ig.add_item(id=0x4103, name='frequency', 121 | description="Channel index of first channel in the heap", 122 | shape=(), dtype=np.uint32) 123 | 124 | async def start(self): 125 | """Send a start packet to the stream.""" 126 | await async_send_heap(self._stream, self._ig.get_start()) 127 | 128 | async def stop(self): 129 | """Send a stop packet to the stream. To ensure that it won't be lost 130 | on the sending side, the stream is first flushed, then the stop 131 | heap is sent and waited for.""" 132 | await self._stream.async_flush() 133 | await self._stream.async_send_heap(self._ig.get_end()) 134 | 135 | async def send(self, data, idx, ts_rel): 136 | """Asynchronously send visibilities to the receiver""" 137 | data = data[self._channel_range.asslice()] 138 | self._ig['correlator_data'].value = data.vis 139 | self._ig['flags'].value = data.flags 140 | self._ig['weights'].value = data.weights 141 | self._ig['weights_channel'].value = data.weights_channel 142 | self._ig['timestamp'].value = ts_rel 143 | self._ig['dump_index'].value = idx 144 | self._ig['frequency'].value = self._channel0 145 | await async_send_heap(self._stream, self._ig.get_heap()) 146 | 147 | 148 | class VisSenderSet: 149 | """Manages a collection of :class:`VisSender` objects, and provides similar 150 | functions that work collectively on all the streams. 151 | """ 152 | def __init__(self, 153 | thread_pool: spead2.ThreadPool, 154 | endpoints: Sequence[Endpoint], 155 | interface_address: str, 156 | flavour: spead2.Flavour, 157 | int_time: float, 158 | channel_range: Range, 159 | channel0: int, all_channels: int, baselines: int) -> None: 160 | channels = len(channel_range) 161 | n = len(endpoints) 162 | if channels % n != 0: 163 | raise ValueError('Number of channels not evenly divisible by number of endpoints') 164 | sub_channels = channels // n 165 | self.sub_channels = sub_channels 166 | self._senders = [] # type: List[VisSender] 167 | for i in range(n): 168 | a = channel_range.start + i * sub_channels 169 | b = a + sub_channels 170 | self._senders.append( 171 | VisSender(thread_pool, endpoints[i], interface_address, flavour, int_time, 172 | Range(a, b), channel0 + i * sub_channels, all_channels, baselines)) 173 | 174 | @property 175 | def size(self) -> int: 176 | return len(self._senders) 177 | 178 | async def start(self) -> None: 179 | """Send a start heap to all streams.""" 180 | await asyncio.gather(*(sender.start() for sender in self._senders)) 181 | 182 | async def stop(self) -> None: 183 | """Send a stop heap to all streams.""" 184 | await asyncio.gather(*(sender.stop() for sender in self._senders)) 185 | 186 | async def send(self, data: Data, idx: int, ts_rel: float) -> None: 187 | """Send a data heap to all streams, splitting the data between them.""" 188 | await asyncio.gather(*(sender.send(data, idx, ts_rel) 189 | for sender in self._senders)) 190 | -------------------------------------------------------------------------------- /katsdpingest/test/test_receiver.py: -------------------------------------------------------------------------------- 1 | """Tests for receiver module""" 2 | 3 | from unittest import mock 4 | import asyncio 5 | from typing import Dict, Tuple # noqa: F401 6 | 7 | import numpy as np 8 | import spead2 9 | import spead2.send 10 | import spead2.recv.asyncio 11 | import asynctest 12 | import async_timeout 13 | 14 | from katsdpingest.receiver import Receiver 15 | from katsdpingest.sigproc import Range 16 | from katsdpingest.test.test_ingest_session import fake_cbf_attr 17 | import katsdptelstate.endpoint 18 | from katsdptelstate.endpoint import Endpoint 19 | from nose.tools import assert_equal, assert_is_none, assert_raises 20 | 21 | 22 | class TestReceiver(asynctest.TestCase): 23 | def setUp(self): 24 | self._streams = {} # Dict[Endpoint, spead2.send.InprocStream] 25 | 26 | def add_udp_reader(rx, multicast_group, port, *args, **kwargs): 27 | endpoint = Endpoint(multicast_group, port) 28 | tx = self._streams[endpoint] 29 | rx.add_inproc_reader(tx.queue) 30 | 31 | patcher = mock.patch.object( 32 | spead2.recv.asyncio.Stream, 'add_udp_reader', add_udp_reader) 33 | patcher.start() 34 | self.addCleanup(patcher.stop) 35 | 36 | self.n_streams = 2 37 | endpoints = katsdptelstate.endpoint.endpoint_list_parser(7148)( 38 | '239.0.0.1+{}'.format(self.n_streams - 1)) 39 | self.n_xengs = 4 40 | sensors = mock.MagicMock() 41 | self.cbf_attr = fake_cbf_attr(4, self.n_xengs) 42 | self.n_chans = self.cbf_attr['n_chans'] 43 | self.n_bls = len(self.cbf_attr['bls_ordering']) 44 | tx_thread_pool = spead2.ThreadPool() 45 | self.tx = [spead2.send.InprocStream(tx_thread_pool, [spead2.InprocQueue()]) 46 | for endpoint in endpoints] 47 | self._streams = dict(zip(endpoints, self.tx)) 48 | for tx in self.tx: 49 | # asyncio.iscoroutinefunction doesn't like pybind11 functions, so 50 | # we have to hide it inside a lambda. 51 | self.addCleanup(lambda: tx.queues[0].stop()) 52 | self.rx = Receiver(endpoints, '127.0.0.1', False, self.n_streams, 32 * 1024**2, 53 | Range(0, self.n_chans), self.n_chans, 54 | sensors, self.cbf_attr, active_frames=3) 55 | self.tx_ig = [spead2.send.ItemGroup() for tx in self.tx] 56 | for i, ig in enumerate(self.tx_ig): 57 | ig.add_item(0x1600, 'timestamp', 58 | 'Timestamp of start of this integration. ' 59 | 'uint counting multiples of ADC samples since last sync ' 60 | '(sync_time, id=0x1027). Divide this number by timestamp_scale ' 61 | '(id=0x1046) to get back to seconds since last sync when this ' 62 | 'integration was actually started. Note that the receiver will need ' 63 | 'to figure out the centre timestamp of the accumulation ' 64 | '(eg, by adding half of int_time, id 0x1016).', 65 | (), None, format=[('u', 48)]) 66 | ig.add_item(0x4103, 'frequency', 67 | 'Identifies the first channel in the band of frequencies ' 68 | 'in the SPEAD heap. Can be used to reconstruct the full spectrum.', 69 | (), format=[('u', 48)]) 70 | ig.add_item(0x1800, 'xeng_raw', 71 | 'Raw data stream from all the X-engines in the system. ' 72 | 'For KAT-7, this item represents a full spectrum ' 73 | '(all frequency channels) assembled from lowest frequency ' 74 | 'to highest frequency. Each frequency channel contains the data ' 75 | 'for all baselines (n_bls given by SPEAD Id=0x1008). ' 76 | 'Each value is a complex number - ' 77 | 'two (real and imaginary) signed integers.', 78 | (self.n_chans // self.n_xengs, self.n_bls, 2), np.dtype('>i4')) 79 | for ig, tx in zip(self.tx_ig, self.tx): 80 | tx.send_heap(ig.get_heap()) 81 | 82 | async def test_stop(self): 83 | """The receiver must stop once all streams stop""" 84 | data_future = self.loop.create_task(self.rx.get()) 85 | for ig, tx in zip(self.tx_ig, self.tx): 86 | tx.send_heap(ig.get_end()) 87 | # Check that we get the end-of-stream notification; using a timeout 88 | # to ensure that we don't hang if the test fails. 89 | with assert_raises(spead2.Stopped): 90 | with async_timeout.timeout(30): 91 | await data_future 92 | 93 | def _make_data(self, n_frames): 94 | """Generates made-up timestamps and correlator data 95 | 96 | Parameters 97 | ---------- 98 | n_frames : int 99 | Number of frames to generate 100 | 101 | Returns 102 | ------- 103 | xeng_raw : np.ndarray 104 | 5D array of integer correlator data, indexed by time, stream, 105 | channel, baseline, and real/complex 106 | indices : np.ndarray 107 | 1D array of input dump indices 108 | timestamps : np.ndarray 109 | 1D array of timestamps 110 | """ 111 | xeng_raw = np.random.uniform( 112 | -1000, 1000, 113 | size=(n_frames, self.n_xengs, self.n_chans // self.n_xengs, self.n_bls, 2)) 114 | xeng_raw = xeng_raw.astype('>i4') 115 | interval = self.cbf_attr['ticks_between_spectra'] * self.cbf_attr['n_accs'] 116 | indices = np.arange(n_frames, dtype=np.uint64) 117 | timestamps = indices * interval + 1234567890123 118 | return xeng_raw, indices, timestamps 119 | 120 | async def _send_in_order(self, xeng_raw, timestamps): 121 | for t in range(len(xeng_raw)): 122 | for i in range(self.n_xengs): 123 | stream_idx = i * self.n_streams // self.n_xengs 124 | self.tx_ig[stream_idx]['timestamp'].value = timestamps[t] 125 | self.tx_ig[stream_idx]['frequency'].value = i * self.n_chans // self.n_xengs 126 | self.tx_ig[stream_idx]['xeng_raw'].value = xeng_raw[t, i] 127 | self.tx[stream_idx].send_heap(self.tx_ig[stream_idx].get_heap()) 128 | await asyncio.sleep(0.02) 129 | for i in range(self.n_streams): 130 | self.tx[i].send_heap(self.tx_ig[i].get_end()) 131 | 132 | async def test_in_order(self): 133 | """Test normal case with data arriving in the expected order""" 134 | n_frames = 10 135 | xeng_raw, indices, timestamps = self._make_data(n_frames) 136 | send_future = asyncio.ensure_future(self._send_in_order(xeng_raw, timestamps)) 137 | for t in range(n_frames): 138 | frame = await asyncio.wait_for(self.rx.get(), 3) 139 | assert_equal(indices[t], frame.idx) 140 | assert_equal(timestamps[t], frame.timestamp) 141 | assert_equal(self.n_xengs, len(frame.items)) 142 | for i in range(self.n_xengs): 143 | np.testing.assert_equal(xeng_raw[t, i], frame.items[i]) 144 | with assert_raises(spead2.Stopped): 145 | with async_timeout.timeout(3): 146 | await self.rx.get() 147 | await send_future 148 | 149 | async def _send_out_of_order(self, xeng_raw, timestamps): 150 | order = [ 151 | # Send parts of frames 0, 1 152 | (0, 0), (0, 1), (0, 3), 153 | (1, 1), (1, 3), (1, 2), 154 | # Finish frame 1, start frame 2 155 | (1, 0), (2, 2), 156 | # Finish frame 0; frames 0, 1 should flush 157 | (0, 2), 158 | # Jump ahead by more than the window; frames 2-4 should be flushed/dropped 159 | (7, 0), 160 | # Finish off frame 2; should be discarded 161 | (2, 0), (2, 1), (2, 3), 162 | # Fill in a frame that's not at the start of the window 163 | (6, 0), (6, 1), (6, 2), (6, 3), 164 | # Force the window to advance, flushing 6 165 | (9, 0), (9, 2), 166 | # Fill in frame that's not at the start of the window; it should flush 167 | # when the stream stops 168 | (8, 0), (9, 3), (8, 1), (8, 3), (8, 2) 169 | ] 170 | for (t, i) in order: 171 | stream_idx = i * self.n_streams // self.n_xengs 172 | self.tx_ig[stream_idx]['timestamp'].value = timestamps[t] 173 | self.tx_ig[stream_idx]['frequency'].value = i * self.n_chans // self.n_xengs 174 | self.tx_ig[stream_idx]['xeng_raw'].value = xeng_raw[t, i] 175 | self.tx[stream_idx].send_heap(self.tx_ig[stream_idx].get_heap()) 176 | # Longish sleep to ensure the ordering is respected 177 | await asyncio.sleep(0.02) 178 | for i in range(self.n_streams): 179 | self.tx[i].send_heap(self.tx_ig[i].get_end()) 180 | 181 | async def test_out_of_order(self): 182 | """Test various edge behaviour for out-of-order data""" 183 | n_frames = 10 184 | xeng_raw, indices, timestamps = self._make_data(n_frames) 185 | send_future = self.loop.create_task(self._send_out_of_order(xeng_raw, timestamps)) 186 | try: 187 | for t, missing in [(0, []), (1, []), (2, [0, 1, 3]), (6, []), 188 | (7, [1, 2, 3]), (8, []), (9, [1])]: 189 | with async_timeout.timeout(3): 190 | frame = await self.rx.get() 191 | assert_equal(indices[t], frame.idx) 192 | assert_equal(timestamps[t], frame.timestamp) 193 | assert_equal(self.n_xengs, len(frame.items)) 194 | for i in range(self.n_xengs): 195 | if i in missing: 196 | assert_is_none(frame.items[i]) 197 | else: 198 | np.testing.assert_equal(xeng_raw[t, i], frame.items[i]) 199 | with assert_raises(spead2.Stopped): 200 | with async_timeout.timeout(3): 201 | await self.rx.get() 202 | finally: 203 | await send_future 204 | -------------------------------------------------------------------------------- /scripts/ingest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Capture utility for a relatively generic packetised correlator data output stream. 4 | 5 | import logging 6 | import signal 7 | import asyncio 8 | import argparse 9 | from typing import List, Callable, TypeVar 10 | 11 | import katsdpservices 12 | from katsdpsigproc import accel 13 | from katsdptelstate import endpoint 14 | import katsdptelstate.aio.redis 15 | import katsdpmodels.fetch.aiohttp 16 | 17 | from katsdpingest.ingest_session import ChannelRanges, SystemAttrs 18 | from katsdpingest.utils import Range, cbf_telstate_view 19 | from katsdpingest.ingest_server import IngestDeviceServer 20 | 21 | 22 | logger = logging.getLogger("katsdpingest.ingest") 23 | 24 | 25 | _T = TypeVar('_T') 26 | 27 | 28 | def comma_list(type_: Callable[..., _T]) -> Callable[[str], List[_T]]: 29 | """Return a function which splits a string on commas and converts each element to 30 | `type_`.""" 31 | 32 | def convert(arg: str) -> List[_T]: 33 | return [type_(x) for x in arg.split(',')] 34 | return convert 35 | 36 | 37 | def parse_args() -> argparse.Namespace: 38 | parser = katsdpservices.ArgumentParser() 39 | parser.add_argument( 40 | '--sdisp-spead', type=endpoint.endpoint_list_parser(7149), 41 | default=[], metavar='ENDPOINT', 42 | help=('signal display destination. Either single endpoint or comma-separated list. ' 43 | '[default=%(default)s]')) 44 | parser.add_argument( 45 | '--sdisp-interface', metavar='INTERFACE', 46 | help='interface on which to send signal display data [default=auto]') 47 | parser.add_argument( 48 | '--cbf-spead', type=endpoint.endpoint_list_parser(7148), 49 | default=':7148', metavar='ENDPOINTS', 50 | help=('endpoints to listen for CBF SPEAD stream (including multicast IPs). ' 51 | '[[+]][:port]. [default=%(default)s]')) 52 | parser.add_argument( 53 | '--cbf-interface', metavar='INTERFACE', 54 | help='interface to subscribe to for CBF SPEAD data. [default=auto]') 55 | parser.add_argument( 56 | '--cbf-ibv', action='store_true', 57 | help='use ibverbs acceleration for CBF SPEAD data [default=no].') 58 | parser.add_argument( 59 | '--cbf-name', 60 | help='name of the baseline correlation products stream') 61 | parser.add_argument( 62 | '--l0-spectral-spead', type=endpoint.endpoint_list_parser(7200), metavar='ENDPOINTS', 63 | help='destination for spectral L0 output. [default=do not send]') 64 | parser.add_argument( 65 | '--l0-spectral-interface', metavar='INTERFACE', 66 | help='interface on which to send spectral L0 output. [default=auto]') 67 | parser.add_argument( 68 | '--l0-spectral-name', default='sdp_l0', metavar='NAME', 69 | help='telstate name of the spectral output stream') 70 | parser.add_argument( 71 | '--l0-continuum-spead', type=endpoint.endpoint_list_parser(7201), metavar='ENDPOINTS', 72 | help='destination for continuum L0 output. [default=do not send]') 73 | parser.add_argument( 74 | '--l0-continuum-interface', metavar='INTERFACE', 75 | help='interface on which to send continuum L0 output. [default=auto]') 76 | parser.add_argument( 77 | '--l0-continuum-name', default='sdp_l0_continuum', metavar='NAME', 78 | help='telstate name of the continuum output stream') 79 | parser.add_argument( 80 | '--output-int-time', default=2.0, type=float, 81 | help='seconds between output dumps (will be quantised). [default=%(default)s]') 82 | parser.add_argument( 83 | '--sd-int-time', default=2.0, type=float, 84 | help='seconds between signal display updates (will be quantised). [default=%(default)s]') 85 | parser.add_argument( 86 | '--antenna-mask', default=None, type=comma_list(str), 87 | help='comma-separated list of antennas to keep. [default=all]') 88 | parser.add_argument( 89 | '--output-channels', type=Range.parse, 90 | help='output spectral channels, in format A:B [default=all]') 91 | parser.add_argument( 92 | '--sd-output-channels', type=Range.parse, 93 | help='signal display channels, in format A:B [default=all]') 94 | parser.add_argument( 95 | '--continuum-factor', default=16, type=int, 96 | help='factor by which to reduce number of channels. [default=%(default)s]') 97 | parser.add_argument( 98 | '--sd-continuum-factor', default=128, type=int, 99 | help=('factor by which to reduce number of channels for signal display. ' 100 | '[default=%(default)s]')) 101 | parser.add_argument( 102 | '--guard-channels', default=64, type=int, 103 | help='extra channels to use for RFI detection. [default=%(default)s]') 104 | parser.add_argument( 105 | '--input-streams', default=1, type=int, 106 | help='maximum separate streams for receive. [default=%(default)s]') 107 | parser.add_argument( 108 | '--input-buffer', default=64 * 1024**2, type=int, 109 | help='network buffer size for input. [default=%(default)s]') 110 | parser.add_argument( 111 | '--sd-spead-rate', type=float, default=1000000000, 112 | help='rate (bits per second) to transmit signal display output. [default=%(default)s]') 113 | parser.add_argument( 114 | '--no-excise', dest='excise', action='store_false', 115 | help='disable excision of flagged data [default=no]') 116 | parser.add_argument( 117 | '--use-data-suspect', action='store_true', 118 | help=('use the CAM-provided input-data-suspect and channel-data-suspect ' 119 | 'sensors to flag data [default=no]')) 120 | parser.add_argument( 121 | '--servers', type=int, default=1, 122 | help='number of parallel servers producing the output [default=%(default)s]') 123 | parser.add_argument( 124 | '--server-id', type=int, default=1, 125 | help='index of this server amongst parallel servers (1-based) [default=%(default)s]') 126 | parser.add_aiomonitor_arguments() 127 | parser.add_argument( 128 | '--clock-ratio', type=float, default=1.0, 129 | help='Scale factor for transmission rate, smaller is faster [default=%(default)s]') 130 | parser.add_argument( 131 | '-p', '--port', type=int, default=2040, metavar='N', 132 | help='katcp host port. [default=%(default)s]') 133 | parser.add_argument( 134 | '-a', '--host', type=str, default="", metavar='HOST', 135 | help='katcp host address. [default=all hosts]') 136 | parser.add_argument( 137 | '-l', '--log-level', type=str, default=None, metavar='LEVEL', 138 | help='log level to use') 139 | args = parser.parse_args() 140 | if args.telstate is None: 141 | parser.error('argument --telstate is required') 142 | if args.cbf_ibv and args.cbf_interface is None: 143 | parser.error('--cbf-ibv requires --cbf-interface') 144 | if args.cbf_name is None: 145 | parser.error('--cbf-name is required') 146 | if not 1 <= args.server_id <= args.servers: 147 | parser.error('--server-id is out of range') 148 | if args.l0_spectral_spead is None and args.l0_continuum_spead is None: 149 | parser.error('at least one of --l0-spectral-spead and --l0-continuum-spead must be given') 150 | return args 151 | 152 | 153 | async def on_shutdown(server: IngestDeviceServer) -> None: 154 | # Disable the signal handlers, to avoid being unable to kill if there 155 | # is an exception in the shutdown path. 156 | for sig in [signal.SIGINT, signal.SIGTERM]: 157 | asyncio.get_event_loop().remove_signal_handler(sig) 158 | logger.info("Shutting down katsdpingest server...") 159 | await server.handle_interrupt() 160 | server.halt() 161 | 162 | 163 | async def get_async_telstate(endpoint: katsdptelstate.endpoint.Endpoint): 164 | backend = await katsdptelstate.aio.redis.RedisBackend.from_url( 165 | f'redis://{endpoint.host}:{endpoint.port}' 166 | ) 167 | return katsdptelstate.aio.TelescopeState(backend) 168 | 169 | 170 | async def main() -> None: 171 | katsdpservices.setup_logging() 172 | katsdpservices.setup_restart() 173 | args = parse_args() 174 | if args.log_level is not None: 175 | logging.root.setLevel(args.log_level.upper()) 176 | 177 | loop = asyncio.get_event_loop() 178 | telstate = await get_async_telstate(args.telstate_endpoint) 179 | telstate_cbf = await cbf_telstate_view(telstate, args.cbf_name) 180 | async with katsdpmodels.fetch.aiohttp.TelescopeStateFetcher(telstate) as fetcher: 181 | system_attrs = await SystemAttrs.create(fetcher, telstate_cbf, args.antenna_mask) 182 | cbf_channels = system_attrs.cbf_attr['n_chans'] 183 | if args.output_channels is None: 184 | args.output_channels = Range(0, cbf_channels) 185 | if args.sd_output_channels is None: 186 | args.sd_output_channels = Range(0, cbf_channels) 187 | # If no continuum product is selected, set continuum factor to 1 since 188 | # that effectively disables the alignment checks. 189 | continuum_factor = args.continuum_factor if args.l0_continuum_spead else 1 190 | # TODO: determine an appropriate value for guard 191 | channel_ranges = ChannelRanges( 192 | args.servers, args.server_id - 1, 193 | cbf_channels, continuum_factor, args.sd_continuum_factor, 194 | len(args.cbf_spead), args.guard_channels, args.output_channels, args.sd_output_channels) 195 | context = accel.create_some_context(interactive=False) 196 | server = IngestDeviceServer(args, telstate_cbf, channel_ranges, system_attrs, context, 197 | args.host, args.port) 198 | 199 | loop.add_signal_handler(signal.SIGINT, lambda: loop.create_task(on_shutdown(server))) 200 | loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(on_shutdown(server))) 201 | await server.start() 202 | logger.info("Started katsdpingest server.") 203 | with katsdpservices.start_aiomonitor(loop, args, locals()): 204 | await server.join() 205 | telstate.backend.close() 206 | await telstate.backend.wait_closed() 207 | logger.info("Shutdown complete") 208 | 209 | 210 | if __name__ == '__main__': 211 | loop = asyncio.get_event_loop() 212 | try: 213 | loop.run_until_complete(main()) 214 | finally: 215 | loop.run_until_complete(loop.shutdown_asyncgens()) 216 | loop.close() 217 | -------------------------------------------------------------------------------- /katsdpingest/ingest_server.py: -------------------------------------------------------------------------------- 1 | """katcp server for ingest.""" 2 | 3 | import time 4 | import logging 5 | import argparse 6 | from typing import List, Dict, Mapping, Any, cast # noqa: F401 7 | 8 | import aiokatcp 9 | from aiokatcp import FailReply, SensorSampler 10 | import katsdptelstate.aio 11 | from katsdptelstate.endpoint import endpoint_parser 12 | 13 | import katsdpingest 14 | from .ingest_session import CBFIngest, Status, DeviceStatus, ChannelRanges, SystemAttrs 15 | from . import receiver 16 | from .utils import Sensor 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def _warn_if_positive(value: float) -> aiokatcp.Sensor.Status: 23 | """Status function for sensors that count problems""" 24 | return Sensor.Status.WARN if value > 0 else Sensor.Status.NOMINAL 25 | 26 | 27 | def _device_status_status(value: DeviceStatus) -> aiokatcp.Sensor.Status: 28 | """Sets katcp status for device-status sensor from value""" 29 | if value == DeviceStatus.OK: 30 | return Sensor.Status.NOMINAL 31 | elif value == DeviceStatus.DEGRADED: 32 | return Sensor.Status.WARN 33 | else: 34 | return Sensor.Status.ERROR 35 | 36 | 37 | class IngestDeviceServer(aiokatcp.DeviceServer): 38 | """Serves the ingest katcp interface. 39 | Top level holder of the ingest session. 40 | 41 | Parameters 42 | ---------- 43 | user_args : :class:`argparse.Namespace` 44 | Command-line arguments 45 | telstate_cbf : :class:`katsdptelstate.aio.TelescopeState` 46 | Asynchronous client for the telescope state, as returned by 47 | :func:`katsdpingest.utils.cbf_telstate_view`. 48 | channel_ranges : :class:`katsdpingest.ingest_session.ChannelRanges` 49 | Ranges of channels for various parts of the pipeline 50 | system_attrs : :class:`katsdpingest.ingest_session.SystemAttrs` 51 | System configuration, as returned by 52 | :meth:`katsdpingest.ingest_session.SystemAttrs.create`. 53 | context : :class:`katsdpsigproc.cuda.Context` or :class:`katsdpsigproc.opencl.Context` 54 | Context in which to compile device code and allocate resources 55 | args, kwargs 56 | Passed to :class:`aiokatcp.DeviceServer` 57 | """ 58 | 59 | VERSION = "sdp-ingest-0.2" 60 | BUILD_STATE = 'katsdpingest-' + katsdpingest.__version__ 61 | 62 | def __init__( 63 | self, 64 | user_args: argparse.Namespace, 65 | telstate_cbf: katsdptelstate.aio.TelescopeState, 66 | channel_ranges: ChannelRanges, 67 | system_attrs: SystemAttrs, 68 | context, *args, **kwargs) -> None: 69 | super().__init__(*args, **kwargs) 70 | self._stopping = False 71 | 72 | def counter(name: str, description: str, *, 73 | event_rate: bool = False, 74 | warn_if_positive: bool = False, 75 | **kwargs: Any) -> Sensor: 76 | if event_rate: 77 | kwargs['auto_strategy'] = SensorSampler.Strategy.EVENT_RATE 78 | kwargs['auto_strategy_parameters'] = (0.05, 10.0) 79 | if warn_if_positive: 80 | kwargs['status_func'] = _warn_if_positive 81 | return Sensor(int, name, description + ' (prometheus: counter)', 82 | initial_status=Sensor.Status.NOMINAL, **kwargs) 83 | 84 | sensors = [ 85 | Sensor(int, "output-n-ants", 86 | "Number of antennas in L0 stream (prometheus: gauge)"), 87 | Sensor(int, "output-n-inputs", 88 | "Number of single-pol signals in L0 stream (prometheus: gauge)"), 89 | Sensor(int, "output-n-bls", 90 | "Number of baseline products in L0 stream (prometheus: gauge)"), 91 | Sensor(int, "output-n-chans", 92 | "Number of channels this server contributes to L0 spectral stream " 93 | "(prometheus: gauge)"), 94 | Sensor(float, "output-int-time", 95 | "Integration time of L0 stream (prometheus: gauge)", "s"), 96 | Sensor(bool, "capture-active", 97 | "Is there a currently active capture session (prometheus: gauge)", 98 | default=False, initial_status=Sensor.Status.NOMINAL), 99 | Sensor(Status, "status", 100 | "The current status of the capture session.", 101 | default=Status.INIT, initial_status=Sensor.Status.NOMINAL), 102 | Sensor(float, "last-dump-timestamp", 103 | "Timestamp of most recently received correlator dump in Unix seconds " 104 | "(prometheus: gauge)", "s", 105 | default=0.0, initial_status=Sensor.Status.NOMINAL), 106 | Sensor(DeviceStatus, "device-status", 107 | "Health status", 108 | default=DeviceStatus.OK, initial_status=Sensor.Status.NOMINAL, 109 | status_func=_device_status_status), 110 | counter("input-bytes-total", 111 | "Number of payload bytes received from CBF in this session", 112 | event_rate=True), 113 | counter("input-heaps-total", 114 | "Number of payload heaps received from CBF in this session", 115 | event_rate=True), 116 | counter("input-dumps-total", 117 | "Number of CBF dumps received in this session", 118 | event_rate=True), 119 | counter("input-metadata-heaps-total", 120 | "Number of heaps that do not contain payload in this session", 121 | event_rate=True), 122 | counter("output-bytes-total", 123 | "Number of payload bytes sent on L0 in this session"), 124 | counter("output-heaps-total", 125 | "Number of payload heaps sent on L0 in this session"), 126 | counter("output-dumps-total", 127 | "Number of payload dumps sent on L0 in this session"), 128 | counter("output-vis-total", 129 | "Number of spectral visibilities computed for signal displays in this session"), 130 | counter("output-flagged-total", 131 | "Number of flagged visibilities (out of output-vis-total)"), 132 | Sensor(bool, "descriptors-received", 133 | "Whether the SPEAD descriptors have been received " 134 | " (prometheus: gauge)", 135 | initial_status=Sensor.Status.NOMINAL) 136 | ] # type: List[Sensor] 137 | for key, value in receiver.REJECT_HEAP_TYPES.items(): 138 | sensors.append(counter( 139 | "input-" + key + "-heaps-total", 140 | "Number of heaps rejected because {}".format(value), 141 | event_rate=True, warn_if_positive=True)) 142 | for sensor in sensors: 143 | self.sensors.add(sensor) 144 | 145 | # create the device resources 146 | self.cbf_ingest = CBFIngest( 147 | user_args, system_attrs, channel_ranges, context, 148 | cast(Mapping[str, Sensor], self.sensors), telstate_cbf) 149 | # add default or user specified endpoints 150 | for sdisp_endpoint in user_args.sdisp_spead: 151 | self.cbf_ingest.add_sdisp_ip(sdisp_endpoint) 152 | 153 | async def start(self) -> None: 154 | await self.cbf_ingest.populate_telstate() 155 | await super().start() 156 | 157 | async def request_enable_debug(self, ctx) -> str: 158 | """Enable debugging of the ingest process.""" 159 | self.cbf_ingest.enable_debug(True) 160 | return "Debug logging enabled." 161 | 162 | async def request_disable_debug(self, ctx) -> str: 163 | """Disable debugging of the ingest process.""" 164 | self.cbf_ingest.enable_debug(False) 165 | return "Debug logging disabled." 166 | 167 | async def request_internal_log_level(self, ctx, 168 | component: str = None, level: str = None) -> None: 169 | """ 170 | Set the log level of an internal component to the specified value. 171 | ?internal-log-level 172 | 173 | If is omitted, the current log level is shown. If component 174 | is also omitted, the current log levels of all loggers are shown. 175 | """ 176 | # manager isn't part of the documented API, so mypy/typeshed doesn't know about it 177 | manager = logging.Logger.manager # type: ignore 178 | # Filter out placeholders 179 | loggers = {name: logger for (name, logger) in manager.loggerDict.items() 180 | if isinstance(logger, logging.Logger)} 181 | loggers[''] = logging.getLogger() # Not kept in loggerDict 182 | if component is None: 183 | ctx.informs((name, logging.getLevelName(logger.level)) 184 | for name, logger in sorted(loggers.items())) 185 | elif level is None: 186 | logger = loggers.get(component) 187 | if logger is None: 188 | raise FailReply('Unknown logger component {}'.format(component)) 189 | else: 190 | ctx.informs([(component, logging.getLevelName(logger.level))]) 191 | else: 192 | level = level.upper() 193 | logger = logging.getLogger(component) 194 | try: 195 | logger.setLevel(level) 196 | except ValueError: 197 | raise FailReply("Unknown log level specified {}".format(level)) 198 | 199 | async def request_capture_init(self, ctx, capture_block_id: str) -> str: 200 | """Spawns ingest session to capture suitable data to produce 201 | the L0 output stream.""" 202 | if self.cbf_ingest.capturing: 203 | raise FailReply( 204 | "Existing capture session found. " 205 | "If you really want to init, stop the current capture using capture-done.") 206 | if self._stopping: 207 | raise FailReply("Cannot start a capture session while ingest is shutting down") 208 | 209 | self.cbf_ingest.start(capture_block_id) 210 | 211 | self.sensors["capture-active"].value = True 212 | smsg = "Capture initialised at %s" % time.ctime() 213 | logger.info(smsg) 214 | return smsg 215 | 216 | async def request_drop_sdisp_ip(self, ctx, ip: str) -> str: 217 | """Drop an IP address from the internal list of signal display data recipients.""" 218 | try: 219 | await self.cbf_ingest.drop_sdisp_ip(ip) 220 | except KeyError: 221 | raise FailReply( 222 | "The IP address specified ({}) does not exist " 223 | "in the current list of recipients.".format(ip)) 224 | else: 225 | return "The IP address has been dropped as a signal display recipient" 226 | 227 | async def request_add_sdisp_ip(self, ctx, ip: str) -> str: 228 | """Add the supplied ip and port (ip[:port]) to the list of signal 229 | display data recipients.If not port is supplied default of 7149 is 230 | used.""" 231 | endpoint = endpoint_parser(7149)(ip) 232 | try: 233 | self.cbf_ingest.add_sdisp_ip(endpoint) 234 | except ValueError: 235 | return "The supplied IP is already in the active list of recipients." 236 | else: 237 | return "Added {} to list of signal display data recipients.".format(endpoint) 238 | 239 | async def handle_interrupt(self) -> None: 240 | """Used to attempt a graceful resolution to external 241 | interrupts. Basically calls capture done.""" 242 | self._stopping = True # Prevent a capture-init during shutdown 243 | if self.cbf_ingest.capturing: 244 | await self.request_capture_done(None) 245 | self.cbf_ingest.close() 246 | 247 | async def request_capture_done(self, ctx) -> str: 248 | """Stops the current capture.""" 249 | if not self.cbf_ingest.capturing: 250 | raise FailReply("fail", "No existing capture session.") 251 | 252 | stopped = await self.cbf_ingest.stop() 253 | 254 | # In the case of concurrent connections, we need to ensure that we 255 | # were the one that actually did the stop, as another connection may 256 | # have raced us to stop and then started a new session. 257 | if stopped: 258 | self.sensors["capture-active"].value = False 259 | # Error states were associated with the session, which is now dead. 260 | self.sensors["device-status"].value = DeviceStatus.OK 261 | logger.info("capture complete") 262 | return "capture complete" 263 | -------------------------------------------------------------------------------- /katsdpingest/test/test_ingest_session.py: -------------------------------------------------------------------------------- 1 | """Tests for the ingest_session module""" 2 | 3 | import logging 4 | import re 5 | from unittest import mock 6 | 7 | import numpy as np 8 | from nose.tools import ( 9 | assert_equal, assert_is, assert_regex, assert_is_none, assert_is_not_none, 10 | assert_logs, assert_raises, assert_true 11 | ) 12 | import asynctest 13 | from katsdpsigproc.test.test_accel import device_test 14 | import katsdptelstate.aio 15 | import katsdpmodels.fetch.aiohttp 16 | import katsdpmodels.rfi_mask 17 | import katsdpmodels.band_mask 18 | import katpoint 19 | 20 | from katsdpingest import ingest_session 21 | from katsdpingest.utils import Range, cbf_telstate_view 22 | 23 | 24 | def fake_cbf_attr(n_antennas, n_xengs=4): 25 | cbf_attr = dict( 26 | scale_factor_timestamp=1712000000.0, 27 | n_chans=4096, 28 | n_chans_per_substream=1024, 29 | n_accs=408 * 256, 30 | sync_time=1400000000.0 31 | ) 32 | cbf_attr['bandwidth'] = cbf_attr['scale_factor_timestamp'] / 2 33 | cbf_attr['center_freq'] = cbf_attr['bandwidth'] * 3 / 2 # Reasonable for L band 34 | cbf_attr['ticks_between_spectra'] = 2 * cbf_attr['n_chans'] 35 | cbf_attr['n_chans_per_substream'] = cbf_attr['n_chans'] // n_xengs 36 | cbf_attr['int_time'] = (cbf_attr['n_accs'] * cbf_attr['ticks_between_spectra'] 37 | / cbf_attr['scale_factor_timestamp']) 38 | bls_ordering = [] 39 | input_labels = [] 40 | antennas = ['m{:03}'.format(90 + i) for i in range(n_antennas)] 41 | for ib, b in enumerate(antennas): 42 | for a in antennas[:ib+1]: 43 | bls_ordering.append((a + 'h', b + 'h')) 44 | bls_ordering.append((a + 'v', b + 'v')) 45 | bls_ordering.append((a + 'h', b + 'v')) 46 | bls_ordering.append((a + 'v', b + 'h')) 47 | input_labels.append(b + 'h') 48 | input_labels.append(b + 'v') 49 | cbf_attr['bls_ordering'] = np.array(bls_ordering) 50 | cbf_attr['input_labels'] = input_labels 51 | return cbf_attr 52 | 53 | 54 | class TestSystemAttrs(asynctest.TestCase): 55 | async def setUp(self): 56 | self.values = { 57 | 'i0_bandwidth': 856000000.0, 58 | 'i0_sync_time': 1234567890.0, 59 | 'i0_scale_factor_timestamp': 1712000000.0, 60 | 'i0_antenna_channelised_voltage_instrument_dev_name': 'i0', 61 | 'i0_antenna_channelised_voltage_n_chans': 262144, # Different, to check precedence 62 | 'i0_antenna_channelised_voltage_center_freq': 1284000000.0, 63 | 'i0_antenna_channelised_voltage_ticks_between_spectra': 8192, 64 | 'i0_antenna_channelised_voltage_input_labels': ['m001h', 'm001v'], 65 | 'i1_baseline_correlation_products_src_streams': ['i0_antenna_channelised_voltage'], 66 | 'i1_baseline_correlation_products_instrument_dev_name': 'i1', 67 | 'i1_baseline_correlation_products_int_time': 0.499, 68 | 'i1_baseline_correlation_products_n_chans': 4096, 69 | 'i1_baseline_correlation_products_n_chans_per_substream': 256, 70 | 'i1_baseline_correlation_products_n_accs': 104448, 71 | 'i1_baseline_correlation_products_bls_ordering': [('m001h', 'm001h')], 72 | 'm001_observer': 'm001, -30:42:39.8, 21:26:38.0, 1035.0, 13.5, 1.126 -171.761 1.0605 5868.979 5869.998, -0:42:08.0 0 0:01:44.0 0:01:11.9 -0:00:14.0 -0:00:21.0 -0:36:13.1 0:01:36.2, 1.14', # noqa: E501 73 | 'sdp_model_base_url': 'https://test.invalid/models/', 74 | 'model_rfi_mask_fixed': 'rfi_mask/fixed/dummy.h5', 75 | 'i0_antenna_channelised_voltage_model_band_mask_fixed': 'band_mask/fixed/dummy.h5' 76 | } 77 | self.telstate = katsdptelstate.aio.TelescopeState() 78 | for key, value in self.values.items(): 79 | await self.telstate.set(key, value) 80 | self.expected_cbf_attr = { 81 | 'n_chans': 4096, 82 | 'n_chans_per_substream': 256, 83 | 'n_accs': 104448, 84 | 'bls_ordering': [('m001h', 'm001h')], 85 | 'input_labels': ['m001h', 'm001v'], 86 | 'bandwidth': 856000000.0, 87 | 'center_freq': 1284000000.0, 88 | 'sync_time': 1234567890.0, 89 | 'int_time': 0.499, 90 | 'scale_factor_timestamp': 1712000000.0, 91 | 'ticks_between_spectra': 8192 92 | } 93 | 94 | async def test(self): 95 | telstate_cbf = await cbf_telstate_view(self.telstate, 96 | 'i1_baseline_correlation_products') 97 | with asynctest.patch('katsdpmodels.fetch.aiohttp.Fetcher.get', autospec=True) as fetch: 98 | async with katsdpmodels.fetch.aiohttp.TelescopeStateFetcher(self.telstate) as fetcher: 99 | attrs = await ingest_session.SystemAttrs.create( 100 | fetcher, 101 | telstate_cbf, 102 | ['m000', 'm001'] 103 | ) 104 | assert_equal(self.expected_cbf_attr, attrs.cbf_attr) 105 | assert_equal([katpoint.Antenna(self.values['m001_observer'])], 106 | attrs.antennas) 107 | assert_is_not_none(attrs.rfi_mask_model) 108 | assert_is_not_none(attrs.band_mask_model) 109 | assert_equal(fetch.mock_calls, [ 110 | mock.call(mock.ANY, 'https://test.invalid/models/rfi_mask/fixed/dummy.h5', 111 | katsdpmodels.rfi_mask.RFIMask), 112 | mock.call(mock.ANY, 'https://test.invalid/models/band_mask/fixed/dummy.h5', 113 | katsdpmodels.band_mask.BandMask) 114 | ]) 115 | 116 | async def test_no_models(self): 117 | await self.telstate.delete('sdp_model_base_url') 118 | telstate_cbf = await cbf_telstate_view(self.telstate, 119 | 'i1_baseline_correlation_products') 120 | async with katsdpmodels.fetch.aiohttp.TelescopeStateFetcher(self.telstate) as fetcher: 121 | with assert_logs(level=logging.WARNING) as cm: 122 | attrs = await ingest_session.SystemAttrs.create( 123 | fetcher, 124 | telstate_cbf, 125 | ['m000', 'm001'] 126 | ) 127 | assert_regex(cm.output[0], re.compile('.*Failed to load rfi_mask model.*', re.M)) 128 | assert_regex(cm.output[1], re.compile('.*Failed to load band_mask model.*', re.M)) 129 | assert_is_none(attrs.rfi_mask_model) 130 | assert_is_none(attrs.band_mask_model) 131 | 132 | 133 | class TestTimeAverage(asynctest.TestCase): 134 | def test_constructor(self): 135 | avg = ingest_session._TimeAverage(3, asynctest.CoroutineMock(name='flush')) 136 | assert_equal(3, avg.ratio) 137 | assert_is(None, avg._start_idx) 138 | 139 | async def test_add_index(self): 140 | avg = ingest_session._TimeAverage(3, asynctest.CoroutineMock(name='flush')) 141 | await avg.add_index(0) 142 | await avg.add_index(2) 143 | await avg.add_index(1) # Test time reordering 144 | assert not avg.flush.called 145 | 146 | await avg.add_index(3) # Skip first frame in the group 147 | avg.flush.assert_called_once_with(0) 148 | avg.flush.reset_mock() 149 | assert_equal(3, avg._start_idx) 150 | 151 | await avg.add_index(12) # Skip some whole groups 152 | avg.flush.assert_called_once_with(1) 153 | avg.flush.reset_mock() 154 | assert_equal(12, avg._start_idx) 155 | 156 | await avg.finish() 157 | avg.flush.assert_called_once_with(4) 158 | assert_is(None, avg._start_idx) 159 | 160 | 161 | def test_split_array(): 162 | """Test _split_array""" 163 | c64 = (np.random.uniform(size=(4, 7)) 164 | + 1j * np.random.uniform(size=(4, 7))).astype(np.complex64) 165 | # Create a view which is discontiguous 166 | src = c64[:3, :5].T 167 | actual = ingest_session._split_array(src, np.float32) 168 | expected = np.zeros((5, 3, 2), np.float32) 169 | for i in range(5): 170 | for j in range(3): 171 | expected[i, j, 0] = src[i, j].real 172 | expected[i, j, 1] = src[i, j].imag 173 | np.testing.assert_equal(actual, expected) 174 | 175 | 176 | def test_fast_unique(): 177 | """Test _fast_unique.""" 178 | a = np.array([ 179 | [0, 1, 0, 0, 1, 0, 0, 1], 180 | [0, 1, 0, 0, 1, 0, 1, 1], 181 | [0, 1, 0, 0, 1, 0, 0, 1], 182 | [0, 0, 0, 0, 0, 0, 0, 0], 183 | [0, 0, 0, 0, 0, 0, 0, 0], 184 | [0, 1, 0, 0, 1, 0, 0, 1], 185 | [0, 1, 0, 0, 1, 0, 1, 1] 186 | ], dtype=bool) 187 | comp, indices = ingest_session._fast_unique(a) 188 | assert_equal(comp.shape, (3, 8)) 189 | assert_equal(indices.shape, (7,)) 190 | assert_true(np.all(0 <= indices)) 191 | assert_true(np.all(indices < len(comp))) 192 | np.testing.assert_equal(comp[indices], a) 193 | 194 | 195 | class TestTelstateReceiver(asynctest.TestCase): 196 | def setUp(self): 197 | self.telstate = katsdptelstate.aio.TelescopeState() 198 | 199 | async def test_first_timestamp(self): 200 | # We don't want to bother setting up a valid Receiver base class, we 201 | # just want to test the subclass, so we mock in a different base. 202 | class DummyBase: 203 | def __init__(self, cbf_attr): 204 | self.cbf_attr = cbf_attr 205 | 206 | patcher = mock.patch.object(ingest_session.TelstateReceiver, '__bases__', (DummyBase,)) 207 | with patcher: 208 | patcher.is_local = True # otherwise mock tries to delete __bases__ 209 | cbf_attr = {'scale_factor_timestamp': 4.0} 210 | receiver = ingest_session.TelstateReceiver(cbf_attr=cbf_attr, 211 | telstates=[self.telstate], 212 | l0_int_time=3.0) 213 | # Set first value 214 | assert_equal(12345, await receiver._first_timestamp(12345)) 215 | # Try a different value, first value must stick 216 | assert_equal(12345, await receiver._first_timestamp(54321)) 217 | # Set same value 218 | assert_equal(12345, await receiver._first_timestamp(12345)) 219 | # Check the telstate keys 220 | assert_equal(12345, await self.telstate['first_timestamp_adc']) 221 | assert_equal(3087.75, await self.telstate['first_timestamp']) 222 | 223 | 224 | class TestSensorHistory: 225 | def setUp(self): 226 | self.sh = ingest_session.SensorHistory('test') 227 | 228 | def test_simple(self) -> None: 229 | self.sh.add(4.0, 'hello') 230 | self.sh.add(6.0, 'world') 231 | assert_equal(self.sh.get(4.0), 'hello') 232 | assert_equal(self.sh.get(5.0), 'hello') 233 | assert_equal(self.sh.get(6.0), 'world') 234 | assert_equal(self.sh.get(7.0), 'world') 235 | assert_equal(len(self.sh._data), 1, 'old data was not pruned') 236 | 237 | def test_query_empty(self) -> None: 238 | assert_is_none(self.sh.get(4.0)) 239 | assert_equal(self.sh.get(5.0, 'default'), 'default') 240 | 241 | def test_query_before_first(self) -> None: 242 | self.sh.add(5.0, 'hello') 243 | assert_is_none(self.sh.get(4.0)) 244 | 245 | def test_add_before_query(self) -> None: 246 | self.sh.get(5.0) 247 | with assert_logs(ingest_session.logger, logging.WARNING): 248 | self.sh.add(4.0, 'oops') 249 | assert_equal(self.sh.get(5.0), 'oops') 250 | 251 | def test_add_out_of_order(self) -> None: 252 | self.sh.add(5.0, 'first') 253 | with assert_logs(ingest_session.logger, logging.WARNING): 254 | self.sh.add(4.0, 'second') 255 | assert_is_none(self.sh.get(4)) 256 | 257 | def test_replace_latest(self) -> None: 258 | self.sh.add(5.0, 'first') 259 | self.sh.add(5.0, 'second') 260 | assert_equal(self.sh.get(5.0), 'second') 261 | 262 | def test_query_out_of_order(self) -> None: 263 | self.sh.get(5.0) 264 | with assert_raises(ValueError): 265 | self.sh.get(4.0) 266 | 267 | 268 | class TestCBFIngest: 269 | @device_test 270 | def test_create_proc(self, context, queue): 271 | """Test that an ingest processor can be created on the device""" 272 | template = ingest_session.CBFIngest.create_proc_template(context, [4, 12], 4096, True, True) 273 | template.instantiate( 274 | queue, 1024, Range(96, 1024 - 96), Range(96, 1024 - 96), 544, 512, 1, 275 | 8, 16, [(0, 4), (500, 512)], 276 | threshold_args={'n_sigma': 11.0}) 277 | 278 | def test_tune_next(self): 279 | assert_equal(2, ingest_session.CBFIngest._tune_next(0, [2, 4, 8, 16])) 280 | assert_equal(8, ingest_session.CBFIngest._tune_next(5, [2, 4, 8, 16])) 281 | assert_equal(8, ingest_session.CBFIngest._tune_next(8, [2, 4, 8, 16])) 282 | assert_equal(21, ingest_session.CBFIngest._tune_next(21, [2, 4, 8, 16])) 283 | 284 | def test_baseline_permutation(self): 285 | orig_ordering = np.array([ 286 | ['m000v', 'm000v'], 287 | ['m000h', 'm000v'], 288 | ['m000h', 'm000h'], 289 | ['m000v', 'm000h'], 290 | ['m000v', 'm001v'], 291 | ['m000v', 'm001h'], 292 | ['m000h', 'm001v'], 293 | ['m000h', 'm001h'], 294 | ['m001h', 'm001v'], 295 | ['m001v', 'm001h'], 296 | ['m001h', 'm001h'], 297 | ['m001v', 'm001v']]) 298 | expected_ordering = np.array([ 299 | ['m000h', 'm000h'], 300 | ['m001h', 'm001h'], 301 | ['m000v', 'm000v'], 302 | ['m001v', 'm001v'], 303 | ['m000h', 'm000v'], 304 | ['m001h', 'm001v'], 305 | ['m000v', 'm000h'], 306 | ['m001v', 'm001h'], 307 | ['m000h', 'm001h'], 308 | ['m000v', 'm001v'], 309 | ['m000h', 'm001v'], 310 | ['m000v', 'm001h']]) 311 | 312 | bls = ingest_session.BaselineOrdering(orig_ordering) 313 | np.testing.assert_equal(expected_ordering, bls.sdp_bls_ordering) 314 | np.testing.assert_equal([2, 4, 0, 6, 9, 11, 10, 8, 5, 7, 1, 3], bls.permutation) 315 | 316 | def test_baseline_permutation_masked(self): 317 | orig_ordering = np.array([ 318 | ['m000v', 'm000v'], 319 | ['m000h', 'm000v'], 320 | ['m000h', 'm000h'], 321 | ['m000v', 'm000h'], 322 | ['m000v', 'm001v'], 323 | ['m000v', 'm001h'], 324 | ['m000h', 'm001v'], 325 | ['m000h', 'm001h'], 326 | ['m001h', 'm001v'], 327 | ['m001v', 'm001h'], 328 | ['m001h', 'm001h'], 329 | ['m001v', 'm001v']]) 330 | expected_ordering = np.array([ 331 | ['m001h', 'm001h'], 332 | ['m001v', 'm001v'], 333 | ['m001h', 'm001v'], 334 | ['m001v', 'm001h']]) 335 | antenna_mask = set(['m001']) 336 | 337 | bls = ingest_session.BaselineOrdering(orig_ordering, antenna_mask) 338 | np.testing.assert_equal(expected_ordering, bls.sdp_bls_ordering) 339 | np.testing.assert_equal([-1, -1, -1, -1, -1, -1, -1, -1, 2, 3, 0, 1], bls.permutation) 340 | -------------------------------------------------------------------------------- /katsdpingest/receiver.py: -------------------------------------------------------------------------------- 1 | """Receives from multiple SPEAD streams and combines heaps into frames.""" 2 | 3 | import logging 4 | from collections import deque 5 | import asyncio 6 | import typing # noqa: F401 7 | from typing import List, Sequence, Mapping, Any, Optional, Union # noqa: F401 8 | 9 | import spead2 10 | import spead2.recv 11 | import spead2.recv.asyncio 12 | from aiokatcp import Sensor 13 | 14 | import numpy as np 15 | from katsdptelstate.endpoint import endpoints_to_str, Endpoint 16 | 17 | from .utils import Range 18 | 19 | 20 | _logger = logging.getLogger(__name__) 21 | 22 | 23 | REJECT_HEAP_TYPES = { 24 | 'incomplete': 'incomplete', 25 | 'no-descriptor': 'descriptors not yet received', 26 | 'bad-timestamp': 'timestamp not aligned to integration boundary', 27 | 'too-old': 'timestamp is prior to the start time', 28 | 'bad-channel': 'channel offset is not aligned to the substreams', 29 | 'missing': 'expected heap was not received', 30 | 'bad-heap': 'heap items are missing, wrong shape etc' 31 | } 32 | 33 | 34 | class Frame: 35 | """A group of xeng_raw data with a common timestamp""" 36 | def __init__(self, idx: int, timestamp: int, n_xengs: int) -> None: 37 | self.idx = idx 38 | self.timestamp = timestamp 39 | self.items = [None] * n_xengs # type: List[Optional[np.ndarray]] 40 | 41 | def ready(self) -> bool: 42 | return all(item is not None for item in self.items) 43 | 44 | def empty(self) -> bool: 45 | return all(item is None for item in self.items) 46 | 47 | @property 48 | def nbytes(self) -> int: 49 | return sum([(item.nbytes if item is not None else 0) for item in self.items]) 50 | 51 | 52 | class Receiver: 53 | """Class that receives from multiple SPEAD streams and combines heaps into 54 | frames. 55 | 56 | Parameters 57 | ---------- 58 | endpoints : list of :class:`katsdptelstate.Endpoint` 59 | Endpoints for SPEAD streams. These must be listed in order 60 | of increasing channel number. 61 | interface_address : str 62 | Address of interface to subscribe to for endpoints 63 | ibv : bool 64 | If true, use ibverbs for acceleration 65 | max_streams : int 66 | Maximum number of separate streams to use. The endpoints are spread 67 | across the streams, with a thread per stream. 68 | buffer_size : int 69 | Buffer size. It is split across the streams. 70 | channel_range : :class:`katsdpingest.utils.Range` 71 | Channels to capture. These must be aligned to the stream boundaries. 72 | cbf_channels : int 73 | Total number of channels represented by `endpoints`. 74 | sensors : dict 75 | Dictionary mapping sensor names to sensor objects 76 | cbf_attr : dict 77 | Dictionary mapping CBF attribute names to value 78 | active_frames : int, optional 79 | Maximum number of incomplete frames to keep at one time 80 | 81 | Attributes 82 | ---------- 83 | cbf_attr : dict 84 | Dictionary mapping CBF attribute names to value 85 | active_frames : int 86 | Value of `active_frames` passed to constructor 87 | interval : int 88 | Timestamp change between successive frames. 89 | timestamp_base : Optional[int] 90 | Timestamp associated with the frame with index 0. It is initially 91 | ``None``, and is set when the first dump is received. The raw 92 | timestamp of any other frame can be computed as 93 | ``timestamp_base + idx * interval``. 94 | _frames : :class:`deque` 95 | Deque of :class:`Frame` objects representing incomplete frames. After 96 | initialization, it always contains exactly `active_frames` 97 | elements, with timestamps separated by the inter-dump interval. 98 | _frames_complete : :class:`asyncio.Queue` 99 | Queue of complete frames of type :class:`Frame`. It may also contain 100 | integers, which are the numbers of finished streams. 101 | _running : int 102 | Number of streams still running 103 | _futures : list of :class:`asyncio.Future` 104 | Futures associated with each call to :meth:`_read_stream` 105 | _streams : list of :class:`spead2.recv.asyncio.Stream` 106 | Individual SPEAD streams 107 | _stopping : bool 108 | Set to try by stop(). Note that some streams may still be running 109 | (:attr:`_running` > 0) at the same time. 110 | """ 111 | def __init__( 112 | self, 113 | endpoints: List[Endpoint], 114 | interface_address: str, ibv: bool, 115 | max_streams: int, buffer_size: int, 116 | channel_range: Range, cbf_channels: int, 117 | sensors: Mapping[str, Sensor], 118 | cbf_attr: Mapping[str, Any], 119 | active_frames: int = 1) -> None: 120 | # Determine the endpoints to actually use 121 | if cbf_channels % len(endpoints): 122 | raise ValueError('cbf_channels not divisible by the number of endpoints') 123 | self._endpoint_channels = cbf_channels // len(endpoints) 124 | if not channel_range.isaligned(self._endpoint_channels): 125 | raise ValueError('channel_range is not aligned to the stream boundaries') 126 | if self._endpoint_channels % cbf_attr['n_chans_per_substream'] != 0: 127 | raise ValueError('Number of channels in substream does not divide ' 128 | 'into per-endpoint channels') 129 | use_endpoints = endpoints[channel_range.start // self._endpoint_channels : 130 | channel_range.stop // self._endpoint_channels] 131 | 132 | self.cbf_attr = cbf_attr 133 | self.active_frames = active_frames 134 | self.channel_range = channel_range 135 | self.cbf_channels = cbf_channels 136 | self._interface_address = interface_address 137 | self._ibv = ibv 138 | self._streams = [] # type: List[spead2.recv.asyncio.Stream] 139 | self._frames = deque() # type: typing.Deque[Frame] 140 | self._frames_complete = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Union[Frame, int]] 141 | self._futures = [] # type: List[Optional[asyncio.Future]] 142 | self._stopping = False 143 | self.interval = cbf_attr['ticks_between_spectra'] * cbf_attr['n_accs'] 144 | self.timestamp_base = None 145 | self._ig_cbf = spead2.ItemGroup() 146 | 147 | self._input_bytes = sensors['input-bytes-total'] 148 | self._input_bytes.value = 0 149 | self._input_heaps = sensors['input-heaps-total'] 150 | self._input_heaps.value = 0 151 | self._input_dumps = sensors['input-dumps-total'] 152 | self._input_dumps.value = 0 153 | self._descriptors_received = sensors['descriptors-received'] 154 | self._descriptors_received.value = False 155 | self._metadata_heaps = sensors['input-metadata-heaps-total'] 156 | self._metadata_heaps.value = 0 157 | self._reject_heaps = { 158 | name: sensors['input-' + name + '-heaps-total'] for name in REJECT_HEAP_TYPES 159 | } 160 | for sensor in self._reject_heaps.values(): 161 | sensor.value = 0 162 | 163 | n_streams = min(max_streams, len(use_endpoints)) 164 | stream_buffer_size = buffer_size // n_streams 165 | for i in range(n_streams): 166 | first = len(use_endpoints) * i // n_streams 167 | last = len(use_endpoints) * (i + 1) // n_streams 168 | self._streams.append(self._make_stream(use_endpoints[first:last], 169 | stream_buffer_size)) 170 | self._futures.append(asyncio.get_event_loop().create_task( 171 | self._read_stream(self._streams[-1], i, last - first))) 172 | self._running = n_streams 173 | 174 | def stop(self) -> None: 175 | """Stop all the individual streams.""" 176 | self._stopping = True 177 | for stream in self._streams: 178 | if stream is not None: 179 | stream.stop() 180 | 181 | async def join(self) -> None: 182 | """Wait for all the individual streams to stop. This must not 183 | be called concurrently with :meth:`get`. 184 | 185 | This is a coroutine. 186 | """ 187 | while self._running > 0: 188 | frame = await self._frames_complete.get() 189 | if isinstance(frame, int): 190 | future = self._futures[frame] 191 | assert future is not None 192 | await future 193 | self._futures[frame] = None 194 | self._running -= 1 195 | 196 | def _pop_frame(self, replace=True) -> Optional[Frame]: 197 | """Remove the oldest element of :attr:`_frames`. 198 | 199 | Replace it with a new frame at the other end (unless `replace` is 200 | false), warn if it is incomplete, and update the missing heaps 201 | counter. 202 | 203 | Returns 204 | ------- 205 | frame 206 | The popped frame, or ``None`` if it was empty 207 | """ 208 | xengs = len(self._frames[-1].items) 209 | next_idx = self._frames[-1].idx + 1 210 | next_timestamp = self._frames[-1].timestamp + self.interval 211 | frame = self._frames.popleft() 212 | if replace: 213 | self._frames.append(Frame(next_idx, next_timestamp, xengs)) 214 | actual = sum(item is not None for item in frame.items) 215 | self._reject_heaps['missing'].value += xengs - actual 216 | if actual == 0: 217 | _logger.debug('Frame with timestamp %d is empty, discarding', frame.timestamp) 218 | return None 219 | else: 220 | _logger.debug('Frame with timestamp %d is %d/%d complete', 221 | frame.timestamp, actual, xengs) 222 | return frame 223 | 224 | async def _put_frame(self, frame: Frame) -> None: 225 | """Put a frame onto :attr:`_frames_complete` and update the sensor.""" 226 | self._input_dumps.value += 1 227 | await self._frames_complete.put(frame) 228 | 229 | def _add_readers(self, stream: spead2.recv.asyncio.Stream, 230 | endpoints: Sequence[Endpoint], 231 | buffer_size: int) -> None: 232 | """Subscribe a stream to a list of endpoints.""" 233 | ifaddr = self._interface_address 234 | if self._ibv: 235 | if ifaddr is None: 236 | raise ValueError('Cannot use ibverbs without an interface address') 237 | endpoint_tuples = [(endpoint.host, endpoint.port) for endpoint in endpoints] 238 | stream.add_udp_ibv_reader( 239 | spead2.recv.UdpIbvConfig( 240 | endpoints=endpoint_tuples, 241 | interface_address=ifaddr, 242 | buffer_size=buffer_size 243 | ) 244 | ) 245 | else: 246 | for endpoint in endpoints: 247 | if ifaddr is None: 248 | stream.add_udp_reader(endpoint.port, bind_hostname=endpoint.host) 249 | else: 250 | stream.add_udp_reader(endpoint.host, endpoint.port, 251 | interface_address=ifaddr) 252 | _logger.info( 253 | "CBF SPEAD stream reception on %s via %s%s", 254 | endpoints_to_str(endpoints), 255 | ifaddr if ifaddr is not None else 'default interface', 256 | ' with ibv' if self._ibv else '') 257 | 258 | def _make_stream(self, endpoints: Sequence[Endpoint], 259 | buffer_size: int) -> spead2.recv.asyncio.Stream: 260 | """Prepare a stream, which may combine multiple endpoints.""" 261 | # Figure out how many heaps will have the same timestamp, and set 262 | # up the stream. 263 | heap_channels = self.cbf_attr['n_chans_per_substream'] 264 | stream_channels = len(endpoints) * self._endpoint_channels 265 | baselines = len(self.cbf_attr['bls_ordering']) 266 | heap_data_size = np.dtype(np.complex64).itemsize * heap_channels * baselines 267 | stream_xengs = stream_channels // heap_channels 268 | # It's possible for a heap from each X engine and a descriptor heap 269 | # per endpoint to all arrive at once. 270 | ring_heaps = stream_xengs + len(endpoints) 271 | # Additionally, reordering in the network can cause the end of one dump 272 | # to overlap with the start of the next, for which we need to allow for 273 | # an extra stream_xengs. 274 | max_heaps = ring_heaps + stream_xengs 275 | # We need space in the memory pool for: 276 | # - live heaps (max_heaps, plus a newly incoming heap) 277 | # - ringbuffer heaps 278 | # - per X-engine: 279 | # - heap that has just been popped from the ringbuffer (1) 280 | # - active frames 281 | # - complete frames queue (1) 282 | # - frame being processed by ingest_session (which could be several, depending on 283 | # latency of the pipeline, but assume 4 to be on the safe side) 284 | memory_pool_heaps = ring_heaps + max_heaps + stream_xengs * (self.active_frames + 6) 285 | memory_pool = spead2.MemoryPool(16384, heap_data_size + 512, 286 | memory_pool_heaps, memory_pool_heaps) 287 | stream = spead2.recv.asyncio.Stream( 288 | spead2.ThreadPool(), 289 | spead2.recv.StreamConfig( 290 | max_heaps=max_heaps, 291 | memory_allocator=memory_pool, 292 | memcpy=spead2.MEMCPY_NONTEMPORAL, 293 | stop_on_stop_item=False 294 | ), 295 | spead2.recv.RingStreamConfig( 296 | heaps=ring_heaps, 297 | contiguous_only=False 298 | ) 299 | ) 300 | self._add_readers(stream, endpoints, buffer_size) 301 | return stream 302 | 303 | async def _first_timestamp(self, candidate: int) -> int: 304 | """Get raw ADC timestamp of the first frame across all ingests. 305 | 306 | This is called when the first valid dump is received for this 307 | receiver, and returns the raw timestamp of the first valid dump 308 | across all receivers. Note that the return value may be greater 309 | than `candidate` if another receiver received a heap first but with 310 | a larger timestamp. 311 | 312 | In the base implementation, it simply returns `candidate`. Subclasses 313 | may override this to implement inter-receiver communication. 314 | """ 315 | return candidate 316 | 317 | async def _read_stream(self, stream: spead2.recv.asyncio.Stream, 318 | stream_idx: int, n_endpoints: int) -> None: 319 | """Co-routine that sucks data from a single stream and populates 320 | :attr:`_frames_complete`.""" 321 | try: 322 | heap_channels = self.cbf_attr['n_chans_per_substream'] 323 | xengs = len(self.channel_range) // heap_channels 324 | prev_ts = None 325 | ts_wrap_offset = 0 # Value added to compensate for CBF timestamp wrapping 326 | ts_wrap_period = 2**48 327 | n_stop = 0 328 | 329 | async def process_heap(heap): 330 | """Process one heap and return a classification for it. 331 | 332 | The classification is one of: 333 | - None (normal) 334 | - 'stop' 335 | - 'metadata' 336 | - a key from REJECT_HEAP_TYPES 337 | """ 338 | nonlocal prev_ts, ts_wrap_offset, n_stop 339 | 340 | heap_type = None 341 | data_ts = None 342 | data_item = None 343 | channel0 = None 344 | 345 | if heap.is_end_of_stream(): 346 | self._metadata_heaps.value += 1 347 | n_stop += 1 348 | _logger.debug("%d/%d endpoints stopped on stream %d", 349 | n_stop, n_endpoints, stream_idx) 350 | return 'stop' 351 | elif isinstance(heap, spead2.recv.IncompleteHeap): 352 | heap_type = 'incomplete' 353 | _logger.debug('dropped incomplete heap %d (%d/%d bytes of payload)', 354 | heap.cnt, heap.received_length, heap.heap_length) 355 | # Attempt to extract the timestamp. We can't use 356 | # self._ig_cbf.update because that requires a complete 357 | # heap, so this emulates some of its functionality. 358 | try: 359 | item = self._ig_cbf['timestamp'] 360 | except KeyError: 361 | pass # We don't have the descriptor for it yet 362 | else: 363 | for raw_item in heap.get_items(): 364 | if raw_item.id == item.id: 365 | try: 366 | item.set_from_raw(raw_item) 367 | item.version += 1 368 | except ValueError: 369 | _logger.warning('Exception updating item from heap', 370 | exc_info=True) 371 | return 'bad-heap' 372 | data_ts = item.value 373 | break 374 | # Note: no return here. We carry on to process the timestamp 375 | elif not self._descriptors_received.value and not heap.get_descriptors(): 376 | _logger.debug('Received non-descriptor heap before descriptors') 377 | return 'no-descriptor' 378 | else: 379 | try: 380 | # We suppress the conversion to little endian. The data 381 | # gets copied later anyway and numpy will do the endian 382 | # swapping then without an extraneous copy. 383 | updated = self._ig_cbf.update(heap, new_order='|') 384 | except ValueError: 385 | _logger.warning('Exception updating item group from heap', exc_info=True) 386 | return 'bad-heap' 387 | # The _ig_cbf is shared between streams, so we need to use the values 388 | # before next yielding. 389 | if 'timestamp' in updated: 390 | data_ts = updated['timestamp'].value 391 | if 'xeng_raw' in updated: 392 | data_item = updated['xeng_raw'].value 393 | if 'frequency' in updated: 394 | channel0 = updated['frequency'].value 395 | if not self._descriptors_received.value and 'xeng_raw' in self._ig_cbf: 396 | # This heap added the descriptors 397 | self._descriptors_received.value = True 398 | 399 | if data_ts is None: 400 | _logger.debug("Heap without timestamp received on stream %d", stream_idx) 401 | return heap_type or 'metadata' 402 | 403 | # Process the timestamp, even if this is an incomplete heap, so 404 | # that we age out partial frames timeously. 405 | data_ts += ts_wrap_offset 406 | if prev_ts is not None and data_ts < prev_ts - ts_wrap_period // 2: 407 | # This happens either because packets ended up out-of-order, 408 | # or because the CBF timestamp wrapped. Out-of-order should 409 | # jump backwards a tiny amount while wraps should jump back by 410 | # close to ts_wrap_period. 411 | ts_wrap_offset += ts_wrap_period 412 | data_ts += ts_wrap_period 413 | _logger.warning('Data timestamps wrapped') 414 | elif prev_ts is not None and data_ts > prev_ts + ts_wrap_period // 2: 415 | # This happens if we wrapped, then received another heap 416 | # (probably from a different X engine) from before the 417 | # wrap. We need to undo the wrap. 418 | ts_wrap_offset -= ts_wrap_period 419 | data_ts -= ts_wrap_period 420 | _logger.warning('Data timestamps reverse wrapped') 421 | _logger.debug('Received heap with timestamp %d on stream %d, channel %s', 422 | data_ts, stream_idx, channel0) 423 | prev_ts = data_ts 424 | if not self._frames: 425 | self.timestamp_base = await self._first_timestamp(data_ts) 426 | for i in range(self.active_frames): 427 | self._frames.append( 428 | Frame(i, self.timestamp_base + self.interval * i, xengs)) 429 | ts0 = self._frames[0].timestamp 430 | if data_ts < ts0: 431 | _logger.debug('Timestamp %d is too far in the past, discarding ' 432 | '(channel %s)', data_ts, channel0) 433 | return heap_type or 'too-old' 434 | elif (data_ts - ts0) % self.interval != 0: 435 | _logger.debug('Timestamp %d does not conform to %d + %dn, ' 436 | 'discarding (channel %s)', 437 | data_ts, ts0, self.interval, channel0) 438 | return heap_type or 'bad-timestamp' 439 | while data_ts >= ts0 + self.interval * self.active_frames: 440 | frame = self._pop_frame() 441 | if frame: 442 | await self._put_frame(frame) 443 | del frame # Free it up, particularly if discarded 444 | ts0 = self._frames[0].timestamp 445 | 446 | if heap_type == 'incomplete': 447 | return heap_type 448 | 449 | # From here on we expect we have proper data 450 | if data_item is None: 451 | _logger.warning("CBF heap without xeng_raw received on stream %d", stream_idx) 452 | return 'bad-heap' 453 | if channel0 is None: 454 | _logger.warning("CBF heap without frequency received on stream %d", stream_idx) 455 | return 'bad-heap' 456 | heap_channel_range = Range(channel0, channel0 + heap_channels) 457 | if not (heap_channel_range.isaligned(heap_channels) 458 | and heap_channel_range.issubset(self.channel_range)): 459 | _logger.debug("CBF heap with invalid channel %d on stream %d", 460 | channel0, stream_idx) 461 | return 'bad-channel' 462 | xeng_idx = (channel0 - self.channel_range.start) // heap_channels 463 | frame_idx = (data_ts - ts0) // self.interval 464 | self._frames[frame_idx].items[xeng_idx] = data_item 465 | self._input_bytes.value += data_item.nbytes 466 | self._input_heaps.value += 1 467 | return heap_type 468 | 469 | async for heap in stream: 470 | heap_type = await process_heap(heap) 471 | if heap_type == 'stop': 472 | if n_stop == n_endpoints: 473 | stream.stop() 474 | break 475 | elif heap_type == 'metadata': 476 | self._metadata_heaps.value += 1 477 | elif heap_type in REJECT_HEAP_TYPES: 478 | # Don't warn about incomplete heaps if we've already been 479 | # asked to stop. There may be some heaps still in the 480 | # network at the time we were asked to stop. 481 | if heap_type != 'incomplete' or not self._stopping: 482 | self._reject_heaps[heap_type].value += 1 483 | else: 484 | assert heap_type is None 485 | finally: 486 | await self._frames_complete.put(stream_idx) 487 | 488 | async def get(self) -> Frame: 489 | """Return the next frame. 490 | 491 | This is a coroutine. 492 | 493 | Raises 494 | ------ 495 | spead2.Stopped 496 | if all the streams have stopped 497 | """ 498 | while self._running > 0: 499 | frame = await self._frames_complete.get() 500 | if isinstance(frame, int): 501 | # It's actually the index of a finished stream 502 | self._streams[frame].stop() # In case the co-routine exited with an exception 503 | future = self._futures[frame] 504 | assert future is not None 505 | await future 506 | self._futures[frame] = None 507 | self._running -= 1 508 | else: 509 | return frame 510 | # Check for frames still in the queue 511 | while self._frames: 512 | tail_frame = self._pop_frame(replace=False) 513 | if tail_frame: 514 | return tail_frame 515 | raise spead2.Stopped('End of streams') 516 | -------------------------------------------------------------------------------- /katsdpingest/test/test_ingest_server.py: -------------------------------------------------------------------------------- 1 | """Tests for :mod:`katsdpingest.ingest_server`.""" 2 | 3 | import argparse 4 | import logging 5 | import asyncio 6 | import copy 7 | import concurrent.futures 8 | from unittest import mock 9 | from typing import List, Dict, Any 10 | 11 | import asynctest 12 | import async_timeout 13 | import numpy as np 14 | from nose.tools import (assert_in, assert_is_not_none, assert_is_instance, 15 | assert_true, assert_equal, assert_almost_equal, assert_raises_regex) 16 | 17 | import spead2 18 | import spead2.recv 19 | import spead2.send 20 | import aiokatcp 21 | import katsdptelstate.aio.memory 22 | from katsdptelstate.endpoint import Endpoint 23 | from katsdpsigproc.test.test_accel import device_test 24 | from katdal.flags import CAM, STATIC 25 | import katsdpmodels.rfi_mask 26 | import katsdpmodels.band_mask 27 | import katpoint 28 | import astropy.table 29 | import astropy.units as u 30 | 31 | from katsdpingest.utils import Range, cbf_telstate_view 32 | from katsdpingest.ingest_server import IngestDeviceServer 33 | from katsdpingest.ingest_session import ChannelRanges, BaselineOrdering, SystemAttrs 34 | from katsdpingest.test.test_ingest_session import fake_cbf_attr 35 | from katsdpingest.receiver import Frame 36 | from katsdpingest.sender import Data 37 | 38 | 39 | class MockReceiver: 40 | """Replacement for :class:`katsdpingest.receiver.Receiver`. 41 | 42 | It has a predefined list of frames and yields them with no delay. However, 43 | one can request a pause prior to a particular frame. 44 | 45 | Parameters 46 | ---------- 47 | data : ndarray 48 | 3D array of visibilities indexed by time, frequency and baseline. 49 | The array contains data for the entire CBF channel range. 50 | timestamps : array-like 51 | 1D array of CBF timestamps 52 | """ 53 | def __init__(self, data, timestamps, 54 | endpoints, interface_address, ibv, 55 | max_streams, buffer_size, 56 | channel_range, cbf_channels, sensors, 57 | cbf_attr, active_frames=2, telstates=None, 58 | l0_int_time=None, pauses=None): 59 | assert data.shape[0] == len(timestamps) 60 | self._next_frame = 0 61 | self._data = data 62 | self._timestamps = timestamps 63 | self._stop_event = asyncio.Event() 64 | self._channel_range = channel_range 65 | self._substreams = len(channel_range) // cbf_attr['n_chans_per_substream'] 66 | self._pauses = {} if pauses is None else pauses 67 | # Set values to match Receiver 68 | self.cbf_attr = cbf_attr 69 | self.interval = cbf_attr['ticks_between_spectra'] * cbf_attr['n_accs'] 70 | self.timestamp_base = timestamps[0] 71 | 72 | def stop(self): 73 | self._stop_event.set() 74 | 75 | @asyncio.coroutine 76 | def join(self): 77 | yield from(self._stop_event.wait()) 78 | 79 | @asyncio.coroutine 80 | def get(self): 81 | event = self._pauses.get(self._next_frame) 82 | if event is None: 83 | event = asyncio.sleep(0) 84 | yield from(event) 85 | if self._next_frame >= len(self._data): 86 | raise spead2.Stopped('end of frame list') 87 | frame = Frame(self._next_frame, self._timestamps[self._next_frame], self._substreams) 88 | item_channels = len(self._channel_range) // self._substreams 89 | for i in range(self._substreams): 90 | start = self._channel_range.start + i * item_channels 91 | stop = start + item_channels 92 | frame.items[i] = self._data[self._next_frame, start:stop, ...] 93 | self._next_frame += 1 94 | return frame 95 | 96 | 97 | class DeepCopyMock(mock.MagicMock): 98 | """Mock that takes deep copies of its arguments when called.""" 99 | 100 | def __call__(self, *args, **kwargs): 101 | return super().__call__(*copy.deepcopy(args), **copy.deepcopy(kwargs)) 102 | 103 | 104 | def decode_heap_ig(heap): 105 | ig = spead2.ItemGroup() 106 | assert_is_not_none(heap) 107 | ig.update(heap) 108 | return ig 109 | 110 | 111 | def get_heaps(tx): 112 | rx = spead2.recv.Stream( 113 | spead2.ThreadPool(), 114 | spead2.recv.StreamConfig(stop_on_stop_item=False) 115 | ) 116 | tx.queues[0].stop() 117 | rx.add_inproc_reader(tx.queues[0]) 118 | return list(rx) 119 | 120 | 121 | class TestIngestDeviceServer(asynctest.TestCase): 122 | """Tests for :class:`katsdpingest.ingest_server.IngestDeviceServer. 123 | 124 | This does not test all the intricacies of flagging, timeseries masking, 125 | lost data and so on. It is intended to check that the katcp commands 126 | function and that the correct channels are sent to the correct places. 127 | """ 128 | 129 | def _patch(self, *args, **kwargs): 130 | patcher = mock.patch(*args, **kwargs) 131 | mock_obj = patcher.start() 132 | self.addCleanup(patcher.stop) 133 | return mock_obj 134 | 135 | def _get_tx(self, thread_pool, endpoints, interface_address, flavour, 136 | int_time, channel_range, channel0, all_channels, baselines): 137 | if endpoints == self.user_args.l0_spectral_spead[1:2]: 138 | return self._tx['spectral'] 139 | elif endpoints == self.user_args.l0_continuum_spead[1:2]: 140 | return self._tx['continuum'] 141 | else: 142 | raise KeyError('VisSenderSet created with unrecognised endpoints') 143 | 144 | def _get_sd_tx(self, thread_pool, endpoints, config): 145 | assert_equal(len(endpoints), 1) 146 | tx = spead2.send.asyncio.InprocStream(thread_pool, [spead2.InprocQueue()]) 147 | self._sd_tx[Endpoint(*endpoints[0])] = tx 148 | return tx 149 | 150 | def _create_data(self): 151 | start_ts = 100000000 152 | interval = self.cbf_attr['n_accs'] * self.cbf_attr['ticks_between_spectra'] 153 | n_dumps = 19 154 | shape = (n_dumps, self.cbf_attr['n_chans'], len(self.cbf_attr['bls_ordering']), 2) 155 | rs = np.random.RandomState(seed=1) 156 | data = (rs.standard_normal(shape) * 1000).astype(np.int32) 157 | # Make autocorrelations real, and also set a fixed value. This gives 158 | # all visibilities the same weight, making it easier to compute the 159 | # expected values. 160 | for i, (a, b) in enumerate(self.cbf_attr['bls_ordering']): 161 | if a == b: 162 | data[:, :, i, 0] = 1000 163 | data[:, :, i, 1] = 0 164 | timestamps = (np.arange(n_dumps) * interval + start_ts).astype(np.uint64) 165 | return data, timestamps 166 | 167 | def fake_channel_mask(self) -> np.ndarray: 168 | channel_mask = np.zeros((self.cbf_attr['n_chans']), np.bool_) 169 | channel_mask[704] = True 170 | channel_mask[750:800] = True 171 | channel_mask[900] = True 172 | return channel_mask 173 | 174 | def fake_rfi_mask_model(self) -> katsdpmodels.rfi_mask.RFIMask: 175 | # Channels 852:857 and 1024 176 | ranges = astropy.table.QTable( 177 | [[1034e6, 1070.0e6] * u.Hz, 178 | [1034.95e6, 1070.0e6] * u.Hz, 179 | [1500, np.inf] * u.m], 180 | names=('min_frequency', 'max_frequency', 'max_baseline') 181 | ) 182 | return katsdpmodels.rfi_mask.RFIMaskRanges(ranges, False) 183 | 184 | def fake_band_mask_model(self) -> katsdpmodels.band_mask.BandMask: 185 | # Channels 820:840 186 | ranges = astropy.table.Table( 187 | [[0.2001], [0.2049]], names=('min_fraction', 'max_fraction') 188 | ) 189 | return katsdpmodels.band_mask.BandMaskRanges(ranges) 190 | 191 | def fake_channel_data_suspect(self): 192 | bad = np.zeros(self.cbf_attr['n_chans'], np.bool_) 193 | bad[300] = True 194 | bad[650:750] = True 195 | return bad 196 | 197 | @device_test 198 | async def setUp(self, context, command_queue) -> None: 199 | done_future = asyncio.Future() # type: asyncio.Future[None] 200 | done_future.set_result(None) 201 | self._patchers = [] # type: List[Any] 202 | self._telstate = katsdptelstate.aio.TelescopeState() 203 | n_xengs = 16 204 | self.user_args = user_args = argparse.Namespace( 205 | sdisp_spead=[Endpoint('127.0.0.2', 7149)], 206 | sdisp_interface=None, 207 | cbf_spead=[Endpoint('239.102.250.{}'.format(i), 7148) for i in range(n_xengs)], 208 | cbf_interface='dummyif1', 209 | cbf_ibv=False, 210 | cbf_name='i0_baseline_correlation_products', 211 | l0_spectral_spead=[Endpoint('239.102.251.{}'.format(i), 7148) for i in range(4)], 212 | l0_spectral_interface='dummyif2', 213 | l0_spectral_name='sdp_l0', 214 | l0_continuum_spead=[Endpoint('239.102.252.{}'.format(i), 7148) for i in range(4)], 215 | l0_continuum_interface='dummyif3', 216 | l0_continuum_name='sdp_l0_continuum', 217 | output_int_time=4.0, 218 | sd_int_time=4.0, 219 | antenna_mask=['m090', 'm091', 'm093'], 220 | output_channels=Range(464, 1744), 221 | sd_output_channels=Range(640, 1664), 222 | continuum_factor=16, 223 | sd_continuum_factor=128, 224 | guard_channels=64, 225 | input_streams=2, 226 | input_buffer=32*1024**2, 227 | sd_spead_rate=1000000000.0, 228 | excise=False, 229 | use_data_suspect=True, 230 | servers=4, 231 | server_id=2, 232 | clock_ratio=1.0, 233 | host='127.0.0.1', 234 | port=7147, 235 | name='sdp.ingest.1' 236 | ) 237 | self.loop.set_default_executor(concurrent.futures.ThreadPoolExecutor(max_workers=8)) 238 | self.cbf_attr = fake_cbf_attr(4, n_xengs=n_xengs) 239 | # Put them in at the beginning of time, to ensure they apply to every dump 240 | await self._telstate.set('i0_baseline_correlation_products_src_streams', 241 | ['i0_antenna_channelised_voltage']) 242 | await self._telstate.set('i0_antenna_channelised_voltage_instrument_dev_name', 'i0') 243 | await self._telstate.add('i0_antenna_channelised_voltage_channel_mask', 244 | self.fake_channel_mask(), ts=0) 245 | await self._telstate.add('m090_data_suspect', False, ts=0) 246 | await self._telstate.add('m091_data_suspect', True, ts=0) 247 | input_data_suspect = np.zeros(len(self.cbf_attr['input_labels']), np.bool_) 248 | input_data_suspect[1] = True # Corresponds to m090v 249 | await self._telstate.add('i0_antenna_channelised_voltage_input_data_suspect', 250 | input_data_suspect, ts=0) 251 | await self._telstate.add('i0_baseline_correlation_products_channel_data_suspect', 252 | self.fake_channel_data_suspect(), ts=0) 253 | # These correspond to three core and one outlying MeerKAT antennas, 254 | # so that baselines to m093 are long while the others are short. 255 | antennas = [ 256 | katpoint.Antenna('m090, -30:42:39.8, 21:26:38.0, 1035.0, 13.5, -8.258 -207.289 1.2075 5874.184 5875.444, -0:00:39.7 0 -0:04:04.4 -0:04:53.0 0:00:57.8 -0:00:13.9 0:13:45.2 0:00:59.8, 1.14'), # noqa: E501 257 | katpoint.Antenna('m091, -30:42:39.8, 21:26:38.0, 1035.0, 13.5, 1.126 -171.761 1.0605 5868.979 5869.998, -0:42:08.0 0 0:01:44.0 0:01:11.9 -0:00:14.0 -0:00:21.0 -0:36:13.1 0:01:36.2, 1.14'), # noqa: E501 258 | katpoint.Antenna('m002, -30:42:39.8, 21:26:38.0, 1035.0, 13.5, -32.1085 -224.2365 1.248 5871.207 5872.205, 0:40:20.2 0 -0:02:41.9 -0:03:46.8 0:00:09.4 -0:00:01.1 0:03:04.7, 1.14'), # noqa: E501 259 | katpoint.Antenna('m093, -30:42:39.8, 21:26:38.0, 1035.0, 13.5, -1440.6235 -2503.7705 14.288 5932.94 5934.732, -0:15:23.0 0 0:00:04.6 -0:03:30.4 0:01:12.2 0:00:37.5 0:00:15.6 0:01:11.8, 1.14') # noqa: E501 260 | ] 261 | self._telstate_cbf = await cbf_telstate_view(self._telstate, 262 | 'i0_baseline_correlation_products') 263 | self.system_attrs = SystemAttrs( 264 | self.cbf_attr, self.fake_rfi_mask_model(), self.fake_band_mask_model(), 265 | antennas) 266 | self.channel_ranges = ChannelRanges( 267 | user_args.servers, user_args.server_id - 1, 268 | self.cbf_attr['n_chans'], user_args.continuum_factor, user_args.sd_continuum_factor, 269 | len(user_args.cbf_spead), 64, 270 | user_args.output_channels, user_args.sd_output_channels) 271 | 272 | self._data, self._timestamps = self._create_data() 273 | self._pauses = None 274 | self._Receiver = self._patch( 275 | 'katsdpingest.ingest_session.TelstateReceiver', 276 | side_effect=lambda *args, **kwargs: 277 | MockReceiver(self._data, self._timestamps, *args, # type: ignore 278 | pauses=self._pauses, **kwargs)) 279 | self._tx = {'continuum': mock.MagicMock(), 'spectral': mock.MagicMock()} 280 | for tx in self._tx.values(): 281 | tx.start.return_value = done_future 282 | tx.stop.return_value = done_future 283 | tx.send = DeepCopyMock() 284 | tx.send.return_value = done_future 285 | tx.sub_channels = len(self.channel_ranges.output) 286 | self._tx['continuum'].sub_channels //= self.channel_ranges.cont_factor 287 | self._VisSenderSet = self._patch( 288 | 'katsdpingest.sender.VisSenderSet', side_effect=self._get_tx) 289 | self._sd_tx: Dict[Endpoint, spead2.send.asyncio.InprocStream] = {} 290 | self._UdpStream = self._patch('spead2.send.asyncio.UdpStream', 291 | side_effect=self._get_sd_tx) 292 | self._patch('katsdpservices.get_interface_address', 293 | side_effect=lambda interface: '127.0.0.' + interface[-1] if interface else None) 294 | self._server = IngestDeviceServer( 295 | user_args, self._telstate_cbf, self.channel_ranges, self.system_attrs, context, 296 | host=user_args.host, port=user_args.port) 297 | await self._server.start() 298 | self.addCleanup(self._server.stop) 299 | self._client = await aiokatcp.Client.connect(user_args.host, user_args.port) 300 | self.addCleanup(self._client.wait_closed) 301 | self.addCleanup(self._client.close) 302 | 303 | async def make_request(self, name: str, *args) -> List[aiokatcp.Message]: 304 | """Issue a request to the server, timing out if it takes too long. 305 | 306 | Parameters 307 | ---------- 308 | name : str 309 | Request name 310 | args : list 311 | Arguments to the request 312 | 313 | Returns 314 | ------- 315 | informs : list 316 | Informs returned with the reply 317 | """ 318 | with async_timeout.timeout(15): 319 | reply, informs = await self._client.request(name, *args) 320 | return informs 321 | 322 | async def assert_request_fails(self, msg_re, name, *args): 323 | """Assert that a request fails, and test the error message against 324 | a regular expression.""" 325 | with assert_raises_regex(aiokatcp.FailReply, msg_re): 326 | with async_timeout.timeout(15): 327 | await self._client.request(name, *args) 328 | 329 | async def _get_expected(self): 330 | """Return expected visibilities, flags and timestamps. 331 | 332 | The timestamps are in seconds since the sync time. The full CBF channel 333 | range is returned. 334 | """ 335 | # Convert to complex64 from pairs of real and imag int 336 | vis = (self._data[..., 0] + self._data[..., 1] * 1j).astype(np.complex64) 337 | # Scaling 338 | vis /= self.cbf_attr['n_accs'] 339 | # Time averaging 340 | time_ratio = int(np.round(await self._telstate['sdp_l0_int_time'] 341 | / self.cbf_attr['int_time'])) 342 | batch_edges = np.arange(0, vis.shape[0], time_ratio) 343 | batch_sizes = np.minimum(batch_edges + time_ratio, vis.shape[0]) - batch_edges 344 | vis = np.add.reduceat(vis, batch_edges, axis=0) 345 | vis /= batch_sizes[:, np.newaxis, np.newaxis] 346 | timestamps = self._timestamps[::time_ratio] / self.cbf_attr['scale_factor_timestamp'] \ 347 | + 0.5 * (await self._telstate['sdp_l0_int_time']) 348 | # Baseline permutation 349 | bls = BaselineOrdering(self.cbf_attr['bls_ordering'], self.user_args.antenna_mask) 350 | inv_permutation = np.empty(len(bls.sdp_bls_ordering), np.int) 351 | for i, p in enumerate(bls.permutation): 352 | if p != -1: 353 | inv_permutation[p] = i 354 | vis = vis[..., inv_permutation] 355 | # Sanity check that we've constructed inv_permutation correctly 356 | np.testing.assert_array_equal( 357 | await self._telstate['sdp_l0_bls_ordering'], 358 | self.cbf_attr['bls_ordering'][inv_permutation]) 359 | flags = np.empty(vis.shape, np.uint8) 360 | channel_mask = self.fake_channel_mask() 361 | channel_mask[820:840] = True # Merge in band mask 362 | channel_data_suspect = self.fake_channel_data_suspect()[np.newaxis, :, np.newaxis] 363 | flags[:] = channel_data_suspect * np.uint8(CAM) 364 | for i, (a, b) in enumerate(bls.sdp_bls_ordering): 365 | if a.startswith('m091') or b.startswith('m091'): 366 | # data suspect sensor is True 367 | flags[:, :, i] |= CAM 368 | if a == 'm090v' or b == 'm090v': 369 | # input_data_suspect is True 370 | flags[:, :, i] |= CAM 371 | flags[:, :, i] |= channel_mask * np.uint8(STATIC) 372 | if a[:-1] != b[:-1]: 373 | # RFI model, which doesn't apply to auto-correlations 374 | flags[:, 1024, i] |= np.uint8(STATIC) 375 | if a.startswith('m093') == b.startswith('m093'): 376 | # Short baseline 377 | flags[:, 852:857, i] |= np.uint8(STATIC) 378 | return vis, flags, timestamps 379 | 380 | def _channel_average(self, vis, factor): 381 | return np.add.reduceat(vis, np.arange(0, vis.shape[1], factor), axis=1) / factor 382 | 383 | def _channel_average_flags(self, flags, factor): 384 | return np.bitwise_or.reduceat(flags, np.arange(0, flags.shape[1], factor), axis=1) 385 | 386 | def _check_output(self, tx, expected_vis, expected_flags, expected_ts, send_slice): 387 | """Checks that the visibilities and timestamps are correct.""" 388 | tx.start.assert_called_once_with() 389 | tx.stop.assert_called_once_with() 390 | calls = tx.send.mock_calls 391 | assert_equal(len(expected_vis), len(calls)) 392 | for i, (vis, flags, ts, call) in enumerate( 393 | zip(expected_vis, expected_flags, expected_ts, calls)): 394 | data, idx, ts_rel = call[1] 395 | assert_is_instance(data, Data) 396 | np.testing.assert_allclose(vis, data.vis[send_slice], rtol=1e-5, atol=1e-6) 397 | np.testing.assert_array_equal(flags, data.flags[send_slice]) 398 | assert_equal(i, idx) 399 | assert_almost_equal(ts, ts_rel) 400 | 401 | async def test_init_telstate(self): 402 | """Test the output metadata in telstate""" 403 | async def get_ts(key): 404 | return await self._telstate[prefix + '_' + key] 405 | 406 | bls_ordering = [] 407 | for a in self.user_args.antenna_mask: 408 | for b in self.user_args.antenna_mask: 409 | if a <= b: 410 | for ap in 'hv': 411 | for bp in 'hv': 412 | bls_ordering.append([a + ap, b + bp]) 413 | bls_ordering.sort() 414 | for prefix in ['sdp_l0', 'sdp_l0_continuum']: 415 | factor = 1 if prefix == 'sdp_l0' else self.user_args.continuum_factor 416 | assert_equal(1280 // factor, await get_ts('n_chans')) 417 | assert_equal((await get_ts('n_chans')) // 4, await get_ts('n_chans_per_substream')) 418 | assert_equal(len(bls_ordering), await get_ts('n_bls')) 419 | assert_equal(bls_ordering, sorted((await get_ts('bls_ordering')).tolist())) 420 | assert_equal(self.cbf_attr['sync_time'], await get_ts('sync_time')) 421 | assert_equal(267500000.0, await get_ts('bandwidth')) 422 | assert_equal(8 * self.cbf_attr['int_time'], await get_ts('int_time')) 423 | assert_equal((464, 1744), await get_ts('channel_range')) 424 | assert_equal(1086718750.0, await self._telstate['sdp_l0_center_freq']) 425 | # Offset by 7.5 channels to identify the centre of a continuum channel 426 | assert_equal(1088286132.8125, await self._telstate['sdp_l0_continuum_center_freq']) 427 | 428 | async def test_capture(self): 429 | """Test the core data capture process.""" 430 | await self.make_request('capture-init', 'cb1') 431 | await self.make_request('capture-done') 432 | l0_flavour = spead2.Flavour(4, 64, 48) 433 | l0_int_time = 8 * self.cbf_attr['int_time'] 434 | expected_vis, expected_flags, expected_ts = await self._get_expected() 435 | expected_output_vis = expected_vis[:, self.channel_ranges.output.asslice(), :] 436 | expected_output_flags = expected_flags[:, self.channel_ranges.output.asslice(), :] 437 | 438 | # This server sends channels 784:1104 to L0 and 896:1152 to sdisp. 439 | # Aligning to the sd_continuum_factor (128) gives computed = 768:1152. 440 | assert_equal(Range(784, 1104), self.channel_ranges.output) 441 | assert_equal(Range(896, 1152), self.channel_ranges.sd_output) 442 | send_range = Range(16, 336) 443 | self._VisSenderSet.assert_any_call( 444 | mock.ANY, self.user_args.l0_spectral_spead[1:2], '127.0.0.2', 445 | l0_flavour, l0_int_time, send_range, 320, 1280, 24) 446 | self._check_output(self._tx['spectral'], expected_output_vis, expected_output_flags, 447 | expected_ts, send_range.asslice()) 448 | self._tx['spectral'].stop.assert_called_once_with() 449 | 450 | send_range = Range(1, 21) 451 | self._VisSenderSet.assert_any_call( 452 | mock.ANY, self.user_args.l0_continuum_spead[1:2], '127.0.0.3', 453 | l0_flavour, l0_int_time, send_range, 20, 80, 24) 454 | self._check_output( 455 | self._tx['continuum'], 456 | self._channel_average(expected_output_vis, self.user_args.continuum_factor), 457 | self._channel_average_flags(expected_output_flags, self.user_args.continuum_factor), 458 | expected_ts, send_range.asslice()) 459 | 460 | assert_equal([Endpoint('127.0.0.2', 7149)], list(self._sd_tx.keys())) 461 | sd_tx = self._sd_tx[Endpoint('127.0.0.2', 7149)] 462 | expected_sd_vis = self._channel_average( 463 | expected_vis[:, self.channel_ranges.sd_output.asslice(), :], 464 | self.user_args.sd_continuum_factor) 465 | expected_sd_flags = self._channel_average_flags( 466 | expected_flags[:, self.channel_ranges.sd_output.asslice(), :], 467 | self.user_args.sd_continuum_factor) 468 | heaps = get_heaps(sd_tx) 469 | # First heap should be start-of-stream marker 470 | assert_true(heaps[0].is_start_of_stream()) 471 | # Following heaps should contain averaged visibility data 472 | assert_equal(len(expected_sd_vis), len(heaps) - 2) 473 | for i, heap in enumerate(heaps[1:-1]): 474 | ig = decode_heap_ig(heap) 475 | vis = ig['sd_blmxdata'].value 476 | # Signal displays take complex values as pairs of floats; reconstitute them. 477 | vis = vis[..., 0] + 1j * vis[..., 1] 478 | flags = ig['sd_blmxflags'].value 479 | np.testing.assert_allclose(expected_sd_vis[i], vis, rtol=1e-5, atol=1e-6) 480 | np.testing.assert_array_equal(expected_sd_flags[i], flags) 481 | # Final call must send a stop 482 | assert_true(heaps[-1].is_end_of_stream()) 483 | 484 | async def test_done_when_not_capturing(self): 485 | """Calling capture-stop when not capturing fails""" 486 | await self.assert_request_fails(r'No existing capture session', 'capture-done') 487 | 488 | async def test_init_when_capturing(self): 489 | """Calling capture-init when capturing fails""" 490 | await self.make_request('capture-init', 'cb1') 491 | await self.assert_request_fails(r'Existing capture session found', 'capture-init', 'cb2') 492 | await self.make_request('capture-done') 493 | 494 | async def test_enable_disable_debug(self): 495 | """?enable-debug and ?disable-debug change the log level of session logger""" 496 | assert_equal(logging.NOTSET, logging.getLogger('katsdpingest.ingest_session').level) 497 | await self.make_request('enable-debug') 498 | assert_equal(logging.DEBUG, logging.getLogger('katsdpingest.ingest_session').level) 499 | await self.make_request('disable-debug') 500 | assert_equal(logging.NOTSET, logging.getLogger('katsdpingest.ingest_session').level) 501 | 502 | async def test_add_sdisp_ip(self): 503 | """Add additional addresses with add-sdisp-ip.""" 504 | await self.make_request('add-sdisp-ip', '127.0.0.3:8000') 505 | await self.make_request('add-sdisp-ip', '127.0.0.4') 506 | # A duplicate 507 | await self.make_request('add-sdisp-ip', '127.0.0.3:8001') 508 | await self.make_request('capture-init', 'cb1') 509 | await self.make_request('capture-done') 510 | assert_equal({Endpoint('127.0.0.2', 7149), 511 | Endpoint('127.0.0.3', 8000), 512 | Endpoint('127.0.0.4', 7149)}, 513 | self._sd_tx.keys()) 514 | # We won't check the contents, since that is tested elsewhere. Just 515 | # check that all the streams got the expected number of heaps. 516 | for tx in self._sd_tx.values(): 517 | assert_equal(5, len(get_heaps(tx))) 518 | 519 | async def test_drop_sdisp_ip_not_capturing(self): 520 | """Dropping a sdisp IP when not capturing sends no data at all.""" 521 | await self.make_request('drop-sdisp-ip', '127.0.0.2') 522 | await self.make_request('capture-init', 'cb1') 523 | await self.make_request('capture-done') 524 | sd_tx = self._sd_tx[Endpoint('127.0.0.2', 7149)] 525 | assert_equal([], get_heaps(sd_tx)) 526 | 527 | async def test_drop_sdisp_ip_capturing(self): 528 | """Dropping a sdisp IP when capturing sends a stop heap.""" 529 | self._pauses = {10: asyncio.Future()} 530 | await self.make_request('capture-init', 'cb1') 531 | sd_tx = self._sd_tx[Endpoint('127.0.0.2', 7149)] 532 | # Ensure the pause point gets reached, and wait for 533 | # the signal display data to be sent. 534 | sd_rx = spead2.recv.asyncio.Stream( 535 | spead2.ThreadPool(), 536 | spead2.recv.StreamConfig(stop_on_stop_item=False) 537 | ) 538 | sd_rx.add_inproc_reader(sd_tx.queues[0]) 539 | heaps = [] 540 | with async_timeout.timeout(10): 541 | for i in range(2): 542 | heaps.append(await sd_rx.get()) 543 | await self.make_request('drop-sdisp-ip', '127.0.0.2') 544 | self._pauses[10].set_result(None) 545 | await self.make_request('capture-done') 546 | sd_tx.queues[0].stop() 547 | while True: 548 | try: 549 | heaps.append(await sd_rx.get()) 550 | except spead2.Stopped: 551 | break 552 | assert_equal(3, len(heaps)) # start, one data, and stop heaps 553 | assert_true(heaps[0].is_start_of_stream()) 554 | ig = decode_heap_ig(heaps[1]) 555 | assert_in('sd_blmxdata', ig) 556 | assert_true(heaps[2].is_end_of_stream()) 557 | 558 | async def test_drop_sdisp_ip_missing(self): 559 | """Dropping an unregistered IP address fails""" 560 | await self.assert_request_fails('does not exist', 'drop-sdisp-ip', '127.0.0.3') 561 | 562 | async def test_internal_log_level_query_all(self): 563 | """Test internal-log-level query with no parameters""" 564 | informs = await self.make_request('internal-log-level') 565 | levels = {} 566 | for inform in informs: 567 | levels[inform.arguments[0]] = inform.arguments[1] 568 | # Check that some known logger appears in the list 569 | assert_in(b'aiokatcp.connection', levels) 570 | assert_equal(b'NOTSET', levels[b'aiokatcp.connection']) 571 | 572 | async def test_internal_log_level_query_one(self): 573 | """Test internal-log-level query with one parameter""" 574 | informs = await self.make_request('internal-log-level', 'aiokatcp.connection') 575 | assert_equal(1, len(informs)) 576 | assert_equal(aiokatcp.Message.inform('internal-log-level', b'aiokatcp.connection', 577 | b'NOTSET', mid=informs[0].mid), 578 | informs[0]) 579 | 580 | async def test_internal_log_level_query_one_missing(self): 581 | """Querying internal-log-level with a non-existent logger fails""" 582 | await self.assert_request_fails('Unknown logger', 'internal-log-level', 'notalogger') 583 | 584 | async def test_internal_log_level_set(self): 585 | """Set a logger level via internal-log-level""" 586 | await self.make_request('internal-log-level', 'katcp.server', 'INFO') 587 | assert_equal(logging.INFO, logging.getLogger('katcp.server').level) 588 | await self.make_request('internal-log-level', 'katcp.server', 'NOTSET') 589 | assert_equal(logging.NOTSET, logging.getLogger('katcp.server').level) 590 | await self.assert_request_fails( 591 | 'Unknown log level', 'internal-log-level', 'katcp.server', 'DUMMY') 592 | -------------------------------------------------------------------------------- /katsdpingest/test/test_sigproc.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """Tests for the sigproc module.""" 3 | 4 | from unittest import mock 5 | 6 | import numpy as np 7 | from katsdpsigproc import tune 8 | import katsdpsigproc.rfi.device as rfi 9 | import katsdpsigproc.rfi.host as rfi_host 10 | from katsdpsigproc.test.test_accel import device_test, force_autotune 11 | from katdal.flags import INGEST_RFI, CAL_RFI, CAM 12 | from nose.tools import assert_equal, assert_raises 13 | 14 | from katsdpingest import sigproc 15 | from katsdpingest.utils import Range 16 | 17 | 18 | UNFLAGGED_BIT = 128 19 | FLAG_SCALE = np.float32(2) ** -64 20 | FLAG_SCALE_INV = np.float32(2) ** 64 21 | 22 | 23 | def random_vis(rs, shape): 24 | """Generate random visibilities with mean 0 and standard deviation 1.""" 25 | return (rs.standard_normal(shape) + rs.standard_normal(shape) * 1j).astype(np.complex64) 26 | 27 | 28 | def random_flags(rs, shape, bits, p): 29 | """Generate random array of flag bits. 30 | 31 | Parameters 32 | ---------- 33 | rs : :class:`numpy.random.RandomState` 34 | Random generator 35 | shape : tuple 36 | Shape of the output array 37 | bits : int 38 | Number of bits in each flag word that are candidates 39 | p : float 40 | Probability of each individual bit being set 41 | """ 42 | flags = np.zeros(shape, np.uint8) 43 | for i in range(bits): 44 | flags |= rs.choice([1 << i, 0], shape, p=[p, 1 - p]).astype(np.uint8) 45 | return flags 46 | 47 | 48 | class TestPrepare: 49 | """Test :class:`katsdpingest.sigproc.Prepare`""" 50 | 51 | @device_test 52 | def test_prepare(self, context, queue): 53 | """Basic test of data preparation""" 54 | channels = 73 55 | in_baselines = 99 56 | out_baselines = 91 57 | n_accs = 11 58 | 59 | rs = np.random.RandomState(seed=1) 60 | vis_in = rs.random_integers(-1000, 1000, (channels, in_baselines, 2)).astype(np.int32) 61 | permutation = rs.permutation(in_baselines).astype(np.int16) 62 | permutation[permutation >= out_baselines] = -1 63 | 64 | template = sigproc.PrepareTemplate(context) 65 | prepare = template.instantiate(queue, channels, in_baselines, out_baselines) 66 | prepare.ensure_all_bound() 67 | prepare.buffer('vis_in').set(queue, vis_in) 68 | prepare.buffer('permutation').set(queue, permutation) 69 | prepare.n_accs = n_accs 70 | prepare() 71 | vis_out = prepare.buffer('vis_out').get(queue) 72 | 73 | assert_equal((out_baselines, channels), vis_out.shape) 74 | expected_vis = np.zeros_like(vis_out) 75 | scale = np.float32(1 / n_accs) 76 | for i in range(channels): 77 | for j in range(in_baselines): 78 | value = (vis_in[i, j, 0] + 1j * vis_in[i, j, 1]) * scale 79 | row = permutation[j] 80 | if row >= 0: 81 | expected_vis[row, i] = value 82 | np.testing.assert_equal(expected_vis, vis_out) 83 | 84 | @device_test 85 | @force_autotune 86 | def test_autotune(self, context, queue): 87 | sigproc.PrepareTemplate(context) 88 | 89 | 90 | class TestPrepareFlags: 91 | """Test :class:`katsdpingest.sigproc.PrepareFlags`""" 92 | 93 | @device_test 94 | def test_random(self, context, queue): 95 | """Basic test using random data""" 96 | channels = 643 97 | baselines = 497 98 | masks = 17 99 | rs = np.random.RandomState(seed=1) 100 | vis = random_vis(rs, (channels, baselines)) 101 | # Create some zero visibilities to ensure they're flagged 102 | vis[rs.rand(channels, baselines) < 0.3] = 0 103 | channel_mask = random_flags(rs, (masks, channels), 7, 0.1) 104 | channel_mask_idx = rs.randint(0, masks, baselines).astype(np.uint32) 105 | 106 | template = sigproc.PrepareFlagsTemplate(context) 107 | fn = template.instantiate(queue, channels, baselines, masks, 2**7) 108 | fn.ensure_all_bound() 109 | fn.buffer('vis').set(queue, vis) 110 | fn.buffer('channel_mask').set(queue, channel_mask) 111 | fn.buffer('channel_mask_idx').set(queue, channel_mask_idx) 112 | fn() 113 | flags = fn.buffer('flags').get(queue) 114 | 115 | expected = channel_mask[channel_mask_idx, :].T 116 | expected = expected | np.where(vis == 0, 2**7, 0) 117 | np.testing.assert_equal(expected, flags) 118 | 119 | 120 | class TestMergeFlags: 121 | """Test :class:`katsdpingest.sigproc.MergeFlags`""" 122 | 123 | @device_test 124 | def test_random(self, context, queue): 125 | """Basic test using random data""" 126 | channels = 643 127 | baselines = 497 128 | rs = np.random.RandomState(seed=1) 129 | flags_in = random_flags(rs, (channels, baselines), 8, 0.1) 130 | flags_out = random_flags(rs, (baselines, channels), 8, 0.1) 131 | baseline_flags = random_flags(rs, (baselines,), 8, 0.2) 132 | 133 | template = sigproc.MergeFlagsTemplate(context) 134 | fn = template.instantiate(queue, channels, baselines) 135 | fn.ensure_all_bound() 136 | fn.buffer('flags_in').set(queue, flags_in) 137 | fn.buffer('flags_out').set(queue, flags_out) 138 | fn.buffer('baseline_flags').set(queue, baseline_flags) 139 | fn() 140 | output = fn.buffer('flags_out').get(queue) 141 | 142 | expected = flags_out | flags_in.T | baseline_flags[:, np.newaxis] 143 | np.testing.assert_equal(expected, output) 144 | 145 | 146 | class TestCountFlags: 147 | """Test :class:`katsdpingest.sigproc.CountFlags`""" 148 | 149 | @device_test 150 | def test_random(self, context, queue): 151 | """Basic test using random data""" 152 | channels = 1243 153 | channel_range = Range(64, 1235) 154 | baselines = 97 155 | mask = 255 - (1 << 6) 156 | 157 | rs = np.random.RandomState(seed=1) 158 | flags = rs.randint(0, 256, size=(baselines, channels)).astype(np.uint8) 159 | orig_counts = rs.randint(0, 10000, size=(baselines, 8)).astype(np.uint32) 160 | orig_any_counts = rs.randint(0, 10000, size=baselines).astype(np.uint32) 161 | 162 | template = sigproc.CountFlagsTemplate(context) 163 | fn = template.instantiate(queue, channels, channel_range, baselines, mask) 164 | fn.ensure_all_bound() 165 | fn.buffer('flags').set(queue, flags) 166 | fn.buffer('counts').set(queue, orig_counts) 167 | fn.buffer('any_counts').set(queue, orig_any_counts) 168 | fn() 169 | counts = fn.buffer('counts').get(queue) 170 | any_counts = fn.buffer('any_counts').get(queue) 171 | 172 | assert_equal((baselines, 8), counts.shape) 173 | expected = orig_counts[:] 174 | combined_flags = flags & mask 175 | included_flags = combined_flags[:, channel_range.asslice()] 176 | for i in range(8): 177 | expected[:, i] += np.count_nonzero(included_flags & (1 << i), axis=1).astype(np.uint32) 178 | np.testing.assert_equal(expected, counts) 179 | 180 | assert_equal((baselines,), any_counts.shape) 181 | expected = orig_any_counts[:] 182 | expected += np.count_nonzero(included_flags, axis=1).astype(np.uint32) 183 | np.testing.assert_equal(expected, any_counts) 184 | 185 | @device_test 186 | @force_autotune 187 | def test_autotune(self, context, queue): 188 | sigproc.CountFlagsTemplate(context) 189 | 190 | 191 | class TestAccum: 192 | """Test :class:`katsdpingest.sigproc.Accum`""" 193 | 194 | def _test_small(self, context, queue, excise, expected): 195 | """Run a small hand-coded test case.""" 196 | unflagged = UNFLAGGED_BIT if excise else 0 197 | # Host copies of arrays 198 | host = { 199 | 'vis_in': np.array([[1+2j, 2+5j, 3-3j, 2+1j, 4]], dtype=np.complex64), 200 | 'weights_in': np.array([[2.0, 4.0, 3.0]], dtype=np.float32), 201 | 'flags_in': np.array([[5, 0, 10, 0, 4]], dtype=np.uint8), 202 | 'vis_out0': np.array([[7-3j, 0+0j, 0+5j]], dtype=np.complex64).T, 203 | 'weights_out0': np.array([[1.5, 0.0, 4.5]], dtype=np.float32).T, 204 | 'flags_out0': np.array([[1 | unflagged, 9, unflagged]], dtype=np.uint8).T 205 | } 206 | 207 | template = sigproc.AccumTemplate(context, 1, UNFLAGGED_BIT, excise) 208 | fn = template.instantiate(queue, 5, Range(1, 4), 1) 209 | fn.ensure_all_bound() 210 | for name, value in host.items(): 211 | fn.buffer(name).set(queue, value) 212 | fn() 213 | for name, value in expected.items(): 214 | actual = fn.buffer(name).get(queue) 215 | np.testing.assert_equal(value, actual, err_msg=name + " does not match") 216 | 217 | @device_test 218 | def test_small_excise(self, context, queue): 219 | """Hand-coded test data, to test various cases, with excision""" 220 | expected = { 221 | 'vis_out0': np.array([[11+7j, (12-12j) * FLAG_SCALE, 6+8j]], dtype=np.complex64).T, 222 | 'weights_out0': np.array([[3.5, 4.0 * FLAG_SCALE, 7.5]], dtype=np.float32).T, 223 | 'flags_out0': np.array([[1 | UNFLAGGED_BIT, 11, UNFLAGGED_BIT]], dtype=np.uint8).T 224 | } 225 | self._test_small(context, queue, True, expected) 226 | 227 | @device_test 228 | def test_small_no_excise(self, context, queue): 229 | expected = { 230 | 'vis_out0': np.array([[11+7j, 12-12j, 6+8j]], dtype=np.complex64).T, 231 | 'weights_out0': np.array([[3.5, 4.0, 7.5]], dtype=np.float32).T, 232 | 'flags_out0': np.array([[1, 11, 0]], dtype=np.uint8).T 233 | } 234 | self._test_small(context, queue, False, expected) 235 | 236 | def _test_big(self, context, queue, excise): 237 | channels = 203 238 | baselines = 171 239 | channel_range = Range(7, 198) 240 | kept_channels = len(channel_range) 241 | outputs = 2 242 | rs = np.random.RandomState(1) 243 | 244 | vis_in = random_vis(rs, (baselines, channels)) 245 | weights_in = rs.uniform(size=(baselines, kept_channels)).astype(np.float32) 246 | flags_in = random_flags(rs, (baselines, channels), 7, p=0.2) 247 | vis_out = [] 248 | weights_out = [] 249 | flags_out = [] 250 | for i in range(outputs): 251 | vis_out.append(random_vis(rs, (kept_channels, baselines))) 252 | weights_out.append(rs.uniform(size=(kept_channels, baselines)).astype(np.float32)) 253 | flags_out.append(random_flags(rs, (kept_channels, baselines), 8, p=0.02)) 254 | # Where the unflagged bit is not set, we expect the current 255 | # accumulation to be downweighted by FLAG_SCALE. 256 | if excise: 257 | scale = np.where(flags_out[-1] & UNFLAGGED_BIT, 1, FLAG_SCALE) 258 | vis_out[-1] *= scale 259 | weights_out[-1] *= scale 260 | 261 | template = sigproc.AccumTemplate(context, outputs, UNFLAGGED_BIT, excise) 262 | fn = template.instantiate(queue, channels, channel_range, baselines) 263 | fn.ensure_all_bound() 264 | for (name, value) in [('vis_in', vis_in), ('weights_in', weights_in), 265 | ('flags_in', flags_in)]: 266 | fn.buffer(name).set(queue, value) 267 | for (name, value) in [('vis_out', vis_out), ('weights_out', weights_out), 268 | ('flags_out', flags_out)]: 269 | for i in range(outputs): 270 | fn.buffer(name + str(i)).set(queue, value[i]) 271 | fn() 272 | 273 | # Perform the operation on the host 274 | kept_vis = vis_in[:, channel_range.start : channel_range.stop] 275 | kept_flags = flags_in[:, channel_range.start : channel_range.stop] 276 | if excise: 277 | flagged_weights = weights_in * ((kept_flags == 0) + FLAG_SCALE) 278 | # unflagged inputs need the UNFLAGGED_BIT set 279 | kept_flags = kept_flags | np.where(kept_flags, 0, UNFLAGGED_BIT).astype(np.uint8) 280 | else: 281 | flagged_weights = weights_in 282 | for i in range(outputs): 283 | # The GPU uses an FMA - simulate it by computing in double precision first 284 | vis_out[i][:] = vis_out[i] + (kept_vis.astype(np.complex128) * flagged_weights).T 285 | weights_out[i] += flagged_weights.T 286 | flags_out[i] |= kept_flags.T 287 | 288 | # Verify results 289 | for (name, value) in [('vis_out', vis_out), ('weights_out', weights_out), 290 | ('flags_out', flags_out)]: 291 | for i in range(outputs): 292 | actual = fn.buffer(name + str(i)).get(queue) 293 | np.testing.assert_allclose(value[i], actual, 1e-5) 294 | 295 | @device_test 296 | def test_big_excise(self, context, queue): 297 | """Test with large random data against a simple CPU version (with excision)""" 298 | self._test_big(context, queue, True) 299 | 300 | @device_test 301 | def test_big_no_excise(self, context, queue): 302 | """Test with large random data against a simple CPU version (no excision)""" 303 | self._test_big(context, queue, False) 304 | 305 | @device_test 306 | @force_autotune 307 | def test_autotune(self, context, queue): 308 | sigproc.AccumTemplate(context, 2, 1, False) 309 | sigproc.AccumTemplate(context, 2, 1, True) 310 | 311 | 312 | class TestPostproc: 313 | """Tests for :class:`katsdpingest.sigproc.Postproc`""" 314 | 315 | def test_bad_cont_factor(self): 316 | """Test with a continuum factor that does not divide into the channel count""" 317 | template = mock.sentinel.template 318 | template.continuum = True 319 | mock.sentinel.command_queue.context = mock.sentinel.context 320 | assert_raises(ValueError, sigproc.Postproc, template, mock.sentinel.command_queue, 12, 8, 8) 321 | 322 | def _test_postproc(self, context, queue, excise, continuum): 323 | channels = 1024 324 | baselines = 512 325 | cont_factor = 16 326 | rs = np.random.RandomState(1) 327 | vis_in = random_vis(rs, (channels, baselines)) 328 | weights_in = rs.uniform(0.5, 2.0, (channels, baselines)).astype(np.float32) 329 | flags_in = random_flags(rs, (channels, baselines), 8, 0.2) 330 | # Ensure that we test the case of none flagged and all flagged when 331 | # doing continuum reduction 332 | flags_in[:, 123] = 1 333 | flags_in[:, 234] = UNFLAGGED_BIT 334 | # Test the flush-to-zero behaviour 335 | flags_in[:cont_factor, 1] &= ~np.uint8(UNFLAGGED_BIT) 336 | flags_in[1, 1] = UNFLAGGED_BIT 337 | weights_in[1, 1] = 2.0 338 | vis_in[1, 1] = 1e-10 + 1e-10j 339 | if excise: 340 | # Where UNFLAGGED_BIT is not set, weights should be much smaller 341 | scale = np.where(flags_in & UNFLAGGED_BIT, 1, FLAG_SCALE) 342 | vis_in *= scale 343 | weights_in *= scale 344 | 345 | template = sigproc.PostprocTemplate(context, UNFLAGGED_BIT, excise, continuum) 346 | fn = sigproc.Postproc(template, queue, channels, baselines, cont_factor) 347 | fn.ensure_all_bound() 348 | fn.buffer('vis').set(queue, vis_in) 349 | fn.buffer('weights').set(queue, weights_in) 350 | fn.buffer('flags').set(queue, flags_in) 351 | fn() 352 | 353 | # Compute expected spectral values 354 | expected_vis = vis_in / weights_in 355 | 356 | # Compute expected continuum values. This is done even if continuum is 357 | # disabled, just to keep the code simple. 358 | indices = list(range(0, channels, cont_factor)) 359 | cont_weights = np.add.reduceat(weights_in, indices, axis=0) 360 | cont_vis = np.add.reduceat(vis_in, indices, axis=0) / cont_weights 361 | cont_flags = np.bitwise_or.reduceat(flags_in, indices, axis=0) 362 | 363 | if excise: 364 | # Flagged visibilities have their weights re-scaled 365 | expected_weights = weights_in * np.where(flags_in & UNFLAGGED_BIT, 1, FLAG_SCALE_INV) 366 | cont_weights *= np.where(cont_flags & UNFLAGGED_BIT, 1, FLAG_SCALE_INV) 367 | # UNFLAGGED_BIT gets cleared 368 | cont_flags = np.where(cont_flags & UNFLAGGED_BIT, 0, cont_flags) 369 | expected_flags = np.where(flags_in & UNFLAGGED_BIT, 0, flags_in) 370 | # Gets flushed to zero 371 | expected_vis[1, 1] = 0 372 | cont_vis[0, 1] = 0 373 | else: 374 | expected_weights = weights_in 375 | expected_flags = flags_in 376 | 377 | # Verify results 378 | np.testing.assert_allclose(expected_vis, fn.buffer('vis').get(queue), rtol=1e-5) 379 | np.testing.assert_allclose(expected_weights, fn.buffer('weights').get(queue), rtol=1e-5) 380 | np.testing.assert_equal(expected_flags, fn.buffer('flags').get(queue)) 381 | if continuum: 382 | np.testing.assert_allclose(cont_vis, fn.buffer('cont_vis').get(queue), rtol=1e-5) 383 | np.testing.assert_allclose(cont_weights, 384 | fn.buffer('cont_weights').get(queue), 385 | rtol=1e-5) 386 | np.testing.assert_equal(cont_flags, fn.buffer('cont_flags').get(queue)) 387 | 388 | @device_test 389 | def test_postproc(self, context, queue): 390 | """Test with random data against a CPU implementation (with excision)""" 391 | self._test_postproc(context, queue, True, True) 392 | 393 | @device_test 394 | def test_postproc_no_excise(self, context, queue): 395 | """Test with random data against a CPU implementation (no excision)""" 396 | self._test_postproc(context, queue, False, True) 397 | 398 | @device_test 399 | def test_postproc_no_continuum(self, context, queue): 400 | """Test with random data against a CPU implementation (no continuum)""" 401 | self._test_postproc(context, queue, True, False) 402 | 403 | @device_test 404 | @force_autotune 405 | def test_autotune(self, context, queue): 406 | sigproc.PostprocTemplate(context, 128, False, True) 407 | sigproc.PostprocTemplate(context, 128, True, False) 408 | sigproc.PostprocTemplate(context, 128, True, True) 409 | 410 | 411 | class TestCompressWeights: 412 | """Tests for :class:`katsdpingest.sigproc.CompressWeights`""" 413 | @device_test 414 | def test_simple(self, context, queue): 415 | """Test with random data against a CPU implementation""" 416 | channels = 123 417 | baselines = 235 418 | rs = np.random.RandomState(1) 419 | weights_in = rs.uniform(0.01, 1000.0, (channels, baselines)).astype(np.float32) 420 | 421 | template = sigproc.CompressWeightsTemplate(context) 422 | fn = template.instantiate(queue, channels, baselines) 423 | fn.ensure_all_bound() 424 | fn.buffer('weights_in').set(queue, weights_in) 425 | fn.buffer('weights_out').zero(queue) 426 | fn.buffer('weights_channel').zero(queue) 427 | fn() 428 | 429 | expected_channel = np.max(weights_in, axis=1) * np.float32(1.0 / 255.0) 430 | scale = np.reciprocal(expected_channel)[..., np.newaxis] 431 | expected_out = np.round(weights_in * scale).astype(np.uint8) 432 | np.testing.assert_allclose(expected_channel, 433 | fn.buffer('weights_channel').get(queue), 434 | rtol=1e-5) 435 | np.testing.assert_equal(expected_out, fn.buffer('weights_out').get(queue)) 436 | 437 | @device_test 438 | @force_autotune 439 | def test_autotune(self, context, queue): 440 | sigproc.CompressWeightsTemplate(context) 441 | 442 | 443 | class TestIngestOperation: 444 | flag_value = INGEST_RFI 445 | unflagged_bit = CAL_RFI 446 | 447 | @mock.patch('katsdpsigproc.tune.autotuner_impl', new=tune.stub_autotuner) 448 | @mock.patch('katsdpsigproc.accel.build', spec=True) 449 | def test_descriptions(self, *args): 450 | channels = 128 451 | channel_range = Range(16, 96) 452 | count_flags_channel_range = Range(8, 104) 453 | cbf_baselines = 220 454 | baselines = 192 455 | masks = 3 456 | 457 | context = mock.Mock() 458 | command_queue = mock.Mock() 459 | background_template = rfi.BackgroundMedianFilterDeviceTemplate( 460 | context, width=13) 461 | noise_est_template = rfi.NoiseEstMADTDeviceTemplate( 462 | context, 10240) 463 | threshold_template = rfi.ThresholdSimpleDeviceTemplate( 464 | context, transposed=True, flag_value=self.flag_value) 465 | flagger_template = rfi.FlaggerDeviceTemplate( 466 | background_template, noise_est_template, threshold_template) 467 | template = sigproc.IngestTemplate(context, flagger_template, [8, 12], True, True) 468 | fn = template.instantiate( 469 | command_queue, channels, channel_range, count_flags_channel_range, 470 | cbf_baselines, baselines, masks, 471 | 8, 16, [(0, 8), (10, 22)], 472 | threshold_args={'n_sigma': 11.0}) 473 | 474 | expected = [ 475 | ('ingest', {'class': 'katsdpingest.sigproc.IngestOperation'}), 476 | ('ingest:prepare', { 477 | 'channels': 128, 'class': 'katsdpingest.sigproc.Prepare', 478 | 'in_baselines': 220, 'out_baselines': 192, 'n_accs': 1 479 | }), 480 | ('ingest:prepare_flags', { 481 | 'baselines': 192, 'channels': 128, 482 | 'class': 'katsdpingest.sigproc.PrepareFlags', 483 | 'masks': 3, 'zero_flag': CAM 484 | }), 485 | ('ingest:init_weights', { 486 | 'class': 'katsdpsigproc.fill.Fill', 'shape': (192, 80), 487 | 'dtype': np.float32, 'ctype': 'float', 'value': 0.0 488 | }), 489 | ('ingest:zero_spec', { 490 | 'channels': 80, 'baselines': 192, 'class': 'katsdpingest.sigproc.Zero' 491 | }), 492 | ('ingest:zero_sd_spec', { 493 | 'channels': 80, 'baselines': 192, 'class': 'katsdpingest.sigproc.Zero' 494 | }), 495 | ('ingest:transpose_vis', { 496 | 'class': 'katsdpsigproc.transpose.Transpose', 497 | 'ctype': 'float2', 'dtype': 'complex64', 'shape': (192, 128) 498 | }), 499 | ('ingest:flagger', {'class': 'katsdpsigproc.rfi.device.FlaggerDevice'}), 500 | ('ingest:flagger:background', { 501 | 'baselines': 192, 'channels': 128, 502 | 'class': 'katsdpsigproc.rfi.device.BackgroundMedianFilterDevice', 503 | 'use_flags': 'NONE', 'width': 13 504 | }), 505 | ('ingest:flagger:transpose_deviations', { 506 | 'class': 'katsdpsigproc.transpose.Transpose', 507 | 'ctype': 'float', 'dtype': 'float32', 'shape': (128, 192) 508 | }), 509 | ('ingest:flagger:noise_est', { 510 | 'baselines': 192, 'channels': 128, 511 | 'class': 'katsdpsigproc.rfi.device.NoiseEstMADTDevice', 512 | 'max_channels': 10240 513 | }), 514 | ('ingest:flagger:threshold', { 515 | 'baselines': 192, 'channels': 128, 516 | 'class': 'katsdpsigproc.rfi.device.ThresholdSimpleDevice', 517 | 'flag_value': 16, 'n_sigma': 11.0, 'transposed': True 518 | }), 519 | ('ingest:flagger:transpose_flags', { 520 | 'class': 'katsdpsigproc.transpose.Transpose', 521 | 'ctype': 'unsigned char', 'dtype': 'uint8', 'shape': (192, 128) 522 | }), 523 | ('ingest:merge_flags', { 524 | 'channels': 128, 'baselines': 192, 525 | 'class': 'katsdpingest.sigproc.MergeFlags' 526 | }), 527 | ('ingest:count_flags', { 528 | 'channels': 128, 'baselines': 192, 'channel_range': (8, 104), 529 | 'class': 'katsdpingest.sigproc.CountFlags', 'mask': 191 530 | }), 531 | ('ingest:accum', { 532 | 'baselines': 192, 'channel_range': (16, 96), 'channels': 128, 533 | 'class': 'katsdpingest.sigproc.Accum', 'excise': True, 534 | 'outputs': 2, 'unflagged_bit': 64 535 | }), 536 | ('ingest:finalise', {'class': 'katsdpingest.sigproc.Finalise'}), 537 | ('ingest:finalise:postproc', { 538 | 'baselines': 192, 'channels': 80, 'class': 'katsdpingest.sigproc.Postproc', 539 | 'cont_factor': 8, 'continuum': True, 'excise': True, 'unflagged_bit': 64 540 | }), 541 | ('ingest:finalise:compress_weights_spec', { 542 | 'baselines': 192, 'channels': 80, 'class': 'katsdpingest.sigproc.CompressWeights' 543 | }), 544 | ('ingest:finalise:compress_weights_cont', { 545 | 'baselines': 192, 'channels': 10, 'class': 'katsdpingest.sigproc.CompressWeights' 546 | }), 547 | ('ingest:sd_finalise', {'class': 'katsdpingest.sigproc.Finalise'}), 548 | ('ingest:sd_finalise:postproc', { 549 | 'baselines': 192, 'channels': 80, 550 | 'class': 'katsdpingest.sigproc.Postproc', 'cont_factor': 16, 551 | 'continuum': True, 'excise': True, 'unflagged_bit': 64 552 | }), 553 | ('ingest:sd_finalise:compress_weights_spec', { 554 | 'baselines': 192, 'channels': 80, 'class': 'katsdpingest.sigproc.CompressWeights' 555 | }), 556 | ('ingest:sd_finalise:compress_weights_cont', { 557 | 'baselines': 192, 'channels': 5, 'class': 'katsdpingest.sigproc.CompressWeights' 558 | }), 559 | ('ingest:timeseries', { 560 | 'class': 'katsdpsigproc.maskedsum.MaskedSum', 'shape': (80, 192), 561 | 'use_amplitudes': False 562 | }), 563 | ('ingest:timeseriesabs', { 564 | 'class': 'katsdpsigproc.maskedsum.MaskedSum', 'shape': (80, 192), 565 | 'use_amplitudes': True 566 | }), 567 | ('ingest:percentile0', { 568 | 'class': 'katsdpsigproc.percentile.Percentile5', 'column_range': (0, 8), 569 | 'is_amplitude': False, 'max_columns': 8, 'shape': (80, 192) 570 | }), 571 | ('ingest:percentile0_flags', { 572 | 'class': 'katsdpsigproc.reduce.HReduce', 'column_range': (0, 8), 573 | 'ctype': 'unsigned char', 'dtype': np.uint8, 574 | 'extra_code': '', 'identity': '0', 'op': 'a | b', 'shape': (80, 192) 575 | }), 576 | ('ingest:percentile1', { 577 | 'class': 'katsdpsigproc.percentile.Percentile5', 'column_range': (10, 22), 578 | 'is_amplitude': False, 'max_columns': 12, 'shape': (80, 192) 579 | }), 580 | ('ingest:percentile1_flags', { 581 | 'class': 'katsdpsigproc.reduce.HReduce', 'column_range': (10, 22), 582 | 'ctype': 'unsigned char', 'dtype': np.uint8, 583 | 'extra_code': '', 'identity': '0', 'op': 'a | b', 'shape': (80, 192) 584 | }) 585 | ] 586 | self.maxDiff = None 587 | assert_equal(expected, fn.descriptions()) 588 | 589 | def finalise_host(self, vis, flags, weights, excise): 590 | """Does the final steps of run_host_basic, for either the continuum or spectral 591 | product. The inputs are modified in-place. 592 | """ 593 | vis /= weights 594 | if excise: 595 | weights *= np.where(flags & self.unflagged_bit, 1, np.float32(2**64)) 596 | flags = np.where(flags & self.unflagged_bit, 0, flags) 597 | weights_channel = np.max(weights, axis=1) * np.float32(1.0 / 255.0) 598 | inv_weights_channel = np.float32(1.0) / weights_channel 599 | weights = (weights * inv_weights_channel[..., np.newaxis]).astype(np.uint8) 600 | return vis, flags, weights, weights_channel 601 | 602 | def run_host_basic(self, vis, channel_mask, channel_mask_idx, baseline_flags, 603 | n_accs, permutation, 604 | cont_factor, channel_range, count_flags_channel_range, n_sigma, excise): 605 | """Simple CPU implementation. All inputs and outputs are channel-major. 606 | There is no support for separate cadences for main and signal display 607 | products; instead, call the function twice with different time slices. 608 | No signal display calculations are performed, with the exception of 609 | flag counting. 610 | 611 | Parameters 612 | ---------- 613 | vis : array-like 614 | Input dump visibilities (first axis being time) 615 | channel_mask : array-like 616 | Baseline-dependent per-channel flags (indexed by time, baseline group and 617 | post-permutation baseline) 618 | channel_mask_idx : array-like 619 | Baseline group for each baseline to index channel_mask 620 | baseline_flags : array-like 621 | Input per-baseline flags (indexed by time and post-permutation baseline) 622 | n_accs : int 623 | Number of correlations accumulated in `vis` 624 | permutation : sequence 625 | Maps input baseline numbers to output numbers (with -1 indicating discard) 626 | cont_factor : int 627 | Number of spectral channels per continuum channel 628 | channel_range: :class:`katsdpingest.utils.Range` 629 | Range of channels to retain in the output 630 | count_flags_channel_range: :class:`katsdpingest.utils.Range` 631 | Range of channels for which to count flags. May be ``None`` if flag 632 | counting is not required. 633 | n_sigma : float 634 | Significance level for flagger 635 | excise : bool 636 | Excise flagged data 637 | 638 | Returns 639 | ------- 640 | dictionary, with the following keys: 641 | 642 | - spec_vis, spec_weights, spec_flags 643 | - cont_vis, cont_weights, cont_flags 644 | """ 645 | background = rfi_host.BackgroundMedianFilterHost(width=13) 646 | noise_est = rfi_host.NoiseEstMADHost() 647 | threshold = rfi_host.ThresholdSimpleHost(n_sigma=n_sigma, flag_value=self.flag_value) 648 | flagger = rfi_host.FlaggerHost(background, noise_est, threshold) 649 | 650 | vis = np.asarray(vis).astype(np.float32) 651 | # Scaling, and combine real and imaginary elements 652 | vis = vis[..., 0] + vis[..., 1] * 1j 653 | vis *= np.float32(1.0 / n_accs) 654 | # Baseline permutation 655 | new_baselines = np.sum(np.asarray(permutation) != -1) 656 | new_vis = np.empty(vis.shape[:-1] + (new_baselines,), np.complex64) 657 | weights = np.empty(new_vis.shape, np.float32) 658 | for old_idx, new_idx in enumerate(permutation): 659 | if new_idx != -1: 660 | new_vis[..., new_idx] = vis[..., old_idx] 661 | vis = new_vis 662 | # Compute initial weights 663 | weights[:] = np.float32(n_accs) 664 | # Compute flags 665 | flags = channel_mask[:, channel_mask_idx, :].transpose(0, 2, 1) 666 | for i in range(len(vis)): 667 | flags[i, ...] |= flagger(vis[i, ...]) 668 | flags[i, ...] |= baseline_flags[i, np.newaxis, :] 669 | # Apply flags to weights 670 | if excise: 671 | weights *= (flags == 0).astype(np.float32) + 2**-64 672 | # Mark unflagged visibilities 673 | flags |= np.where(flags == 0, self.unflagged_bit, 0).astype(np.uint8) 674 | # Count flags 675 | if count_flags_channel_range is not None: 676 | flag_counts = np.empty((new_baselines, 8), np.uint32) 677 | flag_any_counts = np.empty((new_baselines,), np.uint32) 678 | flags_to_count = flags[:, count_flags_channel_range.asslice(), :] & ~self.unflagged_bit 679 | for i in range(8): 680 | flag_counts[:, i] = np.count_nonzero(flags_to_count & (1 << i), axis=(0, 1)) 681 | flag_any_counts[:] = np.count_nonzero(flags_to_count, axis=(0, 1)) 682 | 683 | # Time accumulation 684 | vis = np.sum(vis * weights, axis=0) 685 | weights = np.sum(weights, axis=0) 686 | flags = np.bitwise_or.reduce(flags, axis=0) 687 | 688 | # Clip to the channel range 689 | rng = channel_range.asslice() 690 | vis = vis[rng, ...] 691 | weights = weights[rng, ...] 692 | flags = flags[rng, ...] 693 | 694 | # Continuum accumulation 695 | indices = list(range(0, vis.shape[0], cont_factor)) 696 | cont_vis = np.add.reduceat(vis, indices, axis=0) 697 | cont_weights = np.add.reduceat(weights, indices, axis=0) 698 | cont_flags = np.bitwise_or.reduceat(flags, indices, axis=0) 699 | 700 | # Finalisation 701 | spec_vis, spec_flags, spec_weights, spec_weights_channel = \ 702 | self.finalise_host(vis, flags, weights, excise) 703 | cont_vis, cont_flags, cont_weights, cont_weights_channel = \ 704 | self.finalise_host(cont_vis, cont_flags, cont_weights, excise) 705 | ans = { 706 | 'spec_vis': spec_vis, 707 | 'spec_flags': spec_flags, 708 | 'spec_weights': spec_weights, 709 | 'spec_weights_channel': spec_weights_channel, 710 | 'cont_vis': cont_vis, 711 | 'cont_flags': cont_flags, 712 | 'cont_weights': cont_weights, 713 | 'cont_weights_channel': cont_weights_channel 714 | } 715 | if count_flags_channel_range is not None: 716 | ans['flag_counts'] = flag_counts 717 | ans['flag_any_counts'] = flag_any_counts 718 | return ans 719 | 720 | def run_host( 721 | self, vis, channel_mask, channel_mask_idx, baseline_flags, 722 | n_vis, n_sd_vis, n_accs, permutation, 723 | cont_factor, sd_cont_factor, channel_range, count_flags_channel_range, 724 | n_sigma, excise, timeseries_weights, percentile_ranges): 725 | """Simple CPU implementation. All inputs and outputs are channel-major. 726 | There is no support for separate cadences for main and signal display 727 | products; instead, call the function twice with different time slices. 728 | 729 | Parameters 730 | ---------- 731 | vis : array-like 732 | Input dump visibilities (indexed by time, channel, baseline) 733 | channel_mask : array-like 734 | Baseline-dependent per-channel flags (indexed by time, baseline group and 735 | post-permutation baseline) 736 | channel_mask_idx : array-like 737 | Baseline group for each baseline to index channel_mask 738 | baseline_flags : array-like 739 | Input per-baseline flags (indexed by time and post-permutation baseline) 740 | n_vis : int 741 | number of dumps to use for main calculations 742 | n_sd_vis : int 743 | number of dumps to use for signal display calculations 744 | n_accs : int 745 | Number of visibilities accumulated in correlator 746 | permutation : sequence 747 | Maps input baseline numbers to output numbers (with -1 indicating discard) 748 | cont_factor : int 749 | Number of spectral channels per continuum channel 750 | sd_cont_factor : int 751 | Number of spectral channels per continuum channel, for signal displays 752 | channel_range : :class:`katsdpingest.utils.Range` 753 | Range of channels to retain in the output 754 | n_sigma : float 755 | Significance level for flagger 756 | excise : bool 757 | Excise flagged data 758 | timeseries_weights : 1D array of float 759 | Weights for masked timeseries averaging 760 | percentile_ranges : list of 2-tuples of int 761 | Range of baselines (after permutation) for each percentile product 762 | 763 | Returns 764 | ------- 765 | dictionary, with the following keys: 766 | 767 | - spec_vis, spec_weights, spec_flags 768 | - cont_vis, cont_weights, cont_flags 769 | - sd_spec_vis, sd_spec_weights, sd_spec_flags 770 | - sd_cont_vis, sd_cont_weights, sd_cont_flags 771 | - timeseries, timeseriesabs 772 | - percentileN (where N is a non-negative integer) 773 | """ 774 | expected = self.run_host_basic( 775 | vis[:n_vis], channel_mask[:n_vis], channel_mask_idx, baseline_flags[:n_vis], 776 | n_accs, permutation, 777 | cont_factor, channel_range, None, n_sigma, excise) 778 | sd_expected = self.run_host_basic( 779 | vis[:n_sd_vis], channel_mask[:n_sd_vis], channel_mask_idx, baseline_flags[:n_sd_vis], 780 | n_accs, permutation, 781 | sd_cont_factor, channel_range, count_flags_channel_range, n_sigma, excise) 782 | for (name, value) in sd_expected.items(): 783 | expected['sd_' + name] = value 784 | 785 | # Time series 786 | expected['timeseries'] = \ 787 | np.sum(expected['sd_spec_vis'] * timeseries_weights[..., np.newaxis], axis=0) 788 | expected['timeseriesabs'] = \ 789 | np.sum(np.abs(expected['sd_spec_vis']) * timeseries_weights[..., np.newaxis], axis=0) 790 | 791 | # Percentiles 792 | for i, (start, end) in enumerate(percentile_ranges): 793 | if start != end: 794 | percentile = np.percentile( 795 | np.abs(expected['sd_spec_vis'][..., start:end]), 796 | [0, 100, 25, 75, 50], axis=1, interpolation='lower') 797 | flags = np.bitwise_or.reduce( 798 | expected['sd_spec_flags'][..., start:end], axis=1) 799 | else: 800 | percentile = \ 801 | np.tile(np.nan, (5, expected['sd_spec_vis'].shape[0])).astype(np.float32) 802 | flags = np.zeros(expected['sd_spec_flags'].shape[0], np.uint8) 803 | expected['percentile{0}'.format(i)] = percentile 804 | expected['percentile{0}_flags'.format(i)] = flags 805 | 806 | return expected 807 | 808 | def _make_flagger_template(self, context): 809 | background_template = rfi.BackgroundMedianFilterDeviceTemplate( 810 | context, width=13) 811 | noise_est_template = rfi.NoiseEstMADTDeviceTemplate( 812 | context, 10240) 813 | threshold_template = rfi.ThresholdSimpleDeviceTemplate( 814 | context, transposed=True, flag_value=self.flag_value) 815 | flagger_template = rfi.FlaggerDeviceTemplate( 816 | background_template, noise_est_template, threshold_template) 817 | return flagger_template 818 | 819 | def _test_random(self, context, queue, excise, continuum): 820 | """Test with random data against a CPU implementation""" 821 | channels = 128 822 | channel_range = Range(16, 96) 823 | count_flags_channel_range = Range(8, 104) 824 | kept_channels = len(channel_range) 825 | cbf_baselines = 220 826 | baselines = 192 827 | masks = 3 828 | cont_factor = 4 829 | sd_cont_factor = 8 830 | n_accs = 64 831 | dumps = 4 832 | sd_dumps = 3 # Must currently be <= dumps, but could easily be fixed 833 | percentile_ranges = [(0, 10), (32, 40), (0, 0), (180, 192)] 834 | # Use a very low significance so that there will still be about 50% 835 | # flags after averaging 836 | n_sigma = -1.0 837 | 838 | rs = np.random.RandomState(seed=1) 839 | vis_in = \ 840 | rs.random_integers(-1000, 1000, (dumps, channels, cbf_baselines, 2)).astype(np.int32) 841 | permutation = rs.permutation(cbf_baselines).astype(np.int16) 842 | permutation[permutation >= baselines] = -1 843 | timeseries_weights = rs.random_integers(0, 1, kept_channels).astype(np.float32) 844 | timeseries_weights /= np.sum(timeseries_weights) 845 | channel_mask = random_flags(rs, (dumps, masks, channels), 2, p=0.05) 846 | channel_mask_idx = rs.randint(0, masks, baselines).astype(np.uint32) 847 | baseline_flags = random_flags(rs, (dumps, baselines), 2, p=0.05) 848 | 849 | flagger_template = self._make_flagger_template(context) 850 | template = sigproc.IngestTemplate(context, flagger_template, [0, 8, 12], excise, continuum) 851 | fn = template.instantiate( 852 | queue, channels, channel_range, count_flags_channel_range, 853 | cbf_baselines, baselines, masks, 854 | cont_factor, sd_cont_factor, percentile_ranges, 855 | threshold_args={'n_sigma': n_sigma}) 856 | fn.ensure_all_bound() 857 | fn.n_accs = n_accs 858 | fn.buffer('permutation').set(queue, permutation) 859 | fn.buffer('timeseries_weights').set(queue, timeseries_weights) 860 | 861 | data_keys = ['spec_vis', 'spec_weights', 'spec_weights_channel', 'spec_flags'] 862 | if continuum: 863 | data_keys.extend(['cont_vis', 'cont_weights', 'cont_weights_channel', 'cont_flags']) 864 | sd_keys = ['sd_spec_vis', 'sd_spec_weights', 'sd_spec_flags', 865 | 'sd_cont_vis', 'sd_cont_weights', 'sd_cont_flags', 866 | 'timeseries', 'timeseriesabs', 'sd_flag_counts', 'sd_flag_any_counts'] 867 | for i in range(len(percentile_ranges)): 868 | sd_keys.append('percentile{0}'.format(i)) 869 | sd_keys.append('percentile{0}_flags'.format(i)) 870 | for name in data_keys + sd_keys: 871 | fn.buffer(name).zero(queue) 872 | 873 | actual = {} 874 | fn.start_sum() 875 | fn.start_sd_sum() 876 | fn.buffer('channel_mask_idx').set(queue, channel_mask_idx) 877 | for i in range(max(dumps, sd_dumps)): 878 | fn.buffer('vis_in').set(queue, vis_in[i]) 879 | fn.buffer('channel_mask').set(queue, channel_mask[i]) 880 | fn.buffer('baseline_flags').set(queue, baseline_flags[i]) 881 | fn() 882 | if i + 1 == dumps: 883 | fn.end_sum() 884 | for name in data_keys: 885 | actual[name] = fn.buffer(name).get(queue) 886 | if i + 1 == sd_dumps: 887 | fn.end_sd_sum() 888 | for name in sd_keys: 889 | actual[name] = fn.buffer(name).get(queue) 890 | 891 | expected = self.run_host( 892 | vis_in, channel_mask, channel_mask_idx, baseline_flags, 893 | dumps, sd_dumps, n_accs, permutation, 894 | cont_factor, sd_cont_factor, channel_range, count_flags_channel_range, 895 | n_sigma, excise, timeseries_weights, percentile_ranges) 896 | 897 | for name in data_keys + sd_keys: 898 | err_msg = '{0} is not equal'.format(name) 899 | if expected[name].dtype in (np.dtype(np.float32), np.dtype(np.complex64)): 900 | np.testing.assert_allclose(expected[name], actual[name], 901 | rtol=1e-5, atol=1e-5, err_msg=err_msg) 902 | elif name.endswith('_weights'): 903 | # Integer parts of weights can end up slightly different due to rounding 904 | np.testing.assert_allclose(expected[name], actual[name], atol=1, err_msg=err_msg) 905 | else: 906 | np.testing.assert_equal(expected[name], actual[name], err_msg=err_msg) 907 | 908 | @device_test 909 | def test_random_excise(self, context, queue): 910 | """Test with random data against a CPU implementation (with excision)""" 911 | self._test_random(context, queue, True, True) 912 | 913 | @device_test 914 | def test_random_no_excise(self, context, queue): 915 | """Test with random data against a CPU implementation (without excision)""" 916 | self._test_random(context, queue, False, True) 917 | 918 | @device_test 919 | def test_random_no_continuum(self, context, queue): 920 | """Test with random data against a CPU implementation (without continuum averaging)""" 921 | self._test_random(context, queue, True, False) 922 | 923 | @device_test 924 | def test_zero_antenna(self, context, queue): 925 | """If all data for an antenna is zero, it must not cause NaNs in the output.""" 926 | channels = 4 927 | dumps = 2 928 | 929 | flagger_template = self._make_flagger_template(context) 930 | template = sigproc.IngestTemplate(context, flagger_template, [0, 2, 4], True, True) 931 | fn = template.instantiate( 932 | queue, channels, Range(0, channels), Range(0, channels), 933 | 4, 4, 1, 2, 2, [(0, 1), (1, 2), (2, 4)], 934 | threshold_args={'n_sigma': 3.0}) 935 | fn.ensure_all_bound() 936 | fn.n_accs = 1 937 | fn.buffer('permutation').set(queue, np.array([0, 1, 2, 3], dtype=np.int16)) 938 | fn.buffer('timeseries_weights').set(queue, np.full(channels, 1 / channels, np.float32)) 939 | 940 | fn.start_sum() 941 | fn.start_sd_sum() 942 | for i in range(dumps): 943 | fn.buffer('vis_in').zero(queue) 944 | fn.buffer('channel_mask').zero(queue) 945 | fn.buffer('channel_mask_idx').zero(queue) 946 | fn.buffer('baseline_flags').zero(queue) 947 | fn() 948 | fn.end_sum() 949 | fn.end_sd_sum() 950 | 951 | spec_vis = fn.buffer('spec_vis').get(queue) 952 | spec_flags = fn.buffer('spec_flags').get(queue) 953 | cont_vis = fn.buffer('cont_vis').get(queue) 954 | cont_flags = fn.buffer('cont_flags').get(queue) 955 | np.testing.assert_equal(0 + 0j, spec_vis) 956 | np.testing.assert_equal(CAM, spec_flags) 957 | np.testing.assert_equal(0 + 0j, cont_vis) 958 | np.testing.assert_equal(CAM, cont_flags) 959 | --------------------------------------------------------------------------------