├── .coveragerc
├── .flake8
├── .github
└── workflows
│ ├── build-and-publish-docs.yml
│ ├── build-and-run-tests.yml
│ ├── publish-to-pypi.yml
│ ├── run-examples.yml
│ └── run-linting.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE.md
├── NOTICE.md
├── README.md
├── examples
├── courses
│ ├── Primer 1
│ │ ├── Primer1.ipynb
│ │ ├── spec_csi_many_sources.yaml
│ │ ├── spec_nai_few_sources.yaml
│ │ └── spec_nai_many_sources.yaml
│ └── Primer 2
│ │ ├── Primer 2.ipynb
│ │ └── seed_configs
│ │ ├── advanced.yaml
│ │ ├── basic.yaml
│ │ └── problem.yaml
├── data
│ ├── conversion
│ │ ├── aipt_to_ss.py
│ │ ├── pcf_to_ss.py
│ │ └── topcoder_to_ss.py
│ ├── difficulty_score.py
│ ├── preprocessing
│ │ └── energy_calibration.py
│ └── synthesis
│ │ ├── mix_seeds.py
│ │ ├── synthesize_passbys.py
│ │ ├── synthesize_seeds_advanced.py
│ │ ├── synthesize_seeds_basic.py
│ │ ├── synthesize_seeds_custom.py
│ │ └── synthesize_static.py
├── modeling
│ ├── anomaly_detection.py
│ ├── arad.py
│ ├── arad_latent_prediction.py
│ ├── classifier_comparison.py
│ ├── custom_loss_and_metric.py
│ ├── label_proportion_estimation.py
│ └── neural_network_classifier.py
├── run_examples.py
└── visualization
│ ├── confusion_matrix.py
│ ├── distance_matrix.py
│ ├── plot_sampleset_compare_to.py
│ └── plot_spectra.py
├── pdoc
├── config.mako
└── requirements.txt
├── pyproject.toml
├── riid
├── __init__.py
├── anomaly.py
├── data
│ ├── __init__.py
│ ├── converters
│ │ ├── __init__.py
│ │ ├── aipt.py
│ │ └── topcoder.py
│ ├── labeling.py
│ ├── sampleset.py
│ └── synthetic
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── passby.py
│ │ ├── seed.py
│ │ └── static.py
├── gadras
│ ├── __init__.py
│ ├── api.py
│ ├── api_schema.json
│ └── pcf.py
├── losses
│ ├── __init__.py
│ └── sparsemax.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── bayes.py
│ ├── layers.py
│ └── neural_nets
│ │ ├── __init__.py
│ │ ├── arad.py
│ │ ├── basic.py
│ │ └── lpe.py
└── visualize.py
├── run_tests.ps1
├── run_tests.sh
└── tests
├── anomaly_tests.py
├── config_tests.py
├── data_tests.py
├── gadras_tests.py
├── model_tests.py
├── sampleset_tests.py
├── seedmixer_tests.py
├── staticsynth_tests.py
└── visualize_tests.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit =
3 | */site-packages/*
4 | */distutils/*
5 | tests/*
6 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 100
3 | inline-quotes = double
4 | avoid-escape = False
5 |
--------------------------------------------------------------------------------
/.github/workflows/build-and-publish-docs.yml:
--------------------------------------------------------------------------------
1 | name: Build docs
2 | on:
3 | push:
4 | branches:
5 | - main
6 | permissions:
7 | contents: read
8 | jobs:
9 | build:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v4
13 | - uses: actions/setup-python@v5
14 | with:
15 | python-version: "3.10"
16 | cache: "pip"
17 | cache-dependency-path: "**/pyproject.toml"
18 | - run: pip install -e .
19 | - run: pip install -r pdoc/requirements.txt
20 | - run: pdoc riid -o docs/ --html --template-dir pdoc
21 | - run: echo '' > docs/index.html
22 | - uses: actions/upload-pages-artifact@v3
23 | with:
24 | path: docs/
25 | deploy:
26 | needs: build
27 | runs-on: ubuntu-latest
28 | permissions:
29 | pages: write
30 | id-token: write
31 | environment:
32 | name: github-pages
33 | url: ${{ steps.deployment.outputs.page_url }}
34 | steps:
35 | - name: Deploy to GitHub Pages
36 | id: deployment
37 | uses: actions/deploy-pages@v4
38 |
--------------------------------------------------------------------------------
/.github/workflows/build-and-run-tests.yml:
--------------------------------------------------------------------------------
1 | name: Run unit tests
2 | on:
3 | push:
4 | branches:
5 | - '**'
6 | schedule:
7 | - cron: '0 0 * * 1' # Trigger every Monday at midnight (UTC)
8 | concurrency:
9 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
10 | cancel-in-progress: true
11 | jobs:
12 | build:
13 | strategy:
14 | matrix:
15 | python-version: ["3.10", "3.11", "3.12"]
16 | os: [ubuntu-latest, windows-latest, macos-latest]
17 | runs-on: ${{ matrix.os }}
18 | steps:
19 | - name: Checkout
20 | uses: actions/checkout@v4
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v5
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | cache: "pip"
26 | cache-dependency-path: "**/pyproject.toml"
27 | - name: Install HDF5 (macOS only)
28 | if: runner.os == 'macOS'
29 | run: |
30 | brew install hdf5
31 | brew install c-blosc2
32 | - name: Set HDF5_DIR environment variable (macOS only)
33 | if: runner.os == 'macOS'
34 | run: |
35 | echo "HDF5_DIR=$(brew --prefix hdf5)" >> $GITHUB_ENV
36 | echo "BLOSC2_DIR=$(brew --prefix c-blosc2)" >> $GITHUB_ENV
37 | - name: Install dependencies
38 | run: |
39 | python -m pip install --upgrade pip
40 | pip install ".[dev]"
41 | - name: Run unit tests
42 | run: |
43 | sh run_tests.sh
44 | if: github.event_name == 'push' || (github.event_name == 'schedule' && github.ref == 'refs/heads/main')
45 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI
2 | on:
3 | release:
4 | types: [published]
5 | jobs:
6 | build-n-publish:
7 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI
8 | runs-on: ubuntu-latest
9 | environment:
10 | name: pypi
11 | url: https://pypi.org/p/riid
12 | permissions:
13 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
14 | steps:
15 | - name: Checkout
16 | uses: actions/checkout@v4
17 | - name: Set up Python
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: "3.10"
21 | - name: Install pypa/build
22 | run: >-
23 | python -m
24 | pip install
25 | build
26 | --user
27 | - name: Build a binary wheel and a source tarball
28 | run: >-
29 | python -m
30 | build
31 | --sdist
32 | --wheel
33 | --outdir dist/
34 | - name: Publish distribution 📦 to Test PyPI
35 | uses: pypa/gh-action-pypi-publish@release/v1
36 | with:
37 | repository-url: https://test.pypi.org/legacy/
38 | - name: Publish distribution 📦 to PyPI
39 | uses: pypa/gh-action-pypi-publish@release/v1
40 |
--------------------------------------------------------------------------------
/.github/workflows/run-examples.yml:
--------------------------------------------------------------------------------
1 | name: Run examples
2 | on: [push]
3 | concurrency:
4 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
5 | cancel-in-progress: true
6 | jobs:
7 | build:
8 | strategy:
9 | matrix:
10 | python-version: ["3.10", "3.11", "3.12"]
11 | os: [ubuntu-latest, windows-latest, macos-latest]
12 | runs-on: ${{ matrix.os }}
13 | steps:
14 | - name: Checkout
15 | uses: actions/checkout@v4
16 | - name: Set up Python ${{ matrix.python-version }}
17 | uses: actions/setup-python@v5
18 | with:
19 | python-version: ${{ matrix.python-version }}
20 | cache: "pip"
21 | cache-dependency-path: "**/pyproject.toml"
22 | - name: Install HDF5 (macOS only)
23 | if: runner.os == 'macOS'
24 | run: |
25 | brew install hdf5
26 | brew install c-blosc2
27 | - name: Set HDF5_DIR environment variable (macOS only)
28 | if: runner.os == 'macOS'
29 | run: |
30 | echo "HDF5_DIR=$(brew --prefix hdf5)" >> $GITHUB_ENV
31 | echo "BLOSC2_DIR=$(brew --prefix c-blosc2)" >> $GITHUB_ENV
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip setuptools wheel
35 | pip install -e ".[dev]"
36 | - name: Run examples
37 | run: |
38 | python examples/run_examples.py
39 |
--------------------------------------------------------------------------------
/.github/workflows/run-linting.yml:
--------------------------------------------------------------------------------
1 | name: Run linting
2 | on: [push]
3 | concurrency:
4 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
5 | cancel-in-progress: true
6 | jobs:
7 | linting:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@v3
11 | - name: Set up Python
12 | uses: actions/setup-python@v4
13 | with:
14 | python-version: "3.10"
15 | cache: "pip"
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip setuptools wheel
19 | pip install flake8 flake8-quotes
20 | - name: Run Flake8 linter
21 | run: |
22 | flake8 riid examples tests
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project-specific
2 | *.zip
3 | *.p
4 | *.png
5 | *.jpg
6 | *.jpeg
7 | *.pdf
8 | *.smpl
9 | *.h5
10 | *.pcf
11 | *.onnx
12 | *.tflite
13 | *.hdf
14 | examples/**/*.json
15 | */__pycache__
16 | _scripts/
17 | debug/
18 | *.whl
19 | docs/
20 |
21 | # Mac-specific
22 | *.DS_Store
23 | ._*
24 | .ipynb_checkpoints
25 |
26 | # VS code
27 | .vscode
28 |
29 | # History files
30 | .Rhistory
31 | .Rapp.history
32 |
33 | # Session Data files
34 | .RData
35 |
36 | # Example code in package build process
37 | *-Ex.R
38 |
39 | # Output files from R CMD build
40 | /*.tar.gz
41 |
42 | # Output files from R CMD check
43 | /*.Rcheck/
44 |
45 | # RStudio files
46 | .Rproj.user/
47 |
48 | # produced vignettes
49 | vignettes/*.html
50 | vignettes/*.pdf
51 |
52 | # OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3
53 | .httr-oauth
54 |
55 | # knitr and R markdown default cache directories
56 | /*_cache/
57 | /cache/
58 |
59 | # Temporary files created by R markdown
60 | *.utf8.md
61 | *.knit.md
62 |
63 | # Emacs
64 | *~
65 |
66 | # Byte-compiled / optimized / DLL files
67 | __pycache__/
68 | *.py[cod]
69 | *$py.class
70 |
71 | # C extensions
72 | *.so
73 |
74 | # Distribution / packaging
75 | .Python
76 | build/
77 | develop-eggs/
78 | dist/
79 | downloads/
80 | eggs/
81 | .eggs/
82 | lib/
83 | lib64/
84 | parts/
85 | sdist/
86 | var/
87 | wheels/
88 | *.egg-info/
89 | .installed.cfg
90 | *.egg
91 | MANIFEST
92 |
93 | # PyInstaller
94 | # Usually these files are written by a python script from a template
95 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
96 | *.manifest
97 | *.spec
98 |
99 | # Installer logs
100 | pip-log.txt
101 | pip-delete-this-directory.txt
102 |
103 | # Unit test / coverage reports
104 | htmlcov/
105 | .tox/
106 | .coverage
107 | .coverage.*
108 | .cache
109 | nosetests.xml
110 | coverage.xml
111 | *.cover
112 | .hypothesis/
113 |
114 | # Translations
115 | *.mo
116 | *.pot
117 |
118 | # Django stuff:
119 | *.log
120 | .static_storage/
121 | .media/
122 | local_settings.py
123 |
124 | # Flask stuff:
125 | instance/
126 | .webassets-cache
127 |
128 | # Scrapy stuff:
129 | .scrapy
130 |
131 | # Sphinx documentation
132 | docs/_build/
133 |
134 | # PyBuilder
135 | target/
136 |
137 | # Jupyter Notebook
138 | .ipynb_checkpoints
139 |
140 | # pyenv
141 | .python-version
142 |
143 | # celery beat schedule file
144 | celerybeat-schedule
145 |
146 | # SageMath parsed files
147 | *.sage.py
148 |
149 | # Environments
150 | .env
151 | .venv
152 | env/
153 | venv/
154 | ENV/
155 | env.bak/
156 | venv.bak/
157 | venv_*/
158 | *_venv/
159 |
160 | # Spyder project settings
161 | .spyderproject
162 | .spyproject
163 |
164 | # Rope project settings
165 | .ropeproject
166 |
167 | # mkdocs documentation
168 | /site
169 |
170 | # mypy
171 | .mypy_cache/
172 |
173 | # files being used for development only
174 | sandbox.py
175 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, caste, color, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | [INSERT CONTACT METHOD].
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0].
120 |
121 | Community Impact Guidelines were inspired by
122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
123 |
124 | For answers to common questions about this code of conduct, see the FAQ at
125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available
126 | at [https://www.contributor-covenant.org/translations][translations].
127 |
128 | [homepage]: https://www.contributor-covenant.org
129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html
130 | [Mozilla CoC]: https://github.com/mozilla/diversity
131 | [FAQ]: https://www.contributor-covenant.org/faq
132 | [translations]: https://www.contributor-covenant.org/translations
133 |
134 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | # BSD 3-Clause License
2 |
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 |
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 |
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 |
9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10 |
11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
12 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
--------------------------------------------------------------------------------
/NOTICE.md:
--------------------------------------------------------------------------------
1 | This source code is part of the PyRIID project and is licensed under the BSD-style licence.
2 | This project also contains code covered under the Apache-2.0 license based on Tensorflow-Addons functions which can be found in `riid/models/losses/sparsemax.py`.
3 |
4 | The following is a list of the relevent copyright and license information.
5 |
6 | ---
7 |
8 | Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
9 | Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights in this software.
10 | This source code is licensed under the BSD-style license found [here](https://github.com/sandialabs/PyRIID/blob/main/LICENSE.md).
11 |
12 | ---
13 |
14 | Copyright 2016 The TensorFlow Authors. All Rights Reserved.
15 |
16 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
17 | You may obtain a copy of the License at
18 |
19 | http://www.apache.org/licenses/LICENSE-2.0
20 |
21 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
22 | See the License for the specific language governing permissions and limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 
6 | 
7 |
8 | PyRIID is a Python package providing modeling and data synthesis utilities for machine learning-based research and development of radioisotope-related detection, identification, and quantification.
9 |
10 | ## Installation
11 |
12 | Requirements:
13 |
14 | - Python version: 3.10 to 3.12
15 | - Note: we recommended the highest Python version you can manage as anecdotally, we have noticed that everything just tends to get faster.
16 | - Operating systems: Windows, Mac, or Ubuntu
17 |
18 | Tests and examples are run via Actions on many combinations of Python version and operating system.
19 | You can verify support for your platform by checking the workflow files.
20 |
21 | ### For Use
22 |
23 | To use the latest version on PyPI, run:
24 |
25 | ```sh
26 | pip install riid
27 | ```
28 |
29 | Note that changes are slower to appear on PyPI, so for the latest features, run:**
30 |
31 | ```sh
32 | pip install git+https://github.com/sandialabs/pyriid.git@main
33 | ```
34 |
35 | ### For Development
36 |
37 | If you are developing PyRIID, clone this repository and run:
38 |
39 | ```sh
40 | pip install -e ".[dev]"
41 | ```
42 |
43 | If you encounter Pylance issues, try:
44 |
45 | ```sh
46 | pip install -e ".[dev]" --config-settings editable_mode=compat
47 | ```
48 |
49 | ## Examples
50 |
51 | Examples for how to use this package can be found [here](https://github.com/sandialabs/PyRIID/blob/main/examples).
52 |
53 | ## Tests
54 |
55 | Unit tests for this package can be found [here](https://github.com/sandialabs/PyRIID/blob/main/tests).
56 |
57 | Run all unit tests with the following:
58 |
59 | ```sh
60 | python -m unittest tests/*.py -v
61 | ```
62 |
63 | You can also run one of the `run_tests.*` scripts, whichever is appropriate for your platform.
64 |
65 | ## Docs
66 |
67 | API documentation can be found [here](https://sandialabs.github.io/PyRIID).
68 |
69 | Docs can be built locally with the following:
70 |
71 | ```sh
72 | pip install -r pdoc/requirements.txt
73 | pdoc riid -o docs/ --html --template-dir pdoc
74 | ```
75 |
76 | ## Contributing
77 |
78 | Pull requests are welcome.
79 | For major changes, please open an issue first to discuss what you would like to change.
80 |
81 | Please make sure to update tests as appropriate and adhere to our [code of conduct](https://github.com/sandialabs/PyRIID/blob/main/CODE_OF_CONDUCT.md).
82 |
83 | ## Contacts
84 |
85 | Maintainers and authors can be found [here](https://github.com/sandialabs/PyRIID/blob/main/pyproject.toml).
86 |
87 | ## Copyright
88 |
89 | Full copyright details can be found [here](https://github.com/sandialabs/PyRIID/blob/main/NOTICE.md).
90 |
91 | ## Acknowledgements
92 |
93 | **Thank you** to the U.S. Department of Energy, National Nuclear Security Administration,
94 | Office of Defense Nuclear Nonproliferation Research and Development (DNN R&D) for funding that has led to versions `2.0` and `2.1`.
95 |
96 | Additionally, **thank you** to the following individuals who have provided invaluable subject-matter expertise:
97 |
98 | - Paul Thelen (also an author)
99 | - Ben Maestas
100 | - Greg Thoreson
101 | - Michael Enghauser
102 | - Elliott Leonard
103 |
104 | ## Citing
105 |
106 | When citing PyRIID, please reference the U.S. Department of Energy Office of Science and Technology Information (OSTI) record here:
107 | [10.11578/dc.20221017.2](https://doi.org/10.11578/dc.20221017.2)
108 |
109 | ## Related Reports, Publications, and Projects
110 |
111 | 1. Alan Van Omen, *"A Semi-Supervised Model for Multi-Label Radioisotope Classification and Out-of-Distribution Detection."* Diss. 2023. doi: [10.7302/7200](https://dx.doi.org/10.7302/7200).
112 | 2. Tyler Morrow, *"Questionnaire for Radioisotope Identification and Estimation from Gamma Spectra using PyRIID v2."* United States: N. p., 2023. Web. doi: [10.2172/2229893](https://doi.org/10.2172/2229893).
113 | 3. Aaron Fjeldsted, Tyler Morrow, and Douglas Wolfe, *"Identifying Signal-to-Noise Ratios Representative of Gamma Detector Response in Realistic Scenarios,"* 2023 IEEE Nuclear Science Symposium, Medical Imaging Conference and International Symposium on Room-Temperature Semiconductor Detectors (NSS MIC RTSD), Vancouver, BC, Canada, 2023. doi: [10.1109/NSSMICRTSD49126.2023.10337860](https://doi.org/10.1109/NSSMICRTSD49126.2023.10337860).
114 | 4. Alan Van Omen and Tyler Morrow, *"A Semi-supervised Learning Method to Produce Explainable Radioisotope Proportion Estimates for NaI-based Synthetic and Measured Gamma Spectra."* United States: N. p., 2024. Web. doi: [10.2172/2335904](https://doi.org/10.2172/2335904).
115 | - [Code, data, and best model](https://zenodo.org/doi/10.5281/zenodo.10223445)
116 | 5. Alan Van Omen and Tyler Morrow, *"Controlling Radioisotope Proportions When Randomly Sampling from Dirichlet Distributions in PyRIID."* United States: N. p., 2024. Web. doi: [10.2172/2335905](https://doi.org/10.2172/2335905).
117 | 6. Alan Van Omen, Tyler Morrow, et al., *"Multilabel Proportion Prediction and Out-of-distribution Detection on Gamma Spectra of Short-lived Fission Products."* Annals of Nuclear Energy 208 (2024): 110777. doi: [10.1016/j.anucene.2024.110777](https://doi.org/10.1016/j.anucene.2024.110777).
118 | - [Code, data, and best models](https://zenodo.org/doi/10.5281/zenodo.12796964)
119 | 7. Aaron Fjeldsted, Tyler Morrow, et al., *"A Novel Methodology for Gamma-Ray Spectra Dataset Procurement over Varying Standoff Distances and Source Activities,"* Nuclear Instruments and Methods in Physics Research Section A (2024): 169681. doi: [10.1016/j.nima.2024.169681](https://doi.org/10.1016/j.nima.2024.169681).
120 |
--------------------------------------------------------------------------------
/examples/courses/Primer 1/spec_csi_many_sources.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | gamma_detector:
3 | name: Generic\\CsI\\2x4x16
4 | parameters:
5 | distance_cm: 1000
6 | height_cm: 45
7 | dead_time_per_pulse: 5
8 | latitude_deg: 35.0
9 | longitude_deg: 253.4
10 | elevation_m: 1620
11 | sources:
12 | - isotope: Am241
13 | configurations:
14 | - Am241,100uC
15 | - Am241,100uC {13,10}
16 | - Am241,100uC {13,30}
17 | - Am241,100uC {26,5}
18 | - Am241,100uC {26,20}
19 | - Am241,100uC {26,50}
20 | - Am241,100uC {50,2}
21 | - isotope: Ba133
22 | configurations:
23 | - Ba133,100uC
24 | - Ba133,100uC {10,20}
25 | - Ba133,100uC {26,10}
26 | - Ba133,100uC {26,30}
27 | - Ba133,100uC {74,20}
28 | - Ba133,100uC {50,30}
29 | - isotope: Bi207
30 | configurations:
31 | - Bi207,100uC
32 | - Bi207,100uC {10,30}
33 | - Bi207,100uC {26,10}
34 | - isotope: Cf249
35 | configurations:
36 | - Cf249,100uC
37 | - isotope: Co57
38 | configurations:
39 | - Co57,100uC
40 | - Co57,100uC {13,10}
41 | - isotope: Co60
42 | configurations:
43 | - Co60,100uC
44 | - Co60,100uC {10,10}
45 | - Co60,100uC {10,30}
46 | - Co60,100uC {26,20}
47 | - Co60,100uC {26,40}
48 | - Co60,100uC {82,30}
49 | - Co60,100uC {82,60}
50 | - isotope: Cosmic
51 | configurations:
52 | - Cosmic
53 | - isotope: Cs137
54 | configurations:
55 | - Cs137,100uC
56 | - Cs137,100uC {6,2}
57 | - Cs137,100uC {26,10}
58 | - Cs137,100uC {82,10}
59 | - Cs137,100uC {13,240;26,1}
60 | - Cs137InPine
61 | - Cs137InLead
62 | - Cs137InGradedShield1
63 | - Cs137InGradedShield2
64 | - isotope: Eu152
65 | configurations:
66 | - Eu152,100uC
67 | - Eu152,100uC {10,10}
68 | - Eu152,100uC {10,30}
69 | - Eu152,100uC {30,50}
70 | - Eu152,100uC {74,20}
71 | - isotope: Eu154
72 | configurations:
73 | - Eu154,100uC
74 | - Eu154,100uC {10,10}
75 | - Eu154,100uC {74,20}
76 | - isotope: F18
77 | configurations:
78 | - F18,100uC
79 | - F18,100uC {10,20}
80 | - F18,100uC {10,50}
81 | - F18,100uC {26,30}
82 | - isotope: Ga67
83 | configurations:
84 | - Ga67,100uC
85 | - Ga67,100uC {6,10}
86 | - Ga67,100uC {6,20}
87 | - Ga67,100uC {10,30}
88 | - Ga67,100uC {50,16}
89 | - Ga67,100uC {82,20}
90 | - isotope: Ho166m
91 | configurations:
92 | - Ho166m,100uC
93 | - Ho166m,100uC {10,20}
94 | - Ho166m,100uC {26,20}
95 | - Ho166m,100uC {74,30}
96 | - isotope: I123
97 | configurations:
98 | - I123,100uC
99 | - I123,100uC {10,30}
100 | - isotope: I131
101 | configurations:
102 | - I131,100uC
103 | - I131,100uC {10,10}
104 | - I131,100uC {10,30}
105 | - I131,100uC {16,50}
106 | - I131,100uC {20,20}
107 | - I131,100uC {82,10}
108 | - I131,100uC {10,20;50,5}
109 | - isotope: In111
110 | configurations:
111 | - In111,100uC
112 | - In111,100uC {10,20}
113 | - In111,100uC {50,20}
114 | - isotope: Ir192
115 | configurations:
116 | - Ir192,100uC
117 | - Ir192,100uC {10,20}
118 | - Ir192,100uC {26,40}
119 | - Ir192,100uC {26,100}
120 | - Ir192,100uC {82,30}
121 | - Ir192,100uC {82,160}
122 | - Ir192Shielded1
123 | - Ir192Shielded2
124 | - Ir192Shielded3
125 | - isotope: K40
126 | configurations:
127 | - PotassiumInSoil
128 | - K40,100uC
129 | - K40inCargo1
130 | - K40inCargo2
131 | - K40inCargo3
132 | - K40inCargo4
133 | - isotope: Mo99
134 | configurations:
135 | - Mo99,100uC
136 | - Mo99,100uC {26,20}
137 | - Mo99,100uC {50,40}
138 | - isotope: Na22
139 | configurations:
140 | - Na22,100uC
141 | - Na22,100uC {10,10}
142 | - Na22,100uC {10,30}
143 | - Na22,100uC {74,20}
144 | - isotope: Np237
145 | configurations:
146 | - 1gNp237,1kC
147 | - 1kgNp237
148 | - Np237inFe1
149 | - Np237inFe4
150 | - Np237Shielded2
151 | - isotope: Pu238
152 | configurations:
153 | - Pu238,100uC
154 | - Pu238,100uC {10,5}
155 | - Pu238,100uC {26,10}
156 | - isotope: Pu239
157 | configurations:
158 | - 1kgPu239
159 | - 1kgPu239,1C {40,4}
160 | - 1kgPu239inFe
161 | - 1kgPu239inPine
162 | - 1kgPu239InFeAndPine
163 | - 1kgPu239inW
164 | - isotope: Ra226
165 | configurations:
166 | - UraniumInSoil
167 | - Ra226,100uC
168 | - Ra226inCargo1
169 | - Ra226inCargo2
170 | - Ra226inCargo3
171 | - Ra226inSilica1
172 | - Ra226inSilica2
173 | - Ra226,100uC {60,100}
174 | - isotope: Sr90
175 | configurations:
176 | - Sr90InPoly1,1C
177 | - Sr90InPoly10,1C
178 | - Sr90InFe,1C
179 | - Sr90InSn,1C
180 | - isotope: Tc99m
181 | configurations:
182 | - Tc99m,100uC
183 | - Tc99m,100uC {7,10}
184 | - Tc99m,100uC {10,20}
185 | - Tc99m,100uC {13,30}
186 | - Tc99m,100uC {26,30}
187 | - isotope: Th232
188 | configurations:
189 | - ThoriumInSoil
190 | - Th232,100uC
191 | - Th232inCargo1
192 | - Th232inCargo2
193 | - Th232inCargo3
194 | - Th232inSilica1
195 | - Th232inSilica2
196 | - Th232,100uC {60,100}
197 | - ThPlate
198 | - ThPlate+Thxray,10uC
199 | - isotope: Tl201
200 | configurations:
201 | - Tl201,100uC
202 | - Tl201,100uC {8,40}
203 | - Tl201,100uC {10,10}
204 | - Tl201,100uC {10,30}
205 | - Tl201,100uC {26,10}
206 | - isotope: U232
207 | configurations:
208 | - U232,100uC
209 | - U232,100uC {10,10}
210 | - U232,100uC {10,30}
211 | - U232,100uC {13,50}
212 | - U232,100uC {26,30}
213 | - U232,100uC {26,60}
214 | - U232inLeadAndPine
215 | - U232inLeadAndFe
216 | - U232inPineAndFe
217 | - U232,100uC {82,30}
218 | - isotope: U233
219 | configurations:
220 | - 1kgU233At1yr
221 | - 1kgU233InFeAt1yr
222 | - 1kgU233At50yr
223 | - 1kgU233InFeAt50yr
224 | - isotope: U235
225 | configurations:
226 | - 1kgU235
227 | - 1kgU235inFe
228 | - 1kgU235inPine
229 | - 1kgU235inPineAndFe
230 | - isotope: U238
231 | configurations:
232 | - 1KGU238
233 | - 1KGU238Clad
234 | - 1KGU238inPine
235 | - 1KGU238inPine2
236 | - 1KGU238inPine3
237 | - 1KGU238inFe
238 | - 1KGU238inPineAndFe
239 | - 1KGU238inW
240 | - U238Fiesta
241 | - fiestaware
242 | - DUOxide
243 | - isotope: Y88
244 | configurations:
245 | - Y88,100uC
246 | - Y88,100uC {10,50}
247 | - Y88,100uC {26,30}
248 | - Y88,100uC {80,50}
249 | ...
250 |
--------------------------------------------------------------------------------
/examples/courses/Primer 1/spec_nai_few_sources.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | gamma_detector:
3 | name: Generic\\NaI\\2x4x16
4 | parameters:
5 | distance_cm: 1000
6 | height_cm: 45
7 | dead_time_per_pulse: 5
8 | latitude_deg: 35.0
9 | longitude_deg: 253.4
10 | elevation_m: 1620
11 | sources:
12 | - isotope: Am241
13 | configurations:
14 | - Am241,100uC
15 | - isotope: Ba133
16 | configurations:
17 | - Ba133,100uC
18 | - isotope: Cs137
19 | configurations:
20 | - Cs137,100uC
21 | - isotope: K40
22 | configurations:
23 | - PotassiumInSoil
24 | - isotope: Ra226
25 | configurations:
26 | - UraniumInSoil
27 | - isotope: Th232
28 | configurations:
29 | - ThoriumInSoil
30 | ...
31 |
--------------------------------------------------------------------------------
/examples/courses/Primer 1/spec_nai_many_sources.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | gamma_detector:
3 | name: Generic\\NaI\\2x4x16
4 | parameters:
5 | distance_cm: 1000
6 | height_cm: 45
7 | dead_time_per_pulse: 5
8 | latitude_deg: 35.0
9 | longitude_deg: 253.4
10 | elevation_m: 1620
11 | sources:
12 | - isotope: Am241
13 | configurations:
14 | - Am241,100uC
15 | - Am241,100uC {13,10}
16 | - Am241,100uC {13,30}
17 | - Am241,100uC {26,5}
18 | - Am241,100uC {26,20}
19 | - Am241,100uC {26,50}
20 | - Am241,100uC {50,2}
21 | - isotope: Ba133
22 | configurations:
23 | - Ba133,100uC
24 | - Ba133,100uC {10,20}
25 | - Ba133,100uC {26,10}
26 | - Ba133,100uC {26,30}
27 | - Ba133,100uC {74,20}
28 | - Ba133,100uC {50,30}
29 | - isotope: Bi207
30 | configurations:
31 | - Bi207,100uC
32 | - Bi207,100uC {10,30}
33 | - Bi207,100uC {26,10}
34 | - isotope: Cf249
35 | configurations:
36 | - Cf249,100uC
37 | - isotope: Co57
38 | configurations:
39 | - Co57,100uC
40 | - Co57,100uC {13,10}
41 | - isotope: Co60
42 | configurations:
43 | - Co60,100uC
44 | - Co60,100uC {10,10}
45 | - Co60,100uC {10,30}
46 | - Co60,100uC {26,20}
47 | - Co60,100uC {26,40}
48 | - Co60,100uC {82,30}
49 | - Co60,100uC {82,60}
50 | - isotope: Cosmic
51 | configurations:
52 | - Cosmic
53 | - isotope: Cs137
54 | configurations:
55 | - Cs137,100uC
56 | - Cs137,100uC {6,2}
57 | - Cs137,100uC {26,10}
58 | - Cs137,100uC {82,10}
59 | - Cs137,100uC {13,240;26,1}
60 | - Cs137InPine
61 | - Cs137InLead
62 | - Cs137InGradedShield1
63 | - Cs137InGradedShield2
64 | - isotope: Eu152
65 | configurations:
66 | - Eu152,100uC
67 | - Eu152,100uC {10,10}
68 | - Eu152,100uC {10,30}
69 | - Eu152,100uC {30,50}
70 | - Eu152,100uC {74,20}
71 | - isotope: Eu154
72 | configurations:
73 | - Eu154,100uC
74 | - Eu154,100uC {10,10}
75 | - Eu154,100uC {74,20}
76 | - isotope: F18
77 | configurations:
78 | - F18,100uC
79 | - F18,100uC {10,20}
80 | - F18,100uC {10,50}
81 | - F18,100uC {26,30}
82 | - isotope: Ga67
83 | configurations:
84 | - Ga67,100uC
85 | - Ga67,100uC {6,10}
86 | - Ga67,100uC {6,20}
87 | - Ga67,100uC {10,30}
88 | - Ga67,100uC {50,16}
89 | - Ga67,100uC {82,20}
90 | - isotope: Ho166m
91 | configurations:
92 | - Ho166m,100uC
93 | - Ho166m,100uC {10,20}
94 | - Ho166m,100uC {26,20}
95 | - Ho166m,100uC {74,30}
96 | - isotope: I123
97 | configurations:
98 | - I123,100uC
99 | - I123,100uC {10,30}
100 | - isotope: I131
101 | configurations:
102 | - I131,100uC
103 | - I131,100uC {10,10}
104 | - I131,100uC {10,30}
105 | - I131,100uC {16,50}
106 | - I131,100uC {20,20}
107 | - I131,100uC {82,10}
108 | - I131,100uC {10,20;50,5}
109 | - isotope: In111
110 | configurations:
111 | - In111,100uC
112 | - In111,100uC {10,20}
113 | - In111,100uC {50,20}
114 | - isotope: Ir192
115 | configurations:
116 | - Ir192,100uC
117 | - Ir192,100uC {10,20}
118 | - Ir192,100uC {26,40}
119 | - Ir192,100uC {26,100}
120 | - Ir192,100uC {82,30}
121 | - Ir192,100uC {82,160}
122 | - Ir192Shielded1
123 | - Ir192Shielded2
124 | - Ir192Shielded3
125 | - isotope: K40
126 | configurations:
127 | - PotassiumInSoil
128 | - K40,100uC
129 | - K40inCargo1
130 | - K40inCargo2
131 | - K40inCargo3
132 | - K40inCargo4
133 | - isotope: Mo99
134 | configurations:
135 | - Mo99,100uC
136 | - Mo99,100uC {26,20}
137 | - Mo99,100uC {50,40}
138 | - isotope: Na22
139 | configurations:
140 | - Na22,100uC
141 | - Na22,100uC {10,10}
142 | - Na22,100uC {10,30}
143 | - Na22,100uC {74,20}
144 | - isotope: Np237
145 | configurations:
146 | - 1gNp237,1kC
147 | - 1kgNp237
148 | - Np237inFe1
149 | - Np237inFe4
150 | - Np237Shielded2
151 | - isotope: Pu238
152 | configurations:
153 | - Pu238,100uC
154 | - Pu238,100uC {10,5}
155 | - Pu238,100uC {26,10}
156 | - isotope: Pu239
157 | configurations:
158 | - 1kgPu239
159 | - 1kgPu239,1C {40,4}
160 | - 1kgPu239inFe
161 | - 1kgPu239inPine
162 | - 1kgPu239InFeAndPine
163 | - 1kgPu239inW
164 | - isotope: Ra226
165 | configurations:
166 | - UraniumInSoil
167 | - Ra226,100uC
168 | - Ra226inCargo1
169 | - Ra226inCargo2
170 | - Ra226inCargo3
171 | - Ra226inSilica1
172 | - Ra226inSilica2
173 | - Ra226,100uC {60,100}
174 | - isotope: Sr90
175 | configurations:
176 | - Sr90InPoly1,1C
177 | - Sr90InPoly10,1C
178 | - Sr90InFe,1C
179 | - Sr90InSn,1C
180 | - isotope: Tc99m
181 | configurations:
182 | - Tc99m,100uC
183 | - Tc99m,100uC {7,10}
184 | - Tc99m,100uC {10,20}
185 | - Tc99m,100uC {13,30}
186 | - Tc99m,100uC {26,30}
187 | - isotope: Th232
188 | configurations:
189 | - ThoriumInSoil
190 | - Th232,100uC
191 | - Th232inCargo1
192 | - Th232inCargo2
193 | - Th232inCargo3
194 | - Th232inSilica1
195 | - Th232inSilica2
196 | - Th232,100uC {60,100}
197 | - ThPlate
198 | - ThPlate+Thxray,10uC
199 | - isotope: Tl201
200 | configurations:
201 | - Tl201,100uC
202 | - Tl201,100uC {8,40}
203 | - Tl201,100uC {10,10}
204 | - Tl201,100uC {10,30}
205 | - Tl201,100uC {26,10}
206 | - isotope: U232
207 | configurations:
208 | - U232,100uC
209 | - U232,100uC {10,10}
210 | - U232,100uC {10,30}
211 | - U232,100uC {13,50}
212 | - U232,100uC {26,30}
213 | - U232,100uC {26,60}
214 | - U232inLeadAndPine
215 | - U232inLeadAndFe
216 | - U232inPineAndFe
217 | - U232,100uC {82,30}
218 | - isotope: U233
219 | configurations:
220 | - 1kgU233At1yr
221 | - 1kgU233InFeAt1yr
222 | - 1kgU233At50yr
223 | - 1kgU233InFeAt50yr
224 | - isotope: U235
225 | configurations:
226 | - 1kgU235
227 | - 1kgU235inFe
228 | - 1kgU235inPine
229 | - 1kgU235inPineAndFe
230 | - isotope: U238
231 | configurations:
232 | - 1KGU238
233 | - 1KGU238Clad
234 | - 1KGU238inPine
235 | - 1KGU238inPine2
236 | - 1KGU238inPine3
237 | - 1KGU238inFe
238 | - 1KGU238inPineAndFe
239 | - 1KGU238inW
240 | - U238Fiesta
241 | - fiestaware
242 | - DUOxide
243 | - isotope: Y88
244 | configurations:
245 | - Y88,100uC
246 | - Y88,100uC {10,50}
247 | - Y88,100uC {26,30}
248 | - Y88,100uC {80,50}
249 | ...
250 |
--------------------------------------------------------------------------------
/examples/courses/Primer 2/seed_configs/advanced.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | gamma_detector:
3 | name: tbd
4 | parameters:
5 | distance_cm:
6 | - 10
7 | - 100
8 | - 1000
9 | height_cm: 100
10 | dead_time_per_pulse: 10
11 | latitude_deg: 35.0
12 | longitude_deg: 253.4
13 | elevation_m: 1620
14 | sources:
15 | - isotope: Cs137
16 | configurations:
17 | - Cs137,100uCi
18 | - name: Cs137
19 | activity:
20 | - 1
21 | - 0.5
22 | activity_units: Ci
23 | shielding_atomic_number:
24 | min: 10
25 | max: 40.0
26 | dist: uniform
27 | num_samples: 5
28 | shielding_aerial_density:
29 | mean: 120
30 | std: 2
31 | num_samples: 5
32 | - isotope: Cosmic
33 | configurations:
34 | - Cosmic
35 | - isotope: K40
36 | configurations:
37 | - PotassiumInSoil
38 | - isotope: Ra226
39 | configurations:
40 | - UraniumInSoil
41 | - isotope: Th232
42 | configurations:
43 | - ThoriumInSoil
44 | ...
--------------------------------------------------------------------------------
/examples/courses/Primer 2/seed_configs/basic.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | gamma_detector:
3 | name: Generic\\NaI\\3x3\\Front\\MidScat
4 | parameters:
5 | distance_cm: 100
6 | height_cm: 10
7 | dead_time_per_pulse: 5
8 | latitude_deg: 35.0
9 | longitude_deg: 253.4
10 | elevation_m: 1620
11 | sources:
12 | - isotope: Am241
13 | configurations:
14 | - Am241,100uC
15 | - isotope: Ba133
16 | configurations:
17 | - Ba133,100uC
18 | - isotope: Cs137
19 | configurations:
20 | - Cs137,100uC
21 | - isotope: Cosmic
22 | configurations:
23 | - Cosmic
24 | - isotope: K40
25 | configurations:
26 | - PotassiumInSoil
27 | - isotope: Ra226
28 | configurations:
29 | - UraniumInSoil
30 | - isotope: Th232
31 | configurations:
32 | - ThoriumInSoil
33 | ...
--------------------------------------------------------------------------------
/examples/courses/Primer 2/seed_configs/problem.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | gamma_detector:
3 | name: Generic\\NaI\\3x3\\Front\\LowScat
4 | parameters:
5 | distance_cm: 1000
6 | height_cm: 0
7 | dead_time_per_pulse: 10
8 | latitude_deg: 35.0
9 | longitude_deg: 253.4
10 | elevation_m: 1620
11 | sources:
12 | - isotope: Am241
13 | configurations:
14 | - Am241,100uC
15 | - Am241,100uC {13,10}
16 | - Am241,100uC {13,30}
17 | - Am241,100uC {26,5}
18 | - Am241,100uC {26,20}
19 | - Am241,100uC {26,50}
20 | - Am241,100uC {50,2}
21 | - isotope: Ba133
22 | configurations:
23 | - Ba133,100uC
24 | - Ba133,100uC {10,20}
25 | - Ba133,100uC {26,10}
26 | - Ba133,100uC {26,30}
27 | - Ba133,100uC {74,20}
28 | - Ba133,100uC {50,30}
29 | - isotope: Bi207
30 | configurations:
31 | - Bi207,100uC
32 | - Bi207,100uC {10,30}
33 | - Bi207,100uC {26,10}
34 | - isotope: Cf249
35 | configurations:
36 | - Cf249,100uC
37 | - isotope: Co57
38 | configurations:
39 | - Co57,100uC
40 | - Co57,100uC {13,10}
41 | - isotope: Co60
42 | configurations:
43 | - Co60,100uC
44 | - Co60,100uC {10,10}
45 | - Co60,100uC {10,30}
46 | - Co60,100uC {26,20}
47 | - Co60,100uC {26,40}
48 | - Co60,100uC {82,30}
49 | - Co60,100uC {82,60}
50 | - isotope: Cosmic
51 | configurations:
52 | - Cosmic
53 | - isotope: Cs137
54 | configurations:
55 | - Cs137,100uC
56 | - Cs137,100uC {6,2}
57 | - Cs137,100uC {26,10}
58 | - Cs137,100uC {82,10}
59 | - Cs137,100uC {13,240;26,1}
60 | - Cs137InPine
61 | - Cs137InLead
62 | - Cs137InGradedShield1
63 | - Cs137InGradedShield2
64 | - isotope: Eu152
65 | configurations:
66 | - Eu152,100uC
67 | - Eu152,100uC {10,10}
68 | - Eu152,100uC {10,30}
69 | - Eu152,100uC {30,50}
70 | - Eu152,100uC {74,20}
71 | - isotope: Eu154
72 | configurations:
73 | - Eu154,100uC
74 | - Eu154,100uC {10,10}
75 | - Eu154,100uC {74,20}
76 | - isotope: F18
77 | configurations:
78 | - F18,100uC
79 | - F18,100uC {10,20}
80 | - F18,100uC {10,50}
81 | - F18,100uC {26,30}
82 | - isotope: Ga67
83 | configurations:
84 | - Ga67,100uC
85 | - Ga67,100uC {6,10}
86 | - Ga67,100uC {6,20}
87 | - Ga67,100uC {10,30}
88 | - Ga67,100uC {50,16}
89 | - Ga67,100uC {82,20}
90 | - isotope: Ho166m
91 | configurations:
92 | - Ho166m,100uC
93 | - Ho166m,100uC {10,20}
94 | - Ho166m,100uC {26,20}
95 | - Ho166m,100uC {74,30}
96 | - isotope: I123
97 | configurations:
98 | - I123,100uC
99 | - I123,100uC {10,30}
100 | - isotope: I131
101 | configurations:
102 | - I131,100uC
103 | - I131,100uC {10,10}
104 | - I131,100uC {10,30}
105 | - I131,100uC {16,50}
106 | - I131,100uC {20,20}
107 | - I131,100uC {82,10}
108 | - I131,100uC {10,20;50,5}
109 | - isotope: In111
110 | configurations:
111 | - In111,100uC
112 | - In111,100uC {10,20}
113 | - In111,100uC {50,20}
114 | - isotope: Ir192
115 | configurations:
116 | - Ir192,100uC
117 | - Ir192,100uC {10,20}
118 | - Ir192,100uC {26,40}
119 | - Ir192,100uC {26,100}
120 | - Ir192,100uC {82,30}
121 | - Ir192,100uC {82,160}
122 | - Ir192Shielded1
123 | - Ir192Shielded2
124 | - Ir192Shielded3
125 | - isotope: K40
126 | configurations:
127 | - PotassiumInSoil
128 | - K40,100uC
129 | - K40inCargo1
130 | - K40inCargo2
131 | - K40inCargo3
132 | - K40inCargo4
133 | - isotope: Mo99
134 | configurations:
135 | - Mo99,100uC
136 | - Mo99,100uC {26,20}
137 | - Mo99,100uC {50,40}
138 | - isotope: Na22
139 | configurations:
140 | - Na22,100uC
141 | - Na22,100uC {10,10}
142 | - Na22,100uC {10,30}
143 | - Na22,100uC {74,20}
144 | - isotope: Ra226
145 | configurations:
146 | - UraniumInSoil
147 | - Ra226,100uC
148 | - Ra226inCargo1
149 | - Ra226inCargo2
150 | - Ra226inCargo3
151 | - Ra226inSilica1
152 | - Ra226inSilica2
153 | - Ra226,100uC {60,100}
154 | - isotope: Sr90
155 | configurations:
156 | - Sr90InPoly1,1C
157 | - Sr90InPoly10,1C
158 | - Sr90InFe,1C
159 | - Sr90InSn,1C
160 | - isotope: Tc99m
161 | configurations:
162 | - Tc99m,100uC
163 | - Tc99m,100uC {7,10}
164 | - Tc99m,100uC {10,20}
165 | - Tc99m,100uC {13,30}
166 | - Tc99m,100uC {26,30}
167 | - isotope: Th232
168 | configurations:
169 | - ThoriumInSoil
170 | - Th232,100uC
171 | - Th232inCargo1
172 | - Th232inCargo2
173 | - Th232inCargo3
174 | - Th232inSilica1
175 | - Th232inSilica2
176 | - Th232,100uC {60,100}
177 | - ThPlate
178 | - ThPlate+Thxray,10uC
179 | - isotope: Tl201
180 | configurations:
181 | - Tl201,100uC
182 | - Tl201,100uC {8,40}
183 | - Tl201,100uC {10,10}
184 | - Tl201,100uC {10,30}
185 | - Tl201,100uC {26,10}
186 | - isotope: Y88
187 | configurations:
188 | - Y88,100uC
189 | - Y88,100uC {10,50}
190 | - Y88,100uC {26,30}
191 | - Y88,100uC {80,50}
192 | ...
--------------------------------------------------------------------------------
/examples/data/conversion/aipt_to_ss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to convert AIPT data files to `SampleSet`s
5 | and save as HDF5 files.
6 | """
7 | from riid.data.converters import convert_directory
8 | from riid.data.converters.aipt import convert_and_save
9 |
10 |
11 | if __name__ == "__main__":
12 | # Change the following to a valid path on your computer
13 | DIRECTORY_WITH_FILES_TO_CONVERT = "./data"
14 |
15 | convert_directory(
16 | DIRECTORY_WITH_FILES_TO_CONVERT,
17 | convert_and_save,
18 | file_ext="open",
19 | output_dir=DIRECTORY_WITH_FILES_TO_CONVERT + "/converted",
20 | )
21 |
--------------------------------------------------------------------------------
/examples/data/conversion/pcf_to_ss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to convert PCFs to `SampleSet`s
5 | and save as HDF5 files.
6 |
7 | This example shows how the utilities originally built for converting Topcoder and AIPT data
8 | can be easily repurposed to bulk convert any type of file.
9 | PyRIID just so happens to already have PCF reading available (via `read_pcf()`),
10 | but a custom file format is easily handled by implementing your own `convert_and_save()` function.
11 | """
12 | import os
13 | from pathlib import Path
14 |
15 | from riid import SAMPLESET_HDF_FILE_EXTENSION, read_pcf
16 | from riid.data.converters import (_validate_and_create_output_dir,
17 | convert_directory)
18 |
19 |
20 | def convert_and_save(input_file_path: str, output_dir: str = None,
21 | skip_existing: bool = True, **kwargs):
22 | input_path = Path(input_file_path)
23 | if not output_dir:
24 | output_dir = input_path.parent
25 | _validate_and_create_output_dir(output_dir)
26 | output_file_path = os.path.join(output_dir, input_path.stem + SAMPLESET_HDF_FILE_EXTENSION)
27 | if skip_existing and os.path.exists(output_file_path):
28 | return
29 |
30 | output_file_path = os.path.join(
31 | input_path.parent,
32 | input_path.stem + SAMPLESET_HDF_FILE_EXTENSION
33 | )
34 | ss = read_pcf(input_file_path)
35 | ss.to_hdf(output_file_path)
36 |
37 |
38 | if __name__ == "__main__":
39 | # Change the following to a valid path on your computer
40 | DIRECTORY_WITH_FILES_TO_CONVERT = "./data"
41 |
42 | convert_directory(
43 | DIRECTORY_WITH_FILES_TO_CONVERT,
44 | convert_and_save,
45 | file_ext="pcf",
46 | output_dir=DIRECTORY_WITH_FILES_TO_CONVERT + "/converted",
47 | )
48 |
--------------------------------------------------------------------------------
/examples/data/conversion/topcoder_to_ss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to convert TopCoder data files to `SampleSet`s
5 | and save as HDF5 files.
6 | """
7 | from riid.data.converters import convert_directory
8 | from riid.data.converters.topcoder import convert_and_save
9 |
10 |
11 | if __name__ == "__main__":
12 | # Change the following to a valid path on your computer
13 | DIRECTORY_WITH_FILES_TO_CONVERT = "./data"
14 |
15 | convert_directory(
16 | DIRECTORY_WITH_FILES_TO_CONVERT,
17 | convert_and_save,
18 | file_ext="csv",
19 | output_dir=DIRECTORY_WITH_FILES_TO_CONVERT + "/converted",
20 | sample_interval=1.0,
21 | pm_chunksize=50,
22 | )
23 |
--------------------------------------------------------------------------------
/examples/data/difficulty_score.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to compute the difficulty of a given SampleSet."""
5 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
6 |
7 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
8 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\
9 | .generate(1)
10 |
11 | static_synth = StaticSynthesizer(
12 | samples_per_seed=500,
13 | snr_function="uniform",
14 | )
15 | easy_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
16 |
17 | static_synth.snr_function = "log10"
18 | medium_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
19 |
20 | static_synth.snr_function_args = (.00001, .1)
21 | hard_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
22 |
23 | easy_score = easy_ss.difficulty_score
24 | print(f"Difficulty score for Uniform: {easy_score:.5f}")
25 | medium_score = medium_ss.difficulty_score
26 | print(f"Difficulty score for Log10: {medium_score:.5f}")
27 | hard_score = hard_ss.difficulty_score
28 | print(f"Difficulty score for Log10 Low Signal: {hard_score:.5f}")
29 |
--------------------------------------------------------------------------------
/examples/data/preprocessing/energy_calibration.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates energy calibration
5 | configuration for a SampleSet."""
6 | import sys
7 |
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 |
11 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
12 |
13 | SYNTHETIC_DATA_CONFIG = {
14 | "samples_per_seed": 10,
15 | "bg_cps": 10,
16 | "snr_function": "uniform",
17 | "snr_function_args": (1, 100),
18 | "live_time_function": "uniform",
19 | "live_time_function_args": (0.25, 10),
20 | "apply_poisson_noise": True,
21 | "return_fg": True,
22 | "return_gross": True
23 | }
24 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
25 |
26 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\
27 | .generate(1)
28 |
29 | fg_ss, ss = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\
30 | .generate(fg_seeds_ss, mixed_bg_seed_ss)
31 | ss.ecal_low_e = 0
32 | ss.ecal_order_0 = 0
33 | ss.ecal_order_1 = 3100
34 | ss.ecal_order_2 = 0
35 | ss.ecal_order_3 = 0
36 |
37 | index = ss[ss.get_labels() == "Cs137"]
38 | index = np.sum(index.spectra, axis=0).values.astype(int).argmax()
39 | plt.subplot(2, 1, 1)
40 | plt.step(
41 | np.arange(ss.n_channels),
42 | ss.spectra.values[index, :],
43 | where="mid"
44 | )
45 | plt.yscale("log")
46 | plt.xlim(0, ss.n_channels)
47 | plt.title(ss.get_labels()[index])
48 | plt.xlabel("Channels")
49 |
50 | energy_bins = np.linspace(0, 3100, 512)
51 | channel_energies = ss.get_channel_energies(sample_index=0,
52 | fractional_energy_bins=energy_bins)
53 |
54 | plt.subplot(2, 1, 2)
55 |
56 | plt.step(channel_energies, ss.spectra.values[index, :], where="mid")
57 | plt.fill_between(
58 | energy_bins,
59 | 0,
60 | ss.spectra.values[index, :],
61 | alpha=0.3,
62 | step="mid"
63 | )
64 |
65 | plt.yscale("log")
66 | plt.title(ss.get_labels()[index])
67 | plt.xlabel("Energy (keV)")
68 | plt.vlines(661, 1e-6, 1e8)
69 | plt.ylim(bottom=0.8, top=ss.spectra.values[index, :].max()*1.5)
70 | plt.xlim(0, energy_bins[-1])
71 | plt.subplots_adjust(hspace=1)
72 | if len(sys.argv) == 1:
73 | plt.show()
74 |
--------------------------------------------------------------------------------
/examples/data/synthesis/mix_seeds.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to randomly mix seeds together."""
5 | import numpy as np
6 |
7 | from riid import SeedMixer, get_dummy_seeds
8 |
9 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
10 |
11 | rng = np.random.default_rng(3)
12 | mixed_fg_seeds_ss = SeedMixer(fg_seeds_ss, mixture_size=2, rng=rng)\
13 | .generate(n_samples=10)
14 | mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng)\
15 | .generate(n_samples=10)
16 |
17 | print(mixed_fg_seeds_ss)
18 | print(mixed_bg_seeds_ss)
19 |
--------------------------------------------------------------------------------
/examples/data/synthesis/synthesize_passbys.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """Example of generating synthetic passby gamma spectra from seeds."""
5 | import sys
6 |
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 |
10 | from riid import PassbySynthesizer, get_dummy_seeds
11 |
12 | if len(sys.argv) == 2:
13 | import matplotlib
14 | matplotlib.use("Agg")
15 |
16 |
17 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
18 | pbs = PassbySynthesizer(
19 | sample_interval=0.5,
20 | fwhm_function_args=(5, 5),
21 | snr_function_args=(100, 100),
22 | dwell_time_function_args=(5, 5),
23 | events_per_seed=1,
24 | return_fg=False,
25 | return_gross=True,
26 | )
27 |
28 | events = pbs.generate(fg_seeds_ss, bg_seeds_ss)
29 | _, gross_passbys = list(zip(*events))
30 | passby_ss = gross_passbys[0]
31 | passby_ss.concat(gross_passbys[1:])
32 |
33 | passby_ss.sources.drop(bg_seeds_ss.sources.columns, axis=1, inplace=True)
34 | passby_ss.normalize_sources()
35 | passby_ss.normalize(p=2)
36 |
37 | plt.imshow(passby_ss.spectra.values, aspect="auto")
38 | plt.ylabel("Time")
39 | plt.xlabel("Channel")
40 | plt.show()
41 |
42 | offset = 0
43 | breaks = []
44 | labels = passby_ss.get_labels()
45 | distinct_labels = set(labels)
46 | # Iterate over set of unique labels
47 | for iso in distinct_labels:
48 | sub = passby_ss[labels == iso]
49 | plt.plot(
50 | pbs.sample_interval * (np.arange(sub.n_samples) + offset),
51 | sub.info.snr,
52 | label=f"{iso}",
53 | lw=1
54 | )
55 | offset += sub.n_samples
56 | breaks.append(offset*pbs.sample_interval)
57 |
58 | plt.vlines(breaks, 0, 1e4, linestyle="--", color="grey")
59 |
60 | plt.yscale("log")
61 | plt.ylim([.1, 1e4])
62 | plt.xlim(0, breaks[-1])
63 | plt.ylabel("SNR")
64 | plt.xlabel("Time")
65 | plt.legend(frameon=False)
66 | plt.show()
67 |
--------------------------------------------------------------------------------
/examples/data/synthesis/synthesize_seeds_advanced.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to generate synthetic seeds from GADRAS using PyRIID's
5 | configuration expansion features."""
6 | import yaml
7 |
8 | from riid import SeedSynthesizer
9 |
10 | seed_synth_config = """
11 | ---
12 | gamma_detector:
13 | name: Generic\\NaI\\3x3\\Front\\MidScat
14 | parameters:
15 | distance_cm:
16 | - 10
17 | - 100
18 | - 1000
19 | height_cm: 100
20 | dead_time_per_pulse: 5
21 | latitude_deg: 35.0
22 | longitude_deg: 253.4
23 | elevation_m: 1620
24 | sources:
25 | - isotope: Cs137
26 | configurations:
27 | - Cs137,100uCi
28 | - name: Cs137
29 | activity:
30 | - 1
31 | - 0.5
32 | activity_units: Ci
33 | shielding_atomic_number:
34 | min: 10
35 | max: 40.0
36 | dist: uniform
37 | num_samples: 5
38 | shielding_aerial_density:
39 | mean: 120
40 | std: 2
41 | num_samples: 5
42 | - isotope: Cosmic
43 | configurations:
44 | - Cosmic
45 | - isotope: K40
46 | configurations:
47 | - PotassiumInSoil
48 | - isotope: Ra226
49 | configurations:
50 | - UraniumInSoil
51 | - isotope: Th232
52 | configurations:
53 | - ThoriumInSoil
54 | ...
55 | """
56 | seed_synth_config = yaml.safe_load(seed_synth_config)
57 |
58 | try:
59 | seeds_ss = SeedSynthesizer().generate(
60 | seed_synth_config,
61 | verbose=True
62 | )
63 | print(seeds_ss)
64 |
65 | # At this point, you could save out the seeds via:
66 | seeds_ss.to_hdf("seeds.h5")
67 |
68 | # or start separating your backgrounds from foreground for use with the StaticSynthesizer
69 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg()
70 |
71 | print(fg_seeds_ss)
72 | print(bg_seeds_ss)
73 |
74 | fg_seeds_ss.to_hdf("./fg_seeds.h5")
75 | bg_seeds_ss.to_hdf("./bg_seeds.h5")
76 | except FileNotFoundError:
77 | pass # Happens when not on Windows
78 |
--------------------------------------------------------------------------------
/examples/data/synthesis/synthesize_seeds_basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to generate synthetic seeds from GADRAS."""
5 | import yaml
6 |
7 | from riid import SeedSynthesizer
8 |
9 | seed_synth_config = """
10 | ---
11 | gamma_detector:
12 | name: Generic\\NaI\\3x3\\Front\\MidScat
13 | parameters:
14 | distance_cm: 100
15 | height_cm: 10
16 | dead_time_per_pulse: 5
17 | latitude_deg: 35.0
18 | longitude_deg: 253.4
19 | elevation_m: 1620
20 | sources:
21 | - isotope: Am241
22 | configurations:
23 | - Am241,100uC
24 | - isotope: Ba133
25 | configurations:
26 | - Ba133,100uC
27 | - isotope: Cs137
28 | configurations:
29 | - Cs137,100uC
30 | - isotope: Cosmic
31 | configurations:
32 | - Cosmic
33 | - isotope: K40
34 | configurations:
35 | - PotassiumInSoil
36 | - isotope: Ra226
37 | configurations:
38 | - UraniumInSoil
39 | - isotope: Th232
40 | configurations:
41 | - ThoriumInSoil
42 | ...
43 | """
44 | seed_synth_config = yaml.safe_load(seed_synth_config)
45 |
46 | try:
47 | seeds_ss = SeedSynthesizer().generate(
48 | seed_synth_config,
49 | verbose=True
50 | )
51 | print(seeds_ss)
52 |
53 | # At this point, you could save out the seeds via:
54 | seeds_ss.to_hdf("seeds.h5")
55 |
56 | # or start separating your backgrounds from foreground for use with the StaticSynthesizer
57 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg()
58 |
59 | print(fg_seeds_ss)
60 | print(bg_seeds_ss)
61 |
62 | fg_seeds_ss.to_hdf("./fg_seeds.h5")
63 | bg_seeds_ss.to_hdf("./bg_seeds.h5")
64 | except FileNotFoundError:
65 | pass # Happens when not on Windows
66 |
--------------------------------------------------------------------------------
/examples/data/synthesis/synthesize_seeds_custom.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates additional ways to generate synthetic seeds from GADRAS:
5 | - Example 1: inject everything in a folder ending in .gam
6 | - Example 2: build and inject point sources comprised of multiple radioisotopes
7 | """
8 | from pathlib import Path
9 |
10 | import numpy as np
11 | import pandas as pd
12 | import yaml
13 |
14 | from riid import SeedSynthesizer
15 | from riid.gadras.api import GADRAS_INSTALL_PATH
16 |
17 |
18 | def convert_df_row_to_inject_string(row):
19 | """Converts a row of the DataFrame to a proper GADRAS inject string"""
20 | isotopes = row.index.values
21 | activities = row.values
22 | isotopes_and_activities = [
23 | f"{iso},{round(act, 4)}uCi" for iso, act in zip(isotopes, activities)
24 | ] # Note the "uCi" specified here. You may need to change it.
25 | inject_string = " + ".join(isotopes_and_activities)
26 | return inject_string
27 |
28 |
29 | seed_synth_config = """
30 | ---
31 | gamma_detector:
32 | name: Generic\\NaI\\3x3\\Front\\MidScat
33 | parameters:
34 | distance_cm: 100
35 | height_cm: 10
36 | dead_time_per_pulse: 5
37 | latitude_deg: 35.0
38 | longitude_deg: 253.4
39 | elevation_m: 1620
40 | sources:
41 | - isotope: U235
42 | configurations: null
43 | ...
44 | """
45 | seed_synth_config = yaml.safe_load(seed_synth_config)
46 |
47 | try:
48 | seed_synth = SeedSynthesizer()
49 |
50 | # Example 1
51 | # Change "Continuum" to your own source directory
52 | gam_dir = Path(GADRAS_INSTALL_PATH).joinpath("Source/Continuum")
53 | gam_filenames = [x.stem for x in gam_dir.glob("*.gam")]
54 | seed_synth_config["sources"][0]["configurations"] = gam_filenames
55 | seeds = seed_synth.generate(seed_synth_config)
56 | seeds.to_hdf("seeds_from_gams.h5")
57 |
58 | # Example 2
59 | # For the following DataFrame, columns are isotopes, rows are samples, and cells are activity
60 | df = pd.DataFrame(
61 | np.random.rand(25, 5), # Reminder: ensure all activity values are in the same units
62 | columns=["Am241", "Ba133", "Co60", "Cs137", "U235"],
63 | )
64 | configurations = df.apply(convert_df_row_to_inject_string, axis=1).to_list()
65 | seed_synth_config["sources"][0]["configurations"] = configurations
66 |
67 | seeds = seed_synth.generate(seed_synth_config)
68 | seeds.to_hdf("seeds_from_df.h5")
69 | except FileNotFoundError:
70 | pass # Happens when not on Windows
71 |
--------------------------------------------------------------------------------
/examples/data/synthesis/synthesize_static.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to generate synthetic gamma spectra from seeds."""
5 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
6 |
7 | SYNTHETIC_DATA_CONFIG = {
8 | "samples_per_seed": 10000,
9 | "bg_cps": 10,
10 | "snr_function": "uniform",
11 | "snr_function_args": (1, 100),
12 | "live_time_function": "uniform",
13 | "live_time_function_args": (0.25, 10),
14 | "apply_poisson_noise": True,
15 | "return_fg": True,
16 | "return_gross": True,
17 | }
18 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
19 |
20 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\
21 | .generate(1)
22 |
23 | static_synth = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)
24 | fg_ss, gross_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
25 | """ | |
26 | | |
27 | | |> gross samples
28 | |> source-only samples
29 | """
30 | print(fg_ss)
31 | print(gross_ss)
32 |
--------------------------------------------------------------------------------
/examples/modeling/anomaly_detection.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to obtain events using an anomaly detection
5 | algorithm.
6 | """
7 | import sys
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from matplotlib import cm
12 |
13 | from riid import PassbySynthesizer, SeedMixer, get_dummy_seeds
14 | from riid.anomaly import PoissonNChannelEventDetector
15 |
16 | if len(sys.argv) == 2:
17 | import matplotlib
18 | matplotlib.use("Agg")
19 |
20 | SAMPLE_INTERVAL = 0.5
21 | BG_RATE = 300
22 | EXPECTED_BG_COUNTS = SAMPLE_INTERVAL * BG_RATE
23 | SHORT_TERM_DURATION = 1.5
24 | POST_EVENT_DURATION = 1.5
25 | N_POST_EVENT_SAMPLES = (POST_EVENT_DURATION + SAMPLE_INTERVAL) / SAMPLE_INTERVAL
26 |
27 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
28 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\
29 | .generate(1)
30 |
31 | ed = PoissonNChannelEventDetector(
32 | long_term_duration=600,
33 | short_term_duration=SHORT_TERM_DURATION,
34 | pre_event_duration=5,
35 | max_event_duration=120,
36 | post_event_duration=POST_EVENT_DURATION,
37 | tolerable_false_alarms_per_day=2,
38 | anomaly_threshold_update_interval=60,
39 | )
40 | count_rate_history = []
41 |
42 | # Fill background buffer first
43 | print("Filling background")
44 | measurement_id = 0
45 | expected_bg_measurement = mixed_bg_seed_ss.spectra.iloc[0] * EXPECTED_BG_COUNTS
46 | while ed.background_percent_complete < 100:
47 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement)
48 | count_rate = noisy_bg_measurement.sum() / SAMPLE_INTERVAL
49 | count_rate_history.append(count_rate)
50 | _ = ed.add_measurement(
51 | measurement_id,
52 | noisy_bg_measurement,
53 | SAMPLE_INTERVAL
54 | )
55 | measurement_id += 1
56 |
57 | # Now let's see how many false alarms we get.
58 | # You may want to make this duration higher in order to get a better statistic.
59 | print("Checking false alarm rate")
60 | FALSE_ALARM_CHECK_DURATION = 3 * 60 * 60
61 | false_alarm_check_range = range(
62 | measurement_id,
63 | int(FALSE_ALARM_CHECK_DURATION / SAMPLE_INTERVAL) + 1
64 | )
65 | false_alarms = 0
66 | for measurement_id in false_alarm_check_range:
67 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement)
68 | count_rate = noisy_bg_measurement.sum() / SAMPLE_INTERVAL
69 | count_rate_history.append(count_rate)
70 | event_result = ed.add_measurement(
71 | measurement_id,
72 | noisy_bg_measurement,
73 | SAMPLE_INTERVAL
74 | )
75 | if event_result:
76 | false_alarms += 1
77 | measurement_id += 1
78 | false_alarm_rate = 60 * 60 * false_alarms / FALSE_ALARM_CHECK_DURATION
79 | print(f"False alarm rate: {false_alarm_rate:.2f}/hour")
80 |
81 | # Now let's make a passby
82 | print("Generating pass-by")
83 | events = PassbySynthesizer(
84 | fwhm_function_args=(1,),
85 | snr_function_args=(30, 30),
86 | dwell_time_function_args=(1, 1),
87 | events_per_seed=1,
88 | sample_interval=SAMPLE_INTERVAL,
89 | bg_cps=BG_RATE,
90 | return_fg=False,
91 | return_gross=True,
92 | ).generate(fg_seeds_ss, mixed_bg_seed_ss)
93 | _, gross_events = list(zip(*events))
94 | passby_ss = gross_events[0]
95 |
96 | print("Passing by...")
97 | passby_begin_idx = measurement_id
98 | passby_end_idx = passby_ss.n_samples + measurement_id
99 | passby_range = list(range(passby_begin_idx, passby_end_idx))
100 | for i, measurement_id in enumerate(passby_range):
101 | gross_spectrum = passby_ss.spectra.iloc[i].values
102 | count_rate = gross_spectrum.sum() / SAMPLE_INTERVAL
103 | count_rate_history.append(count_rate)
104 | event_result = ed.add_measurement(
105 | measurement_id=measurement_id,
106 | measurement=gross_spectrum,
107 | duration=SAMPLE_INTERVAL,
108 | )
109 | if event_result:
110 | break
111 |
112 | # A little extra background to close out any pending event
113 | if ed.event_in_progress:
114 | while not event_result:
115 | measurement_id += 1
116 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement)
117 | count_rate = noisy_bg_measurement.sum() / SAMPLE_INTERVAL
118 | count_rate_history.append(count_rate)
119 | event_result = ed.add_measurement(
120 | measurement_id,
121 | noisy_bg_measurement,
122 | SAMPLE_INTERVAL
123 | )
124 |
125 | count_rate_history = np.array(count_rate_history)
126 | if event_result:
127 | event_measurement, event_bg_measurement, event_duration, event_measurement_ids = event_result
128 | print(f"Event Duration: {event_duration}")
129 | event_begin, event_end = event_measurement_ids[0], event_measurement_ids[-1]
130 | start_idx = int(passby_begin_idx - 30 / SAMPLE_INTERVAL) # include some lead up to event
131 | y = count_rate_history[start_idx:measurement_id]
132 | x = np.array(range(start_idx, measurement_id)) * SAMPLE_INTERVAL
133 | fix, ax = plt.subplots()
134 | ax.plot(x, y)
135 | ax.axvspan(
136 | xmin=event_begin * SAMPLE_INTERVAL,
137 | xmax=event_end * SAMPLE_INTERVAL,
138 | facecolor=cm.tab10(1),
139 | alpha=0.35,
140 | )
141 | ax.set_ylabel("Count rate (cps)")
142 | ax.set_xlabel("Time (sec)")
143 | plt.show()
144 | else:
145 | print("Pass-by did not produce an event")
146 |
--------------------------------------------------------------------------------
/examples/modeling/arad.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to use the PyRIID implementations of ARAD.
5 | """
6 | import numpy as np
7 | import pandas as pd
8 |
9 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
10 | from riid.models import ARADv1, ARADv2
11 |
12 | # Config
13 | rng = np.random.default_rng(42)
14 | OOD_QUANTILE = 0.99
15 | VERBOSE = False
16 | # Some of the following parameters are set low because this example runs on GitHub Actions and
17 | # we don't want it taking a bunch of time.
18 | # When running this locally, change the values per their corresponding comment, otherwise
19 | # the results likely will not be meaningful.
20 | EPOCHS = 5 # Change this to 20+
21 | N_MIXTURES = 50 # Changes this to 1000+
22 | TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+
23 | TEST_SAMPLES_PER_SEED = 5
24 |
25 | # Generate training data
26 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=128, rng=rng).split_fg_and_bg()
27 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng).generate(N_MIXTURES)
28 | static_synth = StaticSynthesizer(
29 | samples_per_seed=TRAIN_SAMPLES_PER_SEED,
30 | snr_function_args=(0, 0),
31 | return_fg=False,
32 | return_gross=True,
33 | rng=rng,
34 | )
35 | _, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
36 | gross_train_ss.normalize()
37 |
38 | # Generate test data
39 | static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED
40 | _, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
41 | test_ss.normalize()
42 |
43 | # Train the models
44 | results = {}
45 | models = [ARADv1, ARADv2]
46 | for model_class in models:
47 | arad = model_class()
48 | model_name = arad.__class__.__name__
49 |
50 | print(f"Training and testing {model_name}...")
51 | arad.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
52 | arad.predict(gross_train_ss)
53 | ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE)
54 |
55 | reconstructions = arad.predict(test_ss, verbose=VERBOSE)
56 | ood = test_ss.info.recon_error.values > ood_threshold
57 | false_positive_rate = ood.mean()
58 | mean_recon_error = test_ss.info.recon_error.values.mean()
59 |
60 | results[model_name] = {
61 | "ood_threshold": f"{ood_threshold:.4f}",
62 | "mean_recon_error": mean_recon_error,
63 | "false_positive_rate": false_positive_rate,
64 | }
65 |
66 | print(f"Target False Positive Rate: {1-OOD_QUANTILE:.4f}")
67 | print(pd.DataFrame.from_dict(results))
68 |
--------------------------------------------------------------------------------
/examples/modeling/arad_latent_prediction.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to train a regressor or classifier branch
5 | from an ARAD latent space.
6 | """
7 | import numpy as np
8 | from keras.api.metrics import Accuracy, CategoricalCrossentropy
9 | from sklearn.metrics import f1_score, mean_squared_error
10 |
11 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
12 | from riid.models import ARADLatentPredictor, ARADv2
13 |
14 | # Config
15 | rng = np.random.default_rng(42)
16 | VERBOSE = False
17 | # Some of the following parameters are set low because this example runs on GitHub Actions and
18 | # we don't want it taking a bunch of time.
19 | # When running this locally, change the values per their corresponding comment, otherwise
20 | # the results likely will not be meaningful.
21 | EPOCHS = 5 # Change this to 20+
22 | N_MIXTURES = 50 # Change this to 1000+
23 | TRAIN_SAMPLES_PER_SEED = 5 # Change this to 20+
24 | TEST_SAMPLES_PER_SEED = 5
25 |
26 | # Generate training data
27 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=128, rng=rng).split_fg_and_bg()
28 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng).generate(N_MIXTURES)
29 | static_synth = StaticSynthesizer(
30 | samples_per_seed=TRAIN_SAMPLES_PER_SEED,
31 | snr_function_args=(0, 0),
32 | return_fg=False,
33 | return_gross=True,
34 | rng=rng,
35 | )
36 | _, gross_train_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
37 | gross_train_ss.normalize()
38 |
39 | # Generate test data
40 | static_synth.samples_per_seed = TEST_SAMPLES_PER_SEED
41 | _, test_ss = static_synth.generate(fg_seeds_ss[0], mixed_bg_seed_ss)
42 | test_ss.normalize()
43 |
44 | # Train ARAD model
45 | print("Training ARAD")
46 | arad_v2 = ARADv2()
47 | arad_v2.fit(gross_train_ss, epochs=EPOCHS, verbose=VERBOSE)
48 |
49 | # Train regressor to predict SNR
50 | print("Training Regressor")
51 | arad_regressor = ARADLatentPredictor()
52 | _ = arad_regressor.fit(
53 | arad_v2.model,
54 | gross_train_ss,
55 | target_info_columns=["live_time"],
56 | epochs=10,
57 | batch_size=5,
58 | verbose=VERBOSE,
59 | )
60 | regression_predictions = arad_regressor.predict(test_ss)
61 | regression_score = mean_squared_error(gross_train_ss.info.live_time, regression_predictions)
62 | print("Regressor MSE: {:.3f}".format(regression_score))
63 |
64 | # Train classifier to predict isotope
65 | print("Training Classifier")
66 | arad_classifier = ARADLatentPredictor(
67 | loss="categorical_crossentropy",
68 | metrics=[Accuracy(), CategoricalCrossentropy()],
69 | final_activation="softmax"
70 | )
71 | arad_classifier.fit(
72 | arad_v2.model,
73 | gross_train_ss,
74 | target_level="Isotope",
75 | epochs=10,
76 | batch_size=5,
77 | verbose=VERBOSE,
78 | )
79 | arad_classifier.predict(test_ss)
80 | classification_score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro")
81 | print("Classification F1 Score: {:.3f}".format(classification_score))
82 |
--------------------------------------------------------------------------------
/examples/modeling/classifier_comparison.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates a comparison between Poisson Bayes and MLP classifiers."""
5 | import sys
6 |
7 | import matplotlib.pyplot as plt
8 | from sklearn.metrics import f1_score
9 |
10 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
11 | from riid.metrics import precision_recall_curve
12 | from riid.models import MLPClassifier, PoissonBayesClassifier
13 | from riid.visualize import plot_precision_recall
14 |
15 | if len(sys.argv) == 2:
16 | import matplotlib
17 | matplotlib.use("Agg")
18 |
19 | # Generate some training data
20 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds(n_channels=64).split_fg_and_bg()
21 |
22 |
23 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(10)
24 |
25 | static_synth = StaticSynthesizer(
26 | samples_per_seed=100,
27 | live_time_function_args=(1, 10),
28 | snr_function="log10",
29 | snr_function_args=(1, 20),
30 | return_fg=True,
31 | return_gross=True,
32 | )
33 | train_fg_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss, verbose=False)
34 | train_fg_ss.normalize()
35 |
36 | model_nn = MLPClassifier()
37 | model_nn.fit(train_fg_ss, epochs=10, patience=5)
38 |
39 | # Create PB model
40 | model_pb = PoissonBayesClassifier()
41 | model_pb.fit(fg_seeds_ss)
42 |
43 | # Generate some test data
44 | static_synth.samples_per_seed = 50
45 | test_fg_ss, test_gross_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss,
46 | verbose=False)
47 | test_bg_ss = test_gross_ss - test_fg_ss
48 | test_fg_ss.normalize()
49 | test_gross_ss.sources.drop(bg_seeds_ss.sources.columns, axis=1, inplace=True)
50 | test_gross_ss.normalize_sources()
51 |
52 | # Plot
53 | fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
54 | results = {}
55 | for model, tag, ax in zip([model_nn, model_pb], ["NN", "PB"], axs):
56 | if tag == "NN":
57 | labels = test_fg_ss.get_labels()
58 | model.predict(test_fg_ss)
59 | predictions = test_fg_ss.get_predictions()
60 | precision, recall, thresholds = precision_recall_curve(test_fg_ss)
61 | elif tag == "PB":
62 | labels = test_gross_ss.get_labels()
63 | model.predict(test_gross_ss, test_bg_ss)
64 | predictions = test_gross_ss.get_predictions()
65 | precision, recall, thresholds = precision_recall_curve(test_gross_ss)
66 | else:
67 | raise ValueError()
68 |
69 | score = f1_score(labels, predictions, average="weighted")
70 | print(f"{tag} F1-score: {score:.3f}")
71 |
72 | plot_precision_recall(
73 | precision=precision,
74 | recall=recall,
75 | title=f"{tag}\nPrecision VS Recall",
76 | fig_ax=(fig, ax),
77 | show=False,
78 | )
79 |
80 | plt.show()
81 |
--------------------------------------------------------------------------------
/examples/modeling/custom_loss_and_metric.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates custom loss functions and metrics."""
5 | import tensorflow as tf
6 |
7 | from riid.losses import negative_log_f1
8 | from riid.metrics import multi_f1
9 |
10 | y_true = tf.constant([.524, .175, .1, .1, 0, .1])
11 | y_pred = tf.constant([.2, .2, .2, .1, .2, .1])
12 | f1 = multi_f1(y_true, y_pred)
13 | loss = negative_log_f1(y_true, y_pred)
14 | print(f"F1 Score: {f1:.3f}")
15 | print(f"Loss: {loss:.3f}")
16 |
--------------------------------------------------------------------------------
/examples/modeling/label_proportion_estimation.py:
--------------------------------------------------------------------------------
1 | """This example demonstrates how to train the Label Proportion
2 | Estimator with a semi-supervised loss function."""
3 | import os
4 |
5 | from sklearn.metrics import mean_absolute_error
6 |
7 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
8 | from riid.models import LabelProportionEstimator
9 |
10 | # Generate some mixture training data.
11 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
12 | mixed_bg_seeds_ss = SeedMixer(
13 | bg_seeds_ss,
14 | mixture_size=3,
15 | dirichlet_alpha=2
16 | ).generate(100)
17 |
18 | static_syn = StaticSynthesizer(
19 | samples_per_seed=100,
20 | bg_cps=300.0,
21 | live_time_function_args=(60, 600),
22 | snr_function_args=(0, 0),
23 | return_fg=False,
24 | return_gross=True,
25 | )
26 |
27 | _, bg_ss = static_syn.generate(fg_seeds_ss[0], mixed_bg_seeds_ss)
28 | bg_ss.drop_sources_columns_with_all_zeros()
29 | bg_ss.normalize()
30 |
31 | # Create the model
32 | model = LabelProportionEstimator(
33 | hidden_layers=(64,),
34 | # The supervised loss can either be "sparsemax"
35 | # or "categorical_crossentropy".
36 | sup_loss="categorical_crossentropy",
37 | # The unsupervised loss be "poisson_nll", "normal_nll",
38 | # "sse", or "weighted_sse".
39 | unsup_loss="poisson_nll",
40 | # This controls the tradeoff between the sup
41 | # and unsup losses.,
42 | beta=1e-4,
43 | optimizer="RMSprop",
44 | learning_rate=1e-2,
45 | hidden_layer_activation="relu",
46 | dropout=0.05,
47 | )
48 |
49 | # Train the model.
50 | model.fit(
51 | bg_seeds_ss,
52 | bg_ss,
53 | batch_size=10,
54 | epochs=2,
55 | validation_split=0.2,
56 | bg_cps=300
57 | )
58 |
59 | # Generate some test data.
60 | static_syn.samples_per_seed = 50
61 | _, test_bg_ss = static_syn.generate(fg_seeds_ss[0], mixed_bg_seeds_ss)
62 | test_bg_ss.normalize(p=1)
63 | test_bg_ss.drop_sources_columns_with_all_zeros()
64 |
65 | model.predict(test_bg_ss)
66 |
67 | test_meas = mean_absolute_error(
68 | test_bg_ss.sources.values,
69 | test_bg_ss.prediction_probas.values
70 | )
71 | print(f"Mean Test MAE: {test_meas.mean():.3f}")
72 |
73 | # Save model
74 | model_path = "./model.json"
75 | model.save(model_path, overwrite=True)
76 |
77 | loaded_model = LabelProportionEstimator()
78 | loaded_model.load(model_path)
79 |
80 | loaded_model.predict(test_bg_ss)
81 | test_maes = mean_absolute_error(
82 | test_bg_ss.sources.values,
83 | test_bg_ss.prediction_probas.values
84 | )
85 |
86 | print(f"Mean Test MAE: {test_maes.mean():.3f}")
87 |
88 | # Clean up model file - remove this if you want to keep the model
89 | os.remove(model_path)
90 |
--------------------------------------------------------------------------------
/examples/modeling/neural_network_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to use the MLP classifier."""
5 | import numpy as np
6 | from sklearn.metrics import f1_score
7 |
8 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
9 | from riid.models import MLPClassifier
10 |
11 | # Generate some training data
12 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
13 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(1)
14 |
15 | static_synth = StaticSynthesizer(
16 | samples_per_seed=100,
17 | snr_function="log10",
18 | return_fg=False,
19 | return_gross=True,
20 | )
21 | _, train_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
22 | train_ss.normalize()
23 |
24 | model = MLPClassifier()
25 | model.fit(train_ss, epochs=10, patience=5)
26 |
27 | # Generate some test data
28 | static_synth.samples_per_seed = 50
29 | _, test_ss = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss)
30 | test_ss.normalize()
31 |
32 | # Predict
33 | model.predict(test_ss)
34 |
35 | score = f1_score(test_ss.get_labels(), test_ss.get_predictions(), average="micro")
36 | print("F1 Score: {:.3f}".format(score))
37 |
38 | # Get confidences
39 | confidences = test_ss.get_confidences(
40 | fg_seeds_ss,
41 | bg_seed_ss=mixed_bg_seed_ss,
42 | bg_cps=300
43 | )
44 | print(f"Avg Confidence: {np.mean(confidences):.3f}")
45 |
--------------------------------------------------------------------------------
/examples/run_examples.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """Runs all Python files in the subfolders and reports whether they are successful."""
5 | import os
6 | import subprocess
7 | import sys
8 | from pathlib import Path
9 |
10 | import pandas as pd
11 | from tabulate import tabulate
12 |
13 | DIRS_TO_CHECK = ["data", "modeling", "visualization"]
14 | FILENAME_KEY = "File"
15 | RESULT_KEY = "Result"
16 | SUCCESS_STR = "Success"
17 | FAILURE_STR = "Fail"
18 |
19 | original_wdir = os.getcwd()
20 | example_dir = Path(__file__).parent
21 | os.chdir(example_dir)
22 |
23 | files_to_run = []
24 | for d in DIRS_TO_CHECK:
25 | file_paths = list(Path(d).rglob("*.py"))
26 | files_to_run.extend(file_paths)
27 | files_to_run = sorted(files_to_run)
28 |
29 | results = {}
30 | n_tests = len(files_to_run)
31 | for i, f in enumerate(files_to_run, start=1):
32 | print(f"Running example {i}/{n_tests}")
33 | return_code = 0
34 | output = None
35 | try:
36 | output = subprocess.check_output(f"python {f} hide",
37 | stderr=subprocess.STDOUT,
38 | shell=True)
39 |
40 | except subprocess.CalledProcessError as e:
41 | if (bytes("Error", "utf-8")) in e.output:
42 | print(e.output.decode())
43 | return_code = e.returncode
44 |
45 | results[i] = {
46 | FILENAME_KEY: os.path.relpath(f, example_dir),
47 | RESULT_KEY: SUCCESS_STR if not return_code else FAILURE_STR
48 | }
49 | os.chdir(original_wdir)
50 |
51 | df = pd.DataFrame.from_dict(results, orient="index")
52 | tabulated_df = tabulate(df, headers="keys", tablefmt="psql")
53 | print(tabulated_df)
54 |
55 | all_succeeded = all([x[RESULT_KEY] == SUCCESS_STR for x in results.values()])
56 | sys.exit(0 if all_succeeded else 1)
57 |
--------------------------------------------------------------------------------
/examples/visualization/confusion_matrix.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to obtain confusion matrices."""
5 | import sys
6 |
7 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
8 | from riid.models import MLPClassifier
9 | from riid.visualize import confusion_matrix
10 |
11 | if len(sys.argv) == 2:
12 | import matplotlib
13 | matplotlib.use("Agg")
14 |
15 | SYNTHETIC_DATA_CONFIG = {
16 | "snr_function": "log10",
17 | "snr_function_args": (.01, 10),
18 | }
19 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
20 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\
21 | .generate(1)
22 |
23 | train_ss, _ = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\
24 | .generate(fg_seeds_ss, mixed_bg_seed_ss)
25 | train_ss.normalize()
26 |
27 | model = MLPClassifier()
28 | model.fit(train_ss, verbose=0, epochs=50)
29 |
30 | # Generate some test data
31 | SYNTHETIC_DATA_CONFIG = {
32 | "snr_function": "log10",
33 | "snr_function_args": (.01, 10),
34 | "samples_per_seed": 50,
35 | }
36 |
37 | test_ss, _ = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\
38 | .generate(fg_seeds_ss, mixed_bg_seed_ss)
39 | test_ss.normalize()
40 |
41 | # Predict and evaluate
42 | model.predict(test_ss)
43 |
44 | fig, ax = confusion_matrix(test_ss, show=True)
45 |
--------------------------------------------------------------------------------
/examples/visualization/distance_matrix.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to generate a distance matrix that
5 | compares every pair of spectra in a SampleSet.
6 | """
7 | import sys
8 |
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 |
12 | from riid import get_dummy_seeds
13 |
14 | if len(sys.argv) == 2:
15 | import matplotlib
16 | matplotlib.use("Agg")
17 |
18 | seeds_ss = get_dummy_seeds(n_channels=16)
19 | distance_df = seeds_ss.get_spectral_distance_matrix()
20 |
21 | sns.set(rc={"figure.figsize": (10, 7)})
22 | ax = sns.heatmap(
23 | distance_df,
24 | cbar_kws={"label": "Jensen-Shannon Distance"},
25 | vmin=0,
26 | vmax=1
27 | )
28 | _ = ax.set_title("Comparing All Seed Pairs Using Jensen-Shannon Distance")
29 | fig = ax.get_figure()
30 | fig.tight_layout()
31 | plt.show()
32 |
--------------------------------------------------------------------------------
/examples/visualization/plot_sampleset_compare_to.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to compare sample sets."""
5 | import sys
6 |
7 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
8 | from riid.visualize import plot_ss_comparison
9 |
10 | if len(sys.argv) == 2:
11 | import matplotlib
12 | matplotlib.use("Agg")
13 |
14 | SYNTHETIC_DATA_CONFIG = {
15 | "samples_per_seed": 100,
16 | "bg_cps": 100,
17 | "snr_function": "uniform",
18 | "snr_function_args": (1, 100),
19 | "live_time_function": "uniform",
20 | "live_time_function_args": (0.25, 10),
21 | "apply_poisson_noise": True,
22 | "return_fg": False,
23 | "return_gross": True,
24 | }
25 | fg_seeds_ss, bg_seeds_ss = get_dummy_seeds()\
26 | .split_fg_and_bg()
27 | mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\
28 | .generate(1)
29 |
30 | _, gross_ss1 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\
31 | .generate(fg_seeds_ss, mixed_bg_seed_ss)
32 | _, gross_ss2 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\
33 | .generate(fg_seeds_ss, mixed_bg_seed_ss)
34 |
35 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2,
36 | density=False)
37 | plot_ss_comparison(ss1_stats,
38 | ss2_stats,
39 | col_comparisons,
40 | "live_time",
41 | show=True)
42 |
43 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2,
44 | density=True)
45 | plot_ss_comparison(ss1_stats,
46 | ss2_stats,
47 | col_comparisons,
48 | "total_counts",
49 | show=True)
50 |
--------------------------------------------------------------------------------
/examples/visualization/plot_spectra.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This example demonstrates how to plot gamma spectra."""
5 | import sys
6 |
7 | from riid import get_dummy_seeds
8 | from riid.visualize import plot_spectra
9 |
10 | if len(sys.argv) == 2:
11 | import matplotlib
12 | matplotlib.use("Agg")
13 |
14 | seeds_ss = get_dummy_seeds()
15 |
16 | plot_spectra(seeds_ss, ylim=(None, None), in_energy=True)
17 |
--------------------------------------------------------------------------------
/pdoc/config.mako:
--------------------------------------------------------------------------------
1 | <%!
2 | # Template configuration. Copy over in your template directory
3 | # (used with `--template-dir`) and adapt as necessary.
4 | # Note, defaults are loaded from this distribution file, so your
5 | # config.mako only needs to contain values you want overridden.
6 | # You can also run pdoc with `--config KEY=VALUE` to override
7 | # individual values.
8 | html_lang = 'en'
9 |
10 | show_inherited_members = False
11 |
12 | extract_module_toc_into_sidebar = True
13 |
14 | list_class_variables_in_index = True
15 |
16 | sort_identifiers = True
17 |
18 | show_type_annotations = True
19 |
20 | # Show collapsed source code block next to each item.
21 | # Disabling this can improve rendering speed of large modules.
22 | show_source_code = True
23 |
24 | # If set, format links to objects in online source code repository
25 | # according to this template. Supported keywords for interpolation
26 | # are: commit, path, start_line, end_line.
27 | git_link_template = 'https://github.com/sandialabs/pyriid/blob/{commit}/{path}#L{start_line}-L{end_line}'
28 |
29 | # A prefix to use for every HTML hyperlink in the generated documentation.
30 | # No prefix results in all links being relative.
31 | link_prefix = ''
32 |
33 | # Enable syntax highlighting for code/source blocks by including Highlight.js
34 | syntax_highlighting = True
35 | # Set the style keyword such as 'atom-one-light' or 'github-gist'
36 | # Options: https://github.com/highlightjs/highlight.js/tree/master/src/styles
37 | # Demo: https://highlightjs.org/static/demo/
38 | hljs_style = 'github'
39 |
40 | # If set, insert Google Analytics tracking code. Value is GA
41 | # tracking id (UA-XXXXXX-Y).
42 | google_analytics = ''
43 |
44 | # If set, insert Google Custom Search search bar widget above the sidebar index.
45 | # The whitespace-separated tokens represent arbitrary extra queries (at least one
46 | # must match) passed to regular Google search. Example:
47 | #google_search_query = 'inurl:github.com/USER/PROJECT site:PROJECT.github.io site:PROJECT.website'
48 | google_search_query = ''
49 |
50 | # Enable offline search using Lunr.js. For explanation of 'fuzziness' parameter, which is
51 | # added to every query word, see: https://lunrjs.com/guides/searching.html#fuzzy-matches
52 | # If 'index_docstrings' is False, a shorter index is built, indexing only
53 | # the full object reference names.
54 | lunr_search = {'fuzziness': 1, 'index_docstrings': True}
55 |
56 | # If set, render LaTeX math syntax within \(...\) (inline equations),
57 | # or within \[...\] or $$...$$ or `.. math::` (block equations)
58 | # as nicely-formatted math formulas using MathJax.
59 | # Note: in Python docstrings, either all backslashes need to be escaped (\\)
60 | # or you need to use raw r-strings.
61 | latex_math = True
62 | %>
--------------------------------------------------------------------------------
/pdoc/requirements.txt:
--------------------------------------------------------------------------------
1 | pdoc3==0.10.0
2 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 68", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.setuptools.packages.find]
6 | include = ["riid*"]
7 | namespaces = false
8 |
9 | [tool.setuptools.package-data]
10 | "riid.gadras" = ["*.json"]
11 |
12 | [project]
13 | name = "riid"
14 | description = "Machine learning-based models and utilities for radioisotope identification"
15 | version = "2.2.0"
16 | maintainers = [
17 | {name="Tyler Morrow", email="tmorro@sandia.gov"},
18 | ]
19 | authors = [
20 | {name="Tyler Morrow"},
21 | {name="Nathan Price"},
22 | {name="Travis McGuire"},
23 | {name="Tyler Ganter"},
24 | {name="Aislinn Handley"},
25 | {name="Paul Thelen"},
26 | {name="Alan Van Omen"},
27 | {name="Leon Ross"},
28 | {name="Alyshia Bustos"},
29 | ]
30 | readme = "README.md"
31 | license = {file = "LICENSE.md"}
32 | classifiers = [
33 | 'Development Status :: 5 - Production/Stable',
34 | 'Intended Audience :: Developers',
35 | 'Intended Audience :: Education',
36 | 'Intended Audience :: Science/Research',
37 | 'License :: OSI Approved :: BSD License',
38 | 'Topic :: Scientific/Engineering',
39 | 'Topic :: Scientific/Engineering :: Mathematics',
40 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
41 | 'Topic :: Software Development',
42 | 'Topic :: Software Development :: Libraries',
43 | 'Topic :: Software Development :: Libraries :: Python Modules',
44 | 'Programming Language :: Python',
45 | 'Programming Language :: Python :: 3',
46 | 'Programming Language :: Python :: 3.10',
47 | 'Programming Language :: Python :: 3.11',
48 | 'Programming Language :: Python :: 3.12',
49 | ]
50 | keywords = ["pyriid", "riid", "machine learning", "radioisotope identification", "gamma spectrum"]
51 |
52 | requires-python = ">=3.10,<3.13"
53 | dependencies = [
54 | "jsonschema ==4.23.*", # 3.8 - 3.13
55 | "matplotlib ==3.9.*", # 3.9 - 3.12
56 | "numpy ==1.26.*", # 3.9 - 3.12, also to be limited by onnx 1.16.2
57 | "pandas ==2.2.*", # >= 3.9
58 | "pythonnet ==3.0.3; platform_system == 'Windows'", # 3.7 - 3.12
59 | "pyyaml ==6.0.*", # >= 3.6
60 | "tables ==3.10.*", # >= 3.9
61 | "scikit-learn ==1.5.*", # 3.9 - 3.12
62 | "scipy ==1.13.*", # >= 3.10
63 | "seaborn ==0.13.*", # >= 3.8
64 | "keras ==3.8.0",
65 | "tensorflow ==2.16.*", # 3.9 - 3.12
66 | "tensorflow-model-optimization ==0.8.*", # 3.7 - 3.12
67 | "onnx ==1.16.1", # 3.7 - 3.10
68 | "tf2onnx ==1.16.1", # 3.7 - 3.10
69 | "tqdm ==4.66.*", # >= 3.7
70 | "typeguard ==4.3.*", # 3.9 - 3.12
71 | ]
72 |
73 | [project.optional-dependencies]
74 | dev = [
75 | "coverage",
76 | "ipykernel",
77 | "flake8",
78 | "flake8-quotes",
79 | "tabulate",
80 | ]
81 |
82 | [project.urls]
83 | Documentation = "https://sandialabs.github.io/PyRIID"
84 | Repository = "https://github.com/sandialabs/PyRIID"
85 |
--------------------------------------------------------------------------------
/riid/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """
5 | .. include:: ../README.md
6 | """
7 | import logging
8 | import os
9 | import sys
10 | from importlib.metadata import version
11 |
12 | from riid.data.sampleset import (SampleSet, SpectraState, SpectraType,
13 | read_hdf, read_json, read_pcf)
14 | from riid.data.synthetic.passby import PassbySynthesizer
15 | from riid.data.synthetic.seed import (SeedMixer, SeedSynthesizer,
16 | get_dummy_seeds)
17 | from riid.data.synthetic.static import StaticSynthesizer
18 |
19 | HANDLER = logging.StreamHandler(sys.stdout)
20 | logging.root.addHandler(HANDLER)
21 | logging.root.setLevel(logging.DEBUG)
22 | MPL_LOGGER = logging.getLogger("matplotlib")
23 | MPL_LOGGER.setLevel(logging.WARNING)
24 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
25 |
26 | SAMPLESET_HDF_FILE_EXTENSION = ".h5"
27 | SAMPLESET_JSON_FILE_EXTENSION = ".json"
28 | PCF_FILE_EXTENSION = ".pcf"
29 | ONNX_MODEL_FILE_EXTENSION = ".onnx"
30 | TFLITE_MODEL_FILE_EXTENSION = ".tflite"
31 | RIID = "riid"
32 |
33 | __version__ = version(RIID)
34 |
35 | __pdoc__ = {
36 | "riid.data.synthetic.seed.SeedMixer.__call__": True,
37 | "riid.data.synthetic.passby.PassbySynthesizer._generate_single_passby": True,
38 | "riid.data.sampleset.SampleSet._channels_to_energies": True,
39 | }
40 |
41 | __all__ = ["SampleSet", "SpectraState", "SpectraType",
42 | "read_hdf", "read_json", "read_pcf", "get_dummy_seeds",
43 | "PassbySynthesizer", "SeedSynthesizer", "StaticSynthesizer", "SeedMixer"]
44 |
--------------------------------------------------------------------------------
/riid/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This sub-package contains all utilities for synthesizing, reading, writing, and converting data.
5 | """
6 | import numpy as np
7 |
8 |
9 | def get_expected_spectra(seeds: np.ndarray, expected_counts: np.ndarray) -> np.ndarray:
10 | """Multiply a 1-D array of expected counts by either a 1-D array or 2-D
11 | matrix of seed spectra.
12 |
13 | The dimension(s) of the seed array(s), `seeds`, is expanded to be `(m, n, 1)` where:
14 |
15 | - m = # of seeds
16 | - n = # of channels
17 |
18 | and the final dimension is added in order to facilitate proper broadcasting.
19 | The dimension of the `expected_counts` must be 1, but the length `p` can be
20 | any positive number.
21 |
22 | The resulting expected spectra will be of shape `(m x p, n)`.
23 | This represents the same number of channels `n`, but each expected count
24 | value, of which there are `p`, will be me multiplied through each seed spectrum,
25 | of which there are `m`.
26 | All expected spectra matrices for each seed are then concatenated together
27 | (stacked), eliminating the 3rd dimension.
28 | """
29 | if expected_counts.ndim != 1:
30 | raise ValueError("Expected counts array must be 1-D.")
31 | if expected_counts.shape[0] == 0:
32 | raise ValueError("Expected counts array cannot be empty.")
33 | if seeds.ndim > 2:
34 | raise InvalidSeedError("Seeds array must be 1-D or 2-D.")
35 |
36 | expected_spectra = np.concatenate(
37 | seeds * expected_counts[:, np.newaxis, np.newaxis]
38 | )
39 |
40 | return expected_spectra
41 |
42 |
43 | class InvalidSeedError(Exception):
44 | """Seed spectra data structure is not 1- or 2-dimensional."""
45 | pass
46 |
--------------------------------------------------------------------------------
/riid/data/converters/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains utilities for converting known datasets into `SampleSet`s."""
5 | import glob
6 | from pathlib import Path
7 | from typing import Callable
8 |
9 | from joblib import Parallel, delayed
10 |
11 |
12 | def _validate_and_create_output_dir(output_dir: str):
13 | output_dir_path = Path(output_dir)
14 | output_dir_path.mkdir(exist_ok=True)
15 | if not output_dir_path.is_dir:
16 | raise ValueError("`output_dir` already exists but is not a directory.")
17 |
18 |
19 | def convert_directory(input_dir_path: str, conversion_func: Callable, file_ext: str,
20 | n_jobs: int = 8, **kwargs):
21 | """Convert and save every file in a specified directory.
22 |
23 | Conversion functions can be found in sub-modules:
24 |
25 | - AIPT: `riid.data.converters.aipt.convert_and_save()`
26 | - TopCoder: `riid.data.converters.topcoder.convert_and_save()`
27 |
28 | Due to usage of parallel processing, be sure to run this function as follows:
29 |
30 | ```python
31 | if __name__ == "__main__":
32 | convert_directory(...)
33 | ```
34 |
35 | Tip: for max utilization, considering setting `n_jobs` to `multiprocessing.cpu_count()`.
36 |
37 | Args:
38 | input_dir_path: directory path containing the input files
39 | conversion_func: function used to convert a data file to a `SampleSet`
40 | file_ext: file extension to read in for conversion
41 | n_jobs: `joblib.Parallel` parameter to set the # of jobs
42 | kwargs: additional keyword args passed to conversion_func
43 | """
44 | input_path = Path(input_dir_path)
45 | if not input_path.exists() or not input_path.is_dir():
46 | print(f"No directory at provided input path: '{input_dir_path}'")
47 | return
48 |
49 | input_file_paths = sorted(glob.glob(f"{input_dir_path}/*.{file_ext}"))
50 |
51 | Parallel(n_jobs, verbose=10)(
52 | delayed(conversion_func)(path, **kwargs) for path in input_file_paths
53 | )
54 |
--------------------------------------------------------------------------------
/riid/data/converters/aipt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module provides tools for handling data related to the multi-lab Algorithm Improvement
5 | Program Team (AIPT).
6 | """
7 | import os
8 | from pathlib import Path
9 | from typing import List
10 |
11 | import pandas as pd
12 |
13 | from riid import SAMPLESET_HDF_FILE_EXTENSION, SampleSet
14 | from riid.data.converters import _validate_and_create_output_dir
15 |
16 | ELEMENT_IDS_PER_FILE = [0, 1, 2, 3]
17 | DEFAULT_ECAL = [
18 | -3.00000,
19 | 3010.00000,
20 | 150.00000,
21 | 0.00000,
22 | 0.00000,
23 | ]
24 |
25 |
26 | def _element_to_ss(data_df: pd.DataFrame, eid: int, description) -> SampleSet:
27 | if eid not in ELEMENT_IDS_PER_FILE:
28 | msg = (
29 | f"Element #{eid} is invalid. "
30 | "The available options are: {ELEMENT_IDS_PER_FILE}"
31 | )
32 | raise ValueError(msg)
33 |
34 | ss = SampleSet()
35 | ss.spectra = data_df[f"spectrum-channels{eid}"]\
36 | .str\
37 | .split(",", expand=True)\
38 | .astype(int)
39 | ss.info.live_time = data_df[f"spectrum-lt{eid}"] / 1000
40 | ss.info.real_time = data_df[f"spectrum-rt{eid}"] / 1000
41 | ss.info.total_counts = data_df[f"gc{eid}"]
42 | ss.info.neutron_counts = data_df["nc0"]
43 | ss.info.ecal_order_0 = DEFAULT_ECAL[0]
44 | ss.info.ecal_order_1 = DEFAULT_ECAL[1]
45 | ss.info.ecal_order_2 = DEFAULT_ECAL[2]
46 | ss.info.ecal_order_3 = DEFAULT_ECAL[3]
47 | ss.info.ecal_low_e = DEFAULT_ECAL[4]
48 | ss.info.timestamp = data_df["utc-time"]
49 | ss.info.description = description
50 | ss.info["latitude"] = data_df["latitude"]
51 | ss.info["longitude"] = data_df["longitude"]
52 | ss.info["is_in_zone"] = data_df["is-in-zone"]
53 | ss.info["is_closest_approach"] = data_df["is-closest-approach"]
54 | ss.info["is_source_present"] = data_df["is-source-present"]
55 |
56 | detector_name = f"{data_df['detector'].unique()[0]}.{eid}"
57 | ss.detector_info = {
58 | "name": detector_name
59 | }
60 |
61 | return ss
62 |
63 |
64 | def aipt_file_to_ss_list(file_path: str) -> List[SampleSet]:
65 | """Process an AIPT CSV file into a list of SampleSets.
66 |
67 | Each file contains a series of spectra for multiple detectors running simultaneously.
68 | As such, each `SampleSet` in the list returned by this function represents the data
69 | collected by each detector.
70 | Each row of each `SampleSet` represents a measurement from each detector at the same
71 | moment in time.
72 | As such, after calling this function, you might consider summing all four spectra at
73 | each timestep into a single spectrum as another processing step.
74 |
75 | Args:
76 | file_path: file path of the CSV file
77 |
78 | Returns:
79 | List of `SampleSet`s each containing a series of spectra for a single run
80 | """
81 | data_df = pd.read_csv(file_path, header=0, sep="\t")
82 | base_description = os.path.splitext(os.path.basename(file_path))[0]
83 |
84 | ss_list = []
85 | for eid in ELEMENT_IDS_PER_FILE:
86 | description = f"{base_description}_{eid}"
87 | ss = _element_to_ss(data_df, eid, description)
88 | ss_list.append(ss)
89 |
90 | return ss_list
91 |
92 |
93 | def convert_and_save(input_file_path: str, output_dir: str = None,
94 | skip_existing: bool = True, **kwargs):
95 | """Convert AIPT file to SampleSet and save as HDF.
96 |
97 | Output file will have same name but appended with a detector identifier
98 | and having a different extension.
99 |
100 | Args:
101 | input_file_path: file path of the CSV file
102 | output_dir: alternative directory in which to save HDF files
103 | (defaults to `input_file_path` parent if not provided)
104 | skip_existing: whether to skip conversion if the file already exists
105 | kwargs: keyword args passed to `aipt_file_to_ss_list()` (not currently used)
106 | """
107 | input_path = Path(input_file_path)
108 | if not output_dir:
109 | output_dir = input_path.parent
110 | _validate_and_create_output_dir(output_dir)
111 | output_file_paths = [
112 | os.path.join(output_dir, input_path.stem + f"-{i}{SAMPLESET_HDF_FILE_EXTENSION}")
113 | for i in ELEMENT_IDS_PER_FILE
114 | ]
115 | all_output_files_exist = all([os.path.exists(p) for p in output_file_paths])
116 | if skip_existing and all_output_files_exist:
117 | return
118 |
119 | ss_list = aipt_file_to_ss_list(input_file_path, **kwargs)
120 | for output_file_path, ss in zip(output_file_paths, ss_list):
121 | ss.to_hdf(output_file_path)
122 |
--------------------------------------------------------------------------------
/riid/data/converters/topcoder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module provides tools for handling data related to the
5 | "Detecting Radiological Threats in Urban Areas" TopCoder Challenge.
6 | https://doi.org/10.1038/s41597-020-00672-2
7 | """
8 | import csv
9 | import logging
10 | import os
11 | from pathlib import Path
12 |
13 | import numpy as np
14 | import pandas as pd
15 |
16 | from riid import SAMPLESET_HDF_FILE_EXTENSION, SampleSet
17 | from riid.data.converters import _validate_and_create_output_dir
18 | from riid.data.labeling import label_to_index_element
19 |
20 | SOURCE_ID_TO_LABEL = {
21 | 0: "Background",
22 | 1: "HEU",
23 | 2: "WGPu",
24 | 3: "I131",
25 | 4: "Co60",
26 | 5: "Tc99m",
27 | 6: "HEU + Tc99m",
28 | }
29 | DISTINCT_SOURCES = list(SOURCE_ID_TO_LABEL.values())
30 |
31 |
32 | def _get_answers(answer_file_path: str):
33 | answers = {}
34 | with open(answer_file_path) as csvfile:
35 | reader = csv.reader(csvfile, delimiter=",")
36 | _ = next(reader) # skip header
37 | # timestamp = 0 # in milliseconds
38 | for row in reader:
39 | run_id = row[0]
40 | source_id = int(row[1])
41 | source_time_secs = float(row[2])
42 | answers[run_id] = {
43 | "source_id": source_id,
44 | "source_time_secs": source_time_secs
45 | }
46 | return answers
47 |
48 |
49 | def topcoder_file_to_ss(file_path: str, sample_interval: float, n_bins: int = 1024,
50 | max_energy_kev: int = 4000, answers_path: str = None) -> SampleSet:
51 | """Convert a TopCoder CSV file of list-mode data into a SampleSet.
52 |
53 | Args:
54 | file_path: file path of the CSV file
55 | sample_interval: integration time (referred to as "live time" later) to use in seconds.
56 | Warning: the final sample in the set will likely be truncated, i.e., the count rate
57 | will appear low because the live time represented is too large.
58 | Consider ignoring the last sample.
59 | n_bins: desired number of bins in the resulting spectra.
60 | Bins will be uniformly spaced from 0 to `max_energy_kev`.
61 | max_energy_kev: desired maximum of the energy range represented in the resulting spectra.
62 | Intuition (assuming a fixed number of bins): a higher max energy value "compresses" the
63 | spectral information; a lower max energy value spreads out the spectral information and
64 | counts are potentially lost off the high energy end of the specturm.
65 | answers_path: path to the answer key for the data. If provided, this will fill out the
66 | `SampleSet.sources` DataFrame.
67 |
68 | Returns:
69 | `SampleSet` containing the series of spectra for a single run
70 | """
71 | file_name_with_dir = os.path.splitext(file_path)[0]
72 | file_name = os.path.basename(file_name_with_dir)
73 | slice_duration_ms = sample_interval * 1000 # in milliseconds
74 | events = []
75 | with open(file_path) as csvfile:
76 | reader = csv.reader(csvfile, delimiter=",")
77 | timestamp = 0 # in milliseconds
78 | for row in reader:
79 | timestamp += int(row[0]) / 1000 # microseconds to milliseconds
80 | energy = float(row[1])
81 | if energy > max_energy_kev:
82 | msg = (
83 | f"Encountered energy ({energy:.2f} keV) greater than "
84 | f"specified max energy ({max_energy_kev:.2f})"
85 | )
86 | logging.warn(msg)
87 | channel = int(n_bins * energy // max_energy_kev) # energy to bin
88 | events.append((timestamp, channel, 1))
89 |
90 | events_df = pd.DataFrame(
91 | events,
92 | columns=["timestamp", "channel", "counts"]
93 | )
94 | # Organize events into time intervals
95 | event_time_intervals = pd.cut(
96 | events_df["timestamp"],
97 | np.arange(
98 | start=0,
99 | stop=events_df["timestamp"].max()+slice_duration_ms,
100 | step=slice_duration_ms
101 | )
102 | )
103 | # Group events by time intervals
104 | event_time_groups = events_df.groupby(event_time_intervals)
105 | # Within time intervals, sum counts by channel
106 | result = event_time_groups.apply(
107 | lambda x: x.groupby("channel").sum().loc[:, ["counts"]]
108 | )
109 | # Create new dataframe where: row = time interval, column = channel
110 | spectra_df = result.unstack(level=-1, fill_value=0)
111 | spectra_df = spectra_df["counts"]
112 | # Add in missing columns as needed
113 | col_list = np.arange(0, n_bins)
114 | cols_to_add = np.setdiff1d(col_list, spectra_df.columns)
115 | missing_df = pd.DataFrame(
116 | 0,
117 | columns=cols_to_add,
118 | index=spectra_df.index
119 | )
120 | combined_unsorted_df = pd.concat([spectra_df, missing_df], axis=1)
121 | combined_df = combined_unsorted_df.reindex(
122 | sorted(combined_unsorted_df.columns),
123 | axis=1
124 | )
125 | spectra_df = combined_df.astype(int)
126 | spectra_df.reset_index(inplace=True, drop=True)
127 |
128 | # SampleSet creation
129 | ss = SampleSet()
130 | ss.measured_or_synthetic = "synthetic"
131 | ss.detector_info = {
132 | "name": "2\"x4\"x16\" NaI(Tl)",
133 | "height_cm": 100,
134 | "fwhm_at_661_kev": 0.075,
135 | }
136 | ss.spectra = spectra_df
137 | ss.info.total_counts = spectra_df.sum(axis=1)
138 | ss.info.live_time = sample_interval
139 | ss.info.description = file_name
140 | ss.info.ecal_order_0 = 0
141 | ss.info.ecal_order_1 = max_energy_kev
142 | ss.info.ecal_order_2 = 0
143 | ss.info.ecal_order_3 = 0
144 | ss.info.ecal_low_e = 0
145 |
146 | if answers_path:
147 | answers = _get_answers(answers_path)
148 | run_id = file_name.split("runID-")[-1].split(".")[0]
149 | ss.info.timestamp = answers[run_id]["source_time_secs"]
150 | source_id = answers[run_id]["source_id"]
151 | sources_mi = pd.MultiIndex.from_tuples([
152 | label_to_index_element(x, label_level="Seed")
153 | for x in DISTINCT_SOURCES
154 | ], names=SampleSet.SOURCES_MULTI_INDEX_NAMES)
155 | sources_df = pd.DataFrame(
156 | np.zeros((ss.n_samples, len(DISTINCT_SOURCES))),
157 | columns=sources_mi,
158 | )
159 | sources_df.sort_index(axis=1, inplace=True)
160 | source = label_to_index_element(SOURCE_ID_TO_LABEL[source_id], label_level="Seed")
161 | sources_df[source] = 1.0
162 | ss.sources = sources_df
163 |
164 | return ss
165 |
166 |
167 | def convert_and_save(input_file_path: str, output_dir: str = None,
168 | skip_existing: bool = True, **kwargs):
169 | """Convert TopCoder file to SampleSet and save as HDF.
170 |
171 | Output file will have same name with different extension.
172 |
173 | Args:
174 | input_file_path: file path of the CSV file
175 | output_dir: alternative directory in which to save HDF files
176 | (defaults to `input_file_path` parent if not provided)
177 | skip_existing: whether to skip conversion if the file already exists
178 | kwargs: keyword args passed to `topcoder_file_to_ss()`
179 | """
180 | input_path = Path(input_file_path)
181 | if not output_dir:
182 | output_dir = input_path.parent
183 | _validate_and_create_output_dir(output_dir)
184 | output_file_path = os.path.join(output_dir, input_path.stem + SAMPLESET_HDF_FILE_EXTENSION)
185 | if skip_existing and os.path.exists(output_file_path):
186 | return
187 |
188 | ss = topcoder_file_to_ss(input_file_path, **kwargs)
189 | ss.to_hdf(output_file_path)
190 |
--------------------------------------------------------------------------------
/riid/data/labeling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains utility functions for managing ground truth information."""
5 | import logging
6 | import re
7 |
8 | BACKGROUND_LABEL = "Background"
9 | NO_SEED = "Unknown"
10 | NO_ISOTOPE = "Unknown"
11 | NO_CATEGORY = "Uncategorized"
12 | CATEGORY_ISOTOPES = {
13 | "Fission Product": {
14 | "severity": 3,
15 | "isotopes": [
16 | "Ag112",
17 | "As78",
18 | "Ba139",
19 | "Ce143",
20 | "I132",
21 | "I133",
22 | "I134",
23 | "I135",
24 | "Kr85m",
25 | "Kr87",
26 | "La140",
27 | "La142",
28 | "Nd149",
29 | "Pm150",
30 | "Rh105",
31 | "Ru105",
32 | "Sb115",
33 | "Sb129",
34 | "Sr91",
35 | "Sr92",
36 | "Te132",
37 | "Y93",
38 | "Y91m",
39 | "Zr95",
40 | ]
41 | },
42 | "Industrial": {
43 | "severity": 2,
44 | "isotopes": [
45 | "Am241",
46 | "Ba133",
47 | "Ba140",
48 | "Bi207",
49 | "Cf249",
50 | "Cf250",
51 | "Cf251",
52 | "Cf252",
53 | "Cm244",
54 | "Co57",
55 | "Co60",
56 | "Cs137",
57 | "Eu152",
58 | "Eu154",
59 | "H3",
60 | "Ho166m",
61 | "Ir192",
62 | "Na22",
63 | "P32",
64 | "P33",
65 | "Po210",
66 | "Se75",
67 | "Sr90",
68 | "Tc99",
69 | "Y88",
70 | ],
71 | },
72 | "Medical": {
73 | "severity": 3,
74 | "isotopes": [
75 | "F18",
76 | "Ga67",
77 | "Ga68",
78 | "Ge68",
79 | "I123",
80 | "I124",
81 | "I125",
82 | "I129",
83 | "I131",
84 | "In111",
85 | "Lu177m",
86 | "Mo99",
87 | "Pd103",
88 | "Ra223",
89 | "Rb82",
90 | "Sm153",
91 | "Tc99m",
92 | "Tl201",
93 | "Xe133",
94 | ],
95 | },
96 | "NORM": {
97 | "severity": 1,
98 | "isotopes": [
99 | "Cosmic",
100 | "K40",
101 | "Pb210",
102 | "Ra226",
103 | "Th232",
104 | ],
105 | },
106 | "SNM": {
107 | "severity": 4,
108 | "isotopes": [
109 | "Np237",
110 | "Pu238",
111 | "Pu239",
112 | "U232",
113 | "U233",
114 | "U235",
115 | "U237",
116 | # 16,000 years from now, once the chemically separated uranium has equilibrated
117 | # with Ra226, then we will need to reconsider U238's categorization.
118 | "U238",
119 | ],
120 | },
121 | }
122 | ISOTOPES = sum(
123 | [c["isotopes"] for c in CATEGORY_ISOTOPES.values()],
124 | []
125 | ) # Concatenating the lists of isotopes into one list
126 | SEED_TO_ISOTOPE_SPECIAL_CASES = {
127 | "ThPlate": "Th232",
128 | "ThPlate+Thxray,10uC": "Th232",
129 | "fiestaware": "U238",
130 | "Uxray,100uC": "U238",
131 | "DUOxide": "U238",
132 | "ShieldedDU": "U238",
133 | "modified_berpball": "Pu239",
134 | "10 yr WGPu in Fe": "Pu239",
135 | "1gPuWG_0.5yr,3{an=10,ad=5}": "Pu239",
136 | "pu239_1yr": "Pu239",
137 | "pu239_5yr": "Pu239",
138 | "pu239_10yr": "Pu239",
139 | "pu239_25yr": "Pu239",
140 | "pu239_50yr": "Pu239",
141 | "1kg HEU + 800uCi Cs137": "U235",
142 | "WGPu + Cs137": "Pu239",
143 | "HEU": "U235",
144 | "HEU + Tc99m": "U235",
145 | "DU": "U238",
146 | "WGPu": "Pu239",
147 | "PuWG": "Pu239",
148 | "RTG": "Pu238",
149 | "PotassiumInSoil": "K40",
150 | "UraniumInSoil": "Ra226",
151 | "ThoriumInSoil": "Th232",
152 | "Cosmic": "Cosmic",
153 | }
154 |
155 |
156 | def _find_isotope(seed: str, verbose=True):
157 | """Attempt to find the category for the given seed.
158 |
159 | Args:
160 | seed: string containing the isotope name
161 | verbose: whether log warnings
162 |
163 | Returns:
164 | Isotope if found, otherwise NO_ISOTOPE
165 | """
166 | if seed.lower() == BACKGROUND_LABEL.lower():
167 | return BACKGROUND_LABEL
168 | if seed == NO_ISOTOPE:
169 | return NO_ISOTOPE
170 |
171 | isotopes = []
172 | for i in ISOTOPES:
173 | if i in seed:
174 | isotopes.append(i)
175 |
176 | n_isotopes = len(isotopes)
177 | if n_isotopes > 1:
178 | # Use the longest matching isotope (handles sources strings for things like Tc99 vs Tc99m)
179 | chosen_match = max(isotopes)
180 | if verbose:
181 | logging.warning((
182 | f"Found multiple isotopes whose names are subsets of '{seed}';"
183 | f" '{chosen_match}' was chosen."
184 | ))
185 | return chosen_match
186 | elif n_isotopes == 0:
187 | return NO_ISOTOPE
188 | else:
189 | return isotopes[0]
190 |
191 |
192 | def _find_category(isotope: str):
193 | """Attempt to find the category for the given isotope.
194 |
195 | Args:
196 | isotope: string containing the isotope name
197 |
198 | Returns:
199 | Category if found, otherwise NO_CATEGORY
200 | """
201 | if isotope.lower() == BACKGROUND_LABEL.lower():
202 | return BACKGROUND_LABEL
203 | if isotope == NO_CATEGORY:
204 | return NO_CATEGORY
205 |
206 | categories = []
207 | for c, v in CATEGORY_ISOTOPES.items():
208 | c_severity = v["severity"]
209 | c_isotopes = v["isotopes"]
210 | for i in c_isotopes:
211 | if i in isotope:
212 | categories.append((c, c_severity))
213 |
214 | n_categories = len(categories)
215 | if n_categories > 1:
216 | return max(categories, key=lambda x: x[1])[0]
217 | elif n_categories == 0:
218 | return NO_CATEGORY
219 | else:
220 | return categories[0][0]
221 |
222 |
223 | def label_to_index_element(label_val: str, label_level="Isotope", verbose=False) -> tuple:
224 | """Try to map a label to a tuple for use in `DataFrame` `MultiIndex` columns.
225 |
226 | Depending on the level of the label value, you will get different tuple:
227 |
228 | | Label Level | Resulting Tuple |
229 | |:-------------|:-------------------------|
230 | |Seed |(Category, Isotope, Seed) |
231 | |Isotope |(Category, Isotope) |
232 | |Category |(Category,) |
233 |
234 | Args:
235 | label_val: part of the label (Category, Isotope, Seed)
236 | from which to map the other two label values, if possible
237 | label_level: level of the part of the label provided, e.g,
238 | "Category", "Isotope", or "Seed"
239 |
240 | Returns:
241 | Tuple containing the Category, Isotope, and/or Seed values identified
242 | for the old label format.
243 | """
244 |
245 | old_label = label_val.strip()
246 | # Some files use 'background', others use 'Background'.
247 | if old_label.lower() == BACKGROUND_LABEL.lower():
248 | old_label = BACKGROUND_LABEL
249 |
250 | if label_level == "Category":
251 | return (old_label,)
252 |
253 | if label_level == "Isotope":
254 | category = _find_category(old_label)
255 | return (category, old_label)
256 |
257 | if label_level == "Seed":
258 | special_cases = SEED_TO_ISOTOPE_SPECIAL_CASES.keys()
259 | exact_match = old_label in special_cases
260 | first_partial_match = None
261 | if not exact_match:
262 | first_partial_match = next((x for x in special_cases if x in old_label), None)
263 | if exact_match:
264 | isotope = SEED_TO_ISOTOPE_SPECIAL_CASES[old_label]
265 | elif first_partial_match:
266 | isotope = SEED_TO_ISOTOPE_SPECIAL_CASES[first_partial_match]
267 | else:
268 | isotope = _find_isotope(old_label, verbose)
269 | category = _find_category(isotope)
270 | return (category, isotope, old_label)
271 |
272 |
273 | def isotope_name_is_valid(isotope: str):
274 | """Validate whether the given string contains a properly formatted radioisotope name.
275 |
276 | Note that this function does NOT look up a string to determine if the string corresponds
277 | to a radioisotope that actually exists, it just checks the format.
278 |
279 | The regular expression used by this function looks for the following (in order):
280 |
281 | - 1 capital letter
282 | - 0 to 1 lowercase letters
283 | - 1 to 3 numbers
284 | - an optional "m" for metastable
285 |
286 | Examples of properly formatted isotope names:
287 |
288 | - Y88
289 | - Ba133
290 | - Ho166m
291 |
292 | Args:
293 | isotope: string containing the isotope name
294 |
295 | Returns:
296 | Bool representing whether the name string is valid
297 | """
298 | validator = re.compile(r"^[A-Z]{1}[a-z]{0,1}[0-9]{1,3}m?$")
299 | other_valid_names = ["fiestaware"]
300 | match = validator.match(isotope)
301 | is_valid = match is not None or \
302 | isotope.lower() in other_valid_names
303 | return is_valid
304 |
--------------------------------------------------------------------------------
/riid/data/synthetic/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains utilities for synthesizing gamma spectra."""
5 | # The following imports are left to not break previous imports; remove in v3
6 | from riid.data.synthetic.base import Synthesizer, get_distribution_values
7 | from riid.data.synthetic.seed import get_dummy_seeds
8 |
9 | __all__ = ["get_dummy_seeds", "Synthesizer", "get_distribution_values"]
10 |
--------------------------------------------------------------------------------
/riid/gadras/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module wraps various GADRAS features by calling into the DLLs of a GADRAS installation."""
5 |
--------------------------------------------------------------------------------
/riid/gadras/api_schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "$id": "GADRAS_API_19.2.3_Seed_Synthesis_Schema",
3 | "$schema": "https://json-schema.org/draft/2020-12/schema",
4 | "title": "",
5 | "type": "object",
6 | "required": [
7 | "gamma_detector",
8 | "sources"
9 | ],
10 | "properties": {
11 | "random_seed": {
12 | "description": "The numpy random seed.",
13 | "type": "integer"
14 | },
15 | "gamma_detector": {
16 | "title": "Gamma Detector",
17 | "description": "The information about your detector.",
18 | "type": "object",
19 | "properties": {
20 | "name": {
21 | "description": "The directory containing your detector's DRF (.dat file) relative to the GADRAS Detector directory.",
22 | "type": "string"
23 | },
24 | "parameters": {
25 | "type": "object",
26 | "additionalProperties": true,
27 | "properties": {
28 | "distance_cm": {
29 | "description": "Distance between detector and sourcecentimeters",
30 | "$ref": "#/$defs/detector_properties_types"
31 | },
32 | "height_cm": {
33 | "description": "Detector height off ground, in centimeters",
34 | "$ref": "#/$defs/detector_properties_types"
35 | },
36 | "dead_time_per_pulse": {
37 | "description": "Detector dead time, in microseconds",
38 | "$ref": "#/$defs/detector_properties_types"
39 | },
40 | "latitude_deg": {
41 | "description": "Latitude, in degrees",
42 | "$ref": "#/$defs/detector_properties_types"
43 | },
44 | "longitude_deg": {
45 | "description": "Longitude, in degrees",
46 | "$ref": "#/$defs/detector_properties_types"
47 | },
48 | "elevation_m": {
49 | "description": "Elevation, in meters",
50 | "$ref": "#/$defs/detector_properties_types"
51 | }
52 | },
53 | "required": [
54 | "distance_cm",
55 | "height_cm",
56 | "dead_time_per_pulse",
57 | "latitude_deg",
58 | "longitude_deg",
59 | "elevation_m"
60 | ]
61 | }
62 | },
63 | "required": [
64 | "name",
65 | "parameters"
66 | ]
67 | },
68 | "sources": {
69 | "title": "Sources",
70 | "description": "The list of sources to obtain via inject(s).",
71 | "type": "array",
72 | "items": {
73 | "type": "object",
74 | "properties": {
75 | "isotope": {
76 | "type": "string"
77 | },
78 | "configurations": {
79 | "type": "array",
80 | "items": {
81 | "anyOf": [
82 | {
83 | "type": "string"
84 | },
85 | {
86 | "$ref": "#/$defs/source_config_type"
87 | }
88 | ]
89 | }
90 | }
91 | },
92 | "additionalProperties": false
93 | }
94 | }
95 | },
96 | "$defs": {
97 | "detector_properties_types": {
98 | "anyOf": [
99 | {
100 | "type": "number"
101 | },
102 | {
103 | "type": "array",
104 | "items": {
105 | "anyOf": [
106 | {
107 | "type": "number"
108 | },
109 | {
110 | "$ref": "#/$defs/sample_range"
111 | },
112 | {
113 | "$ref": "#/$defs/sample_norm"
114 | }
115 | ]
116 | }
117 | },
118 | {
119 | "$ref": "#/$defs/sample_range"
120 | },
121 | {
122 | "$ref": "#/$defs/sample_norm"
123 | }
124 | ]
125 | },
126 | "sample_range": {
127 | "type": "object",
128 | "required": [
129 | "min",
130 | "max",
131 | "dist",
132 | "num_samples"
133 | ],
134 | "properties": {
135 | "min": {
136 | "type": "number",
137 | "description": "Minimum value of the range (inclusive)."
138 | },
139 | "max": {
140 | "type": "number",
141 | "description": "Maximum value of the range (inclusive)."
142 | },
143 | "dist": {
144 | "type": "string",
145 | "description": "The distribution from which to draw samples.",
146 | "enum": [
147 | "uniform",
148 | "log10"
149 | ]
150 | },
151 | "num_samples": {
152 | "type": "number",
153 | "description": "Number of samples to draw"
154 | }
155 | },
156 | "additionalProperties": false
157 | },
158 | "sample_norm": {
159 | "type": "object",
160 | "required": [
161 | "mean",
162 | "std",
163 | "num_samples"
164 | ],
165 | "properties": {
166 | "mean": {
167 | "type": "number",
168 | "description": "Mean of the normal distribution from which to draw samples."
169 | },
170 | "std": {
171 | "type": "number",
172 | "description": "Standard deviation of the normal distribution from which to draw samples."
173 | },
174 | "num_samples": {
175 | "type": "number",
176 | "description": "Number of samples to draw"
177 | }
178 | },
179 | "additionalProperties": false
180 | },
181 | "source_config_type": {
182 | "type": "object",
183 | "required": [
184 | "name"
185 | ],
186 | "properties": {
187 | "name": {
188 | "type": "string"
189 | },
190 | "activity": {
191 | "anyOf": [
192 | {
193 | "type": "number"
194 | },
195 | {
196 | "type": "array",
197 | "items": {
198 | "type": "number"
199 | }
200 | },
201 | {
202 | "$ref": "#/$defs/sample_range"
203 | },
204 | {
205 | "$ref": "#/$defs/sample_norm"
206 | }
207 | ]
208 | },
209 | "activity_units": {
210 | "type": "string",
211 | "enum": [
212 | "Ci",
213 | "uCi",
214 | "Bq"
215 | ]
216 | },
217 | "shielding_atomic_number": {
218 | "anyOf": [
219 | {
220 | "type": "number"
221 | },
222 | {
223 | "type": "array",
224 | "items": {
225 | "type": "number"
226 | }
227 | },
228 | {
229 | "$ref": "#/$defs/sample_range"
230 | },
231 | {
232 | "$ref": "#/$defs/sample_norm"
233 | }
234 | ]
235 | },
236 | "shielding_aerial_density": {
237 | "anyOf": [
238 | {
239 | "type": "number"
240 | },
241 | {
242 | "type": "array",
243 | "items": {
244 | "type": "number"
245 | }
246 | },
247 | {
248 | "$ref": "#/$defs/sample_range"
249 | },
250 | {
251 | "$ref": "#/$defs/sample_norm"
252 | }
253 | ]
254 | }
255 | },
256 | "additionalProperties": false
257 | }
258 | }
259 | }
--------------------------------------------------------------------------------
/riid/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains custom loss functions."""
5 | import numpy as np
6 | import tensorflow as tf
7 | from keras.api import ops
8 |
9 |
10 | def negative_log_f1(y_true: np.ndarray, y_pred: np.ndarray):
11 | """Calculate negative log F1 score.
12 |
13 | Args:
14 | y_true: list of ground truth
15 | y_pred: list of predictions to compare against the ground truth
16 |
17 | Returns:
18 | Custom loss score on a log scale
19 | """
20 | diff = y_true - y_pred
21 | negs = ops.clip(diff, -1.0, 0.0)
22 | false_positive = -ops.sum(negs, axis=-1)
23 | true_positive = 1.0 - false_positive
24 | lower_clip = 1e-20
25 | true_positive = ops.clip(true_positive, lower_clip, 1.0)
26 |
27 | return -ops.mean(ops.log(true_positive))
28 |
29 |
30 | def negative_f1(y_true, y_pred):
31 | """Calculate negative F1 score.
32 |
33 | Args:
34 | y_true: list of ground truth
35 | y_pred: list of predictions to compare against the ground truth
36 |
37 | Returns:
38 | Custom loss score
39 | """
40 | diff = y_true - y_pred
41 | negs = ops.clip(diff, -1.0, 0.0)
42 | false_positive = -ops.sum(negs, axis=-1)
43 | true_positive = 1.0 - false_positive
44 | lower_clip = 1e-20
45 | true_positive = ops.clip(true_positive, lower_clip, 1.0)
46 |
47 | return -ops.mean(true_positive)
48 |
49 |
50 | def build_keras_semisupervised_loss_func(supervised_loss_func,
51 | unsupervised_loss_func,
52 | dictionary, beta,
53 | activation, n_labels,
54 | normalize: bool = False,
55 | normalize_scaler: float = 1.0,
56 | normalize_func=tf.math.tanh):
57 | @tf.keras.utils.register_keras_serializable(package="Addons")
58 | def _semisupervised_loss_func(data, y_pred):
59 | """
60 | Args:
61 | data: Contains true labels and input features (spectra)
62 | y_pred: Model output (unactivated logits)
63 | """
64 | y_true = data[:, :n_labels]
65 | spectra = data[:, n_labels:]
66 | logits = y_pred
67 | lpes = activation(y_pred)
68 |
69 | sup_losses = supervised_loss_func(y_true, logits)
70 | unsup_losses = reconstruction_error(spectra, lpes, dictionary,
71 | unsupervised_loss_func)
72 | if normalize:
73 | sup_losses = normalize_func(normalize_scaler * sup_losses)
74 |
75 | semisup_losses = (1 - beta) * sup_losses + beta * unsup_losses
76 |
77 | return semisup_losses
78 |
79 | return _semisupervised_loss_func
80 |
81 |
82 | def sse_diff(spectra, reconstructed_spectra):
83 | """Compute the sum of squares error.
84 |
85 | TODO: refactor to assume spectral inputs are in the same form
86 |
87 | Args:
88 | spectra: spectral samples, assumed to be in counts
89 | reconstructed_spectra: reconstructed spectra created using a
90 | dictionary with label proportion estimates
91 | """
92 | total_counts = tf.reduce_sum(spectra, axis=1)
93 | scaled_reconstructed_spectra = tf.multiply(
94 | reconstructed_spectra,
95 | tf.reshape(total_counts, (-1, 1))
96 | )
97 |
98 | diff = spectra - scaled_reconstructed_spectra
99 | norm_diff = tf.norm(diff, axis=-1)
100 | squared_norm_diff = tf.square(norm_diff)
101 | return squared_norm_diff
102 |
103 |
104 | def poisson_nll_diff(spectra, reconstructed_spectra, eps=1e-8):
105 | """Compute the Poisson Negative Log-Likelihood.
106 |
107 | TODO: refactor to assume spectral inputs are in the same form
108 |
109 | Args:
110 | spectra: spectral samples, assumed to be in counts
111 | reconstructed_spectra: reconstructed spectra created using a
112 | dictionary with label proportion estimates
113 | """
114 | total_counts = tf.reduce_sum(spectra, axis=-1)
115 | scaled_reconstructed_spectra = tf.multiply(
116 | reconstructed_spectra,
117 | tf.reshape(total_counts, (-1, 1))
118 | )
119 | log_reconstructed_spectra = tf.math.log(scaled_reconstructed_spectra + eps)
120 | diff = tf.nn.log_poisson_loss(
121 | spectra,
122 | log_reconstructed_spectra,
123 | compute_full_loss=True
124 | )
125 | diff = tf.reduce_sum(diff, axis=-1)
126 |
127 | return diff
128 |
129 |
130 | def normal_nll_diff(spectra, reconstructed_spectra, eps=1e-8):
131 | """Compute the Normal Negative Log-Likelihood.
132 |
133 | TODO: refactor to assume spectral inputs are in the same form
134 |
135 | Args:
136 | spectra: spectral samples, assumed to be in counts
137 | reconstructed_spectra: reconstructed spectra created using a
138 | dictionary with label proportion estimates
139 | """
140 | total_counts = tf.reduce_sum(spectra, axis=-1)
141 | scaled_reconstructed_spectra = tf.multiply(
142 | reconstructed_spectra,
143 | tf.reshape(total_counts, (-1, 1))
144 | )
145 |
146 | var = tf.clip_by_value(spectra, clip_value_min=1, clip_value_max=np.inf)
147 |
148 | sigma_term = tf.math.log(2 * np.pi * var)
149 | mu_term = tf.math.divide(tf.math.square(scaled_reconstructed_spectra - spectra), var)
150 | diff = sigma_term + mu_term
151 | diff = 0.5 * tf.reduce_sum(diff, axis=-1)
152 |
153 | return diff
154 |
155 |
156 | def weighted_sse_diff(spectra, reconstructed_spectra):
157 | """Compute the Normal Negative Log-Likelihood under constant variance
158 | (this reduces to the SSE, just on a different scale).
159 |
160 | Args:
161 | spectra: spectral samples, assumed to be in counts
162 | reconstructed_spectra: reconstructed spectra created using a
163 | dictionary with label proportion estimates
164 | """
165 | total_counts = tf.reduce_sum(spectra, axis=1)
166 | scaled_reconstructed_spectra = tf.multiply(
167 | reconstructed_spectra,
168 | tf.reshape(total_counts, (-1, 1))
169 | )
170 |
171 | sample_variance = tf.sqrt(tf.math.reduce_variance(spectra, axis=1))
172 |
173 | sigma_term = tf.math.log(2 * np.pi * sample_variance)
174 |
175 | mu_term = tf.math.divide(
176 | tf.math.square(scaled_reconstructed_spectra - spectra),
177 | tf.reshape(sample_variance, (-1, 1))
178 | )
179 | diff = 0.5 * (sigma_term + tf.reduce_sum(mu_term, axis=-1))
180 |
181 | return diff
182 |
183 |
184 | def reconstruction_error(spectra, lpes, dictionary, diff_func):
185 | reconstructed_spectra = tf.matmul(lpes, dictionary)
186 | reconstruction_errors = diff_func(spectra, reconstructed_spectra)
187 | return reconstruction_errors
188 |
189 |
190 | def mish(x):
191 | return x * tf.math.tanh(tf.math.softplus(x))
192 |
193 |
194 | def jensen_shannon_divergence(p, q):
195 | p_sum = tf.reduce_sum(p, axis=-1)
196 | p_norm = tf.divide(
197 | p,
198 | tf.reshape(p_sum, (-1, 1))
199 | )
200 |
201 | q_sum = tf.reduce_sum(q, axis=-1)
202 | q_norm = tf.divide(
203 | q,
204 | tf.reshape(q_sum, (-1, 1))
205 | )
206 |
207 | kld = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.NONE)
208 | m = (p_norm + q_norm) / 2
209 | jsd = (kld(p_norm, m) + kld(q_norm, m)) / 2
210 | return jsd
211 |
212 |
213 | def jensen_shannon_distance(p, q):
214 | divergence = jensen_shannon_divergence(p, q)
215 | return tf.math.sqrt(divergence)
216 |
217 |
218 | def chi_squared_diff(spectra, reconstructed_spectra):
219 | """Compute the Chi-Squared test.
220 |
221 | Args:
222 | spectra: spectral samples, assumed to be in counts
223 | reconstructed_spectra: reconstructed spectra created using a
224 | dictionary with label proportion estimates
225 | """
226 | total_counts = tf.reduce_sum(spectra, axis=1)
227 | scaled_reconstructed_spectra = tf.multiply(
228 | reconstructed_spectra,
229 | tf.reshape(total_counts, (-1, 1))
230 | )
231 |
232 | diff = tf.math.subtract(spectra, scaled_reconstructed_spectra)
233 | squared_diff = tf.math.square(diff)
234 | variances = tf.clip_by_value(spectra, 1, np.inf)
235 | chi_squared = tf.math.divide(squared_diff, variances)
236 | return tf.reduce_sum(chi_squared, axis=-1)
237 |
--------------------------------------------------------------------------------
/riid/losses/sparsemax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 |
5 | # This code is based on Tensorflow-Addons. THE ORIGINAL CODE HAS BEEN MODIFIED.
6 | # https://www.tensorflow.org/addons/
7 |
8 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
9 |
10 | # Licensed under the Apache License, Version 2.0 (the "License");
11 | # you may not use this file except in compliance with the License.
12 | # You may obtain a copy of the License at
13 |
14 | # http://www.apache.org/licenses/LICENSE-2.0
15 |
16 | # Unless required by applicable law or agreed to in writing, software
17 | # distributed under the License is distributed on an "AS IS" BASIS,
18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 | # See the License for the specific language governing permissions and
20 | # limitations under the License.
21 | """This module contains sparsemax-related functions."""
22 |
23 | from typing import Optional
24 |
25 | import tensorflow as tf
26 | from typeguard import typechecked
27 |
28 |
29 | def sparsemax(logits, axis: int = -1) -> tf.Tensor:
30 | r"""Sparsemax activation function.
31 |
32 | For each batch \( i \), and class \( j \),
33 | compute sparsemax activation function:
34 |
35 | $$
36 | \mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0).
37 | $$
38 |
39 | See
40 | [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification
41 | ](https://arxiv.org/abs/1602.02068).
42 |
43 | Usage:
44 |
45 | >>> x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]])
46 | >>> tfa.activations.sparsemax(x)
47 |
50 |
51 | Args:
52 | logits: `Tensor`
53 | axis: `int`, axis along which the sparsemax operation is applied
54 |
55 | Returns:
56 | `Tensor`, output of sparsemax transformation (has the same type and shape as `logits`)
57 |
58 | Raises:
59 | `ValueError` when `dim(logits) == 1`
60 | """
61 | logits = tf.convert_to_tensor(logits, name="logits")
62 |
63 | # We need its original shape for shape inference.
64 | shape = logits.get_shape()
65 | rank = shape.rank
66 | is_last_axis = (axis == -1) or (axis == rank - 1)
67 |
68 | if is_last_axis:
69 | output = _compute_2d_sparsemax(logits)
70 | output.set_shape(shape)
71 | return output
72 |
73 | # If dim is not the last dimension, we have to do a transpose so that we can
74 | # still perform softmax on its last dimension.
75 |
76 | # Swap logits' dimension of dim and its last dimension.
77 | rank_op = tf.rank(logits)
78 | axis_norm = axis % rank
79 | logits = _swap_axis(logits, axis_norm, tf.math.subtract(rank_op, 1))
80 |
81 | # Do the actual softmax on its last dimension.
82 | output = _compute_2d_sparsemax(logits)
83 | output = _swap_axis(output, axis_norm, tf.math.subtract(rank_op, 1))
84 |
85 | # Make shape inference work since transpose may erase its static shape.
86 | output.set_shape(shape)
87 | return output
88 |
89 |
90 | def _swap_axis(logits, dim_index, last_index, **kwargs):
91 | return tf.transpose(
92 | logits,
93 | tf.concat(
94 | [
95 | tf.range(dim_index),
96 | [last_index],
97 | tf.range(dim_index + 1, last_index),
98 | [dim_index],
99 | ],
100 | 0,
101 | ),
102 | **kwargs,
103 | )
104 |
105 |
106 | def _compute_2d_sparsemax(logits):
107 | """Perform the sparsemax operation when axis=-1."""
108 | shape_op = tf.shape(logits)
109 | obs = tf.math.reduce_prod(shape_op[:-1])
110 | dims = shape_op[-1]
111 |
112 | # In the paper, they call the logits z.
113 | # The mean(logits) can be substracted from logits to make the algorithm
114 | # more numerically stable. the instability in this algorithm comes mostly
115 | # from the z_cumsum. Substacting the mean will cause z_cumsum to be close
116 | # to zero. However, in practise the numerical instability issues are very
117 | # minor and substacting the mean causes extra issues with inf and nan
118 | # input.
119 | # Reshape to [obs, dims] as it is almost free and means the remanining
120 | # code doesn't need to worry about the rank.
121 | z = tf.reshape(logits, [obs, dims])
122 |
123 | # sort z
124 | z_sorted, _ = tf.nn.top_k(z, k=dims)
125 |
126 | # calculate k(z)
127 | z_cumsum = tf.math.cumsum(z_sorted, axis=-1)
128 | k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
129 | z_check = 1 + k * z_sorted > z_cumsum
130 | # because the z_check vector is always [1,1,...1,0,0,...0] finding the
131 | # (index + 1) of the last `1` is the same as just summing the number of 1.
132 | k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1)
133 |
134 | # calculate tau(z)
135 | # If there are inf values or all values are -inf, the k_z will be zero,
136 | # this is mathematically invalid and will also cause the gather_nd to fail.
137 | # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
138 | # fixed later (see p_safe) by returning p = nan. This results in the same
139 | # behavior as softmax.
140 | k_z_safe = tf.math.maximum(k_z, 1)
141 | indices = tf.stack([tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1)
142 | tau_sum = tf.gather_nd(z_cumsum, indices)
143 | tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype)
144 |
145 | # calculate p
146 | p = tf.math.maximum(tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1))
147 | # If k_z = 0 or if z = nan, then the input is invalid
148 | p_safe = tf.where(
149 | tf.expand_dims(
150 | tf.math.logical_or(tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])),
151 | axis=-1,
152 | ),
153 | tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)),
154 | p,
155 | )
156 |
157 | # Reshape back to original size
158 | p_safe = tf.reshape(p_safe, shape_op)
159 | return p_safe
160 |
161 |
162 | def sparsemax_loss(logits, sparsemax, labels, name: Optional[str] = None) -> tf.Tensor:
163 | r"""Sparsemax loss function ([1]).
164 |
165 | Computes the generalized multi-label classification loss for the sparsemax
166 | function. The implementation is a reformulation of the original loss
167 | function such that it uses the sparsemax probability output instead of the
168 | internal \( \tau \) variable. However, the output is identical to the original
169 | loss function.
170 |
171 | [1]: https://arxiv.org/abs/1602.02068
172 |
173 | Args:
174 | logits: `Tensor`. Must be one of the following types: `float32`,
175 | `float64`.
176 | sparsemax: `Tensor`. Must have the same type as `logits`.
177 | labels: `Tensor`. Must have the same type as `logits`.
178 | name: name for the operation (optional).
179 |
180 | Returns:
181 | A `Tensor`. Has the same type as `logits`.
182 | """
183 | logits = tf.convert_to_tensor(logits, name="logits")
184 | sparsemax = tf.convert_to_tensor(sparsemax, name="sparsemax")
185 | labels = tf.convert_to_tensor(labels, name="labels")
186 |
187 | # In the paper, they call the logits z.
188 | # A constant can be substracted from logits to make the algorithm
189 | # more numerically stable in theory. However, there are really no major
190 | # source numerical instability in this algorithm.
191 | z = logits
192 |
193 | # sum over support
194 | # Use a conditional where instead of a multiplication to support z = -inf.
195 | # If z = -inf, and there is no support (sparsemax = 0), a multiplication
196 | # would cause 0 * -inf = nan, which is not correct in this case.
197 | sum_s = tf.where(
198 | tf.math.logical_or(sparsemax > 0, tf.math.is_nan(sparsemax)),
199 | sparsemax * (z - 0.5 * sparsemax),
200 | tf.zeros_like(sparsemax),
201 | )
202 |
203 | # - z_k + ||q||^2
204 | q_part = labels * (0.5 * labels - z)
205 | # Fix the case where labels = 0 and z = -inf, where q_part would
206 | # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
207 | # z = -inf should be consideredself.
208 | # The code below also coveres the case where z = inf. Howeverm in this
209 | # caose the sparsemax will be nan, which means the sum_s will also be nan,
210 | # therefor this case doesn't need addtional special treatment.
211 | q_part_safe = tf.where(
212 | tf.math.logical_and(tf.math.equal(labels, 0), tf.math.is_inf(z)),
213 | tf.zeros_like(z),
214 | q_part,
215 | )
216 |
217 | return tf.math.reduce_sum(sum_s + q_part_safe, axis=1)
218 |
219 |
220 | @tf.function
221 | @tf.keras.utils.register_keras_serializable(package="Addons")
222 | def sparsemax_loss_from_logits(
223 | y_true, logits_pred
224 | ) -> tf.Tensor:
225 | y_pred = sparsemax(logits_pred)
226 | loss = sparsemax_loss(logits_pred, y_pred, y_true)
227 | return loss
228 |
229 |
230 | @tf.keras.utils.register_keras_serializable(package="Addons")
231 | class SparsemaxLoss(tf.keras.losses.Loss):
232 | """Sparsemax loss function.
233 |
234 | Computes the generalized multi-label classification loss for the sparsemax
235 | function.
236 |
237 | Because the sparsemax loss function needs both the probability output and
238 | the logits to compute the loss value, `from_logits` must be `True`.
239 |
240 | Because it computes the generalized multi-label loss, the shape of both
241 | `y_pred` and `y_true` must be `[batch_size, num_classes]`.
242 |
243 | Args:
244 | from_logits: Whether `y_pred` is expected to be a logits tensor. Default
245 | is `True`, meaning `y_pred` is the logits.
246 | reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
247 | loss. Default value is `SUM_OVER_BATCH_SIZE`.
248 | name: Optional name for the op
249 | """
250 |
251 | @typechecked
252 | def __init__(
253 | self,
254 | from_logits: bool = True,
255 | reduction: str = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
256 | name: str = "sparsemax_loss",
257 | ):
258 | if from_logits is not True:
259 | raise ValueError("from_logits must be True")
260 |
261 | super().__init__(name=name, reduction=reduction)
262 | self.from_logits = from_logits
263 |
264 | def call(self, y_true, y_pred):
265 | return sparsemax_loss_from_logits(y_true, y_pred)
266 |
267 | def get_config(self):
268 | config = {
269 | "from_logits": self.from_logits,
270 | }
271 | base_config = super().get_config()
272 | return {**base_config, **config}
273 |
--------------------------------------------------------------------------------
/riid/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module provides custom model metrics."""
5 | import numpy as np
6 | import sklearn
7 |
8 | from riid import SampleSet
9 |
10 |
11 | def multi_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float:
12 | """Calculate a measure of the F1 score of two tensors.
13 |
14 | Values for `y_true` and `y_pred` are assumed to sum to 1.
15 |
16 | Args:
17 | y_true: list of ground truth
18 | y_pred: list of predictions to compare against the ground truth
19 |
20 | Returns:
21 | Multi F1-score value(s)
22 | """
23 | from keras.api import ops
24 |
25 | diff = y_true - y_pred
26 | negs = ops.clip(diff, -1.0, 0.0)
27 | false_positive = -ops.sum(negs, axis=-1)
28 | true_positive = 1.0 - false_positive
29 |
30 | return ops.mean(true_positive)
31 |
32 |
33 | def single_f1(y_true: np.ndarray, y_pred: np.ndarray):
34 | """Compute the weighted F1 score for the maximum prediction and maximum ground truth.
35 |
36 | Values for `y_true` and `y_pred` are assumed to sum to 1.
37 |
38 | Args:
39 | y_true: list of ground truth
40 | y_pred: list of predictions to compare against the ground truth
41 |
42 | Returns:
43 | F1-score value(s)
44 | """
45 | import tensorflow as tf
46 | from keras.api import ops
47 |
48 | a = tf.dtypes.cast(y_true == ops.max(y_true, axis=1)[:, None], tf.float32)
49 | b = tf.dtypes.cast(y_pred == ops.max(y_pred, axis=1)[:, None], tf.float32)
50 |
51 | TP_mat = tf.dtypes.cast(ops.all(tf.stack([a, b]), axis=0), tf.float32)
52 | FP_mat = tf.dtypes.cast(ops.all(tf.stack([a != b, b == 1]), axis=0), tf.float32)
53 | FN_mat = tf.dtypes.cast(ops.all(tf.stack([a != b, a == 1]), axis=0), tf.float32)
54 |
55 | TPs = ops.sum(TP_mat, axis=0)
56 | FPs = ops.sum(FP_mat, axis=0)
57 | FNs = ops.sum(FN_mat, axis=0)
58 |
59 | F1s = 2 * TPs / (2*TPs + FNs + FPs + tf.fill(tf.shape(TPs), tf.keras.backend.epsilon()))
60 |
61 | support = ops.sum(a, axis=0)
62 | f1 = ops.sum(F1s * support) / ops.sum(support)
63 | return f1
64 |
65 |
66 | def harmonic_mean(x, y):
67 | """Compute the harmonic mean of two same-dimensional arrays.
68 |
69 | Used to compute F1 score:
70 |
71 | ```
72 | f1_score = harmonic_mean(precision, recall)
73 | ```
74 |
75 | Args:
76 | x (array-like): numeric or array_like of numerics
77 | y (array-like): numeric or array_like of numerics matching the shape/type of `x`
78 |
79 | Returns:
80 | Array-like harmonic mean of `x` and `y`
81 | """
82 | return 2 * x * y / (x + y)
83 |
84 |
85 | def precision_recall_curve(ss: SampleSet, smooth: bool = True, multiclass: bool = None,
86 | include_micro: bool = True, target_level: str = "Isotope",
87 | minimum_contribution: float = 0.01):
88 | """Similar to `sklearn.metrics.precision_recall_curve`, however, this function
89 | computes the precision and recall for each class, and supports both multi-class
90 | and multi-label problems.
91 |
92 | The reason this is necessary is that in multi-class problems, for a single sample,
93 | all predictions are discarded except for the argmax.
94 |
95 | Args:
96 | ss: `SampleSet` that predictions were generated on
97 | smooth: if True, precision is smoothing is applied to make a monotonically
98 | decreasing precision function
99 | multiclass: set to True if this is a multi-class (i.e. y_true is one-hot) as
100 | opposed to multi-label (i.e. labels are not mutually exclusive). Ff True,
101 | predictions will be masked such that non-argmax predictions are set to zero
102 | (this prevents inflating the precision by continuing past a point that could
103 | be pragmatically useful). Furthermore, in the multiclass case the recall is
104 | not guaranteed to reach 1.0.
105 | include_micro: if True, compute an additional precision and recall for the
106 | micro-average across all labels and put it under entry `"micro"`
107 | target_level: `SampleSet.sources` and `SampleSet.prediction_probas` column level to use
108 | minimum_contribution: threshold for a source to be considered a ground truth positive
109 | label. if this is set to `None` the raw mixture ratios will be used as y_true.
110 |
111 | Returns:
112 | precision (dict): dict with keys for each label and values that are the
113 | monotonically increasing precision values at each threshold
114 | recall (dict): dict with keys for each label and values that are the
115 | monotonically decreasing recall values at each threshold
116 | thresholds (dict): dict with keys for each label and values that are the
117 | monotonically increasing thresholds on the decision function used to compute
118 | precision and recall
119 |
120 | References:
121 | - [Precision smoothing](
122 | https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173)
123 |
124 | """
125 | y_true = ss.sources.T.groupby(target_level, sort=False).sum().T
126 | if minimum_contribution is not None:
127 | y_true = (y_true > minimum_contribution).astype(int)
128 | y_pred = ss.prediction_probas.T.groupby(target_level, sort=False).sum().T
129 |
130 | # switch from pandas to numpy
131 | labels = y_true.columns
132 | n_classes = len(labels)
133 | y_true = y_true.values.copy()
134 | y_pred = y_pred.values.copy()
135 |
136 | if y_pred.shape != y_true.shape:
137 | raise ValueError(
138 | f"Shape mismatch between truth and predictions, "
139 | f"{y_true.shape} != {y_pred.shape}. "
140 | f"It is possible that the `target_level` is incorrect."
141 | )
142 |
143 | if multiclass is None:
144 | # infer whether multi-class or multi-label
145 | multiclass = not np.any(y_true.sum(axis=1) != 1)
146 |
147 | # drop nans
148 | notnan = ~np.isnan(y_pred).any(axis=1)
149 | y_true = y_true[notnan, :]
150 | y_pred = y_pred[notnan, :]
151 |
152 | y_pred_min = None
153 | if multiclass:
154 | # shift predictions to force positive
155 | if np.any(y_pred < 0):
156 | y_pred_min = y_pred.min()
157 | y_pred -= y_pred_min
158 |
159 | # mask the predictions by argmax
160 | pred_mask = np.eye(n_classes)[np.argmax(y_pred, axis=1)]
161 | # mask the predictions
162 | y_pred *= pred_mask
163 |
164 | precision = dict()
165 | recall = dict()
166 | thresholds = dict()
167 | for i, label in enumerate(labels):
168 | precision[label], recall[label], thresholds[label] = _pr_curve(
169 | y_true[:, i], y_pred[:, i], multiclass=multiclass, smooth=smooth
170 | )
171 |
172 | if include_micro:
173 | # A "micro-average": quantifying score on all classes jointly
174 | precision["micro"], recall["micro"], thresholds["micro"] = _pr_curve(
175 | y_true.ravel(), y_pred.ravel(), multiclass=multiclass, smooth=smooth
176 | )
177 |
178 | # un-shift thresholds if predictions were shifted
179 | if y_pred_min is not None:
180 | thresholds = {k: v + y_pred_min for k, v in thresholds.items()}
181 |
182 | return precision, recall, thresholds
183 |
184 |
185 | def average_precision_score(precision, recall):
186 | """Compute the average precision (area under the curve) for each precision/recall
187 | pair.
188 |
189 | Args:
190 | precision (dict): return value of `ctutil.evaluation.precision_recall_curve()`
191 | recall (dict): return value of `ctutil.evaluation.precision_recall_curve()`
192 |
193 | Returns:
194 | (dict): average precision values (float) for each label in precision/recall
195 | """
196 | return {label: _integrate(recall[label], precision[label]) for label in recall}
197 |
198 |
199 | def _step(x):
200 | """Compute the right going maximum of `x` and all previous values of `x`.
201 |
202 | Args:
203 | x (array-like): 1D array to process
204 |
205 | Returns:
206 | (array-like): right-going maximum of `x`
207 |
208 | """
209 | y = np.array(x)
210 | for i in range(1, len(y)):
211 | y[i] = max(y[i], y[i - 1])
212 | return y
213 |
214 |
215 | def _integrate(x, y, y_left=True):
216 | """Integrate an (x, y) function pair.
217 |
218 | Args:
219 | x (array-like): 1D array of x values
220 | y (array-like): 1D array of y values
221 | y_left: if true, omit the last value of y, else, omit the first value
222 |
223 | Returns:
224 | (float): integrated "area under the curve"
225 |
226 | """
227 | delta_x = x[1:] - x[:-1]
228 | y_trimmed = y[:-1] if y_left else y[1:]
229 | return np.abs(np.sum(delta_x * y_trimmed))
230 |
231 |
232 | def _pr_curve(y_true, y_pred, multiclass, smooth):
233 | precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
234 | y_true, y_pred
235 | )
236 |
237 | if smooth:
238 | precision = _step(precision)
239 |
240 | if multiclass:
241 | # remove the point where threshold=0 and recall=1
242 | precision = precision[1:]
243 | recall = recall[1:]
244 | thresholds = thresholds[1:]
245 |
246 | return precision, recall, thresholds
247 |
248 |
249 | def build_keras_semisupervised_metric_func(keras_metric_func, activation_func,
250 | n_labels):
251 | def metric_func(y_true, y_pred):
252 | return keras_metric_func(y_true[:, :n_labels], activation_func(y_pred))
253 | metric_func.__name__ = keras_metric_func.__class__.__name__
254 |
255 | return metric_func
256 |
--------------------------------------------------------------------------------
/riid/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains PyRIID models."""
5 | from riid.models.bayes import PoissonBayesClassifier
6 | from riid.models.neural_nets import LabelProportionEstimator, MLPClassifier
7 | from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv1, ARADv2
8 |
9 | __all__ = ["PoissonBayesClassifier", "LabelProportionEstimator", "MLPClassifier",
10 | "ARADLatentPredictor", "ARADv1", "ARADv2"]
11 |
--------------------------------------------------------------------------------
/riid/models/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains functionality shared across all PyRIID models."""
5 | import json
6 | import os
7 | from pathlib import Path
8 | import uuid
9 | from abc import abstractmethod
10 | from enum import Enum
11 |
12 | import numpy as np
13 | import tensorflow as tf
14 | import tf2onnx
15 | from keras.api.models import Model
16 | from keras.api.utils import get_custom_objects
17 |
18 | import riid
19 | from riid import SampleSet, SpectraState
20 | from riid.data.labeling import label_to_index_element
21 | from riid.losses import mish
22 | from riid.metrics import multi_f1, single_f1
23 |
24 | get_custom_objects().update({
25 | "multi_f1": multi_f1,
26 | "single_f1": single_f1,
27 | "mish": mish,
28 | })
29 |
30 |
31 | class ModelInput(int, Enum):
32 | """Enumerates the potential input sources for a model."""
33 | GrossSpectrum = 0
34 | BackgroundSpectrum = 1
35 | ForegroundSpectrum = 2
36 |
37 |
38 | class PyRIIDModel:
39 | """Base class for PyRIID models."""
40 |
41 | def __init__(self, *args, **kwargs):
42 | self._info = {}
43 | self._temp_file_path = "temp_model.json"
44 | self._custom_objects = {}
45 | self._initialize_info()
46 |
47 | @property
48 | def seeds(self):
49 | return self._info["seeds"]
50 |
51 | @seeds.setter
52 | def seeds(self, value):
53 | self._info["seeds"] = value
54 |
55 | @property
56 | def info(self):
57 | return self._info
58 |
59 | @info.setter
60 | def info(self, value):
61 | self._info = value
62 |
63 | @property
64 | def target_level(self):
65 | return self._info["target_level"]
66 |
67 | @target_level.setter
68 | def target_level(self, value):
69 | if value in SampleSet.SOURCES_MULTI_INDEX_NAMES:
70 | self._info["target_level"] = value
71 | else:
72 | msg = (
73 | f"Target level '{value}' is invalid. "
74 | f"Acceptable levels: {SampleSet.SOURCES_MULTI_INDEX_NAMES}"
75 | )
76 | raise ValueError(msg)
77 |
78 | @property
79 | def model(self) -> Model:
80 | return self._model
81 |
82 | @model.setter
83 | def model(self, value: Model):
84 | self._model = value
85 |
86 | @property
87 | def model_id(self):
88 | return self._info["model_id"]
89 |
90 | @model_id.setter
91 | def model_id(self, value):
92 | self._info["model_id"] = value
93 |
94 | @property
95 | def model_inputs(self):
96 | return self._info["model_inputs"]
97 |
98 | @model_inputs.setter
99 | def model_inputs(self, value):
100 | self._info["model_inputs"] = value
101 |
102 | @property
103 | def model_outputs(self):
104 | return self._info["model_outputs"]
105 |
106 | @model_outputs.setter
107 | def model_outputs(self, value):
108 | self._info["model_outputs"] = value
109 |
110 | def get_model_outputs_as_label_tuples(self):
111 | return [
112 | label_to_index_element(v, self.target_level) for v in self.model_outputs
113 | ]
114 |
115 | def _get_model_dict(self) -> dict:
116 | model_json = self.model.to_json()
117 | model_dict = json.loads(model_json)
118 | model_weights = self.model.get_weights()
119 | model_dict = {
120 | "info": self._info,
121 | "model": model_dict,
122 | "weights": model_weights,
123 | }
124 | return model_dict
125 |
126 | def _get_model_str(self) -> str:
127 | model_dict = self._get_model_dict()
128 | model_str = json.dumps(model_dict, indent=4, cls=PyRIIDModelJsonEncoder)
129 | return model_str
130 |
131 | def _initialize_info(self):
132 | init_info = {
133 | "model_id": str(uuid.uuid4()),
134 | "model_type": self.__class__.__name__,
135 | "normalization": SpectraState.Unknown,
136 | "pyriid_version": riid.__version__,
137 | }
138 | self._update_info(**init_info)
139 |
140 | def _update_info(self, **kwargs):
141 | self._info.update(kwargs)
142 |
143 | def _update_custom_objects(self, key, value):
144 | self._custom_objects.update({key: value})
145 |
146 | def load(self, model_path: str):
147 | """Load the model from a path.
148 |
149 | Args:
150 | model_path: path from which to load the model.
151 | """
152 | if not os.path.exists(model_path):
153 | raise ValueError("Model file does not exist.")
154 |
155 | with open(model_path) as fin:
156 | model = json.load(fin)
157 |
158 | model_str = json.dumps(model["model"])
159 | self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects)
160 | self.model.set_weights([np.array(x) for x in model["weights"]])
161 | self.info = model["info"]
162 |
163 | def save(self, model_path: str, overwrite=False):
164 | """Save the model to a path.
165 |
166 | Args:
167 | model_path: path at which to save the model.
168 | overwrite: whether to overwrite an existing file if it already exists.
169 |
170 | Raises:
171 | `ValueError` when the given path already exists
172 | """
173 | if os.path.exists(model_path) and not overwrite:
174 | raise ValueError("Model file already exists.")
175 |
176 | model_str = self._get_model_str()
177 | with open(model_path, "w") as fout:
178 | fout.write(model_str)
179 |
180 | def to_onnx(self, model_path, **tf2onnx_kwargs: dict):
181 | """Convert the model to an ONNX model.
182 |
183 | Args:
184 | model_path: path at which to save the model
185 | tf2onnx_kwargs: additional kwargs to pass to the conversion
186 | """
187 | model_path = Path(model_path)
188 | if not str(model_path).endswith(riid.ONNX_MODEL_FILE_EXTENSION):
189 | raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}")
190 | if model_path.exists():
191 | raise ValueError("Model file already exists.")
192 |
193 | tf2onnx.convert.from_keras(
194 | self.model,
195 | input_signature=[
196 | tf.TensorSpec(
197 | shape=input_tensor.shape,
198 | dtype=input_tensor.dtype,
199 | name=input_tensor.name
200 | )
201 | for input_tensor in self.model.inputs
202 | ],
203 | output_path=str(model_path),
204 | **tf2onnx_kwargs
205 | )
206 |
207 | def to_tflite(self, model_path, quantize: bool = False, prune: bool = False):
208 | """Convert the model to a TFLite model and optionally applying quantization or pruning.
209 |
210 | Note: requires export to SavedModel format first, then conversion to TFLite occurs.
211 |
212 | Args:
213 | model_path: file path at which to save the model
214 | quantize: whether to apply quantization
215 | prune: whether to apply pruning
216 | """
217 | model_path = Path(model_path)
218 | if not str(model_path).endswith(riid.TFLITE_MODEL_FILE_EXTENSION):
219 | raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}")
220 | if model_path.exists():
221 | raise ValueError("Model file already exists.")
222 |
223 | optimizations = []
224 | if quantize:
225 | optimizations.append(tf.lite.Optimize.DEFAULT)
226 | if prune:
227 | optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY)
228 |
229 | saved_model_dir = model_path.stem
230 | self.model.export(saved_model_dir)
231 | converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir))
232 | converter.optimizations = optimizations
233 | tflite_model = converter.convert()
234 |
235 | with open(model_path, "wb") as fout:
236 | fout.write(tflite_model)
237 |
238 | @abstractmethod
239 | def fit(self):
240 | pass
241 |
242 | @abstractmethod
243 | def predict(self):
244 | pass
245 |
246 |
247 | class PyRIIDModelJsonEncoder(json.JSONEncoder):
248 | """Custom JSON encoder for saving models.
249 | """
250 | def default(self, o):
251 | """Converts certain types to JSON-compatible types.
252 | """
253 | if isinstance(o, np.ndarray):
254 | return o.tolist()
255 | elif isinstance(o, np.float32):
256 | return o.astype(float)
257 |
258 | return super().default(o)
259 |
--------------------------------------------------------------------------------
/riid/models/bayes.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains the Poisson-Bayes classifier."""
5 | import numpy as np
6 | import pandas as pd
7 | import tensorflow as tf
8 | from keras.api.layers import Add, Input, Multiply, Subtract
9 | from keras.api.models import Model
10 |
11 | from riid import SampleSet
12 | from riid.models.base import PyRIIDModel
13 | from riid.models.layers import (ClipByValueLayer, DivideLayer, ExpandDimsLayer,
14 | PoissonLogProbabilityLayer, ReduceMaxLayer,
15 | ReduceSumLayer, SeedLayer)
16 |
17 |
18 | class PoissonBayesClassifier(PyRIIDModel):
19 | """Classifier calculating the conditional Poisson log probability of each seed spectrum
20 | given the measurement.
21 |
22 | This implementation is an adaptation of a naive Bayes classifier, a formal description of
23 | which can be found in ESLII:
24 |
25 | Hastie, Trevor, et al. The elements of statistical learning: data mining, inference, and
26 | prediction. Vol. 2. New York. Springer, 2009.
27 |
28 | For this model, each spectrum channel is treated as a Poisson random variable and
29 | expectations are provided by the user in the form of seeds rather than learned.
30 | Like the model described in ESLII, all classes are considered equally likely and features
31 | are assumed to be conditionally independent.
32 | """
33 | def __init__(self):
34 | super().__init__()
35 |
36 | self._update_custom_objects("ReduceSumLayer", ReduceSumLayer)
37 | self._update_custom_objects("ReduceMaxLayer", ReduceMaxLayer)
38 | self._update_custom_objects("DivideLayer", DivideLayer)
39 | self._update_custom_objects("ExpandDimsLayer", ExpandDimsLayer)
40 | self._update_custom_objects("ClipByValueLayer", ClipByValueLayer)
41 | self._update_custom_objects("PoissonLogProbabilityLayer", PoissonLogProbabilityLayer)
42 | self._update_custom_objects("SeedLayer", SeedLayer)
43 |
44 | def fit(self, seeds_ss: SampleSet):
45 | """Construct a TF-based implementation of a poisson-bayes classifier in terms
46 | of the given seeds.
47 |
48 | Args:
49 | seeds_ss: `SampleSet` of `n` foreground seed spectra where `n` >= 1.
50 |
51 | Raises:
52 | - `ValueError` when no seeds are provided
53 | - `NegativeSpectrumError` when any seed spectrum has negative counts in any bin
54 | - `ZeroTotalCountsError` when any seed spectrum contains zero total counts
55 | """
56 | if seeds_ss.n_samples <= 0:
57 | raise ValueError("Argument 'seeds_ss' must contain at least one seed.")
58 | if (seeds_ss.spectra.values < 0).any():
59 | msg = "Argument 'seeds_ss' can't contain any spectra with negative values."
60 | raise NegativeSpectrumError(msg)
61 | if (seeds_ss.spectra.values.sum(axis=1) <= 0).any():
62 | msg = "Argument 'seeds_ss' can't contain any spectra with zero total counts."
63 | raise ZeroTotalCountsError(msg)
64 |
65 | self._seeds = tf.convert_to_tensor(
66 | seeds_ss.spectra.values,
67 | dtype=tf.float32
68 | )
69 |
70 | # Inputs
71 | gross_spectrum_input = Input(shape=(seeds_ss.n_channels,),
72 | name="gross_spectrum")
73 | gross_live_time_input = Input(shape=(),
74 | name="gross_live_time")
75 | bg_spectrum_input = Input(shape=(seeds_ss.n_channels,),
76 | name="bg_spectrum")
77 | bg_live_time_input = Input(shape=(),
78 | name="bg_live_time")
79 | model_inputs = (
80 | gross_spectrum_input,
81 | gross_live_time_input,
82 | bg_spectrum_input,
83 | bg_live_time_input,
84 | )
85 |
86 | # Input statistics
87 | gross_total_counts = ReduceSumLayer(name="gross_total_counts")(gross_spectrum_input, axis=1)
88 | bg_total_counts = ReduceSumLayer(name="bg_total_counts")(bg_spectrum_input, axis=1)
89 | bg_count_rate = DivideLayer(name="bg_count_rate")([bg_total_counts, bg_live_time_input])
90 |
91 | gross_spectrum_input_expanded = ExpandDimsLayer(
92 | name="gross_spectrum_input_expanded"
93 | )(gross_spectrum_input, axis=1)
94 | bg_total_counts_expanded = ExpandDimsLayer(
95 | name="bg_total_counts_expanded"
96 | )(bg_total_counts, axis=1)
97 |
98 | # Expectations
99 | seed_layer = SeedLayer(self._seeds)(model_inputs)
100 | seed_layer_expanded = ExpandDimsLayer()(seed_layer, axis=0)
101 | expected_bg_counts = Multiply(
102 | trainable=False,
103 | name="expected_bg_counts"
104 | )([bg_count_rate, gross_live_time_input])
105 | expected_bg_counts_expanded = ExpandDimsLayer(
106 | name="expected_bg_counts_expanded"
107 | )(expected_bg_counts, axis=1)
108 | normalized_bg_spectrum = DivideLayer(
109 | name="normalized_bg_spectrum"
110 | )([bg_spectrum_input, bg_total_counts_expanded])
111 | expected_bg_spectrum = Multiply(
112 | trainable=False,
113 | name="expected_bg_spectrum"
114 | )([normalized_bg_spectrum, expected_bg_counts_expanded])
115 | expected_fg_counts = Subtract(
116 | trainable=False,
117 | name="expected_fg_counts"
118 | )([gross_total_counts, expected_bg_counts])
119 | expected_fg_counts_expanded = ExpandDimsLayer(
120 | name="expected_fg_counts_expanded"
121 | )(expected_fg_counts, axis=-1)
122 | expected_fg_counts_expanded2 = ExpandDimsLayer(
123 | name="expected_fg_counts_expanded2"
124 | )(expected_fg_counts_expanded, axis=-1)
125 | expected_fg_spectrum = Multiply(
126 | trainable=False,
127 | name="expected_fg_spectrum"
128 | )([seed_layer_expanded, expected_fg_counts_expanded2])
129 | max_fg_value = ReduceMaxLayer(
130 | name="max_fg_value"
131 | )(expected_fg_spectrum)
132 | expected_fg_spectrum = ClipByValueLayer(
133 | name="clip_expected_fg_spectrum"
134 | )(expected_fg_spectrum, clip_value_min=1e-8, clip_value_max=max_fg_value)
135 | expected_bg_spectrum_expanded = ExpandDimsLayer(
136 | name="expected_bg_spectrum_expanded"
137 | )(expected_bg_spectrum, axis=1)
138 | expected_gross_spectrum = Add(
139 | trainable=False,
140 | name="expected_gross_spectrum"
141 | )([expected_fg_spectrum, expected_bg_spectrum_expanded])
142 |
143 | # Compute probabilities
144 | log_probabilities = PoissonLogProbabilityLayer(
145 | name="log_probabilities"
146 | )([expected_gross_spectrum, gross_spectrum_input_expanded])
147 | summed_log_probabilities = ReduceSumLayer(
148 | name="summed_log_probabilities"
149 | )(log_probabilities, axis=2)
150 |
151 | # Assemble model
152 | self.model = Model(model_inputs, summed_log_probabilities)
153 | self.model.compile()
154 |
155 | self.target_level = "Seed"
156 | sources_df = seeds_ss.sources.T.groupby(self.target_level, sort=False).sum().T
157 | self.model_outputs = sources_df.columns.values.tolist()
158 |
159 | def predict(self, gross_ss: SampleSet, bg_ss: SampleSet,
160 | normalize_scores: bool = False, verbose: bool = False):
161 | """Compute the conditional Poisson log probability between spectra in a `SampleSet` and
162 | the seeds to which the model was fit.
163 |
164 | Args:
165 | gross_ss: `SampleSet` of `n` gross spectra where `n` >= 1
166 | bg_ss: `SampleSet` of `n` background spectra where `n` >= 1
167 | normalize_scores (bool): whether to normalize prediction probabilities
168 | When True, this makes the probabilities positive and rescales them
169 | by the minimum value present in given the dataset.
170 | While this can be helpful in terms of visualizing probabilities in log scale,
171 | it can adversely affects one's ability to detect significantly anomalous signatures.
172 | """
173 | gross_spectra = tf.convert_to_tensor(gross_ss.spectra.values, dtype=tf.float32)
174 | gross_lts = tf.convert_to_tensor(gross_ss.info.live_time.values, dtype=tf.float32)
175 | bg_spectra = tf.convert_to_tensor(bg_ss.spectra.values, dtype=tf.float32)
176 | bg_lts = tf.convert_to_tensor(bg_ss.info.live_time.values, dtype=tf.float32)
177 |
178 | prediction_probas = self.model.predict((
179 | gross_spectra, gross_lts, bg_spectra, bg_lts
180 | ), batch_size=512, verbose=verbose)
181 |
182 | # Normalization
183 | if normalize_scores:
184 | rows_min = np.min(prediction_probas, axis=1)
185 | prediction_probas = prediction_probas - rows_min[:, np.newaxis]
186 |
187 | gross_ss.prediction_probas = pd.DataFrame(
188 | prediction_probas,
189 | columns=pd.MultiIndex.from_tuples(
190 | self.get_model_outputs_as_label_tuples(),
191 | names=SampleSet.SOURCES_MULTI_INDEX_NAMES
192 | )
193 | )
194 |
195 |
196 | class ZeroTotalCountsError(ValueError):
197 | """All spectrum channels are zero."""
198 | pass
199 |
200 |
201 | class NegativeSpectrumError(ValueError):
202 | """At least one spectrum channel is negative."""
203 | pass
204 |
--------------------------------------------------------------------------------
/riid/models/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains custom Keras layers."""
5 | import tensorflow as tf
6 | from keras.api.layers import Layer
7 |
8 |
9 | class ReduceSumLayer(Layer):
10 | def __init__(self, **kwargs):
11 | super().__init__(**kwargs)
12 |
13 | def call(self, x, axis):
14 | return tf.reduce_sum(x, axis=axis)
15 |
16 |
17 | class ReduceMaxLayer(Layer):
18 | def __init__(self, **kwargs):
19 | super().__init__(**kwargs)
20 |
21 | def call(self, x):
22 | return tf.reduce_max(x)
23 |
24 |
25 | class DivideLayer(Layer):
26 | def __init__(self, **kwargs):
27 | super().__init__(**kwargs)
28 |
29 | def call(self, x):
30 | return tf.divide(x[0], x[1])
31 |
32 |
33 | class ExpandDimsLayer(Layer):
34 | def __init__(self, **kwargs):
35 | super().__init__(**kwargs)
36 |
37 | def call(self, x, axis):
38 | return tf.expand_dims(x, axis=axis)
39 |
40 |
41 | class ClipByValueLayer(Layer):
42 | def __init__(self, **kwargs):
43 | super().__init__(**kwargs)
44 |
45 | def call(self, x, clip_value_min, clip_value_max):
46 | return tf.clip_by_value(x, clip_value_min=clip_value_min, clip_value_max=clip_value_max)
47 |
48 |
49 | class PoissonLogProbabilityLayer(Layer):
50 | def __init__(self, **kwargs):
51 | super().__init__(**kwargs)
52 |
53 | def call(self, x):
54 | exp, value = x
55 | log_probas = tf.math.xlogy(value, exp) - exp - tf.math.lgamma(value + 1)
56 | return log_probas
57 |
58 |
59 | class SeedLayer(Layer):
60 | def __init__(self, seeds, **kwargs):
61 | super(SeedLayer, self).__init__(**kwargs)
62 | self.seeds = tf.convert_to_tensor(seeds)
63 |
64 | def get_config(self):
65 | config = super().get_config()
66 | config.update({
67 | "seeds": self.seeds.numpy().tolist(),
68 | })
69 | return config
70 |
71 | def call(self, inputs):
72 | return self.seeds
73 |
74 |
75 | class L1NormLayer(Layer):
76 | def __init__(self, **kwargs):
77 | super().__init__(**kwargs)
78 |
79 | def call(self, inputs):
80 | sums = tf.reduce_sum(inputs, axis=-1)
81 | l1_norm = inputs / tf.reshape(sums, (-1, 1))
82 | return l1_norm
83 |
--------------------------------------------------------------------------------
/riid/models/neural_nets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains neural network-based classifiers and regressors."""
5 | from riid.models.neural_nets.basic import MLPClassifier
6 | from riid.models.neural_nets.lpe import LabelProportionEstimator
7 |
8 | __all__ = ["LabelProportionEstimator", "MLPClassifier"]
9 |
--------------------------------------------------------------------------------
/riid/models/neural_nets/basic.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module contains a simple neural network."""
5 | import pandas as pd
6 | import tensorflow as tf
7 | from keras.api.callbacks import EarlyStopping
8 | from keras.api.layers import Dense, Input, Dropout
9 | from keras.api.losses import CategoricalCrossentropy
10 | from keras.api.metrics import F1Score, Precision, Recall
11 | from keras.api.models import Model
12 | from keras.api.optimizers import Adam
13 | from keras.api.regularizers import l1, l2
14 | from keras.api.utils import split_dataset
15 |
16 | from riid import SampleSet, SpectraType
17 | from riid.models.base import ModelInput, PyRIIDModel
18 |
19 |
20 | class MLPClassifier(PyRIIDModel):
21 | """Multi-layer perceptron classifier."""
22 | def __init__(self, activation=None, loss=None, optimizer=None,
23 | metrics=None, l2_alpha: float = 1e-4,
24 | activity_regularizer=None, final_activation=None,
25 | dense_layer_size=None, dropout=None):
26 | """
27 | Args:
28 | activation: activate function to use for each dense layer
29 | loss: loss function to use for training
30 | optimizer: tensorflow optimizer or optimizer name to use for training
31 | metrics: list of metrics to be evaluating during training
32 | l2_alpha: alpha value for the L2 regularization of each dense layer
33 | activity_regularizer: regularizer function applied each dense layer output
34 | final_activation: final activation function to apply to model output
35 | """
36 | super().__init__()
37 |
38 | self.activation = activation
39 | self.loss = loss
40 | self.optimizer = optimizer
41 | self.final_activation = final_activation
42 | self.metrics = metrics
43 | self.l2_alpha = l2_alpha
44 | self.activity_regularizer = activity_regularizer
45 | self.final_activation = final_activation
46 | self.dense_layer_size = dense_layer_size
47 | self.dropout = dropout
48 |
49 | if self.activation is None:
50 | self.activation = "relu"
51 | if self.loss is None:
52 | self.loss = CategoricalCrossentropy()
53 | if optimizer is None:
54 | self.optimizer = Adam(learning_rate=0.01, clipnorm=0.001)
55 | if self.metrics is None:
56 | self.metrics = [F1Score(), Precision(), Recall()]
57 | if self.activity_regularizer is None:
58 | self.activity_regularizer = l1(0.0)
59 | if self.final_activation is None:
60 | self.final_activation = "softmax"
61 |
62 | self.model = None
63 | self._set_predict_fn()
64 |
65 | def fit(self, ss: SampleSet, batch_size: int = 200, epochs: int = 20,
66 | validation_split: float = 0.2, callbacks=None,
67 | patience: int = 15, es_monitor: str = "val_loss",
68 | es_mode: str = "min", es_verbose=0, target_level="Isotope", verbose: bool = False):
69 | """Fit a model to the given `SampleSet`(s).
70 |
71 | Args:
72 | ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either
73 | foreground (AKA, "net") or gross.
74 | batch_size: number of samples per gradient update
75 | epochs: maximum number of training iterations
76 | validation_split: percentage of the training data to use as validation data
77 | callbacks: list of callbacks to be passed to the TensorFlow `Model.fit()` method
78 | patience: number of epochs to wait for `EarlyStopping` object
79 | es_monitor: quantity to be monitored for `EarlyStopping` object
80 | es_mode: mode for `EarlyStopping` object
81 | es_verbose: verbosity level for `EarlyStopping` object
82 | target_level: `SampleSet.sources` column level to use
83 | verbose: whether to show detailed model training output
84 |
85 | Returns:
86 | `tf.History` object.
87 |
88 | Raises:
89 | `ValueError` when no spectra are provided as input
90 | """
91 | if ss.n_samples <= 0:
92 | raise ValueError("No spectr[a|um] provided!")
93 |
94 | if ss.spectra_type == SpectraType.Gross:
95 | self.model_inputs = (ModelInput.GrossSpectrum,)
96 | elif ss.spectra_type == SpectraType.Foreground:
97 | self.model_inputs = (ModelInput.ForegroundSpectrum,)
98 | elif ss.spectra_type == SpectraType.Background:
99 | self.model_inputs = (ModelInput.BackgroundSpectrum,)
100 | else:
101 | raise ValueError(f"{ss.spectra_type} is not supported in this model.")
102 |
103 | X = ss.get_samples()
104 | source_contributions_df = ss.sources.T.groupby(target_level, sort=False).sum().T
105 | model_outputs = source_contributions_df.columns.values.tolist()
106 | Y = source_contributions_df.values
107 |
108 | spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32)
109 | labels_tensor = tf.convert_to_tensor(Y, dtype=tf.float32)
110 | training_dataset = tf.data.Dataset.from_tensor_slices((spectra_tensor, labels_tensor))
111 | training_dataset, validation_dataset = split_dataset(
112 | training_dataset,
113 | right_size=validation_split,
114 | shuffle=True
115 | )
116 | training_dataset = training_dataset.batch(batch_size=batch_size)
117 | validation_dataset = validation_dataset.batch(batch_size=batch_size)
118 |
119 | if not self.model:
120 | inputs = Input(shape=(X.shape[1],), name="Spectrum")
121 | if self.dense_layer_size is None:
122 | dense_layer_size = X.shape[1] // 2
123 | else:
124 | dense_layer_size = self.dense_layer_size
125 | dense_layer = Dense(
126 | dense_layer_size,
127 | activation=self.activation,
128 | activity_regularizer=self.activity_regularizer,
129 | kernel_regularizer=l2(self.l2_alpha),
130 | )(inputs)
131 | if self.dropout is not None:
132 | last_layer = Dropout(self.dropout)(dense_layer)
133 | else:
134 | last_layer = dense_layer
135 | outputs = Dense(Y.shape[1], activation=self.final_activation)(last_layer)
136 | self.model = Model(inputs, outputs)
137 | self.model.compile(loss=self.loss, optimizer=self.optimizer,
138 | metrics=self.metrics)
139 |
140 | es = EarlyStopping(
141 | monitor=es_monitor,
142 | patience=patience,
143 | verbose=es_verbose,
144 | restore_best_weights=True,
145 | mode=es_mode,
146 | )
147 | if callbacks:
148 | callbacks.append(es)
149 | else:
150 | callbacks = [es]
151 |
152 | history = self.model.fit(
153 | training_dataset,
154 | epochs=epochs,
155 | verbose=verbose,
156 | validation_data=validation_dataset,
157 | callbacks=callbacks,
158 | )
159 |
160 | # Update model information
161 | self._update_info(
162 | target_level=target_level,
163 | model_outputs=model_outputs,
164 | normalization=ss.spectra_state,
165 | )
166 |
167 | # Define the predict function with tf.function and input_signature
168 | self._set_predict_fn()
169 |
170 | return history
171 |
172 | def _set_predict_fn(self):
173 | self._predict_fn = tf.function(
174 | self._predict,
175 | experimental_relax_shapes=True
176 | )
177 |
178 | def _predict(self, input_tensor):
179 | return self.model(input_tensor, training=False)
180 |
181 | def predict(self, ss: SampleSet, bg_ss: SampleSet = None):
182 | """Classify the spectra in the provided `SampleSet`(s).
183 |
184 | Results are stored inside the first SampleSet's prediction-related properties.
185 |
186 | Args:
187 | ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either
188 | foreground (AKA, "net") or gross
189 | bg_ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are background
190 | """
191 | x_test = ss.get_samples().astype(float)
192 | if bg_ss:
193 | X = [x_test, bg_ss.get_samples().astype(float)]
194 | else:
195 | X = x_test
196 |
197 | spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32)
198 | results = self._predict_fn(spectra_tensor)
199 |
200 | col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level)
201 | col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1]
202 | ss.prediction_probas = pd.DataFrame(
203 | data=results,
204 | columns=pd.MultiIndex.from_tuples(
205 | self.get_model_outputs_as_label_tuples(),
206 | names=col_level_subset
207 | )
208 | )
209 |
210 | ss.classified_by = self.model_id
211 |
--------------------------------------------------------------------------------
/run_tests.ps1:
--------------------------------------------------------------------------------
1 | coverage run -m unittest discover -s tests/ -p *.py -v
2 | if ($LASTEXITCODE -ne 0) { throw "Tests failed!" }
3 | coverage report -i
4 | coverage xml -i
5 |
--------------------------------------------------------------------------------
/run_tests.sh:
--------------------------------------------------------------------------------
1 | set -e
2 | coverage run --source=./riid -m unittest tests/*.py
3 | coverage report -i --skip-empty
4 |
--------------------------------------------------------------------------------
/tests/anomaly_tests.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module tests the anomaly module."""
5 | import unittest
6 |
7 | import numpy as np
8 |
9 | from riid import PassbySynthesizer, SeedMixer, get_dummy_seeds
10 | from riid.anomaly import PoissonNChannelEventDetector
11 |
12 |
13 | class TestAnomaly(unittest.TestCase):
14 | """Test class for Anomaly module."""
15 | def setUp(self):
16 | """Test setup."""
17 | pass
18 |
19 | def test_event_detector(self):
20 | random_state = 42
21 | rng = np.random.default_rng(random_state)
22 |
23 | SAMPLE_INTERVAL = 0.5
24 | BG_RATE = 300
25 | seeds_ss = get_dummy_seeds(100)
26 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg()
27 | mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng)\
28 | .generate(1)
29 | events = PassbySynthesizer(events_per_seed=1,
30 | sample_interval=SAMPLE_INTERVAL,
31 | bg_cps=BG_RATE,
32 | fwhm_function_args=(5,),
33 | dwell_time_function_args=(20, 20),
34 | snr_function_args=(20, 20),
35 | return_gross=True,
36 | rng=rng)\
37 | .generate(fg_seeds_ss, mixed_bg_seeds_ss, verbose=False)
38 |
39 | _, gross_events = list(zip(*events))
40 | passby_ss = gross_events[0]
41 |
42 | expected_bg_counts = SAMPLE_INTERVAL * BG_RATE
43 | expected_bg_measurement = mixed_bg_seeds_ss.spectra.iloc[0] * expected_bg_counts
44 | ed = PoissonNChannelEventDetector(
45 | long_term_duration=600,
46 | short_term_duration=10,
47 | pre_event_duration=1,
48 | max_event_duration=120,
49 | post_event_duration=10,
50 | tolerable_false_alarms_per_day=1e-5,
51 | anomaly_threshold_update_interval=60,
52 | )
53 | cps_history = []
54 |
55 | # Filling background
56 | measurement_id = 0
57 | while ed.background_percent_complete < 100:
58 | noisy_bg_measurement = np.random.poisson(expected_bg_measurement)
59 | cps_history.append(noisy_bg_measurement.sum() / SAMPLE_INTERVAL)
60 | _ = ed.add_measurement(
61 | measurement_id,
62 | noisy_bg_measurement,
63 | SAMPLE_INTERVAL,
64 | verbose=False
65 | )
66 | measurement_id += 1
67 |
68 | # Create event using a synthesized passby
69 | for i in range(passby_ss.n_samples):
70 | gross_spectrum = passby_ss.spectra.iloc[i].values
71 | cps_history.append(gross_spectrum.sum() / SAMPLE_INTERVAL)
72 | event_result = ed.add_measurement(
73 | measurement_id=measurement_id,
74 | measurement=gross_spectrum,
75 | duration=SAMPLE_INTERVAL,
76 | verbose=False
77 | )
78 | measurement_id += 1
79 | if event_result:
80 | break
81 |
82 | self.assertTrue(event_result is not None)
83 |
--------------------------------------------------------------------------------
/tests/gadras_tests.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module tests the gadras module."""
5 | import unittest
6 |
7 | import pandas as pd
8 |
9 | from riid import get_dummy_seeds
10 | from riid.gadras.pcf import (_pack_compressed_text_buffer,
11 | _unpack_compressed_text_buffer)
12 |
13 |
14 | class TestGadras(unittest.TestCase):
15 | """Test class for Gadras."""
16 |
17 | def test_pcf_header_formatting(self):
18 | """Tests the PCF header information is properly packed.
19 |
20 | String formatting note:
21 | f"{'hello':60.50}"
22 | ^ ^ ^
23 | | | '> The number of letters to allow from the input value
24 | | '> The length of the string
25 | '> The input value to format
26 |
27 | """
28 | FIELD_LENGTH = 10
29 | test_cases = [
30 | (
31 | "tttttttt ", "ddddddddd ", "ssssssssss",
32 | "tttttttt", "ddddddddd", "ssssssssss",
33 | "tttttttt ddddddddd ssssssssss",
34 | ),
35 | (
36 | "tttttttttt", "dddddddddd", "ssssssssss+",
37 | "tttttttttt", "", "ssssssssss+",
38 | "ÿttttttttttÿÿssssssssss+ "
39 | ),
40 | (
41 | "tttttttt ", "dddddddd ", "ssssssssss+",
42 | "tttttttt", "dddddddd", "ssssssssss+",
43 | "ÿttttttttÿddddddddÿssssssssss+"
44 | ),
45 | (
46 | "tt ", "dddddddddd+", "ssssssssss++",
47 | "tt", "dddddddddd+", "ssssssssss++",
48 | "ÿttÿdddddddddd+ÿssssssssss++ "
49 | ),
50 | ]
51 | for case in test_cases:
52 | title, desc, source, \
53 | expected_title, expected_desc, expected_source, \
54 | expected_ctb = case
55 | actual_ctb = _pack_compressed_text_buffer(
56 | title,
57 | desc,
58 | source,
59 | field_len=FIELD_LENGTH
60 | )
61 | actual_title, actual_desc, actual_source = _unpack_compressed_text_buffer(
62 | actual_ctb,
63 | field_len=FIELD_LENGTH
64 | )
65 | self.assertEqual(expected_title, actual_title)
66 | self.assertEqual(expected_desc, actual_desc)
67 | self.assertEqual(expected_source, actual_source)
68 | self.assertEqual(expected_ctb, actual_ctb)
69 |
70 | def test_to_pcf_with_various_sources_dataframes(self):
71 | TEMP_PCF_PATH = "temp.pcf"
72 |
73 | # With all levels
74 | ss = get_dummy_seeds()
75 | ss.to_pcf(TEMP_PCF_PATH, verbose=False)
76 |
77 | # Without seed level (only category and isotope)
78 | ss = get_dummy_seeds()
79 | ss.sources.columns.droplevel("Seed")
80 | ss.to_pcf(TEMP_PCF_PATH, verbose=False)
81 |
82 | # Without seed and isotope levels (only category)
83 | ss = get_dummy_seeds()
84 | ss.sources.columns.droplevel("Seed")
85 | ss.sources.columns.droplevel("Isotope")
86 | ss.to_pcf(TEMP_PCF_PATH, verbose=False)
87 |
88 | # With no sources
89 | ss = get_dummy_seeds()
90 | ss.sources = pd.DataFrame()
91 | ss.to_pcf(TEMP_PCF_PATH, verbose=False)
92 |
93 |
94 | if __name__ == "__main__":
95 | unittest.main()
96 |
--------------------------------------------------------------------------------
/tests/model_tests.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module tests the bayes module."""
5 | import os
6 | import unittest
7 |
8 | import numpy as np
9 | import pandas as pd
10 |
11 | from riid import SampleSet, SeedMixer, StaticSynthesizer, get_dummy_seeds
12 | from riid.models import (ARADLatentPredictor, ARADv1, ARADv2,
13 | LabelProportionEstimator, MLPClassifier,
14 | PoissonBayesClassifier)
15 | from riid.models.base import PyRIIDModel
16 | from riid.models.bayes import NegativeSpectrumError, ZeroTotalCountsError
17 |
18 |
19 | class TestModels(unittest.TestCase):
20 | """Test class for PyRIID models."""
21 | def setUp(self):
22 | """Test setup."""
23 | pass
24 |
25 | @classmethod
26 | def setUpClass(self):
27 | self.seeds_ss = get_dummy_seeds(n_channels=128)
28 | self.fg_seeds_ss, self.bg_seeds_ss = self.seeds_ss.split_fg_and_bg()
29 | self.mixed_bg_seeds_ss = SeedMixer(self.bg_seeds_ss, mixture_size=3).generate(1)
30 | self.static_synth = StaticSynthesizer(samples_per_seed=5)
31 | self.train_ss, _ = self.static_synth.generate(self.fg_seeds_ss, self.mixed_bg_seeds_ss,
32 | verbose=False)
33 | self.train_ss.prediction_probas = self.train_ss.sources
34 | self.train_ss.normalize()
35 | self.test_ss, _ = self.static_synth.generate(self.fg_seeds_ss, self.mixed_bg_seeds_ss,
36 | verbose=False)
37 | self.test_ss.normalize()
38 |
39 | @classmethod
40 | def tearDownClass(self):
41 | pass
42 |
43 | def test_pb_constructor_errors(self):
44 | """Testing for constructor errors when different arguments are provided."""
45 | pb_model = PoissonBayesClassifier()
46 |
47 | # Empty argument provided
48 | spectra = np.array([])
49 | ss = SampleSet()
50 | ss.spectra = pd.DataFrame(spectra)
51 | self.assertRaises(ValueError, pb_model.fit, ss)
52 |
53 | # Negative channel argument provided
54 | spectra = np.array([
55 | [1, 1, 1, 1],
56 | [1, 1, -1, 1],
57 | [1, 1, 1, 1]
58 | ])
59 | ss = SampleSet()
60 | ss.spectra = pd.DataFrame(spectra)
61 | self.assertRaises(NegativeSpectrumError, pb_model.fit, ss)
62 |
63 | # Zero total counts argument provided
64 | spectra = np.array([
65 | [1, 1, 1, 1],
66 | [0, 0, 0, 0],
67 | [1, 1, 1, 1]
68 | ])
69 | ss = SampleSet()
70 | ss.spectra = pd.DataFrame(spectra)
71 | self.assertRaises(ZeroTotalCountsError, pb_model.fit, ss)
72 |
73 | def test_pb_predict(self):
74 | """Tests the constructor with a valid SampleSet."""
75 | seeds_ss = get_dummy_seeds()
76 | fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg()
77 | bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(1)
78 |
79 | # Create the PoissonBayesClassifier
80 | pb_model = PoissonBayesClassifier()
81 | pb_model.fit(fg_seeds_ss)
82 |
83 | # Get test samples
84 | gss = StaticSynthesizer(
85 | samples_per_seed=2,
86 | live_time_function_args=(4, 4),
87 | snr_function_args=(10, 10),
88 | rng=np.random.default_rng(42),
89 | return_fg=True,
90 | return_gross=True,
91 | )
92 | test_fg_ss, test_gross_ss = gss.generate(fg_seeds_ss, bg_seeds_ss, verbose=False)
93 | test_bg_ss = test_gross_ss - test_fg_ss
94 |
95 | # Predict
96 | pb_model.predict(test_gross_ss, test_bg_ss)
97 |
98 | truth_labels = test_fg_ss.get_labels()
99 | predicted_labels = test_gross_ss.get_predictions()
100 | assert (truth_labels == predicted_labels).all()
101 |
102 | def test_pb_fit_save_load(self):
103 | _test_model_fit_save_load_predict(self, PoissonBayesClassifier, None, self.fg_seeds_ss)
104 |
105 | def test_mlp_fit_save_load_predict(self):
106 | _test_model_fit_save_load_predict(self, MLPClassifier, self.test_ss, self.train_ss,
107 | epochs=1)
108 |
109 | def test_lpe_fit_save_load_predict(self):
110 | _test_model_fit_save_load_predict(self, LabelProportionEstimator, self.test_ss,
111 | self.fg_seeds_ss, self.train_ss, epochs=1)
112 |
113 | def test_aradv1_fit_save_load_predict(self):
114 | _test_model_fit_save_load_predict(self, ARADv1, self.test_ss, self.train_ss, epochs=1)
115 |
116 | def test_aradv2_fit_save_load_predict(self):
117 | _test_model_fit_save_load_predict(self, ARADv2, self.test_ss, self.train_ss, epochs=1)
118 |
119 | def test_alp_fit_save_load_predict(self):
120 | arad_v2 = ARADv2()
121 | arad_v2.fit(self.train_ss, epochs=1)
122 | _test_model_fit_save_load_predict(self, ARADLatentPredictor, self.test_ss, arad_v2.model,
123 | self.train_ss, target_info_columns=["snr"], epochs=1)
124 |
125 |
126 | def _try_remove_model_and_info(model_path: str):
127 | if os.path.exists(model_path):
128 | if os.path.isdir(model_path):
129 | os.rmdir(model_path)
130 | else:
131 | os.remove(model_path)
132 |
133 |
134 | def _test_model_fit_save_load_predict(test_case: unittest.TestCase, model_class: PyRIIDModel,
135 | test_ss: SampleSet = None, *args_for_fit, **kwargs_for_fit):
136 | m1 = model_class()
137 | m2 = model_class()
138 |
139 | m1.fit(*args_for_fit, **kwargs_for_fit)
140 |
141 | model_path = m1._temp_file_path
142 |
143 | _try_remove_model_and_info(model_path)
144 | test_case.assertRaises(ValueError, m2.load, model_path)
145 |
146 | m1.save(model_path)
147 | test_case.assertRaises(ValueError, m1.save, model_path)
148 |
149 | m2.load(model_path)
150 | _try_remove_model_and_info(model_path)
151 |
152 | if test_ss is not None:
153 | m1.predict(test_ss)
154 |
155 |
156 | if __name__ == "__main__":
157 | unittest.main()
158 |
--------------------------------------------------------------------------------
/tests/seedmixer_tests.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module tests the sampleset module."""
5 | import unittest
6 |
7 | import numpy as np
8 | from scipy.spatial.distance import jensenshannon
9 |
10 | from riid import SampleSet, SeedMixer, get_dummy_seeds
11 |
12 |
13 | class TestSeedMixer(unittest.TestCase):
14 | """Test seed mixing functionality of SampleSet.
15 | """
16 | @classmethod
17 | def setUpClass(self):
18 | random_state = 42
19 | self.rng = np.random.default_rng(random_state)
20 |
21 | self.ss, _ = get_dummy_seeds().split_fg_and_bg()
22 | self.ss.normalize()
23 | self.sources = self.ss.get_labels().values
24 |
25 | self.two_mix_seeds_ss = SeedMixer(
26 | self.ss,
27 | mixture_size=2,
28 | dirichlet_alpha=10,
29 | ).generate(n_samples=20)
30 |
31 | self.three_mix_seeds_ss = SeedMixer(
32 | self.ss,
33 | mixture_size=3,
34 | dirichlet_alpha=10,
35 | ).generate(n_samples=20)
36 |
37 | def test_mixture_combinations(self):
38 | # check that each mixture contains unique isotopes and the correct mixture size
39 | two_mix_isotopes = [
40 | x.split(" + ")
41 | for x in self.two_mix_seeds_ss.get_labels(target_level="Isotope", max_only=False)
42 | ]
43 | self.assertTrue(all([len(set(x)) == 2 for x in two_mix_isotopes]))
44 |
45 | three_mix_isotopes = [
46 | x.split(" + ")
47 | for x in self.three_mix_seeds_ss.get_labels(target_level="Isotope", max_only=False)
48 | ]
49 | self.assertTrue(all([len(set(x)) == 3 for x in three_mix_isotopes]))
50 |
51 | def test_mixture_ratios(self):
52 | # check for valid probability distribution
53 | for each in self.two_mix_seeds_ss.get_source_contributions(target_level="Isotope"):
54 | self.assertAlmostEqual(each.sum(), 1.0)
55 |
56 | for each in self.three_mix_seeds_ss.get_source_contributions(target_level="Isotope"):
57 | self.assertAlmostEqual(each.sum(), 1.0)
58 |
59 | def test_mixture_number(self):
60 | # check that number of samples is less than largest possible combinations
61 | # (worst case scenario)
62 | self.assertEqual(self.two_mix_seeds_ss.n_samples, 20)
63 | self.assertEqual(self.two_mix_seeds_ss.n_samples, 20)
64 |
65 | def test_mixture_pdf(self):
66 | # check that each mixture sums to one
67 | for sample in range(self.two_mix_seeds_ss.n_samples):
68 | self.assertAlmostEqual(self.two_mix_seeds_ss.spectra.values[sample, :].sum(), 1.0)
69 |
70 | for sample in range(self.three_mix_seeds_ss.n_samples):
71 | self.assertAlmostEqual(self.three_mix_seeds_ss.spectra.values[sample, :].sum(), 1.0)
72 |
73 | def test_spectrum_construction_3seeds_2mix(self):
74 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg()
75 | mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2, rng=self.rng).generate(100)
76 | spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss)
77 | self.assertTrue(np.isclose(spectral_distances, 0.0).all())
78 |
79 | def test_spectrum_construction_3seeds_3mix(self):
80 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg()
81 | mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=self.rng).generate(100)
82 | spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss)
83 | self.assertTrue(np.isclose(spectral_distances, 0.0).all())
84 |
85 | def test_spectrum_construction_2seeds_2mix(self):
86 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg(
87 | bg_seed_names=SampleSet.DEFAULT_BG_SEED_NAMES[1:3]
88 | )
89 | mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2, rng=self.rng).generate(100)
90 | spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss)
91 | self.assertTrue(np.isclose(spectral_distances, 0.0).all())
92 |
93 | def test_spectrum_construction_2seeds_2mix_error(self):
94 | _, bg_seeds_ss = get_dummy_seeds(n_channels=16).split_fg_and_bg(
95 | bg_seed_names=SampleSet.DEFAULT_BG_SEED_NAMES[1:3]
96 | )
97 | mixer = SeedMixer(bg_seeds_ss, mixture_size=3)
98 | self.assertRaises(ValueError, mixer.generate, 100)
99 |
100 |
101 | def _get_spectral_distances(seeds_ss, mixed_ss):
102 | recon_spectra = np.dot(
103 | seeds_ss.spectra.values.T,
104 | mixed_ss.sources.values.T
105 | ).T
106 | spectral_distances = jensenshannon(
107 | recon_spectra,
108 | mixed_ss.spectra.values,
109 | axis=1
110 | )
111 | spectral_distances = np.nan_to_num(spectral_distances, nan=0.0)
112 | return spectral_distances
113 |
114 |
115 | if __name__ == "__main__":
116 | unittest.main()
117 |
--------------------------------------------------------------------------------
/tests/visualize_tests.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
2 | # Under the terms of Contract DE-NA0003525 with NTESS,
3 | # the U.S. Government retains certain rights in this software.
4 | """This module tests the visualize module."""
5 | import unittest
6 |
7 | import numpy as np
8 |
9 | from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds
10 | from riid.metrics import precision_recall_curve
11 | from riid.models import MLPClassifier
12 | from riid.visualize import (plot_correlation_between_all_labels,
13 | plot_count_rate_history,
14 | plot_label_and_prediction_distributions,
15 | plot_label_distribution, plot_learning_curve,
16 | plot_live_time_vs_snr, plot_precision_recall,
17 | plot_prediction_distribution,
18 | plot_score_distribution, plot_snr_vs_score,
19 | plot_spectra, plot_ss_comparison)
20 |
21 |
22 | class TestVisualize(unittest.TestCase):
23 | """Testing plot functions in the visualize module."""
24 | @classmethod
25 | def setUpClass(self):
26 | """Test setup."""
27 | self.fg_seeds_ss, self.bg_seeds_ss = get_dummy_seeds().split_fg_and_bg()
28 | self.mixed_bg_seed_ss = SeedMixer(self.bg_seeds_ss, mixture_size=3).generate(10)
29 |
30 | self.static_synth = StaticSynthesizer(
31 | samples_per_seed=100,
32 | snr_function="log10",
33 | return_fg=False,
34 | return_gross=True
35 | )
36 | _, self.train_ss = self.static_synth.generate(
37 | self.fg_seeds_ss,
38 | self.mixed_bg_seed_ss,
39 | verbose=False
40 | )
41 | self.train_ss.normalize()
42 |
43 | model = MLPClassifier()
44 | self.history = model.fit(self.train_ss, epochs=10, patience=5).history
45 | model.predict(self.train_ss)
46 |
47 | # Generate some test data
48 | self.static_synth.samples_per_seed = 50
49 | _, self.test_ss = self.static_synth.generate(
50 | self.fg_seeds_ss,
51 | self.mixed_bg_seed_ss,
52 | verbose=False
53 | )
54 | self.test_ss.normalize()
55 | model.predict(self.test_ss)
56 |
57 | def test_plot_live_time_vs_snr(self):
58 | """Plots SNR against live time for all samples in a SampleSet.
59 | Prediction and label information is used to distinguish between correct and incorrect
60 | classifications using color (green for correct, red for incorrect).
61 | """
62 | plot_live_time_vs_snr(self.test_ss, show=False)
63 | plot_live_time_vs_snr(self.train_ss, self.test_ss, show=False)
64 |
65 | def test_plot_snr_vs_score(self):
66 | """Plots SNR against prediction score for all samples in a SampleSet.
67 | Prediction and label information is used to distinguish between correct and incorrect
68 | classifications using color (green for correct, red for incorrect).
69 | """
70 | plot_snr_vs_score(self.train_ss, show=False)
71 | plot_snr_vs_score(self.train_ss, self.test_ss, show=False)
72 |
73 | def test_plot_spectra(self):
74 | """Plots the spectra contained with a SampleSet."""
75 | plot_spectra(self.fg_seeds_ss, ylim=(None, None), in_energy=False, show=False)
76 | plot_spectra(self.fg_seeds_ss, ylim=(None, None), in_energy=True, show=False)
77 |
78 | def test_plot_learning_curve(self):
79 | """Plots training and validation loss curves."""
80 | plot_learning_curve(self.history["loss"],
81 | self.history["val_loss"],
82 | show=False)
83 | plot_learning_curve(self.history["loss"],
84 | self.history["val_loss"],
85 | smooth=True,
86 | show=False)
87 |
88 | def test_plot_count_rate_history(self):
89 | """Plots a count rate history."""
90 | counts = np.random.normal(size=1000)
91 | histogram, _ = np.histogram(counts, bins=100, range=(0, 100))
92 | plot_count_rate_history(histogram, 1, 80, 20, show=False)
93 |
94 | def test_plot_label_and_prediction_distributions(self):
95 | """Plots distributions of data labels, predictions, and prediction scores."""
96 | plot_score_distribution(self.test_ss, show=False)
97 | plot_label_distribution(self.test_ss, show=False)
98 | plot_prediction_distribution(self.test_ss, show=False)
99 | plot_label_and_prediction_distributions(self.test_ss, show=False)
100 |
101 | def test_plot_correlation_between_all_labels(self):
102 | """Plots a correlation matrix of each label against each other label."""
103 | plot_correlation_between_all_labels(self.bg_seeds_ss, show=False)
104 | plot_correlation_between_all_labels(self.bg_seeds_ss, mean=True, show=False)
105 |
106 | def test_plot_precision_recall(self):
107 | """Plots the multi-class or multi-label Precision-Recall curve and marks the optimal
108 | F1 score for each class.
109 | """
110 | precision, recall, _ = precision_recall_curve(self.test_ss)
111 | plot_precision_recall(precision=precision, recall=recall, show=False)
112 |
113 | def test_plot_ss_comparison(self):
114 | """Creates a plot for output from SampleSet.compare_to()."""
115 | SYNTHETIC_DATA_CONFIG = {
116 | "samples_per_seed": 100,
117 | "bg_cps": 100,
118 | "snr_function": "uniform",
119 | "snr_function_args": (1, 100),
120 | "live_time_function": "uniform",
121 | "live_time_function_args": (0.25, 10),
122 | "apply_poisson_noise": True,
123 | "return_fg": False,
124 | "return_gross": True,
125 | }
126 |
127 | _, gross_ss1 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\
128 | .generate(self.fg_seeds_ss, self.mixed_bg_seed_ss, verbose=False)
129 | _, gross_ss2 = StaticSynthesizer(**SYNTHETIC_DATA_CONFIG)\
130 | .generate(self.fg_seeds_ss, self.mixed_bg_seed_ss, verbose=False)
131 |
132 | # Compare two different gross sample sets - Live time
133 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2,
134 | density=False)
135 | plot_ss_comparison(ss1_stats,
136 | ss2_stats,
137 | col_comparisons,
138 | "live_time",
139 | show=False)
140 |
141 | # Compare two different gross sample sets - Total Counts
142 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss2,
143 | density=True)
144 | plot_ss_comparison(ss1_stats,
145 | ss2_stats,
146 | col_comparisons,
147 | "total_counts",
148 | show=False)
149 |
150 | # Compare the same sampleset to itself - Live time
151 | ss1_stats, ss2_stats, col_comparisons = gross_ss1.compare_to(gross_ss1,
152 | density=False)
153 | plot_ss_comparison(ss1_stats,
154 | ss2_stats,
155 | col_comparisons,
156 | "live_time",
157 | show=False)
158 |
159 | # Compare the same sampleset to itself - Total Counts
160 | ss1_stats, ss2_stats, col_comparisons = gross_ss2.compare_to(gross_ss2,
161 | density=True)
162 | plot_ss_comparison(ss1_stats,
163 | ss2_stats,
164 | col_comparisons,
165 | "total_counts",
166 | show=False)
167 |
168 |
169 | if __name__ == "__main__":
170 | unittest.main()
171 |
--------------------------------------------------------------------------------