├── .coveragerc ├── .flake8 ├── .github └── workflows │ ├── build-and-publish-docs.yml │ ├── build-and-run-tests.yml │ ├── publish-to-pypi.yml │ ├── run-examples.yml │ └── run-linting.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE.md ├── NOTICE.md ├── README.md ├── examples ├── courses │ ├── Primer 1 │ │ ├── Primer1.ipynb │ │ ├── spec_csi_many_sources.yaml │ │ ├── spec_nai_few_sources.yaml │ │ └── spec_nai_many_sources.yaml │ └── Primer 2 │ │ ├── Primer 2.ipynb │ │ └── seed_configs │ │ ├── advanced.yaml │ │ ├── basic.yaml │ │ └── problem.yaml ├── data │ ├── conversion │ │ ├── aipt_to_ss.py │ │ ├── pcf_to_ss.py │ │ └── topcoder_to_ss.py │ ├── difficulty_score.py │ ├── preprocessing │ │ └── energy_calibration.py │ └── synthesis │ │ ├── mix_seeds.py │ │ ├── synthesize_passbys.py │ │ ├── synthesize_seeds_advanced.py │ │ ├── synthesize_seeds_basic.py │ │ ├── synthesize_seeds_custom.py │ │ └── synthesize_static.py ├── modeling │ ├── anomaly_detection.py │ ├── arad.py │ ├── arad_latent_prediction.py │ ├── classifier_comparison.py │ ├── custom_loss_and_metric.py │ ├── label_proportion_estimation.py │ └── neural_network_classifier.py ├── run_examples.py └── visualization │ ├── confusion_matrix.py │ ├── distance_matrix.py │ ├── plot_sampleset_compare_to.py │ └── plot_spectra.py ├── pdoc ├── config.mako └── requirements.txt ├── pyproject.toml ├── riid ├── __init__.py ├── anomaly.py ├── data │ ├── __init__.py │ ├── converters │ │ ├── __init__.py │ │ ├── aipt.py │ │ └── topcoder.py │ ├── labeling.py │ ├── sampleset.py │ └── synthetic │ │ ├── __init__.py │ │ ├── base.py │ │ ├── passby.py │ │ ├── seed.py │ │ └── static.py ├── gadras │ ├── __init__.py │ ├── api.py │ ├── api_schema.json │ └── pcf.py ├── losses │ ├── __init__.py │ └── sparsemax.py ├── metrics.py ├── models │ ├── __init__.py │ ├── base.py │ ├── bayes.py │ ├── layers.py │ └── neural_nets │ │ ├── __init__.py │ │ ├── arad.py │ │ ├── basic.py │ │ └── lpe.py └── visualize.py ├── run_tests.ps1 ├── run_tests.sh └── tests ├── anomaly_tests.py ├── config_tests.py ├── data_tests.py ├── gadras_tests.py ├── model_tests.py ├── sampleset_tests.py ├── seedmixer_tests.py ├── staticsynth_tests.py └── visualize_tests.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | */site-packages/* 4 | */distutils/* 5 | tests/* 6 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | inline-quotes = double 4 | avoid-escape = False 5 | -------------------------------------------------------------------------------- /.github/workflows/build-and-publish-docs.yml: -------------------------------------------------------------------------------- 1 | name: Build docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | permissions: 7 | contents: read 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.10" 16 | cache: "pip" 17 | cache-dependency-path: "**/pyproject.toml" 18 | - run: pip install -e . 19 | - run: pip install -r pdoc/requirements.txt 20 | - run: pdoc riid -o docs/ --html --template-dir pdoc 21 | - run: echo '' > docs/index.html 22 | - uses: actions/upload-pages-artifact@v3 23 | with: 24 | path: docs/ 25 | deploy: 26 | needs: build 27 | runs-on: ubuntu-latest 28 | permissions: 29 | pages: write 30 | id-token: write 31 | environment: 32 | name: github-pages 33 | url: ${{ steps.deployment.outputs.page_url }} 34 | steps: 35 | - name: Deploy to GitHub Pages 36 | id: deployment 37 | uses: actions/deploy-pages@v4 38 | -------------------------------------------------------------------------------- /.github/workflows/build-and-run-tests.yml: -------------------------------------------------------------------------------- 1 | name: Run unit tests 2 | on: 3 | push: 4 | branches: 5 | - '**' 6 | schedule: 7 | - cron: '0 0 * * 1' # Trigger every Monday at midnight (UTC) 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 10 | cancel-in-progress: true 11 | jobs: 12 | build: 13 | strategy: 14 | matrix: 15 | python-version: ["3.10", "3.11", "3.12"] 16 | os: [ubuntu-latest, windows-latest, macos-latest] 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v4 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | cache: "pip" 26 | cache-dependency-path: "**/pyproject.toml" 27 | - name: Install HDF5 (macOS only) 28 | if: runner.os == 'macOS' 29 | run: | 30 | brew install hdf5 31 | brew install c-blosc2 32 | - name: Set HDF5_DIR environment variable (macOS only) 33 | if: runner.os == 'macOS' 34 | run: | 35 | echo "HDF5_DIR=$(brew --prefix hdf5)" >> $GITHUB_ENV 36 | echo "BLOSC2_DIR=$(brew --prefix c-blosc2)" >> $GITHUB_ENV 37 | - name: Install dependencies 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install ".[dev]" 41 | - name: Run unit tests 42 | run: | 43 | sh run_tests.sh 44 | if: github.event_name == 'push' || (github.event_name == 'schedule' && github.ref == 'refs/heads/main') 45 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | build-n-publish: 7 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 8 | runs-on: ubuntu-latest 9 | environment: 10 | name: pypi 11 | url: https://pypi.org/p/riid 12 | permissions: 13 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v4 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: "3.10" 21 | - name: Install pypa/build 22 | run: >- 23 | python -m 24 | pip install 25 | build 26 | --user 27 | - name: Build a binary wheel and a source tarball 28 | run: >- 29 | python -m 30 | build 31 | --sdist 32 | --wheel 33 | --outdir dist/ 34 | - name: Publish distribution 📦 to Test PyPI 35 | uses: pypa/gh-action-pypi-publish@release/v1 36 | with: 37 | repository-url: https://test.pypi.org/legacy/ 38 | - name: Publish distribution 📦 to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | -------------------------------------------------------------------------------- /.github/workflows/run-examples.yml: -------------------------------------------------------------------------------- 1 | name: Run examples 2 | on: [push] 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 5 | cancel-in-progress: true 6 | jobs: 7 | build: 8 | strategy: 9 | matrix: 10 | python-version: ["3.10", "3.11", "3.12"] 11 | os: [ubuntu-latest, windows-latest, macos-latest] 12 | runs-on: ${{ matrix.os }} 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v4 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | cache: "pip" 21 | cache-dependency-path: "**/pyproject.toml" 22 | - name: Install HDF5 (macOS only) 23 | if: runner.os == 'macOS' 24 | run: | 25 | brew install hdf5 26 | brew install c-blosc2 27 | - name: Set HDF5_DIR environment variable (macOS only) 28 | if: runner.os == 'macOS' 29 | run: | 30 | echo "HDF5_DIR=$(brew --prefix hdf5)" >> $GITHUB_ENV 31 | echo "BLOSC2_DIR=$(brew --prefix c-blosc2)" >> $GITHUB_ENV 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip setuptools wheel 35 | pip install -e ".[dev]" 36 | - name: Run examples 37 | run: | 38 | python examples/run_examples.py 39 | -------------------------------------------------------------------------------- /.github/workflows/run-linting.yml: -------------------------------------------------------------------------------- 1 | name: Run linting 2 | on: [push] 3 | concurrency: 4 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} 5 | cancel-in-progress: true 6 | jobs: 7 | linting: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - name: Set up Python 12 | uses: actions/setup-python@v4 13 | with: 14 | python-version: "3.10" 15 | cache: "pip" 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip setuptools wheel 19 | pip install flake8 flake8-quotes 20 | - name: Run Flake8 linter 21 | run: | 22 | flake8 riid examples tests 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project-specific 2 | *.zip 3 | *.p 4 | *.png 5 | *.jpg 6 | *.jpeg 7 | *.pdf 8 | *.smpl 9 | *.h5 10 | *.pcf 11 | *.onnx 12 | *.tflite 13 | *.hdf 14 | examples/**/*.json 15 | */__pycache__ 16 | _scripts/ 17 | debug/ 18 | *.whl 19 | docs/ 20 | 21 | # Mac-specific 22 | *.DS_Store 23 | ._* 24 | .ipynb_checkpoints 25 | 26 | # VS code 27 | .vscode 28 | 29 | # History files 30 | .Rhistory 31 | .Rapp.history 32 | 33 | # Session Data files 34 | .RData 35 | 36 | # Example code in package build process 37 | *-Ex.R 38 | 39 | # Output files from R CMD build 40 | /*.tar.gz 41 | 42 | # Output files from R CMD check 43 | /*.Rcheck/ 44 | 45 | # RStudio files 46 | .Rproj.user/ 47 | 48 | # produced vignettes 49 | vignettes/*.html 50 | vignettes/*.pdf 51 | 52 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 53 | .httr-oauth 54 | 55 | # knitr and R markdown default cache directories 56 | /*_cache/ 57 | /cache/ 58 | 59 | # Temporary files created by R markdown 60 | *.utf8.md 61 | *.knit.md 62 | 63 | # Emacs 64 | *~ 65 | 66 | # Byte-compiled / optimized / DLL files 67 | __pycache__/ 68 | *.py[cod] 69 | *$py.class 70 | 71 | # C extensions 72 | *.so 73 | 74 | # Distribution / packaging 75 | .Python 76 | build/ 77 | develop-eggs/ 78 | dist/ 79 | downloads/ 80 | eggs/ 81 | .eggs/ 82 | lib/ 83 | lib64/ 84 | parts/ 85 | sdist/ 86 | var/ 87 | wheels/ 88 | *.egg-info/ 89 | .installed.cfg 90 | *.egg 91 | MANIFEST 92 | 93 | # PyInstaller 94 | # Usually these files are written by a python script from a template 95 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 96 | *.manifest 97 | *.spec 98 | 99 | # Installer logs 100 | pip-log.txt 101 | pip-delete-this-directory.txt 102 | 103 | # Unit test / coverage reports 104 | htmlcov/ 105 | .tox/ 106 | .coverage 107 | .coverage.* 108 | .cache 109 | nosetests.xml 110 | coverage.xml 111 | *.cover 112 | .hypothesis/ 113 | 114 | # Translations 115 | *.mo 116 | *.pot 117 | 118 | # Django stuff: 119 | *.log 120 | .static_storage/ 121 | .media/ 122 | local_settings.py 123 | 124 | # Flask stuff: 125 | instance/ 126 | .webassets-cache 127 | 128 | # Scrapy stuff: 129 | .scrapy 130 | 131 | # Sphinx documentation 132 | docs/_build/ 133 | 134 | # PyBuilder 135 | target/ 136 | 137 | # Jupyter Notebook 138 | .ipynb_checkpoints 139 | 140 | # pyenv 141 | .python-version 142 | 143 | # celery beat schedule file 144 | celerybeat-schedule 145 | 146 | # SageMath parsed files 147 | *.sage.py 148 | 149 | # Environments 150 | .env 151 | .venv 152 | env/ 153 | venv/ 154 | ENV/ 155 | env.bak/ 156 | venv.bak/ 157 | venv_*/ 158 | *_venv/ 159 | 160 | # Spyder project settings 161 | .spyderproject 162 | .spyproject 163 | 164 | # Rope project settings 165 | .ropeproject 166 | 167 | # mkdocs documentation 168 | /site 169 | 170 | # mypy 171 | .mypy_cache/ 172 | 173 | # files being used for development only 174 | sandbox.py 175 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | [INSERT CONTACT METHOD]. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available 126 | at [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | 134 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 12 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /NOTICE.md: -------------------------------------------------------------------------------- 1 | This source code is part of the PyRIID project and is licensed under the BSD-style licence. 2 | This project also contains code covered under the Apache-2.0 license based on Tensorflow-Addons functions which can be found in `riid/models/losses/sparsemax.py`. 3 | 4 | The following is a list of the relevent copyright and license information. 5 | 6 | --- 7 | 8 | Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 9 | Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights in this software. 10 | This source code is licensed under the BSD-style license found [here](https://github.com/sandialabs/PyRIID/blob/main/LICENSE.md). 11 | 12 | --- 13 | 14 | Copyright 2016 The TensorFlow Authors. All Rights Reserved. 15 | 16 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. 17 | You may obtain a copy of the License at 18 | 19 | http://www.apache.org/licenses/LICENSE-2.0 20 | 21 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | See the License for the specific language governing permissions and limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | PyRIID 3 |

4 | 5 | ![Python Version from PEP 621 TOML](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Fsandialabs%2FPyRIID%2Frefs%2Fheads%2Fmain%2Fpyproject.toml) 6 | ![PyPI](https://badge.fury.io/py/riid.svg) 7 | 8 | PyRIID is a Python package providing modeling and data synthesis utilities for machine learning-based research and development of radioisotope-related detection, identification, and quantification. 9 | 10 | ## Installation 11 | 12 | Requirements: 13 | 14 | - Python version: 3.10 to 3.12 15 | - Note: we recommended the highest Python version you can manage as anecdotally, we have noticed that everything just tends to get faster. 16 | - Operating systems: Windows, Mac, or Ubuntu 17 | 18 | Tests and examples are run via Actions on many combinations of Python version and operating system. 19 | You can verify support for your platform by checking the workflow files. 20 | 21 | ### For Use 22 | 23 | To use the latest version on PyPI, run: 24 | 25 | ```sh 26 | pip install riid 27 | ``` 28 | 29 | Note that changes are slower to appear on PyPI, so for the latest features, run:** 30 | 31 | ```sh 32 | pip install git+https://github.com/sandialabs/pyriid.git@main 33 | ``` 34 | 35 | ### For Development 36 | 37 | If you are developing PyRIID, clone this repository and run: 38 | 39 | ```sh 40 | pip install -e ".[dev]" 41 | ``` 42 | 43 | If you encounter Pylance issues, try: 44 | 45 | ```sh 46 | pip install -e ".[dev]" --config-settings editable_mode=compat 47 | ``` 48 | 49 | ## Examples 50 | 51 | Examples for how to use this package can be found [here](https://github.com/sandialabs/PyRIID/blob/main/examples). 52 | 53 | ## Tests 54 | 55 | Unit tests for this package can be found [here](https://github.com/sandialabs/PyRIID/blob/main/tests). 56 | 57 | Run all unit tests with the following: 58 | 59 | ```sh 60 | python -m unittest tests/*.py -v 61 | ``` 62 | 63 | You can also run one of the `run_tests.*` scripts, whichever is appropriate for your platform. 64 | 65 | ## Docs 66 | 67 | API documentation can be found [here](https://sandialabs.github.io/PyRIID). 68 | 69 | Docs can be built locally with the following: 70 | 71 | ```sh 72 | pip install -r pdoc/requirements.txt 73 | pdoc riid -o docs/ --html --template-dir pdoc 74 | ``` 75 | 76 | ## Contributing 77 | 78 | Pull requests are welcome. 79 | For major changes, please open an issue first to discuss what you would like to change. 80 | 81 | Please make sure to update tests as appropriate and adhere to our [code of conduct](https://github.com/sandialabs/PyRIID/blob/main/CODE_OF_CONDUCT.md). 82 | 83 | ## Contacts 84 | 85 | Maintainers and authors can be found [here](https://github.com/sandialabs/PyRIID/blob/main/pyproject.toml). 86 | 87 | ## Copyright 88 | 89 | Full copyright details can be found [here](https://github.com/sandialabs/PyRIID/blob/main/NOTICE.md). 90 | 91 | ## Acknowledgements 92 | 93 | **Thank you** to the U.S. Department of Energy, National Nuclear Security Administration, 94 | Office of Defense Nuclear Nonproliferation Research and Development (DNN R&D) for funding that has led to versions `2.0` and `2.1`. 95 | 96 | Additionally, **thank you** to the following individuals who have provided invaluable subject-matter expertise: 97 | 98 | - Paul Thelen (also an author) 99 | - Ben Maestas 100 | - Greg Thoreson 101 | - Michael Enghauser 102 | - Elliott Leonard 103 | 104 | ## Citing 105 | 106 | When citing PyRIID, please reference the U.S. Department of Energy Office of Science and Technology Information (OSTI) record here: 107 | [10.11578/dc.20221017.2](https://doi.org/10.11578/dc.20221017.2) 108 | 109 | ## Related Reports, Publications, and Projects 110 | 111 | 1. Alan Van Omen, *"A Semi-Supervised Model for Multi-Label Radioisotope Classification and Out-of-Distribution Detection."* Diss. 2023. doi: [10.7302/7200](https://dx.doi.org/10.7302/7200). 112 | 2. Tyler Morrow, *"Questionnaire for Radioisotope Identification and Estimation from Gamma Spectra using PyRIID v2."* United States: N. p., 2023. Web. doi: [10.2172/2229893](https://doi.org/10.2172/2229893). 113 | 3. Aaron Fjeldsted, Tyler Morrow, and Douglas Wolfe, *"Identifying Signal-to-Noise Ratios Representative of Gamma Detector Response in Realistic Scenarios,"* 2023 IEEE Nuclear Science Symposium, Medical Imaging Conference and International Symposium on Room-Temperature Semiconductor Detectors (NSS MIC RTSD), Vancouver, BC, Canada, 2023. doi: [10.1109/NSSMICRTSD49126.2023.10337860](https://doi.org/10.1109/NSSMICRTSD49126.2023.10337860). 114 | 4. Alan Van Omen and Tyler Morrow, *"A Semi-supervised Learning Method to Produce Explainable Radioisotope Proportion Estimates for NaI-based Synthetic and Measured Gamma Spectra."* United States: N. p., 2024. Web. doi: [10.2172/2335904](https://doi.org/10.2172/2335904). 115 | - [Code, data, and best model](https://zenodo.org/doi/10.5281/zenodo.10223445) 116 | 5. Alan Van Omen and Tyler Morrow, *"Controlling Radioisotope Proportions When Randomly Sampling from Dirichlet Distributions in PyRIID."* United States: N. p., 2024. Web. doi: [10.2172/2335905](https://doi.org/10.2172/2335905). 117 | 6. Alan Van Omen, Tyler Morrow, et al., *"Multilabel Proportion Prediction and Out-of-distribution Detection on Gamma Spectra of Short-lived Fission Products."* Annals of Nuclear Energy 208 (2024): 110777. doi: [10.1016/j.anucene.2024.110777](https://doi.org/10.1016/j.anucene.2024.110777). 118 | - [Code, data, and best models](https://zenodo.org/doi/10.5281/zenodo.12796964) 119 | 7. Aaron Fjeldsted, Tyler Morrow, et al., *"A Novel Methodology for Gamma-Ray Spectra Dataset Procurement over Varying Standoff Distances and Source Activities,"* Nuclear Instruments and Methods in Physics Research Section A (2024): 169681. doi: [10.1016/j.nima.2024.169681](https://doi.org/10.1016/j.nima.2024.169681). 120 | -------------------------------------------------------------------------------- /examples/courses/Primer 1/spec_csi_many_sources.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | gamma_detector: 3 | name: Generic\\CsI\\2x4x16 4 | parameters: 5 | distance_cm: 1000 6 | height_cm: 45 7 | dead_time_per_pulse: 5 8 | latitude_deg: 35.0 9 | longitude_deg: 253.4 10 | elevation_m: 1620 11 | sources: 12 | - isotope: Am241 13 | configurations: 14 | - Am241,100uC 15 | - Am241,100uC {13,10} 16 | - Am241,100uC {13,30} 17 | - Am241,100uC {26,5} 18 | - Am241,100uC {26,20} 19 | - Am241,100uC {26,50} 20 | - Am241,100uC {50,2} 21 | - isotope: Ba133 22 | configurations: 23 | - Ba133,100uC 24 | - Ba133,100uC {10,20} 25 | - Ba133,100uC {26,10} 26 | - Ba133,100uC {26,30} 27 | - Ba133,100uC {74,20} 28 | - Ba133,100uC {50,30} 29 | - isotope: Bi207 30 | configurations: 31 | - Bi207,100uC 32 | - Bi207,100uC {10,30} 33 | - Bi207,100uC {26,10} 34 | - isotope: Cf249 35 | configurations: 36 | - Cf249,100uC 37 | - isotope: Co57 38 | configurations: 39 | - Co57,100uC 40 | - Co57,100uC {13,10} 41 | - isotope: Co60 42 | configurations: 43 | - Co60,100uC 44 | - Co60,100uC {10,10} 45 | - Co60,100uC {10,30} 46 | - Co60,100uC {26,20} 47 | - Co60,100uC {26,40} 48 | - Co60,100uC {82,30} 49 | - Co60,100uC {82,60} 50 | - isotope: Cosmic 51 | configurations: 52 | - Cosmic 53 | - isotope: Cs137 54 | configurations: 55 | - Cs137,100uC 56 | - Cs137,100uC {6,2} 57 | - Cs137,100uC {26,10} 58 | - Cs137,100uC {82,10} 59 | - Cs137,100uC {13,240;26,1} 60 | - Cs137InPine 61 | - Cs137InLead 62 | - Cs137InGradedShield1 63 | - Cs137InGradedShield2 64 | - isotope: Eu152 65 | configurations: 66 | - Eu152,100uC 67 | - Eu152,100uC {10,10} 68 | - Eu152,100uC {10,30} 69 | - Eu152,100uC {30,50} 70 | - Eu152,100uC {74,20} 71 | - isotope: Eu154 72 | configurations: 73 | - Eu154,100uC 74 | - Eu154,100uC {10,10} 75 | - Eu154,100uC {74,20} 76 | - isotope: F18 77 | configurations: 78 | - F18,100uC 79 | - F18,100uC {10,20} 80 | - F18,100uC {10,50} 81 | - F18,100uC {26,30} 82 | - isotope: Ga67 83 | configurations: 84 | - Ga67,100uC 85 | - Ga67,100uC {6,10} 86 | - Ga67,100uC {6,20} 87 | - Ga67,100uC {10,30} 88 | - Ga67,100uC {50,16} 89 | - Ga67,100uC {82,20} 90 | - isotope: Ho166m 91 | configurations: 92 | - Ho166m,100uC 93 | - Ho166m,100uC {10,20} 94 | - Ho166m,100uC {26,20} 95 | - Ho166m,100uC {74,30} 96 | - isotope: I123 97 | configurations: 98 | - I123,100uC 99 | - I123,100uC {10,30} 100 | - isotope: I131 101 | configurations: 102 | - I131,100uC 103 | - I131,100uC {10,10} 104 | - I131,100uC {10,30} 105 | - I131,100uC {16,50} 106 | - I131,100uC {20,20} 107 | - I131,100uC {82,10} 108 | - I131,100uC {10,20;50,5} 109 | - isotope: In111 110 | configurations: 111 | - In111,100uC 112 | - In111,100uC {10,20} 113 | - In111,100uC {50,20} 114 | - isotope: Ir192 115 | configurations: 116 | - Ir192,100uC 117 | - Ir192,100uC {10,20} 118 | - Ir192,100uC {26,40} 119 | - Ir192,100uC {26,100} 120 | - Ir192,100uC {82,30} 121 | - Ir192,100uC {82,160} 122 | - Ir192Shielded1 123 | - Ir192Shielded2 124 | - Ir192Shielded3 125 | - isotope: K40 126 | configurations: 127 | - PotassiumInSoil 128 | - K40,100uC 129 | - K40inCargo1 130 | - K40inCargo2 131 | - K40inCargo3 132 | - K40inCargo4 133 | - isotope: Mo99 134 | configurations: 135 | - Mo99,100uC 136 | - Mo99,100uC {26,20} 137 | - Mo99,100uC {50,40} 138 | - isotope: Na22 139 | configurations: 140 | - Na22,100uC 141 | - Na22,100uC {10,10} 142 | - Na22,100uC {10,30} 143 | - Na22,100uC {74,20} 144 | - isotope: Np237 145 | configurations: 146 | - 1gNp237,1kC 147 | - 1kgNp237 148 | - Np237inFe1 149 | - Np237inFe4 150 | - Np237Shielded2 151 | - isotope: Pu238 152 | configurations: 153 | - Pu238,100uC 154 | - Pu238,100uC {10,5} 155 | - Pu238,100uC {26,10} 156 | - isotope: Pu239 157 | configurations: 158 | - 1kgPu239 159 | - 1kgPu239,1C {40,4} 160 | - 1kgPu239inFe 161 | - 1kgPu239inPine 162 | - 1kgPu239InFeAndPine 163 | - 1kgPu239inW 164 | - isotope: Ra226 165 | configurations: 166 | - UraniumInSoil 167 | - Ra226,100uC 168 | - Ra226inCargo1 169 | - Ra226inCargo2 170 | - Ra226inCargo3 171 | - Ra226inSilica1 172 | - Ra226inSilica2 173 | - Ra226,100uC {60,100} 174 | - isotope: Sr90 175 | configurations: 176 | - Sr90InPoly1,1C 177 | - Sr90InPoly10,1C 178 | - Sr90InFe,1C 179 | - Sr90InSn,1C 180 | - isotope: Tc99m 181 | configurations: 182 | - Tc99m,100uC 183 | - Tc99m,100uC {7,10} 184 | - Tc99m,100uC {10,20} 185 | - Tc99m,100uC {13,30} 186 | - Tc99m,100uC {26,30} 187 | - isotope: Th232 188 | configurations: 189 | - ThoriumInSoil 190 | - Th232,100uC 191 | - Th232inCargo1 192 | - Th232inCargo2 193 | - Th232inCargo3 194 | - Th232inSilica1 195 | - Th232inSilica2 196 | - Th232,100uC {60,100} 197 | - ThPlate 198 | - ThPlate+Thxray,10uC 199 | - isotope: Tl201 200 | configurations: 201 | - Tl201,100uC 202 | - Tl201,100uC {8,40} 203 | - Tl201,100uC {10,10} 204 | - Tl201,100uC {10,30} 205 | - Tl201,100uC {26,10} 206 | - isotope: U232 207 | configurations: 208 | - U232,100uC 209 | - U232,100uC {10,10} 210 | - U232,100uC {10,30} 211 | - U232,100uC {13,50} 212 | - U232,100uC {26,30} 213 | - U232,100uC {26,60} 214 | - U232inLeadAndPine 215 | - U232inLeadAndFe 216 | - U232inPineAndFe 217 | - U232,100uC {82,30} 218 | - isotope: U233 219 | configurations: 220 | - 1kgU233At1yr 221 | - 1kgU233InFeAt1yr 222 | - 1kgU233At50yr 223 | - 1kgU233InFeAt50yr 224 | - isotope: U235 225 | configurations: 226 | - 1kgU235 227 | - 1kgU235inFe 228 | - 1kgU235inPine 229 | - 1kgU235inPineAndFe 230 | - isotope: U238 231 | configurations: 232 | - 1KGU238 233 | - 1KGU238Clad 234 | - 1KGU238inPine 235 | - 1KGU238inPine2 236 | - 1KGU238inPine3 237 | - 1KGU238inFe 238 | - 1KGU238inPineAndFe 239 | - 1KGU238inW 240 | - U238Fiesta 241 | - fiestaware 242 | - DUOxide 243 | - isotope: Y88 244 | configurations: 245 | - Y88,100uC 246 | - Y88,100uC {10,50} 247 | - Y88,100uC {26,30} 248 | - Y88,100uC {80,50} 249 | ... 250 | -------------------------------------------------------------------------------- /examples/courses/Primer 1/spec_nai_few_sources.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | gamma_detector: 3 | name: Generic\\NaI\\2x4x16 4 | parameters: 5 | distance_cm: 1000 6 | height_cm: 45 7 | dead_time_per_pulse: 5 8 | latitude_deg: 35.0 9 | longitude_deg: 253.4 10 | elevation_m: 1620 11 | sources: 12 | - isotope: Am241 13 | configurations: 14 | - Am241,100uC 15 | - isotope: Ba133 16 | configurations: 17 | - Ba133,100uC 18 | - isotope: Cs137 19 | configurations: 20 | - Cs137,100uC 21 | - isotope: K40 22 | configurations: 23 | - PotassiumInSoil 24 | - isotope: Ra226 25 | configurations: 26 | - UraniumInSoil 27 | - isotope: Th232 28 | configurations: 29 | - ThoriumInSoil 30 | ... 31 | -------------------------------------------------------------------------------- /examples/courses/Primer 1/spec_nai_many_sources.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | gamma_detector: 3 | name: Generic\\NaI\\2x4x16 4 | parameters: 5 | distance_cm: 1000 6 | height_cm: 45 7 | dead_time_per_pulse: 5 8 | latitude_deg: 35.0 9 | longitude_deg: 253.4 10 | elevation_m: 1620 11 | sources: 12 | - isotope: Am241 13 | configurations: 14 | - Am241,100uC 15 | - Am241,100uC {13,10} 16 | - Am241,100uC {13,30} 17 | - Am241,100uC {26,5} 18 | - Am241,100uC {26,20} 19 | - Am241,100uC {26,50} 20 | - Am241,100uC {50,2} 21 | - isotope: Ba133 22 | configurations: 23 | - Ba133,100uC 24 | - Ba133,100uC {10,20} 25 | - Ba133,100uC {26,10} 26 | - Ba133,100uC {26,30} 27 | - Ba133,100uC {74,20} 28 | - Ba133,100uC {50,30} 29 | - isotope: Bi207 30 | configurations: 31 | - Bi207,100uC 32 | - Bi207,100uC {10,30} 33 | - Bi207,100uC {26,10} 34 | - isotope: Cf249 35 | configurations: 36 | - Cf249,100uC 37 | - isotope: Co57 38 | configurations: 39 | - Co57,100uC 40 | - Co57,100uC {13,10} 41 | - isotope: Co60 42 | configurations: 43 | - Co60,100uC 44 | - Co60,100uC {10,10} 45 | - Co60,100uC {10,30} 46 | - Co60,100uC {26,20} 47 | - Co60,100uC {26,40} 48 | - Co60,100uC {82,30} 49 | - Co60,100uC {82,60} 50 | - isotope: Cosmic 51 | configurations: 52 | - Cosmic 53 | - isotope: Cs137 54 | configurations: 55 | - Cs137,100uC 56 | - Cs137,100uC {6,2} 57 | - Cs137,100uC {26,10} 58 | - Cs137,100uC {82,10} 59 | - Cs137,100uC {13,240;26,1} 60 | - Cs137InPine 61 | - Cs137InLead 62 | - Cs137InGradedShield1 63 | - Cs137InGradedShield2 64 | - isotope: Eu152 65 | configurations: 66 | - Eu152,100uC 67 | - Eu152,100uC {10,10} 68 | - Eu152,100uC {10,30} 69 | - Eu152,100uC {30,50} 70 | - Eu152,100uC {74,20} 71 | - isotope: Eu154 72 | configurations: 73 | - Eu154,100uC 74 | - Eu154,100uC {10,10} 75 | - Eu154,100uC {74,20} 76 | - isotope: F18 77 | configurations: 78 | - F18,100uC 79 | - F18,100uC {10,20} 80 | - F18,100uC {10,50} 81 | - F18,100uC {26,30} 82 | - isotope: Ga67 83 | configurations: 84 | - Ga67,100uC 85 | - Ga67,100uC {6,10} 86 | - Ga67,100uC {6,20} 87 | - Ga67,100uC {10,30} 88 | - Ga67,100uC {50,16} 89 | - Ga67,100uC {82,20} 90 | - isotope: Ho166m 91 | configurations: 92 | - Ho166m,100uC 93 | - Ho166m,100uC {10,20} 94 | - Ho166m,100uC {26,20} 95 | - Ho166m,100uC {74,30} 96 | - isotope: I123 97 | configurations: 98 | - I123,100uC 99 | - I123,100uC {10,30} 100 | - isotope: I131 101 | configurations: 102 | - I131,100uC 103 | - I131,100uC {10,10} 104 | - I131,100uC {10,30} 105 | - I131,100uC {16,50} 106 | - I131,100uC {20,20} 107 | - I131,100uC {82,10} 108 | - I131,100uC {10,20;50,5} 109 | - isotope: In111 110 | configurations: 111 | - In111,100uC 112 | - In111,100uC {10,20} 113 | - In111,100uC {50,20} 114 | - isotope: Ir192 115 | configurations: 116 | - Ir192,100uC 117 | - Ir192,100uC {10,20} 118 | - Ir192,100uC {26,40} 119 | - Ir192,100uC {26,100} 120 | - Ir192,100uC {82,30} 121 | - Ir192,100uC {82,160} 122 | - Ir192Shielded1 123 | - Ir192Shielded2 124 | - Ir192Shielded3 125 | - isotope: K40 126 | configurations: 127 | - PotassiumInSoil 128 | - K40,100uC 129 | - K40inCargo1 130 | - K40inCargo2 131 | - K40inCargo3 132 | - K40inCargo4 133 | - isotope: Mo99 134 | configurations: 135 | - Mo99,100uC 136 | - Mo99,100uC {26,20} 137 | - Mo99,100uC {50,40} 138 | - isotope: Na22 139 | configurations: 140 | - Na22,100uC 141 | - Na22,100uC {10,10} 142 | - Na22,100uC {10,30} 143 | - Na22,100uC {74,20} 144 | - isotope: Np237 145 | configurations: 146 | - 1gNp237,1kC 147 | - 1kgNp237 148 | - Np237inFe1 149 | - Np237inFe4 150 | - Np237Shielded2 151 | - isotope: Pu238 152 | configurations: 153 | - Pu238,100uC 154 | - Pu238,100uC {10,5} 155 | - Pu238,100uC {26,10} 156 | - isotope: Pu239 157 | configurations: 158 | - 1kgPu239 159 | - 1kgPu239,1C {40,4} 160 | - 1kgPu239inFe 161 | - 1kgPu239inPine 162 | - 1kgPu239InFeAndPine 163 | - 1kgPu239inW 164 | - isotope: Ra226 165 | configurations: 166 | - UraniumInSoil 167 | - Ra226,100uC 168 | - Ra226inCargo1 169 | - Ra226inCargo2 170 | - Ra226inCargo3 171 | - Ra226inSilica1 172 | - Ra226inSilica2 173 | - Ra226,100uC {60,100} 174 | - isotope: Sr90 175 | configurations: 176 | - Sr90InPoly1,1C 177 | - Sr90InPoly10,1C 178 | - Sr90InFe,1C 179 | - Sr90InSn,1C 180 | - isotope: Tc99m 181 | configurations: 182 | - Tc99m,100uC 183 | - Tc99m,100uC {7,10} 184 | - Tc99m,100uC {10,20} 185 | - Tc99m,100uC {13,30} 186 | - Tc99m,100uC {26,30} 187 | - isotope: Th232 188 | configurations: 189 | - ThoriumInSoil 190 | - Th232,100uC 191 | - Th232inCargo1 192 | - Th232inCargo2 193 | - Th232inCargo3 194 | - Th232inSilica1 195 | - Th232inSilica2 196 | - Th232,100uC {60,100} 197 | - ThPlate 198 | - ThPlate+Thxray,10uC 199 | - isotope: Tl201 200 | configurations: 201 | - Tl201,100uC 202 | - Tl201,100uC {8,40} 203 | - Tl201,100uC {10,10} 204 | - Tl201,100uC {10,30} 205 | - Tl201,100uC {26,10} 206 | - isotope: U232 207 | configurations: 208 | - U232,100uC 209 | - U232,100uC {10,10} 210 | - U232,100uC {10,30} 211 | - U232,100uC {13,50} 212 | - U232,100uC {26,30} 213 | - U232,100uC {26,60} 214 | - U232inLeadAndPine 215 | - U232inLeadAndFe 216 | - U232inPineAndFe 217 | - U232,100uC {82,30} 218 | - isotope: U233 219 | configurations: 220 | - 1kgU233At1yr 221 | - 1kgU233InFeAt1yr 222 | - 1kgU233At50yr 223 | - 1kgU233InFeAt50yr 224 | - isotope: U235 225 | configurations: 226 | - 1kgU235 227 | - 1kgU235inFe 228 | - 1kgU235inPine 229 | - 1kgU235inPineAndFe 230 | - isotope: U238 231 | configurations: 232 | - 1KGU238 233 | - 1KGU238Clad 234 | - 1KGU238inPine 235 | - 1KGU238inPine2 236 | - 1KGU238inPine3 237 | - 1KGU238inFe 238 | - 1KGU238inPineAndFe 239 | - 1KGU238inW 240 | - U238Fiesta 241 | - fiestaware 242 | - DUOxide 243 | - isotope: Y88 244 | configurations: 245 | - Y88,100uC 246 | - Y88,100uC {10,50} 247 | - Y88,100uC {26,30} 248 | - Y88,100uC {80,50} 249 | ... 250 | -------------------------------------------------------------------------------- /examples/courses/Primer 2/seed_configs/advanced.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | gamma_detector: 3 | name: tbd 4 | parameters: 5 | distance_cm: 6 | - 10 7 | - 100 8 | - 1000 9 | height_cm: 100 10 | dead_time_per_pulse: 10 11 | latitude_deg: 35.0 12 | longitude_deg: 253.4 13 | elevation_m: 1620 14 | sources: 15 | - isotope: Cs137 16 | configurations: 17 | - Cs137,100uCi 18 | - name: Cs137 19 | activity: 20 | - 1 21 | - 0.5 22 | activity_units: Ci 23 | shielding_atomic_number: 24 | min: 10 25 | max: 40.0 26 | dist: uniform 27 | num_samples: 5 28 | shielding_aerial_density: 29 | mean: 120 30 | std: 2 31 | num_samples: 5 32 | - isotope: Cosmic 33 | configurations: 34 | - Cosmic 35 | - isotope: K40 36 | configurations: 37 | - PotassiumInSoil 38 | - isotope: Ra226 39 | configurations: 40 | - UraniumInSoil 41 | - isotope: Th232 42 | configurations: 43 | - ThoriumInSoil 44 | ... -------------------------------------------------------------------------------- /examples/courses/Primer 2/seed_configs/basic.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | gamma_detector: 3 | name: Generic\\NaI\\3x3\\Front\\MidScat 4 | parameters: 5 | distance_cm: 100 6 | height_cm: 10 7 | dead_time_per_pulse: 5 8 | latitude_deg: 35.0 9 | longitude_deg: 253.4 10 | elevation_m: 1620 11 | sources: 12 | - isotope: Am241 13 | configurations: 14 | - Am241,100uC 15 | - isotope: Ba133 16 | configurations: 17 | - Ba133,100uC 18 | - isotope: Cs137 19 | configurations: 20 | - Cs137,100uC 21 | - isotope: Cosmic 22 | configurations: 23 | - Cosmic 24 | - isotope: K40 25 | configurations: 26 | - PotassiumInSoil 27 | - isotope: Ra226 28 | configurations: 29 | - UraniumInSoil 30 | - isotope: Th232 31 | configurations: 32 | - ThoriumInSoil 33 | ... -------------------------------------------------------------------------------- /examples/courses/Primer 2/seed_configs/problem.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | gamma_detector: 3 | name: Generic\\NaI\\3x3\\Front\\LowScat 4 | parameters: 5 | distance_cm: 1000 6 | height_cm: 0 7 | dead_time_per_pulse: 10 8 | latitude_deg: 35.0 9 | longitude_deg: 253.4 10 | elevation_m: 1620 11 | sources: 12 | - isotope: Am241 13 | configurations: 14 | - Am241,100uC 15 | - Am241,100uC {13,10} 16 | - Am241,100uC {13,30} 17 | - Am241,100uC {26,5} 18 | - Am241,100uC {26,20} 19 | - Am241,100uC {26,50} 20 | - Am241,100uC {50,2} 21 | - isotope: Ba133 22 | configurations: 23 | - Ba133,100uC 24 | - Ba133,100uC {10,20} 25 | - Ba133,100uC {26,10} 26 | - Ba133,100uC {26,30} 27 | - Ba133,100uC {74,20} 28 | - Ba133,100uC {50,30} 29 | - isotope: Bi207 30 | configurations: 31 | - Bi207,100uC 32 | - Bi207,100uC {10,30} 33 | - Bi207,100uC {26,10} 34 | - isotope: Cf249 35 | configurations: 36 | - Cf249,100uC 37 | - isotope: Co57 38 | configurations: 39 | - Co57,100uC 40 | - Co57,100uC {13,10} 41 | - isotope: Co60 42 | configurations: 43 | - Co60,100uC 44 | - Co60,100uC {10,10} 45 | - Co60,100uC {10,30} 46 | - Co60,100uC {26,20} 47 | - Co60,100uC {26,40} 48 | - Co60,100uC {82,30} 49 | - Co60,100uC {82,60} 50 | - isotope: Cosmic 51 | configurations: 52 | - Cosmic 53 | - isotope: Cs137 54 | configurations: 55 | - Cs137,100uC 56 | - Cs137,100uC {6,2} 57 | - Cs137,100uC {26,10} 58 | - Cs137,100uC {82,10} 59 | - Cs137,100uC {13,240;26,1} 60 | - Cs137InPine 61 | - Cs137InLead 62 | - Cs137InGradedShield1 63 | - Cs137InGradedShield2 64 | - isotope: Eu152 65 | configurations: 66 | - Eu152,100uC 67 | - Eu152,100uC {10,10} 68 | - Eu152,100uC {10,30} 69 | - Eu152,100uC {30,50} 70 | - Eu152,100uC {74,20} 71 | - isotope: Eu154 72 | configurations: 73 | - Eu154,100uC 74 | - Eu154,100uC {10,10} 75 | - Eu154,100uC {74,20} 76 | - isotope: F18 77 | configurations: 78 | - F18,100uC 79 | - F18,100uC {10,20} 80 | - F18,100uC {10,50} 81 | - F18,100uC {26,30} 82 | - isotope: Ga67 83 | configurations: 84 | - Ga67,100uC 85 | - Ga67,100uC {6,10} 86 | - Ga67,100uC {6,20} 87 | - Ga67,100uC {10,30} 88 | - Ga67,100uC {50,16} 89 | - Ga67,100uC {82,20} 90 | - isotope: Ho166m 91 | configurations: 92 | - Ho166m,100uC 93 | - Ho166m,100uC {10,20} 94 | - Ho166m,100uC {26,20} 95 | - Ho166m,100uC {74,30} 96 | - isotope: I123 97 | configurations: 98 | - I123,100uC 99 | - I123,100uC {10,30} 100 | - isotope: I131 101 | configurations: 102 | - I131,100uC 103 | - I131,100uC {10,10} 104 | - I131,100uC {10,30} 105 | - I131,100uC {16,50} 106 | - I131,100uC {20,20} 107 | - I131,100uC {82,10} 108 | - I131,100uC {10,20;50,5} 109 | - isotope: In111 110 | configurations: 111 | - In111,100uC 112 | - In111,100uC {10,20} 113 | - In111,100uC {50,20} 114 | - isotope: Ir192 115 | configurations: 116 | - Ir192,100uC 117 | - Ir192,100uC {10,20} 118 | - Ir192,100uC {26,40} 119 | - Ir192,100uC {26,100} 120 | - Ir192,100uC {82,30} 121 | - Ir192,100uC {82,160} 122 | - Ir192Shielded1 123 | - Ir192Shielded2 124 | - Ir192Shielded3 125 | - isotope: K40 126 | configurations: 127 | - PotassiumInSoil 128 | - K40,100uC 129 | - K40inCargo1 130 | - K40inCargo2 131 | - K40inCargo3 132 | - K40inCargo4 133 | - isotope: Mo99 134 | configurations: 135 | - Mo99,100uC 136 | - Mo99,100uC {26,20} 137 | - Mo99,100uC {50,40} 138 | - isotope: Na22 139 | configurations: 140 | - Na22,100uC 141 | - Na22,100uC {10,10} 142 | - Na22,100uC {10,30} 143 | - Na22,100uC {74,20} 144 | - isotope: Ra226 145 | configurations: 146 | - UraniumInSoil 147 | - Ra226,100uC 148 | - Ra226inCargo1 149 | - Ra226inCargo2 150 | - Ra226inCargo3 151 | - Ra226inSilica1 152 | - Ra226inSilica2 153 | - Ra226,100uC {60,100} 154 | - isotope: Sr90 155 | configurations: 156 | - Sr90InPoly1,1C 157 | - Sr90InPoly10,1C 158 | - Sr90InFe,1C 159 | - Sr90InSn,1C 160 | - isotope: Tc99m 161 | configurations: 162 | - Tc99m,100uC 163 | - Tc99m,100uC {7,10} 164 | - Tc99m,100uC {10,20} 165 | - Tc99m,100uC {13,30} 166 | - Tc99m,100uC {26,30} 167 | - isotope: Th232 168 | configurations: 169 | - ThoriumInSoil 170 | - Th232,100uC 171 | - Th232inCargo1 172 | - Th232inCargo2 173 | - Th232inCargo3 174 | - Th232inSilica1 175 | - Th232inSilica2 176 | - Th232,100uC {60,100} 177 | - ThPlate 178 | - ThPlate+Thxray,10uC 179 | - isotope: Tl201 180 | configurations: 181 | - Tl201,100uC 182 | - Tl201,100uC {8,40} 183 | - Tl201,100uC {10,10} 184 | - Tl201,100uC {10,30} 185 | - Tl201,100uC {26,10} 186 | - isotope: Y88 187 | configurations: 188 | - Y88,100uC 189 | - Y88,100uC {10,50} 190 | - Y88,100uC {26,30} 191 | - Y88,100uC {80,50} 192 | ... -------------------------------------------------------------------------------- /examples/data/conversion/aipt_to_ss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to convert AIPT data files to `SampleSet`s 5 | and save as HDF5 files. 6 | """ 7 | from riid.data.converters import convert_directory 8 | from riid.data.converters.aipt import convert_and_save 9 | 10 | 11 | if __name__ == "__main__": 12 | # Change the following to a valid path on your computer 13 | DIRECTORY_WITH_FILES_TO_CONVERT = "./data" 14 | 15 | convert_directory( 16 | DIRECTORY_WITH_FILES_TO_CONVERT, 17 | convert_and_save, 18 | file_ext="open", 19 | output_dir=DIRECTORY_WITH_FILES_TO_CONVERT + "/converted", 20 | ) 21 | -------------------------------------------------------------------------------- /examples/data/conversion/pcf_to_ss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to convert PCFs to `SampleSet`s 5 | and save as HDF5 files. 6 | 7 | This example shows how the utilities originally built for converting Topcoder and AIPT data 8 | can be easily repurposed to bulk convert any type of file. 9 | PyRIID just so happens to already have PCF reading available (via `read_pcf()`), 10 | but a custom file format is easily handled by implementing your own `convert_and_save()` function. 11 | """ 12 | import os 13 | from pathlib import Path 14 | 15 | from riid import SAMPLESET_HDF_FILE_EXTENSION, read_pcf 16 | from riid.data.converters import (_validate_and_create_output_dir, 17 | convert_directory) 18 | 19 | 20 | def convert_and_save(input_file_path: str, output_dir: str = None, 21 | skip_existing: bool = True, **kwargs): 22 | input_path = Path(input_file_path) 23 | if not output_dir: 24 | output_dir = input_path.parent 25 | _validate_and_create_output_dir(output_dir) 26 | output_file_path = os.path.join(output_dir, input_path.stem + SAMPLESET_HDF_FILE_EXTENSION) 27 | if skip_existing and os.path.exists(output_file_path): 28 | return 29 | 30 | output_file_path = os.path.join( 31 | input_path.parent, 32 | input_path.stem + SAMPLESET_HDF_FILE_EXTENSION 33 | ) 34 | ss = read_pcf(input_file_path) 35 | ss.to_hdf(output_file_path) 36 | 37 | 38 | if __name__ == "__main__": 39 | # Change the following to a valid path on your computer 40 | DIRECTORY_WITH_FILES_TO_CONVERT = "./data" 41 | 42 | convert_directory( 43 | DIRECTORY_WITH_FILES_TO_CONVERT, 44 | convert_and_save, 45 | file_ext="pcf", 46 | output_dir=DIRECTORY_WITH_FILES_TO_CONVERT + "/converted", 47 | ) 48 | -------------------------------------------------------------------------------- /examples/data/conversion/topcoder_to_ss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to convert TopCoder data files to `SampleSet`s 5 | and save as HDF5 files. 6 | """ 7 | from riid.data.converters import convert_directory 8 | from riid.data.converters.topcoder import convert_and_save 9 | 10 | 11 | if __name__ == "__main__": 12 | # Change the following to a valid path on your computer 13 | DIRECTORY_WITH_FILES_TO_CONVERT = "./data" 14 | 15 | convert_directory( 16 | DIRECTORY_WITH_FILES_TO_CONVERT, 17 | convert_and_save, 18 | file_ext="csv", 19 | output_dir=DIRECTORY_WITH_FILES_TO_CONVERT + "/converted", 20 | sample_interval=1.0, 21 | pm_chunksize=50, 22 | ) 23 | -------------------------------------------------------------------------------- /examples/data/difficulty_score.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to compute the difficulty of a given SampleSet.""" 5 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 6 | 7 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 8 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ 9 | .generate(1) 10 | 11 | static_synth = StaticSynthesizer( 12 | samples_per_seed=500, 13 | snr_function="uniform", 14 | ) 15 | easy_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss) 16 | 17 | static_synth.snr_function = "log10" 18 | medium_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss) 19 | 20 | static_synth.snr_function_args = (.00001, .1) 21 | hard_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss) 22 | 23 | easy_score = easy_ss.difficulty_score 24 | print(f"Difficulty score for Uniform: {easy_score:.5f}") 25 | medium_score = medium_ss.difficulty_score 26 | print(f"Difficulty score for Log10: {medium_score:.5f}") 27 | hard_score = hard_ss.difficulty_score 28 | print(f"Difficulty score for Log10 Low Signal: {hard_score:.5f}") 29 | -------------------------------------------------------------------------------- /examples/data/preprocessing/energy_calibration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates energy calibration 5 | configuration for a SampleSet.""" 6 | import sys 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 12 | 13 | SYNTHETIC_DATA_CONFIG = { 14 | "samples_per_seed": 10, 15 | "bg_cps": 10, 16 | "snr_function": "uniform", 17 | "snr_function_args": (1, 100), 18 | "live_time_function": "uniform", 19 | "live_time_function_args": (0.25, 10), 20 | "apply_poisson_noise": True, 21 | "return_fg": True, 22 | "return_gross": True 23 | } 24 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 25 | 26 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ 27 | .generate(1) 28 | 29 | fg_ss, ss = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\ 30 | .generate(fg_seeds_ss, mixed_bg_seed_ss) 31 | ss.ecal_low_e = 0 32 | ss.ecal_order_0 = 0 33 | ss.ecal_order_1 = 3100 34 | ss.ecal_order_2 = 0 35 | ss.ecal_order_3 = 0 36 | 37 | index = ss[ss.get_labels() == "Cs137"] 38 | index = np.sum(index.spectra, axis=0).values.astype(int).argmax() 39 | plt.subplot(2, 1, 1) 40 | plt.step( 41 | np.arange(ss.n_channels), 42 | ss.spectra.values[index, :], 43 | where="mid" 44 | ) 45 | plt.yscale("log") 46 | plt.xlim(0, ss.n_channels) 47 | plt.title(ss.get_labels()[index]) 48 | plt.xlabel("Channels") 49 | 50 | energy_bins = np.linspace(0, 3100, 512) 51 | channel_energies = ss.get_channel_energies(sample_index=0, 52 | fractional_energy_bins=energy_bins) 53 | 54 | plt.subplot(2, 1, 2) 55 | 56 | plt.step(channel_energies, ss.spectra.values[index, :], where="mid") 57 | plt.fill_between( 58 | energy_bins, 59 | 0, 60 | ss.spectra.values[index, :], 61 | alpha=0.3, 62 | step="mid" 63 | ) 64 | 65 | plt.yscale("log") 66 | plt.title(ss.get_labels()[index]) 67 | plt.xlabel("Energy (keV)") 68 | plt.vlines(661, 1e-6, 1e8) 69 | plt.ylim(bottom=0.8, top=ss.spectra.values[index, :].max()*1.5) 70 | plt.xlim(0, energy_bins[-1]) 71 | plt.subplots_adjust(hspace=1) 72 | if len(sys.argv) == 1: 73 | plt.show() 74 | -------------------------------------------------------------------------------- /examples/data/synthesis/mix_seeds.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to randomly mix seeds together.""" 5 | import numpy as np 6 | 7 | from riid import SeedMixer, get_dummy_seeds 8 | 9 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 10 | 11 | rng = np.random.default_rng(3) 12 | mixed_fg_seeds_ss = SeedMixer(fg_seeds_ss, mixture_size=2, rng=rng)\ 13 | .generate(n_samples=10) 14 | mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng)\ 15 | .generate(n_samples=10) 16 | 17 | print(mixed_fg_seeds_ss) 18 | print(mixed_bg_seeds_ss) 19 | -------------------------------------------------------------------------------- /examples/data/synthesis/synthesize_passbys.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """Example of generating synthetic passby gamma spectra from seeds.""" 5 | import sys 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | from riid import PassbySynthesizer, get_dummy_seeds 11 | 12 | if len(sys.argv) == 2: 13 | import matplotlib 14 | matplotlib.use("Agg") 15 | 16 | 17 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 18 | pbs = PassbySynthesizer( 19 | sample_interval=0.5, 20 | fwhm_function_args=(5, 5), 21 | snr_function_args=(100, 100), 22 | dwell_time_function_args=(5, 5), 23 | events_per_seed=1, 24 | return_fg=False, 25 | return_gross=True, 26 | ) 27 | 28 | events = pbs.generate(fg_seeds_ss, bg_seeds_ss) 29 | _, gross_passbys = list(zip(*events)) 30 | passby_ss = gross_passbys[0] 31 | passby_ss.concat(gross_passbys[1:]) 32 | 33 | passby_ss.sources.drop(bg_seeds_ss.sources.columns, axis=1, inplace=True) 34 | passby_ss.normalize_sources() 35 | passby_ss.normalize(p=2) 36 | 37 | plt.imshow(passby_ss.spectra.values, aspect="auto") 38 | plt.ylabel("Time") 39 | plt.xlabel("Channel") 40 | plt.show() 41 | 42 | offset = 0 43 | breaks = [] 44 | labels = passby_ss.get_labels() 45 | distinct_labels = set(labels) 46 | # Iterate over set of unique labels 47 | for iso in distinct_labels: 48 | sub = passby_ss[labels == iso] 49 | plt.plot( 50 | pbs.sample_interval * (np.arange(sub.n_samples) + offset), 51 | sub.info.snr, 52 | label=f"{iso}", 53 | lw=1 54 | ) 55 | offset += sub.n_samples 56 | breaks.append(offset*pbs.sample_interval) 57 | 58 | plt.vlines(breaks, 0, 1e4, linestyle="--", color="grey") 59 | 60 | plt.yscale("log") 61 | plt.ylim([.1, 1e4]) 62 | plt.xlim(0, breaks[-1]) 63 | plt.ylabel("SNR") 64 | plt.xlabel("Time") 65 | plt.legend(frameon=False) 66 | plt.show() 67 | -------------------------------------------------------------------------------- /examples/data/synthesis/synthesize_seeds_advanced.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to generate synthetic seeds from GADRAS using PyRIID's 5 | configuration expansion features.""" 6 | import yaml 7 | 8 | from riid import SeedSynthesizer 9 | 10 | seed_synth_config = """ 11 | --- 12 | gamma_detector: 13 | name: Generic\\NaI\\3x3\\Front\\MidScat 14 | parameters: 15 | distance_cm: 16 | - 10 17 | - 100 18 | - 1000 19 | height_cm: 100 20 | dead_time_per_pulse: 5 21 | latitude_deg: 35.0 22 | longitude_deg: 253.4 23 | elevation_m: 1620 24 | sources: 25 | - isotope: Cs137 26 | configurations: 27 | - Cs137,100uCi 28 | - name: Cs137 29 | activity: 30 | - 1 31 | - 0.5 32 | activity_units: Ci 33 | shielding_atomic_number: 34 | min: 10 35 | max: 40.0 36 | dist: uniform 37 | num_samples: 5 38 | shielding_aerial_density: 39 | mean: 120 40 | std: 2 41 | num_samples: 5 42 | - isotope: Cosmic 43 | configurations: 44 | - Cosmic 45 | - isotope: K40 46 | configurations: 47 | - PotassiumInSoil 48 | - isotope: Ra226 49 | configurations: 50 | - UraniumInSoil 51 | - isotope: Th232 52 | configurations: 53 | - ThoriumInSoil 54 | ... 55 | """ 56 | seed_synth_config = yaml.safe_load(seed_synth_config) 57 | 58 | try: 59 | seeds_ss = SeedSynthesizer().generate( 60 | seed_synth_config, 61 | verbose=True 62 | ) 63 | print(seeds_ss) 64 | 65 | # At this point, you could save out the seeds via: 66 | seeds_ss.to_hdf("seeds.h5") 67 | 68 | # or start separating your backgrounds from foreground for use with the StaticSynthesizer 69 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg() 70 | 71 | print(fg_seeds_ss) 72 | print(bg_seeds_ss) 73 | 74 | fg_seeds_ss.to_hdf("./fg_seeds.h5") 75 | bg_seeds_ss.to_hdf("./bg_seeds.h5") 76 | except FileNotFoundError: 77 | pass # Happens when not on Windows 78 | -------------------------------------------------------------------------------- /examples/data/synthesis/synthesize_seeds_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to generate synthetic seeds from GADRAS.""" 5 | import yaml 6 | 7 | from riid import SeedSynthesizer 8 | 9 | seed_synth_config = """ 10 | --- 11 | gamma_detector: 12 | name: Generic\\NaI\\3x3\\Front\\MidScat 13 | parameters: 14 | distance_cm: 100 15 | height_cm: 10 16 | dead_time_per_pulse: 5 17 | latitude_deg: 35.0 18 | longitude_deg: 253.4 19 | elevation_m: 1620 20 | sources: 21 | - isotope: Am241 22 | configurations: 23 | - Am241,100uC 24 | - isotope: Ba133 25 | configurations: 26 | - Ba133,100uC 27 | - isotope: Cs137 28 | configurations: 29 | - Cs137,100uC 30 | - isotope: Cosmic 31 | configurations: 32 | - Cosmic 33 | - isotope: K40 34 | configurations: 35 | - PotassiumInSoil 36 | - isotope: Ra226 37 | configurations: 38 | - UraniumInSoil 39 | - isotope: Th232 40 | configurations: 41 | - ThoriumInSoil 42 | ... 43 | """ 44 | seed_synth_config = yaml.safe_load(seed_synth_config) 45 | 46 | try: 47 | seeds_ss = SeedSynthesizer().generate( 48 | seed_synth_config, 49 | verbose=True 50 | ) 51 | print(seeds_ss) 52 | 53 | # At this point, you could save out the seeds via: 54 | seeds_ss.to_hdf("seeds.h5") 55 | 56 | # or start separating your backgrounds from foreground for use with the StaticSynthesizer 57 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg() 58 | 59 | print(fg_seeds_ss) 60 | print(bg_seeds_ss) 61 | 62 | fg_seeds_ss.to_hdf("./fg_seeds.h5") 63 | bg_seeds_ss.to_hdf("./bg_seeds.h5") 64 | except FileNotFoundError: 65 | pass # Happens when not on Windows 66 | -------------------------------------------------------------------------------- /examples/data/synthesis/synthesize_seeds_custom.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates additional ways to generate synthetic seeds from GADRAS: 5 | - Example 1: inject everything in a folder ending in .gam 6 | - Example 2: build and inject point sources comprised of multiple radioisotopes 7 | """ 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import yaml 13 | 14 | from riid import SeedSynthesizer 15 | from riid.gadras.api import GADRAS_INSTALL_PATH 16 | 17 | 18 | def convert_df_row_to_inject_string(row): 19 | """Converts a row of the DataFrame to a proper GADRAS inject string""" 20 | isotopes = row.index.values 21 | activities = row.values 22 | isotopes_and_activities = [ 23 | f"{iso},{round(act, 4)}uCi" for iso, act in zip(isotopes, activities) 24 | ] # Note the "uCi" specified here. You may need to change it. 25 | inject_string = " + ".join(isotopes_and_activities) 26 | return inject_string 27 | 28 | 29 | seed_synth_config = """ 30 | --- 31 | gamma_detector: 32 | name: Generic\\NaI\\3x3\\Front\\MidScat 33 | parameters: 34 | distance_cm: 100 35 | height_cm: 10 36 | dead_time_per_pulse: 5 37 | latitude_deg: 35.0 38 | longitude_deg: 253.4 39 | elevation_m: 1620 40 | sources: 41 | - isotope: U235 42 | configurations: null 43 | ... 44 | """ 45 | seed_synth_config = yaml.safe_load(seed_synth_config) 46 | 47 | try: 48 | seed_synth = SeedSynthesizer() 49 | 50 | # Example 1 51 | # Change "Continuum" to your own source directory 52 | gam_dir = Path(GADRAS_INSTALL_PATH).joinpath("Source/Continuum") 53 | gam_filenames = [x.stem for x in gam_dir.glob("*.gam")] 54 | seed_synth_config["sources"][0]["configurations"] = gam_filenames 55 | seeds = seed_synth.generate(seed_synth_config) 56 | seeds.to_hdf("seeds_from_gams.h5") 57 | 58 | # Example 2 59 | # For the following DataFrame, columns are isotopes, rows are samples, and cells are activity 60 | df = pd.DataFrame( 61 | np.random.rand(25, 5), # Reminder: ensure all activity values are in the same units 62 | columns=["Am241", "Ba133", "Co60", "Cs137", "U235"], 63 | ) 64 | configurations = df.apply(convert_df_row_to_inject_string, axis=1).to_list() 65 | seed_synth_config["sources"][0]["configurations"] = configurations 66 | 67 | seeds = seed_synth.generate(seed_synth_config) 68 | seeds.to_hdf("seeds_from_df.h5") 69 | except FileNotFoundError: 70 | pass # Happens when not on Windows 71 | -------------------------------------------------------------------------------- /examples/data/synthesis/synthesize_static.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to generate synthetic gamma spectra from seeds.""" 5 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 6 | 7 | SYNTHETIC_DATA_CONFIG = { 8 | "samples_per_seed": 10000, 9 | "bg_cps": 10, 10 | "snr_function": "uniform", 11 | "snr_function_args": (1, 100), 12 | "live_time_function": "uniform", 13 | "live_time_function_args": (0.25, 10), 14 | "apply_poisson_noise": True, 15 | "return_fg": True, 16 | "return_gross": True, 17 | } 18 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 19 | 20 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ 21 | .generate(1) 22 | 23 | static_synth = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG) 24 | fg_ss, gross_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss) 25 | """ | | 26 | | | 27 | | |> gross samples 28 | |> source-only samples 29 | """ 30 | print(fg_ss) 31 | print(gross_ss) 32 | -------------------------------------------------------------------------------- /examples/modeling/anomaly_detection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to obtain events using an anomaly detection 5 | algorithm. 6 | """ 7 | import sys 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from matplotlib import cm 12 | 13 | from riid import PassbySynthesizer, SeedMixer, get_dummy_seeds 14 | from riid.anomaly import PoissonNChannelEventDetector 15 | 16 | if len(sys.argv) == 2: 17 | import matplotlib 18 | matplotlib.use("Agg") 19 | 20 | SAMPLE_INTERVAL = 0.5 21 | BG_RATE = 300 22 | EXPECTED_BG_COUNTS = SAMPLE_INTERVAL * BG_RATE 23 | SHORT_TERM_DURATION = 1.5 24 | POST_EVENT_DURATION = 1.5 25 | N_POST_EVENT_SAMPLES = (POST_EVENT_DURATION + SAMPLE_INTERVAL) / SAMPLE_INTERVAL 26 | 27 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 28 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ 29 | .generate(1) 30 | 31 | ed = PoissonNChannelEventDetector( 32 | long_term_duration=600, 33 | short_term_duration=SHORT_TERM_DURATION, 34 | pre_event_duration=5, 35 | max_event_duration=120, 36 | post_event_duration=POST_EVENT_DURATION, 37 | tolerable_false_alarms_per_day=2, 38 | anomaly_threshold_update_interval=60, 39 | ) 40 | count_rate_history = [] 41 | 42 | # Fill background buffer first 43 | print("Filling background") 44 | measurement_id = 0 45 | expected_bg_measurement = mixed_bg_seed_ss.spectra.iloc[0] * EXPECTED_BG_COUNTS 46 | while ed.background_percent_complete < 100: 47 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement) 48 | count_rate = noisy_bg_measurement.sum() / SAMPLE_INTERVAL 49 | count_rate_history.append(count_rate) 50 | _ = ed.add_measurement( 51 | measurement_id, 52 | noisy_bg_measurement, 53 | SAMPLE_INTERVAL 54 | ) 55 | measurement_id += 1 56 | 57 | # Now let's see how many false alarms we get. 58 | # You may want to make this duration higher in order to get a better statistic. 59 | print("Checking false alarm rate") 60 | FALSE_ALARM_CHECK_DURATION = 3 * 60 * 60 61 | false_alarm_check_range = range( 62 | measurement_id, 63 | int(FALSE_ALARM_CHECK_DURATION / SAMPLE_INTERVAL) + 1 64 | ) 65 | false_alarms = 0 66 | for measurement_id in false_alarm_check_range: 67 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement) 68 | count_rate = noisy_bg_measurement.sum() / SAMPLE_INTERVAL 69 | count_rate_history.append(count_rate) 70 | event_result = ed.add_measurement( 71 | measurement_id, 72 | noisy_bg_measurement, 73 | SAMPLE_INTERVAL 74 | ) 75 | if event_result: 76 | false_alarms += 1 77 | measurement_id += 1 78 | false_alarm_rate = 60 * 60 * false_alarms / FALSE_ALARM_CHECK_DURATION 79 | print(f"False alarm rate: {false_alarm_rate:.2f}/hour") 80 | 81 | # Now let's make a passby 82 | print("Generating pass-by") 83 | events = PassbySynthesizer( 84 | fwhm_function_args=(1,), 85 | snr_function_args=(30, 30), 86 | dwell_time_function_args=(1, 1), 87 | events_per_seed=1, 88 | sample_interval=SAMPLE_INTERVAL, 89 | bg_cps=BG_RATE, 90 | return_fg=False, 91 | return_gross=True, 92 | ).generate(fg_seeds_ss, mixed_bg_seed_ss) 93 | _, gross_events = list(zip(*events)) 94 | passby_ss = gross_events[0] 95 | 96 | print("Passing by...") 97 | passby_begin_idx = measurement_id 98 | passby_end_idx = passby_ss.n_samples + measurement_id 99 | passby_range = list(range(passby_begin_idx, passby_end_idx)) 100 | for i, measurement_id in enumerate(passby_range): 101 | gross_spectrum = passby_ss.spectra.iloc[i].values 102 | count_rate = gross_spectrum.sum() / SAMPLE_INTERVAL 103 | count_rate_history.append(count_rate) 104 | event_result = ed.add_measurement( 105 | measurement_id=measurement_id, 106 | measurement=gross_spectrum, 107 | duration=SAMPLE_INTERVAL, 108 | ) 109 | if event_result: 110 | break 111 | 112 | # A little extra background to close out any pending event 113 | if ed.event_in_progress: 114 | while not event_result: 115 | measurement_id += 1 116 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement) 117 | count_rate = noisy_bg_measurement.sum() / SAMPLE_INTERVAL 118 | count_rate_history.append(count_rate) 119 | event_result = ed.add_measurement( 120 | measurement_id, 121 | noisy_bg_measurement, 122 | SAMPLE_INTERVAL 123 | ) 124 | 125 | count_rate_history = np.array(count_rate_history) 126 | if event_result: 127 | event_measurement, event_bg_measurement, event_duration, event_measurement_ids = event_result 128 | print(f"Event Duration: {event_duration}") 129 | event_begin, event_end = event_measurement_ids[0], event_measurement_ids[-1] 130 | start_idx = int(passby_begin_idx - 30 / SAMPLE_INTERVAL) # include some lead up to event 131 | y = count_rate_history[start_idx:measurement_id] 132 | x = np.array(range(start_idx, measurement_id)) * SAMPLE_INTERVAL 133 | fix, ax = plt.subplots() 134 | ax.plot(x, y) 135 | ax.axvspan( 136 | xmin=event_begin * SAMPLE_INTERVAL, 137 | xmax=event_end * SAMPLE_INTERVAL, 138 | facecolor=cm.tab10(1), 139 | alpha=0.35, 140 | ) 141 | ax.set_ylabel("Count rate (cps)") 142 | ax.set_xlabel("Time (sec)") 143 | plt.show() 144 | else: 145 | print("Pass-by did not produce an event") 146 | -------------------------------------------------------------------------------- /examples/modeling/arad.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to use the PyRIID implementations of ARAD. 5 | """ 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 10 | from riid.models import ARADv1, ARADv2 11 | 12 | # Config 13 | rng = np.random.default_rng(42) 14 | OOD_QUANTILE = 0.99 15 | VERBOSE = False 16 | # Some of the following parameters are set low because this example runs on GitHub Actions and 17 | # we don't want it taking a bunch of time. 18 | # When running this locally, change the values per their corresponding comment, otherwise 19 | # the results likely will not be meaningful. 20 | EPOCHS = 5 # Change this to 20+ 21 | N_MIXTURES = 50 # Changes this to 1000+ 22 | TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+ 23 | TEST_SAMPLES_PER_SEED = 5 24 | 25 | # Generate training data 26 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=128, rng=rng).split_fg_and_bg() 27 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng).generate(N_MIXTURES) 28 | static_synth = StaticSynthesizer( 29 | samples_per_seed=TRAIN_SAMPLES_PER_SEED, 30 | snr_function_args=(0, 0), 31 | return_fg=False, 32 | return_gross=True, 33 | rng=rng, 34 | ) 35 | _, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) 36 | gross_train_ss.normalize() 37 | 38 | # Generate test data 39 | static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED 40 | _, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) 41 | test_ss.normalize() 42 | 43 | # Train the models 44 | results = {} 45 | models = [ARADv1, ARADv2] 46 | for model_class in models: 47 | arad = model_class() 48 | model_name = arad.__class__.__name__ 49 | 50 | print(f"Training and testing {model_name}...") 51 | arad.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) 52 | arad.predict(gross_train_ss) 53 | ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE) 54 | 55 | reconstructions = arad.predict(test_ss, verbose=VERBOSE) 56 | ood = test_ss.info.recon_error.values > ood_threshold 57 | false_positive_rate = ood.mean() 58 | mean_recon_error = test_ss.info.recon_error.values.mean() 59 | 60 | results[model_name] = { 61 | "ood_threshold": f"{ood_threshold:.4f}", 62 | "mean_recon_error": mean_recon_error, 63 | "false_positive_rate": false_positive_rate, 64 | } 65 | 66 | print(f"Target False Positive Rate: {1-OOD_QUANTILE:.4f}") 67 | print(pd.DataFrame.from_dict(results)) 68 | -------------------------------------------------------------------------------- /examples/modeling/arad_latent_prediction.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to train a regressor or classifier branch 5 | from an ARAD latent space. 6 | """ 7 | import numpy as np 8 | from keras.api.metrics import Accuracy, CategoricalCrossentropy 9 | from sklearn.metrics import f1_score, mean_squared_error 10 | 11 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 12 | from riid.models import ARADLatentPredictor, ARADv2 13 | 14 | # Config 15 | rng = np.random.default_rng(42) 16 | VERBOSE = False 17 | # Some of the following parameters are set low because this example runs on GitHub Actions and 18 | # we don't want it taking a bunch of time. 19 | # When running this locally, change the values per their corresponding comment, otherwise 20 | # the results likely will not be meaningful. 21 | EPOCHS = 5 # Change this to 20+ 22 | N_MIXTURES = 50 # Change this to 1000+ 23 | TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+ 24 | TEST_SAMPLES_PER_SEED = 5 25 | 26 | # Generate training data 27 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=128, rng=rng).split_fg_and_bg() 28 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng).generate(N_MIXTURES) 29 | static_synth = StaticSynthesizer( 30 | samples_per_seed=TRAIN_SAMPLES_PER_SEED, 31 | snr_function_args=(0, 0), 32 | return_fg=False, 33 | return_gross=True, 34 | rng=rng, 35 | ) 36 | _, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) 37 | gross_train_ss.normalize() 38 | 39 | # Generate test data 40 | static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED 41 | _, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss) 42 | test_ss.normalize() 43 | 44 | # Train ARAD model 45 | print("Training ARAD") 46 | arad_v2 = ARADv2() 47 | arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE) 48 | 49 | # Train regressor to predict SNR 50 | print("Training Regressor") 51 | arad_regressor = ARADLatentPredictor() 52 | _ = arad_regressor.fit( 53 | arad_v2.model, 54 | gross_train_ss, 55 | target_info_columns=["live_time"], 56 | epochs=10, 57 | batch_size=5, 58 | verbose=VERBOSE, 59 | ) 60 | regression_predictions = arad_regressor.predict(test_ss) 61 | regression_score = mean_squared_error(gross_train_ss.info.live_time, regression_predictions) 62 | print("Regressor MSE: {:.3f}".format(regression_score)) 63 | 64 | # Train classifier to predict isotope 65 | print("Training Classifier") 66 | arad_classifier = ARADLatentPredictor( 67 | loss="categorical_crossentropy", 68 | metrics=[Accuracy(), CategoricalCrossentropy()], 69 | final_activation="softmax" 70 | ) 71 | arad_classifier.fit( 72 | arad_v2.model, 73 | gross_train_ss, 74 | target_level="Isotope", 75 | epochs=10, 76 | batch_size=5, 77 | verbose=VERBOSE, 78 | ) 79 | arad_classifier.predict(test_ss) 80 | classification_score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro") 81 | print("Classification F1 Score: {:.3f}".format(classification_score)) 82 | -------------------------------------------------------------------------------- /examples/modeling/classifier_comparison.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates a comparison between Poisson Bayes and MLP classifiers.""" 5 | import sys 6 | 7 | import matplotlib.pyplot as plt 8 | from sklearn.metrics import f1_score 9 | 10 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 11 | from riid.metrics import precision_recall_curve 12 | from riid.models import MLPClassifier, PoissonBayesClassifier 13 | from riid.visualize import plot_precision_recall 14 | 15 | if len(sys.argv) == 2: 16 | import matplotlib 17 | matplotlib.use("Agg") 18 | 19 | # Generate some training data 20 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=64).split_fg_and_bg() 21 | 22 | 23 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(10) 24 | 25 | static_synth = StaticSynthesizer( 26 | samples_per_seed=100, 27 | live_time_function_args=(1, 10), 28 | snr_function="log10", 29 | snr_function_args=(1, 20), 30 | return_fg=True, 31 | return_gross=True, 32 | ) 33 | train_fg_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss, verbose=False) 34 | train_fg_ss.normalize() 35 | 36 | model_nn = MLPClassifier() 37 | model_nn.fit(train_fg_ss, epochs=10, patience=5) 38 | 39 | # Create PB model 40 | model_pb = PoissonBayesClassifier() 41 | model_pb.fit(fg_seeds_ss) 42 | 43 | # Generate some test data 44 | static_synth.samples_per_seed = 50 45 | test_fg_ss, test_gross_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss, 46 | verbose=False) 47 | test_bg_ss = test_gross_ss - test_fg_ss 48 | test_fg_ss.normalize() 49 | test_gross_ss.sources.drop(bg_seeds_ss.sources.columns, axis=1, inplace=True) 50 | test_gross_ss.normalize_sources() 51 | 52 | # Plot 53 | fig, axs = plt.subplots(ncols=2, figsize=(10, 5)) 54 | results = {} 55 | for model, tag, ax in zip([model_nn, model_pb], ["NN", "PB"], axs): 56 | if tag == "NN": 57 | labels = test_fg_ss.get_labels() 58 | model.predict(test_fg_ss) 59 | predictions = test_fg_ss.get_predictions() 60 | precision, recall, thresholds = precision_recall_curve(test_fg_ss) 61 | elif tag == "PB": 62 | labels = test_gross_ss.get_labels() 63 | model.predict(test_gross_ss, test_bg_ss) 64 | predictions = test_gross_ss.get_predictions() 65 | precision, recall, thresholds = precision_recall_curve(test_gross_ss) 66 | else: 67 | raise ValueError() 68 | 69 | score = f1_score(labels, predictions, average="weighted") 70 | print(f"{tag} F1-score: {score:.3f}") 71 | 72 | plot_precision_recall( 73 | precision=precision, 74 | recall=recall, 75 | title=f"{tag}\nPrecision VS Recall", 76 | fig_ax=(fig, ax), 77 | show=False, 78 | ) 79 | 80 | plt.show() 81 | -------------------------------------------------------------------------------- /examples/modeling/custom_loss_and_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates custom loss functions and metrics.""" 5 | import tensorflow as tf 6 | 7 | from riid.losses import negative_log_f1 8 | from riid.metrics import multi_f1 9 | 10 | y_true = tf.constant([.524, .175, .1, .1, 0, .1]) 11 | y_pred = tf.constant([.2, .2, .2, .1, .2, .1]) 12 | f1 = multi_f1(y_true, y_pred) 13 | loss = negative_log_f1(y_true, y_pred) 14 | print(f"F1 Score: {f1:.3f}") 15 | print(f"Loss: {loss:.3f}") 16 | -------------------------------------------------------------------------------- /examples/modeling/label_proportion_estimation.py: -------------------------------------------------------------------------------- 1 | """This example demonstrates how to train the Label Proportion 2 | Estimator with a semi-supervised loss function.""" 3 | import os 4 | 5 | from sklearn.metrics import mean_absolute_error 6 | 7 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 8 | from riid.models import LabelProportionEstimator 9 | 10 | # Generate some mixture training data. 11 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 12 | mixed_bg_seeds_ss = SeedMixer( 13 | bg_seeds_ss, 14 | mixture_size=3, 15 | dirichlet_alpha=2 16 | ).generate(100) 17 | 18 | static_syn = StaticSynthesizer( 19 | samples_per_seed=100, 20 | bg_cps=300.0, 21 | live_time_function_args=(60, 600), 22 | snr_function_args=(0, 0), 23 | return_fg=False, 24 | return_gross=True, 25 | ) 26 | 27 | _, bg_ss = static_syn.generate(fg_seeds_ss[0], mixed_bg_seeds_ss) 28 | bg_ss.drop_sources_columns_with_all_zeros() 29 | bg_ss.normalize() 30 | 31 | # Create the model 32 | model = LabelProportionEstimator( 33 | hidden_layers=(64,), 34 | # The supervised loss can either be "sparsemax" 35 | # or "categorical_crossentropy". 36 | sup_loss="categorical_crossentropy", 37 | # The unsupervised loss be "poisson_nll", "normal_nll", 38 | # "sse", or "weighted_sse". 39 | unsup_loss="poisson_nll", 40 | # This controls the tradeoff between the sup 41 | # and unsup losses., 42 | beta=1e-4, 43 | optimizer="RMSprop", 44 | learning_rate=1e-2, 45 | hidden_layer_activation="relu", 46 | dropout=0.05, 47 | ) 48 | 49 | # Train the model. 50 | model.fit( 51 | bg_seeds_ss, 52 | bg_ss, 53 | batch_size=10, 54 | epochs=2, 55 | validation_split=0.2, 56 | bg_cps=300 57 | ) 58 | 59 | # Generate some test data. 60 | static_syn.samples_per_seed = 50 61 | _, test_bg_ss = static_syn.generate(fg_seeds_ss[0], mixed_bg_seeds_ss) 62 | test_bg_ss.normalize(p=1) 63 | test_bg_ss.drop_sources_columns_with_all_zeros() 64 | 65 | model.predict(test_bg_ss) 66 | 67 | test_meas = mean_absolute_error( 68 | test_bg_ss.sources.values, 69 | test_bg_ss.prediction_probas.values 70 | ) 71 | print(f"Mean Test MAE: {test_meas.mean():.3f}") 72 | 73 | # Save model 74 | model_path = "./model.json" 75 | model.save(model_path, overwrite=True) 76 | 77 | loaded_model = LabelProportionEstimator() 78 | loaded_model.load(model_path) 79 | 80 | loaded_model.predict(test_bg_ss) 81 | test_maes = mean_absolute_error( 82 | test_bg_ss.sources.values, 83 | test_bg_ss.prediction_probas.values 84 | ) 85 | 86 | print(f"Mean Test MAE: {test_maes.mean():.3f}") 87 | 88 | # Clean up model file - remove this if you want to keep the model 89 | os.remove(model_path) 90 | -------------------------------------------------------------------------------- /examples/modeling/neural_network_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to use the MLP classifier.""" 5 | import numpy as np 6 | from sklearn.metrics import f1_score 7 | 8 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 9 | from riid.models import MLPClassifier 10 | 11 | # Generate some training data 12 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 13 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(1) 14 | 15 | static_synth = StaticSynthesizer( 16 | samples_per_seed=100, 17 | snr_function="log10", 18 | return_fg=False, 19 | return_gross=True, 20 | ) 21 | _, train_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss) 22 | train_ss.normalize() 23 | 24 | model = MLPClassifier() 25 | model.fit(train_ss, epochs=10, patience=5) 26 | 27 | # Generate some test data 28 | static_synth.samples_per_seed = 50 29 | _, test_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss) 30 | test_ss.normalize() 31 | 32 | # Predict 33 | model.predict(test_ss) 34 | 35 | score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro") 36 | print("F1 Score: {:.3f}".format(score)) 37 | 38 | # Get confidences 39 | confidences = test_ss.get_confidences( 40 | fg_seeds_ss, 41 | bg_seed_ss=mixed_bg_seed_ss, 42 | bg_cps=300 43 | ) 44 | print(f"Avg Confidence: {np.mean(confidences):.3f}") 45 | -------------------------------------------------------------------------------- /examples/run_examples.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """Runs all Python files in the subfolders and reports whether they are successful.""" 5 | import os 6 | import subprocess 7 | import sys 8 | from pathlib import Path 9 | 10 | import pandas as pd 11 | from tabulate import tabulate 12 | 13 | DIRS_TO_CHECK = ["data", "modeling", "visualization"] 14 | FILENAME_KEY = "File" 15 | RESULT_KEY = "Result" 16 | SUCCESS_STR = "Success" 17 | FAILURE_STR = "Fail" 18 | 19 | original_wdir = os.getcwd() 20 | example_dir = Path(__file__).parent 21 | os.chdir(example_dir) 22 | 23 | files_to_run = [] 24 | for d in DIRS_TO_CHECK: 25 | file_paths = list(Path(d).rglob("*.py")) 26 | files_to_run.extend(file_paths) 27 | files_to_run = sorted(files_to_run) 28 | 29 | results = {} 30 | n_tests = len(files_to_run) 31 | for i, f in enumerate(files_to_run, start=1): 32 | print(f"Running example {i}/{n_tests}") 33 | return_code = 0 34 | output = None 35 | try: 36 | output = subprocess.check_output(f"python {f} hide", 37 | stderr=subprocess.STDOUT, 38 | shell=True) 39 | 40 | except subprocess.CalledProcessError as e: 41 | if (bytes("Error", "utf-8")) in e.output: 42 | print(e.output.decode()) 43 | return_code = e.returncode 44 | 45 | results[i] = { 46 | FILENAME_KEY: os.path.relpath(f, example_dir), 47 | RESULT_KEY: SUCCESS_STR if not return_code else FAILURE_STR 48 | } 49 | os.chdir(original_wdir) 50 | 51 | df = pd.DataFrame.from_dict(results, orient="index") 52 | tabulated_df = tabulate(df, headers="keys", tablefmt="psql") 53 | print(tabulated_df) 54 | 55 | all_succeeded = all([x[RESULT_KEY] == SUCCESS_STR for x in results.values()]) 56 | sys.exit(0 if all_succeeded else 1) 57 | -------------------------------------------------------------------------------- /examples/visualization/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to obtain confusion matrices.""" 5 | import sys 6 | 7 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 8 | from riid.models import MLPClassifier 9 | from riid.visualize import confusion_matrix 10 | 11 | if len(sys.argv) == 2: 12 | import matplotlib 13 | matplotlib.use("Agg") 14 | 15 | SYNTHETIC_DATA_CONFIG = { 16 | "snr_function": "log10", 17 | "snr_function_args": (.01, 10), 18 | } 19 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 20 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ 21 | .generate(1) 22 | 23 | train_ss, _ = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\ 24 | .generate(fg_seeds_ss, mixed_bg_seed_ss) 25 | train_ss.normalize() 26 | 27 | model = MLPClassifier() 28 | model.fit(train_ss, verbose=0, epochs=50) 29 | 30 | # Generate some test data 31 | SYNTHETIC_DATA_CONFIG = { 32 | "snr_function": "log10", 33 | "snr_function_args": (.01, 10), 34 | "samples_per_seed": 50, 35 | } 36 | 37 | test_ss, _ = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\ 38 | .generate(fg_seeds_ss, mixed_bg_seed_ss) 39 | test_ss.normalize() 40 | 41 | # Predict and evaluate 42 | model.predict(test_ss) 43 | 44 | fig, ax = confusion_matrix(test_ss, show=True) 45 | -------------------------------------------------------------------------------- /examples/visualization/distance_matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to generate a distance matrix that 5 | compares every pair of spectra in a SampleSet. 6 | """ 7 | import sys 8 | 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | 12 | from riid import get_dummy_seeds 13 | 14 | if len(sys.argv) == 2: 15 | import matplotlib 16 | matplotlib.use("Agg") 17 | 18 | seeds_ss = get_dummy_seeds(n_channels=16) 19 | distance_df = seeds_ss.get_spectral_distance_matrix() 20 | 21 | sns.set(rc={"figure.figsize": (10, 7)}) 22 | ax = sns.heatmap( 23 | distance_df, 24 | cbar_kws={"label": "Jensen-Shannon Distance"}, 25 | vmin=0, 26 | vmax=1 27 | ) 28 | _ = ax.set_title("Comparing All Seed Pairs Using Jensen-Shannon Distance") 29 | fig = ax.get_figure() 30 | fig.tight_layout() 31 | plt.show() 32 | -------------------------------------------------------------------------------- /examples/visualization/plot_sampleset_compare_to.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to compare sample sets.""" 5 | import sys 6 | 7 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 8 | from riid.visualize import plot_ss_comparison 9 | 10 | if len(sys.argv) == 2: 11 | import matplotlib 12 | matplotlib.use("Agg") 13 | 14 | SYNTHETIC_DATA_CONFIG = { 15 | "samples_per_seed": 100, 16 | "bg_cps": 100, 17 | "snr_function": "uniform", 18 | "snr_function_args": (1, 100), 19 | "live_time_function": "uniform", 20 | "live_time_function_args": (0.25, 10), 21 | "apply_poisson_noise": True, 22 | "return_fg": False, 23 | "return_gross": True, 24 | } 25 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds()\ 26 | .split_fg_and_bg() 27 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ 28 | .generate(1) 29 | 30 | _, gross_ss1 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\ 31 | .generate(fg_seeds_ss, mixed_bg_seed_ss) 32 | _, gross_ss2 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\ 33 | .generate(fg_seeds_ss, mixed_bg_seed_ss) 34 | 35 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2, 36 | density=False) 37 | plot_ss_comparison(ss1_stats, 38 | ss2_stats, 39 | col_comparisons, 40 | "live_time", 41 | show=True) 42 | 43 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2, 44 | density=True) 45 | plot_ss_comparison(ss1_stats, 46 | ss2_stats, 47 | col_comparisons, 48 | "total_counts", 49 | show=True) 50 | -------------------------------------------------------------------------------- /examples/visualization/plot_spectra.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This example demonstrates how to plot gamma spectra.""" 5 | import sys 6 | 7 | from riid import get_dummy_seeds 8 | from riid.visualize import plot_spectra 9 | 10 | if len(sys.argv) == 2: 11 | import matplotlib 12 | matplotlib.use("Agg") 13 | 14 | seeds_ss = get_dummy_seeds() 15 | 16 | plot_spectra(seeds_ss, ylim=(None, None), in_energy=True) 17 | -------------------------------------------------------------------------------- /pdoc/config.mako: -------------------------------------------------------------------------------- 1 | <%! 2 | # Template configuration. Copy over in your template directory 3 | # (used with `--template-dir`) and adapt as necessary. 4 | # Note, defaults are loaded from this distribution file, so your 5 | # config.mako only needs to contain values you want overridden. 6 | # You can also run pdoc with `--config KEY=VALUE` to override 7 | # individual values. 8 | html_lang = 'en' 9 | 10 | show_inherited_members = False 11 | 12 | extract_module_toc_into_sidebar = True 13 | 14 | list_class_variables_in_index = True 15 | 16 | sort_identifiers = True 17 | 18 | show_type_annotations = True 19 | 20 | # Show collapsed source code block next to each item. 21 | # Disabling this can improve rendering speed of large modules. 22 | show_source_code = True 23 | 24 | # If set, format links to objects in online source code repository 25 | # according to this template. Supported keywords for interpolation 26 | # are: commit, path, start_line, end_line. 27 | git_link_template = 'https://github.com/sandialabs/pyriid/blob/{commit}/{path}#L{start_line}-L{end_line}' 28 | 29 | # A prefix to use for every HTML hyperlink in the generated documentation. 30 | # No prefix results in all links being relative. 31 | link_prefix = '' 32 | 33 | # Enable syntax highlighting for code/source blocks by including Highlight.js 34 | syntax_highlighting = True 35 | # Set the style keyword such as 'atom-one-light' or 'github-gist' 36 | # Options: https://github.com/highlightjs/highlight.js/tree/master/src/styles 37 | # Demo: https://highlightjs.org/static/demo/ 38 | hljs_style = 'github' 39 | 40 | # If set, insert Google Analytics tracking code. Value is GA 41 | # tracking id (UA-XXXXXX-Y). 42 | google_analytics = '' 43 | 44 | # If set, insert Google Custom Search search bar widget above the sidebar index. 45 | # The whitespace-separated tokens represent arbitrary extra queries (at least one 46 | # must match) passed to regular Google search. Example: 47 | #google_search_query = 'inurl:github.com/USER/PROJECT site:PROJECT.github.io site:PROJECT.website' 48 | google_search_query = '' 49 | 50 | # Enable offline search using Lunr.js. For explanation of 'fuzziness' parameter, which is 51 | # added to every query word, see: https://lunrjs.com/guides/searching.html#fuzzy-matches 52 | # If 'index_docstrings' is False, a shorter index is built, indexing only 53 | # the full object reference names. 54 | lunr_search = {'fuzziness': 1, 'index_docstrings': True} 55 | 56 | # If set, render LaTeX math syntax within \(...\) (inline equations), 57 | # or within \[...\] or $$...$$ or `.. math::` (block equations) 58 | # as nicely-formatted math formulas using MathJax. 59 | # Note: in Python docstrings, either all backslashes need to be escaped (\\) 60 | # or you need to use raw r-strings. 61 | latex_math = True 62 | %> -------------------------------------------------------------------------------- /pdoc/requirements.txt: -------------------------------------------------------------------------------- 1 | pdoc3==0.10.0 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 68", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.setuptools.packages.find] 6 | include = ["riid*"] 7 | namespaces = false 8 | 9 | [tool.setuptools.package-data] 10 | "riid.gadras" = ["*.json"] 11 | 12 | [project] 13 | name = "riid" 14 | description = "Machine learning-based models and utilities for radioisotope identification" 15 | version = "2.2.0" 16 | maintainers = [ 17 | {name="Tyler Morrow", email="tmorro@sandia.gov"}, 18 | ] 19 | authors = [ 20 | {name="Tyler Morrow"}, 21 | {name="Nathan Price"}, 22 | {name="Travis McGuire"}, 23 | {name="Tyler Ganter"}, 24 | {name="Aislinn Handley"}, 25 | {name="Paul Thelen"}, 26 | {name="Alan Van Omen"}, 27 | {name="Leon Ross"}, 28 | {name="Alyshia Bustos"}, 29 | ] 30 | readme = "README.md" 31 | license = {file = "LICENSE.md"} 32 | classifiers = [ 33 | 'Development Status :: 5 - Production/Stable', 34 | 'Intended Audience :: Developers', 35 | 'Intended Audience :: Education', 36 | 'Intended Audience :: Science/Research', 37 | 'License :: OSI Approved :: BSD License', 38 | 'Topic :: Scientific/Engineering', 39 | 'Topic :: Scientific/Engineering :: Mathematics', 40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 41 | 'Topic :: Software Development', 42 | 'Topic :: Software Development :: Libraries', 43 | 'Topic :: Software Development :: Libraries :: Python Modules', 44 | 'Programming Language :: Python', 45 | 'Programming Language :: Python :: 3', 46 | 'Programming Language :: Python :: 3.10', 47 | 'Programming Language :: Python :: 3.11', 48 | 'Programming Language :: Python :: 3.12', 49 | ] 50 | keywords = ["pyriid", "riid", "machine learning", "radioisotope identification", "gamma spectrum"] 51 | 52 | requires-python = ">=3.10,<3.13" 53 | dependencies = [ 54 | "jsonschema ==4.23.*", # 3.8 - 3.13 55 | "matplotlib ==3.9.*", # 3.9 - 3.12 56 | "numpy ==1.26.*", # 3.9 - 3.12, also to be limited by onnx 1.16.2 57 | "pandas ==2.2.*", # >= 3.9 58 | "pythonnet ==3.0.3; platform_system == 'Windows'", # 3.7 - 3.12 59 | "pyyaml ==6.0.*", # >= 3.6 60 | "tables ==3.10.*", # >= 3.9 61 | "scikit-learn ==1.5.*", # 3.9 - 3.12 62 | "scipy ==1.13.*", # >= 3.10 63 | "seaborn ==0.13.*", # >= 3.8 64 | "keras ==3.8.0", 65 | "tensorflow ==2.16.*", # 3.9 - 3.12 66 | "tensorflow-model-optimization ==0.8.*", # 3.7 - 3.12 67 | "onnx ==1.16.1", # 3.7 - 3.10 68 | "tf2onnx ==1.16.1", # 3.7 - 3.10 69 | "tqdm ==4.66.*", # >= 3.7 70 | "typeguard ==4.3.*", # 3.9 - 3.12 71 | ] 72 | 73 | [project.optional-dependencies] 74 | dev = [ 75 | "coverage", 76 | "ipykernel", 77 | "flake8", 78 | "flake8-quotes", 79 | "tabulate", 80 | ] 81 | 82 | [project.urls] 83 | Documentation = "https://sandialabs.github.io/PyRIID" 84 | Repository = "https://github.com/sandialabs/PyRIID" 85 | -------------------------------------------------------------------------------- /riid/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """ 5 | .. include:: ../README.md 6 | """ 7 | import logging 8 | import os 9 | import sys 10 | from importlib.metadata import version 11 | 12 | from riid.data.sampleset import (SampleSet, SpectraState, SpectraType, 13 | read_hdf, read_json, read_pcf) 14 | from riid.data.synthetic.passby import PassbySynthesizer 15 | from riid.data.synthetic.seed import (SeedMixer, SeedSynthesizer, 16 | get_dummy_seeds) 17 | from riid.data.synthetic.static import StaticSynthesizer 18 | 19 | HANDLER = logging.StreamHandler(sys.stdout) 20 | logging.root.addHandler(HANDLER) 21 | logging.root.setLevel(logging.DEBUG) 22 | MPL_LOGGER = logging.getLogger("matplotlib") 23 | MPL_LOGGER.setLevel(logging.WARNING) 24 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" 25 | 26 | SAMPLESET_HDF_FILE_EXTENSION = ".h5" 27 | SAMPLESET_JSON_FILE_EXTENSION = ".json" 28 | PCF_FILE_EXTENSION = ".pcf" 29 | ONNX_MODEL_FILE_EXTENSION = ".onnx" 30 | TFLITE_MODEL_FILE_EXTENSION = ".tflite" 31 | RIID = "riid" 32 | 33 | __version__ = version(RIID) 34 | 35 | __pdoc__ = { 36 | "riid.data.synthetic.seed.SeedMixer.__call__": True, 37 | "riid.data.synthetic.passby.PassbySynthesizer._generate_single_passby": True, 38 | "riid.data.sampleset.SampleSet._channels_to_energies": True, 39 | } 40 | 41 | __all__ = ["SampleSet", "SpectraState", "SpectraType", 42 | "read_hdf", "read_json", "read_pcf", "get_dummy_seeds", 43 | "PassbySynthesizer", "SeedSynthesizer", "StaticSynthesizer", "SeedMixer"] 44 | -------------------------------------------------------------------------------- /riid/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This sub-package contains all utilities for synthesizing, reading, writing, and converting data. 5 | """ 6 | import numpy as np 7 | 8 | 9 | def get_expected_spectra(seeds: np.ndarray, expected_counts: np.ndarray) -> np.ndarray: 10 | """Multiply a 1-D array of expected counts by either a 1-D array or 2-D 11 | matrix of seed spectra. 12 | 13 | The dimension(s) of the seed array(s), `seeds`, is expanded to be `(m, n, 1)` where: 14 | 15 | - m = # of seeds 16 | - n = # of channels 17 | 18 | and the final dimension is added in order to facilitate proper broadcasting. 19 | The dimension of the `expected_counts` must be 1, but the length `p` can be 20 | any positive number. 21 | 22 | The resulting expected spectra will be of shape `(m x p, n)`. 23 | This represents the same number of channels `n`, but each expected count 24 | value, of which there are `p`, will be me multiplied through each seed spectrum, 25 | of which there are `m`. 26 | All expected spectra matrices for each seed are then concatenated together 27 | (stacked), eliminating the 3rd dimension. 28 | """ 29 | if expected_counts.ndim != 1: 30 | raise ValueError("Expected counts array must be 1-D.") 31 | if expected_counts.shape[0] == 0: 32 | raise ValueError("Expected counts array cannot be empty.") 33 | if seeds.ndim > 2: 34 | raise InvalidSeedError("Seeds array must be 1-D or 2-D.") 35 | 36 | expected_spectra = np.concatenate( 37 | seeds * expected_counts[:, np.newaxis, np.newaxis] 38 | ) 39 | 40 | return expected_spectra 41 | 42 | 43 | class InvalidSeedError(Exception): 44 | """Seed spectra data structure is not 1- or 2-dimensional.""" 45 | pass 46 | -------------------------------------------------------------------------------- /riid/data/converters/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains utilities for converting known datasets into `SampleSet`s.""" 5 | import glob 6 | from pathlib import Path 7 | from typing import Callable 8 | 9 | from joblib import Parallel, delayed 10 | 11 | 12 | def _validate_and_create_output_dir(output_dir: str): 13 | output_dir_path = Path(output_dir) 14 | output_dir_path.mkdir(exist_ok=True) 15 | if not output_dir_path.is_dir: 16 | raise ValueError("`output_dir` already exists but is not a directory.") 17 | 18 | 19 | def convert_directory(input_dir_path: str, conversion_func: Callable, file_ext: str, 20 | n_jobs: int = 8, **kwargs): 21 | """Convert and save every file in a specified directory. 22 | 23 | Conversion functions can be found in sub-modules: 24 | 25 | - AIPT: `riid.data.converters.aipt.convert_and_save()` 26 | - TopCoder: `riid.data.converters.topcoder.convert_and_save()` 27 | 28 | Due to usage of parallel processing, be sure to run this function as follows: 29 | 30 | ```python 31 | if __name__ == "__main__": 32 | convert_directory(...) 33 | ``` 34 | 35 | Tip: for max utilization, considering setting `n_jobs` to `multiprocessing.cpu_count()`. 36 | 37 | Args: 38 | input_dir_path: directory path containing the input files 39 | conversion_func: function used to convert a data file to a `SampleSet` 40 | file_ext: file extension to read in for conversion 41 | n_jobs: `joblib.Parallel` parameter to set the # of jobs 42 | kwargs: additional keyword args passed to conversion_func 43 | """ 44 | input_path = Path(input_dir_path) 45 | if not input_path.exists() or not input_path.is_dir(): 46 | print(f"No directory at provided input path: '{input_dir_path}'") 47 | return 48 | 49 | input_file_paths = sorted(glob.glob(f"{input_dir_path}/*.{file_ext}")) 50 | 51 | Parallel(n_jobs, verbose=10)( 52 | delayed(conversion_func)(path, **kwargs) for path in input_file_paths 53 | ) 54 | -------------------------------------------------------------------------------- /riid/data/converters/aipt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module provides tools for handling data related to the multi-lab Algorithm Improvement 5 | Program Team (AIPT). 6 | """ 7 | import os 8 | from pathlib import Path 9 | from typing import List 10 | 11 | import pandas as pd 12 | 13 | from riid import SAMPLESET_HDF_FILE_EXTENSION, SampleSet 14 | from riid.data.converters import _validate_and_create_output_dir 15 | 16 | ELEMENT_IDS_PER_FILE = [0, 1, 2, 3] 17 | DEFAULT_ECAL = [ 18 | -3.00000, 19 | 3010.00000, 20 | 150.00000, 21 | 0.00000, 22 | 0.00000, 23 | ] 24 | 25 | 26 | def _element_to_ss(data_df: pd.DataFrame, eid: int, description) -> SampleSet: 27 | if eid not in ELEMENT_IDS_PER_FILE: 28 | msg = ( 29 | f"Element #{eid} is invalid. " 30 | "The available options are: {ELEMENT_IDS_PER_FILE}" 31 | ) 32 | raise ValueError(msg) 33 | 34 | ss = SampleSet() 35 | ss.spectra = data_df[f"spectrum-channels{eid}"]\ 36 | .str\ 37 | .split(",", expand=True)\ 38 | .astype(int) 39 | ss.info.live_time = data_df[f"spectrum-lt{eid}"] / 1000 40 | ss.info.real_time = data_df[f"spectrum-rt{eid}"] / 1000 41 | ss.info.total_counts = data_df[f"gc{eid}"] 42 | ss.info.neutron_counts = data_df["nc0"] 43 | ss.info.ecal_order_0 = DEFAULT_ECAL[0] 44 | ss.info.ecal_order_1 = DEFAULT_ECAL[1] 45 | ss.info.ecal_order_2 = DEFAULT_ECAL[2] 46 | ss.info.ecal_order_3 = DEFAULT_ECAL[3] 47 | ss.info.ecal_low_e = DEFAULT_ECAL[4] 48 | ss.info.timestamp = data_df["utc-time"] 49 | ss.info.description = description 50 | ss.info["latitude"] = data_df["latitude"] 51 | ss.info["longitude"] = data_df["longitude"] 52 | ss.info["is_in_zone"] = data_df["is-in-zone"] 53 | ss.info["is_closest_approach"] = data_df["is-closest-approach"] 54 | ss.info["is_source_present"] = data_df["is-source-present"] 55 | 56 | detector_name = f"{data_df['detector'].unique()[0]}.{eid}" 57 | ss.detector_info = { 58 | "name": detector_name 59 | } 60 | 61 | return ss 62 | 63 | 64 | def aipt_file_to_ss_list(file_path: str) -> List[SampleSet]: 65 | """Process an AIPT CSV file into a list of SampleSets. 66 | 67 | Each file contains a series of spectra for multiple detectors running simultaneously. 68 | As such, each `SampleSet` in the list returned by this function represents the data 69 | collected by each detector. 70 | Each row of each `SampleSet` represents a measurement from each detector at the same 71 | moment in time. 72 | As such, after calling this function, you might consider summing all four spectra at 73 | each timestep into a single spectrum as another processing step. 74 | 75 | Args: 76 | file_path: file path of the CSV file 77 | 78 | Returns: 79 | List of `SampleSet`s each containing a series of spectra for a single run 80 | """ 81 | data_df = pd.read_csv(file_path, header=0, sep="\t") 82 | base_description = os.path.splitext(os.path.basename(file_path))[0] 83 | 84 | ss_list = [] 85 | for eid in ELEMENT_IDS_PER_FILE: 86 | description = f"{base_description}_{eid}" 87 | ss = _element_to_ss(data_df, eid, description) 88 | ss_list.append(ss) 89 | 90 | return ss_list 91 | 92 | 93 | def convert_and_save(input_file_path: str, output_dir: str = None, 94 | skip_existing: bool = True, **kwargs): 95 | """Convert AIPT file to SampleSet and save as HDF. 96 | 97 | Output file will have same name but appended with a detector identifier 98 | and having a different extension. 99 | 100 | Args: 101 | input_file_path: file path of the CSV file 102 | output_dir: alternative directory in which to save HDF files 103 | (defaults to `input_file_path` parent if not provided) 104 | skip_existing: whether to skip conversion if the file already exists 105 | kwargs: keyword args passed to `aipt_file_to_ss_list()` (not currently used) 106 | """ 107 | input_path = Path(input_file_path) 108 | if not output_dir: 109 | output_dir = input_path.parent 110 | _validate_and_create_output_dir(output_dir) 111 | output_file_paths = [ 112 | os.path.join(output_dir, input_path.stem + f"-{i}{SAMPLESET_HDF_FILE_EXTENSION}") 113 | for i in ELEMENT_IDS_PER_FILE 114 | ] 115 | all_output_files_exist = all([os.path.exists(p) for p in output_file_paths]) 116 | if skip_existing and all_output_files_exist: 117 | return 118 | 119 | ss_list = aipt_file_to_ss_list(input_file_path, **kwargs) 120 | for output_file_path, ss in zip(output_file_paths, ss_list): 121 | ss.to_hdf(output_file_path) 122 | -------------------------------------------------------------------------------- /riid/data/converters/topcoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module provides tools for handling data related to the 5 | "Detecting Radiological Threats in Urban Areas" TopCoder Challenge. 6 | https://doi.org/10.1038/s41597-020-00672-2 7 | """ 8 | import csv 9 | import logging 10 | import os 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | from riid import SAMPLESET_HDF_FILE_EXTENSION, SampleSet 17 | from riid.data.converters import _validate_and_create_output_dir 18 | from riid.data.labeling import label_to_index_element 19 | 20 | SOURCE_ID_TO_LABEL = { 21 | 0: "Background", 22 | 1: "HEU", 23 | 2: "WGPu", 24 | 3: "I131", 25 | 4: "Co60", 26 | 5: "Tc99m", 27 | 6: "HEU + Tc99m", 28 | } 29 | DISTINCT_SOURCES = list(SOURCE_ID_TO_LABEL.values()) 30 | 31 | 32 | def _get_answers(answer_file_path: str): 33 | answers = {} 34 | with open(answer_file_path) as csvfile: 35 | reader = csv.reader(csvfile, delimiter=",") 36 | _ = next(reader) # skip header 37 | # timestamp = 0 # in milliseconds 38 | for row in reader: 39 | run_id = row[0] 40 | source_id = int(row[1]) 41 | source_time_secs = float(row[2]) 42 | answers[run_id] = { 43 | "source_id": source_id, 44 | "source_time_secs": source_time_secs 45 | } 46 | return answers 47 | 48 | 49 | def topcoder_file_to_ss(file_path: str, sample_interval: float, n_bins: int = 1024, 50 | max_energy_kev: int = 4000, answers_path: str = None) -> SampleSet: 51 | """Convert a TopCoder CSV file of list-mode data into a SampleSet. 52 | 53 | Args: 54 | file_path: file path of the CSV file 55 | sample_interval: integration time (referred to as "live time" later) to use in seconds. 56 | Warning: the final sample in the set will likely be truncated, i.e., the count rate 57 | will appear low because the live time represented is too large. 58 | Consider ignoring the last sample. 59 | n_bins: desired number of bins in the resulting spectra. 60 | Bins will be uniformly spaced from 0 to `max_energy_kev`. 61 | max_energy_kev: desired maximum of the energy range represented in the resulting spectra. 62 | Intuition (assuming a fixed number of bins): a higher max energy value "compresses" the 63 | spectral information; a lower max energy value spreads out the spectral information and 64 | counts are potentially lost off the high energy end of the specturm. 65 | answers_path: path to the answer key for the data. If provided, this will fill out the 66 | `SampleSet.sources` DataFrame. 67 | 68 | Returns: 69 | `SampleSet` containing the series of spectra for a single run 70 | """ 71 | file_name_with_dir = os.path.splitext(file_path)[0] 72 | file_name = os.path.basename(file_name_with_dir) 73 | slice_duration_ms = sample_interval * 1000 # in milliseconds 74 | events = [] 75 | with open(file_path) as csvfile: 76 | reader = csv.reader(csvfile, delimiter=",") 77 | timestamp = 0 # in milliseconds 78 | for row in reader: 79 | timestamp += int(row[0]) / 1000 # microseconds to milliseconds 80 | energy = float(row[1]) 81 | if energy > max_energy_kev: 82 | msg = ( 83 | f"Encountered energy ({energy:.2f} keV) greater than " 84 | f"specified max energy ({max_energy_kev:.2f})" 85 | ) 86 | logging.warn(msg) 87 | channel = int(n_bins * energy // max_energy_kev) # energy to bin 88 | events.append((timestamp, channel, 1)) 89 | 90 | events_df = pd.DataFrame( 91 | events, 92 | columns=["timestamp", "channel", "counts"] 93 | ) 94 | # Organize events into time intervals 95 | event_time_intervals = pd.cut( 96 | events_df["timestamp"], 97 | np.arange( 98 | start=0, 99 | stop=events_df["timestamp"].max()+slice_duration_ms, 100 | step=slice_duration_ms 101 | ) 102 | ) 103 | # Group events by time intervals 104 | event_time_groups = events_df.groupby(event_time_intervals) 105 | # Within time intervals, sum counts by channel 106 | result = event_time_groups.apply( 107 | lambda x: x.groupby("channel").sum().loc[:, ["counts"]] 108 | ) 109 | # Create new dataframe where: row = time interval, column = channel 110 | spectra_df = result.unstack(level=-1, fill_value=0) 111 | spectra_df = spectra_df["counts"] 112 | # Add in missing columns as needed 113 | col_list = np.arange(0, n_bins) 114 | cols_to_add = np.setdiff1d(col_list, spectra_df.columns) 115 | missing_df = pd.DataFrame( 116 | 0, 117 | columns=cols_to_add, 118 | index=spectra_df.index 119 | ) 120 | combined_unsorted_df = pd.concat([spectra_df, missing_df], axis=1) 121 | combined_df = combined_unsorted_df.reindex( 122 | sorted(combined_unsorted_df.columns), 123 | axis=1 124 | ) 125 | spectra_df = combined_df.astype(int) 126 | spectra_df.reset_index(inplace=True, drop=True) 127 | 128 | # SampleSet creation 129 | ss = SampleSet() 130 | ss.measured_or_synthetic = "synthetic" 131 | ss.detector_info = { 132 | "name": "2\"x4\"x16\" NaI(Tl)", 133 | "height_cm": 100, 134 | "fwhm_at_661_kev": 0.075, 135 | } 136 | ss.spectra = spectra_df 137 | ss.info.total_counts = spectra_df.sum(axis=1) 138 | ss.info.live_time = sample_interval 139 | ss.info.description = file_name 140 | ss.info.ecal_order_0 = 0 141 | ss.info.ecal_order_1 = max_energy_kev 142 | ss.info.ecal_order_2 = 0 143 | ss.info.ecal_order_3 = 0 144 | ss.info.ecal_low_e = 0 145 | 146 | if answers_path: 147 | answers = _get_answers(answers_path) 148 | run_id = file_name.split("runID-")[-1].split(".")[0] 149 | ss.info.timestamp = answers[run_id]["source_time_secs"] 150 | source_id = answers[run_id]["source_id"] 151 | sources_mi = pd.MultiIndex.from_tuples([ 152 | label_to_index_element(x, label_level="Seed") 153 | for x in DISTINCT_SOURCES 154 | ], names=SampleSet.SOURCES_MULTI_INDEX_NAMES) 155 | sources_df = pd.DataFrame( 156 | np.zeros((ss.n_samples, len(DISTINCT_SOURCES))), 157 | columns=sources_mi, 158 | ) 159 | sources_df.sort_index(axis=1, inplace=True) 160 | source = label_to_index_element(SOURCE_ID_TO_LABEL[source_id], label_level="Seed") 161 | sources_df[source] = 1.0 162 | ss.sources = sources_df 163 | 164 | return ss 165 | 166 | 167 | def convert_and_save(input_file_path: str, output_dir: str = None, 168 | skip_existing: bool = True, **kwargs): 169 | """Convert TopCoder file to SampleSet and save as HDF. 170 | 171 | Output file will have same name with different extension. 172 | 173 | Args: 174 | input_file_path: file path of the CSV file 175 | output_dir: alternative directory in which to save HDF files 176 | (defaults to `input_file_path` parent if not provided) 177 | skip_existing: whether to skip conversion if the file already exists 178 | kwargs: keyword args passed to `topcoder_file_to_ss()` 179 | """ 180 | input_path = Path(input_file_path) 181 | if not output_dir: 182 | output_dir = input_path.parent 183 | _validate_and_create_output_dir(output_dir) 184 | output_file_path = os.path.join(output_dir, input_path.stem + SAMPLESET_HDF_FILE_EXTENSION) 185 | if skip_existing and os.path.exists(output_file_path): 186 | return 187 | 188 | ss = topcoder_file_to_ss(input_file_path, **kwargs) 189 | ss.to_hdf(output_file_path) 190 | -------------------------------------------------------------------------------- /riid/data/labeling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains utility functions for managing ground truth information.""" 5 | import logging 6 | import re 7 | 8 | BACKGROUND_LABEL = "Background" 9 | NO_SEED = "Unknown" 10 | NO_ISOTOPE = "Unknown" 11 | NO_CATEGORY = "Uncategorized" 12 | CATEGORY_ISOTOPES = { 13 | "Fission Product": { 14 | "severity": 3, 15 | "isotopes": [ 16 | "Ag112", 17 | "As78", 18 | "Ba139", 19 | "Ce143", 20 | "I132", 21 | "I133", 22 | "I134", 23 | "I135", 24 | "Kr85m", 25 | "Kr87", 26 | "La140", 27 | "La142", 28 | "Nd149", 29 | "Pm150", 30 | "Rh105", 31 | "Ru105", 32 | "Sb115", 33 | "Sb129", 34 | "Sr91", 35 | "Sr92", 36 | "Te132", 37 | "Y93", 38 | "Y91m", 39 | "Zr95", 40 | ] 41 | }, 42 | "Industrial": { 43 | "severity": 2, 44 | "isotopes": [ 45 | "Am241", 46 | "Ba133", 47 | "Ba140", 48 | "Bi207", 49 | "Cf249", 50 | "Cf250", 51 | "Cf251", 52 | "Cf252", 53 | "Cm244", 54 | "Co57", 55 | "Co60", 56 | "Cs137", 57 | "Eu152", 58 | "Eu154", 59 | "H3", 60 | "Ho166m", 61 | "Ir192", 62 | "Na22", 63 | "P32", 64 | "P33", 65 | "Po210", 66 | "Se75", 67 | "Sr90", 68 | "Tc99", 69 | "Y88", 70 | ], 71 | }, 72 | "Medical": { 73 | "severity": 3, 74 | "isotopes": [ 75 | "F18", 76 | "Ga67", 77 | "Ga68", 78 | "Ge68", 79 | "I123", 80 | "I124", 81 | "I125", 82 | "I129", 83 | "I131", 84 | "In111", 85 | "Lu177m", 86 | "Mo99", 87 | "Pd103", 88 | "Ra223", 89 | "Rb82", 90 | "Sm153", 91 | "Tc99m", 92 | "Tl201", 93 | "Xe133", 94 | ], 95 | }, 96 | "NORM": { 97 | "severity": 1, 98 | "isotopes": [ 99 | "Cosmic", 100 | "K40", 101 | "Pb210", 102 | "Ra226", 103 | "Th232", 104 | ], 105 | }, 106 | "SNM": { 107 | "severity": 4, 108 | "isotopes": [ 109 | "Np237", 110 | "Pu238", 111 | "Pu239", 112 | "U232", 113 | "U233", 114 | "U235", 115 | "U237", 116 | # 16,000 years from now, once the chemically separated uranium has equilibrated 117 | # with Ra226, then we will need to reconsider U238's categorization. 118 | "U238", 119 | ], 120 | }, 121 | } 122 | ISOTOPES = sum( 123 | [c["isotopes"] for c in CATEGORY_ISOTOPES.values()], 124 | [] 125 | ) # Concatenating the lists of isotopes into one list 126 | SEED_TO_ISOTOPE_SPECIAL_CASES = { 127 | "ThPlate": "Th232", 128 | "ThPlate+Thxray,10uC": "Th232", 129 | "fiestaware": "U238", 130 | "Uxray,100uC": "U238", 131 | "DUOxide": "U238", 132 | "ShieldedDU": "U238", 133 | "modified_berpball": "Pu239", 134 | "10 yr WGPu in Fe": "Pu239", 135 | "1gPuWG_0.5yr,3{an=10,ad=5}": "Pu239", 136 | "pu239_1yr": "Pu239", 137 | "pu239_5yr": "Pu239", 138 | "pu239_10yr": "Pu239", 139 | "pu239_25yr": "Pu239", 140 | "pu239_50yr": "Pu239", 141 | "1kg HEU + 800uCi Cs137": "U235", 142 | "WGPu + Cs137": "Pu239", 143 | "HEU": "U235", 144 | "HEU + Tc99m": "U235", 145 | "DU": "U238", 146 | "WGPu": "Pu239", 147 | "PuWG": "Pu239", 148 | "RTG": "Pu238", 149 | "PotassiumInSoil": "K40", 150 | "UraniumInSoil": "Ra226", 151 | "ThoriumInSoil": "Th232", 152 | "Cosmic": "Cosmic", 153 | } 154 | 155 | 156 | def _find_isotope(seed: str, verbose=True): 157 | """Attempt to find the category for the given seed. 158 | 159 | Args: 160 | seed: string containing the isotope name 161 | verbose: whether log warnings 162 | 163 | Returns: 164 | Isotope if found, otherwise NO_ISOTOPE 165 | """ 166 | if seed.lower() == BACKGROUND_LABEL.lower(): 167 | return BACKGROUND_LABEL 168 | if seed == NO_ISOTOPE: 169 | return NO_ISOTOPE 170 | 171 | isotopes = [] 172 | for i in ISOTOPES: 173 | if i in seed: 174 | isotopes.append(i) 175 | 176 | n_isotopes = len(isotopes) 177 | if n_isotopes > 1: 178 | # Use the longest matching isotope (handles sources strings for things like Tc99 vs Tc99m) 179 | chosen_match = max(isotopes) 180 | if verbose: 181 | logging.warning(( 182 | f"Found multiple isotopes whose names are subsets of '{seed}';" 183 | f" '{chosen_match}' was chosen." 184 | )) 185 | return chosen_match 186 | elif n_isotopes == 0: 187 | return NO_ISOTOPE 188 | else: 189 | return isotopes[0] 190 | 191 | 192 | def _find_category(isotope: str): 193 | """Attempt to find the category for the given isotope. 194 | 195 | Args: 196 | isotope: string containing the isotope name 197 | 198 | Returns: 199 | Category if found, otherwise NO_CATEGORY 200 | """ 201 | if isotope.lower() == BACKGROUND_LABEL.lower(): 202 | return BACKGROUND_LABEL 203 | if isotope == NO_CATEGORY: 204 | return NO_CATEGORY 205 | 206 | categories = [] 207 | for c, v in CATEGORY_ISOTOPES.items(): 208 | c_severity = v["severity"] 209 | c_isotopes = v["isotopes"] 210 | for i in c_isotopes: 211 | if i in isotope: 212 | categories.append((c, c_severity)) 213 | 214 | n_categories = len(categories) 215 | if n_categories > 1: 216 | return max(categories, key=lambda x: x[1])[0] 217 | elif n_categories == 0: 218 | return NO_CATEGORY 219 | else: 220 | return categories[0][0] 221 | 222 | 223 | def label_to_index_element(label_val: str, label_level="Isotope", verbose=False) -> tuple: 224 | """Try to map a label to a tuple for use in `DataFrame` `MultiIndex` columns. 225 | 226 | Depending on the level of the label value, you will get different tuple: 227 | 228 | | Label Level | Resulting Tuple | 229 | |:-------------|:-------------------------| 230 | |Seed |(Category, Isotope, Seed) | 231 | |Isotope |(Category, Isotope) | 232 | |Category |(Category,) | 233 | 234 | Args: 235 | label_val: part of the label (Category, Isotope, Seed) 236 | from which to map the other two label values, if possible 237 | label_level: level of the part of the label provided, e.g, 238 | "Category", "Isotope", or "Seed" 239 | 240 | Returns: 241 | Tuple containing the Category, Isotope, and/or Seed values identified 242 | for the old label format. 243 | """ 244 | 245 | old_label = label_val.strip() 246 | # Some files use 'background', others use 'Background'. 247 | if old_label.lower() == BACKGROUND_LABEL.lower(): 248 | old_label = BACKGROUND_LABEL 249 | 250 | if label_level == "Category": 251 | return (old_label,) 252 | 253 | if label_level == "Isotope": 254 | category = _find_category(old_label) 255 | return (category, old_label) 256 | 257 | if label_level == "Seed": 258 | special_cases = SEED_TO_ISOTOPE_SPECIAL_CASES.keys() 259 | exact_match = old_label in special_cases 260 | first_partial_match = None 261 | if not exact_match: 262 | first_partial_match = next((x for x in special_cases if x in old_label), None) 263 | if exact_match: 264 | isotope = SEED_TO_ISOTOPE_SPECIAL_CASES[old_label] 265 | elif first_partial_match: 266 | isotope = SEED_TO_ISOTOPE_SPECIAL_CASES[first_partial_match] 267 | else: 268 | isotope = _find_isotope(old_label, verbose) 269 | category = _find_category(isotope) 270 | return (category, isotope, old_label) 271 | 272 | 273 | def isotope_name_is_valid(isotope: str): 274 | """Validate whether the given string contains a properly formatted radioisotope name. 275 | 276 | Note that this function does NOT look up a string to determine if the string corresponds 277 | to a radioisotope that actually exists, it just checks the format. 278 | 279 | The regular expression used by this function looks for the following (in order): 280 | 281 | - 1 capital letter 282 | - 0 to 1 lowercase letters 283 | - 1 to 3 numbers 284 | - an optional "m" for metastable 285 | 286 | Examples of properly formatted isotope names: 287 | 288 | - Y88 289 | - Ba133 290 | - Ho166m 291 | 292 | Args: 293 | isotope: string containing the isotope name 294 | 295 | Returns: 296 | Bool representing whether the name string is valid 297 | """ 298 | validator = re.compile(r"^[A-Z]{1}[a-z]{0,1}[0-9]{1,3}m?$") 299 | other_valid_names = ["fiestaware"] 300 | match = validator.match(isotope) 301 | is_valid = match is not None or \ 302 | isotope.lower() in other_valid_names 303 | return is_valid 304 | -------------------------------------------------------------------------------- /riid/data/synthetic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains utilities for synthesizing gamma spectra.""" 5 | # The following imports are left to not break previous imports; remove in v3 6 | from riid.data.synthetic.base import Synthesizer, get_distribution_values 7 | from riid.data.synthetic.seed import get_dummy_seeds 8 | 9 | __all__ = ["get_dummy_seeds", "Synthesizer", "get_distribution_values"] 10 | -------------------------------------------------------------------------------- /riid/gadras/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module wraps various GADRAS features by calling into the DLLs of a GADRAS installation.""" 5 | -------------------------------------------------------------------------------- /riid/gadras/api_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "$id": "GADRAS_API_19.2.3_Seed_Synthesis_Schema", 3 | "$schema": "https://json-schema.org/draft/2020-12/schema", 4 | "title": "", 5 | "type": "object", 6 | "required": [ 7 | "gamma_detector", 8 | "sources" 9 | ], 10 | "properties": { 11 | "random_seed": { 12 | "description": "The numpy random seed.", 13 | "type": "integer" 14 | }, 15 | "gamma_detector": { 16 | "title": "Gamma Detector", 17 | "description": "The information about your detector.", 18 | "type": "object", 19 | "properties": { 20 | "name": { 21 | "description": "The directory containing your detector's DRF (.dat file) relative to the GADRAS Detector directory.", 22 | "type": "string" 23 | }, 24 | "parameters": { 25 | "type": "object", 26 | "additionalProperties": true, 27 | "properties": { 28 | "distance_cm": { 29 | "description": "Distance between detector and sourcecentimeters", 30 | "$ref": "#/$defs/detector_properties_types" 31 | }, 32 | "height_cm": { 33 | "description": "Detector height off ground, in centimeters", 34 | "$ref": "#/$defs/detector_properties_types" 35 | }, 36 | "dead_time_per_pulse": { 37 | "description": "Detector dead time, in microseconds", 38 | "$ref": "#/$defs/detector_properties_types" 39 | }, 40 | "latitude_deg": { 41 | "description": "Latitude, in degrees", 42 | "$ref": "#/$defs/detector_properties_types" 43 | }, 44 | "longitude_deg": { 45 | "description": "Longitude, in degrees", 46 | "$ref": "#/$defs/detector_properties_types" 47 | }, 48 | "elevation_m": { 49 | "description": "Elevation, in meters", 50 | "$ref": "#/$defs/detector_properties_types" 51 | } 52 | }, 53 | "required": [ 54 | "distance_cm", 55 | "height_cm", 56 | "dead_time_per_pulse", 57 | "latitude_deg", 58 | "longitude_deg", 59 | "elevation_m" 60 | ] 61 | } 62 | }, 63 | "required": [ 64 | "name", 65 | "parameters" 66 | ] 67 | }, 68 | "sources": { 69 | "title": "Sources", 70 | "description": "The list of sources to obtain via inject(s).", 71 | "type": "array", 72 | "items": { 73 | "type": "object", 74 | "properties": { 75 | "isotope": { 76 | "type": "string" 77 | }, 78 | "configurations": { 79 | "type": "array", 80 | "items": { 81 | "anyOf": [ 82 | { 83 | "type": "string" 84 | }, 85 | { 86 | "$ref": "#/$defs/source_config_type" 87 | } 88 | ] 89 | } 90 | } 91 | }, 92 | "additionalProperties": false 93 | } 94 | } 95 | }, 96 | "$defs": { 97 | "detector_properties_types": { 98 | "anyOf": [ 99 | { 100 | "type": "number" 101 | }, 102 | { 103 | "type": "array", 104 | "items": { 105 | "anyOf": [ 106 | { 107 | "type": "number" 108 | }, 109 | { 110 | "$ref": "#/$defs/sample_range" 111 | }, 112 | { 113 | "$ref": "#/$defs/sample_norm" 114 | } 115 | ] 116 | } 117 | }, 118 | { 119 | "$ref": "#/$defs/sample_range" 120 | }, 121 | { 122 | "$ref": "#/$defs/sample_norm" 123 | } 124 | ] 125 | }, 126 | "sample_range": { 127 | "type": "object", 128 | "required": [ 129 | "min", 130 | "max", 131 | "dist", 132 | "num_samples" 133 | ], 134 | "properties": { 135 | "min": { 136 | "type": "number", 137 | "description": "Minimum value of the range (inclusive)." 138 | }, 139 | "max": { 140 | "type": "number", 141 | "description": "Maximum value of the range (inclusive)." 142 | }, 143 | "dist": { 144 | "type": "string", 145 | "description": "The distribution from which to draw samples.", 146 | "enum": [ 147 | "uniform", 148 | "log10" 149 | ] 150 | }, 151 | "num_samples": { 152 | "type": "number", 153 | "description": "Number of samples to draw" 154 | } 155 | }, 156 | "additionalProperties": false 157 | }, 158 | "sample_norm": { 159 | "type": "object", 160 | "required": [ 161 | "mean", 162 | "std", 163 | "num_samples" 164 | ], 165 | "properties": { 166 | "mean": { 167 | "type": "number", 168 | "description": "Mean of the normal distribution from which to draw samples." 169 | }, 170 | "std": { 171 | "type": "number", 172 | "description": "Standard deviation of the normal distribution from which to draw samples." 173 | }, 174 | "num_samples": { 175 | "type": "number", 176 | "description": "Number of samples to draw" 177 | } 178 | }, 179 | "additionalProperties": false 180 | }, 181 | "source_config_type": { 182 | "type": "object", 183 | "required": [ 184 | "name" 185 | ], 186 | "properties": { 187 | "name": { 188 | "type": "string" 189 | }, 190 | "activity": { 191 | "anyOf": [ 192 | { 193 | "type": "number" 194 | }, 195 | { 196 | "type": "array", 197 | "items": { 198 | "type": "number" 199 | } 200 | }, 201 | { 202 | "$ref": "#/$defs/sample_range" 203 | }, 204 | { 205 | "$ref": "#/$defs/sample_norm" 206 | } 207 | ] 208 | }, 209 | "activity_units": { 210 | "type": "string", 211 | "enum": [ 212 | "Ci", 213 | "uCi", 214 | "Bq" 215 | ] 216 | }, 217 | "shielding_atomic_number": { 218 | "anyOf": [ 219 | { 220 | "type": "number" 221 | }, 222 | { 223 | "type": "array", 224 | "items": { 225 | "type": "number" 226 | } 227 | }, 228 | { 229 | "$ref": "#/$defs/sample_range" 230 | }, 231 | { 232 | "$ref": "#/$defs/sample_norm" 233 | } 234 | ] 235 | }, 236 | "shielding_aerial_density": { 237 | "anyOf": [ 238 | { 239 | "type": "number" 240 | }, 241 | { 242 | "type": "array", 243 | "items": { 244 | "type": "number" 245 | } 246 | }, 247 | { 248 | "$ref": "#/$defs/sample_range" 249 | }, 250 | { 251 | "$ref": "#/$defs/sample_norm" 252 | } 253 | ] 254 | } 255 | }, 256 | "additionalProperties": false 257 | } 258 | } 259 | } -------------------------------------------------------------------------------- /riid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains custom loss functions.""" 5 | import numpy as np 6 | import tensorflow as tf 7 | from keras.api import ops 8 | 9 | 10 | def negative_log_f1(y_true: np.ndarray, y_pred: np.ndarray): 11 | """Calculate negative log F1 score. 12 | 13 | Args: 14 | y_true: list of ground truth 15 | y_pred: list of predictions to compare against the ground truth 16 | 17 | Returns: 18 | Custom loss score on a log scale 19 | """ 20 | diff = y_true - y_pred 21 | negs = ops.clip(diff, -1.0, 0.0) 22 | false_positive = -ops.sum(negs, axis=-1) 23 | true_positive = 1.0 - false_positive 24 | lower_clip = 1e-20 25 | true_positive = ops.clip(true_positive, lower_clip, 1.0) 26 | 27 | return -ops.mean(ops.log(true_positive)) 28 | 29 | 30 | def negative_f1(y_true, y_pred): 31 | """Calculate negative F1 score. 32 | 33 | Args: 34 | y_true: list of ground truth 35 | y_pred: list of predictions to compare against the ground truth 36 | 37 | Returns: 38 | Custom loss score 39 | """ 40 | diff = y_true - y_pred 41 | negs = ops.clip(diff, -1.0, 0.0) 42 | false_positive = -ops.sum(negs, axis=-1) 43 | true_positive = 1.0 - false_positive 44 | lower_clip = 1e-20 45 | true_positive = ops.clip(true_positive, lower_clip, 1.0) 46 | 47 | return -ops.mean(true_positive) 48 | 49 | 50 | def build_keras_semisupervised_loss_func(supervised_loss_func, 51 | unsupervised_loss_func, 52 | dictionary, beta, 53 | activation, n_labels, 54 | normalize: bool = False, 55 | normalize_scaler: float = 1.0, 56 | normalize_func=tf.math.tanh): 57 | @tf.keras.utils.register_keras_serializable(package="Addons") 58 | def _semisupervised_loss_func(data, y_pred): 59 | """ 60 | Args: 61 | data: Contains true labels and input features (spectra) 62 | y_pred: Model output (unactivated logits) 63 | """ 64 | y_true = data[:, :n_labels] 65 | spectra = data[:, n_labels:] 66 | logits = y_pred 67 | lpes = activation(y_pred) 68 | 69 | sup_losses = supervised_loss_func(y_true, logits) 70 | unsup_losses = reconstruction_error(spectra, lpes, dictionary, 71 | unsupervised_loss_func) 72 | if normalize: 73 | sup_losses = normalize_func(normalize_scaler * sup_losses) 74 | 75 | semisup_losses = (1 - beta) * sup_losses + beta * unsup_losses 76 | 77 | return semisup_losses 78 | 79 | return _semisupervised_loss_func 80 | 81 | 82 | def sse_diff(spectra, reconstructed_spectra): 83 | """Compute the sum of squares error. 84 | 85 | TODO: refactor to assume spectral inputs are in the same form 86 | 87 | Args: 88 | spectra: spectral samples, assumed to be in counts 89 | reconstructed_spectra: reconstructed spectra created using a 90 | dictionary with label proportion estimates 91 | """ 92 | total_counts = tf.reduce_sum(spectra, axis=1) 93 | scaled_reconstructed_spectra = tf.multiply( 94 | reconstructed_spectra, 95 | tf.reshape(total_counts, (-1, 1)) 96 | ) 97 | 98 | diff = spectra - scaled_reconstructed_spectra 99 | norm_diff = tf.norm(diff, axis=-1) 100 | squared_norm_diff = tf.square(norm_diff) 101 | return squared_norm_diff 102 | 103 | 104 | def poisson_nll_diff(spectra, reconstructed_spectra, eps=1e-8): 105 | """Compute the Poisson Negative Log-Likelihood. 106 | 107 | TODO: refactor to assume spectral inputs are in the same form 108 | 109 | Args: 110 | spectra: spectral samples, assumed to be in counts 111 | reconstructed_spectra: reconstructed spectra created using a 112 | dictionary with label proportion estimates 113 | """ 114 | total_counts = tf.reduce_sum(spectra, axis=-1) 115 | scaled_reconstructed_spectra = tf.multiply( 116 | reconstructed_spectra, 117 | tf.reshape(total_counts, (-1, 1)) 118 | ) 119 | log_reconstructed_spectra = tf.math.log(scaled_reconstructed_spectra + eps) 120 | diff = tf.nn.log_poisson_loss( 121 | spectra, 122 | log_reconstructed_spectra, 123 | compute_full_loss=True 124 | ) 125 | diff = tf.reduce_sum(diff, axis=-1) 126 | 127 | return diff 128 | 129 | 130 | def normal_nll_diff(spectra, reconstructed_spectra, eps=1e-8): 131 | """Compute the Normal Negative Log-Likelihood. 132 | 133 | TODO: refactor to assume spectral inputs are in the same form 134 | 135 | Args: 136 | spectra: spectral samples, assumed to be in counts 137 | reconstructed_spectra: reconstructed spectra created using a 138 | dictionary with label proportion estimates 139 | """ 140 | total_counts = tf.reduce_sum(spectra, axis=-1) 141 | scaled_reconstructed_spectra = tf.multiply( 142 | reconstructed_spectra, 143 | tf.reshape(total_counts, (-1, 1)) 144 | ) 145 | 146 | var = tf.clip_by_value(spectra, clip_value_min=1, clip_value_max=np.inf) 147 | 148 | sigma_term = tf.math.log(2 * np.pi * var) 149 | mu_term = tf.math.divide(tf.math.square(scaled_reconstructed_spectra - spectra), var) 150 | diff = sigma_term + mu_term 151 | diff = 0.5 * tf.reduce_sum(diff, axis=-1) 152 | 153 | return diff 154 | 155 | 156 | def weighted_sse_diff(spectra, reconstructed_spectra): 157 | """Compute the Normal Negative Log-Likelihood under constant variance 158 | (this reduces to the SSE, just on a different scale). 159 | 160 | Args: 161 | spectra: spectral samples, assumed to be in counts 162 | reconstructed_spectra: reconstructed spectra created using a 163 | dictionary with label proportion estimates 164 | """ 165 | total_counts = tf.reduce_sum(spectra, axis=1) 166 | scaled_reconstructed_spectra = tf.multiply( 167 | reconstructed_spectra, 168 | tf.reshape(total_counts, (-1, 1)) 169 | ) 170 | 171 | sample_variance = tf.sqrt(tf.math.reduce_variance(spectra, axis=1)) 172 | 173 | sigma_term = tf.math.log(2 * np.pi * sample_variance) 174 | 175 | mu_term = tf.math.divide( 176 | tf.math.square(scaled_reconstructed_spectra - spectra), 177 | tf.reshape(sample_variance, (-1, 1)) 178 | ) 179 | diff = 0.5 * (sigma_term + tf.reduce_sum(mu_term, axis=-1)) 180 | 181 | return diff 182 | 183 | 184 | def reconstruction_error(spectra, lpes, dictionary, diff_func): 185 | reconstructed_spectra = tf.matmul(lpes, dictionary) 186 | reconstruction_errors = diff_func(spectra, reconstructed_spectra) 187 | return reconstruction_errors 188 | 189 | 190 | def mish(x): 191 | return x * tf.math.tanh(tf.math.softplus(x)) 192 | 193 | 194 | def jensen_shannon_divergence(p, q): 195 | p_sum = tf.reduce_sum(p, axis=-1) 196 | p_norm = tf.divide( 197 | p, 198 | tf.reshape(p_sum, (-1, 1)) 199 | ) 200 | 201 | q_sum = tf.reduce_sum(q, axis=-1) 202 | q_norm = tf.divide( 203 | q, 204 | tf.reshape(q_sum, (-1, 1)) 205 | ) 206 | 207 | kld = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.NONE) 208 | m = (p_norm + q_norm) / 2 209 | jsd = (kld(p_norm, m) + kld(q_norm, m)) / 2 210 | return jsd 211 | 212 | 213 | def jensen_shannon_distance(p, q): 214 | divergence = jensen_shannon_divergence(p, q) 215 | return tf.math.sqrt(divergence) 216 | 217 | 218 | def chi_squared_diff(spectra, reconstructed_spectra): 219 | """Compute the Chi-Squared test. 220 | 221 | Args: 222 | spectra: spectral samples, assumed to be in counts 223 | reconstructed_spectra: reconstructed spectra created using a 224 | dictionary with label proportion estimates 225 | """ 226 | total_counts = tf.reduce_sum(spectra, axis=1) 227 | scaled_reconstructed_spectra = tf.multiply( 228 | reconstructed_spectra, 229 | tf.reshape(total_counts, (-1, 1)) 230 | ) 231 | 232 | diff = tf.math.subtract(spectra, scaled_reconstructed_spectra) 233 | squared_diff = tf.math.square(diff) 234 | variances = tf.clip_by_value(spectra, 1, np.inf) 235 | chi_squared = tf.math.divide(squared_diff, variances) 236 | return tf.reduce_sum(chi_squared, axis=-1) 237 | -------------------------------------------------------------------------------- /riid/losses/sparsemax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | 5 | # This code is based on Tensorflow-Addons. THE ORIGINAL CODE HAS BEEN MODIFIED. 6 | # https://www.tensorflow.org/addons/ 7 | 8 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 9 | 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | """This module contains sparsemax-related functions.""" 22 | 23 | from typing import Optional 24 | 25 | import tensorflow as tf 26 | from typeguard import typechecked 27 | 28 | 29 | def sparsemax(logits, axis: int = -1) -> tf.Tensor: 30 | r"""Sparsemax activation function. 31 | 32 | For each batch \( i \), and class \( j \), 33 | compute sparsemax activation function: 34 | 35 | $$ 36 | \mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0). 37 | $$ 38 | 39 | See 40 | [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification 41 | ](https://arxiv.org/abs/1602.02068). 42 | 43 | Usage: 44 | 45 | >>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]]) 46 | >>> tfa.activations.sparsemax(x) 47 | 50 | 51 | Args: 52 | logits: `Tensor` 53 | axis: `int`, axis along which the sparsemax operation is applied 54 | 55 | Returns: 56 | `Tensor`, output of sparsemax transformation (has the same type and shape as `logits`) 57 | 58 | Raises: 59 | `ValueError` when `dim(logits) == 1` 60 | """ 61 | logits = tf.convert_to_tensor(logits, name="logits") 62 | 63 | # We need its original shape for shape inference. 64 | shape = logits.get_shape() 65 | rank = shape.rank 66 | is_last_axis = (axis == -1) or (axis == rank - 1) 67 | 68 | if is_last_axis: 69 | output = _compute_2d_sparsemax(logits) 70 | output.set_shape(shape) 71 | return output 72 | 73 | # If dim is not the last dimension, we have to do a transpose so that we can 74 | # still perform softmax on its last dimension. 75 | 76 | # Swap logits' dimension of dim and its last dimension. 77 | rank_op = tf.rank(logits) 78 | axis_norm = axis % rank 79 | logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1)) 80 | 81 | # Do the actual softmax on its last dimension. 82 | output = _compute_2d_sparsemax(logits) 83 | output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1)) 84 | 85 | # Make shape inference work since transpose may erase its static shape. 86 | output.set_shape(shape) 87 | return output 88 | 89 | 90 | def _swap_axis(logits, dim_index, last_index, **kwargs): 91 | return tf.transpose( 92 | logits, 93 | tf.concat( 94 | [ 95 | tf.range(dim_index), 96 | [last_index], 97 | tf.range(dim_index + 1, last_index), 98 | [dim_index], 99 | ], 100 | 0, 101 | ), 102 | **kwargs, 103 | ) 104 | 105 | 106 | def _compute_2d_sparsemax(logits): 107 | """Perform the sparsemax operation when axis=-1.""" 108 | shape_op = tf.shape(logits) 109 | obs = tf.math.reduce_prod(shape_op[:-1]) 110 | dims = shape_op[-1] 111 | 112 | # In the paper, they call the logits z. 113 | # The mean(logits) can be substracted from logits to make the algorithm 114 | # more numerically stable. the instability in this algorithm comes mostly 115 | # from the z_cumsum. Substacting the mean will cause z_cumsum to be close 116 | # to zero. However, in practise the numerical instability issues are very 117 | # minor and substacting the mean causes extra issues with inf and nan 118 | # input. 119 | # Reshape to [obs, dims] as it is almost free and means the remanining 120 | # code doesn't need to worry about the rank. 121 | z = tf.reshape(logits, [obs, dims]) 122 | 123 | # sort z 124 | z_sorted, _ = tf.nn.top_k(z, k=dims) 125 | 126 | # calculate k(z) 127 | z_cumsum = tf.math.cumsum(z_sorted, axis=-1) 128 | k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype) 129 | z_check = 1 + k * z_sorted > z_cumsum 130 | # because the z_check vector is always [1,1,...1,0,0,...0] finding the 131 | # (index + 1) of the last `1` is the same as just summing the number of 1. 132 | k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1) 133 | 134 | # calculate tau(z) 135 | # If there are inf values or all values are -inf, the k_z will be zero, 136 | # this is mathematically invalid and will also cause the gather_nd to fail. 137 | # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then 138 | # fixed later (see p_safe) by returning p = nan. This results in the same 139 | # behavior as softmax. 140 | k_z_safe = tf.math.maximum(k_z, 1) 141 | indices = tf.stack([tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1) 142 | tau_sum = tf.gather_nd(z_cumsum, indices) 143 | tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype) 144 | 145 | # calculate p 146 | p = tf.math.maximum(tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1)) 147 | # If k_z = 0 or if z = nan, then the input is invalid 148 | p_safe = tf.where( 149 | tf.expand_dims( 150 | tf.math.logical_or(tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])), 151 | axis=-1, 152 | ), 153 | tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)), 154 | p, 155 | ) 156 | 157 | # Reshape back to original size 158 | p_safe = tf.reshape(p_safe, shape_op) 159 | return p_safe 160 | 161 | 162 | def sparsemax_loss(logits, sparsemax, labels, name: Optional[str] = None) -> tf.Tensor: 163 | r"""Sparsemax loss function ([1]). 164 | 165 | Computes the generalized multi-label classification loss for the sparsemax 166 | function. The implementation is a reformulation of the original loss 167 | function such that it uses the sparsemax probability output instead of the 168 | internal \( \tau \) variable. However, the output is identical to the original 169 | loss function. 170 | 171 | [1]: https://arxiv.org/abs/1602.02068 172 | 173 | Args: 174 | logits: `Tensor`. Must be one of the following types: `float32`, 175 | `float64`. 176 | sparsemax: `Tensor`. Must have the same type as `logits`. 177 | labels: `Tensor`. Must have the same type as `logits`. 178 | name: name for the operation (optional). 179 | 180 | Returns: 181 | A `Tensor`. Has the same type as `logits`. 182 | """ 183 | logits = tf.convert_to_tensor(logits, name="logits") 184 | sparsemax = tf.convert_to_tensor(sparsemax, name="sparsemax") 185 | labels = tf.convert_to_tensor(labels, name="labels") 186 | 187 | # In the paper, they call the logits z. 188 | # A constant can be substracted from logits to make the algorithm 189 | # more numerically stable in theory. However, there are really no major 190 | # source numerical instability in this algorithm. 191 | z = logits 192 | 193 | # sum over support 194 | # Use a conditional where instead of a multiplication to support z = -inf. 195 | # If z = -inf, and there is no support (sparsemax = 0), a multiplication 196 | # would cause 0 * -inf = nan, which is not correct in this case. 197 | sum_s = tf.where( 198 | tf.math.logical_or(sparsemax > 0, tf.math.is_nan(sparsemax)), 199 | sparsemax * (z - 0.5 * sparsemax), 200 | tf.zeros_like(sparsemax), 201 | ) 202 | 203 | # - z_k + ||q||^2 204 | q_part = labels * (0.5 * labels - z) 205 | # Fix the case where labels = 0 and z = -inf, where q_part would 206 | # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for 207 | # z = -inf should be consideredself. 208 | # The code below also coveres the case where z = inf. Howeverm in this 209 | # caose the sparsemax will be nan, which means the sum_s will also be nan, 210 | # therefor this case doesn't need addtional special treatment. 211 | q_part_safe = tf.where( 212 | tf.math.logical_and(tf.math.equal(labels, 0), tf.math.is_inf(z)), 213 | tf.zeros_like(z), 214 | q_part, 215 | ) 216 | 217 | return tf.math.reduce_sum(sum_s + q_part_safe, axis=1) 218 | 219 | 220 | @tf.function 221 | @tf.keras.utils.register_keras_serializable(package="Addons") 222 | def sparsemax_loss_from_logits( 223 | y_true, logits_pred 224 | ) -> tf.Tensor: 225 | y_pred = sparsemax(logits_pred) 226 | loss = sparsemax_loss(logits_pred, y_pred, y_true) 227 | return loss 228 | 229 | 230 | @tf.keras.utils.register_keras_serializable(package="Addons") 231 | class SparsemaxLoss(tf.keras.losses.Loss): 232 | """Sparsemax loss function. 233 | 234 | Computes the generalized multi-label classification loss for the sparsemax 235 | function. 236 | 237 | Because the sparsemax loss function needs both the probability output and 238 | the logits to compute the loss value, `from_logits` must be `True`. 239 | 240 | Because it computes the generalized multi-label loss, the shape of both 241 | `y_pred` and `y_true` must be `[batch_size, num_classes]`. 242 | 243 | Args: 244 | from_logits: Whether `y_pred` is expected to be a logits tensor. Default 245 | is `True`, meaning `y_pred` is the logits. 246 | reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to 247 | loss. Default value is `SUM_OVER_BATCH_SIZE`. 248 | name: Optional name for the op 249 | """ 250 | 251 | @typechecked 252 | def __init__( 253 | self, 254 | from_logits: bool = True, 255 | reduction: str = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE, 256 | name: str = "sparsemax_loss", 257 | ): 258 | if from_logits is not True: 259 | raise ValueError("from_logits must be True") 260 | 261 | super().__init__(name=name, reduction=reduction) 262 | self.from_logits = from_logits 263 | 264 | def call(self, y_true, y_pred): 265 | return sparsemax_loss_from_logits(y_true, y_pred) 266 | 267 | def get_config(self): 268 | config = { 269 | "from_logits": self.from_logits, 270 | } 271 | base_config = super().get_config() 272 | return {**base_config, **config} 273 | -------------------------------------------------------------------------------- /riid/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module provides custom model metrics.""" 5 | import numpy as np 6 | import sklearn 7 | 8 | from riid import SampleSet 9 | 10 | 11 | def multi_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float: 12 | """Calculate a measure of the F1 score of two tensors. 13 | 14 | Values for `y_true` and `y_pred` are assumed to sum to 1. 15 | 16 | Args: 17 | y_true: list of ground truth 18 | y_pred: list of predictions to compare against the ground truth 19 | 20 | Returns: 21 | Multi F1-score value(s) 22 | """ 23 | from keras.api import ops 24 | 25 | diff = y_true - y_pred 26 | negs = ops.clip(diff, -1.0, 0.0) 27 | false_positive = -ops.sum(negs, axis=-1) 28 | true_positive = 1.0 - false_positive 29 | 30 | return ops.mean(true_positive) 31 | 32 | 33 | def single_f1(y_true: np.ndarray, y_pred: np.ndarray): 34 | """Compute the weighted F1 score for the maximum prediction and maximum ground truth. 35 | 36 | Values for `y_true` and `y_pred` are assumed to sum to 1. 37 | 38 | Args: 39 | y_true: list of ground truth 40 | y_pred: list of predictions to compare against the ground truth 41 | 42 | Returns: 43 | F1-score value(s) 44 | """ 45 | import tensorflow as tf 46 | from keras.api import ops 47 | 48 | a = tf.dtypes.cast(y_true == ops.max(y_true, axis=1)[:, None], tf.float32) 49 | b = tf.dtypes.cast(y_pred == ops.max(y_pred, axis=1)[:, None], tf.float32) 50 | 51 | TP_mat = tf.dtypes.cast(ops.all(tf.stack([a, b]), axis=0), tf.float32) 52 | FP_mat = tf.dtypes.cast(ops.all(tf.stack([a != b, b == 1]), axis=0), tf.float32) 53 | FN_mat = tf.dtypes.cast(ops.all(tf.stack([a != b, a == 1]), axis=0), tf.float32) 54 | 55 | TPs = ops.sum(TP_mat, axis=0) 56 | FPs = ops.sum(FP_mat, axis=0) 57 | FNs = ops.sum(FN_mat, axis=0) 58 | 59 | F1s = 2 * TPs / (2*TPs + FNs + FPs + tf.fill(tf.shape(TPs), tf.keras.backend.epsilon())) 60 | 61 | support = ops.sum(a, axis=0) 62 | f1 = ops.sum(F1s * support) / ops.sum(support) 63 | return f1 64 | 65 | 66 | def harmonic_mean(x, y): 67 | """Compute the harmonic mean of two same-dimensional arrays. 68 | 69 | Used to compute F1 score: 70 | 71 | ``` 72 | f1_score = harmonic_mean(precision, recall) 73 | ``` 74 | 75 | Args: 76 | x (array-like): numeric or array_like of numerics 77 | y (array-like): numeric or array_like of numerics matching the shape/type of `x` 78 | 79 | Returns: 80 | Array-like harmonic mean of `x` and `y` 81 | """ 82 | return 2 * x * y / (x + y) 83 | 84 | 85 | def precision_recall_curve(ss: SampleSet, smooth: bool = True, multiclass: bool = None, 86 | include_micro: bool = True, target_level: str = "Isotope", 87 | minimum_contribution: float = 0.01): 88 | """Similar to `sklearn.metrics.precision_recall_curve`, however, this function 89 | computes the precision and recall for each class, and supports both multi-class 90 | and multi-label problems. 91 | 92 | The reason this is necessary is that in multi-class problems, for a single sample, 93 | all predictions are discarded except for the argmax. 94 | 95 | Args: 96 | ss: `SampleSet` that predictions were generated on 97 | smooth: if True, precision is smoothing is applied to make a monotonically 98 | decreasing precision function 99 | multiclass: set to True if this is a multi-class (i.e. y_true is one-hot) as 100 | opposed to multi-label (i.e. labels are not mutually exclusive). Ff True, 101 | predictions will be masked such that non-argmax predictions are set to zero 102 | (this prevents inflating the precision by continuing past a point that could 103 | be pragmatically useful). Furthermore, in the multiclass case the recall is 104 | not guaranteed to reach 1.0. 105 | include_micro: if True, compute an additional precision and recall for the 106 | micro-average across all labels and put it under entry `"micro"` 107 | target_level: `SampleSet.sources` and `SampleSet.prediction_probas` column level to use 108 | minimum_contribution: threshold for a source to be considered a ground truth positive 109 | label. if this is set to `None` the raw mixture ratios will be used as y_true. 110 | 111 | Returns: 112 | precision (dict): dict with keys for each label and values that are the 113 | monotonically increasing precision values at each threshold 114 | recall (dict): dict with keys for each label and values that are the 115 | monotonically decreasing recall values at each threshold 116 | thresholds (dict): dict with keys for each label and values that are the 117 | monotonically increasing thresholds on the decision function used to compute 118 | precision and recall 119 | 120 | References: 121 | - [Precision smoothing]( 122 | https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173) 123 | 124 | """ 125 | y_true = ss.sources.T.groupby(target_level, sort=False).sum().T 126 | if minimum_contribution is not None: 127 | y_true = (y_true > minimum_contribution).astype(int) 128 | y_pred = ss.prediction_probas.T.groupby(target_level, sort=False).sum().T 129 | 130 | # switch from pandas to numpy 131 | labels = y_true.columns 132 | n_classes = len(labels) 133 | y_true = y_true.values.copy() 134 | y_pred = y_pred.values.copy() 135 | 136 | if y_pred.shape != y_true.shape: 137 | raise ValueError( 138 | f"Shape mismatch between truth and predictions, " 139 | f"{y_true.shape} != {y_pred.shape}. " 140 | f"It is possible that the `target_level` is incorrect." 141 | ) 142 | 143 | if multiclass is None: 144 | # infer whether multi-class or multi-label 145 | multiclass = not np.any(y_true.sum(axis=1) != 1) 146 | 147 | # drop nans 148 | notnan = ~np.isnan(y_pred).any(axis=1) 149 | y_true = y_true[notnan, :] 150 | y_pred = y_pred[notnan, :] 151 | 152 | y_pred_min = None 153 | if multiclass: 154 | # shift predictions to force positive 155 | if np.any(y_pred < 0): 156 | y_pred_min = y_pred.min() 157 | y_pred -= y_pred_min 158 | 159 | # mask the predictions by argmax 160 | pred_mask = np.eye(n_classes)[np.argmax(y_pred, axis=1)] 161 | # mask the predictions 162 | y_pred *= pred_mask 163 | 164 | precision = dict() 165 | recall = dict() 166 | thresholds = dict() 167 | for i, label in enumerate(labels): 168 | precision[label], recall[label], thresholds[label] = _pr_curve( 169 | y_true[:, i], y_pred[:, i], multiclass=multiclass, smooth=smooth 170 | ) 171 | 172 | if include_micro: 173 | # A "micro-average": quantifying score on all classes jointly 174 | precision["micro"], recall["micro"], thresholds["micro"] = _pr_curve( 175 | y_true.ravel(), y_pred.ravel(), multiclass=multiclass, smooth=smooth 176 | ) 177 | 178 | # un-shift thresholds if predictions were shifted 179 | if y_pred_min is not None: 180 | thresholds = {k: v + y_pred_min for k, v in thresholds.items()} 181 | 182 | return precision, recall, thresholds 183 | 184 | 185 | def average_precision_score(precision, recall): 186 | """Compute the average precision (area under the curve) for each precision/recall 187 | pair. 188 | 189 | Args: 190 | precision (dict): return value of `ctutil.evaluation.precision_recall_curve()` 191 | recall (dict): return value of `ctutil.evaluation.precision_recall_curve()` 192 | 193 | Returns: 194 | (dict): average precision values (float) for each label in precision/recall 195 | """ 196 | return {label: _integrate(recall[label], precision[label]) for label in recall} 197 | 198 | 199 | def _step(x): 200 | """Compute the right going maximum of `x` and all previous values of `x`. 201 | 202 | Args: 203 | x (array-like): 1D array to process 204 | 205 | Returns: 206 | (array-like): right-going maximum of `x` 207 | 208 | """ 209 | y = np.array(x) 210 | for i in range(1, len(y)): 211 | y[i] = max(y[i], y[i - 1]) 212 | return y 213 | 214 | 215 | def _integrate(x, y, y_left=True): 216 | """Integrate an (x, y) function pair. 217 | 218 | Args: 219 | x (array-like): 1D array of x values 220 | y (array-like): 1D array of y values 221 | y_left: if true, omit the last value of y, else, omit the first value 222 | 223 | Returns: 224 | (float): integrated "area under the curve" 225 | 226 | """ 227 | delta_x = x[1:] - x[:-1] 228 | y_trimmed = y[:-1] if y_left else y[1:] 229 | return np.abs(np.sum(delta_x * y_trimmed)) 230 | 231 | 232 | def _pr_curve(y_true, y_pred, multiclass, smooth): 233 | precision, recall, thresholds = sklearn.metrics.precision_recall_curve( 234 | y_true, y_pred 235 | ) 236 | 237 | if smooth: 238 | precision = _step(precision) 239 | 240 | if multiclass: 241 | # remove the point where threshold=0 and recall=1 242 | precision = precision[1:] 243 | recall = recall[1:] 244 | thresholds = thresholds[1:] 245 | 246 | return precision, recall, thresholds 247 | 248 | 249 | def build_keras_semisupervised_metric_func(keras_metric_func, activation_func, 250 | n_labels): 251 | def metric_func(y_true, y_pred): 252 | return keras_metric_func(y_true[:, :n_labels], activation_func(y_pred)) 253 | metric_func.__name__ = keras_metric_func.__class__.__name__ 254 | 255 | return metric_func 256 | -------------------------------------------------------------------------------- /riid/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains PyRIID models.""" 5 | from riid.models.bayes import PoissonBayesClassifier 6 | from riid.models.neural_nets import LabelProportionEstimator, MLPClassifier 7 | from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv1, ARADv2 8 | 9 | __all__ = ["PoissonBayesClassifier", "LabelProportionEstimator", "MLPClassifier", 10 | "ARADLatentPredictor", "ARADv1", "ARADv2"] 11 | -------------------------------------------------------------------------------- /riid/models/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains functionality shared across all PyRIID models.""" 5 | import json 6 | import os 7 | from pathlib import Path 8 | import uuid 9 | from abc import abstractmethod 10 | from enum import Enum 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | import tf2onnx 15 | from keras.api.models import Model 16 | from keras.api.utils import get_custom_objects 17 | 18 | import riid 19 | from riid import SampleSet, SpectraState 20 | from riid.data.labeling import label_to_index_element 21 | from riid.losses import mish 22 | from riid.metrics import multi_f1, single_f1 23 | 24 | get_custom_objects().update({ 25 | "multi_f1": multi_f1, 26 | "single_f1": single_f1, 27 | "mish": mish, 28 | }) 29 | 30 | 31 | class ModelInput(int, Enum): 32 | """Enumerates the potential input sources for a model.""" 33 | GrossSpectrum = 0 34 | BackgroundSpectrum = 1 35 | ForegroundSpectrum = 2 36 | 37 | 38 | class PyRIIDModel: 39 | """Base class for PyRIID models.""" 40 | 41 | def __init__(self, *args, **kwargs): 42 | self._info = {} 43 | self._temp_file_path = "temp_model.json" 44 | self._custom_objects = {} 45 | self._initialize_info() 46 | 47 | @property 48 | def seeds(self): 49 | return self._info["seeds"] 50 | 51 | @seeds.setter 52 | def seeds(self, value): 53 | self._info["seeds"] = value 54 | 55 | @property 56 | def info(self): 57 | return self._info 58 | 59 | @info.setter 60 | def info(self, value): 61 | self._info = value 62 | 63 | @property 64 | def target_level(self): 65 | return self._info["target_level"] 66 | 67 | @target_level.setter 68 | def target_level(self, value): 69 | if value in SampleSet.SOURCES_MULTI_INDEX_NAMES: 70 | self._info["target_level"] = value 71 | else: 72 | msg = ( 73 | f"Target level '{value}' is invalid. " 74 | f"Acceptable levels: {SampleSet.SOURCES_MULTI_INDEX_NAMES}" 75 | ) 76 | raise ValueError(msg) 77 | 78 | @property 79 | def model(self) -> Model: 80 | return self._model 81 | 82 | @model.setter 83 | def model(self, value: Model): 84 | self._model = value 85 | 86 | @property 87 | def model_id(self): 88 | return self._info["model_id"] 89 | 90 | @model_id.setter 91 | def model_id(self, value): 92 | self._info["model_id"] = value 93 | 94 | @property 95 | def model_inputs(self): 96 | return self._info["model_inputs"] 97 | 98 | @model_inputs.setter 99 | def model_inputs(self, value): 100 | self._info["model_inputs"] = value 101 | 102 | @property 103 | def model_outputs(self): 104 | return self._info["model_outputs"] 105 | 106 | @model_outputs.setter 107 | def model_outputs(self, value): 108 | self._info["model_outputs"] = value 109 | 110 | def get_model_outputs_as_label_tuples(self): 111 | return [ 112 | label_to_index_element(v, self.target_level) for v in self.model_outputs 113 | ] 114 | 115 | def _get_model_dict(self) -> dict: 116 | model_json = self.model.to_json() 117 | model_dict = json.loads(model_json) 118 | model_weights = self.model.get_weights() 119 | model_dict = { 120 | "info": self._info, 121 | "model": model_dict, 122 | "weights": model_weights, 123 | } 124 | return model_dict 125 | 126 | def _get_model_str(self) -> str: 127 | model_dict = self._get_model_dict() 128 | model_str = json.dumps(model_dict, indent=4, cls=PyRIIDModelJsonEncoder) 129 | return model_str 130 | 131 | def _initialize_info(self): 132 | init_info = { 133 | "model_id": str(uuid.uuid4()), 134 | "model_type": self.__class__.__name__, 135 | "normalization": SpectraState.Unknown, 136 | "pyriid_version": riid.__version__, 137 | } 138 | self._update_info(**init_info) 139 | 140 | def _update_info(self, **kwargs): 141 | self._info.update(kwargs) 142 | 143 | def _update_custom_objects(self, key, value): 144 | self._custom_objects.update({key: value}) 145 | 146 | def load(self, model_path: str): 147 | """Load the model from a path. 148 | 149 | Args: 150 | model_path: path from which to load the model. 151 | """ 152 | if not os.path.exists(model_path): 153 | raise ValueError("Model file does not exist.") 154 | 155 | with open(model_path) as fin: 156 | model = json.load(fin) 157 | 158 | model_str = json.dumps(model["model"]) 159 | self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects) 160 | self.model.set_weights([np.array(x) for x in model["weights"]]) 161 | self.info = model["info"] 162 | 163 | def save(self, model_path: str, overwrite=False): 164 | """Save the model to a path. 165 | 166 | Args: 167 | model_path: path at which to save the model. 168 | overwrite: whether to overwrite an existing file if it already exists. 169 | 170 | Raises: 171 | `ValueError` when the given path already exists 172 | """ 173 | if os.path.exists(model_path) and not overwrite: 174 | raise ValueError("Model file already exists.") 175 | 176 | model_str = self._get_model_str() 177 | with open(model_path, "w") as fout: 178 | fout.write(model_str) 179 | 180 | def to_onnx(self, model_path, **tf2onnx_kwargs: dict): 181 | """Convert the model to an ONNX model. 182 | 183 | Args: 184 | model_path: path at which to save the model 185 | tf2onnx_kwargs: additional kwargs to pass to the conversion 186 | """ 187 | model_path = Path(model_path) 188 | if not str(model_path).endswith(riid.ONNX_MODEL_FILE_EXTENSION): 189 | raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}") 190 | if model_path.exists(): 191 | raise ValueError("Model file already exists.") 192 | 193 | tf2onnx.convert.from_keras( 194 | self.model, 195 | input_signature=[ 196 | tf.TensorSpec( 197 | shape=input_tensor.shape, 198 | dtype=input_tensor.dtype, 199 | name=input_tensor.name 200 | ) 201 | for input_tensor in self.model.inputs 202 | ], 203 | output_path=str(model_path), 204 | **tf2onnx_kwargs 205 | ) 206 | 207 | def to_tflite(self, model_path, quantize: bool = False, prune: bool = False): 208 | """Convert the model to a TFLite model and optionally applying quantization or pruning. 209 | 210 | Note: requires export to SavedModel format first, then conversion to TFLite occurs. 211 | 212 | Args: 213 | model_path: file path at which to save the model 214 | quantize: whether to apply quantization 215 | prune: whether to apply pruning 216 | """ 217 | model_path = Path(model_path) 218 | if not str(model_path).endswith(riid.TFLITE_MODEL_FILE_EXTENSION): 219 | raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}") 220 | if model_path.exists(): 221 | raise ValueError("Model file already exists.") 222 | 223 | optimizations = [] 224 | if quantize: 225 | optimizations.append(tf.lite.Optimize.DEFAULT) 226 | if prune: 227 | optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY) 228 | 229 | saved_model_dir = model_path.stem 230 | self.model.export(saved_model_dir) 231 | converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir)) 232 | converter.optimizations = optimizations 233 | tflite_model = converter.convert() 234 | 235 | with open(model_path, "wb") as fout: 236 | fout.write(tflite_model) 237 | 238 | @abstractmethod 239 | def fit(self): 240 | pass 241 | 242 | @abstractmethod 243 | def predict(self): 244 | pass 245 | 246 | 247 | class PyRIIDModelJsonEncoder(json.JSONEncoder): 248 | """Custom JSON encoder for saving models. 249 | """ 250 | def default(self, o): 251 | """Converts certain types to JSON-compatible types. 252 | """ 253 | if isinstance(o, np.ndarray): 254 | return o.tolist() 255 | elif isinstance(o, np.float32): 256 | return o.astype(float) 257 | 258 | return super().default(o) 259 | -------------------------------------------------------------------------------- /riid/models/bayes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains the Poisson-Bayes classifier.""" 5 | import numpy as np 6 | import pandas as pd 7 | import tensorflow as tf 8 | from keras.api.layers import Add, Input, Multiply, Subtract 9 | from keras.api.models import Model 10 | 11 | from riid import SampleSet 12 | from riid.models.base import PyRIIDModel 13 | from riid.models.layers import (ClipByValueLayer, DivideLayer, ExpandDimsLayer, 14 | PoissonLogProbabilityLayer, ReduceMaxLayer, 15 | ReduceSumLayer, SeedLayer) 16 | 17 | 18 | class PoissonBayesClassifier(PyRIIDModel): 19 | """Classifier calculating the conditional Poisson log probability of each seed spectrum 20 | given the measurement. 21 | 22 | This implementation is an adaptation of a naive Bayes classifier, a formal description of 23 | which can be found in ESLII: 24 | 25 | Hastie, Trevor, et al. The elements of statistical learning: data mining, inference, and 26 | prediction. Vol. 2. New York. Springer, 2009. 27 | 28 | For this model, each spectrum channel is treated as a Poisson random variable and 29 | expectations are provided by the user in the form of seeds rather than learned. 30 | Like the model described in ESLII, all classes are considered equally likely and features 31 | are assumed to be conditionally independent. 32 | """ 33 | def __init__(self): 34 | super().__init__() 35 | 36 | self._update_custom_objects("ReduceSumLayer", ReduceSumLayer) 37 | self._update_custom_objects("ReduceMaxLayer", ReduceMaxLayer) 38 | self._update_custom_objects("DivideLayer", DivideLayer) 39 | self._update_custom_objects("ExpandDimsLayer", ExpandDimsLayer) 40 | self._update_custom_objects("ClipByValueLayer", ClipByValueLayer) 41 | self._update_custom_objects("PoissonLogProbabilityLayer", PoissonLogProbabilityLayer) 42 | self._update_custom_objects("SeedLayer", SeedLayer) 43 | 44 | def fit(self, seeds_ss: SampleSet): 45 | """Construct a TF-based implementation of a poisson-bayes classifier in terms 46 | of the given seeds. 47 | 48 | Args: 49 | seeds_ss: `SampleSet` of `n` foreground seed spectra where `n` >= 1. 50 | 51 | Raises: 52 | - `ValueError` when no seeds are provided 53 | - `NegativeSpectrumError` when any seed spectrum has negative counts in any bin 54 | - `ZeroTotalCountsError` when any seed spectrum contains zero total counts 55 | """ 56 | if seeds_ss.n_samples <= 0: 57 | raise ValueError("Argument 'seeds_ss' must contain at least one seed.") 58 | if (seeds_ss.spectra.values < 0).any(): 59 | msg = "Argument 'seeds_ss' can't contain any spectra with negative values." 60 | raise NegativeSpectrumError(msg) 61 | if (seeds_ss.spectra.values.sum(axis=1) <= 0).any(): 62 | msg = "Argument 'seeds_ss' can't contain any spectra with zero total counts." 63 | raise ZeroTotalCountsError(msg) 64 | 65 | self._seeds = tf.convert_to_tensor( 66 | seeds_ss.spectra.values, 67 | dtype=tf.float32 68 | ) 69 | 70 | # Inputs 71 | gross_spectrum_input = Input(shape=(seeds_ss.n_channels,), 72 | name="gross_spectrum") 73 | gross_live_time_input = Input(shape=(), 74 | name="gross_live_time") 75 | bg_spectrum_input = Input(shape=(seeds_ss.n_channels,), 76 | name="bg_spectrum") 77 | bg_live_time_input = Input(shape=(), 78 | name="bg_live_time") 79 | model_inputs = ( 80 | gross_spectrum_input, 81 | gross_live_time_input, 82 | bg_spectrum_input, 83 | bg_live_time_input, 84 | ) 85 | 86 | # Input statistics 87 | gross_total_counts = ReduceSumLayer(name="gross_total_counts")(gross_spectrum_input, axis=1) 88 | bg_total_counts = ReduceSumLayer(name="bg_total_counts")(bg_spectrum_input, axis=1) 89 | bg_count_rate = DivideLayer(name="bg_count_rate")([bg_total_counts, bg_live_time_input]) 90 | 91 | gross_spectrum_input_expanded = ExpandDimsLayer( 92 | name="gross_spectrum_input_expanded" 93 | )(gross_spectrum_input, axis=1) 94 | bg_total_counts_expanded = ExpandDimsLayer( 95 | name="bg_total_counts_expanded" 96 | )(bg_total_counts, axis=1) 97 | 98 | # Expectations 99 | seed_layer = SeedLayer(self._seeds)(model_inputs) 100 | seed_layer_expanded = ExpandDimsLayer()(seed_layer, axis=0) 101 | expected_bg_counts = Multiply( 102 | trainable=False, 103 | name="expected_bg_counts" 104 | )([bg_count_rate, gross_live_time_input]) 105 | expected_bg_counts_expanded = ExpandDimsLayer( 106 | name="expected_bg_counts_expanded" 107 | )(expected_bg_counts, axis=1) 108 | normalized_bg_spectrum = DivideLayer( 109 | name="normalized_bg_spectrum" 110 | )([bg_spectrum_input, bg_total_counts_expanded]) 111 | expected_bg_spectrum = Multiply( 112 | trainable=False, 113 | name="expected_bg_spectrum" 114 | )([normalized_bg_spectrum, expected_bg_counts_expanded]) 115 | expected_fg_counts = Subtract( 116 | trainable=False, 117 | name="expected_fg_counts" 118 | )([gross_total_counts, expected_bg_counts]) 119 | expected_fg_counts_expanded = ExpandDimsLayer( 120 | name="expected_fg_counts_expanded" 121 | )(expected_fg_counts, axis=-1) 122 | expected_fg_counts_expanded2 = ExpandDimsLayer( 123 | name="expected_fg_counts_expanded2" 124 | )(expected_fg_counts_expanded, axis=-1) 125 | expected_fg_spectrum = Multiply( 126 | trainable=False, 127 | name="expected_fg_spectrum" 128 | )([seed_layer_expanded, expected_fg_counts_expanded2]) 129 | max_fg_value = ReduceMaxLayer( 130 | name="max_fg_value" 131 | )(expected_fg_spectrum) 132 | expected_fg_spectrum = ClipByValueLayer( 133 | name="clip_expected_fg_spectrum" 134 | )(expected_fg_spectrum, clip_value_min=1e-8, clip_value_max=max_fg_value) 135 | expected_bg_spectrum_expanded = ExpandDimsLayer( 136 | name="expected_bg_spectrum_expanded" 137 | )(expected_bg_spectrum, axis=1) 138 | expected_gross_spectrum = Add( 139 | trainable=False, 140 | name="expected_gross_spectrum" 141 | )([expected_fg_spectrum, expected_bg_spectrum_expanded]) 142 | 143 | # Compute probabilities 144 | log_probabilities = PoissonLogProbabilityLayer( 145 | name="log_probabilities" 146 | )([expected_gross_spectrum, gross_spectrum_input_expanded]) 147 | summed_log_probabilities = ReduceSumLayer( 148 | name="summed_log_probabilities" 149 | )(log_probabilities, axis=2) 150 | 151 | # Assemble model 152 | self.model = Model(model_inputs, summed_log_probabilities) 153 | self.model.compile() 154 | 155 | self.target_level = "Seed" 156 | sources_df = seeds_ss.sources.T.groupby(self.target_level, sort=False).sum().T 157 | self.model_outputs = sources_df.columns.values.tolist() 158 | 159 | def predict(self, gross_ss: SampleSet, bg_ss: SampleSet, 160 | normalize_scores: bool = False, verbose: bool = False): 161 | """Compute the conditional Poisson log probability between spectra in a `SampleSet` and 162 | the seeds to which the model was fit. 163 | 164 | Args: 165 | gross_ss: `SampleSet` of `n` gross spectra where `n` >= 1 166 | bg_ss: `SampleSet` of `n` background spectra where `n` >= 1 167 | normalize_scores (bool): whether to normalize prediction probabilities 168 | When True, this makes the probabilities positive and rescales them 169 | by the minimum value present in given the dataset. 170 | While this can be helpful in terms of visualizing probabilities in log scale, 171 | it can adversely affects one's ability to detect significantly anomalous signatures. 172 | """ 173 | gross_spectra = tf.convert_to_tensor(gross_ss.spectra.values, dtype=tf.float32) 174 | gross_lts = tf.convert_to_tensor(gross_ss.info.live_time.values, dtype=tf.float32) 175 | bg_spectra = tf.convert_to_tensor(bg_ss.spectra.values, dtype=tf.float32) 176 | bg_lts = tf.convert_to_tensor(bg_ss.info.live_time.values, dtype=tf.float32) 177 | 178 | prediction_probas = self.model.predict(( 179 | gross_spectra, gross_lts, bg_spectra, bg_lts 180 | ), batch_size=512, verbose=verbose) 181 | 182 | # Normalization 183 | if normalize_scores: 184 | rows_min = np.min(prediction_probas, axis=1) 185 | prediction_probas = prediction_probas - rows_min[:, np.newaxis] 186 | 187 | gross_ss.prediction_probas = pd.DataFrame( 188 | prediction_probas, 189 | columns=pd.MultiIndex.from_tuples( 190 | self.get_model_outputs_as_label_tuples(), 191 | names=SampleSet.SOURCES_MULTI_INDEX_NAMES 192 | ) 193 | ) 194 | 195 | 196 | class ZeroTotalCountsError(ValueError): 197 | """All spectrum channels are zero.""" 198 | pass 199 | 200 | 201 | class NegativeSpectrumError(ValueError): 202 | """At least one spectrum channel is negative.""" 203 | pass 204 | -------------------------------------------------------------------------------- /riid/models/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains custom Keras layers.""" 5 | import tensorflow as tf 6 | from keras.api.layers import Layer 7 | 8 | 9 | class ReduceSumLayer(Layer): 10 | def __init__(self, **kwargs): 11 | super().__init__(**kwargs) 12 | 13 | def call(self, x, axis): 14 | return tf.reduce_sum(x, axis=axis) 15 | 16 | 17 | class ReduceMaxLayer(Layer): 18 | def __init__(self, **kwargs): 19 | super().__init__(**kwargs) 20 | 21 | def call(self, x): 22 | return tf.reduce_max(x) 23 | 24 | 25 | class DivideLayer(Layer): 26 | def __init__(self, **kwargs): 27 | super().__init__(**kwargs) 28 | 29 | def call(self, x): 30 | return tf.divide(x[0], x[1]) 31 | 32 | 33 | class ExpandDimsLayer(Layer): 34 | def __init__(self, **kwargs): 35 | super().__init__(**kwargs) 36 | 37 | def call(self, x, axis): 38 | return tf.expand_dims(x, axis=axis) 39 | 40 | 41 | class ClipByValueLayer(Layer): 42 | def __init__(self, **kwargs): 43 | super().__init__(**kwargs) 44 | 45 | def call(self, x, clip_value_min, clip_value_max): 46 | return tf.clip_by_value(x, clip_value_min=clip_value_min, clip_value_max=clip_value_max) 47 | 48 | 49 | class PoissonLogProbabilityLayer(Layer): 50 | def __init__(self, **kwargs): 51 | super().__init__(**kwargs) 52 | 53 | def call(self, x): 54 | exp, value = x 55 | log_probas = tf.math.xlogy(value, exp) - exp - tf.math.lgamma(value + 1) 56 | return log_probas 57 | 58 | 59 | class SeedLayer(Layer): 60 | def __init__(self, seeds, **kwargs): 61 | super(SeedLayer, self).__init__(**kwargs) 62 | self.seeds = tf.convert_to_tensor(seeds) 63 | 64 | def get_config(self): 65 | config = super().get_config() 66 | config.update({ 67 | "seeds": self.seeds.numpy().tolist(), 68 | }) 69 | return config 70 | 71 | def call(self, inputs): 72 | return self.seeds 73 | 74 | 75 | class L1NormLayer(Layer): 76 | def __init__(self, **kwargs): 77 | super().__init__(**kwargs) 78 | 79 | def call(self, inputs): 80 | sums = tf.reduce_sum(inputs, axis=-1) 81 | l1_norm = inputs / tf.reshape(sums, (-1, 1)) 82 | return l1_norm 83 | -------------------------------------------------------------------------------- /riid/models/neural_nets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains neural network-based classifiers and regressors.""" 5 | from riid.models.neural_nets.basic import MLPClassifier 6 | from riid.models.neural_nets.lpe import LabelProportionEstimator 7 | 8 | __all__ = ["LabelProportionEstimator", "MLPClassifier"] 9 | -------------------------------------------------------------------------------- /riid/models/neural_nets/basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module contains a simple neural network.""" 5 | import pandas as pd 6 | import tensorflow as tf 7 | from keras.api.callbacks import EarlyStopping 8 | from keras.api.layers import Dense, Input, Dropout 9 | from keras.api.losses import CategoricalCrossentropy 10 | from keras.api.metrics import F1Score, Precision, Recall 11 | from keras.api.models import Model 12 | from keras.api.optimizers import Adam 13 | from keras.api.regularizers import l1, l2 14 | from keras.api.utils import split_dataset 15 | 16 | from riid import SampleSet, SpectraType 17 | from riid.models.base import ModelInput, PyRIIDModel 18 | 19 | 20 | class MLPClassifier(PyRIIDModel): 21 | """Multi-layer perceptron classifier.""" 22 | def __init__(self, activation=None, loss=None, optimizer=None, 23 | metrics=None, l2_alpha: float = 1e-4, 24 | activity_regularizer=None, final_activation=None, 25 | dense_layer_size=None, dropout=None): 26 | """ 27 | Args: 28 | activation: activate function to use for each dense layer 29 | loss: loss function to use for training 30 | optimizer: tensorflow optimizer or optimizer name to use for training 31 | metrics: list of metrics to be evaluating during training 32 | l2_alpha: alpha value for the L2 regularization of each dense layer 33 | activity_regularizer: regularizer function applied each dense layer output 34 | final_activation: final activation function to apply to model output 35 | """ 36 | super().__init__() 37 | 38 | self.activation = activation 39 | self.loss = loss 40 | self.optimizer = optimizer 41 | self.final_activation = final_activation 42 | self.metrics = metrics 43 | self.l2_alpha = l2_alpha 44 | self.activity_regularizer = activity_regularizer 45 | self.final_activation = final_activation 46 | self.dense_layer_size = dense_layer_size 47 | self.dropout = dropout 48 | 49 | if self.activation is None: 50 | self.activation = "relu" 51 | if self.loss is None: 52 | self.loss = CategoricalCrossentropy() 53 | if optimizer is None: 54 | self.optimizer = Adam(learning_rate=0.01, clipnorm=0.001) 55 | if self.metrics is None: 56 | self.metrics = [F1Score(), Precision(), Recall()] 57 | if self.activity_regularizer is None: 58 | self.activity_regularizer = l1(0.0) 59 | if self.final_activation is None: 60 | self.final_activation = "softmax" 61 | 62 | self.model = None 63 | self._set_predict_fn() 64 | 65 | def fit(self, ss: SampleSet, batch_size: int = 200, epochs: int = 20, 66 | validation_split: float = 0.2, callbacks=None, 67 | patience: int = 15, es_monitor: str = "val_loss", 68 | es_mode: str = "min", es_verbose=0, target_level="Isotope", verbose: bool = False): 69 | """Fit a model to the given `SampleSet`(s). 70 | 71 | Args: 72 | ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either 73 | foreground (AKA, "net") or gross. 74 | batch_size: number of samples per gradient update 75 | epochs: maximum number of training iterations 76 | validation_split: percentage of the training data to use as validation data 77 | callbacks: list of callbacks to be passed to the TensorFlow `Model.fit()` method 78 | patience: number of epochs to wait for `EarlyStopping` object 79 | es_monitor: quantity to be monitored for `EarlyStopping` object 80 | es_mode: mode for `EarlyStopping` object 81 | es_verbose: verbosity level for `EarlyStopping` object 82 | target_level: `SampleSet.sources` column level to use 83 | verbose: whether to show detailed model training output 84 | 85 | Returns: 86 | `tf.History` object. 87 | 88 | Raises: 89 | `ValueError` when no spectra are provided as input 90 | """ 91 | if ss.n_samples <= 0: 92 | raise ValueError("No spectr[a|um] provided!") 93 | 94 | if ss.spectra_type == SpectraType.Gross: 95 | self.model_inputs = (ModelInput.GrossSpectrum,) 96 | elif ss.spectra_type == SpectraType.Foreground: 97 | self.model_inputs = (ModelInput.ForegroundSpectrum,) 98 | elif ss.spectra_type == SpectraType.Background: 99 | self.model_inputs = (ModelInput.BackgroundSpectrum,) 100 | else: 101 | raise ValueError(f"{ss.spectra_type} is not supported in this model.") 102 | 103 | X = ss.get_samples() 104 | source_contributions_df = ss.sources.T.groupby(target_level, sort=False).sum().T 105 | model_outputs = source_contributions_df.columns.values.tolist() 106 | Y = source_contributions_df.values 107 | 108 | spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) 109 | labels_tensor = tf.convert_to_tensor(Y, dtype=tf.float32) 110 | training_dataset = tf.data.Dataset.from_tensor_slices((spectra_tensor, labels_tensor)) 111 | training_dataset, validation_dataset = split_dataset( 112 | training_dataset, 113 | right_size=validation_split, 114 | shuffle=True 115 | ) 116 | training_dataset = training_dataset.batch(batch_size=batch_size) 117 | validation_dataset = validation_dataset.batch(batch_size=batch_size) 118 | 119 | if not self.model: 120 | inputs = Input(shape=(X.shape[1],), name="Spectrum") 121 | if self.dense_layer_size is None: 122 | dense_layer_size = X.shape[1] // 2 123 | else: 124 | dense_layer_size = self.dense_layer_size 125 | dense_layer = Dense( 126 | dense_layer_size, 127 | activation=self.activation, 128 | activity_regularizer=self.activity_regularizer, 129 | kernel_regularizer=l2(self.l2_alpha), 130 | )(inputs) 131 | if self.dropout is not None: 132 | last_layer = Dropout(self.dropout)(dense_layer) 133 | else: 134 | last_layer = dense_layer 135 | outputs = Dense(Y.shape[1], activation=self.final_activation)(last_layer) 136 | self.model = Model(inputs, outputs) 137 | self.model.compile(loss=self.loss, optimizer=self.optimizer, 138 | metrics=self.metrics) 139 | 140 | es = EarlyStopping( 141 | monitor=es_monitor, 142 | patience=patience, 143 | verbose=es_verbose, 144 | restore_best_weights=True, 145 | mode=es_mode, 146 | ) 147 | if callbacks: 148 | callbacks.append(es) 149 | else: 150 | callbacks = [es] 151 | 152 | history = self.model.fit( 153 | training_dataset, 154 | epochs=epochs, 155 | verbose=verbose, 156 | validation_data=validation_dataset, 157 | callbacks=callbacks, 158 | ) 159 | 160 | # Update model information 161 | self._update_info( 162 | target_level=target_level, 163 | model_outputs=model_outputs, 164 | normalization=ss.spectra_state, 165 | ) 166 | 167 | # Define the predict function with tf.function and input_signature 168 | self._set_predict_fn() 169 | 170 | return history 171 | 172 | def _set_predict_fn(self): 173 | self._predict_fn = tf.function( 174 | self._predict, 175 | experimental_relax_shapes=True 176 | ) 177 | 178 | def _predict(self, input_tensor): 179 | return self.model(input_tensor, training=False) 180 | 181 | def predict(self, ss: SampleSet, bg_ss: SampleSet = None): 182 | """Classify the spectra in the provided `SampleSet`(s). 183 | 184 | Results are stored inside the first SampleSet's prediction-related properties. 185 | 186 | Args: 187 | ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either 188 | foreground (AKA, "net") or gross 189 | bg_ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are background 190 | """ 191 | x_test = ss.get_samples().astype(float) 192 | if bg_ss: 193 | X = [x_test, bg_ss.get_samples().astype(float)] 194 | else: 195 | X = x_test 196 | 197 | spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) 198 | results = self._predict_fn(spectra_tensor) 199 | 200 | col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) 201 | col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] 202 | ss.prediction_probas = pd.DataFrame( 203 | data=results, 204 | columns=pd.MultiIndex.from_tuples( 205 | self.get_model_outputs_as_label_tuples(), 206 | names=col_level_subset 207 | ) 208 | ) 209 | 210 | ss.classified_by = self.model_id 211 | -------------------------------------------------------------------------------- /run_tests.ps1: -------------------------------------------------------------------------------- 1 | coverage run -m unittest discover -s tests/ -p *.py -v 2 | if ($LASTEXITCODE -ne 0) { throw "Tests failed!" } 3 | coverage report -i 4 | coverage xml -i 5 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | coverage run --source=./riid -m unittest tests/*.py 3 | coverage report -i --skip-empty 4 | -------------------------------------------------------------------------------- /tests/anomaly_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module tests the anomaly module.""" 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from riid import PassbySynthesizer, SeedMixer, get_dummy_seeds 10 | from riid.anomaly import PoissonNChannelEventDetector 11 | 12 | 13 | class TestAnomaly(unittest.TestCase): 14 | """Test class for Anomaly module.""" 15 | def setUp(self): 16 | """Test setup.""" 17 | pass 18 | 19 | def test_event_detector(self): 20 | random_state = 42 21 | rng = np.random.default_rng(random_state) 22 | 23 | SAMPLE_INTERVAL = 0.5 24 | BG_RATE = 300 25 | seeds_ss = get_dummy_seeds(100) 26 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg() 27 | mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng)\ 28 | .generate(1) 29 | events = PassbySynthesizer(events_per_seed=1, 30 | sample_interval=SAMPLE_INTERVAL, 31 | bg_cps=BG_RATE, 32 | fwhm_function_args=(5,), 33 | dwell_time_function_args=(20, 20), 34 | snr_function_args=(20, 20), 35 | return_gross=True, 36 | rng=rng)\ 37 | .generate(fg_seeds_ss, mixed_bg_seeds_ss, verbose=False) 38 | 39 | _, gross_events = list(zip(*events)) 40 | passby_ss = gross_events[0] 41 | 42 | expected_bg_counts = SAMPLE_INTERVAL * BG_RATE 43 | expected_bg_measurement = mixed_bg_seeds_ss.spectra.iloc[0] * expected_bg_counts 44 | ed = PoissonNChannelEventDetector( 45 | long_term_duration=600, 46 | short_term_duration=10, 47 | pre_event_duration=1, 48 | max_event_duration=120, 49 | post_event_duration=10, 50 | tolerable_false_alarms_per_day=1e-5, 51 | anomaly_threshold_update_interval=60, 52 | ) 53 | cps_history = [] 54 | 55 | # Filling background 56 | measurement_id = 0 57 | while ed.background_percent_complete < 100: 58 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement) 59 | cps_history.append(noisy_bg_measurement.sum() / SAMPLE_INTERVAL) 60 | _ = ed.add_measurement( 61 | measurement_id, 62 | noisy_bg_measurement, 63 | SAMPLE_INTERVAL, 64 | verbose=False 65 | ) 66 | measurement_id += 1 67 | 68 | # Create event using a synthesized passby 69 | for i in range(passby_ss.n_samples): 70 | gross_spectrum = passby_ss.spectra.iloc[i].values 71 | cps_history.append(gross_spectrum.sum() / SAMPLE_INTERVAL) 72 | event_result = ed.add_measurement( 73 | measurement_id=measurement_id, 74 | measurement=gross_spectrum, 75 | duration=SAMPLE_INTERVAL, 76 | verbose=False 77 | ) 78 | measurement_id += 1 79 | if event_result: 80 | break 81 | 82 | self.assertTrue(event_result is not None) 83 | -------------------------------------------------------------------------------- /tests/gadras_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module tests the gadras module.""" 5 | import unittest 6 | 7 | import pandas as pd 8 | 9 | from riid import get_dummy_seeds 10 | from riid.gadras.pcf import (_pack_compressed_text_buffer, 11 | _unpack_compressed_text_buffer) 12 | 13 | 14 | class TestGadras(unittest.TestCase): 15 | """Test class for Gadras.""" 16 | 17 | def test_pcf_header_formatting(self): 18 | """Tests the PCF header information is properly packed. 19 | 20 | String formatting note: 21 | f"{'hello':60.50}" 22 | ^ ^ ^ 23 | | | '> The number of letters to allow from the input value 24 | | '> The length of the string 25 | '> The input value to format 26 | 27 | """ 28 | FIELD_LENGTH = 10 29 | test_cases = [ 30 | ( 31 | "tttttttt ", "ddddddddd ", "ssssssssss", 32 | "tttttttt", "ddddddddd", "ssssssssss", 33 | "tttttttt ddddddddd ssssssssss", 34 | ), 35 | ( 36 | "tttttttttt", "dddddddddd", "ssssssssss+", 37 | "tttttttttt", "", "ssssssssss+", 38 | "ÿttttttttttÿÿssssssssss+ " 39 | ), 40 | ( 41 | "tttttttt ", "dddddddd ", "ssssssssss+", 42 | "tttttttt", "dddddddd", "ssssssssss+", 43 | "ÿttttttttÿddddddddÿssssssssss+" 44 | ), 45 | ( 46 | "tt ", "dddddddddd+", "ssssssssss++", 47 | "tt", "dddddddddd+", "ssssssssss++", 48 | "ÿttÿdddddddddd+ÿssssssssss++ " 49 | ), 50 | ] 51 | for case in test_cases: 52 | title, desc, source, \ 53 | expected_title, expected_desc, expected_source, \ 54 | expected_ctb = case 55 | actual_ctb = _pack_compressed_text_buffer( 56 | title, 57 | desc, 58 | source, 59 | field_len=FIELD_LENGTH 60 | ) 61 | actual_title, actual_desc, actual_source = _unpack_compressed_text_buffer( 62 | actual_ctb, 63 | field_len=FIELD_LENGTH 64 | ) 65 | self.assertEqual(expected_title, actual_title) 66 | self.assertEqual(expected_desc, actual_desc) 67 | self.assertEqual(expected_source, actual_source) 68 | self.assertEqual(expected_ctb, actual_ctb) 69 | 70 | def test_to_pcf_with_various_sources_dataframes(self): 71 | TEMP_PCF_PATH = "temp.pcf" 72 | 73 | # With all levels 74 | ss = get_dummy_seeds() 75 | ss.to_pcf(TEMP_PCF_PATH, verbose=False) 76 | 77 | # Without seed level (only category and isotope) 78 | ss = get_dummy_seeds() 79 | ss.sources.columns.droplevel("Seed") 80 | ss.to_pcf(TEMP_PCF_PATH, verbose=False) 81 | 82 | # Without seed and isotope levels (only category) 83 | ss = get_dummy_seeds() 84 | ss.sources.columns.droplevel("Seed") 85 | ss.sources.columns.droplevel("Isotope") 86 | ss.to_pcf(TEMP_PCF_PATH, verbose=False) 87 | 88 | # With no sources 89 | ss = get_dummy_seeds() 90 | ss.sources = pd.DataFrame() 91 | ss.to_pcf(TEMP_PCF_PATH, verbose=False) 92 | 93 | 94 | if __name__ == "__main__": 95 | unittest.main() 96 | -------------------------------------------------------------------------------- /tests/model_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module tests the bayes module.""" 5 | import os 6 | import unittest 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from riid import SampleSet, SeedMixer, StaticSynthesizer, get_dummy_seeds 12 | from riid.models import (ARADLatentPredictor, ARADv1, ARADv2, 13 | LabelProportionEstimator, MLPClassifier, 14 | PoissonBayesClassifier) 15 | from riid.models.base import PyRIIDModel 16 | from riid.models.bayes import NegativeSpectrumError, ZeroTotalCountsError 17 | 18 | 19 | class TestModels(unittest.TestCase): 20 | """Test class for PyRIID models.""" 21 | def setUp(self): 22 | """Test setup.""" 23 | pass 24 | 25 | @classmethod 26 | def setUpClass(self): 27 | self.seeds_ss = get_dummy_seeds(n_channels=128) 28 | self.fg_seeds_ss, self.bg_seeds_ss = self.seeds_ss.split_fg_and_bg() 29 | self.mixed_bg_seeds_ss = SeedMixer(self.bg_seeds_ss, mixture_size=3).generate(1) 30 | self.static_synth = StaticSynthesizer(samples_per_seed=5) 31 | self.train_ss, _ = self.static_synth.generate(self.fg_seeds_ss, self.mixed_bg_seeds_ss, 32 | verbose=False) 33 | self.train_ss.prediction_probas = self.train_ss.sources 34 | self.train_ss.normalize() 35 | self.test_ss, _ = self.static_synth.generate(self.fg_seeds_ss, self.mixed_bg_seeds_ss, 36 | verbose=False) 37 | self.test_ss.normalize() 38 | 39 | @classmethod 40 | def tearDownClass(self): 41 | pass 42 | 43 | def test_pb_constructor_errors(self): 44 | """Testing for constructor errors when different arguments are provided.""" 45 | pb_model = PoissonBayesClassifier() 46 | 47 | # Empty argument provided 48 | spectra = np.array([]) 49 | ss = SampleSet() 50 | ss.spectra = pd.DataFrame(spectra) 51 | self.assertRaises(ValueError, pb_model.fit, ss) 52 | 53 | # Negative channel argument provided 54 | spectra = np.array([ 55 | [1, 1, 1, 1], 56 | [1, 1, -1, 1], 57 | [1, 1, 1, 1] 58 | ]) 59 | ss = SampleSet() 60 | ss.spectra = pd.DataFrame(spectra) 61 | self.assertRaises(NegativeSpectrumError, pb_model.fit, ss) 62 | 63 | # Zero total counts argument provided 64 | spectra = np.array([ 65 | [1, 1, 1, 1], 66 | [0, 0, 0, 0], 67 | [1, 1, 1, 1] 68 | ]) 69 | ss = SampleSet() 70 | ss.spectra = pd.DataFrame(spectra) 71 | self.assertRaises(ZeroTotalCountsError, pb_model.fit, ss) 72 | 73 | def test_pb_predict(self): 74 | """Tests the constructor with a valid SampleSet.""" 75 | seeds_ss = get_dummy_seeds() 76 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg() 77 | bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(1) 78 | 79 | # Create the PoissonBayesClassifier 80 | pb_model = PoissonBayesClassifier() 81 | pb_model.fit(fg_seeds_ss) 82 | 83 | # Get test samples 84 | gss = StaticSynthesizer( 85 | samples_per_seed=2, 86 | live_time_function_args=(4, 4), 87 | snr_function_args=(10, 10), 88 | rng=np.random.default_rng(42), 89 | return_fg=True, 90 | return_gross=True, 91 | ) 92 | test_fg_ss, test_gross_ss = gss.generate(fg_seeds_ss, bg_seeds_ss, verbose=False) 93 | test_bg_ss = test_gross_ss - test_fg_ss 94 | 95 | # Predict 96 | pb_model.predict(test_gross_ss, test_bg_ss) 97 | 98 | truth_labels = test_fg_ss.get_labels() 99 | predicted_labels = test_gross_ss.get_predictions() 100 | assert (truth_labels == predicted_labels).all() 101 | 102 | def test_pb_fit_save_load(self): 103 | _test_model_fit_save_load_predict(self, PoissonBayesClassifier, None, self.fg_seeds_ss) 104 | 105 | def test_mlp_fit_save_load_predict(self): 106 | _test_model_fit_save_load_predict(self, MLPClassifier, self.test_ss, self.train_ss, 107 | epochs=1) 108 | 109 | def test_lpe_fit_save_load_predict(self): 110 | _test_model_fit_save_load_predict(self, LabelProportionEstimator, self.test_ss, 111 | self.fg_seeds_ss, self.train_ss, epochs=1) 112 | 113 | def test_aradv1_fit_save_load_predict(self): 114 | _test_model_fit_save_load_predict(self, ARADv1, self.test_ss, self.train_ss, epochs=1) 115 | 116 | def test_aradv2_fit_save_load_predict(self): 117 | _test_model_fit_save_load_predict(self, ARADv2, self.test_ss, self.train_ss, epochs=1) 118 | 119 | def test_alp_fit_save_load_predict(self): 120 | arad_v2 = ARADv2() 121 | arad_v2.fit(self.train_ss, epochs=1) 122 | _test_model_fit_save_load_predict(self, ARADLatentPredictor, self.test_ss, arad_v2.model, 123 | self.train_ss, target_info_columns=["snr"], epochs=1) 124 | 125 | 126 | def _try_remove_model_and_info(model_path: str): 127 | if os.path.exists(model_path): 128 | if os.path.isdir(model_path): 129 | os.rmdir(model_path) 130 | else: 131 | os.remove(model_path) 132 | 133 | 134 | def _test_model_fit_save_load_predict(test_case: unittest.TestCase, model_class: PyRIIDModel, 135 | test_ss: SampleSet = None, *args_for_fit, **kwargs_for_fit): 136 | m1 = model_class() 137 | m2 = model_class() 138 | 139 | m1.fit(*args_for_fit, **kwargs_for_fit) 140 | 141 | model_path = m1._temp_file_path 142 | 143 | _try_remove_model_and_info(model_path) 144 | test_case.assertRaises(ValueError, m2.load, model_path) 145 | 146 | m1.save(model_path) 147 | test_case.assertRaises(ValueError, m1.save, model_path) 148 | 149 | m2.load(model_path) 150 | _try_remove_model_and_info(model_path) 151 | 152 | if test_ss is not None: 153 | m1.predict(test_ss) 154 | 155 | 156 | if __name__ == "__main__": 157 | unittest.main() 158 | -------------------------------------------------------------------------------- /tests/seedmixer_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module tests the sampleset module.""" 5 | import unittest 6 | 7 | import numpy as np 8 | from scipy.spatial.distance import jensenshannon 9 | 10 | from riid import SampleSet, SeedMixer, get_dummy_seeds 11 | 12 | 13 | class TestSeedMixer(unittest.TestCase): 14 | """Test seed mixing functionality of SampleSet. 15 | """ 16 | @classmethod 17 | def setUpClass(self): 18 | random_state = 42 19 | self.rng = np.random.default_rng(random_state) 20 | 21 | self.ss, _ = get_dummy_seeds().split_fg_and_bg() 22 | self.ss.normalize() 23 | self.sources = self.ss.get_labels().values 24 | 25 | self.two_mix_seeds_ss = SeedMixer( 26 | self.ss, 27 | mixture_size=2, 28 | dirichlet_alpha=10, 29 | ).generate(n_samples=20) 30 | 31 | self.three_mix_seeds_ss = SeedMixer( 32 | self.ss, 33 | mixture_size=3, 34 | dirichlet_alpha=10, 35 | ).generate(n_samples=20) 36 | 37 | def test_mixture_combinations(self): 38 | # check that each mixture contains unique isotopes and the correct mixture size 39 | two_mix_isotopes = [ 40 | x.split(" + ") 41 | for x in self.two_mix_seeds_ss.get_labels(target_level="Isotope", max_only=False) 42 | ] 43 | self.assertTrue(all([len(set(x)) == 2 for x in two_mix_isotopes])) 44 | 45 | three_mix_isotopes = [ 46 | x.split(" + ") 47 | for x in self.three_mix_seeds_ss.get_labels(target_level="Isotope", max_only=False) 48 | ] 49 | self.assertTrue(all([len(set(x)) == 3 for x in three_mix_isotopes])) 50 | 51 | def test_mixture_ratios(self): 52 | # check for valid probability distribution 53 | for each in self.two_mix_seeds_ss.get_source_contributions(target_level="Isotope"): 54 | self.assertAlmostEqual(each.sum(), 1.0) 55 | 56 | for each in self.three_mix_seeds_ss.get_source_contributions(target_level="Isotope"): 57 | self.assertAlmostEqual(each.sum(), 1.0) 58 | 59 | def test_mixture_number(self): 60 | # check that number of samples is less than largest possible combinations 61 | # (worst case scenario) 62 | self.assertEqual(self.two_mix_seeds_ss.n_samples, 20) 63 | self.assertEqual(self.two_mix_seeds_ss.n_samples, 20) 64 | 65 | def test_mixture_pdf(self): 66 | # check that each mixture sums to one 67 | for sample in range(self.two_mix_seeds_ss.n_samples): 68 | self.assertAlmostEqual(self.two_mix_seeds_ss.spectra.values[sample, :].sum(), 1.0) 69 | 70 | for sample in range(self.three_mix_seeds_ss.n_samples): 71 | self.assertAlmostEqual(self.three_mix_seeds_ss.spectra.values[sample, :].sum(), 1.0) 72 | 73 | def test_spectrum_construction_3seeds_2mix(self): 74 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg() 75 | mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2, rng=self.rng).generate(100) 76 | spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss) 77 | self.assertTrue(np.isclose(spectral_distances, 0.0).all()) 78 | 79 | def test_spectrum_construction_3seeds_3mix(self): 80 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg() 81 | mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=self.rng).generate(100) 82 | spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss) 83 | self.assertTrue(np.isclose(spectral_distances, 0.0).all()) 84 | 85 | def test_spectrum_construction_2seeds_2mix(self): 86 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg( 87 | bg_seed_names=SampleSet.DEFAULT_BG_SEED_NAMES[1:3] 88 | ) 89 | mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2, rng=self.rng).generate(100) 90 | spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss) 91 | self.assertTrue(np.isclose(spectral_distances, 0.0).all()) 92 | 93 | def test_spectrum_construction_2seeds_2mix_error(self): 94 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16).split_fg_and_bg( 95 | bg_seed_names=SampleSet.DEFAULT_BG_SEED_NAMES[1:3] 96 | ) 97 | mixer = SeedMixer(bg_seeds_ss, mixture_size=3) 98 | self.assertRaises(ValueError, mixer.generate, 100) 99 | 100 | 101 | def _get_spectral_distances(seeds_ss, mixed_ss): 102 | recon_spectra = np.dot( 103 | seeds_ss.spectra.values.T, 104 | mixed_ss.sources.values.T 105 | ).T 106 | spectral_distances = jensenshannon( 107 | recon_spectra, 108 | mixed_ss.spectra.values, 109 | axis=1 110 | ) 111 | spectral_distances = np.nan_to_num(spectral_distances, nan=0.0) 112 | return spectral_distances 113 | 114 | 115 | if __name__ == "__main__": 116 | unittest.main() 117 | -------------------------------------------------------------------------------- /tests/visualize_tests.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). 2 | # Under the terms of Contract DE-NA0003525 with NTESS, 3 | # the U.S. Government retains certain rights in this software. 4 | """This module tests the visualize module.""" 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds 10 | from riid.metrics import precision_recall_curve 11 | from riid.models import MLPClassifier 12 | from riid.visualize import (plot_correlation_between_all_labels, 13 | plot_count_rate_history, 14 | plot_label_and_prediction_distributions, 15 | plot_label_distribution, plot_learning_curve, 16 | plot_live_time_vs_snr, plot_precision_recall, 17 | plot_prediction_distribution, 18 | plot_score_distribution, plot_snr_vs_score, 19 | plot_spectra, plot_ss_comparison) 20 | 21 | 22 | class TestVisualize(unittest.TestCase): 23 | """Testing plot functions in the visualize module.""" 24 | @classmethod 25 | def setUpClass(self): 26 | """Test setup.""" 27 | self.fg_seeds_ss, self.bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() 28 | self.mixed_bg_seed_ss = SeedMixer(self.bg_seeds_ss, mixture_size=3).generate(10) 29 | 30 | self.static_synth = StaticSynthesizer( 31 | samples_per_seed=100, 32 | snr_function="log10", 33 | return_fg=False, 34 | return_gross=True 35 | ) 36 | _, self.train_ss = self.static_synth.generate( 37 | self.fg_seeds_ss, 38 | self.mixed_bg_seed_ss, 39 | verbose=False 40 | ) 41 | self.train_ss.normalize() 42 | 43 | model = MLPClassifier() 44 | self.history = model.fit(self.train_ss, epochs=10, patience=5).history 45 | model.predict(self.train_ss) 46 | 47 | # Generate some test data 48 | self.static_synth.samples_per_seed = 50 49 | _, self.test_ss = self.static_synth.generate( 50 | self.fg_seeds_ss, 51 | self.mixed_bg_seed_ss, 52 | verbose=False 53 | ) 54 | self.test_ss.normalize() 55 | model.predict(self.test_ss) 56 | 57 | def test_plot_live_time_vs_snr(self): 58 | """Plots SNR against live time for all samples in a SampleSet. 59 | Prediction and label information is used to distinguish between correct and incorrect 60 | classifications using color (green for correct, red for incorrect). 61 | """ 62 | plot_live_time_vs_snr(self.test_ss, show=False) 63 | plot_live_time_vs_snr(self.train_ss, self.test_ss, show=False) 64 | 65 | def test_plot_snr_vs_score(self): 66 | """Plots SNR against prediction score for all samples in a SampleSet. 67 | Prediction and label information is used to distinguish between correct and incorrect 68 | classifications using color (green for correct, red for incorrect). 69 | """ 70 | plot_snr_vs_score(self.train_ss, show=False) 71 | plot_snr_vs_score(self.train_ss, self.test_ss, show=False) 72 | 73 | def test_plot_spectra(self): 74 | """Plots the spectra contained with a SampleSet.""" 75 | plot_spectra(self.fg_seeds_ss, ylim=(None, None), in_energy=False, show=False) 76 | plot_spectra(self.fg_seeds_ss, ylim=(None, None), in_energy=True, show=False) 77 | 78 | def test_plot_learning_curve(self): 79 | """Plots training and validation loss curves.""" 80 | plot_learning_curve(self.history["loss"], 81 | self.history["val_loss"], 82 | show=False) 83 | plot_learning_curve(self.history["loss"], 84 | self.history["val_loss"], 85 | smooth=True, 86 | show=False) 87 | 88 | def test_plot_count_rate_history(self): 89 | """Plots a count rate history.""" 90 | counts = np.random.normal(size=1000) 91 | histogram, _ = np.histogram(counts, bins=100, range=(0, 100)) 92 | plot_count_rate_history(histogram, 1, 80, 20, show=False) 93 | 94 | def test_plot_label_and_prediction_distributions(self): 95 | """Plots distributions of data labels, predictions, and prediction scores.""" 96 | plot_score_distribution(self.test_ss, show=False) 97 | plot_label_distribution(self.test_ss, show=False) 98 | plot_prediction_distribution(self.test_ss, show=False) 99 | plot_label_and_prediction_distributions(self.test_ss, show=False) 100 | 101 | def test_plot_correlation_between_all_labels(self): 102 | """Plots a correlation matrix of each label against each other label.""" 103 | plot_correlation_between_all_labels(self.bg_seeds_ss, show=False) 104 | plot_correlation_between_all_labels(self.bg_seeds_ss, mean=True, show=False) 105 | 106 | def test_plot_precision_recall(self): 107 | """Plots the multi-class or multi-label Precision-Recall curve and marks the optimal 108 | F1 score for each class. 109 | """ 110 | precision, recall, _ = precision_recall_curve(self.test_ss) 111 | plot_precision_recall(precision=precision, recall=recall, show=False) 112 | 113 | def test_plot_ss_comparison(self): 114 | """Creates a plot for output from SampleSet.compare_to().""" 115 | SYNTHETIC_DATA_CONFIG = { 116 | "samples_per_seed": 100, 117 | "bg_cps": 100, 118 | "snr_function": "uniform", 119 | "snr_function_args": (1, 100), 120 | "live_time_function": "uniform", 121 | "live_time_function_args": (0.25, 10), 122 | "apply_poisson_noise": True, 123 | "return_fg": False, 124 | "return_gross": True, 125 | } 126 | 127 | _, gross_ss1 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\ 128 | .generate(self.fg_seeds_ss, self.mixed_bg_seed_ss, verbose=False) 129 | _, gross_ss2 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\ 130 | .generate(self.fg_seeds_ss, self.mixed_bg_seed_ss, verbose=False) 131 | 132 | # Compare two different gross sample sets - Live time 133 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2, 134 | density=False) 135 | plot_ss_comparison(ss1_stats, 136 | ss2_stats, 137 | col_comparisons, 138 | "live_time", 139 | show=False) 140 | 141 | # Compare two different gross sample sets - Total Counts 142 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2, 143 | density=True) 144 | plot_ss_comparison(ss1_stats, 145 | ss2_stats, 146 | col_comparisons, 147 | "total_counts", 148 | show=False) 149 | 150 | # Compare the same sampleset to itself - Live time 151 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss1, 152 | density=False) 153 | plot_ss_comparison(ss1_stats, 154 | ss2_stats, 155 | col_comparisons, 156 | "live_time", 157 | show=False) 158 | 159 | # Compare the same sampleset to itself - Total Counts 160 | ss1_stats, ss2_stats, col_comparisons = gross_ss2.compare_to(gross_ss2, 161 | density=True) 162 | plot_ss_comparison(ss1_stats, 163 | ss2_stats, 164 | col_comparisons, 165 | "total_counts", 166 | show=False) 167 | 168 | 169 | if __name__ == "__main__": 170 | unittest.main() 171 | --------------------------------------------------------------------------------