├── conftest.py
├── test-requirements.txt
├── doc
├── _static
│ ├── git.png
│ ├── logo.png
│ └── RTD-advanced-conf.png
├── _templates
│ ├── function.rst
│ └── class.rst
├── api.rst
├── theory.rst
├── index.rst
├── sphinxext
│ ├── math_dollar.py
│ ├── github.py
│ ├── numpydoc.py
│ └── docscrape_sphinx.py
├── Makefile
└── conf.py
├── requirements.txt
├── annsa
├── tests
│ ├── data_folder
│ │ ├── rocky_flats.npy
│ │ ├── gadras_template.npy
│ │ ├── rocky_flats_spectra.spe
│ │ └── gadras_template.spe
│ ├── test_readspectra.py
│ ├── test_templatesampling.py
│ ├── test_uenrich_data_gen.py
│ ├── test_training_dae.py
│ ├── test_training_cae.py
│ ├── test_baseclass.py
│ ├── test_training_dnn.py
│ └── test_training_cnn1d.py
├── __init__.py
├── annsa.py
├── due.py
├── version.py
├── generate_uranium_templates.py
├── load_dataset.py
└── load_pretrained_network.py
├── examples
├── source-interdiction
│ ├── results-notebooks
│ │ ├── README.md
│ │ └── aux_functions.py
│ ├── hyperparameter-search
│ │ ├── README.md
│ │ ├── DNN-Hyperparameter-Search-Easy.ipynb
│ │ ├── DNN-Hyperparameter-Search-Full.ipynb
│ │ ├── CNN-Hyperparameter-Search-Easy.ipynb
│ │ ├── CNN-Hyperparameter-Search-Full.ipynb
│ │ └── hyperparameter_models.py
│ ├── training-notebooks
│ │ ├── README.md
│ │ ├── cae-manual-easy.ipynb
│ │ └── cae-manual-full.ipynb
│ └── dataset-generation
│ │ ├── README.md
│ │ ├── Dataset-Generation-Simple.ipynb
│ │ └── Dataset-Generation-Complete.ipynb
├── uranium-enrichment
│ ├── dataset-generation
│ │ ├── README.md
│ │ └── Dataset_Generation.ipynb
│ └── training-notebooks
│ │ └── README.md
└── README.txt
├── setup.cfg
├── citation.md
├── Makefile
├── setup.py
├── README.md
├── LICENSE
├── .circleci
└── config.yml
├── .gitignore
└── contributing.md
/conftest.py:
--------------------------------------------------------------------------------
1 | #
2 |
--------------------------------------------------------------------------------
/test-requirements.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | pytest-flake8
3 |
--------------------------------------------------------------------------------
/doc/_static/git.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arfc/annsa/HEAD/doc/_static/git.png
--------------------------------------------------------------------------------
/doc/_static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arfc/annsa/HEAD/doc/_static/logo.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
3 | matplotlib
4 | pandas
5 | tensorflow
6 | scikit-learn
7 | pandas
8 |
--------------------------------------------------------------------------------
/doc/_static/RTD-advanced-conf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arfc/annsa/HEAD/doc/_static/RTD-advanced-conf.png
--------------------------------------------------------------------------------
/annsa/tests/data_folder/rocky_flats.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arfc/annsa/HEAD/annsa/tests/data_folder/rocky_flats.npy
--------------------------------------------------------------------------------
/annsa/tests/data_folder/gadras_template.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arfc/annsa/HEAD/annsa/tests/data_folder/gadras_template.npy
--------------------------------------------------------------------------------
/annsa/tests/data_folder/rocky_flats_spectra.spe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/arfc/annsa/HEAD/annsa/tests/data_folder/rocky_flats_spectra.spe
--------------------------------------------------------------------------------
/examples/source-interdiction/results-notebooks/README.md:
--------------------------------------------------------------------------------
1 | ## Results Notebooks
2 | These notebooks were used to generate the results for the dissertation.
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | # pytest-flake8 ignores:
2 | # E501 : line too long
3 | # W504 : break after binary operator
4 | [tool:pytest]
5 | flake8-ignore = E501 W504
6 |
--------------------------------------------------------------------------------
/annsa/__init__.py:
--------------------------------------------------------------------------------
1 | # from __future__ import absolute_import, division, print_function
2 | # from .version import __version__ # noqa
3 | from .annsa import * # noqa
4 |
--------------------------------------------------------------------------------
/examples/uranium-enrichment/dataset-generation/README.md:
--------------------------------------------------------------------------------
1 | # Dataset Generation
2 | THis notebook generates the training dataset for the uranium enrichment regression problem.
--------------------------------------------------------------------------------
/examples/uranium-enrichment/training-notebooks/README.md:
--------------------------------------------------------------------------------
1 | # Uranium Enrichment Training
2 | These notebooks were used to train models to perform uranium enrichment regression.
--------------------------------------------------------------------------------
/examples/source-interdiction/hyperparameter-search/README.md:
--------------------------------------------------------------------------------
1 | ## Hyperparameter Search
2 | These notebooks were used to perform the hyperparameter search. For the dissertation, multiple hyperparameter search notebooks were run in parallel for each model and dataset combination.
--------------------------------------------------------------------------------
/doc/_templates/function.rst:
--------------------------------------------------------------------------------
1 | {{ fullname }}
2 | {{ underline }}
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autofunction:: {{ objname }}
7 |
8 | .. include:: {{module}}.{{objname}}.examples
9 |
10 | .. raw:: html
11 |
12 |
13 |
--------------------------------------------------------------------------------
/examples/source-interdiction/training-notebooks/README.md:
--------------------------------------------------------------------------------
1 | ### Training
2 | These notebooks were used when training the source interdiction ANNs. Notebooks include those to pretrain the autoencoders and those to train each model with an increasing number of examples to create the learning curves.
--------------------------------------------------------------------------------
/examples/README.txt:
--------------------------------------------------------------------------------
1 | ## Dissertation Notebooks
2 |
3 | These examples outline the two projects performed for the dissertation. The first project is an isotope classification problem for source interdiction using neural networks. The second project demonstrates performing uranium enrichment regression using the same machine learning models.
4 |
--------------------------------------------------------------------------------
/doc/api.rst:
--------------------------------------------------------------------------------
1 | API
2 | ===
3 |
4 |
5 | Classes
6 | -------
7 |
8 | .. currentmodule:: annsa
9 |
10 | .. autosummary::
11 | :template: class.rst
12 | :toctree: gen_api
13 |
14 | DNN
15 |
16 |
17 | Functions
18 | ---------
19 |
20 | .. autosummary::
21 | :template: function.rst
22 | :toctree: gen_api
23 |
24 | f1_error
25 |
--------------------------------------------------------------------------------
/citation.md:
--------------------------------------------------------------------------------
1 | Mark Kamuda, annsa, (2018), GitHub repository, https://github.com/kamuda1/annsa
2 |
3 |
4 |
5 | ```latex
6 | @misc{kamuda2018,
7 | author = {Kamuda, Mark},
8 | title = {annsa},
9 | year = {2018},
10 | publisher = {GitHub},
11 | journal = {GitHub repository},
12 | howpublished = {\url{https://github.com/kamuda1/annsa}}
13 | }
14 | ```
15 |
--------------------------------------------------------------------------------
/doc/_templates/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname }}
2 | {{ underline }}
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 | :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__,__hash__
8 |
9 | .. include:: {{module}}.{{objname}}.examples
10 |
11 | .. raw:: html
12 |
13 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | flake8:
2 | @if command -v flake8 > /dev/null; then \
3 | echo "Running flake8"; \
4 | flake8 flake8 --ignore N802,N806 `find . -name \*.py | grep -v setup.py | grep -v /doc/`; \
5 | else \
6 | echo "flake8 not found, please install it!"; \
7 | exit 1; \
8 | fi;
9 | @echo "flake8 passed"
10 |
11 | test:
12 | py.test --pyargs shablona --cov-report term-missing --cov=shablona
13 |
--------------------------------------------------------------------------------
/doc/theory.rst:
--------------------------------------------------------------------------------
1 |
2 | This is a theory section
3 | ========================
4 |
5 | I might want to describe some equations:
6 |
7 | .. math::
8 |
9 | \int_0^\infty e^{-x^2} dx=\frac{\sqrt{\pi}}{2}
10 |
11 |
12 | And refer to a paper [author2015]_.
13 |
14 |
15 | .. [author2015] first a., second a., cheese b. (2015). The title of their
16 | paper. Journal of papers, *15*: 1023-1049.
17 |
18 |
--------------------------------------------------------------------------------
/examples/source-interdiction/dataset-generation/README.md:
--------------------------------------------------------------------------------
1 | ### Dataset Generation
2 | Using these notebooks, datasets can be constructed with a various simulation parameters. For the dissertation, both a wide range of parameters and a simplified range of parameters were used.
3 |
4 | Note, due to the inclusion of shielding in the complete dataset, some source templates had zero spectral counts and were dropped when loading the templates.
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | PACKAGES = find_packages()
3 |
4 | with open("README.md", "r") as fh:
5 | long_description = fh.read()
6 |
7 | opts = dict(name='annsa',
8 | maintainer='Mark Kamuda',
9 | maintainer_email='kamuda1@illinois.edu',
10 | description='Neural networks applied to gamma-ray spectroscopy',
11 | long_description=long_description,
12 | url='https://github.com/arfc/annsa',
13 | license='BSD 3',
14 | author='Mark Kamuda',
15 | author_email='kamuda1@illinois.edu',
16 | version='0.1dev',
17 | packages=find_packages(),
18 | install_requires=['tensorflow', 'numpy', 'scipy'])
19 |
20 |
21 | if __name__ == '__main__':
22 | setup(**opts)
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## annsa
2 |
3 | Artificial neural networks for spectroscopic analysis (annsa) is a python package can quickly prototype gamma-ray spectra datasets and machine learning and data science experiments for these datasets.
4 |
5 | ### Install
6 |
7 | To install, either git clone the annsa repo or download and unzip the zip file. In your command line run:
8 |
9 | ```
10 | cd annsa
11 | python setup.py install
12 | ```
13 |
14 | #### Standard Naming Convention for Spectrum Files
15 | Ex.
16 | "99MTc\_500.0\_50.0\_lead\_0.0\_2.0.spe"
17 |
18 | This is for a spectrum taken with the following parameters
19 | Technetium-99, metastable
20 | 500cm away from detector
21 | 50cm above ground
22 | Lead shielding
23 | 'Areal' density of shielding
24 | Full width half max of 2.0
25 | File type: .spe
26 |
27 |
28 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | .. annsa documentation master file, created by sphinx-quickstart on Tue Apr 14 10:29:06 2015. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive.
2 |
3 | Welcome to annsa's documentation!
4 | ====================================
5 |
6 | Artificial neural networks for spectroscopic analysis (`annsa`) is a python
7 | package can quickly prototype gamma-ray spectra datasets and machine learning
8 | and data science experiments for these datasets.
9 |
10 | To see how to use it, please refer to the `README file
11 | `_ in the Github repository.
12 |
13 | This is an example of documentation of the software, using sphinx_.
14 |
15 | .. _sphinx: http://sphinx-doc.org/
16 |
17 |
18 | Contents:
19 |
20 | .. toctree::
21 | :maxdepth: 2
22 |
23 | api
24 | theory
25 |
--------------------------------------------------------------------------------
/annsa/tests/test_readspectra.py:
--------------------------------------------------------------------------------
1 | import os
2 | from numpy import load
3 | from annsa.annsa import read_spectrum
4 |
5 |
6 | def test_read_spectrum_rocky_flats():
7 | truedata_filename = os.path.join(os.path.dirname(__file__), './data_folder/rocky_flats.npy')
8 | true_spectrum = load(truedata_filename)
9 |
10 | annsa_data_filename = os.path.join(os.path.dirname(__file__), './data_folder/rocky_flats_spectra.spe')
11 | annsa_spectrum = read_spectrum(annsa_data_filename)
12 | assert (annsa_spectrum == true_spectrum).all()
13 |
14 |
15 | def test_read_spectrum_gadras_template():
16 | truedata_filename = os.path.join(os.path.dirname(__file__), './data_folder/gadras_template.npy')
17 | true_spectrum = load(truedata_filename)
18 |
19 | annsa_data_filename = os.path.join(os.path.dirname(__file__), './data_folder/gadras_template.spe')
20 | annsa_spectrum = read_spectrum(annsa_data_filename, float)
21 | assert (annsa_spectrum == true_spectrum).all()
22 |
--------------------------------------------------------------------------------
/annsa/annsa.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 |
4 | def read_spectrum(filename,
5 | return_type=int):
6 | """
7 | Reads a .spe file into a numpy array.
8 |
9 | Parameters:
10 | -----------
11 | filename : string
12 | Filename with .spe extension
13 | return_type : int or float
14 | Type of number to return
15 | Returns:
16 | --------
17 | spectrum : vector
18 | The source spectrum
19 |
20 | """
21 |
22 | try:
23 | with open(filename, 'r') as myFile:
24 | filecontent = myFile.readlines()
25 | for index, line in enumerate(filecontent):
26 | if '$DATA:' in line:
27 | break
28 | spec_len_index = index + 1
29 | spec_index = index + 2
30 | spec_len = filecontent[spec_len_index]
31 | except UnicodeDecodeError:
32 | with open(filename, 'rb') as myFile:
33 | filecontent = myFile.readlines()
34 | for index, line in enumerate(filecontent):
35 | if b'$DATA:' in line:
36 | break
37 | spec_len_index = index + 1
38 | spec_index = index + 2
39 | spec_len = filecontent[spec_len_index].decode()[:-2]
40 | else:
41 | print('spe in unknown encoding')
42 |
43 | spec_len = int(spec_len.split(' ')[1]) + 1
44 | spectrum = [return_type(x) for x in filecontent[spec_index:
45 | spec_index + spec_len]]
46 | return spectrum
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2018, Mark Kamuda
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/annsa/tests/test_templatesampling.py:
--------------------------------------------------------------------------------
1 | from annsa.template_sampling import (rebin_spectrum,
2 | apply_LLD,
3 | construct_spectrum,)
4 | from numpy.testing import assert_almost_equal
5 | import numpy as np
6 | import pytest
7 |
8 |
9 | @pytest.fixture(scope="session")
10 | def spectrum():
11 | # define lamba = 1000
12 | # define size = 1x1024(the number of channels)
13 | dim = 1024
14 | lam = 1000
15 | spectrum = np.random.poisson(lam=lam, size=dim)
16 | return spectrum
17 |
18 |
19 | # rebinning unit test
20 | def test_rebinning_size(spectrum):
21 | output_len = 512
22 | spectrum_rebinned = rebin_spectrum(spectrum,
23 | output_len=output_len)
24 | assert(len(spectrum_rebinned) == output_len)
25 |
26 |
27 | # LLD test
28 | def test_lld(spectrum):
29 | lld = 10
30 | spectrum_lld = apply_LLD(spectrum, LLD=lld)
31 | assert(np.sum(spectrum_lld[0:lld]) == 0)
32 |
33 |
34 | # construct spectrum test
35 | def test_construct_spectrum_test_rescale_case1(spectrum):
36 | """case 1: Check if rescale returns correctly scaled template"""
37 | spectrum_counts = 10
38 | spectrum_rescaled = construct_spectrum(
39 | spectrum,
40 | spectrum_counts=spectrum_counts,)
41 | assert_almost_equal(np.sum(spectrum_rescaled), 10.0)
42 |
43 |
44 | def test_construct_spectrum_test_rescale_case2(spectrum):
45 | """case 2: Check if rescale returns values above zero"""
46 | spectrum_counts = 10
47 | spectrum_rescaled = construct_spectrum(
48 | spectrum,
49 | spectrum_counts=spectrum_counts,)
50 | assert(np.sum(spectrum_rescaled < 0) == 0)
51 |
--------------------------------------------------------------------------------
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | # Python CircleCI 2.0 configuration file
2 | #
3 | # Check https://circleci.com/docs/2.0/language-python/ for more details
4 | #
5 | version: 2
6 | jobs:
7 | build:
8 | docker:
9 | # specify the version you desire here
10 | # use `-browsers` prefix for selenium tests, e.g. `3.6.1-browsers`
11 | - image: circleci/python:3.6
12 |
13 | # Specify service dependencies here if necessary
14 | # CircleCI maintains a library of pre-built images
15 | # documented at https://circleci.com/docs/2.0/circleci-images/
16 | # - image: circleci/postgres:9.4
17 |
18 | working_directory: ~/annsa
19 |
20 | steps:
21 | - checkout
22 |
23 | # Download and cache dependencies
24 | - restore_cache:
25 | keys:
26 | - v1-dependencies-{{ checksum "requirements.txt" }}
27 | # fallback to using the latest cache if no exact match is found
28 | - v1-dependencies-
29 |
30 | - run:
31 | name: install dependencies
32 | command: |
33 | python3 -m venv venv
34 | . venv/bin/activate
35 | pip install -r requirements.txt
36 | pip install -r test-requirements.txt
37 |
38 | - save_cache:
39 | paths:
40 | - ./venv
41 | key: v1-dependencies-{{ checksum "requirements.txt" }}
42 |
43 | # run tests!
44 | # https://pytest.org
45 | - run:
46 | name: run tests
47 | command: |
48 | . venv/bin/activate
49 | py.test
50 | py.test --flake8 annsa
51 |
52 | - store_artifacts:
53 | path: test-reports
54 | destination: test-reports
55 |
56 |
--------------------------------------------------------------------------------
/doc/sphinxext/math_dollar.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | def dollars_to_math(source):
4 | r"""
5 | Replace dollar signs with backticks.
6 |
7 | More precisely, do a regular expression search. Replace a plain
8 | dollar sign ($) by a backtick (`). Replace an escaped dollar sign
9 | (\$) by a dollar sign ($). Don't change a dollar sign preceded or
10 | followed by a backtick (`$ or $`), because of strings like
11 | "``$HOME``". Don't make any changes on lines starting with
12 | spaces, because those are indented and hence part of a block of
13 | code or examples.
14 |
15 | This also doesn't replaces dollar signs enclosed in curly braces,
16 | to avoid nested math environments, such as ::
17 |
18 | $f(n) = 0 \text{ if $n$ is prime}$
19 |
20 | Thus the above line would get changed to
21 |
22 | `f(n) = 0 \text{ if $n$ is prime}`
23 | """
24 | s = "\n".join(source)
25 | if s.find("$") == -1:
26 | return
27 | # This searches for "$blah$" inside a pair of curly braces --
28 | # don't change these, since they're probably coming from a nested
29 | # math environment. So for each match, we replace it with a temporary
30 | # string, and later on we substitute the original back.
31 | global _data
32 | _data = {}
33 | def repl(matchobj):
34 | global _data
35 | s = matchobj.group(0)
36 | t = "___XXX_REPL_%d___" % len(_data)
37 | _data[t] = s
38 | return t
39 | s = re.sub(r"({[^{}$]*\$[^{}$]*\$[^{}]*})", repl, s)
40 | # matches $...$
41 | dollars = re.compile(r"(? 0.)
91 |
92 |
93 | # dropout tests
94 | @pytest.mark.parametrize('dnn',
95 | ((0.999, 1024),),
96 | indirect=True,)
97 | def test_dropout_0(dnn):
98 | '''case 0: tests that dropout is applied when training.'''
99 | o_training_false = dnn.forward_pass(np.ones([1, 1024]),
100 | training=False).numpy()
101 | o_training_true = dnn.forward_pass(np.ones([1, 1024]),
102 | training=True).numpy()
103 | assert(np.array_equal(o_training_false, o_training_true) is False)
104 |
105 |
106 | @pytest.mark.parametrize('dnn',
107 | ((0.999, 1024),),
108 | indirect=True,)
109 | def test_dropout_1(dnn):
110 | '''case 1: tests that dropout is not applied in inference, when training
111 | is False.'''
112 | o_training_false_1 = dnn.forward_pass(np.ones([1, 1024]),
113 | training=False).numpy()
114 | o_training_false_2 = dnn.forward_pass(np.ones([1, 1024]),
115 | training=False).numpy()
116 | assert(np.array_equal(o_training_false_1, o_training_false_2))
117 |
118 |
119 | # training tests
120 | @pytest.mark.parametrize('cost', ['mse', 'cross_entropy'])
121 | @pytest.mark.parametrize('dnn',
122 | ((0.5, 64),),
123 | indirect=True,)
124 | def test_training_0(dnn, toy_dataset, cost):
125 | '''case 0: test if training on toy dataset reduces errors using
126 | both error functions'''
127 | (data, targets_binarized) = toy_dataset
128 | cost_function = getattr(dnn, cost)
129 | objective_cost, earlystop_cost = dnn.fit_batch(
130 | (data, targets_binarized),
131 | (data, targets_binarized),
132 | optimizer=tf.train.AdamOptimizer(1e-2),
133 | num_epochs=2,
134 | obj_cost=cost_function,
135 | data_augmentation=dnn.default_data_augmentation,)
136 | epoch0_error = objective_cost['test'][0].numpy()
137 | epoch1_error = objective_cost['test'][1].numpy()
138 | assert(epoch1_error < epoch0_error)
139 |
--------------------------------------------------------------------------------
/doc/sphinxext/github.py:
--------------------------------------------------------------------------------
1 | """Define text roles for GitHub
2 |
3 | * ghissue - Issue
4 | * ghpull - Pull Request
5 | * ghuser - User
6 |
7 | Adapted from bitbucket example here:
8 | https://bitbucket.org/birkenfeld/sphinx-contrib/src/tip/bitbucket/sphinxcontrib/bitbucket.py
9 |
10 | Authors
11 | -------
12 |
13 | * Doug Hellmann
14 | * Min RK
15 | """
16 | #
17 | # Original Copyright (c) 2010 Doug Hellmann. All rights reserved.
18 | #
19 |
20 | from docutils import nodes, utils
21 | from docutils.parsers.rst.roles import set_classes
22 |
23 | def make_link_node(rawtext, app, type, slug, options):
24 | """Create a link to a github resource.
25 |
26 | :param rawtext: Text being replaced with link node.
27 | :param app: Sphinx application context
28 | :param type: Link type (issues, changeset, etc.)
29 | :param slug: ID of the thing to link to
30 | :param options: Options dictionary passed to role func.
31 | """
32 |
33 | try:
34 | base = app.config.github_project_url
35 | if not base:
36 | raise AttributeError
37 | if not base.endswith('/'):
38 | base += '/'
39 | except AttributeError as err:
40 | raise ValueError('github_project_url configuration value is not set (%s)' % str(err))
41 |
42 | ref = base + type + '/' + slug + '/'
43 | set_classes(options)
44 | prefix = "#"
45 | if type == 'pull':
46 | prefix = "PR " + prefix
47 | node = nodes.reference(rawtext, prefix + utils.unescape(slug), refuri=ref,
48 | **options)
49 | return node
50 |
51 | def ghissue_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
52 | """Link to a GitHub issue.
53 |
54 | Returns 2 part tuple containing list of nodes to insert into the
55 | document and a list of system messages. Both are allowed to be
56 | empty.
57 |
58 | :param name: The role name used in the document.
59 | :param rawtext: The entire markup snippet, with role.
60 | :param text: The text marked with the role.
61 | :param lineno: The line number where rawtext appears in the input.
62 | :param inliner: The inliner instance that called us.
63 | :param options: Directive options for customization.
64 | :param content: The directive content for customization.
65 | """
66 |
67 | try:
68 | issue_num = int(text)
69 | if issue_num <= 0:
70 | raise ValueError
71 | except ValueError:
72 | msg = inliner.reporter.error(
73 | 'GitHub issue number must be a number greater than or equal to 1; '
74 | '"%s" is invalid.' % text, line=lineno)
75 | prb = inliner.problematic(rawtext, rawtext, msg)
76 | return [prb], [msg]
77 | app = inliner.document.settings.env.app
78 | #app.info('issue %r' % text)
79 | if 'pull' in name.lower():
80 | category = 'pull'
81 | elif 'issue' in name.lower():
82 | category = 'issues'
83 | else:
84 | msg = inliner.reporter.error(
85 | 'GitHub roles include "ghpull" and "ghissue", '
86 | '"%s" is invalid.' % name, line=lineno)
87 | prb = inliner.problematic(rawtext, rawtext, msg)
88 | return [prb], [msg]
89 | node = make_link_node(rawtext, app, category, str(issue_num), options)
90 | return [node], []
91 |
92 | def ghuser_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
93 | """Link to a GitHub user.
94 |
95 | Returns 2 part tuple containing list of nodes to insert into the
96 | document and a list of system messages. Both are allowed to be
97 | empty.
98 |
99 | :param name: The role name used in the document.
100 | :param rawtext: The entire markup snippet, with role.
101 | :param text: The text marked with the role.
102 | :param lineno: The line number where rawtext appears in the input.
103 | :param inliner: The inliner instance that called us.
104 | :param options: Directive options for customization.
105 | :param content: The directive content for customization.
106 | """
107 | app = inliner.document.settings.env.app
108 | #app.info('user link %r' % text)
109 | ref = 'https://www.github.com/' + text
110 | node = nodes.reference(rawtext, text, refuri=ref, **options)
111 | return [node], []
112 |
113 | def ghcommit_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
114 | """Link to a GitHub commit.
115 |
116 | Returns 2 part tuple containing list of nodes to insert into the
117 | document and a list of system messages. Both are allowed to be
118 | empty.
119 |
120 | :param name: The role name used in the document.
121 | :param rawtext: The entire markup snippet, with role.
122 | :param text: The text marked with the role.
123 | :param lineno: The line number where rawtext appears in the input.
124 | :param inliner: The inliner instance that called us.
125 | :param options: Directive options for customization.
126 | :param content: The directive content for customization.
127 | """
128 | app = inliner.document.settings.env.app
129 | #app.info('user link %r' % text)
130 | try:
131 | base = app.config.github_project_url
132 | if not base:
133 | raise AttributeError
134 | if not base.endswith('/'):
135 | base += '/'
136 | except AttributeError as err:
137 | raise ValueError('github_project_url configuration value is not set (%s)' % str(err))
138 |
139 | ref = base + text
140 | node = nodes.reference(rawtext, text[:6], refuri=ref, **options)
141 | return [node], []
142 |
143 |
144 | def setup(app):
145 | """Install the plugin.
146 |
147 | :param app: Sphinx application context.
148 | """
149 | app.info('Initializing GitHub plugin')
150 | app.add_role('ghissue', ghissue_role)
151 | app.add_role('ghpull', ghissue_role)
152 | app.add_role('ghuser', ghuser_role)
153 | app.add_role('ghcommit', ghcommit_role)
154 | app.add_config_value('github_project_url', None, 'env')
155 | return
156 |
--------------------------------------------------------------------------------
/examples/source-interdiction/hyperparameter-search/DNN-Hyperparameter-Search-Easy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
11 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
12 | "model_id = 'DNN-hpsearch-easy'\n",
13 | "\n",
14 | "from sklearn.preprocessing import LabelBinarizer\n",
15 | "from sklearn.model_selection import StratifiedKFold\n",
16 | "import tensorflow as tf\n",
17 | "import pickle\n",
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "\n",
21 | "from annsa.template_sampling import *\n",
22 | "from annsa.load_pretrained_network import save_features"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from hyperparameter_models import make_dense_model as make_model"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "import tensorflow.contrib.eager as tfe"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "tf.enable_eager_execution()"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "#### Import model, training function "
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from annsa.model_classes import (dnn_model_features,\n",
66 | " DNN,\n",
67 | " save_model,\n",
68 | " train_earlystop)"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "## Load testing dataset "
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "dataset = np.load('../dataset_generation/hyperparametersearch_dataset_100_easy.npy')"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "all_spectra = np.float64(np.add(dataset.item()['sources'], dataset.item()['backgrounds']))\n",
94 | "all_keys = dataset.item()['keys']\n",
95 | "\n",
96 | "mlb=LabelBinarizer()\n",
97 | "\n",
98 | "all_keys_binarized = mlb.fit_transform(all_keys)"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "# Train network"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "### Define hyperparameters"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "number_hyperparameters_to_search = 256\n",
122 | "earlystop_errors_test = []"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "### Search hyperparameters"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "skf = StratifiedKFold(n_splits=5, random_state=5)\n",
139 | "testing_errors = []\n",
140 | "all_kf_errors = []\n",
141 | "\n",
142 | "for network_id in range(number_hyperparameters_to_search):\n",
143 | " print(network_id)\n",
144 | " model, model_features = make_model(all_keys_binarized)\n",
145 | " filename = os.path.join('hyperparameter-search-results',\n",
146 | " model_id + '-' +str(hyperparameter_index))\n",
147 | " save_features(model_features, filename)\n",
148 | " \n",
149 | " k_folds_errors = []\n",
150 | " for train_index, test_index in skf.split(all_spectra, all_keys):\n",
151 | " # reset model on each iteration\n",
152 | " model = DNN(model_features)\n",
153 | " optimizer = tf.train.AdamOptimizer(model_features.learining_rate)\n",
154 | "\n",
155 | " costfunction_errors_tmp, earlystop_errors_tmp = train_earlystop(\n",
156 | " training_data=all_spectra[train_index],\n",
157 | " training_keys=all_keys_binarized[train_index],\n",
158 | " testing_data=all_spectra[test_index],\n",
159 | " testing_keys=all_keys_binarized[test_index],\n",
160 | " model=model,\n",
161 | " optimizer=optimizer,\n",
162 | " num_epochs=200,\n",
163 | " obj_cost=model.cross_entropy,\n",
164 | " earlystop_cost_fn=model.f1_error,\n",
165 | " earlystop_patience=10,\n",
166 | " not_learning_patience=10,\n",
167 | " not_learning_threshold=0.9,\n",
168 | " verbose=True,\n",
169 | " fit_batch_verbose=10,\n",
170 | " data_augmentation=model.default_data_augmentation)\n",
171 | " k_folds_errors.append(earlystop_errors_tmp)\n",
172 | " all_kf_errors.append(earlystop_errors_tmp)\n",
173 | "\n",
174 | " testing_errors.append(np.average(k_folds_errors))\n",
175 | " np.save('./final-models/final_test_errors_'+model_id, testing_errors)\n",
176 | " np.save('./final-models/final_kf_errors_'+model_id, all_kf_errors)\n",
177 | " network_id += 1 "
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": []
186 | }
187 | ],
188 | "metadata": {
189 | "kernelspec": {
190 | "display_name": "Environment (conda_tensorflow_p36)",
191 | "language": "python",
192 | "name": "conda_tensorflow_p36"
193 | },
194 | "language_info": {
195 | "codemirror_mode": {
196 | "name": "ipython",
197 | "version": 3
198 | },
199 | "file_extension": ".py",
200 | "mimetype": "text/x-python",
201 | "name": "python",
202 | "nbconvert_exporter": "python",
203 | "pygments_lexer": "ipython3",
204 | "version": "3.6.7"
205 | }
206 | },
207 | "nbformat": 4,
208 | "nbformat_minor": 1
209 | }
210 |
--------------------------------------------------------------------------------
/examples/source-interdiction/hyperparameter-search/DNN-Hyperparameter-Search-Full.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
11 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
12 | "model_id = 'DNN-hpsearch-full'\n",
13 | "\n",
14 | "from sklearn.preprocessing import LabelBinarizer\n",
15 | "from sklearn.model_selection import StratifiedKFold\n",
16 | "import tensorflow as tf\n",
17 | "import pickle\n",
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "\n",
21 | "from annsa.template_sampling import *\n",
22 | "from annsa.load_pretrained_network import save_features"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from hyperparameter_models import make_dense_model as make_model"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "import tensorflow.contrib.eager as tfe"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "tf.enable_eager_execution()"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "#### Import model, training function "
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from annsa.model_classes import (dnn_model_features,\n",
66 | " DNN,\n",
67 | " save_model,\n",
68 | " train_earlystop)"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "## Load testing dataset "
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "dataset = np.load('../dataset_generation/hyperparametersearch_dataset_100_full.npy')"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "all_spectra = np.float64(np.add(dataset.item()['sources'], dataset.item()['backgrounds']))\n",
94 | "all_keys = dataset.item()['keys']\n",
95 | "\n",
96 | "mlb=LabelBinarizer()\n",
97 | "\n",
98 | "all_keys_binarized = mlb.fit_transform(all_keys)"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "# Train network"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "### Define hyperparameters"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "number_hyperparameters_to_search = 256\n",
122 | "earlystop_errors_test = []"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "### Search hyperparameters"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "skf = StratifiedKFold(n_splits=5, random_state=5)\n",
139 | "testing_errors = []\n",
140 | "all_kf_errors = []\n",
141 | "\n",
142 | "for network_id in range(number_hyperparameters_to_search):\n",
143 | " print(network_id)\n",
144 | " model, model_features = make_model(all_keys_binarized)\n",
145 | " filename = os.path.join('hyperparameter-search-results',\n",
146 | " model_id + '-' +str(hyperparameter_index))\n",
147 | " save_features(model_features, filename)\n",
148 | " \n",
149 | " k_folds_errors = []\n",
150 | " for train_index, test_index in skf.split(all_spectra, all_keys):\n",
151 | " # reset model on each iteration\n",
152 | " model = DNN(model_features)\n",
153 | " optimizer = tf.train.AdamOptimizer(model_features.learining_rate)\n",
154 | "\n",
155 | " costfunction_errors_tmp, earlystop_errors_tmp = train_earlystop(\n",
156 | " training_data=all_spectra[train_index],\n",
157 | " training_keys=all_keys_binarized[train_index],\n",
158 | " testing_data=all_spectra[test_index],\n",
159 | " testing_keys=all_keys_binarized[test_index],\n",
160 | " model=model,\n",
161 | " optimizer=optimizer,\n",
162 | " num_epochs=200,\n",
163 | " obj_cost=model.cross_entropy,\n",
164 | " earlystop_cost_fn=model.f1_error,\n",
165 | " earlystop_patience=10,\n",
166 | " not_learning_patience=10,\n",
167 | " not_learning_threshold=0.9,\n",
168 | " verbose=True,\n",
169 | " fit_batch_verbose=10,\n",
170 | " data_augmentation=model.default_data_augmentation)\n",
171 | " k_folds_errors.append(earlystop_errors_tmp)\n",
172 | " all_kf_errors.append(earlystop_errors_tmp)\n",
173 | "\n",
174 | " testing_errors.append(np.average(k_folds_errors))\n",
175 | " np.save('./final-models/final_test_errors_'+model_id, testing_errors)\n",
176 | " np.save('./final-models/final_kf_errors_'+model_id, all_kf_errors)\n",
177 | " network_id += 1 "
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": []
186 | }
187 | ],
188 | "metadata": {
189 | "kernelspec": {
190 | "display_name": "Environment (conda_tensorflow_p36)",
191 | "language": "python",
192 | "name": "conda_tensorflow_p36"
193 | },
194 | "language_info": {
195 | "codemirror_mode": {
196 | "name": "ipython",
197 | "version": 3
198 | },
199 | "file_extension": ".py",
200 | "mimetype": "text/x-python",
201 | "name": "python",
202 | "nbconvert_exporter": "python",
203 | "pygments_lexer": "ipython3",
204 | "version": "3.6.7"
205 | }
206 | },
207 | "nbformat": 4,
208 | "nbformat_minor": 1
209 | }
210 |
--------------------------------------------------------------------------------
/examples/source-interdiction/hyperparameter-search/CNN-Hyperparameter-Search-Easy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
11 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
12 | "model_id = 'CNN-hpsearch-easy'\n",
13 | "\n",
14 | "from sklearn.preprocessing import LabelBinarizer\n",
15 | "from sklearn.model_selection import StratifiedKFold\n",
16 | "import tensorflow as tf\n",
17 | "import pickle\n",
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "\n",
21 | "from annsa.template_sampling import *\n",
22 | "from annsa.load_pretrained_network import save_features"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from hyperparameter_models import make_conv1d_model as make_model"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "import tensorflow.contrib.eager as tfe"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "tf.enable_eager_execution()"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "#### Import model, training function "
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from annsa.model_classes import (cnn1d_model_features,\n",
66 | " CNN1D,\n",
67 | " generate_random_cnn1d_architecture,\n",
68 | " train_earlystop)"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "## Load testing dataset "
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "dataset = np.load('../dataset_generation/hyperparametersearch_dataset_100_easy.npy')"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "all_spectra = np.float64(np.add(dataset.item()['sources'], dataset.item()['backgrounds']))\n",
94 | "all_keys = dataset.item()['keys']\n",
95 | "\n",
96 | "mlb=LabelBinarizer()\n",
97 | "\n",
98 | "all_keys_binarized = mlb.fit_transform(all_keys)"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "# Train network"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "### Define hyperparameters"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "number_hyperparameters_to_search = 256\n",
122 | "earlystop_errors_test = []"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "### Search hyperparameters"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {
136 | "scrolled": false
137 | },
138 | "outputs": [],
139 | "source": [
140 | "skf = StratifiedKFold(n_splits=5, random_state=5)\n",
141 | "testing_errors = []\n",
142 | "all_kf_errors = []\n",
143 | "\n",
144 | "for network_id in range(number_hyperparameters_to_search):\n",
145 | " print(network_id)\n",
146 | " model, model_features = make_model(all_keys_binarized)\n",
147 | " filename = os.path.join('hyperparameter-search-results',\n",
148 | " model_id + '-' +str(hyperparameter_index))\n",
149 | " save_features(model_features, filename)\n",
150 | " \n",
151 | " k_folds_errors = []\n",
152 | " for train_index, test_index in skf.split(all_spectra, all_keys):\n",
153 | " # reset model on each iteration\n",
154 | " model = CNN1D(model_features)\n",
155 | " optimizer = tf.train.AdamOptimizer(model_features.learning_rate)\n",
156 | "\n",
157 | " costfunction_errors_tmp, earlystop_errors_tmp = train_earlystop(\n",
158 | " training_data=all_spectra[train_index],\n",
159 | " training_keys=all_keys_binarized[train_index],\n",
160 | " testing_data=all_spectra[test_index],\n",
161 | " testing_keys=all_keys_binarized[test_index],\n",
162 | " model=model,\n",
163 | " optimizer=optimizer,\n",
164 | " num_epochs=200,\n",
165 | " obj_cost=model.cross_entropy,\n",
166 | " earlystop_cost_fn=model.f1_error,\n",
167 | " earlystop_patience=10,\n",
168 | " not_learning_patience=10,\n",
169 | " not_learning_threshold=0.9,\n",
170 | " verbose=True,\n",
171 | " fit_batch_verbose=10,\n",
172 | " data_augmentation=model.default_data_augmentation)\n",
173 | " k_folds_errors.append(earlystop_errors_tmp)\n",
174 | " all_kf_errors.append(earlystop_errors_tmp)\n",
175 | "\n",
176 | " testing_errors.append(np.average(k_folds_errors))\n",
177 | " np.save('./final-models/final_test_errors_'+model_id, testing_errors)\n",
178 | " np.save('./final-models/final_kf_errors_'+model_id, all_kf_errors)\n",
179 | " network_id += 1 "
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": []
188 | }
189 | ],
190 | "metadata": {
191 | "kernelspec": {
192 | "display_name": "Environment (conda_tensorflow_p36)",
193 | "language": "python",
194 | "name": "conda_tensorflow_p36"
195 | },
196 | "language_info": {
197 | "codemirror_mode": {
198 | "name": "ipython",
199 | "version": 3
200 | },
201 | "file_extension": ".py",
202 | "mimetype": "text/x-python",
203 | "name": "python",
204 | "nbconvert_exporter": "python",
205 | "pygments_lexer": "ipython3",
206 | "version": "3.6.7"
207 | }
208 | },
209 | "nbformat": 4,
210 | "nbformat_minor": 1
211 | }
212 |
--------------------------------------------------------------------------------
/examples/source-interdiction/hyperparameter-search/CNN-Hyperparameter-Search-Full.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
11 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
12 | "model_id = 'CNN-hpsearch-full'\n",
13 | "\n",
14 | "from sklearn.preprocessing import LabelBinarizer\n",
15 | "from sklearn.model_selection import StratifiedKFold\n",
16 | "import tensorflow as tf\n",
17 | "import pickle\n",
18 | "import numpy as np\n",
19 | "import pandas as pd\n",
20 | "\n",
21 | "from annsa.template_sampling import *\n",
22 | "from annsa.load_pretrained_network import save_features"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from hyperparameter_models import make_conv1d_model as make_model"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "import tensorflow.contrib.eager as tfe"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "tf.enable_eager_execution()"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "#### Import model, training function "
57 | ]
58 | },
59 | {
60 | "cell_type": "code",
61 | "execution_count": null,
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "from annsa.model_classes import (cnn1d_model_features,\n",
66 | " CNN1D,\n",
67 | " generate_random_cnn1d_architecture,\n",
68 | " train_earlystop)"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "## Load testing dataset "
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "dataset = np.load('../dataset_generation/hyperparametersearch_dataset_100_full.npy')"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "all_spectra = np.float64(np.add(dataset.item()['sources'], dataset.item()['backgrounds']))\n",
94 | "all_keys = dataset.item()['keys']\n",
95 | "\n",
96 | "mlb=LabelBinarizer()\n",
97 | "\n",
98 | "all_keys_binarized = mlb.fit_transform(all_keys)"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {},
104 | "source": [
105 | "# Train network"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "### Define hyperparameters"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "number_hyperparameters_to_search = 256\n",
122 | "earlystop_errors_test = []"
123 | ]
124 | },
125 | {
126 | "cell_type": "markdown",
127 | "metadata": {},
128 | "source": [
129 | "### Search hyperparameters"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {
136 | "scrolled": false
137 | },
138 | "outputs": [],
139 | "source": [
140 | "skf = StratifiedKFold(n_splits=5, random_state=5)\n",
141 | "testing_errors = []\n",
142 | "all_kf_errors = []\n",
143 | "\n",
144 | "for network_id in range(number_hyperparameters_to_search):\n",
145 | " print(network_id)\n",
146 | " model, model_features = make_model(all_keys_binarized)\n",
147 | " filename = os.path.join('hyperparameter-search-results',\n",
148 | " model_id + '-' +str(hyperparameter_index))\n",
149 | " save_features(model_features, filename)\n",
150 | " \n",
151 | " k_folds_errors = []\n",
152 | " for train_index, test_index in skf.split(all_spectra, all_keys):\n",
153 | " # reset model on each iteration\n",
154 | " model = CNN1D(model_features)\n",
155 | " optimizer = tf.train.AdamOptimizer(model_features.learning_rate)\n",
156 | "\n",
157 | " costfunction_errors_tmp, earlystop_errors_tmp = train_earlystop(\n",
158 | " training_data=all_spectra[train_index],\n",
159 | " training_keys=all_keys_binarized[train_index],\n",
160 | " testing_data=all_spectra[test_index],\n",
161 | " testing_keys=all_keys_binarized[test_index],\n",
162 | " model=model,\n",
163 | " optimizer=optimizer,\n",
164 | " num_epochs=200,\n",
165 | " obj_cost=model.cross_entropy,\n",
166 | " earlystop_cost_fn=model.f1_error,\n",
167 | " earlystop_patience=10,\n",
168 | " not_learning_patience=10,\n",
169 | " not_learning_threshold=0.9,\n",
170 | " verbose=True,\n",
171 | " fit_batch_verbose=10,\n",
172 | " data_augmentation=model.default_data_augmentation)\n",
173 | " k_folds_errors.append(earlystop_errors_tmp)\n",
174 | " all_kf_errors.append(earlystop_errors_tmp)\n",
175 | "\n",
176 | " testing_errors.append(np.average(k_folds_errors))\n",
177 | " np.save('./final-models/final_test_errors_'+model_id, testing_errors)\n",
178 | " np.save('./final-models/final_kf_errors_'+model_id, all_kf_errors)\n",
179 | " network_id += 1 "
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "metadata": {},
186 | "outputs": [],
187 | "source": []
188 | }
189 | ],
190 | "metadata": {
191 | "kernelspec": {
192 | "display_name": "Environment (conda_tensorflow_p36)",
193 | "language": "python",
194 | "name": "conda_tensorflow_p36"
195 | },
196 | "language_info": {
197 | "codemirror_mode": {
198 | "name": "ipython",
199 | "version": 3
200 | },
201 | "file_extension": ".py",
202 | "mimetype": "text/x-python",
203 | "name": "python",
204 | "nbconvert_exporter": "python",
205 | "pygments_lexer": "ipython3",
206 | "version": "3.6.7"
207 | }
208 | },
209 | "nbformat": 4,
210 | "nbformat_minor": 1
211 | }
212 |
--------------------------------------------------------------------------------
/annsa/generate_uranium_templates.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import numpy as np
3 | from numpy.random import choice
4 | from annsa.template_sampling import (apply_LLD,
5 | rebin_spectrum,)
6 |
7 |
8 | def choose_uranium_template(uranium_dataset,
9 | sourcedist,
10 | sourceheight,
11 | shieldingdensity,
12 | fwhm,):
13 | '''
14 | Chooses a specific uranium template from a dataset.
15 |
16 | Inputs
17 | uranium_dataset : pandas dataframe
18 | Dataframe containing U232, U235, U238
19 | templates simulated in multiple conditions.
20 | sourcedist : int
21 | The source distance
22 | sourceheight : int
23 | The source height
24 | shieldingdensity : float
25 | The source density in g/cm2
26 | fwhm : float
27 | The full-width-at-half-max at 662
28 |
29 | Outputs
30 | uranium_templates : dict
31 | Dictionary of a single template for each isotope
32 | Also contains an entry for FWHM.
33 | '''
34 |
35 | uranium_templates = {}
36 | sourcedist_choice = sourcedist
37 | sourceheight_choice = sourceheight
38 | shieldingdensity_choice = shieldingdensity
39 |
40 | source_dataset_tmp = uranium_dataset[
41 | uranium_dataset['sourcedist'] == sourcedist_choice]
42 | source_dataset_tmp = source_dataset_tmp[
43 | source_dataset_tmp['sourceheight'] == sourceheight_choice]
44 | source_dataset_tmp = source_dataset_tmp[
45 | source_dataset_tmp['shieldingdensity'] == shieldingdensity_choice]
46 | source_dataset_tmp = source_dataset_tmp[
47 | source_dataset_tmp['fwhm'] == fwhm]
48 |
49 | for isotope in ['u232', 'u235', 'u238']:
50 | spectrum_template = source_dataset_tmp[
51 | source_dataset_tmp['isotope'] == isotope].values[0][6:]
52 | uranium_templates[isotope] = np.abs(spectrum_template)
53 | uranium_templates[isotope] = uranium_templates[isotope].astype(int)
54 | uranium_templates['fwhm'] = source_dataset_tmp['fwhm']
55 |
56 | return uranium_templates
57 |
58 |
59 | def choose_random_uranium_template(uranium_dataset):
60 | '''
61 | Chooses a random uranium template from a dataset.
62 |
63 | Inputs
64 | source_dataset : pandas dataframe
65 | Dataframe containing U232, U235, U238
66 | templates simulated in multiple conditions.
67 |
68 | Outputs
69 | uranium_templates : dict
70 | Dictionary of a single template for each isotope.
71 | '''
72 |
73 | uranium_templates = {}
74 |
75 | all_sourcedist = list(set(uranium_dataset['sourcedist']))
76 | sourcedist_choice = choice(all_sourcedist)
77 |
78 | all_sourceheight = list(set(uranium_dataset['sourceheight']))
79 | sourceheight_choice = choice(all_sourceheight)
80 |
81 | all_shieldingdensity = list(set(uranium_dataset['shieldingdensity']))
82 | shieldingdensity_choice = choice(all_shieldingdensity)
83 |
84 | all_fwhm = list(set(uranium_dataset['fwhm']))
85 | fwhm_choice = choice(all_fwhm)
86 |
87 | source_dataset_tmp = uranium_dataset[
88 | uranium_dataset['sourcedist'] == sourcedist_choice]
89 | source_dataset_tmp = source_dataset_tmp[
90 | source_dataset_tmp['sourceheight'] == sourceheight_choice]
91 | source_dataset_tmp = source_dataset_tmp[
92 | source_dataset_tmp['shieldingdensity'] == shieldingdensity_choice]
93 | source_dataset_tmp = source_dataset_tmp[
94 | source_dataset_tmp['fwhm'] == fwhm_choice]
95 |
96 | for isotope in ['u232', 'u235', 'u238']:
97 | spectrum_template = source_dataset_tmp[
98 | source_dataset_tmp['isotope'] == isotope].values[0][6:]
99 | uranium_templates[isotope] = np.abs(spectrum_template)
100 | uranium_templates[isotope] = uranium_templates[isotope].astype(int)
101 | uranium_templates['fwhm'] = source_dataset_tmp['fwhm']
102 |
103 | return uranium_templates
104 |
105 |
106 | def generate_uenriched_spectrum(uranium_templates,
107 | background_dataset,
108 | enrichment_level=0.93,
109 | integration_time=60,
110 | background_cps=200,
111 | calibration=[0, 1, 0],
112 | source_background_ratio=1.0,
113 | ):
114 | '''
115 | Generates an enriched uranium spectrum based on .
116 |
117 | Inputs
118 | uranium_template : dict
119 | Dictionary of a single template for each isotope.
120 | background_dataset : pandas dataframe
121 | Dataframe of background spectra with different FWHM parameters.
122 |
123 | Outputs
124 | full_spectrum : array
125 | Sampled source and background spectrum
126 | '''
127 |
128 | a = calibration[0]
129 | b = calibration[1]
130 | c = calibration[2]
131 |
132 | template_measurment_time = 3600
133 | time_scaler = integration_time / template_measurment_time
134 | mass_fraction_u232 = choice([0,
135 | np.random.uniform(0.4, 2.0)])
136 |
137 | uranium_component_magnitudes = {
138 | 'u235': time_scaler * enrichment_level,
139 | 'u232': time_scaler * mass_fraction_u232,
140 | 'u238': time_scaler * (1 - enrichment_level),
141 | }
142 |
143 | source_spectrum = np.zeros([1024])
144 | for isotope in uranium_component_magnitudes:
145 | source_spectrum += uranium_component_magnitudes[isotope] \
146 | * rebin_spectrum(
147 | uranium_templates[isotope], a, b, c)
148 | source_spectrum = apply_LLD(source_spectrum, 10)
149 | source_spectrum_sampled = np.random.poisson(source_spectrum)
150 | source_counts = np.sum(source_spectrum_sampled)
151 |
152 | background_counts = source_counts / source_background_ratio
153 | fwhm = uranium_templates['fwhm'].values[0]
154 | background_dataset = background_dataset[background_dataset['fwhm'] == fwhm]
155 | background_spectrum = background_dataset.sample().values[0][3:]
156 | background_spectrum = rebin_spectrum(background_spectrum,
157 | a, b, c)
158 | background_spectrum = np.array(background_spectrum, dtype='float64')
159 | background_spectrum = apply_LLD(background_spectrum, 10)
160 | background_spectrum /= np.sum(background_spectrum)
161 | background_spectrum_sampled = np.random.poisson(background_spectrum *
162 | background_counts)
163 |
164 | full_spectrum = np.sum(
165 | [source_spectrum_sampled[0:1024],
166 | background_spectrum_sampled[0:1024]],
167 | axis=0,)
168 |
169 | return full_spectrum
170 |
--------------------------------------------------------------------------------
/annsa/tests/test_training_cnn1d.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import numpy as np
3 | import tensorflow as tf
4 | import pytest
5 | from sklearn.preprocessing import FunctionTransformer, LabelBinarizer
6 | from sklearn.pipeline import make_pipeline
7 | from sklearn.datasets import load_digits
8 | from annsa.model_classes import (generate_random_cnn1d_architecture,
9 | CNN1D)
10 |
11 | tf.enable_eager_execution()
12 |
13 |
14 | @pytest.fixture(params=[([10], 0.5, 64),
15 | ([], 0.5, 1024),
16 | ([], 0.999, 1024),
17 | ([10], 0.999, 1024), ])
18 | def cnn1d(request):
19 | '''
20 | Constructs a convolutional neural network with filters
21 | initialized to ones. Fixture params are either zero or one hidden
22 | dense layer.
23 | '''
24 | (dense_nodes, dropout_probability, input_size) = request.param
25 |
26 | scaler = make_pipeline(FunctionTransformer(np.abs, validate=False))
27 | model_features = generate_random_cnn1d_architecture(
28 | cnn_filters_choices=((4, 1),),
29 | cnn_kernel_choices=((4, ), ),
30 | pool_size_choices=((4, ), ))
31 | model_features.learning_rate = 1e-1
32 | model_features.trainable = True
33 | model_features.batch_size = 5
34 | model_features.output_size = 3
35 | model_features.output_function = None
36 | model_features.l2_regularization_scale = 1e1
37 | model_features.dropout_probability = dropout_probability
38 | model_features.scaler = scaler
39 | model_features.Pooling = tf.layers.MaxPooling1D
40 | model_features.activation_function = None
41 | model_features.dense_nodes = dense_nodes
42 | model = CNN1D(model_features)
43 | # forward pass to initialize cnn1d weights
44 | model.forward_pass(np.ones([1, input_size]), training=False)
45 | # set weights to ones
46 | weight_ones = [np.ones(weight.shape) if (index % 2 == 0) else weight for
47 | index, weight in enumerate(model.get_weights())]
48 | model.set_weights(weight_ones)
49 | return model
50 |
51 |
52 | @pytest.fixture()
53 | def toy_dataset():
54 | '''
55 | Constructs toy dataset of digits.
56 | '''
57 | data, target = load_digits(n_class=3, return_X_y=True)
58 | mlb = LabelBinarizer()
59 | targets_binarized = mlb.fit_transform(target)
60 | return (data, targets_binarized)
61 |
62 |
63 | # forward pass tests
64 | @pytest.mark.parametrize('cnn1d',
65 | (([], 0.5, 1024),
66 | ([], 0.999, 1024),
67 | ([10], 0.999, 1024)),
68 | indirect=True,)
69 | def test_forward_pass_0(cnn1d):
70 | '''case 0: test if output size is correct'''
71 | output = cnn1d.forward_pass(np.ones([1, 1024]), training=False)
72 | assert(output.shape[1] == 3)
73 |
74 |
75 | @pytest.mark.parametrize('cnn1d',
76 | (([], 0.5, 1024),),
77 | indirect=True,)
78 | def test_forward_pass_1(cnn1d):
79 | '''case 1: Tests response to a spectrum of all ones
80 | when weight filters are all one. Note, layer before output has an
81 | activation of 64 in each node and a length of 256. Densely connected
82 | output connection yields 64*256=16384 for each output node.'''
83 | output = cnn1d.forward_pass(np.ones([1, 1024]), training=False)
84 | output_value = output.numpy()[0][0]
85 | assert(output_value == 16384)
86 |
87 |
88 | # loss function tests
89 | @pytest.mark.parametrize('cnn1d',
90 | (([], 0.5, 1024),),
91 | indirect=True,)
92 | def test_loss_fn_0(cnn1d):
93 | '''case 0: tests if l2 regularization does not add to the loss_fn
94 | with hidden dense layers.'''
95 | loss = cnn1d.loss_fn(
96 | input_data=np.ones([1, 1024]),
97 | targets=np.array([[16384, 16384, 16384]]),
98 | cost=cnn1d.mse,
99 | training=False)
100 | loss = loss.numpy()
101 | assert(loss == 0.)
102 |
103 |
104 | @pytest.mark.parametrize('cnn1d',
105 | (([10], 0.999, 1024),),
106 | indirect=True,)
107 | def test_loss_fn_1(cnn1d):
108 | '''case 1: tests if l2 regularization adds to loss_fn when there are
109 | dense hidden layers.'''
110 | loss = cnn1d.loss_fn(
111 | input_data=np.ones([1, 1024]),
112 | targets=np.array([[16384, 16384, 16384]]),
113 | cost=cnn1d.mse,
114 | training=False)
115 | loss = loss
116 | assert(loss > 0.)
117 |
118 |
119 | # dropout test
120 | @pytest.mark.parametrize('cnn1d',
121 | (([], 0.999, 1024),),
122 | indirect=True,)
123 | def test_dropout_0(cnn1d):
124 | '''case 0: tests that dropout is not applied when there are no dense
125 | hidden layers.'''
126 | o_training_false = cnn1d.forward_pass(np.ones([1, 1024]),
127 | training=False).numpy()
128 | o_training_true = cnn1d.forward_pass(np.ones([1, 1024]),
129 | training=True).numpy()
130 | assert(np.array_equal(o_training_false, o_training_true))
131 |
132 |
133 | @pytest.mark.parametrize('cnn1d',
134 | (([10], 0.999, 1024),),
135 | indirect=True,)
136 | def test_dropout_1(cnn1d):
137 | '''case 1: tests that dropout is applied when there are
138 | dense hidden layers'''
139 | o_training_false = cnn1d.forward_pass(np.ones([1, 1024]),
140 | training=False).numpy()
141 | o_training_true = cnn1d.forward_pass(np.ones([1, 1024]),
142 | training=True).numpy()
143 | assert(np.array_equal(o_training_false, o_training_true) is False)
144 |
145 |
146 | @pytest.mark.parametrize('cnn1d',
147 | (([10], 0.999, 1024),),
148 | indirect=True,)
149 | def test_dropout_2(cnn1d):
150 | '''case 2: tests that dropout is not applied during inference, when
151 | training is False.'''
152 | o_training_false_1 = cnn1d.forward_pass(np.ones([1, 1024]),
153 | training=False).numpy()
154 | o_training_false_2 = cnn1d.forward_pass(np.ones([1, 1024]),
155 | training=False).numpy()
156 | assert(np.array_equal(o_training_false_1, o_training_false_2))
157 |
158 |
159 | # training tests
160 | @pytest.mark.parametrize('cost', ['mse', 'cross_entropy'])
161 | @pytest.mark.parametrize('cnn1d',
162 | (([10], 0.5, 64),),
163 | indirect=True,)
164 | def test_training_0(cnn1d, toy_dataset, cost):
165 | '''case 0: test if training on toy dataset reduces errors using
166 | both error functions'''
167 | (data, targets_binarized) = toy_dataset
168 | cost_function = getattr(cnn1d, cost)
169 | objective_cost, earlystop_cost = cnn1d.fit_batch(
170 | (data, targets_binarized),
171 | (data, targets_binarized),
172 | optimizer=tf.train.AdamOptimizer(1e-3),
173 | num_epochs=2,
174 | obj_cost=cost_function,
175 | data_augmentation=cnn1d.default_data_augmentation,)
176 | epoch0_error = objective_cost['test'][0].numpy()
177 | epoch1_error = objective_cost['test'][-1].numpy()
178 | assert(epoch1_error < epoch0_error)
179 |
--------------------------------------------------------------------------------
/doc/sphinxext/numpydoc.py:
--------------------------------------------------------------------------------
1 | """
2 | ========
3 | numpydoc
4 | ========
5 |
6 | Sphinx extension that handles docstrings in the Numpy standard format. [1]
7 |
8 | It will:
9 |
10 | - Convert Parameters etc. sections to field lists.
11 | - Convert See Also section to a See also entry.
12 | - Renumber references.
13 | - Extract the signature from the docstring, if it can't be determined
14 | otherwise.
15 |
16 | .. [1] https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt
17 |
18 | """
19 | from __future__ import division, absolute_import, print_function
20 |
21 | import sys
22 | import re
23 | import pydoc
24 | import sphinx
25 | import inspect
26 | import collections
27 |
28 | if sphinx.__version__ < '1.0.1':
29 | raise RuntimeError("Sphinx 1.0.1 or newer is required")
30 |
31 | from docscrape_sphinx import get_doc_object, SphinxDocString
32 | from sphinx.util.compat import Directive
33 |
34 | if sys.version_info[0] >= 3:
35 | sixu = lambda s: s
36 | else:
37 | sixu = lambda s: unicode(s, 'unicode_escape')
38 |
39 |
40 | def mangle_docstrings(app, what, name, obj, options, lines,
41 | reference_offset=[0]):
42 |
43 | cfg = {'use_plots': app.config.numpydoc_use_plots,
44 | 'show_class_members': app.config.numpydoc_show_class_members,
45 | 'show_inherited_class_members':
46 | app.config.numpydoc_show_inherited_class_members,
47 | 'class_members_toctree': app.config.numpydoc_class_members_toctree}
48 |
49 | u_NL = sixu('\n')
50 | if what == 'module':
51 | # Strip top title
52 | pattern = '^\\s*[#*=]{4,}\\n[a-z0-9 -]+\\n[#*=]{4,}\\s*'
53 | title_re = re.compile(sixu(pattern), re.I | re.S)
54 | lines[:] = title_re.sub(sixu(''), u_NL.join(lines)).split(u_NL)
55 | else:
56 | doc = get_doc_object(obj, what, u_NL.join(lines), config=cfg)
57 | if sys.version_info[0] >= 3:
58 | doc = str(doc)
59 | else:
60 | doc = unicode(doc)
61 | lines[:] = doc.split(u_NL)
62 |
63 | if (app.config.numpydoc_edit_link and hasattr(obj, '__name__') and
64 | obj.__name__):
65 | if hasattr(obj, '__module__'):
66 | v = dict(full_name=sixu("%s.%s") % (obj.__module__, obj.__name__))
67 | else:
68 | v = dict(full_name=obj.__name__)
69 | lines += [sixu(''), sixu('.. htmlonly::'), sixu('')]
70 | lines += [sixu(' %s') % x for x in
71 | (app.config.numpydoc_edit_link % v).split("\n")]
72 |
73 | # replace reference numbers so that there are no duplicates
74 | references = []
75 | for line in lines:
76 | line = line.strip()
77 | m = re.match(sixu('^.. \\[([a-z0-9_.-])\\]'), line, re.I)
78 | if m:
79 | references.append(m.group(1))
80 |
81 | # start renaming from the longest string, to avoid overwriting parts
82 | references.sort(key=lambda x: -len(x))
83 | if references:
84 | for i, line in enumerate(lines):
85 | for r in references:
86 | if re.match(sixu('^\\d+$'), r):
87 | new_r = sixu("R%d") % (reference_offset[0] + int(r))
88 | else:
89 | new_r = sixu("%s%d") % (r, reference_offset[0])
90 | lines[i] = lines[i].replace(sixu('[%s]_') % r,
91 | sixu('[%s]_') % new_r)
92 | lines[i] = lines[i].replace(sixu('.. [%s]') % r,
93 | sixu('.. [%s]') % new_r)
94 |
95 | reference_offset[0] += len(references)
96 |
97 |
98 | def mangle_signature(app, what, name, obj, options, sig, retann):
99 | # Do not try to inspect classes that don't define `__init__`
100 | if (inspect.isclass(obj) and
101 | (not hasattr(obj, '__init__') or
102 | 'initializes x; see ' in pydoc.getdoc(obj.__init__))):
103 | return '', ''
104 |
105 | if not (isinstance(obj, collections.Callable) or
106 | hasattr(obj, '__argspec_is_invalid_')):
107 | return
108 |
109 | if not hasattr(obj, '__doc__'):
110 | return
111 |
112 | doc = SphinxDocString(pydoc.getdoc(obj))
113 | if doc['Signature']:
114 | sig = re.sub(sixu("^[^(]*"), sixu(""), doc['Signature'])
115 | return sig, sixu('')
116 |
117 |
118 | def setup(app, get_doc_object_=get_doc_object):
119 | if not hasattr(app, 'add_config_value'):
120 | return # probably called by nose, better bail out
121 |
122 | global get_doc_object
123 | get_doc_object = get_doc_object_
124 |
125 | app.connect('autodoc-process-docstring', mangle_docstrings)
126 | app.connect('autodoc-process-signature', mangle_signature)
127 | app.add_config_value('numpydoc_edit_link', None, False)
128 | app.add_config_value('numpydoc_use_plots', None, False)
129 | app.add_config_value('numpydoc_show_class_members', True, True)
130 | app.add_config_value('numpydoc_show_inherited_class_members', True, True)
131 | app.add_config_value('numpydoc_class_members_toctree', True, True)
132 |
133 | # Extra mangling domains
134 | app.add_domain(NumpyPythonDomain)
135 | app.add_domain(NumpyCDomain)
136 |
137 | # ------------------------------------------------------------------------------
138 | # Docstring-mangling domains
139 | # ------------------------------------------------------------------------------
140 |
141 | from docutils.statemachine import ViewList
142 | from sphinx.domains.c import CDomain
143 | from sphinx.domains.python import PythonDomain
144 |
145 |
146 | class ManglingDomainBase(object):
147 | directive_mangling_map = {}
148 |
149 | def __init__(self, *a, **kw):
150 | super(ManglingDomainBase, self).__init__(*a, **kw)
151 | self.wrap_mangling_directives()
152 |
153 | def wrap_mangling_directives(self):
154 | for name, objtype in list(self.directive_mangling_map.items()):
155 | self.directives[name] = wrap_mangling_directive(
156 | self.directives[name], objtype)
157 |
158 |
159 | class NumpyPythonDomain(ManglingDomainBase, PythonDomain):
160 | name = 'np'
161 | directive_mangling_map = {
162 | 'function': 'function',
163 | 'class': 'class',
164 | 'exception': 'class',
165 | 'method': 'function',
166 | 'classmethod': 'function',
167 | 'staticmethod': 'function',
168 | 'attribute': 'attribute',
169 | }
170 | indices = []
171 |
172 |
173 | class NumpyCDomain(ManglingDomainBase, CDomain):
174 | name = 'np-c'
175 | directive_mangling_map = {
176 | 'function': 'function',
177 | 'member': 'attribute',
178 | 'macro': 'function',
179 | 'type': 'class',
180 | 'var': 'object',
181 | }
182 |
183 |
184 | def wrap_mangling_directive(base_directive, objtype):
185 | class directive(base_directive):
186 | def run(self):
187 | env = self.state.document.settings.env
188 |
189 | name = None
190 | if self.arguments:
191 | m = re.match(r'^(.*\s+)?(.*?)(\(.*)?', self.arguments[0])
192 | name = m.group(2).strip()
193 |
194 | if not name:
195 | name = self.arguments[0]
196 |
197 | lines = list(self.content)
198 | mangle_docstrings(env.app, objtype, name, None, None, lines)
199 | self.content = ViewList(lines, self.content.parent)
200 |
201 | return base_directive.run(self)
202 |
203 | return directive
204 |
--------------------------------------------------------------------------------
/doc/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = _build
9 |
10 | # User-friendly check for sphinx-build
11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
13 | endif
14 |
15 | # Internal variables.
16 | PAPEROPT_a4 = -D latex_paper_size=a4
17 | PAPEROPT_letter = -D latex_paper_size=letter
18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
19 | # the i18n builder cannot share the environment and doctrees with the others
20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
21 |
22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
23 |
24 | help:
25 | @echo "Please use \`make ' where is one of"
26 | @echo " html to make standalone HTML files"
27 | @echo " dirhtml to make HTML files named index.html in directories"
28 | @echo " singlehtml to make a single large HTML file"
29 | @echo " pickle to make pickle files"
30 | @echo " json to make JSON files"
31 | @echo " htmlhelp to make HTML files and a HTML help project"
32 | @echo " qthelp to make HTML files and a qthelp project"
33 | @echo " devhelp to make HTML files and a Devhelp project"
34 | @echo " epub to make an epub"
35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
36 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
38 | @echo " text to make text files"
39 | @echo " man to make manual pages"
40 | @echo " texinfo to make Texinfo files"
41 | @echo " info to make Texinfo files and run them through makeinfo"
42 | @echo " gettext to make PO message catalogs"
43 | @echo " changes to make an overview of all changed/added/deprecated items"
44 | @echo " xml to make Docutils-native XML files"
45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes"
46 | @echo " linkcheck to check all external links for integrity"
47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
48 |
49 |
50 | clean:
51 | rm -rf $(BUILDDIR)/*
52 | rm -rf reference/*
53 |
54 | html:
55 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
56 | @echo
57 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
58 |
59 | dirhtml:
60 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
61 | @echo
62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
63 |
64 | singlehtml:
65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
66 | @echo
67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
68 |
69 | pickle:
70 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
71 | @echo
72 | @echo "Build finished; now you can process the pickle files."
73 |
74 | json:
75 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
76 | @echo
77 | @echo "Build finished; now you can process the JSON files."
78 |
79 | htmlhelp:
80 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
81 | @echo
82 | @echo "Build finished; now you can run HTML Help Workshop with the" \
83 | ".hhp project file in $(BUILDDIR)/htmlhelp."
84 |
85 | qthelp:
86 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
87 | @echo
88 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
89 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
90 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/annsa.qhcp"
91 | @echo "To view the help file:"
92 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/annsa.qhc"
93 |
94 | devhelp:
95 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
96 | @echo
97 | @echo "Build finished."
98 | @echo "To view the help file:"
99 | @echo "# mkdir -p $$HOME/.local/share/devhelp/annsa"
100 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/annsa"
101 | @echo "# devhelp"
102 |
103 | epub:
104 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
105 | @echo
106 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
107 |
108 | latex:
109 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
110 | @echo
111 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
112 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
113 | "(use \`make latexpdf' here to do that automatically)."
114 |
115 | latexpdf:
116 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
117 | @echo "Running LaTeX files through pdflatex..."
118 | $(MAKE) -C $(BUILDDIR)/latex all-pdf
119 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
120 |
121 | latexpdfja:
122 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
123 | @echo "Running LaTeX files through platex and dvipdfmx..."
124 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
125 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
126 |
127 | text:
128 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
129 | @echo
130 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
131 |
132 | man:
133 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
134 | @echo
135 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
136 |
137 | texinfo:
138 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
139 | @echo
140 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
141 | @echo "Run \`make' in that directory to run these through makeinfo" \
142 | "(use \`make info' here to do that automatically)."
143 |
144 | info:
145 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
146 | @echo "Running Texinfo files through makeinfo..."
147 | make -C $(BUILDDIR)/texinfo info
148 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
149 |
150 | gettext:
151 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
152 | @echo
153 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
154 |
155 | changes:
156 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
157 | @echo
158 | @echo "The overview file is in $(BUILDDIR)/changes."
159 |
160 | linkcheck:
161 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
162 | @echo
163 | @echo "Link check complete; look for any errors in the above output " \
164 | "or in $(BUILDDIR)/linkcheck/output.txt."
165 |
166 | doctest:
167 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
168 | @echo "Testing of doctests in the sources finished, look at the " \
169 | "results in $(BUILDDIR)/doctest/output.txt."
170 |
171 | xml:
172 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
173 | @echo
174 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
175 |
176 | pseudoxml:
177 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
178 | @echo
179 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
180 |
181 | show:
182 | @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/index.html')"
183 |
--------------------------------------------------------------------------------
/annsa/load_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from sklearn.datasets import make_classification
4 | from sklearn.preprocessing import LabelBinarizer
5 |
6 |
7 | def load_easy(source_dataset, background_dataset):
8 | source_dataset = source_dataset[source_dataset['fwhm'] == 7.5]
9 | source_dataset = source_dataset[source_dataset['sourcedist'] == 175.0]
10 | source_dataset = source_dataset[source_dataset['sourceheight'] == 100.0]
11 |
12 | # remove 80% shielding
13 | source_dataset = source_dataset[
14 | source_dataset['shieldingdensity'] != 13.16]
15 | source_dataset = source_dataset[
16 | source_dataset['shieldingdensity'] != 11.02]
17 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 1.61]
18 |
19 | # remove 60% shielding
20 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 7.49]
21 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 6.28]
22 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 0.92]
23 |
24 | # remove 40% shielding
25 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 4.18]
26 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 3.5]
27 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 0.51]
28 |
29 | # remove 20% shielding
30 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 1.82]
31 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 1.53]
32 | source_dataset = source_dataset[source_dataset['shieldingdensity'] != 0.22]
33 |
34 | # remove empty spectra
35 | zero_count_indicies = np.argwhere(
36 | np.sum(source_dataset.values[:, 6:], axis=1) == 0).flatten()
37 |
38 | print('indicies dropped: ' + str(zero_count_indicies))
39 |
40 | source_dataset.drop(
41 | source_dataset.index[zero_count_indicies], inplace=True)
42 |
43 | # Add empty spectra for background
44 | blank_spectra = []
45 | for fwhm in set(source_dataset['fwhm']):
46 | num_examples = source_dataset[(source_dataset['fwhm'] == fwhm) &
47 | (source_dataset['isotope'] ==
48 | source_dataset['isotope'].iloc()[0])
49 | ].shape[0]
50 | for k in range(num_examples):
51 | blank_spectra_tmp = [0] * 1200
52 | blank_spectra_tmp[5] = fwhm
53 | blank_spectra_tmp[0] = 'background'
54 | blank_spectra_tmp[3] = 'background'
55 | blank_spectra.append(blank_spectra_tmp)
56 |
57 | source_dataset = source_dataset.append(pd.DataFrame(blank_spectra,
58 | columns=source_dataset.columns))
59 |
60 | spectra_dataset = source_dataset.values[:, 5:].astype('float64')
61 | all_keys = source_dataset['isotope'].values
62 |
63 | return source_dataset, spectra_dataset, all_keys
64 |
65 |
66 | def load_full(source_dataset, background_dataset):
67 | source_dataset = source_dataset[(source_dataset['fwhm'] == 7.0) |
68 | (source_dataset['fwhm'] == 7.5) |
69 | (source_dataset['fwhm'] == 8.0)]
70 |
71 | source_dataset = source_dataset[(source_dataset['sourcedist'] == 50.5) |
72 | (source_dataset['sourcedist'] == 175.0) |
73 | (source_dataset['sourcedist'] == 300.0)]
74 |
75 | source_dataset = source_dataset[(source_dataset['sourceheight'] == 50.0) |
76 | (source_dataset['sourceheight'] == 100.0) |
77 | (source_dataset['sourceheight'] == 150.0)]
78 |
79 | # remove 80% shielding
80 | source_dataset = source_dataset[
81 | source_dataset['shieldingdensity'] != 13.16]
82 | source_dataset = source_dataset[
83 | source_dataset['shieldingdensity'] != 11.02]
84 | source_dataset = source_dataset[
85 | source_dataset['shieldingdensity'] != 1.61]
86 |
87 | # remove empty spectra
88 | zero_count_indicies = np.argwhere(np.sum(source_dataset.values[:, 6:],
89 | axis=1) == 0).flatten()
90 |
91 | print('indicies dropped: ' + str(zero_count_indicies))
92 |
93 | source_dataset.drop(source_dataset.index[zero_count_indicies],
94 | inplace=True)
95 |
96 | # Add empty spectra for background
97 | blank_spectra = []
98 | for fwhm in set(source_dataset['fwhm']):
99 | num_examples = source_dataset[
100 | (source_dataset['fwhm'] == fwhm) &
101 | (source_dataset['isotope'] ==
102 | source_dataset['isotope'].iloc()[0])].shape[0]
103 | for k in range(num_examples):
104 | blank_spectra_tmp = [0] * 1200
105 | blank_spectra_tmp[5] = fwhm
106 | blank_spectra_tmp[0] = 'background'
107 | blank_spectra_tmp[3] = 'background'
108 | blank_spectra.append(blank_spectra_tmp)
109 |
110 | source_dataset = source_dataset.append(pd.DataFrame(
111 | blank_spectra,
112 | columns=source_dataset.columns))
113 |
114 | spectra_dataset = source_dataset.values[:, 5:].astype('float64')
115 | all_keys = source_dataset['isotope'].values
116 |
117 | return source_dataset, spectra_dataset, all_keys
118 |
119 |
120 | def dataset_to_spectrakeys(dataset):
121 | '''
122 | Loads a dataset into spectra and corresponding isotope keys
123 | '''
124 |
125 | source_spectra = np.array(dataset.item()['sources'], dtype='float64')
126 | background_spectra = np.array(dataset.item()['backgrounds'],
127 | dtype='float64')
128 |
129 | spectra = np.random.poisson(np.add(source_spectra, background_spectra))
130 | keys = np.array(dataset.item()['keys'])
131 |
132 | return spectra, keys
133 |
134 |
135 | def load_dataset(kind='nn'):
136 | """
137 | Generates dummy data using 'sklearn.datasets.make_classification()'.
138 | See 'make_classification' documentation for more details.
139 |
140 | Parameters:
141 | kind : string, optional
142 | A string describing what kind of neural network this dataset
143 | will be used for. Default is 'nn.'
144 | Accepts:
145 | 'nn' (standard convolution or dense neural networks)
146 | 'ae' (autoencoder)
147 |
148 |
149 | Returns:
150 | -------
151 | train_dataset : tuple of [train_data, training_keys_binarized]
152 | Contains the training data and the labels in a binarized
153 | format.
154 | test_dataset : tuple of [test_data, testing_keys_binarized]
155 | Contains the testing data and the labels in a binarized
156 | format.
157 | """
158 |
159 | training_dataset = make_classification(n_samples=100,
160 | n_features=1024,
161 | n_informative=200,
162 | n_classes=2)
163 |
164 | testing_dataset = make_classification(n_samples=100,
165 | n_features=1024,
166 | n_informative=200,
167 | n_classes=2)
168 |
169 | mlb = LabelBinarizer()
170 |
171 | # transform the training data
172 | training_data = np.abs(training_dataset[0])
173 | training_keys = training_dataset[1]
174 | training_keys_binarized = mlb.fit_transform(
175 | training_keys.reshape([training_data.shape[0], 1]))
176 |
177 | # transform the testing data
178 | testing_data = np.abs(testing_dataset[0])
179 | testing_keys = testing_dataset[1]
180 | testing_keys_binarized = mlb.transform(
181 | testing_keys.reshape([testing_data.shape[0], 1]))
182 |
183 | if kind == 'nn':
184 | test_dataset = [testing_data, testing_keys_binarized]
185 | train_dataset = [training_data, training_keys_binarized]
186 |
187 | elif kind == 'ae':
188 | train_dataset = [training_data, training_data]
189 | test_dataset = [testing_data, testing_data]
190 |
191 | return train_dataset, test_dataset
192 |
--------------------------------------------------------------------------------
/doc/sphinxext/docscrape_sphinx.py:
--------------------------------------------------------------------------------
1 | import re, inspect, textwrap, pydoc
2 | import sphinx
3 | from docscrape import NumpyDocString, FunctionDoc, ClassDoc
4 |
5 | class SphinxDocString(NumpyDocString):
6 | def __init__(self, docstring, config={}):
7 | self.use_plots = config.get('use_plots', False)
8 | NumpyDocString.__init__(self, docstring, config=config)
9 |
10 | # string conversion routines
11 | def _str_header(self, name, symbol='`'):
12 | return ['.. rubric:: ' + name, '']
13 |
14 | def _str_field_list(self, name):
15 | return [':' + name + ':']
16 |
17 | def _str_indent(self, doc, indent=4):
18 | out = []
19 | for line in doc:
20 | out += [' '*indent + line]
21 | return out
22 |
23 | def _str_signature(self):
24 | return ['']
25 | if self['Signature']:
26 | return ['``%s``' % self['Signature']] + ['']
27 | else:
28 | return ['']
29 |
30 | def _str_summary(self):
31 | return self['Summary'] + ['']
32 |
33 | def _str_extended_summary(self):
34 | return self['Extended Summary'] + ['']
35 |
36 | def _str_param_list(self, name):
37 | out = []
38 | if self[name]:
39 | out += self._str_field_list(name)
40 | out += ['']
41 | for param,param_type,desc in self[name]:
42 | out += self._str_indent(['**%s** : %s' % (param.strip(),
43 | param_type)])
44 | out += ['']
45 | out += self._str_indent(desc,8)
46 | out += ['']
47 | return out
48 |
49 | @property
50 | def _obj(self):
51 | if hasattr(self, '_cls'):
52 | return self._cls
53 | elif hasattr(self, '_f'):
54 | return self._f
55 | return None
56 |
57 | def _str_member_list(self, name):
58 | """
59 | Generate a member listing, autosummary:: table where possible,
60 | and a table where not.
61 |
62 | """
63 | out = []
64 | if self[name]:
65 | out += ['.. rubric:: %s' % name, '']
66 | prefix = getattr(self, '_name', '')
67 |
68 | if prefix:
69 | prefix = '~%s.' % prefix
70 |
71 | autosum = []
72 | others = []
73 | for param, param_type, desc in self[name]:
74 | param = param.strip()
75 | if not self._obj or hasattr(self._obj, param):
76 | autosum += [" %s%s" % (prefix, param)]
77 | else:
78 | others.append((param, param_type, desc))
79 |
80 | if autosum:
81 | out += ['.. autosummary::', ' :toctree:', '']
82 | out += autosum
83 |
84 | if others:
85 | maxlen_0 = max([len(x[0]) for x in others])
86 | maxlen_1 = max([len(x[1]) for x in others])
87 | hdr = "="*maxlen_0 + " " + "="*maxlen_1 + " " + "="*10
88 | fmt = '%%%ds %%%ds ' % (maxlen_0, maxlen_1)
89 | n_indent = maxlen_0 + maxlen_1 + 4
90 | out += [hdr]
91 | for param, param_type, desc in others:
92 | out += [fmt % (param.strip(), param_type)]
93 | out += self._str_indent(desc, n_indent)
94 | out += [hdr]
95 | out += ['']
96 | return out
97 |
98 | def _str_section(self, name):
99 | out = []
100 | if self[name]:
101 | out += self._str_header(name)
102 | out += ['']
103 | content = textwrap.dedent("\n".join(self[name])).split("\n")
104 | out += content
105 | out += ['']
106 | return out
107 |
108 | def _str_see_also(self, func_role):
109 | out = []
110 | if self['See Also']:
111 | see_also = super(SphinxDocString, self)._str_see_also(func_role)
112 | out = ['.. seealso::', '']
113 | out += self._str_indent(see_also[2:])
114 | return out
115 |
116 | def _str_warnings(self):
117 | out = []
118 | if self['Warnings']:
119 | out = ['.. warning::', '']
120 | out += self._str_indent(self['Warnings'])
121 | return out
122 |
123 | def _str_index(self):
124 | idx = self['index']
125 | out = []
126 | if len(idx) == 0:
127 | return out
128 |
129 | out += ['.. index:: %s' % idx.get('default','')]
130 | for section, references in idx.iteritems():
131 | if section == 'default':
132 | continue
133 | elif section == 'refguide':
134 | out += [' single: %s' % (', '.join(references))]
135 | else:
136 | out += [' %s: %s' % (section, ','.join(references))]
137 | return out
138 |
139 | def _str_references(self):
140 | out = []
141 | if self['References']:
142 | out += self._str_header('References')
143 | if isinstance(self['References'], str):
144 | self['References'] = [self['References']]
145 | out.extend(self['References'])
146 | out += ['']
147 | # Latex collects all references to a separate bibliography,
148 | # so we need to insert links to it
149 | if sphinx.__version__ >= "0.6":
150 | out += ['.. only:: latex','']
151 | else:
152 | out += ['.. latexonly::','']
153 | items = []
154 | for line in self['References']:
155 | m = re.match(r'.. \[([a-z0-9._-]+)\]', line, re.I)
156 | if m:
157 | items.append(m.group(1))
158 | out += [' ' + ", ".join(["[%s]_" % item for item in items]), '']
159 | return out
160 |
161 | def _str_examples(self):
162 | examples_str = "\n".join(self['Examples'])
163 |
164 | if (self.use_plots and 'import matplotlib' in examples_str
165 | and 'plot::' not in examples_str):
166 | out = []
167 | out += self._str_header('Examples')
168 | out += ['.. plot::', '']
169 | out += self._str_indent(self['Examples'])
170 | out += ['']
171 | return out
172 | else:
173 | return self._str_section('Examples')
174 |
175 | def __str__(self, indent=0, func_role="obj"):
176 | out = []
177 | out += self._str_signature()
178 | out += self._str_index() + ['']
179 | out += self._str_summary()
180 | out += self._str_extended_summary()
181 | for param_list in ('Parameters', 'Returns', 'Other Parameters',
182 | 'Raises', 'Warns'):
183 | out += self._str_param_list(param_list)
184 | out += self._str_warnings()
185 | out += self._str_see_also(func_role)
186 | out += self._str_section('Notes')
187 | out += self._str_references()
188 | out += self._str_examples()
189 | for param_list in ('Attributes', 'Methods'):
190 | out += self._str_member_list(param_list)
191 | out = self._str_indent(out,indent)
192 | return '\n'.join(out)
193 |
194 | class SphinxFunctionDoc(SphinxDocString, FunctionDoc):
195 | def __init__(self, obj, doc=None, config={}):
196 | self.use_plots = config.get('use_plots', False)
197 | FunctionDoc.__init__(self, obj, doc=doc, config=config)
198 |
199 | class SphinxClassDoc(SphinxDocString, ClassDoc):
200 | def __init__(self, obj, doc=None, func_doc=None, config={}):
201 | self.use_plots = config.get('use_plots', False)
202 | ClassDoc.__init__(self, obj, doc=doc, func_doc=None, config=config)
203 |
204 | class SphinxObjDoc(SphinxDocString):
205 | def __init__(self, obj, doc=None, config={}):
206 | self._f = obj
207 | SphinxDocString.__init__(self, doc, config=config)
208 |
209 | def get_doc_object(obj, what=None, doc=None, config={}):
210 | if what is None:
211 | if inspect.isclass(obj):
212 | what = 'class'
213 | elif inspect.ismodule(obj):
214 | what = 'module'
215 | elif callable(obj):
216 | what = 'function'
217 | else:
218 | what = 'object'
219 | if what == 'class':
220 | return SphinxClassDoc(obj, func_doc=SphinxFunctionDoc, doc=doc,
221 | config=config)
222 | elif what in ('function', 'method'):
223 | return SphinxFunctionDoc(obj, doc=doc, config=config)
224 | else:
225 | if doc is None:
226 | doc = pydoc.getdoc(obj)
227 | return SphinxObjDoc(obj, doc, config=config)
228 |
--------------------------------------------------------------------------------
/examples/source-interdiction/hyperparameter-search/hyperparameter_models.py:
--------------------------------------------------------------------------------
1 | from random import choice
2 | import numpy as np
3 | from annsa.model_classes import (dnn_model_features,
4 | cnn1d_model_features,
5 | cae_model_features,
6 | DNN,
7 | CNN1D,
8 | CAE,
9 | )
10 | import tensorflow as tf
11 | from sklearn.pipeline import make_pipeline
12 | from sklearn.preprocessing import FunctionTransformer, Normalizer
13 |
14 | scaler_choices = [make_pipeline(FunctionTransformer(np.log1p, validate=True)),
15 | make_pipeline(FunctionTransformer(np.log1p, validate=True),
16 | Normalizer(norm='l1')),
17 | make_pipeline(FunctionTransformer(np.log1p, validate=True),
18 | Normalizer(norm='max')),
19 | make_pipeline(FunctionTransformer(np.sqrt, validate=True)),
20 | make_pipeline(FunctionTransformer(np.sqrt, validate=True),
21 | Normalizer(norm='l1')),
22 | make_pipeline(FunctionTransformer(np.sqrt, validate=True),
23 | Normalizer(norm='max')), ]
24 |
25 |
26 | def make_dense_model(all_keys_binarized):
27 | """
28 | Makes a random dense model given some parameters.
29 |
30 | Parameters:
31 | -----------
32 | all_keys_binarized : list, bool
33 | List of binarized keys
34 | Returns:
35 | --------
36 | model : object
37 | A keras model of a dense network.
38 | model_features : class
39 | Class that describes the structure of a the dense network
40 | """
41 | number_layers = choice([1, 2, 3])
42 | dense_nodes = 2**np.random.randint(5, 10, number_layers)
43 | dense_nodes = np.sort(dense_nodes)
44 | dense_nodes = np.flipud(dense_nodes)
45 | model_features = dnn_model_features(
46 | learining_rate=10**np.random.uniform(-4, -1),
47 | l2_regularization_scale=10**np.random.uniform(-2, 0),
48 | dropout_probability=np.random.uniform(0, 1),
49 | batch_size=2**np.random.randint(4, 10),
50 | output_size=all_keys_binarized.shape[1],
51 | dense_nodes=dense_nodes,
52 | activation_function=choice([tf.nn.tanh, tf.nn.relu]),
53 | output_function=None,
54 | scaler=choice(scaler_choices))
55 |
56 | model = DNN(model_features)
57 |
58 | return model, model_features
59 |
60 |
61 | def generate_random_cnn1d_architecture(cnn_filters_choices,
62 | cnn_kernel_choices,
63 | pool_size_choices):
64 | """
65 | Makes a random convolutional model features given some parameters.
66 |
67 | Parameters:
68 | -----------
69 | cnn_filters_choices : list, int
70 | List of number of filter to use for each convolutional layer.
71 | cnn_kernel_choices : list, int
72 | List of filter lengths to use for each convolutional layer.
73 | pool_size_choices : list, int
74 | List of pooling lengths to use for each convolutional layer.
75 |
76 | Returns:
77 | --------
78 | model_features : class
79 | Class that describes the structure of a the 1D convolutional
80 | network
81 | """
82 |
83 | cnn_filters = choice(cnn_filters_choices)
84 | cnn_kernel_choice = choice(cnn_kernel_choices)
85 | pool_size_choice = choice(pool_size_choices)
86 |
87 | cnn_kernel = cnn_kernel_choice * (len(cnn_filters))
88 | cnn_strides = (1,) * (len(cnn_filters))
89 | pool_size = pool_size_choice * (len(cnn_filters))
90 | pool_strides = (2,) * (len(cnn_filters))
91 |
92 | number_layers = np.random.randint(1, 4)
93 | dense_nodes = (10 ** np.random.uniform(
94 | 1,
95 | np.log10(1024 / (2 ** len(cnn_filters))),
96 | number_layers)).astype('int')
97 | dense_nodes = np.sort(dense_nodes)
98 | dense_nodes = np.flipud(dense_nodes)
99 |
100 | model_features = cnn1d_model_features(
101 | trainable=None,
102 | learining_rate=None,
103 | batch_size=None,
104 | output_size=None,
105 | scaler=None,
106 | activation_function=None,
107 | output_function=None,
108 | Pooling=None,
109 | l2_regularization_scale=None,
110 | dropout_probability=None,
111 | cnn_filters=cnn_filters,
112 | cnn_kernel=cnn_kernel,
113 | cnn_strides=cnn_strides,
114 | pool_size=pool_size,
115 | pool_strides=pool_strides,
116 | dense_nodes=dense_nodes
117 | )
118 |
119 | return model_features
120 |
121 |
122 | def make_conv1d_model(all_keys_binarized):
123 | """
124 | Makes a random convolutional model and model features using
125 | predefined parameters.
126 |
127 | Parameters:
128 | -----------
129 | all_keys_binarized : list, bool
130 | List of binarized keys
131 |
132 | Returns:
133 | --------
134 | model : object
135 | A keras model of a convolutional network.
136 | model_features : class
137 | Class that describes the structure of a the convolutional network
138 | """
139 |
140 | cnn_filters_choices = (
141 | (4, 8),
142 | (8, 16),
143 | (16, 32),
144 | (4,),
145 | (8,),
146 | (16,),
147 | (32,),
148 | (4, 8, 16),
149 | (8, 16, 32),
150 | )
151 |
152 | cnn_kernel_choices = ((2,), (4,), (8,), (16,))
153 | pool_size_choices = ((2,), (4,), (8,), (16,))
154 |
155 | model_features = generate_random_cnn1d_architecture(
156 | cnn_filters_choices=cnn_filters_choices,
157 | cnn_kernel_choices=cnn_kernel_choices,
158 | pool_size_choices=pool_size_choices,
159 | )
160 | model_features.trainable = True
161 | model_features.learining_rate = 10 ** np.random.uniform(-4, -1)
162 | model_features.batch_size = 2 ** np.random.randint(4, 6)
163 | model_features.output_size = all_keys_binarized.shape[1]
164 | model_features.scaler = choice(scaler_choices)
165 |
166 | model_features.activation_function = tf.nn.relu
167 | model_features.output_function = None
168 | model_features.Pooling = tf.layers.MaxPooling1D
169 | model_features.l2_regularization_scale = 10 ** np.random.uniform(-3, 0)
170 | model_features.dropout_probability = np.random.uniform(0, 1)
171 | model_features.pool_strides = ((2, 2, 2))
172 | number_layers = choice([1, 2, 3])
173 | dense_nodes = 2 ** np.random.randint(4, 8, number_layers)
174 | dense_nodes = np.sort(dense_nodes)
175 | dense_nodes = np.flipud(dense_nodes)
176 | model_features.dense_nodes = dense_nodes
177 |
178 | model = CNN1D(model_features)
179 |
180 | return model, model_features
181 |
182 |
183 | def make_cae1d_model():
184 | """
185 | Makes a random convolutional model and model features using
186 | predefined parameters.
187 |
188 | Parameters:
189 | -----------
190 | None
191 |
192 | Returns:
193 | --------
194 | model : object
195 | A keras model of a convolutional autoencoder network.
196 | model_features : class
197 | Class that describes the structure of a the convolutional autoencoder
198 | network.
199 | """
200 |
201 | cnn_filters_encoder_choice = choice([(4, 1),
202 | (64, 1),
203 | (4, 8, 1),
204 | (32, 64, 1),
205 | (4, 8, 16, 1),
206 | (16, 32, 64, 1),
207 | ]
208 | )
209 |
210 | cnn_kernel_encoder_choice = choice([(2, ), (4, ), (8, ), (16, )])
211 | pool_size_choice = choice([(2, ), (4, ), (8, ), (16, )])
212 |
213 | scaler_choice = choice(scaler_choices)
214 |
215 | num_cnn_filters = len(cnn_filters_encoder_choice)
216 |
217 | model_features = cae_model_features(
218 | encoder_trainable=True,
219 | learning_rate=10 ** np.random.uniform(-4, -1),
220 | batch_size=2 ** np.random.randint(4, 6),
221 | scaler=scaler_choice,
222 | activation_function=tf.nn.tanh,
223 | output_function=None,
224 | Pooling=tf.layers.MaxPooling1D,
225 | cnn_filters_encoder=cnn_filters_encoder_choice,
226 | cnn_kernel_encoder=(cnn_kernel_encoder_choice,) * num_cnn_filters,
227 | cnn_strides_encoder=(1, ) * num_cnn_filters,
228 | pool_size_encoder=pool_size_choice * num_cnn_filters,
229 | pool_strides_encoder=(2, ) * num_cnn_filters,
230 | cnn_filters_decoder=cnn_filters_encoder_choice,
231 | cnn_kernel_decoder=(cnn_kernel_encoder_choice,) * num_cnn_filters,
232 | cnn_strides_decoder=(1, ) * num_cnn_filters)
233 |
234 | model = CAE(model_features)
235 |
236 | return model, model_features
237 |
--------------------------------------------------------------------------------
/annsa/load_pretrained_network.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import tensorflow as tf
3 | import pickle
4 | import numpy as np
5 |
6 | from annsa.model_classes import (cnn1d_model_features,
7 | CNN1D,
8 | CAE,
9 | DNN,
10 | DAE,
11 | dnn_model_features,
12 | )
13 |
14 |
15 | def save_features(model_features,
16 | filename):
17 | '''
18 | Saves model features to a pickled file
19 |
20 | inputs:
21 | model_features : model features class
22 | The class containing the model features
23 | filename : str
24 | The location of the pickled file
25 | outputs:
26 | None
27 | '''
28 | with open(filename, 'wb+') as f:
29 | pickle.dump(model_features, f)
30 | return None
31 |
32 |
33 | def load_features(filename):
34 | '''
35 | Loads model features from a pickled file
36 |
37 | inputs:
38 | filename : str
39 | The location of the pickled file
40 | outputs:
41 | model_features : model features class
42 | The class containing the model features
43 | '''
44 |
45 | with open(filename, "rb") as f:
46 | model_features = pickle.load(f)
47 | return model_features
48 |
49 |
50 | def load_trained_model(model_class,
51 | features_filename,
52 | weights_filename,
53 | ):
54 | '''
55 | Loads a trained model.
56 |
57 | inputs:
58 | model_class : Keras model class
59 | The keras model class for the features
60 | features_filename : str
61 | Filname location of the pickle file containing the model features
62 | weights_filename : str
63 | Filname location of the file containing the pretrained weights
64 |
65 | outputs:
66 | model : Keras model class
67 | The class containing the keras model with pretrained weights
68 | model_features : model features class
69 | The class containing the model features.
70 | '''
71 |
72 | model_features = load_features(features_filename)
73 | trained_model = model_class(model_features)
74 | trained_model.load_weights(weights_filename)
75 |
76 | # need to do a forward pass to initialize weights
77 | dummy_data = np.zeros(1024)
78 | _ = trained_model.predict_class([dummy_data])
79 |
80 | return trained_model, model_features
81 |
82 |
83 | def load_pretrained_cae_into_cnn(cae_features_filename,
84 | cae_weights_filename,
85 | cnn_dense_nodes=[128],
86 | learning_rate=1e-4,
87 | batch_size=32,
88 | output_size=30,
89 | activation_function=tf.nn.tanh,
90 | l2_regularization_scale=0.0,
91 | dropout_probability=0.0):
92 | '''
93 | Initialized a CNN with a pretrained CAE for fine-tuning.
94 |
95 | inputs:
96 | cae_features_filename : str
97 | Filname location of the pickle file containing the CAE features
98 | cae_weights_filename : str
99 | Filname location of the file containing the CAE pretrained weights
100 | cnn_dense_nodes : list, int (optional)
101 | List of integers describing the dense part of the CNN
102 | learning_rate : float (optional)
103 | Learning rate for the CNN
104 | batch_size : int (optional)
105 | Training batch size for the CNN
106 | output_size : int (optional)
107 | Output size for the CNN
108 | activation_function : tensorflow function
109 | Activation function for the CNN
110 | l2_regularization_scale : float (optional)
111 | The dropout probability for the dense layer
112 | dropout_probability : float (optional)
113 | The dropout probability for the dense layer
114 | outputs:
115 | CNN_model : Keras model class
116 | The class containing the keras model with pretrained weights
117 | model_features_CNN : model features class
118 | The class containing the model features.
119 | '''
120 |
121 | cae_features = load_features(cae_features_filename)
122 | CAE_model = CAE(cae_features)
123 | if cae_weights_filename:
124 | CAE_model.load_weights(cae_weights_filename)
125 |
126 | # need to do a forward pass to initialize weights
127 | dummy_data = np.zeros(1024)
128 | _ = CAE_model.encoder([dummy_data])
129 |
130 | model_features_CNN = cnn1d_model_features(
131 | learning_rate=learning_rate,
132 | trainable=True,
133 | batch_size=batch_size,
134 | output_size=output_size,
135 | output_function=None,
136 | l2_regularization_scale=l2_regularization_scale,
137 | dropout_probability=dropout_probability,
138 | scaler=CAE_model.scaler,
139 | Pooling=tf.layers.MaxPooling1D,
140 | cnn_filters=cae_features.cnn_filters_encoder,
141 | cnn_kernel=cae_features.cnn_kernel_encoder,
142 | cnn_strides=cae_features.cnn_strides_encoder,
143 | pool_size=cae_features.pool_size_encoder,
144 | pool_strides=cae_features.pool_strides_encoder,
145 | dense_nodes=cnn_dense_nodes,
146 | activation_function=activation_function,
147 | )
148 |
149 | CNN_model = CNN1D(model_features_CNN)
150 |
151 | # need to do a forward pass to initialize weights
152 | dummy_data = np.zeros(1024)
153 | _ = CNN_model.predict_class([dummy_data])
154 |
155 | for i in range(len(cae_features.cnn_filters_encoder)):
156 | CNN_model.layers[i].set_weights(CAE_model.layers[i].get_weights())
157 |
158 | return CNN_model, model_features_CNN
159 |
160 |
161 | def load_pretrained_dae_into_dnn(dae_features_filename,
162 | dae_weights_filename,
163 | dnn_dense_nodes=[128],
164 | learning_rate=1e-4,
165 | batch_size=32,
166 | output_size=30,
167 | activation_function=tf.nn.tanh,
168 | l2_regularization_scale=0.0,
169 | dropout_probability=0.0):
170 | '''
171 | Initialized a CNN with a pretrained CAE for fine-tuning.
172 |
173 | inputs:
174 | dae_features_filename : str
175 | Filname location of the pickle file containing the DAE features
176 | dae_weights_filename : str
177 | Filname location of the file containing the DAE pretrained weights
178 | dnn_dense_nodes : list, int (optional)
179 | List of integers describing the untrained dense part of the DNN
180 | learning_rate : float (optional)
181 | Learning rate for the DNN
182 | batch_size : int (optional)
183 | Training batch size for the DNN
184 | output_size : int (optional)
185 | Output size for the DNN
186 | activation_function : tensorflow function
187 | Activation function for the entire DNN
188 | l2_regularization_scale : float (optional)
189 | The dropout probability for all dense layers
190 | dropout_probability : float (optional)
191 | The dropout probability for all dense layers
192 | outputs:
193 | DNN_model : Keras model class
194 | The class containing the keras model with pretrained weights
195 | model_features_DNN : model features class
196 | The class containing the model features.
197 | '''
198 |
199 | dae_features = load_features(dae_features_filename)
200 | DAE_model = DAE(dae_features)
201 | if dae_weights_filename:
202 | DAE_model.load_weights(dae_weights_filename)
203 |
204 | # need to do a forward pass to initialize weights
205 | dummy_data = np.zeros(1024)
206 | _ = DAE_model.encoder([dummy_data])
207 |
208 | dense_nodes = DAE_model.dense_nodes_encoder + dnn_dense_nodes
209 |
210 | model_features_DNN = dnn_model_features(
211 | learning_rate=learning_rate,
212 | batch_size=batch_size,
213 | output_size=output_size,
214 | output_function=None,
215 | l2_regularization_scale=l2_regularization_scale,
216 | dropout_probability=dropout_probability,
217 | scaler=DAE_model.scaler,
218 | dense_nodes=dense_nodes,
219 | activation_function=dae_features.activation_function,
220 | )
221 |
222 | DNN_model = DNN(model_features_DNN)
223 |
224 | # Do a forward pass to initialize weights
225 | dummy_data = np.zeros(1024)
226 | _ = DNN_model.predict_class([dummy_data])
227 |
228 | for i in range(len(dae_features.dense_nodes_encoder)):
229 | DNN_model.layers[i].set_weights(DAE_model.layers[i].get_weights())
230 |
231 | return DNN_model, model_features_DNN
232 |
--------------------------------------------------------------------------------
/examples/source-interdiction/training-notebooks/cae-manual-easy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "GPU_device_id = str(5)\n",
10 | "model_id_save_as = 'caepretrain-easy-final'\n",
11 | "architecture_id = '../hyperparameter_search/hyperparameter-search-results/CNN-kfoldseasy-final-1-reluupdate_33'\n",
12 | "model_class_id = 'CAE'\n",
13 | "training_dataset_id = '../dataset_generation/hyperparametersearch_dataset_200keV_easy_log10time_1000.npy'\n",
14 | "difficulty_setting = 'easy'\n",
15 | "\n",
16 | "earlystop_patience = 10\n",
17 | "num_epochs = 2000"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import matplotlib.pyplot as plt\n",
27 | "import os\n",
28 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n",
29 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = GPU_device_id\n",
30 | "\n",
31 | "\n",
32 | "from sklearn.pipeline import make_pipeline\n",
33 | "from sklearn.preprocessing import FunctionTransformer, LabelBinarizer\n",
34 | "from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit\n",
35 | "import tensorflow as tf\n",
36 | "import pickle\n",
37 | "import numpy as np\n",
38 | "import pandas as pd\n",
39 | "from random import choice\n",
40 | "\n",
41 | "from numpy.random import seed\n",
42 | "seed(5)\n",
43 | "from tensorflow import set_random_seed\n",
44 | "set_random_seed(5)\n",
45 | "\n"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "#### Import model, training function "
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 3,
58 | "metadata": {},
59 | "outputs": [
60 | {
61 | "name": "stderr",
62 | "output_type": "stream",
63 | "text": [
64 | "Using TensorFlow backend.\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "from annsa.model_classes import compile_model, f1, build_cae_model\n",
70 | "from annsa.load_dataset import load_easy, load_full, dataset_to_spectrakeys\n",
71 | "from annsa.load_pretrained_network import load_features"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "## Training Data Construction"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 4,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "training_dataset = np.load(training_dataset_id)\n",
88 | "training_source_spectra, training_background_spectra, training_keys = dataset_to_spectrakeys(training_dataset,\n",
89 | " sampled=False,\n",
90 | " separate_background=True)"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "## Load Model"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 5,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "model_features = load_features(architecture_id)\n",
107 | "model_features.loss = tf.keras.losses.mean_squared_error\n",
108 | "model_features.optimizer = tf.keras.optimizers.Adam\n",
109 | "model_features.metrics = ['mse']\n",
110 | "model_features.input_dim = 1024"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": 6,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "model_features.pool_sizes = model_features.pool_size\n",
120 | "\n",
121 | "cae_features = model_features.to_cae_model_features()"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Train network"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "# Scale input data"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 8,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "name": "stderr",
145 | "output_type": "stream",
146 | "text": [
147 | "/home/ubuntu/anaconda3/envs/tensorflow_p36_update/lib/python3.6/site-packages/sklearn/preprocessing/_function_transformer.py:161: RuntimeWarning: invalid value encountered in sqrt\n",
148 | " **(kw_args if kw_args else {}))\n"
149 | ]
150 | }
151 | ],
152 | "source": [
153 | "training_input = np.random.poisson(training_source_spectra+training_background_spectra)\n",
154 | "training_output = training_source_spectra\n",
155 | "\n",
156 | "training_input = cae_features.scaler.transform(training_input)\n",
157 | "training_output = cae_features.scaler.transform(training_output)\n"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 11,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "earlystop_callback = tf.keras.callbacks.EarlyStopping(\n",
167 | " monitor='val_mean_squared_error',\n",
168 | " patience=earlystop_patience,\n",
169 | " mode='min',\n",
170 | " restore_best_weights=True)\n",
171 | "\n",
172 | "csv_logger = tf.keras.callbacks.CSVLogger('./final-models-keras/'+model_id_save_as+'.log')"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 12,
178 | "metadata": {
179 | "scrolled": false
180 | },
181 | "outputs": [
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "Train on 27000 samples, validate on 3000 samples\n",
187 | "Epoch 1/500\n",
188 | "27000/27000 [==============================] - 34s 1ms/sample - loss: 0.0042 - mean_squared_error: 0.0042 - val_loss: 0.0021 - val_mean_squared_error: 0.0021\n",
189 | "Epoch 2/500\n",
190 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0034 - mean_squared_error: 0.0034 - val_loss: 0.0019 - val_mean_squared_error: 0.0019\n",
191 | "Epoch 3/500\n",
192 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0031 - mean_squared_error: 0.0031 - val_loss: 0.0023 - val_mean_squared_error: 0.0023\n",
193 | "Epoch 4/500\n",
194 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0027 - mean_squared_error: 0.0027 - val_loss: 0.0019 - val_mean_squared_error: 0.0019\n",
195 | "Epoch 5/500\n",
196 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0019 - mean_squared_error: 0.0019 - val_loss: 0.0027 - val_mean_squared_error: 0.0027\n",
197 | "Epoch 6/500\n",
198 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0016 - mean_squared_error: 0.0016 - val_loss: 0.0024 - val_mean_squared_error: 0.0024\n",
199 | "Epoch 7/500\n",
200 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0014 - mean_squared_error: 0.0014 - val_loss: 0.0028 - val_mean_squared_error: 0.0028\n",
201 | "Epoch 8/500\n",
202 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0013 - mean_squared_error: 0.0013 - val_loss: 0.0029 - val_mean_squared_error: 0.0029\n",
203 | "Epoch 9/500\n",
204 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0012 - mean_squared_error: 0.0012 - val_loss: 0.0029 - val_mean_squared_error: 0.0029\n",
205 | "Epoch 10/500\n",
206 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 0.0011 - mean_squared_error: 0.0011 - val_loss: 0.0031 - val_mean_squared_error: 0.0031\n",
207 | "Epoch 11/500\n",
208 | "27000/27000 [==============================] - 31s 1ms/sample - loss: 0.0010 - mean_squared_error: 0.0010 - val_loss: 0.0023 - val_mean_squared_error: 0.0023\n",
209 | "Epoch 12/500\n",
210 | "27000/27000 [==============================] - 32s 1ms/sample - loss: 9.8017e-04 - mean_squared_error: 9.8018e-04 - val_loss: 0.0021 - val_mean_squared_error: 0.0021\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "mlb=LabelBinarizer()\n",
216 | "model = compile_model(\n",
217 | " build_cae_model,\n",
218 | " cae_features)\n",
219 | "\n",
220 | "\n",
221 | "output = model.fit(\n",
222 | " x=training_input,\n",
223 | " y=training_output,\n",
224 | " batch_size=model_features.batch_size,\n",
225 | " validation_split=0.1,\n",
226 | " epochs=500,\n",
227 | " verbose=1,\n",
228 | " shuffle=True,\n",
229 | " callbacks=[earlystop_callback, ],\n",
230 | ")\n"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 23,
236 | "metadata": {},
237 | "outputs": [],
238 | "source": [
239 | "model.save('./final-models-keras/'+model_id_save_as+'.hdf5')"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": null,
245 | "metadata": {},
246 | "outputs": [],
247 | "source": []
248 | }
249 | ],
250 | "metadata": {
251 | "kernelspec": {
252 | "display_name": "Environment (conda_tensorflow_p36_update)",
253 | "language": "python",
254 | "name": "conda_tensorflow_p36_update"
255 | },
256 | "language_info": {
257 | "codemirror_mode": {
258 | "name": "ipython",
259 | "version": 3
260 | },
261 | "file_extension": ".py",
262 | "mimetype": "text/x-python",
263 | "name": "python",
264 | "nbconvert_exporter": "python",
265 | "pygments_lexer": "ipython3",
266 | "version": "3.6.7"
267 | }
268 | },
269 | "nbformat": 4,
270 | "nbformat_minor": 1
271 | }
272 |
--------------------------------------------------------------------------------
/examples/source-interdiction/results-notebooks/aux_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from glob import glob
3 | from sklearn.metrics import f1_score
4 | import matplotlib.pyplot as plt
5 | import os
6 | from sklearn.preprocessing import LabelBinarizer
7 | import matplotlib.colors
8 | from annsa import read_spectrum
9 |
10 | def ensemble_predictions(members, scaler, testX):
11 | # scale inputs
12 | testX_scaled = scaler(testX)
13 | # make predictions
14 | yhats = [model.predict(testX_scaled) for model in members]
15 | yhats = np.array(yhats)
16 | # sum across ensemble members
17 | summed = np.sum(yhats, axis=0)
18 | # argmax across classes
19 | result = np.argmax(summed, axis=1)
20 | return result
21 |
22 |
23 | def ensemble_probas(members, scaler, testX):
24 | # scale inputs
25 | testX_scaled = scaler(testX)
26 | # make predictions
27 | yhats = [model.predict_proba(testX_scaled) for model in members]
28 | yhats = np.array(yhats)
29 | # average across ensemble members
30 | member_average = np.mean(yhats, axis=1)
31 | # average across classes
32 | average = np.mean(member_average, axis=0)
33 | return average
34 |
35 |
36 | def shielding_predictions(dataframe_data,
37 | all_models,
38 | scalers,
39 | isotope,
40 | shielding_material,
41 | shielding_amounts,
42 | shielding_strings,
43 | spectra_path,
44 | spectra_date,):
45 |
46 | for shielding_index, shielding_amount in enumerate(shielding_amounts):
47 | print(shielding_strings[shielding_index], end='\r')
48 | spectra = []
49 | for path in glob(os.path.join('..',
50 | 'training_testing_data',
51 | spectra_date,
52 | isotope,
53 | spectra_path+isotope+shielding_amount+'*.Spe')):
54 | spectra.append(read_spectrum(path))
55 | spectra_cumsum = np.cumsum(spectra, axis=0)
56 |
57 | for model_class in ['dnn', 'cnn', 'daednn', 'caednn',]:
58 | for mode in ['-easy', '-full']:
59 | all_output_probs = []
60 | model_id = model_class + mode
61 | for i in range(30):
62 | bagged_probs = ensemble_probas(all_models[model_id], scalers[model_id], [spectra_cumsum[i]])
63 | all_output_probs.append(bagged_probs)
64 |
65 | dataframe_data.append([model_id,
66 | shielding_material,
67 | shielding_strings[shielding_index],
68 | isotope,
69 | spectra_cumsum,
70 | all_output_probs])
71 |
72 | return dataframe_data
73 |
74 |
75 | def plot_measured_source_shielded_results(results_dataframe,
76 | isotope,
77 | gadras_isotope,
78 | shielding_material,
79 | shielding_strings,
80 | setting,):
81 |
82 | plt.rcParams.update({'font.size': 20})
83 | gadras_index = np.argwhere(mlb.classes_ == gadras_isotope).flatten()[0]
84 | plt.figure(figsize=(10,5))
85 | for option_index, shielding_string in enumerate(shielding_strings):
86 | for model_idindex, model_id in enumerate(['caednn-'+setting,
87 | 'daednn-'+setting,
88 | 'dnn-'+setting,
89 | 'cnn-'+setting,]):
90 | results_dataframe_tmp = results_dataframe[results_dataframe['model_id'] == model_id]
91 | results_dataframe_tmp = results_dataframe_tmp[results_dataframe_tmp['isotope'] == isotope]
92 | results_dataframe_tmp = results_dataframe_tmp[results_dataframe_tmp['shielding_material'] == shielding_material]
93 | results_dataframe_tmp = results_dataframe_tmp[results_dataframe_tmp['shielding_strings'] == shielding_string]
94 |
95 | plt.plot(np.linspace(10,300,30),
96 | 100*np.array(results_dataframe_tmp['posterior_prob'].values[0]).reshape(30,30)[:,gadras_index],
97 | linewidth=2.5,
98 | linestyle=linestyles[option_index],
99 | color=c1.colors[model_idindex],)
100 | plt.xlabel('Integration Time (seconds)')
101 | plt.ylabel('Posterior Probability')
102 | plt.ylim([0,110])
103 |
104 |
105 | def f1_score_bagged(model_ids,
106 | all_models,
107 | scalers,
108 | testing_spectra,
109 | testing_keys_binarized,):
110 | '''
111 | Bags a specific model's f1_score from a dictionary of models.
112 |
113 | Inputs
114 | model_ids : string, list
115 | List of model_id strings to bag. Specific model_id examples are
116 | 'dnn-full' or 'cae-easy'.
117 | all_models : dict
118 | Dictionary containing all models
119 | testing_spectra : numpy array
120 | Array containing multiple gamma-ray spectra
121 | testing_keys_binarized : numpy array
122 | Array containing one-hot encoded (binarized) keys corresponding to
123 | testing_spectra
124 |
125 | Outputs
126 | f1_scores : dict, str, float
127 | Dictionary indexed by model_id in model_ids, contains f1 score for
128 | that model
129 | '''
130 | f1_scores = {}
131 |
132 | for model_id in model_ids:
133 |
134 | predictions = ensemble_predictions(all_models[model_id], scalers[model_id], testing_spectra)
135 | true_labels = testing_keys_binarized.argmax(axis=1)
136 |
137 | f1_scores[model_id] = f1_score(true_labels,
138 | predictions,
139 | average='micro')
140 |
141 | return f1_scores
142 |
143 |
144 | def plot_f1_scores_bagged(dataframe,
145 | model_ids,
146 | all_models,
147 | scalers,
148 | indep_variable,
149 | plot_label,
150 | linestyle,
151 | color,
152 | **kwargs,
153 | ):
154 | '''
155 | Plots the F1 scores for model's in model_ids given some dataframe of spectra.
156 |
157 | Inputs
158 | dataframe : Pandas DataFrame
159 | DataFrame containing spectra, isotope names, and
160 | parameter options.
161 | model_ids : string, list
162 | List of model_id strings to bag. Specific model_id examples are
163 | 'dnn-full' or 'cae-easy'.
164 | all_models : dict
165 | Dictionary containing all models
166 | indep_variable : str
167 | The key for accessing the data column that contains the independent
168 | variable data. This data is plotted on the x-axis.
169 | kwargs : list, int, float
170 | Choices of different parameters to simulate
171 |
172 | Outputs
173 | None
174 | '''
175 | mlb = LabelBinarizer()
176 | keys = list(set(dataframe['isotope']))
177 | mlb.fit(keys)
178 |
179 | plt.rcParams.update({'font.size': 20})
180 | f1_scores_models = {}
181 | for key, value in kwargs.items():
182 | dataframe = dataframe[dataframe[key] == value]
183 | for model_id in model_ids:
184 | tmp_f1_scores = []
185 | for var in sorted(set(dataframe[indep_variable])):
186 |
187 | subset = dataframe[indep_variable] == var
188 | tmp_f1_score = f1_score_bagged([model_id],
189 | all_models,
190 | scalers,
191 | np.vstack(dataframe[subset]['spectrum'].to_numpy()),
192 | mlb.transform(dataframe['isotope'])[subset],)
193 | tmp_f1_scores.append(tmp_f1_score[model_id])
194 |
195 | # f1_scores_models[model_id] = tmp_f1_scores
196 | if plot_label:
197 | plt.plot(tmp_f1_scores,
198 | label=plot_label,
199 | linestyle=linestyle,
200 | linewidth=2.5,
201 | color=color,)
202 | else:
203 | plt.plot(tmp_f1_scores,
204 | label=model_id,
205 | linestyle=linestyle,
206 | linewidth=2.5,
207 | color=color,)
208 | # plt.legend()
209 | plt.xlabel(indep_variable)
210 | plt.ylabel('F1 Score')
211 | plt.ylim([0, 1])
212 | plt.xticks(
213 | range(len(sorted(set(dataframe[indep_variable])))),
214 | [round(var, 2) for var in sorted(set(dataframe[indep_variable]))])
215 |
216 |
217 | def categorical_cmap(nc, nsc, cmap="tab10", continuous=False):
218 | if nc > plt.get_cmap(cmap).N:
219 | raise ValueError("Too many categories for colormap.")
220 | if continuous:
221 | ccolors = plt.get_cmap(cmap)(np.linspace(0,1,nc))
222 | else:
223 | ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
224 | cols = np.zeros((nc*nsc, 3))
225 | for i, c in enumerate(ccolors):
226 | chsv = matplotlib.colors.rgb_to_hsv(c[:3])
227 | arhsv = np.tile(chsv,nsc).reshape(nsc,3)
228 | arhsv[:,1] = np.linspace(chsv[1],0.25,nsc)
229 | arhsv[:,2] = np.linspace(chsv[2],1,nsc)
230 | rgb = matplotlib.colors.hsv_to_rgb(arhsv)
231 | cols[i*nsc:(i+1)*nsc,:] = rgb
232 | cmap = matplotlib.colors.ListedColormap(cols)
233 | return cmap
234 |
235 | c1 = categorical_cmap(5,1, cmap="tab10")
236 | plt.scatter(np.arange(5*1),np.ones(5*1)+1, c=np.arange(5*1), s=180, cmap=c1)
237 |
238 |
239 |
240 |
241 |
--------------------------------------------------------------------------------
/examples/source-interdiction/training-notebooks/cae-manual-full.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "GPU_device_id = str(4)\n",
10 | "model_id_save_as = 'caepretrain-full-final'\n",
11 | "architecture_id = '../hyperparameter_search/hyperparameter-search-results/CNN-kfoldsfull-final-2-reluupdate_28'\n",
12 | "model_class_id = 'CAE'\n",
13 | "training_dataset_id = '../dataset_generation/hyperparametersearch_dataset_200keV_full_log10time_1000.npy'\n",
14 | "difficulty_setting = 'full'\n",
15 | "\n",
16 | "earlystop_patience = 10\n",
17 | "num_epochs = 2000"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 4,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import matplotlib.pyplot as plt\n",
27 | "import os\n",
28 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\" # see issue #152\n",
29 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = GPU_device_id\n",
30 | "\n",
31 | "\n",
32 | "from sklearn.pipeline import make_pipeline\n",
33 | "from sklearn.preprocessing import FunctionTransformer, LabelBinarizer\n",
34 | "from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit\n",
35 | "import tensorflow as tf\n",
36 | "import pickle\n",
37 | "import numpy as np\n",
38 | "import pandas as pd\n",
39 | "from random import choice\n",
40 | "\n",
41 | "from numpy.random import seed\n",
42 | "seed(5)\n",
43 | "from tensorflow import set_random_seed\n",
44 | "set_random_seed(5)\n",
45 | "\n"
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "metadata": {},
51 | "source": [
52 | "#### Import model, training function "
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 5,
58 | "metadata": {},
59 | "outputs": [
60 | {
61 | "name": "stderr",
62 | "output_type": "stream",
63 | "text": [
64 | "Using TensorFlow backend.\n"
65 | ]
66 | }
67 | ],
68 | "source": [
69 | "from annsa.model_classes import compile_model, f1, build_cae_model, mean_normalized_kl_divergence\n",
70 | "from annsa.load_dataset import load_easy, load_full, dataset_to_spectrakeys\n",
71 | "from annsa.load_pretrained_network import load_features"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "## Training Data Construction"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 6,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "training_dataset = np.load(training_dataset_id, allow_pickle=True)\n",
88 | "training_source_spectra, training_background_spectra, training_keys = dataset_to_spectrakeys(training_dataset,\n",
89 | " sampled=False,\n",
90 | " separate_background=True)"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "## Load Model"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 8,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "model_features = load_features(architecture_id)\n",
107 | "model_features.loss = tf.keras.losses.mean_squared_error\n",
108 | "model_features.optimizer = tf.keras.optimizers.Adam\n",
109 | "model_features.input_dim = 1024"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 9,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "model_features.pool_sizes = model_features.pool_size\n",
119 | "\n",
120 | "cae_features = model_features.to_cae_model_features()\n",
121 | "cae_features.metrics = ['mse']"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "metadata": {},
127 | "source": [
128 | "## Train network"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "# Scale input data"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": 11,
141 | "metadata": {},
142 | "outputs": [
143 | {
144 | "name": "stderr",
145 | "output_type": "stream",
146 | "text": [
147 | "/home/ubuntu/anaconda3/envs/tensorflow_p36_update/lib/python3.6/site-packages/sklearn/preprocessing/_function_transformer.py:161: RuntimeWarning: invalid value encountered in sqrt\n",
148 | " **(kw_args if kw_args else {}))\n"
149 | ]
150 | }
151 | ],
152 | "source": [
153 | "training_input = np.random.poisson(training_source_spectra+training_background_spectra)\n",
154 | "training_output = training_source_spectra\n",
155 | "\n",
156 | "training_input = cae_features.scaler.transform(training_input)\n",
157 | "training_output = cae_features.scaler.transform(training_output)\n"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 15,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "earlystop_callback = tf.keras.callbacks.EarlyStopping(\n",
167 | " monitor='val_mean_squared_error',\n",
168 | " patience=earlystop_patience,\n",
169 | " mode='min',\n",
170 | " restore_best_weights=True)\n",
171 | "\n",
172 | "csv_logger = tf.keras.callbacks.CSVLogger('./final-models-keras/'+model_id_save_as+'.log')"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 16,
178 | "metadata": {
179 | "scrolled": false
180 | },
181 | "outputs": [
182 | {
183 | "name": "stdout",
184 | "output_type": "stream",
185 | "text": [
186 | "Train on 27000 samples, validate on 3000 samples\n",
187 | "Epoch 1/500\n",
188 | "27000/27000 [==============================] - 19s 707us/sample - loss: 0.0067 - mean_squared_error: 0.0067 - val_loss: 0.0062 - val_mean_squared_error: 0.0062\n",
189 | "Epoch 2/500\n",
190 | "27000/27000 [==============================] - 17s 634us/sample - loss: 0.0046 - mean_squared_error: 0.0046 - val_loss: 0.0062 - val_mean_squared_error: 0.0062\n",
191 | "Epoch 3/500\n",
192 | "27000/27000 [==============================] - 17s 634us/sample - loss: 0.0043 - mean_squared_error: 0.0043 - val_loss: 0.0051 - val_mean_squared_error: 0.0051\n",
193 | "Epoch 4/500\n",
194 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0041 - mean_squared_error: 0.0041 - val_loss: 0.0044 - val_mean_squared_error: 0.0044\n",
195 | "Epoch 5/500\n",
196 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0040 - mean_squared_error: 0.0040 - val_loss: 0.0052 - val_mean_squared_error: 0.0052\n",
197 | "Epoch 6/500\n",
198 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0039 - mean_squared_error: 0.0039 - val_loss: 0.0038 - val_mean_squared_error: 0.0038\n",
199 | "Epoch 7/500\n",
200 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0039 - mean_squared_error: 0.0039 - val_loss: 0.0042 - val_mean_squared_error: 0.0042\n",
201 | "Epoch 8/500\n",
202 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0038 - mean_squared_error: 0.0038 - val_loss: 0.0048 - val_mean_squared_error: 0.0048\n",
203 | "Epoch 9/500\n",
204 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0037 - mean_squared_error: 0.0037 - val_loss: 0.0041 - val_mean_squared_error: 0.0041\n",
205 | "Epoch 10/500\n",
206 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0036 - mean_squared_error: 0.0036 - val_loss: 0.0049 - val_mean_squared_error: 0.0049\n",
207 | "Epoch 11/500\n",
208 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0035 - mean_squared_error: 0.0035 - val_loss: 0.0042 - val_mean_squared_error: 0.0042\n",
209 | "Epoch 12/500\n",
210 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0033 - mean_squared_error: 0.0033 - val_loss: 0.0041 - val_mean_squared_error: 0.0041\n",
211 | "Epoch 13/500\n",
212 | "27000/27000 [==============================] - 17s 636us/sample - loss: 0.0032 - mean_squared_error: 0.0032 - val_loss: 0.0043 - val_mean_squared_error: 0.0043\n",
213 | "Epoch 14/500\n",
214 | "27000/27000 [==============================] - 17s 636us/sample - loss: 0.0031 - mean_squared_error: 0.0031 - val_loss: 0.0048 - val_mean_squared_error: 0.0048\n",
215 | "Epoch 15/500\n",
216 | "27000/27000 [==============================] - 17s 635us/sample - loss: 0.0030 - mean_squared_error: 0.0030 - val_loss: 0.0043 - val_mean_squared_error: 0.0043\n",
217 | "Epoch 16/500\n",
218 | "27000/27000 [==============================] - 17s 645us/sample - loss: 0.0030 - mean_squared_error: 0.0030 - val_loss: 0.0041 - val_mean_squared_error: 0.0041\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "mlb=LabelBinarizer()\n",
224 | "model = compile_model(\n",
225 | " build_cae_model,\n",
226 | " cae_features)\n",
227 | "\n",
228 | "output = model.fit(\n",
229 | " x=training_input,\n",
230 | " y=training_output,\n",
231 | " batch_size=model_features.batch_size,\n",
232 | " validation_split=0.1,\n",
233 | " epochs=500,\n",
234 | " verbose=1,\n",
235 | " shuffle=True,\n",
236 | " callbacks=[earlystop_callback, ],\n",
237 | ")\n"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": 17,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "model.save('./final-models-keras/'+model_id_save_as+'.hdf5')"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": []
255 | }
256 | ],
257 | "metadata": {
258 | "kernelspec": {
259 | "display_name": "Environment (conda_tensorflow_p36_update)",
260 | "language": "python",
261 | "name": "conda_tensorflow_p36_update"
262 | },
263 | "language_info": {
264 | "codemirror_mode": {
265 | "name": "ipython",
266 | "version": 3
267 | },
268 | "file_extension": ".py",
269 | "mimetype": "text/x-python",
270 | "name": "python",
271 | "nbconvert_exporter": "python",
272 | "pygments_lexer": "ipython3",
273 | "version": "3.6.7"
274 | }
275 | },
276 | "nbformat": 4,
277 | "nbformat_minor": 1
278 | }
279 |
--------------------------------------------------------------------------------
/annsa/tests/data_folder/gadras_template.spe:
--------------------------------------------------------------------------------
1 | $SPEC_ID:
2 | UXRAY,1uC
3 | $MEAS_TIM:
4 | 3599.79 3600.00
5 | $DATE_MEA:
6 | 8/15/2019 5:55:17 PM
7 | $DATA:
8 | 0 1023
9 | -0.000
10 | 0.000
11 | 0.000
12 | -0.000
13 | -0.000
14 | 0.000
15 | -0.000
16 | 0.000
17 | 0.314
18 | 14.109
19 | 56.162
20 | 68.364
21 | 77.340
22 | 95.332
23 | 118.244
24 | 143.160
25 | 172.189
26 | 238.420
27 | 413.388
28 | 701.352
29 | 926.234
30 | 962.682
31 | 922.008
32 | 866.506
33 | 781.751
34 | 881.396
35 | 1526.520
36 | 2543.326
37 | 2910.432
38 | 2130.944
39 | 1139.935
40 | 797.316
41 | 818.621
42 | 677.317
43 | 371.459
44 | 133.169
45 | 31.347
46 | 4.806
47 | 0.458
48 | 0.016
49 | -0.002
50 | 0.000
51 | -0.000
52 | 0.000
53 | 0
54 | 0
55 | 0
56 | 0
57 | 0
58 | 0
59 | 0
60 | 0
61 | 0
62 | 0
63 | 0
64 | 0
65 | 0
66 | 0
67 | 0
68 | 0
69 | 0
70 | 0
71 | 0
72 | 0
73 | 0
74 | 0
75 | 0
76 | 0
77 | 0
78 | 0
79 | 0
80 | 0
81 | 0
82 | 0
83 | 0
84 | 0
85 | 0
86 | 0
87 | 0
88 | 0
89 | 0
90 | 0
91 | 0
92 | 0
93 | 0
94 | 0
95 | 0
96 | 0
97 | 0
98 | 0
99 | 0
100 | 0
101 | 0
102 | 0
103 | 0
104 | 0
105 | 0
106 | 0
107 | 0
108 | 0
109 | 0
110 | 0
111 | 0
112 | 0
113 | 0
114 | 0
115 | 0
116 | 0
117 | 0
118 | 0
119 | 0
120 | 0
121 | 0
122 | 0
123 | 0
124 | 0
125 | 0
126 | 0
127 | 0
128 | 0
129 | 0
130 | 0
131 | 0
132 | 0
133 | 0
134 | 0
135 | 0
136 | 0
137 | 0
138 | 0
139 | 0
140 | 0
141 | 0
142 | 0
143 | 0
144 | 0
145 | 0
146 | 0
147 | 0
148 | 0
149 | 0
150 | 0
151 | 0
152 | 0
153 | 0
154 | 0
155 | 0
156 | 0
157 | 0
158 | 0
159 | 0
160 | 0
161 | 0
162 | 0
163 | 0
164 | 0
165 | 0
166 | 0
167 | 0
168 | 0
169 | 0
170 | 0
171 | 0
172 | 0
173 | 0
174 | 0
175 | 0
176 | 0
177 | 0
178 | 0
179 | 0
180 | 0
181 | 0
182 | 0
183 | 0
184 | 0
185 | 0
186 | 0
187 | 0
188 | 0
189 | 0
190 | 0
191 | 0
192 | 0
193 | 0
194 | 0
195 | 0
196 | 0
197 | 0
198 | 0
199 | 0
200 | 0
201 | 0
202 | 0
203 | 0
204 | 0
205 | 0
206 | 0
207 | 0
208 | 0
209 | 0
210 | 0
211 | 0
212 | 0
213 | 0
214 | 0
215 | 0
216 | 0
217 | 0
218 | 0
219 | 0
220 | 0
221 | 0
222 | 0
223 | 0
224 | 0
225 | 0
226 | 0
227 | 0
228 | 0
229 | 0
230 | 0
231 | 0
232 | 0
233 | 0
234 | 0
235 | 0
236 | 0
237 | 0
238 | 0
239 | 0
240 | 0
241 | 0
242 | 0
243 | 0
244 | 0
245 | 0
246 | 0
247 | 0
248 | 0
249 | 0
250 | 0
251 | 0
252 | 0
253 | 0
254 | 0
255 | 0
256 | 0
257 | 0
258 | 0
259 | 0
260 | 0
261 | 0
262 | 0
263 | 0
264 | 0
265 | 0
266 | 0
267 | 0
268 | 0
269 | 0
270 | 0
271 | 0
272 | 0
273 | 0
274 | 0
275 | 0
276 | 0
277 | 0
278 | 0
279 | 0
280 | 0
281 | 0
282 | 0
283 | 0
284 | 0
285 | 0
286 | 0
287 | 0
288 | 0
289 | 0
290 | 0
291 | 0
292 | 0
293 | 0
294 | 0
295 | 0
296 | 0
297 | 0
298 | 0
299 | 0
300 | 0
301 | 0
302 | 0
303 | 0
304 | 0
305 | 0
306 | 0
307 | 0
308 | 0
309 | 0
310 | 0
311 | 0
312 | 0
313 | 0
314 | 0
315 | 0
316 | 0
317 | 0
318 | 0
319 | 0
320 | 0
321 | 0
322 | 0
323 | 0
324 | 0
325 | 0
326 | 0
327 | 0
328 | 0
329 | 0
330 | 0
331 | 0
332 | 0
333 | 0
334 | 0
335 | 0
336 | 0
337 | 0
338 | 0
339 | 0
340 | 0
341 | 0
342 | 0
343 | 0
344 | 0
345 | 0
346 | 0
347 | 0
348 | 0
349 | 0
350 | 0
351 | 0
352 | 0
353 | 0
354 | 0
355 | 0
356 | 0
357 | 0
358 | 0
359 | 0
360 | 0
361 | 0
362 | 0
363 | 0
364 | 0
365 | 0
366 | 0
367 | 0
368 | 0
369 | 0
370 | 0
371 | 0
372 | 0
373 | 0
374 | 0
375 | 0
376 | 0
377 | 0
378 | 0
379 | 0
380 | 0
381 | 0
382 | 0
383 | 0
384 | 0
385 | 0
386 | 0
387 | 0
388 | 0
389 | 0
390 | 0
391 | 0
392 | 0
393 | 0
394 | 0
395 | 0
396 | 0
397 | 0
398 | 0
399 | 0
400 | 0
401 | 0
402 | 0
403 | 0
404 | 0
405 | 0
406 | 0
407 | 0
408 | 0
409 | 0
410 | 0
411 | 0
412 | 0
413 | 0
414 | 0
415 | 0
416 | 0
417 | 0
418 | 0
419 | 0
420 | 0
421 | 0
422 | 0
423 | 0
424 | 0
425 | 0
426 | 0
427 | 0
428 | 0
429 | 0
430 | 0
431 | 0
432 | 0
433 | 0
434 | 0
435 | 0
436 | 0
437 | 0
438 | 0
439 | 0
440 | 0
441 | 0
442 | 0
443 | 0
444 | 0
445 | 0
446 | 0
447 | 0
448 | 0
449 | 0
450 | 0
451 | 0
452 | 0
453 | 0
454 | 0
455 | 0
456 | 0
457 | 0
458 | 0
459 | 0
460 | 0
461 | 0
462 | 0
463 | 0
464 | 0
465 | 0
466 | 0
467 | 0
468 | 0
469 | 0
470 | 0
471 | 0
472 | 0
473 | 0
474 | 0
475 | 0
476 | 0
477 | 0
478 | 0
479 | 0
480 | 0
481 | 0
482 | 0
483 | 0
484 | 0
485 | 0
486 | 0
487 | 0
488 | 0
489 | 0
490 | 0
491 | 0
492 | 0
493 | 0
494 | 0
495 | 0
496 | 0
497 | 0
498 | 0
499 | 0
500 | 0
501 | 0
502 | 0
503 | 0
504 | 0
505 | 0
506 | 0
507 | 0
508 | 0
509 | 0
510 | 0
511 | 0
512 | 0
513 | 0
514 | 0
515 | 0
516 | 0
517 | 0
518 | 0
519 | 0
520 | 0
521 | 0
522 | 0
523 | 0
524 | 0
525 | 0
526 | 0
527 | 0
528 | 0
529 | 0
530 | 0
531 | 0
532 | 0
533 | 0
534 | 0
535 | 0
536 | 0
537 | 0
538 | 0
539 | 0
540 | 0
541 | 0
542 | 0
543 | 0
544 | 0
545 | 0
546 | 0
547 | 0
548 | 0
549 | 0
550 | 0
551 | 0
552 | 0
553 | 0
554 | 0
555 | 0
556 | 0
557 | 0
558 | 0
559 | 0
560 | 0
561 | 0
562 | 0
563 | 0
564 | 0
565 | 0
566 | 0
567 | 0
568 | 0
569 | 0
570 | 0
571 | 0
572 | 0
573 | 0
574 | 0
575 | 0
576 | 0
577 | 0
578 | 0
579 | 0
580 | 0
581 | 0
582 | 0
583 | 0
584 | 0
585 | 0
586 | 0
587 | 0
588 | 0
589 | 0
590 | 0
591 | 0
592 | 0
593 | 0
594 | 0
595 | 0
596 | 0
597 | 0
598 | 0
599 | 0
600 | 0
601 | 0
602 | 0
603 | 0
604 | 0
605 | 0
606 | 0
607 | 0
608 | 0
609 | 0
610 | 0
611 | 0
612 | 0
613 | 0
614 | 0
615 | 0
616 | 0
617 | 0
618 | 0
619 | 0
620 | 0
621 | 0
622 | 0
623 | 0
624 | 0
625 | 0
626 | 0
627 | 0
628 | 0
629 | 0
630 | 0
631 | 0
632 | 0
633 | 0
634 | 0
635 | 0
636 | 0
637 | 0
638 | 0
639 | 0
640 | 0
641 | 0
642 | 0
643 | 0
644 | 0
645 | 0
646 | 0
647 | 0
648 | 0
649 | 0
650 | 0
651 | 0
652 | 0
653 | 0
654 | 0
655 | 0
656 | 0
657 | 0
658 | 0
659 | 0
660 | 0
661 | 0
662 | 0
663 | 0
664 | 0
665 | 0
666 | 0
667 | 0
668 | 0
669 | 0
670 | 0
671 | 0
672 | 0
673 | 0
674 | 0
675 | 0
676 | 0
677 | 0
678 | 0
679 | 0
680 | 0
681 | 0
682 | 0
683 | 0
684 | 0
685 | 0
686 | 0
687 | 0
688 | 0
689 | 0
690 | 0
691 | 0
692 | 0
693 | 0
694 | 0
695 | 0
696 | 0
697 | 0
698 | 0
699 | 0
700 | 0
701 | 0
702 | 0
703 | 0
704 | 0
705 | 0
706 | 0
707 | 0
708 | 0
709 | 0
710 | 0
711 | 0
712 | 0
713 | 0
714 | 0
715 | 0
716 | 0
717 | 0
718 | 0
719 | 0
720 | 0
721 | 0
722 | 0
723 | 0
724 | 0
725 | 0
726 | 0
727 | 0
728 | 0
729 | 0
730 | 0
731 | 0
732 | 0
733 | 0
734 | 0
735 | 0
736 | 0
737 | 0
738 | 0
739 | 0
740 | 0
741 | 0
742 | 0
743 | 0
744 | 0
745 | 0
746 | 0
747 | 0
748 | 0
749 | 0
750 | 0
751 | 0
752 | 0
753 | 0
754 | 0
755 | 0
756 | 0
757 | 0
758 | 0
759 | 0
760 | 0
761 | 0
762 | 0
763 | 0
764 | 0
765 | 0
766 | 0
767 | 0
768 | 0
769 | 0
770 | 0
771 | 0
772 | 0
773 | 0
774 | 0
775 | 0
776 | 0
777 | 0
778 | 0
779 | 0
780 | 0
781 | 0
782 | 0
783 | 0
784 | 0
785 | 0
786 | 0
787 | 0
788 | 0
789 | 0
790 | 0
791 | 0
792 | 0
793 | 0
794 | 0
795 | 0
796 | 0
797 | 0
798 | 0
799 | 0
800 | 0
801 | 0
802 | 0
803 | 0
804 | 0
805 | 0
806 | 0
807 | 0
808 | 0
809 | 0
810 | 0
811 | 0
812 | 0
813 | 0
814 | 0
815 | 0
816 | 0
817 | 0
818 | 0
819 | 0
820 | 0
821 | 0
822 | 0
823 | 0
824 | 0
825 | 0
826 | 0
827 | 0
828 | 0
829 | 0
830 | 0
831 | 0
832 | 0
833 | 0
834 | 0
835 | 0
836 | 0
837 | 0
838 | 0
839 | 0
840 | 0
841 | 0
842 | 0
843 | 0
844 | 0
845 | 0
846 | 0
847 | 0
848 | 0
849 | 0
850 | 0
851 | 0
852 | 0
853 | 0
854 | 0
855 | 0
856 | 0
857 | 0
858 | 0
859 | 0
860 | 0
861 | 0
862 | 0
863 | 0
864 | 0
865 | 0
866 | 0
867 | 0
868 | 0
869 | 0
870 | 0
871 | 0
872 | 0
873 | 0
874 | 0
875 | 0
876 | 0
877 | 0
878 | 0
879 | 0
880 | 0
881 | 0
882 | 0
883 | 0
884 | 0
885 | 0
886 | 0
887 | 0
888 | 0
889 | 0
890 | 0
891 | 0
892 | 0
893 | 0
894 | 0
895 | 0
896 | 0
897 | 0
898 | 0
899 | 0
900 | 0
901 | 0
902 | 0
903 | 0
904 | 0
905 | 0
906 | 0
907 | 0
908 | 0
909 | 0
910 | 0
911 | 0
912 | 0
913 | 0
914 | 0
915 | 0
916 | 0
917 | 0
918 | 0
919 | 0
920 | 0
921 | 0
922 | 0
923 | 0
924 | 0
925 | 0
926 | 0
927 | 0
928 | 0
929 | 0
930 | 0
931 | 0
932 | 0
933 | 0
934 | 0
935 | 0
936 | 0
937 | 0
938 | 0
939 | 0
940 | 0
941 | 0
942 | 0
943 | 0
944 | 0
945 | 0
946 | 0
947 | 0
948 | 0
949 | 0
950 | 0
951 | 0
952 | 0
953 | 0
954 | 0
955 | 0
956 | 0
957 | 0
958 | 0
959 | 0
960 | 0
961 | 0
962 | 0
963 | 0
964 | 0
965 | 0
966 | 0
967 | 0
968 | 0
969 | 0
970 | 0
971 | 0
972 | 0
973 | 0
974 | 0
975 | 0
976 | 0
977 | 0
978 | 0
979 | 0
980 | 0
981 | 0
982 | 0
983 | 0
984 | 0
985 | 0
986 | 0
987 | 0
988 | 0
989 | 0
990 | 0
991 | 0
992 | 0
993 | 0
994 | 0
995 | 0
996 | 0
997 | 0
998 | 0
999 | 0
1000 | 0
1001 | 0
1002 | 0
1003 | 0
1004 | 0
1005 | 0
1006 | 0
1007 | 0
1008 | 0
1009 | 0
1010 | 0
1011 | 0
1012 | 0
1013 | 0
1014 | 0
1015 | 0
1016 | 0
1017 | 0
1018 | 0
1019 | 0
1020 | 0
1021 | 0
1022 | 0
1023 | 0
1024 | 0
1025 | 0
1026 | 0
1027 | 0
1028 | 0
1029 | 0
1030 | 0
1031 | 0
1032 | 0
1033 | $MCA_CAL:
1034 | 3
1035 | -1.7089844 3.4179688 0.00000000000000
1036 | $ENER_FIT:
1037 | -1.7089844 3.4179688 0.00000000000000
1038 | $MCA_CAL_DATA:
1039 | 3
1040 | 256 873
1041 | 512 1748
1042 | 768 2623
1043 | $CAMBIO:
1044 | Parent file specification: C:\GADRAS-DRF\Temp\uxray.pcf
1045 | Successful read of spectrum: Yes
1046 | Error message: None
1047 | Warning: None
1048 | Information: None
1049 | Neutrons: 0
1050 | Number of spectral records in the parent file: 1
1051 | Record number of this spectrum in the parent file: 1
1052 | Date and time of acquisition read successfully from the original file: Yes
1053 | Live time read successfully from the original file: Yes
1054 | True time read successfully from the original file: Yes
1055 | GADRAS Tag character is:
1056 | Version: 130122
1057 | Parent file type: Sandia GADRAS PCF
1058 |
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | #
4 | # annsa documentation build configuration file, created by
5 | # sphinx-quickstart on Tue Apr 14 10:29:06 2015.
6 | #
7 | # This file is execfile()d with the current directory set to its
8 | # containing dir.
9 | #
10 | # Note that not all possible configuration values are present in this
11 | # autogenerated file.
12 | #
13 | # All configuration values have a default; values that are commented out
14 | # serve to show the default.
15 |
16 | import sys
17 | import os
18 |
19 | # General information about the project.
20 | project = 'annsa'
21 | copyright = '2019, Advanced Reactors and Fuel Cycles group'
22 |
23 | currentdir = os.path.abspath(os.path.dirname(__file__))
24 | ver_file = os.path.join(currentdir, '..', project, 'version.py')
25 | with open(ver_file) as f:
26 | exec(f.read())
27 | source_version = __version__
28 |
29 | currentdir = os.path.abspath(os.path.dirname(__file__))
30 | sys.path.append(os.path.join(currentdir, 'tools'))
31 |
32 | # If extensions (or modules to document with autodoc) are in another directory,
33 | # add these directories to sys.path here. If the directory is relative to the
34 | # documentation root, use os.path.abspath to make it absolute, like shown here.
35 | sys.path.insert(0, os.path.abspath('../'))
36 |
37 | # -- General configuration ------------------------------------------------
38 |
39 | # If your documentation needs a minimal Sphinx version, state it here.
40 | needs_sphinx = '1.0' # numpydoc requires sphinc >= 1.0
41 |
42 | # Add any Sphinx extension module names here, as strings. They can be
43 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
44 | # ones.
45 | sys.path.append(os.path.abspath('sphinxext'))
46 |
47 | extensions = ['sphinx.ext.autodoc',
48 | 'sphinx.ext.doctest',
49 | 'sphinx.ext.intersphinx',
50 | 'sphinx.ext.todo',
51 | 'sphinx.ext.coverage',
52 | 'sphinx.ext.ifconfig',
53 | 'sphinx.ext.autosummary',
54 | 'sphinx.ext.mathjax',
55 | 'math_dollar', # has to go before numpydoc
56 | 'numpydoc',
57 | 'github',
58 | 'sphinx_gallery.gen_gallery']
59 |
60 | # Add any paths that contain templates here, relative to this directory.
61 | templates_path = ['_templates']
62 |
63 | # The suffix of source filenames.
64 | source_suffix = '.rst'
65 |
66 | # The encoding of source files.
67 | # source_encoding = 'utf-8-sig'
68 |
69 | # The master toctree document.
70 | master_doc = 'index'
71 |
72 | # --- Sphinx Gallery ---
73 | sphinx_gallery_conf = {
74 | # path to your examples scripts
75 | 'examples_dirs': '../examples',
76 | # path where to save gallery generated examples
77 | 'gallery_dirs': 'auto_examples',
78 | # To auto-generate example sections in the API
79 | 'doc_module': ('annsa',),
80 | # Auto-generated mini-galleries go here
81 | 'backreferences_dir': 'gen_api'
82 | }
83 |
84 | # Automatically generate stub pages for API
85 | autosummary_generate = True
86 | autodoc_default_flags = ['members', 'inherited-members']
87 |
88 | # The version info for the project you're documenting, acts as replacement for
89 | # |version| and |release|, also used in various other places throughout the
90 | # built documents.
91 | #
92 | # The short X.Y version.
93 | version = '0.1'
94 | # The full version, including alpha/beta/rc tags.
95 | release = '0.1'
96 |
97 | # The language for content autogenerated by Sphinx. Refer to documentation
98 | # for a list of supported languages.
99 | #language = None
100 |
101 | # There are two options for replacing |today|: either, you set today to some
102 | # non-false value, then it is used:
103 | #today = ''
104 | # Else, today_fmt is used as the format for a strftime call.
105 | #today_fmt = '%B %d, %Y'
106 |
107 | # List of patterns, relative to source directory, that match files and
108 | # directories to ignore when looking for source files.
109 | exclude_patterns = ['_build']
110 |
111 | # The reST default role (used for this markup: `text`) to use for all
112 | # documents.
113 | #default_role = None
114 |
115 | # If true, '()' will be appended to :func: etc. cross-reference text.
116 | #add_function_parentheses = True
117 |
118 | # If true, the current module name will be prepended to all description
119 | # unit titles (such as .. function::).
120 | #add_module_names = True
121 |
122 | # If true, sectionauthor and moduleauthor directives will be shown in the
123 | # output. They are ignored by default.
124 | #show_authors = False
125 |
126 | # The name of the Pygments (syntax highlighting) style to use.
127 | pygments_style = 'sphinx'
128 |
129 | # A list of ignored prefixes for module index sorting.
130 | #modindex_common_prefix = []
131 |
132 | # If true, keep warnings as "system message" paragraphs in the built documents.
133 | #keep_warnings = False
134 |
135 |
136 | # -- Options for HTML output ----------------------------------------------
137 |
138 | # The theme to use for HTML and HTML Help pages. See the documentation for
139 | # a list of builtin themes.
140 | html_theme = 'alabaster'
141 |
142 | # Theme options are theme-specific and customize the look and feel of a theme
143 | # further. For a list of options available for each theme, see the
144 | # documentation.
145 | #html_theme_options = {}
146 |
147 | # Add any paths that contain custom themes here, relative to this directory.
148 | #html_theme_path = []
149 |
150 | # The name for this set of Sphinx documents. If None, it defaults to
151 | # " v documentation".
152 | #html_title = None
153 |
154 | # A shorter title for the navigation bar. Default is the same as html_title.
155 | #html_short_title = None
156 |
157 | # The name of an image file (relative to this directory) to place at the top
158 | # of the sidebar.
159 | html_logo = '_static/logo.png'
160 |
161 | # The name of an image file (within the static path) to use as favicon of the
162 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
163 | # pixels large.
164 | #html_favicon = None
165 |
166 | # Add any paths that contain custom static files (such as style sheets) here,
167 | # relative to this directory. They are copied after the builtin static files,
168 | # so a file named "default.css" will overwrite the builtin "default.css".
169 | html_static_path = ['_static']
170 |
171 | # Add any extra paths that contain custom files (such as robots.txt or
172 | # .htaccess) here, relative to this directory. These files are copied
173 | # directly to the root of the documentation.
174 | #html_extra_path = []
175 |
176 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
177 | # using the given strftime format.
178 | #html_last_updated_fmt = '%b %d, %Y'
179 |
180 | # If true, SmartyPants will be used to convert quotes and dashes to
181 | # typographically correct entities.
182 | #html_use_smartypants = True
183 |
184 | # Custom sidebar templates, maps document names to template names.
185 | html_sidebars = {'**': ['localtoc.html', 'searchbox.html']}
186 |
187 | # Additional templates that should be rendered to pages, maps page names to
188 | # template names.
189 | #html_additional_pages = {}
190 |
191 | # If false, no module index is generated.
192 | #html_domain_indices = True
193 | html_domain_indices = False
194 |
195 | # If false, no index is generated.
196 | #html_use_index = True
197 |
198 | # If true, the index is split into individual pages for each letter.
199 | #html_split_index = False
200 |
201 | # If true, links to the reST sources are added to the pages.
202 | #html_show_sourcelink = True
203 |
204 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
205 | #html_show_sphinx = True
206 |
207 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
208 | #html_show_copyright = True
209 |
210 | # If true, an OpenSearch description file will be output, and all pages will
211 | # contain a tag referring to it. The value of this option must be the
212 | # base URL from which the finished HTML is served.
213 | #html_use_opensearch = ''
214 |
215 | # This is the file name suffix for HTML files (e.g. ".xhtml").
216 | #html_file_suffix = None
217 |
218 | # Output file base name for HTML help builder.
219 | htmlhelp_basename = 'annsadoc'
220 |
221 |
222 | # -- Options for LaTeX output ---------------------------------------------
223 |
224 | latex_elements = {
225 | # The paper size ('letterpaper' or 'a4paper').
226 | #'papersize': 'letterpaper',
227 |
228 | # The font size ('10pt', '11pt' or '12pt').
229 | #'pointsize': '10pt',
230 |
231 | # Additional stuff for the LaTeX preamble.
232 | #'preamble': '',
233 | }
234 |
235 | # grouping the document tree into LaTeX files. List of tuples
236 | # (source start file, target name, title,
237 | # author, documentclass [howto, manual, or own class]).
238 | latex_documents = [
239 | ('index', 'annsa.tex', 'annsa Documentation',
240 | 'Advanced Reactors and Fuel Cycles group', 'manual'),
241 | ]
242 |
243 | # The name of an image file (relative to this directory) to place at the top of
244 | # the title page.
245 | #latex_logo = None
246 |
247 | # For "manual" documents, if this is true, then toplevel headings are parts,
248 | # not chapters.
249 | #latex_use_parts = False
250 |
251 | # If true, show page references after internal links.
252 | #latex_show_pagerefs = False
253 |
254 | # If true, show URL addresses after external links.
255 | #latex_show_urls = False
256 |
257 | # Documents to append as an appendix to all manuals.
258 | #latex_appendices = []
259 |
260 | # If false, no module index is generated.
261 | #latex_domain_indices = True
262 |
263 |
264 | # -- Options for manual page output ---------------------------------------
265 |
266 | # One entry per manual page. List of tuples
267 | # (source start file, name, description, authors, manual section).
268 | man_pages = [
269 | ('index', 'annsa', 'annsa Documentation',
270 | ['Advanced Reactors and Fuel Cycles group'], 1)
271 | ]
272 |
273 | # If true, show URL addresses after external links.
274 | #man_show_urls = False
275 |
276 |
277 | # -- Options for Texinfo output -------------------------------------------
278 |
279 | # grouping the document tree into Texinfo files. List of tuples
280 | # (source start file, target name, title, author,
281 | # dir menu entry, description, category)
282 | texinfo_documents = [
283 | ('index', 'annsa', 'annsa Documentation',
284 | 'Advanced Reactors and Fuel Cycles group', 'annsa', 'Artificial neural networks for spectral analysis.',
285 | 'Physics'),
286 | ]
287 |
288 | # Documents to append as an appendix to all manuals.
289 | #texinfo_appendices = []
290 |
291 | # If false, no module index is generated.
292 | texinfo_domain_indices = False
293 |
294 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
295 | #texinfo_show_urls = 'footnote'
296 |
297 | # If true, do not generate a @detailmenu in the "Top" node's menu.
298 | #texinfo_no_detailmenu = False
299 |
300 | # Example configuration for intersphinx: refer to the Python standard library.
301 | intersphinx_mapping = {'http://docs.python.org/': None}
302 |
--------------------------------------------------------------------------------