├── .coveragerc
├── .github
└── workflows
│ ├── ci.yml
│ └── python-publish.yml
├── .gitignore
├── DESCRIPTION.md
├── LICENSE
├── README.md
├── environment.yml
├── examples
├── audio
│ ├── devel_dog.wav
│ ├── devel_seagull.wav
│ ├── test_dog.wav
│ ├── test_seagull.wav
│ ├── train_dog.wav
│ └── train_seagull.wav
├── class_config.json
├── hp_config.json
└── label.csv
├── notebooks
└── explain
│ ├── environment.yml
│ ├── model
│ └── ccs
│ │ └── model.h5
│ ├── shap.ipynb
│ ├── shap
│ └── ccs
│ │ └── 100
│ │ └── shap_ccs_100.pdf
│ └── spectrograms
│ └── ccs
│ ├── Negative
│ └── test_010.png
│ └── Positive
│ └── test_039.png
├── pyproject.toml
├── requirements-test.txt
├── requirements.txt
├── setup.cfg
├── setup.py
├── src
└── deepspectrumlite
│ ├── __init__.py
│ ├── __main__.py
│ ├── cli
│ ├── __init__.py
│ ├── config
│ │ ├── .gitignore
│ │ ├── class_config.json
│ │ └── hp_config.json
│ ├── convert.py
│ ├── create_preprocessor.py
│ ├── devel_test.py
│ ├── predict.py
│ ├── stats.py
│ ├── tflite_stats.py
│ ├── train.py
│ └── utils.py
│ └── lib
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ ├── data_pipeline.py
│ ├── embedded
│ │ ├── __init__.py
│ │ └── preprocessor.py
│ ├── parser
│ │ ├── ComParEParser.py
│ │ └── __init__.py
│ └── plot
│ │ ├── __init__.py
│ │ ├── color_maps
│ │ ├── __init__.py
│ │ ├── abstract_colormap.py
│ │ ├── cividis.py
│ │ ├── inferno.py
│ │ ├── magma.py
│ │ ├── plasma.py
│ │ └── viridis.py
│ │ └── colormap.py
│ ├── hyperparameter.py
│ ├── model
│ ├── TransferBaseModel.py
│ ├── __init__.py
│ ├── ai_model.py
│ ├── config
│ │ ├── __init__.py
│ │ └── gridsearch.py
│ └── modules
│ │ ├── __init__.py
│ │ ├── arelu.py
│ │ ├── attention_module.py
│ │ ├── augmentable_model.py
│ │ └── squeeze_net.py
│ └── util
│ ├── __init__.py
│ └── audio_utils.py
└── tests
├── __init__.py
└── cli
├── __init__.py
├── audio
├── dog
│ └── file0.wav
└── seagull
│ └── file1.wav
├── config
├── cividis_hp_config.json
├── inferno_hp_config.json
├── magma_hp_config.json
├── plasma_hp_config.json
├── regression_class_config.json
└── regression_hp_config.json
├── label
└── regression_label.csv
├── test_cli.py
└── test_train_config.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | source = deepspectrumlite
3 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: CI
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 |
12 | jobs:
13 | pip:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: [3.8]
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install libsndfile
26 | run: sudo apt-get install libsndfile1-dev
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | pip install .
31 | if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi
32 | - name: Test with pytest
33 | run: |
34 | py.test -s tests/ --cov deepspectrumlite --cov-report xml:coverage.xml
35 | - name: Coverage
36 | uses: codecov/codecov-action@v1
37 | with:
38 | file: ./coverage.xml # optional
39 | fail_ci_if_error: true # optional (default = false)
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.8'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.idea
3 | *.DS_Store
4 | *.tflite
5 | venv
6 | env
7 | *.gz
8 | /output/*
9 | recorded
10 | __pycache__
11 | build/*
12 | dist/*
13 | *.egg-info
14 | .coverage
15 | .coverage.*
16 | *.pytest_cache
17 | .direnv/
18 |
--------------------------------------------------------------------------------
/DESCRIPTION.md:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | DeepSpectrumLite is a Python toolkit to design and train light-weight Deep Neural Networks (DNNs) for classification tasks from raw audio data .
3 | The trained models run on embedded devices.
4 |
5 | DeepSpectrumLite features an extraction pipeline which first creates visual representations for audio data - plots of spectrograms.
6 | The image splots are then fed to a DNN. This could be a pre-trained Image Convolutional Neural Network (CNN).
7 | Activations of a specific layer then form the final feature vectors which are used for the final classification.
8 |
9 | The trained models can be easily converted to a TensorFlow Lite model. During the converting process, the model becomes smaller and faster optimised for inference on embedded devices.
10 |
11 | **(c) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, Sandra Ottl, Björn Schuller: Universität Augsburg**
12 | Published under GPLv3, please see the `LICENSE` file for details.
13 |
14 | Please direct any questions or requests to Shahin Amiriparian (shahin.amiriparian at informatik.uni-augsburg.de) or Tobias Hübner (tobias.huebner at informatik.uni-augsburg.de).
15 |
16 | # Why DeepSpectrumLite?
17 | DeepSpectrumLite is built upon TensorFlow Lite which is a specialised version of TensorFlow that supports embedded decvies.
18 | However, TensorFlow Lite does not support all basic TensorFlow functions for audio signal processing and plot image generation. DeepSpectrumLite offers implementations for unsupported functions.
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: DeepSpectrumLite
2 | channels:
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python=3.8
7 | - pip
8 | - cudatoolkit=11.2
9 | - cudnn=8.1
10 | - pip:
11 | - click
12 | - librosa
13 | - numba
14 | - pillow
15 | - pandas
16 | - scikit-learn
17 | - tensorflow==2.5.1
18 | - tensorboard==2.5
19 | - keras-applications
20 | - -e .
21 |
--------------------------------------------------------------------------------
/examples/audio/devel_dog.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/devel_dog.wav
--------------------------------------------------------------------------------
/examples/audio/devel_seagull.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/devel_seagull.wav
--------------------------------------------------------------------------------
/examples/audio/test_dog.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/test_dog.wav
--------------------------------------------------------------------------------
/examples/audio/test_seagull.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/test_seagull.wav
--------------------------------------------------------------------------------
/examples/audio/train_dog.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/train_dog.wav
--------------------------------------------------------------------------------
/examples/audio/train_seagull.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/train_seagull.wav
--------------------------------------------------------------------------------
/examples/class_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "dog": "Dog",
3 | "seagull": "Seagull"
4 | }
--------------------------------------------------------------------------------
/examples/hp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"],
3 | "model_name": ["TransferBaseModel"],
4 | "prediction_type": ["categorical"],
5 | "basemodel_name": ["densenet121"],
6 | "weights": ["imagenet"],
7 | "tb_experiment": ["densenet_exp"],
8 | "tb_run_id": ["densenet121_run"],
9 | "num_units": [4],
10 | "dropout": [0.25],
11 | "optimizer": ["adadelta"],
12 | "learning_rate": [0.001],
13 | "fine_learning_rate": [0.0001],
14 | "finetune_layer": [0.7],
15 | "loss": ["categorical_crossentropy"],
16 | "activation": ["arelu"],
17 | "pre_epochs": [1],
18 | "epochs": [1],
19 | "batch_size": [1],
20 |
21 | "sample_rate": [16000],
22 |
23 | "chunk_size": [2.0],
24 | "chunk_hop_size": [1.0],
25 | "normalize_audio": [true],
26 |
27 | "stft_window_size": [0.128],
28 | "stft_hop_size": [0.064],
29 | "stft_fft_length": [0.128],
30 |
31 | "mel_scale": [true],
32 | "lower_edge_hertz": [0.0],
33 | "upper_edge_hertz": [8000.0],
34 | "num_mel_bins": [128],
35 | "num_mfccs": [0],
36 | "cep_lifter": [0],
37 | "db_scale": [true],
38 | "use_plot_images": [true],
39 | "color_map": ["viridis"],
40 | "image_width": [224],
41 | "image_height": [224],
42 | "resize_method": ["nearest"],
43 | "anti_alias": [false],
44 |
45 | "sap_aug_a": [0.5],
46 | "sap_aug_s": [10],
47 | "augment_cutmix": [true],
48 | "augment_specaug": [true],
49 | "da_prob_min": [0.1],
50 | "da_prob_max": [0.5],
51 | "cutmix_min": [0.075],
52 | "cutmix_max": [0.25],
53 | "specaug_freq_min": [0.1],
54 | "specaug_freq_max": [0.3],
55 | "specaug_time_min": [0.1],
56 | "specaug_time_max": [0.3],
57 | "specaug_freq_mask_num": [1],
58 | "specaug_time_mask_num": [1]
59 | }
--------------------------------------------------------------------------------
/examples/label.csv:
--------------------------------------------------------------------------------
1 | filename,label,duration_frames
2 | train_seagull.wav,seagull,117249
3 | train_dog.wav,dog,87865
4 | devel_seagull.wav,seagull,117249
5 | devel_dog.wav,dog,87865
6 | test_seagull.wav,seagull,117249
7 | test_dog.wav,dog,87865
--------------------------------------------------------------------------------
/notebooks/explain/environment.yml:
--------------------------------------------------------------------------------
1 | name: shap
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=conda_forge
8 | - _openmp_mutex=4.5=1_gnu
9 | - alsa-lib=1.2.3=h516909a_0
10 | - aom=3.2.0=h9c3ff4c_2
11 | - asttokens=2.0.5=pyhd8ed1ab_0
12 | - backcall=0.2.0=pyh9f0ad1d_0
13 | - backports=1.0=py_2
14 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
15 | - bzip2=1.0.8=h7f98852_4
16 | - c-ares=1.18.1=h7f98852_0
17 | - ca-certificates=2021.10.8=ha878542_0
18 | - cairo=1.16.0=h6cf1ce9_1008
19 | - certifi=2021.10.8=py38h578d9bd_1
20 | - dbus=1.13.6=h5008d03_3
21 | - debugpy=1.5.1=py38h709712a_0
22 | - decorator=5.1.1=pyhd8ed1ab_0
23 | - entrypoints=0.4=pyhd8ed1ab_0
24 | - executing=0.8.2=pyhd8ed1ab_0
25 | - expat=2.4.4=h9c3ff4c_0
26 | - ffmpeg=4.4.1=h6987444_1
27 | - fontconfig=2.13.96=ha180cfb_0
28 | - freetype=2.10.4=h0708190_1
29 | - gettext=0.19.8.1=h73d1719_1008
30 | - gmp=6.2.1=h58526e2_0
31 | - gnutls=3.6.13=h85f3911_1
32 | - graphite2=1.3.13=h58526e2_1001
33 | - gst-plugins-base=1.18.5=hf529b03_3
34 | - gstreamer=1.18.5=h9f60fe5_3
35 | - harfbuzz=2.9.1=h83ec7ef_1
36 | - hdf5=1.12.1=nompi_h2750804_103
37 | - icu=68.2=h9c3ff4c_0
38 | - ipykernel=6.9.1=py38he5a9106_0
39 | - ipython=8.0.1=py38h578d9bd_1
40 | - jasper=1.900.1=h07fcdf6_1006
41 | - jbig=2.1=h7f98852_2003
42 | - jedi=0.18.1=py38h578d9bd_0
43 | - jpeg=9e=h7f98852_0
44 | - jupyter_client=7.1.2=pyhd8ed1ab_0
45 | - jupyter_core=4.9.2=py38h578d9bd_0
46 | - krb5=1.19.2=hcc1bbae_3
47 | - lame=3.100=h7f98852_1001
48 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
49 | - lerc=3.0=h9c3ff4c_0
50 | - libblas=3.9.0=13_linux64_openblas
51 | - libcblas=3.9.0=13_linux64_openblas
52 | - libclang=11.1.0=default_ha53f305_1
53 | - libcurl=7.81.0=h2574ce0_0
54 | - libdeflate=1.10=h7f98852_0
55 | - libdrm=2.4.109=h7f98852_0
56 | - libedit=3.1.20191231=he28a2e2_2
57 | - libev=4.33=h516909a_1
58 | - libevent=2.1.10=h9b69904_4
59 | - libffi=3.4.2=h7f98852_5
60 | - libgcc-ng=11.2.0=h1d223b6_12
61 | - libgfortran-ng=11.2.0=h69a702a_12
62 | - libgfortran5=11.2.0=h5c6108e_12
63 | - libglib=2.70.2=h174f98d_4
64 | - libgomp=11.2.0=h1d223b6_12
65 | - libiconv=1.16=h516909a_0
66 | - liblapack=3.9.0=13_linux64_openblas
67 | - liblapacke=3.9.0=13_linux64_openblas
68 | - libllvm11=11.1.0=hf817b99_3
69 | - libnghttp2=1.46.0=h812cca2_0
70 | - libnsl=2.0.0=h7f98852_0
71 | - libogg=1.3.4=h7f98852_1
72 | - libopenblas=0.3.18=pthreads_h8fe5266_0
73 | - libopencv=4.5.3=py38hafa78d9_3
74 | - libopus=1.3.1=h7f98852_1
75 | - libpciaccess=0.16=h516909a_0
76 | - libpng=1.6.37=h21135ba_2
77 | - libpq=13.5=hd57d9b9_1
78 | - libprotobuf=3.18.1=h780b84a_0
79 | - libsodium=1.0.18=h36c2ea0_1
80 | - libssh2=1.10.0=ha56f1ee_2
81 | - libstdcxx-ng=11.2.0=he4da1e4_12
82 | - libtiff=4.3.0=h542a066_3
83 | - libuuid=2.32.1=h7f98852_1000
84 | - libva=2.14.0=h7f98852_0
85 | - libvorbis=1.3.7=h9c3ff4c_0
86 | - libvpx=1.11.0=h9c3ff4c_3
87 | - libwebp-base=1.2.2=h7f98852_1
88 | - libxcb=1.13=h7f98852_1004
89 | - libxkbcommon=1.0.3=he3ba5ed_0
90 | - libxml2=2.9.12=h72842e0_0
91 | - libzlib=1.2.11=h36c2ea0_1013
92 | - lz4-c=1.9.3=h9c3ff4c_1
93 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0
94 | - mysql-common=8.0.28=ha770c72_0
95 | - mysql-libs=8.0.28=hfa10184_0
96 | - ncurses=6.3=h9c3ff4c_0
97 | - nest-asyncio=1.5.4=pyhd8ed1ab_0
98 | - nettle=3.6=he412f7d_0
99 | - nspr=4.32=h9c3ff4c_1
100 | - nss=3.74=hb5efdd6_0
101 | - opencv=4.5.3=py38h578d9bd_3
102 | - openh264=2.1.1=h780b84a_0
103 | - openssl=1.1.1l=h7f98852_0
104 | - parso=0.8.3=pyhd8ed1ab_0
105 | - pcre=8.45=h9c3ff4c_0
106 | - pexpect=4.8.0=pyh9f0ad1d_2
107 | - pickleshare=0.7.5=py_1003
108 | - pip=22.0.3=pyhd8ed1ab_0
109 | - pixman=0.40.0=h36c2ea0_0
110 | - prompt-toolkit=3.0.27=pyha770c72_0
111 | - pthread-stubs=0.4=h36c2ea0_1001
112 | - ptyprocess=0.7.0=pyhd3deb0d_0
113 | - pure_eval=0.2.2=pyhd8ed1ab_0
114 | - py-opencv=4.5.3=py38he5a9106_3
115 | - pygments=2.11.2=pyhd8ed1ab_0
116 | - python=3.8.12=ha38a3c6_3_cpython
117 | - python-dateutil=2.8.2=pyhd8ed1ab_0
118 | - python_abi=3.8=2_cp38
119 | - pyzmq=22.3.0=py38h2035c66_1
120 | - qt=5.12.9=hda022c4_4
121 | - readline=8.1=h46c0cb4_0
122 | - setuptools=60.9.2=py38h578d9bd_0
123 | - sqlite=3.37.0=h9cd32fc_0
124 | - stack_data=0.2.0=pyhd8ed1ab_0
125 | - svt-av1=0.9.0=h9c3ff4c_0
126 | - tk=8.6.12=h27826a3_0
127 | - tornado=6.1=py38h497a2fe_2
128 | - traitlets=5.1.1=pyhd8ed1ab_0
129 | - wcwidth=0.2.5=pyh9f0ad1d_2
130 | - wheel=0.37.1=pyhd8ed1ab_0
131 | - x264=1!161.3030=h7f98852_1
132 | - x265=3.5=h4bd325d_1
133 | - xorg-fixesproto=5.0=h7f98852_1002
134 | - xorg-kbproto=1.0.7=h7f98852_1002
135 | - xorg-libice=1.0.10=h7f98852_0
136 | - xorg-libsm=1.2.3=hd9c2040_1000
137 | - xorg-libx11=1.7.2=h7f98852_0
138 | - xorg-libxau=1.0.9=h7f98852_0
139 | - xorg-libxdmcp=1.1.3=h7f98852_0
140 | - xorg-libxext=1.3.4=h7f98852_1
141 | - xorg-libxfixes=5.0.3=h7f98852_1004
142 | - xorg-libxrender=0.9.10=h7f98852_1003
143 | - xorg-renderproto=0.11.1=h7f98852_1002
144 | - xorg-xextproto=7.3.0=h7f98852_1002
145 | - xorg-xproto=7.0.31=h7f98852_1007
146 | - xz=5.2.5=h516909a_1
147 | - zeromq=4.3.4=h9c3ff4c_1
148 | - zlib=1.2.11=h36c2ea0_1013
149 | - zstd=1.5.2=ha95c52a_0
150 | - pip:
151 | - absl-py==0.15.0
152 | - appdirs==1.4.4
153 | - astunparse==1.6.3
154 | - audioread==2.1.9
155 | - cachetools==4.2.4
156 | - cffi==1.15.0
157 | - charset-normalizer==2.0.11
158 | - click==8.0.3
159 | - cloudpickle==2.0.0
160 | - cycler==0.11.0
161 | - deepspectrumlite==1.0.2
162 | - flatbuffers==1.12
163 | - fonttools==4.29.1
164 | - gast==0.4.0
165 | - google-auth==1.35.0
166 | - google-auth-oauthlib==0.4.6
167 | - google-pasta==0.2.0
168 | - grpcio==1.34.1
169 | - h5py==3.1.0
170 | - idna==3.3
171 | - importlib-metadata==4.10.1
172 | - joblib==1.1.0
173 | - keras-applications==1.0.8
174 | - keras-nightly==2.5.0.dev2021032900
175 | - keras-preprocessing==1.1.2
176 | - kiwisolver==1.3.2
177 | - librosa==0.9.0
178 | - llvmlite==0.38.0
179 | - markdown==3.3.6
180 | - matplotlib==3.5.1
181 | - numba==0.55.1
182 | - numpy==1.19.5
183 | - oauthlib==3.2.0
184 | - opt-einsum==3.3.0
185 | - p2j==1.3.2
186 | - packaging==21.3
187 | - pandas==1.4.0
188 | - pillow==9.0.1
189 | - pooch==1.6.0
190 | - protobuf==3.19.4
191 | - pyasn1==0.4.8
192 | - pyasn1-modules==0.2.8
193 | - pycparser==2.21
194 | - pyparsing==3.0.7
195 | - pytz==2021.3
196 | - requests==2.27.1
197 | - requests-oauthlib==1.3.1
198 | - resampy==0.2.2
199 | - rsa==4.8
200 | - scikit-learn==1.0.2
201 | - scipy==1.8.0
202 | - shap==0.40.0
203 | - six==1.15.0
204 | - slicer==0.0.7
205 | - soundfile==0.10.3.post1
206 | - tensorboard==2.5.0
207 | - tensorboard-data-server==0.6.1
208 | - tensorboard-plugin-wit==1.8.1
209 | - tensorflow==2.5.1
210 | - tensorflow-estimator==2.5.0
211 | - termcolor==1.1.0
212 | - threadpoolctl==3.1.0
213 | - tqdm==4.62.3
214 | - typing-extensions==3.7.4.3
215 | - urllib3==1.26.8
216 | - werkzeug==2.0.3
217 | - wrapt==1.12.1
218 | - zipp==3.7.0
--------------------------------------------------------------------------------
/notebooks/explain/model/ccs/model.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/model/ccs/model.h5
--------------------------------------------------------------------------------
/notebooks/explain/shap/ccs/100/shap_ccs_100.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/shap/ccs/100/shap_ccs_100.pdf
--------------------------------------------------------------------------------
/notebooks/explain/spectrograms/ccs/Negative/test_010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/spectrograms/ccs/Negative/test_010.png
--------------------------------------------------------------------------------
/notebooks/explain/spectrograms/ccs/Positive/test_039.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/spectrograms/ccs/Positive/test_039.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel"
5 | ]
6 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | pytest==6.2.3
2 | coveralls
3 | pytest-cov
4 | librosa
5 | numba
6 | click
7 | pillow
8 | pandas
9 | scikit-learn
10 | tensorflow==2.5.1
11 | tensorboard==2.5
12 | keras-applications
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa
2 | numba
3 | click
4 | pillow
5 | pandas
6 | scikit-learn
7 | tensorflow==2.5.1
8 | tensorboard==2.5
9 | keras-applications
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | license_files = LICENSE
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import re
3 | import sys
4 | import warnings
5 |
6 | warnings.filterwarnings('ignore', category=DeprecationWarning)
7 | warnings.filterwarnings('ignore', category=FutureWarning)
8 |
9 | from setuptools import setup, find_packages
10 | from subprocess import CalledProcessError, check_output
11 |
12 | PROJECT = "DeepSpectrumLite"
13 | VERSION = "1.0.2"
14 | LICENSE = "GPLv3+"
15 | AUTHOR = "Tobias Hübner"
16 | AUTHOR_EMAIL = "tobias.huebner@informatik.uni-augsburg.de"
17 | URL = 'https://github.com/DeepSpectrum/DeepSpectrumLite'
18 |
19 | with open("DESCRIPTION.md", "r") as fh:
20 | LONG_DESCRIPTION = fh.read()
21 |
22 | install_requires = [
23 | "librosa",
24 | "numba",
25 | "pillow",
26 | "pandas",
27 | "scikit-learn",
28 | "click",
29 | "tensorflow==2.5.1",
30 | "tensorboard==2.5",
31 | "keras-applications"
32 | ]
33 |
34 | tests_require = ['pytest>=4.4.1', 'pytest-cov>=2.7.1']
35 | needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv)
36 | setup_requires = ['pytest-runner'] if needs_pytest else []
37 | packages = find_packages('src')
38 |
39 | setup(
40 | name=PROJECT,
41 | version=VERSION,
42 | license=LICENSE,
43 | author=AUTHOR,
44 | author_email=AUTHOR_EMAIL,
45 | long_description=LONG_DESCRIPTION,
46 | long_description_content_type="text/markdown",
47 | descrption="DeepSpectrumLite is a Python toolkit for training light-weight CNN networks targeted at embedded devices.",
48 | platforms=["Any"],
49 | scripts=[],
50 | provides=[],
51 | python_requires="~=3.8.0",
52 | install_requires=install_requires,
53 | setup_requires=setup_requires,
54 | tests_require=tests_require,
55 | namespace_packages=[],
56 | packages=packages,
57 | package_dir={'': 'src'},
58 | include_package_data=True,
59 | entry_points={
60 | "console_scripts": [
61 | "deepspectrumlite = deepspectrumlite.__main__:cli",
62 | ]
63 | },
64 | classifiers=[
65 | # How mature is this project? Common values are
66 | # 3 - Alpha
67 | # 4 - Beta
68 | # 5 - Production/Stable
69 | 'Development Status :: 4 - Beta',
70 |
71 | 'Environment :: GPU :: NVIDIA CUDA :: 11.0',
72 | # Indicate who your project is intended for
73 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
74 | 'Intended Audience :: Science/Research',
75 |
76 | # Pick your license as you wish (should match "license" above)
77 | 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
78 |
79 | 'Programming Language :: Python :: 3.8',
80 | ],
81 | keywords='machine-learning audio-analysis science research',
82 | project_urls={
83 | 'Source': 'https://github.com/DeepSpectrum/DeepSpectrumLite',
84 | 'Tracker': 'https://github.com/DeepSpectrum/DeepSpectrumLite/issues',
85 | },
86 | url=URL,
87 | zip_safe=False,
88 | )
--------------------------------------------------------------------------------
/src/deepspectrumlite/__init__.py:
--------------------------------------------------------------------------------
1 | from .lib.util import *
2 | import deepspectrumlite.lib.data.plot as plot
3 | from .lib import HyperParameterList
4 | from .lib.model.ai_model import Model
5 | from .lib.model.TransferBaseModel import TransferBaseModel
6 | from .lib.data.embedded.preprocessor import *
7 | from .lib.data.data_pipeline import DataPipeline
8 | from .lib.model.modules.augmentable_model import *
9 | from .lib.model.config.gridsearch import *
10 | from .lib.model.modules.arelu import *
11 | from .lib.model.modules.squeeze_net import *
12 | import logging
13 | import sys
14 |
15 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
16 | __version__ = '1.0.2'
17 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/__main__.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import sys, os
20 |
21 | import warnings
22 | warnings.filterwarnings('ignore', category=DeprecationWarning)
23 | warnings.filterwarnings('ignore', category=FutureWarning)
24 |
25 | import click
26 | import logging
27 | import logging.config
28 | import pkg_resources
29 |
30 | from deepspectrumlite.cli.train import train
31 | from deepspectrumlite.cli.devel_test import devel_test
32 | from deepspectrumlite.cli.stats import stats
33 | from deepspectrumlite.cli.tflite_stats import tflite_stats
34 | from deepspectrumlite.cli.create_preprocessor import create_preprocessor
35 | from deepspectrumlite.cli.convert import convert
36 | from deepspectrumlite.cli.predict import predict
37 | from deepspectrumlite.cli.utils import add_options
38 | from deepspectrumlite import __version__ as VERSION
39 |
40 |
41 | _global_options = [
42 | click.option('-v', '--verbose', count=True),
43 | ]
44 |
45 |
46 | version_str = f"DeepSpectrumLite %(version)s\nCopyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, Sandra Ottl, " \
47 | "Björn Schuller\n" \
48 | "License GPLv3+: GNU GPL version 3 or later .\n" \
49 | "This is free software: you are free to change and redistribute it.\n" \
50 | "There is NO WARRANTY, to the extent permitted by law."
51 |
52 | @click.group()
53 | @add_options(_global_options)
54 | @click.version_option(VERSION, message=version_str)
55 | @click.pass_context
56 | def cli(ctx, verbose):
57 | log_levels = ['ERROR', 'INFO', 'DEBUG']
58 | verbose = min(2, verbose)
59 | ctx.ensure_object(dict)
60 | ctx.obj['verbose'] = verbose
61 |
62 | if verbose == 2:
63 | level = logging.DEBUG
64 | elif verbose == 1:
65 | level = logging.INFO
66 | else:
67 | level = logging.ERROR
68 | logging.basicConfig()
69 | logging.config.dictConfig({
70 | 'version': 1,
71 | 'disable_existing_loggers': False, # this fixes the problem
72 | 'formatters': {
73 | 'standard': {
74 | 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s'
75 | },
76 | },
77 | 'handlers': {
78 | 'default': {
79 | 'level': log_levels[verbose],
80 | 'class': 'logging.StreamHandler',
81 | 'formatter': 'standard',
82 | 'stream': 'ext://sys.stdout'
83 | },
84 | },
85 | 'loggers': {
86 | '': {
87 | 'handlers': ['default'],
88 | 'level': log_levels[verbose],
89 | 'propagate': True
90 | }
91 | }
92 | })
93 |
94 | logging.debug('Verbosity: %s' % log_levels[verbose])
95 | # logging.error("error test")
96 | # logging.debug("debug test")
97 | # logging.info("info test")
98 |
99 | os.environ['GLOG_minloglevel'] = '2'
100 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
101 |
102 | cli.add_command(train)
103 | cli.add_command(devel_test)
104 | cli.add_command(stats)
105 | cli.add_command(convert)
106 | cli.add_command(tflite_stats)
107 | cli.add_command(create_preprocessor)
108 | cli.add_command(predict)
109 |
110 | if __name__ == '__main__':
111 | cli(obj={})
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/cli/__init__.py
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/config/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !class_config.json
4 | !hp_config.json
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/config/class_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "negative": "Negative",
3 | "positive": "Positive"
4 | }
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/config/hp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"],
3 | "model_name": ["TransferBaseModel"],
4 | "prediction_type": ["categorical"],
5 | "basemodel_name": ["densenet121"],
6 | "weights": ["imagenet"],
7 | "tb_experiment": ["densenet_exp"],
8 | "tb_run_id": ["densenet121_run"],
9 | "num_units": [512],
10 | "dropout": [0.25],
11 | "optimizer": ["adadelta"],
12 | "learning_rate": [0.001],
13 | "fine_learning_rate": [0.0001],
14 | "finetune_layer": [0.7],
15 | "loss": ["categorical_crossentropy"],
16 | "activation": ["arelu"],
17 | "output_activation": ["softmax"],
18 | "pre_epochs": [40],
19 | "epochs": [100],
20 | "batch_size": [32],
21 |
22 | "sample_rate": [16000],
23 |
24 | "chunk_size": [4.0],
25 | "chunk_hop_size": [2.0],
26 | "normalize_audio": [false],
27 |
28 |
29 | "stft_window_size": [0.128],
30 | "stft_hop_size": [0.064],
31 | "stft_fft_length": [0.128],
32 |
33 | "mel_scale": [true],
34 | "lower_edge_hertz": [0.0],
35 | "upper_edge_hertz": [8000.0],
36 | "num_mel_bins": [128],
37 | "num_mfccs": [0],
38 | "cep_lifter": [0],
39 | "db_scale": [true],
40 | "use_plot_images": [true],
41 | "color_map": ["viridis"],
42 | "image_width": [224],
43 | "image_height": [224],
44 | "resize_method": ["nearest"],
45 | "anti_alias": [false],
46 |
47 | "sap_aug_a": [0.5],
48 | "sap_aug_s": [10],
49 | "augment_cutmix": [true],
50 | "augment_specaug": [true],
51 | "da_prob_min": [0.1],
52 | "da_prob_max": [0.5],
53 | "cutmix_min": [0.075],
54 | "cutmix_max": [0.25],
55 | "specaug_freq_min": [0.1],
56 | "specaug_freq_max": [0.3],
57 | "specaug_time_min": [0.1],
58 | "specaug_time_max": [0.3],
59 | "specaug_freq_mask_num": [1],
60 | "specaug_time_mask_num": [1]
61 | }
62 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/convert.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import logging
20 | import click
21 | from .utils import add_options
22 | import os
23 | import sys
24 | import tensorflow as tf
25 | from tensorflow import keras
26 | import math
27 | from tensorflow.keras import backend as K
28 | from tensorflow.python.saved_model import loader_impl
29 | from deepspectrumlite import AugmentableModel, ARelu
30 | import numpy as np
31 | from os.path import join, dirname, realpath
32 |
33 | log = logging.getLogger(__name__)
34 |
35 | _DESCRIPTION = 'Converts a DeepSpectrumLite model to a TFLite model file.'
36 |
37 | @add_options(
38 | [
39 | click.option(
40 | "-s",
41 | "--source",
42 | type=click.Path(exists=True, writable=False, readable=True),
43 | help="Source HD5 model file",
44 | required=True
45 | ),
46 | click.option(
47 | "-d",
48 | "--destination",
49 | type=click.Path(exists=False, writable=True, readable=True),
50 | help="Destination TFLite model file",
51 | required=True
52 | )
53 | ]
54 | )
55 |
56 | @click.command(help=_DESCRIPTION)
57 | def convert(source, destination, **kwargs):
58 | log.info("Load model: " + source)
59 |
60 | # loader_impl.parse_saved_model(source)
61 |
62 | new_model = tf.keras.models.load_model(source,
63 | custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu},
64 | compile=False)
65 | log.info("Successfully loaded model: " + source)
66 |
67 | converter = tf.lite.TFLiteConverter.from_keras_model(new_model)
68 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
69 | tf.lite.OpsSet.SELECT_TF_OPS]
70 | converter.experimental_new_converter = True
71 | tflite_quant_model = converter.convert()
72 | open(destination, "wb").write(tflite_quant_model)
73 |
74 | log.info("Model was saved as tflite as " + destination)
75 |
76 | interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
77 |
78 | input_details = interpreter.get_input_details()
79 | output_details = interpreter.get_output_details()
80 |
81 | interpreter.allocate_tensors()
82 |
83 | # interpreter.set_tensor(input_details[0]['index'], tf.convert_to_tensor(np.expand_dims(audio_data, 0), dtype=tf.float32))
84 |
85 | interpreter.invoke()
86 |
87 | output = interpreter.get_tensor(output_details[0]['index'])
88 |
89 | # Test model on random input data.
90 | input_shape = input_details[0]['shape']
91 | log.info("input shape: ")
92 | log.info(input_shape)
93 | log.info("output shape: ",)
94 | log.info(output_details[0]['shape'])
95 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
96 | interpreter.set_tensor(input_details[0]['index'], input_data)
97 |
98 | interpreter.invoke()
99 | output_data = interpreter.get_tensor(output_details[0]['index'])
100 |
101 | log.info(output_data)
102 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/create_preprocessor.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import logging
20 | import click
21 | from .utils import add_options
22 | import os
23 | import sys
24 | import tensorflow as tf
25 | from tensorflow import keras
26 | import math
27 | from tensorflow.keras import backend as K
28 | from tensorflow.python.saved_model import loader_impl
29 | from deepspectrumlite import HyperParameterList, PreprocessAudio
30 | import time
31 | import numpy as np
32 | from os.path import join, dirname, realpath
33 |
34 | log = logging.getLogger(__name__)
35 |
36 | _DESCRIPTION = 'Creates a DeepSpectrumLite preprocessor TFLite file.'
37 |
38 | @add_options(
39 | [
40 | click.option(
41 | "-hc",
42 | "--hyper-config",
43 | type=click.Path(exists=True, writable=False, readable=True),
44 | help="Directory for the hyper parameter config file.",
45 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True
46 | ),
47 | click.option(
48 | "-d",
49 | "--destination",
50 | type=click.Path(exists=False, writable=True, readable=True),
51 | help="Destination of the TFLite preprocessor file",
52 | required=True
53 | )
54 | ]
55 | )
56 |
57 | @click.command(help=_DESCRIPTION)
58 | def create_preprocessor(hyper_config, destination, **kwargs):
59 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config)
60 | hparam_values = hyper_parameter_list.get_values(iteration_no=0)
61 | working_directory = dirname(destination)
62 |
63 | preprocess = PreprocessAudio(hparams=hparam_values, name="dsl_audio_preprocessor")
64 | input = tf.convert_to_tensor(np.array(np.random.random_sample((1, 16000)), dtype=np.float32), dtype=tf.float32)
65 | result = preprocess.preprocess(input)
66 |
67 | # ATTENTION: antialias is not supported in tflite
68 | tmp_save_path = os.path.join(working_directory, "preprocessor")
69 | os.makedirs(tmp_save_path, exist_ok=True)
70 | tf.saved_model.save(preprocess, tmp_save_path)
71 |
72 | # new_model = preprocess
73 | converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=tmp_save_path)
74 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
75 | tf.lite.OpsSet.SELECT_TF_OPS]
76 | converter.experimental_new_converter = True
77 | tflite_quant_model = converter.convert()
78 | open(destination, "wb").write(tflite_quant_model)
79 |
80 | interpreter = tf.lite.Interpreter(model_path=destination)
81 | input_details = interpreter.get_input_details()
82 | output_details = interpreter.get_output_details()
83 | log.info(input_details)
84 | log.info(output_details)
85 |
86 | interpreter.allocate_tensors()
87 |
88 | interpreter.set_tensor(input_details[0]['index'], tf.convert_to_tensor(np.array(np.random.random_sample((1, 16000)), dtype=np.float32), dtype=tf.float32))
89 |
90 | interpreter.invoke()
91 |
92 | output = interpreter.get_tensor(output_details[0]['index'])
93 |
94 | # Test model on random input data.
95 | input_shape = input_details[0]['shape']
96 | log.info("input shape:")
97 | log.info(input_shape)
98 | log.info("output shape:")
99 | log.info(output_details[0]['shape'])
100 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
101 | interpreter.set_tensor(input_details[0]['index'], input_data)
102 | start_time = time.time()
103 | interpreter.invoke()
104 | stop_time = time.time()
105 | output_data = interpreter.get_tensor(output_details[0]['index'])
106 |
107 | log.info(output_data)
108 | log.info('time: {:.3f}ms'.format((stop_time - start_time) * 1000))
109 | log.info("Finished creating the TFLite preprocessor")
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/devel_test.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import logging
20 | import click
21 | from .utils import add_options
22 | import os
23 | import sys
24 | import tensorflow as tf
25 | from tensorflow import keras
26 | import math
27 | from tensorflow.keras import backend as K
28 | from tensorflow.python.saved_model import loader_impl
29 | from deepspectrumlite import AugmentableModel, DataPipeline, HyperParameterList, ARelu
30 | import glob
31 | from pathlib import Path
32 | import numpy as np
33 | import importlib
34 | import json
35 | from sklearn.metrics import recall_score, classification_report, confusion_matrix
36 | import csv
37 | import json
38 | import pandas as pd
39 | from collections import Counter
40 | from os.path import join, dirname, realpath
41 |
42 | log = logging.getLogger(__name__)
43 |
44 | _DESCRIPTION = 'Test a DeepSpectrumLite transer learning model.'
45 |
46 | @add_options(
47 | [
48 | click.option(
49 | "-d",
50 | "--data-dir",
51 | type=click.Path(exists=True),
52 | help="Directory of data class categories containing folders of each data class.",
53 | required=True
54 | ),
55 | click.option(
56 | "-md",
57 | "--model-dir",
58 | type=click.Path(exists=False, writable=True),
59 | help="Path to HD5 model file",
60 | required=True
61 | ),
62 | click.option(
63 | "-hc",
64 | "--hyper-config",
65 | type=click.Path(exists=True, writable=False, readable=True),
66 | help="Directory for the hyper parameter config file.",
67 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True
68 | ),
69 | click.option(
70 | "-cc",
71 | "--class-config",
72 | type=click.Path(exists=True, writable=False, readable=True),
73 | help="Directory for the class config file.",
74 | default=join(dirname(realpath(__file__)), "config/class_config.json"), show_default=True
75 | ),
76 | click.option(
77 | "-l",
78 | "--label-file",
79 | type=click.Path(exists=True, writable=False, readable=True),
80 | help="Directory for the label file.",
81 | required=True
82 | )
83 | ]
84 | )
85 |
86 | @click.command(help=_DESCRIPTION)
87 | @click.pass_context
88 | def devel_test(ctx, model_dir, data_dir, class_config, hyper_config, label_file, **kwargs):
89 | verbose = ctx.obj['verbose']
90 | f = open(class_config)
91 | data = json.load(f)
92 | f.close()
93 |
94 | data_dir = os.path.join(data_dir, '')
95 |
96 | data_classes = data
97 |
98 | if data_classes is None:
99 | raise ValueError('no data classes defined')
100 |
101 | class_list = {}
102 | for i, data_class in enumerate(data_classes):
103 | class_list[data_class] = i
104 |
105 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config)
106 |
107 | log.info("Search by rule: " + model_dir)
108 | model_dir_list = glob.glob(model_dir)
109 | log.info("Found " + str(len(model_dir_list)) + " files")
110 |
111 | for model_filename in model_dir_list:
112 | log.info("Load " + model_filename)
113 | p = Path(model_filename)
114 | parent = p.parent
115 | directory = parent.name
116 |
117 | result_dir = os.path.join(parent, "evaluation")
118 |
119 | iteration_no = int(directory.split("_")[-1])
120 |
121 | log.info('--- Testing trial: %s' % iteration_no)
122 | hparam_values = hyper_parameter_list.get_values(iteration_no=iteration_no)
123 | log.info(hparam_values)
124 |
125 | label_parser_key = hparam_values['label_parser']
126 |
127 | if ":" not in label_parser_key:
128 | raise ValueError('Please provide the parser in the following format: path.to.parser_file.py:ParserClass')
129 |
130 | log.info(f'Using custom external parser: {label_parser_key}')
131 | path, class_name = label_parser_key.split(':')
132 | module_name = os.path.splitext(os.path.basename(path))[0]
133 | dir_path = os.path.dirname(os.path.realpath(__file__))
134 | path = os.path.join(dir_path, path)
135 | spec = importlib.util.spec_from_file_location(module_name, path)
136 | foo = importlib.util.module_from_spec(spec)
137 | spec.loader.exec_module(foo)
138 | parser_class = getattr(foo, class_name)
139 |
140 | parser = parser_class(file_path=label_file)
141 | _, devel_data, test_data = parser.parse_labels()
142 | log.info("Successfully parsed labels: " + label_file)
143 | model = tf.keras.models.load_model(model_filename,
144 | custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu},
145 | compile=False)
146 | model.set_hyper_parameters(hparam_values)
147 | log.info("Successfully loaded model: " + model_filename)
148 |
149 | dataset_list = ["devel", "test"]
150 |
151 | for dataset_name in dataset_list:
152 | log.info("===== Dataset Partition: " + dataset_name)
153 | data_raw = []
154 | if dataset_name == 'devel':
155 | data_raw = devel_data # [:10]
156 | elif dataset_name == 'test':
157 | data_raw = test_data # [:10]
158 |
159 | dataset_result_dir = os.path.join(result_dir, dataset_name)
160 |
161 | os.makedirs(dataset_result_dir, exist_ok=True)
162 |
163 | data_pipeline = DataPipeline(name=dataset_name + '_data_set', data_classes=data_classes,
164 | enable_gpu=True, verbose=True, enable_augmentation=False,
165 | hparams=hparam_values, run_id=iteration_no)
166 | data_pipeline.set_data(data_raw)
167 | data_pipeline.set_filename_prepend(prepend_filename_str=data_dir)
168 | data_pipeline.preprocess()
169 | filename_list = data_pipeline.filenames
170 | dataset = data_pipeline.pipeline(cache=False, shuffle=False, drop_remainder=False)
171 |
172 | X_pred = model.predict(x=dataset, verbose=verbose)
173 | true_categories = tf.concat([y for x, y in dataset], axis=0)
174 |
175 | X_pred = tf.argmax(X_pred, axis=1)
176 | X_pred_ny = X_pred.numpy()
177 |
178 | true_categories = tf.argmax(true_categories, axis=1)
179 | true_np = true_categories.numpy()
180 | cm = tf.math.confusion_matrix(true_categories, X_pred)
181 | log.info("Confusion Matrix (chunks):")
182 | log.info(cm.numpy())
183 |
184 | target_names = []
185 | for data_class in data_classes:
186 | target_names.append(data_class)
187 |
188 | log.info(classification_report(y_true=true_categories.numpy(), y_pred=X_pred_ny,
189 | target_names=target_names,
190 | digits=4))
191 |
192 | recall = recall_score(y_true=true_categories.numpy(), y_pred=X_pred_ny, average='macro')
193 | log.info("UAR: " + str(recall * 100))
194 |
195 | json_cm_dir = os.path.join(dataset_result_dir, dataset_name + ".chunks.metrics.json")
196 | with open(json_cm_dir, 'w') as f:
197 | json.dump({"cm": cm.numpy().tolist(), "uar": round(recall * 100, 4)}, f)
198 |
199 | X_pred_pd = pd.DataFrame(data=X_pred_ny, columns=["prediction"])
200 | pd_filename_list = pd.DataFrame(data=filename_list[..., 0], columns=["filename"])
201 |
202 | df = pd_filename_list.join(X_pred_pd, how='outer')
203 | df['filename'] = df['filename'].apply(lambda x: os.path.basename(x))
204 |
205 | df.to_csv(os.path.join(dataset_result_dir, dataset_name + ".chunks.predictions.csv"), index=False)
206 |
207 | ###### grouped #######
208 |
209 | grouped_data = df.groupby('filename', as_index=False).agg(lambda x: Counter(x).most_common(1)[0][0])
210 | grouped_data.to_csv(os.path.join(dataset_result_dir, dataset_name + ".grouped.predictions.csv"),
211 | index=False)
212 | grouped_X_pred = grouped_data.values[..., 1].tolist()
213 |
214 | # test
215 | pd_filename_list = pd.DataFrame(data=filename_list[..., 0], columns=["filename"])
216 | true_pd = pd.DataFrame(data=true_np, columns=["label"])
217 | df = pd_filename_list.join(true_pd, how='outer')
218 | df['filename'] = df['filename'].apply(lambda x: os.path.basename(x))
219 | data_raw_labels = df.groupby('filename', as_index=False).agg(lambda x: Counter(x).most_common(1)[0][0])
220 |
221 | # data_raw_labels = data_raw
222 | # data_raw_labels['label'] = data_raw_labels['label'].apply(lambda x: class_list[x])
223 | grouped_true = data_raw_labels.values[..., 1].tolist()
224 | cm = confusion_matrix(grouped_true, grouped_X_pred)
225 | log.info("Confusion Matrix (grouped):")
226 | log.info(cm)
227 |
228 | log.info(classification_report(y_true=grouped_true, y_pred=grouped_X_pred,
229 | target_names=target_names,
230 | digits=4))
231 |
232 | recall = recall_score(y_true=grouped_true, y_pred=grouped_X_pred, average='macro')
233 | log.info("UAR: " + str(recall * 100))
234 |
235 | json_cm_dir = os.path.join(dataset_result_dir, dataset_name + ".grouped.metrics.json")
236 | with open(json_cm_dir, 'w') as f:
237 | json.dump({"cm": cm.tolist(), "uar": round(recall * 100, 4)}, f)
238 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/predict.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import logging
20 | import click
21 | from .utils import add_options
22 | import os
23 | import sys
24 | import tensorflow as tf
25 | from tensorflow import keras
26 | import math
27 | from tensorflow.keras import backend as K
28 | from tensorflow.python.saved_model import loader_impl
29 | from deepspectrumlite import AugmentableModel, DataPipeline, HyperParameterList, ARelu
30 | import glob
31 | from pathlib import Path
32 | import numpy as np
33 | import importlib
34 | import json
35 | from sklearn.metrics import recall_score
36 | from sklearn.metrics import classification_report, confusion_matrix
37 | import csv
38 | import json
39 | import pandas as pd
40 | import librosa
41 | from collections import Counter
42 | from os.path import join, dirname, realpath
43 |
44 | log = logging.getLogger(__name__)
45 |
46 | _DESCRIPTION = 'Predict a file using an existing DeepSpectrumLite transer learning model.'
47 |
48 | @add_options(
49 | [
50 | click.option(
51 | "-d",
52 | "--data-dir",
53 | type=click.Path(exists=True),
54 | help="Directory of data class categories containing folders of each data class.",
55 | required=True
56 | ),
57 | click.option(
58 | "-md",
59 | "--model-dir",
60 | type=click.Path(exists=True, writable=False),
61 | help="HD5 file of the DeepSpectrumLite model",
62 | required=True
63 | ),
64 | click.option(
65 | "-hc",
66 | "--hyper-config",
67 | type=click.Path(exists=True, writable=False, readable=True),
68 | help="Directory for the hyper parameter config file.",
69 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True
70 | ),
71 | click.option(
72 | "-cc",
73 | "--class-config",
74 | type=click.Path(exists=True, writable=False, readable=True),
75 | help="Directory for the class config file.",
76 | default=join(dirname(realpath(__file__)), "config/class_config.json"), show_default=True
77 | )
78 | ]
79 | )
80 |
81 | @click.command(help=_DESCRIPTION)
82 | @click.pass_context
83 | def predict(ctx, model_dir, data_dir, class_config, hyper_config, **kwargs):
84 | verbose = ctx.obj['verbose']
85 | f = open(class_config)
86 | data = json.load(f)
87 | f.close()
88 |
89 | data_dir = os.path.join(data_dir, '')
90 |
91 | data_classes = data
92 | wav_files = sorted(glob.glob(f'{data_dir}/**/*.wav', recursive=True))
93 | filenames, labels, duration_frames = list(map(lambda x: os.path.relpath(x, start=data_dir), wav_files)), [list(data_classes.keys())[0]]*len(wav_files), []
94 | for fn in filenames:
95 | y, sr = librosa.load(os.path.join(data_dir, fn), sr=None)
96 | duration_frames.append(y.shape[0])
97 |
98 | log.info('Found %d wav files' % len(filenames))
99 |
100 | if data_classes is None:
101 | raise ValueError('no data classes defined')
102 |
103 | class_list = {}
104 | for i, data_class in enumerate(data_classes):
105 | class_list[data_class] = i
106 |
107 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config)
108 | log.info("Search within rule: " + model_dir)
109 | model_dir_list = glob.glob(model_dir)
110 | log.info("Found "+ str(len(model_dir_list)) + " files")
111 |
112 | for model_filename in model_dir_list:
113 | log.info("Load " + model_filename)
114 | p = Path(model_filename)
115 | parent = p.parent
116 | directory = parent.name
117 |
118 | result_dir = os.path.join(parent, "test")
119 | iteration_no = int(directory.split("_")[-1])
120 |
121 | log.info('--- Testing trial: %s' % iteration_no)
122 | hparam_values = hyper_parameter_list.get_values(iteration_no=iteration_no)
123 | log.info(hparam_values)
124 |
125 | test_data = pd.DataFrame({'filename': filenames, 'label': labels, 'duration_frames': duration_frames})
126 |
127 | print("Loading model: " + model_filename)
128 | model = tf.keras.models.load_model(model_filename,
129 | custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu},
130 | compile=False)
131 | model.set_hyper_parameters(hparam_values)
132 | log.info("Successfully loaded model: " + model_filename)
133 |
134 | data_raw = test_data # [:10]
135 | dataset_name = 'test'
136 |
137 | dataset_result_dir = os.path.join(result_dir, dataset_name)
138 |
139 | os.makedirs(dataset_result_dir, exist_ok=True)
140 |
141 | data_pipeline = DataPipeline(name=dataset_name+'_data_set', data_classes=data_classes,
142 | enable_gpu=True, verbose=True, enable_augmentation=False,
143 | hparams=hparam_values, run_id=iteration_no)
144 | data_pipeline.set_data(data_raw)
145 | data_pipeline.set_filename_prepend(prepend_filename_str=data_dir)
146 | data_pipeline.preprocess()
147 | filename_list = data_pipeline.filenames
148 | dataset = data_pipeline.pipeline(cache=False, shuffle=False, drop_remainder=False)
149 |
150 | X_probs = model.predict(x=dataset, verbose=verbose)
151 | true_categories = tf.concat([y for x, y in dataset], axis=0)
152 | X_pred = tf.argmax(X_probs, axis=1)
153 | X_pred_ny = X_pred.numpy()
154 |
155 |
156 | target_names = []
157 | for data_class in data_classes:
158 | target_names.append(data_class)
159 |
160 | df = pd.DataFrame(data=filename_list[...,0], columns=["filename"])
161 |
162 | df['filename'] = df['filename'].apply(lambda x: os.path.basename(x))
163 | df['time'] = list(map(lambda x: int(x)/sr, filename_list[...,1]))
164 | for i, target in enumerate(target_names):
165 | df[f'prob_{target}'] = X_probs[:, i]
166 | df['prediction'] = list(map(lambda x: target_names[x], X_pred))
167 |
168 | df.to_csv(os.path.join(dataset_result_dir, dataset_name+".chunks.predictions.csv"), index=False)
169 |
170 | log.info("Finished testing")
171 |
172 |
173 |
174 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/stats.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import logging
20 | import click
21 | from .utils import add_options
22 | import os
23 | import tensorflow as tf
24 | import numpy as np
25 | from deepspectrumlite import AugmentableModel, ARelu, HyperParameterList
26 |
27 | def get_detailed_stats(model_h5_path):
28 | session = tf.compat.v1.Session()
29 | graph = tf.compat.v1.get_default_graph()
30 |
31 | parent_dir = dirname(model_h5_path)
32 |
33 |
34 | with graph.as_default():
35 | with session.as_default():
36 | new_model = tf.keras.models.load_model(model_h5_path, custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, compile=False)
37 | new_model.summary(print_fn=log.info)
38 | run_meta = tf.compat.v1.RunMetadata()
39 | input_details = new_model.get_config()
40 | input_shape = input_details['layers'][0]['config']['batch_input_shape']
41 |
42 | _ = session.run(new_model.output, {
43 | 'input_1:0': np.random.normal(size=(1, input_shape[1], input_shape[2], input_shape[3]))},
44 | options=tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE),
45 | run_metadata=run_meta)
46 |
47 | # '''
48 | ProfileOptionBuilder = tf.compat.v1.profiler.ProfileOptionBuilder
49 | opts = ProfileOptionBuilder(ProfileOptionBuilder.time_and_memory()
50 | ).with_step(1).with_timeline_output('test.json').build()
51 |
52 | tf.compat.v1.profiler.profile(
53 | tf.compat.v1.get_default_graph(),
54 | run_meta=run_meta,
55 | cmd='code',
56 | options=opts)
57 | # '''
58 | # Print to stdout an analysis of the memory usage and the timing information
59 | # broken down by operation types.
60 | json_export = tf.compat.v1.profiler.profile(
61 | tf.compat.v1.get_default_graph(),
62 | run_meta=run_meta,
63 | cmd='op',
64 | options=tf.compat.v1.profiler.ProfileOptionBuilder.time_and_memory())
65 |
66 | text_file = open(os.path.join(parent_dir, "profiler.json"), "w")
67 | text_file.write(str(json_export))
68 | text_file.close()
69 | # print(json_export)
70 | tf.compat.v1.reset_default_graph()
71 |
72 | # deprecated
73 | def get_flops(model_h5_path): # pragma: no cover
74 | tf.compat.v1.enable_eager_execution()
75 | session = tf.compat.v1.Session()
76 | graph = tf.compat.v1.get_default_graph()
77 |
78 |
79 | with graph.as_default():
80 | with session.as_default():
81 | tf.keras.models.load_model(model_h5_path, custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, compile=False)
82 | run_meta = tf.compat.v1.RunMetadata()
83 | opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
84 |
85 | # Optional: save printed results to file
86 | # flops_log_path = os.path.join(tempfile.gettempdir(), 'tf_flops_log.txt')
87 | # opts['output'] = 'file:outfile={}'.format(flops_log_path)
88 |
89 | # We use the Keras session graph in the call to the profiler.
90 | flops = tf.compat.v1.profiler.profile(graph=graph,
91 | run_meta=run_meta, cmd='op', options=opts)
92 | tf.compat.v1.reset_default_graph()
93 | return flops.total_float_ops
94 |
95 | from os.path import join, dirname, realpath
96 |
97 | log = logging.getLogger(__name__)
98 |
99 | _DESCRIPTION = 'Retrieve statistics of a DeepSpectrumLite transer learning model.'
100 |
101 | @add_options(
102 | [
103 | click.option(
104 | "-md",
105 | "--model-dir",
106 | type=click.Path(exists=False, writable=True),
107 | help="Directory of a DeepSpectrumLite HD5 Model.",
108 | required=True
109 | )
110 | ]
111 | )
112 |
113 | @click.command(help=_DESCRIPTION)
114 | def stats(model_dir, **kwargs):
115 | tf.compat.v1.enable_eager_execution()
116 | tf.config.run_functions_eagerly(True)
117 | # reset seed values
118 | np.random.seed(0)
119 | tf.compat.v1.set_random_seed(0)
120 |
121 | # new_model = tf.keras.models.load_model(model_dir, custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, compile=False)
122 | # new_model.summary(print_fn=log.info)
123 | get_detailed_stats(model_dir)
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/tflite_stats.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import logging
20 | import click
21 | import os
22 | import tensorflow as tf
23 | from .utils import add_options
24 | import numpy as np
25 | import argparse
26 | import time
27 | import h5py
28 | import sys
29 | from deepspectrumlite import AugmentableModel, ARelu
30 |
31 |
32 | def get_detailed_stats(model_h5_path):
33 | session = tf.compat.v1.Session()
34 | graph = tf.compat.v1.get_default_graph()
35 |
36 | with graph.as_default():
37 | with session.as_default():
38 | new_model = tf.keras.models.load_model(model_h5_path, custom_objects={'AugmentableModel': AugmentableModel,
39 | 'ARelu': ARelu}, compile=False)
40 | run_meta = tf.compat.v1.RunMetadata()
41 | input_details = new_model.get_config()
42 | input_shape = input_details['layers'][0]['config']['batch_input_shape']
43 |
44 | _ = session.run(new_model.output, {
45 | 'input_1:0': np.random.normal(size=(1, input_shape[1], input_shape[2], input_shape[3]))},
46 | options=tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE),
47 | run_metadata=run_meta)
48 |
49 | # '''
50 | ProfileOptionBuilder = tf.compat.v1.profiler.ProfileOptionBuilder
51 | opts = ProfileOptionBuilder(ProfileOptionBuilder.time_and_memory()
52 | ).with_step(0).with_timeline_output('test.json').build()
53 |
54 | tf.compat.v1.profiler.profile(
55 | tf.compat.v1.get_default_graph(),
56 | run_meta=run_meta,
57 | cmd='code',
58 | options=opts)
59 | # '''
60 | # Print to stdout an analysis of the memory usage and the timing information
61 | # broken down by operation types.
62 | json_export = tf.compat.v1.profiler.profile(
63 | tf.compat.v1.get_default_graph(),
64 | run_meta=run_meta,
65 | cmd='op',
66 | options=tf.compat.v1.profiler.ProfileOptionBuilder.time_and_memory())
67 |
68 | text_file = open("profiler.json", "w")
69 | text_file.write(str(json_export))
70 | text_file.close()
71 | # print(json_export)
72 | tf.compat.v1.reset_default_graph()
73 |
74 | from os.path import join, dirname, realpath
75 |
76 | log = logging.getLogger(__name__)
77 |
78 | _DESCRIPTION = 'Test a TensorFlowLite model and retrieve statistics about it.'
79 |
80 | @add_options(
81 | [
82 | click.option(
83 | "-md",
84 | "--model-dir",
85 | type=click.Path(exists=False, writable=True),
86 | help="Path to the TensorFlow Lite model",
87 | required=True
88 | )
89 | ]
90 | )
91 |
92 | @click.command(help=_DESCRIPTION)
93 | def tflite_stats(model_dir, **kwargs):
94 | tf.compat.v1.enable_eager_execution()
95 | tf.config.run_functions_eagerly(True)
96 |
97 | # reset seed values
98 | np.random.seed(0)
99 | tf.compat.v1.set_random_seed(0)
100 |
101 | model_sub_dir = dirname(model_dir)
102 |
103 | interpreter = tf.lite.Interpreter(model_path=model_dir)
104 | input_details = interpreter.get_input_details()
105 | output_details = interpreter.get_output_details()
106 | all_layers_details = interpreter.get_tensor_details()
107 |
108 | interpreter.allocate_tensors()
109 |
110 | f = h5py.File(os.path.join(model_sub_dir, "converted_model_weights_infos.hdf5"), "w")
111 | parameters = 0
112 | for layer in all_layers_details:
113 | # to create a group in an hdf5 file
114 | grp = f.create_group(str(layer['index']))
115 |
116 | # to store layer's metadata in group's metadata
117 | grp.attrs["name"] = layer['name']
118 | grp.attrs["shape"] = layer['shape']
119 | # grp.attrs["dtype"] = all_layers_details[i]['dtype']
120 | grp.attrs["quantization"] = layer['quantization']
121 | weights = interpreter.get_tensor(layer['index'])
122 | # print(weights.size)
123 | parameters += weights.size
124 | # to store the weights in a dataset
125 | grp.create_dataset("weights", data=weights)
126 |
127 | f.close()
128 | log.info(str(parameters))
129 |
130 | # interpreter.set_tensor(input_details[0]['index'], tf.convert_to_tensor(np.expand_dims(audio_data, 0), dtype=tf.float32))
131 |
132 | interpreter.invoke()
133 |
134 | output = interpreter.get_tensor(output_details[0]['index'])
135 |
136 | # Test model on random input data.
137 | input_shape = input_details[0]['shape']
138 | log.info("input shape: ")
139 | log.info(input_shape)
140 | log.info("output shape: ",)
141 | log.info(output_details[0]['shape'])
142 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
143 | interpreter.set_tensor(input_details[0]['index'], input_data)
144 | start_time = time.time()
145 | time.sleep(10.0)
146 | log.info("start")
147 | i = 50
148 | while i > 0:
149 | interpreter.invoke()
150 | i = i - 1
151 | stop_time = time.time()
152 | output_data = interpreter.get_tensor(output_details[0]['index'])
153 |
154 | log.info(output_data)
155 | log.info('time: {:.3f}ms'.format((stop_time - start_time) * 1000))
156 | log.info('mean time: {:.3f}ms'.format((stop_time - start_time) * 1000 / 50))
157 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/train.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import logging
20 | import click
21 | from os import environ
22 | from .utils import add_options
23 | import shutil
24 | import sys
25 | import json
26 | import os
27 | from os.path import join, dirname, realpath
28 | import platform
29 |
30 | log = logging.getLogger(__name__)
31 |
32 | _DESCRIPTION = 'Train a DeepSpectrumLite transer learning model.'
33 |
34 | @add_options(
35 | [
36 | click.option(
37 | "-d",
38 | "--data-dir",
39 | type=click.Path(exists=True),
40 | help="Directory of data class categories containing folders of each data class.",
41 | required=True
42 | ),
43 | click.option(
44 | "-md",
45 | "--model-dir",
46 | type=click.Path(exists=False, writable=True),
47 | help="Directory for all training output (logs and final model files).",
48 | required=True
49 | ),
50 | click.option(
51 | "-hc",
52 | "--hyper-config",
53 | type=click.Path(exists=True, writable=False, readable=True),
54 | help="Directory for the hyper parameter config file.",
55 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True
56 | ),
57 | click.option(
58 | "-cc",
59 | "--class-config",
60 | type=click.Path(exists=True, writable=False, readable=True),
61 | help="Directory for the class config file.",
62 | default=join(dirname(realpath(__file__)), "config/class_config.json"), show_default=True
63 | ),
64 | click.option(
65 | "-l",
66 | "--label-file",
67 | type=click.Path(exists=True, writable=False, readable=True),
68 | help="Directory for the label file.",
69 | required=True
70 | ),
71 | click.option(
72 | "-dc",
73 | "--disable-cache",
74 | type=click.Path(exists=True, writable=False, readable=True),
75 | help="Disables the in-memory caching."
76 | )
77 | ]
78 | )
79 |
80 | @click.command(help=_DESCRIPTION)
81 | @click.pass_context
82 | def train(ctx, model_dir, data_dir, class_config, hyper_config, label_file, disable_cache, **kwargs):
83 | import tensorflow as tf
84 | # tf.compat.v1.enable_eager_execution()
85 | # tf.config.experimental_run_functions_eagerly(True)
86 | from tensorboard.plugins.hparams import api as hp
87 | import numpy as np
88 | import importlib
89 | from deepspectrumlite import HyperParameterList, TransferBaseModel, DataPipeline, \
90 | METRIC_ACCURACY, METRIC_MAE, METRIC_RMSE, METRIC_RECALL, METRIC_PRECISION, METRIC_F_SCORE, METRIC_LOSS, METRIC_MSE
91 | import math
92 |
93 | verbose = ctx.obj['verbose']
94 |
95 | enable_cache = not disable_cache
96 | data_dir = os.path.join(data_dir, '') # add trailing slash
97 |
98 | f = open(class_config)
99 | data = json.load(f)
100 | f.close()
101 |
102 | data_classes = data
103 |
104 | if data_classes is None:
105 | raise ValueError('no data classes defined')
106 |
107 | tensorboard_initialised = False
108 |
109 | log.info("Physical devices:")
110 | physical_devices = tf.config.experimental.list_physical_devices('GPU')
111 | log.info(physical_devices)
112 | del physical_devices
113 |
114 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config)
115 |
116 | max_iterations = hyper_parameter_list.get_max_iteration()
117 | log.info('Loaded hyperparameter configuration.')
118 | log.info("Recognised combinations of settings: " + str(max_iterations) + "")
119 |
120 | slurm_jobid = os.getenv('SLURM_ARRAY_TASK_ID')
121 |
122 | if slurm_jobid is not None:
123 | slurm_jobid = int(slurm_jobid)
124 |
125 | if slurm_jobid >= max_iterations:
126 | raise ValueError('slurm jobid ' + str(slurm_jobid) + ' is out of bound')
127 |
128 | for iteration_no in range(max_iterations):
129 | if slurm_jobid is not None:
130 | iteration_no = slurm_jobid
131 | hparam_values = hyper_parameter_list.get_values(iteration_no=iteration_no)
132 | hparam_values_tensorboard = hyper_parameter_list.get_values_tensorboard(iteration_no=iteration_no)
133 |
134 | run_identifier = hparam_values['tb_run_id'] + '_config_' + str(iteration_no)
135 |
136 | tensorboard_dir = hparam_values['tb_experiment']
137 |
138 | log_dir = os.path.join(model_dir, 'logs', tensorboard_dir)
139 | run_log_dir = os.path.join(log_dir, run_identifier)
140 | model_dir = os.path.join(model_dir, 'models', tensorboard_dir, run_identifier)
141 | # delete old log
142 | if os.path.isdir(run_log_dir):
143 | shutil.rmtree(run_log_dir)
144 |
145 | if not tensorboard_initialised:
146 | # create tensorboard
147 | with tf.summary.create_file_writer(log_dir).as_default():
148 | hp.hparams_config(
149 | hparams=hyper_parameter_list.get_hparams(),
150 | metrics=[hp.Metric(METRIC_ACCURACY, display_name='accuracy'),
151 | hp.Metric(METRIC_PRECISION, display_name='precision'),
152 | hp.Metric(METRIC_RECALL, display_name='unweighted recall'),
153 | hp.Metric(METRIC_F_SCORE, display_name='f1 score'),
154 | hp.Metric(METRIC_MAE, display_name='mae'),
155 | hp.Metric(METRIC_RMSE, display_name='rmse')
156 | ],
157 | )
158 | tensorboard_initialised = True
159 |
160 | # Use a label file parser to load data
161 | label_parser_key = hparam_values['label_parser']
162 |
163 | if ":" not in label_parser_key:
164 | raise ValueError('Please provide the parser in the following format: path.to.parser_file.py:ParserClass')
165 |
166 | log.info(f'Using custom external parser: {label_parser_key}')
167 | if platform.system() == "Windows": # need to consider : after drive letter in windows paths
168 | # split
169 | s = label_parser_key.split(":")
170 | assert len(s) == 3
171 | # reconstruct path from first two letters
172 | path = "" + s[0] + ":" + s[1]
173 | class_name = s[-1]
174 |
175 | else: # Linux or Mac
176 | path, class_name = label_parser_key.split(':')
177 | module_name = os.path.splitext(os.path.basename(path))[0]
178 | dir_path = os.path.dirname(os.path.realpath(__file__))
179 | path = os.path.join(dir_path, path)
180 | spec = importlib.util.spec_from_file_location(module_name, path)
181 | foo = importlib.util.module_from_spec(spec)
182 | spec.loader.exec_module(foo)
183 | parser_class = getattr(foo, class_name)
184 |
185 | parser = parser_class(file_path=label_file)
186 | train_data, devel_data, test_data = parser.parse_labels()
187 |
188 | # reset seed values to make keras reproducible
189 | np.random.seed(0)
190 | tf.compat.v1.set_random_seed(0)
191 |
192 | log.info('--- Starting trial: %s' % run_identifier)
193 | log.info({h.name: hparam_values_tensorboard[h] for h in hparam_values_tensorboard})
194 |
195 | log.info("Load data pipeline ...")
196 |
197 | ########### TRAIN DATA ###########
198 | train_data_pipeline = DataPipeline(name='train_data_set', data_classes=data_classes,
199 | enable_gpu=True, verbose=True, enable_augmentation=False,
200 | hparams=hparam_values, run_id=iteration_no)
201 | train_data_pipeline.set_data(train_data)
202 | train_data_pipeline.set_filename_prepend(prepend_filename_str=data_dir)
203 | train_data_pipeline.preprocess()
204 | train_data_pipeline.up_sample()
205 | train_dataset = train_data_pipeline.pipeline(cache=enable_cache)
206 |
207 | ########### DEVEL DATA ###########
208 | devel_data_pipeline = DataPipeline(name='devel_data_set', data_classes=data_classes,
209 | enable_gpu=True, verbose=True, enable_augmentation=False,
210 | hparams=hparam_values, run_id=iteration_no)
211 | devel_data_pipeline.set_data(devel_data)
212 | devel_data_pipeline.set_filename_prepend(prepend_filename_str=data_dir)
213 | devel_dataset = devel_data_pipeline.pipeline(cache=enable_cache, shuffle=False, drop_remainder=False)
214 |
215 | ########### TEST DATA ###########
216 | test_data_pipeline = DataPipeline(name='test_data_set', data_classes=data_classes,
217 | enable_gpu=True, verbose=True, enable_augmentation=False,
218 | hparams=hparam_values, run_id=iteration_no)
219 | test_data_pipeline.set_data(test_data)
220 | test_data_pipeline.set_filename_prepend(prepend_filename_str=data_dir)
221 | test_dataset = test_data_pipeline.pipeline(cache=enable_cache, shuffle=False, drop_remainder=False)
222 |
223 | log.info("All data pipelines have been successfully loaded.")
224 | log.info("Caching in memory is: " + str(enable_cache))
225 |
226 | model_name = hparam_values['model_name']
227 |
228 | available_ai_models = {
229 | 'TransferBaseModel': TransferBaseModel
230 | }
231 |
232 | if model_name in available_ai_models:
233 | model = available_ai_models[model_name](hyper_parameter_list,
234 | train_data_pipeline.get_model_input_shape(),
235 | run_dir=run_log_dir,
236 | data_classes=data_classes,
237 | use_ram=True,
238 | run_id=iteration_no,
239 | verbose=verbose)
240 |
241 | model.run(train_dataset=train_dataset,
242 | test_dataset=test_dataset,
243 | devel_dataset=devel_dataset,
244 | save_model=True,
245 | save_dir=model_dir)
246 | else:
247 | ValueError("Unknown model name: " + model_name)
248 |
249 | if slurm_jobid is not None:
250 | break
251 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/cli/utils.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | def add_options(options):
20 | def _add_options(func):
21 | for option in reversed(options):
22 | func = option(func)
23 | return func
24 |
25 | return _add_options
26 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from .data.plot import *
2 | from .hyperparameter import HyperParameterList
3 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/__init__.py:
--------------------------------------------------------------------------------
1 | """data handling operations for models"""
2 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/embedded/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/data/embedded/__init__.py
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/embedded/preprocessor.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | from tensorflow.python.keras.applications import imagenet_utils
21 | from tensorflow.python.ops.image_ops_impl import ResizeMethod
22 |
23 | from deepspectrumlite import power_to_db, amplitude_to_db
24 | from deepspectrumlite.lib.data.plot import create_map_from_array, CividisColorMap, InfernoColorMap, MagmaColorMap, \
25 | PlasmaColorMap, ViridisColorMap
26 | import numpy as np
27 | import math
28 |
29 |
30 | # TODO refactor
31 | class PreprocessAudio(tf.Module):
32 | def __init__(self, hparams, *args, **kwargs):
33 | super(PreprocessAudio, self).__init__(*args, **kwargs)
34 | self.hparams = hparams
35 |
36 | self.resize_method = ResizeMethod.BILINEAR
37 | if 'resize_method' in self.hparams:
38 | self.resize_method = self.hparams['resize_method']
39 |
40 | self.anti_alias = True
41 | if 'anti_alias' in self.hparams:
42 | self.anti_alias = self.hparams['anti_alias']
43 |
44 | self.colormap = ViridisColorMap()
45 |
46 | available_color_maps = {
47 | "cividis": CividisColorMap,
48 | "inferno": InfernoColorMap,
49 | "magma": MagmaColorMap,
50 | "plasma": PlasmaColorMap,
51 | "viridis": ViridisColorMap
52 | }
53 |
54 | if self.hparams['color_map'] in available_color_maps:
55 | self.colormap = available_color_maps[self.hparams['color_map']]()
56 |
57 | self.preprocessors = {
58 | "vgg16":
59 | tf.keras.applications.vgg16.preprocess_input,
60 | "vgg19":
61 | tf.keras.applications.vgg19.preprocess_input,
62 | "resnet50":
63 | tf.keras.applications.resnet50.preprocess_input,
64 | "xception":
65 | tf.keras.applications.xception.preprocess_input,
66 | "inception_v3":
67 | tf.keras.applications.inception_v3.preprocess_input,
68 | "densenet121":
69 | tf.keras.applications.densenet.preprocess_input,
70 | "densenet169":
71 | tf.keras.applications.densenet.preprocess_input,
72 | "densenet201":
73 | tf.keras.applications.densenet.preprocess_input,
74 | "mobilenet":
75 | tf.keras.applications.mobilenet.preprocess_input,
76 | "mobilenet_v2":
77 | tf.keras.applications.mobilenet_v2.preprocess_input,
78 | "nasnet_large":
79 | tf.keras.applications.nasnet.preprocess_input,
80 | "nasnet_mobile":
81 | tf.keras.applications.nasnet.preprocess_input,
82 | "inception_resnet_v2":
83 | tf.keras.applications.inception_resnet_v2.preprocess_input,
84 | "squeezenet_v1":
85 | tf.keras.applications.imagenet_utils.preprocess_input,
86 | }
87 |
88 | def __preprocess_vgg(self, x, data_format=None):
89 | """
90 | Legacy function for VGG16 and VGG19 preprocessing without centering.
91 | """
92 | x = x[:, :, :, ::-1]
93 | return x
94 |
95 | @tf.function(input_signature=[tf.TensorSpec(shape=(1, 16000), dtype=tf.float32)])
96 | def preprocess(self, audio_signal): # pragma: no cover
97 | decoded_audio = audio_signal * (0.7079 / tf.reduce_max(tf.abs(audio_signal)))
98 |
99 | frame_length = int(self.hparams['stft_window_size'] * self.hparams['sample_rate'])
100 | frame_step = int(self.hparams['stft_hop_size'] * self.hparams['sample_rate'])
101 | fft_length = int(self.hparams['stft_fft_length'] * self.hparams['sample_rate'])
102 |
103 | stfts = tf.signal.stft(decoded_audio, frame_length=frame_length, frame_step=frame_step,
104 | fft_length=fft_length)
105 | spectrograms = tf.abs(stfts, name="magnitude_spectrograms") ** 2
106 |
107 | # Warp the linear scale spectrograms into the mel-scale.
108 | num_spectrogram_bins = stfts.shape[-1]
109 | lower_edge_hertz, upper_edge_hertz, num_mel_bins = self.hparams['lower_edge_hertz'], \
110 | self.hparams['upper_edge_hertz'], \
111 | self.hparams['num_mel_bins']
112 |
113 | linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
114 | num_mel_bins, num_spectrogram_bins, self.hparams['sample_rate'], lower_edge_hertz,
115 | upper_edge_hertz)
116 |
117 | mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1)
118 | mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate(
119 | linear_to_mel_weight_matrix.shape[-1:]))
120 | num_mfcc = self.hparams['num_mfccs']
121 |
122 | if num_mfcc:
123 | if self.hparams['db_scale']:
124 | mel_spectrograms = amplitude_to_db(mel_spectrograms, top_db=None)
125 | else:
126 | mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6)
127 |
128 | # Compute MFCCs from mel_spectrograms
129 | mfccs = tf.signal.mfccs_from_log_mel_spectrograms(mel_spectrograms)[..., :num_mfcc]
130 |
131 | # cep filter
132 | if self.hparams['cep_lifter'] > 0:
133 | cep_lifter = self.hparams['cep_lifter']
134 |
135 | (nframes, ncoeff) = mfccs.shape[-2], mfccs.shape[-1]
136 | n = tf.keras.backend.arange(start=0, stop=ncoeff, dtype=tf.float32)
137 | lift = 1 + (cep_lifter / 2) * tf.sin(math.pi * n / cep_lifter)
138 |
139 | mfccs *= lift
140 |
141 | output = mfccs
142 | else:
143 | if self.hparams['db_scale']:
144 | output = power_to_db(mel_spectrograms, top_db=None)
145 | else:
146 | output = mel_spectrograms
147 |
148 | if self.hparams['use_plot_images']:
149 |
150 | color_map = np.array(self.colormap.get_color_map(), dtype=np.float32)
151 | # color_map = tf.Variable(initial_value=self.colormap.get_color_map(), name="color_map")
152 |
153 | image_data = create_map_from_array(output, color_map=color_map)
154 |
155 | image_data = tf.image.resize(
156 | image_data, (self.hparams['image_width'], self.hparams['image_height']),
157 | method=self.resize_method, preserve_aspect_ratio=False, antialias=False
158 | )
159 |
160 | image_data = image_data * 255.
161 | image_data = tf.clip_by_value(image_data, clip_value_min=0., clip_value_max=255.)
162 | image_data = tf.image.rot90(image_data, k=1)
163 |
164 | def _preprocess(x):
165 | # values in the range [0, 255] are expected!!
166 | model_key = self.hparams['basemodel_name']
167 |
168 | if model_key in self.preprocessors:
169 | return self.preprocessors[model_key](x, data_format='channels_last')
170 |
171 | return x
172 |
173 | # values in the range [0, 255] are expected!!
174 | image_data = _preprocess(image_data)
175 |
176 | else:
177 | image_data = tf.expand_dims(output, axis=3)
178 |
179 | return image_data
180 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/parser/ComParEParser.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import re
20 | import pandas as pd
21 | import numpy as np
22 |
23 |
24 | class ComParEParser: # pragma: no cover
25 |
26 | def __init__(self, file_path: str, delimiter=','): # pragma: no cover
27 | self._file_path = file_path
28 | self._delimiter = delimiter
29 |
30 | def parse_labels(self): # pragma: no cover
31 | complete = pd.read_csv(self._file_path, sep=self._delimiter)
32 | complete.columns = ['filename', 'label', 'duration_frames']
33 |
34 | train_data = complete[complete.filename.str.startswith('train')] # 1-3
35 | devel_data = complete[complete.filename.str.startswith('devel')] # 4
36 | test_data = complete[complete.filename.str.startswith('test')] # 5
37 |
38 | return train_data, devel_data, test_data
39 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/parser/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/data/parser/__init__.py
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/__init__.py:
--------------------------------------------------------------------------------
1 | """ implementation of our spectrogram image plotting """
2 | from .colormap import *
3 | from .color_maps import *
4 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/color_maps/__init__.py:
--------------------------------------------------------------------------------
1 | """implementation of various color maps"""
2 | from .abstract_colormap import AbstractColorMap
3 | from .cividis import *
4 | from .inferno import *
5 | from .magma import *
6 | from .plasma import *
7 | from .viridis import *
8 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/color_maps/abstract_colormap.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 |
4 | class AbstractColorMap(object):
5 | __metaclass__ = abc.ABCMeta
6 |
7 | def __init__(self):
8 | self.color_map = None
9 |
10 | def set_color_map(self, color_map):
11 | self.color_map = color_map
12 |
13 | def get_color_map(self):
14 | return self.color_map
15 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/color_maps/cividis.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from .abstract_colormap import AbstractColorMap
3 |
4 |
5 | class CividisColorMap(AbstractColorMap):
6 | __metaclass__ = abc.ABCMeta
7 |
8 | def __init__(self):
9 | super().__init__()
10 | self.set_color_map([
11 | [0.0, 0.135112, 0.304751],
12 | [0.0, 0.138068, 0.311105],
13 | [0.0, 0.141013, 0.317579],
14 | [0.0, 0.143951, 0.323982],
15 | [0.0, 0.146877, 0.330479],
16 | [0.0, 0.149791, 0.337065],
17 | [0.0, 0.152673, 0.343704],
18 | [0.0, 0.155377, 0.3505],
19 | [0.0, 0.157932, 0.357521],
20 | [0.0, 0.160495, 0.364534],
21 | [0.0, 0.163058, 0.371608],
22 | [0.0, 0.165621, 0.378769],
23 | [0.0, 0.168204, 0.385902],
24 | [0.0, 0.1708, 0.3931],
25 | [0.0, 0.17342, 0.400353],
26 | [0.0, 0.176082, 0.407577],
27 | [0.0, 0.178802, 0.414764],
28 | [0.0, 0.18161, 0.421859],
29 | [0.0, 0.18455, 0.428802],
30 | [0.0, 0.186915, 0.435532],
31 | [0.0, 0.188769, 0.439563],
32 | [0.0, 0.19095, 0.441085],
33 | [0.0, 0.193366, 0.441561],
34 | [0.003602, 0.195911, 0.441564],
35 | [0.017852, 0.198528, 0.441248],
36 | [0.03211, 0.201199, 0.440785],
37 | [0.046205, 0.203903, 0.440196],
38 | [0.058378, 0.206629, 0.439531],
39 | [0.068968, 0.209372, 0.438863],
40 | [0.078624, 0.212122, 0.438105],
41 | [0.087465, 0.214879, 0.437342],
42 | [0.095645, 0.217643, 0.436593],
43 | [0.103401, 0.220406, 0.43579],
44 | [0.110658, 0.22317, 0.435067],
45 | [0.117612, 0.225935, 0.434308],
46 | [0.124291, 0.228697, 0.433547],
47 | [0.130669, 0.231458, 0.43284],
48 | [0.13683, 0.234216, 0.432148],
49 | [0.142852, 0.236972, 0.431404],
50 | [0.148638, 0.239724, 0.430752],
51 | [0.154261, 0.242475, 0.43012],
52 | [0.159733, 0.245221, 0.429528],
53 | [0.165113, 0.247965, 0.428908],
54 | [0.170362, 0.250707, 0.428325],
55 | [0.17549, 0.253444, 0.42779],
56 | [0.180503, 0.25618, 0.427299],
57 | [0.185453, 0.258914, 0.426788],
58 | [0.190303, 0.261644, 0.426329],
59 | [0.195057, 0.264372, 0.425924],
60 | [0.199764, 0.267099, 0.425497],
61 | [0.204385, 0.269823, 0.425126],
62 | [0.208926, 0.272546, 0.424809],
63 | [0.213431, 0.275266, 0.42448],
64 | [0.217863, 0.277985, 0.424206],
65 | [0.222264, 0.280702, 0.423914],
66 | [0.226598, 0.283419, 0.423678],
67 | [0.230871, 0.286134, 0.423498],
68 | [0.23512, 0.288848, 0.423304],
69 | [0.239312, 0.291562, 0.423167],
70 | [0.243485, 0.294274, 0.423014],
71 | [0.247605, 0.296986, 0.422917],
72 | [0.251675, 0.299698, 0.422873],
73 | [0.255731, 0.302409, 0.422814],
74 | [0.25974, 0.30512, 0.42281],
75 | [0.263738, 0.307831, 0.422789],
76 | [0.267693, 0.310542, 0.422821],
77 | [0.271639, 0.313253, 0.422837],
78 | [0.275513, 0.315965, 0.422979],
79 | [0.279411, 0.318677, 0.423031],
80 | [0.28324, 0.32139, 0.423211],
81 | [0.287065, 0.324103, 0.423373],
82 | [0.290884, 0.326816, 0.423517],
83 | [0.294669, 0.329531, 0.423716],
84 | [0.298421, 0.332247, 0.423973],
85 | [0.302169, 0.334963, 0.424213],
86 | [0.305886, 0.337681, 0.424512],
87 | [0.309601, 0.340399, 0.42479],
88 | [0.313287, 0.34312, 0.42512],
89 | [0.316941, 0.345842, 0.425512],
90 | [0.320595, 0.348565, 0.425889],
91 | [0.32425, 0.351289, 0.42625],
92 | [0.327875, 0.354016, 0.42667],
93 | [0.331474, 0.356744, 0.427144],
94 | [0.335073, 0.359474, 0.427605],
95 | [0.338673, 0.362206, 0.428053],
96 | [0.342246, 0.364939, 0.428559],
97 | [0.345793, 0.367676, 0.429127],
98 | [0.349341, 0.370414, 0.429685],
99 | [0.352892, 0.373153, 0.430226],
100 | [0.356418, 0.375896, 0.430823],
101 | [0.359916, 0.378641, 0.431501],
102 | [0.363446, 0.381388, 0.432075],
103 | [0.366923, 0.384139, 0.432796],
104 | [0.37043, 0.38689, 0.433428],
105 | [0.373884, 0.389646, 0.434209],
106 | [0.377371, 0.392404, 0.43489],
107 | [0.38083, 0.395164, 0.435653],
108 | [0.384268, 0.397928, 0.436475],
109 | [0.387705, 0.400694, 0.437305],
110 | [0.391151, 0.403464, 0.438096],
111 | [0.394568, 0.406236, 0.438986],
112 | [0.397991, 0.409011, 0.439848],
113 | [0.401418, 0.41179, 0.440708],
114 | [0.40482, 0.414572, 0.441642],
115 | [0.408226, 0.417357, 0.44257],
116 | [0.411607, 0.420145, 0.443577],
117 | [0.414992, 0.422937, 0.444578],
118 | [0.418383, 0.425733, 0.44556],
119 | [0.421748, 0.428531, 0.44664],
120 | [0.42512, 0.431334, 0.447692],
121 | [0.428462, 0.43414, 0.448864],
122 | [0.431817, 0.43695, 0.449982],
123 | [0.435168, 0.439763, 0.451134],
124 | [0.438504, 0.44258, 0.452341],
125 | [0.44181, 0.445402, 0.453659],
126 | [0.445148, 0.448226, 0.454885],
127 | [0.448447, 0.451053, 0.456264],
128 | [0.451759, 0.453887, 0.457582],
129 | [0.455072, 0.456718, 0.458976],
130 | [0.458366, 0.459552, 0.460457],
131 | [0.461616, 0.462405, 0.461969],
132 | [0.464947, 0.465241, 0.463395],
133 | [0.468254, 0.468083, 0.464908],
134 | [0.471501, 0.47096, 0.466357],
135 | [0.474812, 0.473832, 0.467681],
136 | [0.478186, 0.476699, 0.468845],
137 | [0.481622, 0.479573, 0.469767],
138 | [0.485141, 0.482451, 0.470384],
139 | [0.488697, 0.485318, 0.471008],
140 | [0.492278, 0.488198, 0.471453],
141 | [0.495913, 0.491076, 0.471751],
142 | [0.499552, 0.49396, 0.472032],
143 | [0.503185, 0.496851, 0.472305],
144 | [0.506866, 0.499743, 0.472432],
145 | [0.51054, 0.502643, 0.47255],
146 | [0.514226, 0.505546, 0.47264],
147 | [0.51792, 0.508454, 0.472707],
148 | [0.521643, 0.511367, 0.472639],
149 | [0.525348, 0.514285, 0.47266],
150 | [0.529086, 0.517207, 0.472543],
151 | [0.532829, 0.520135, 0.472401],
152 | [0.536553, 0.523067, 0.472352],
153 | [0.540307, 0.526005, 0.472163],
154 | [0.544069, 0.528948, 0.471947],
155 | [0.54784, 0.531895, 0.471704],
156 | [0.551612, 0.534849, 0.471439],
157 | [0.555393, 0.537807, 0.471147],
158 | [0.559181, 0.540771, 0.470829],
159 | [0.562972, 0.543741, 0.470488],
160 | [0.566802, 0.546715, 0.469988],
161 | [0.570607, 0.549695, 0.469593],
162 | [0.574417, 0.552682, 0.469172],
163 | [0.578236, 0.555673, 0.468724],
164 | [0.582087, 0.55867, 0.468118],
165 | [0.585916, 0.561674, 0.467618],
166 | [0.589753, 0.564682, 0.46709],
167 | [0.593622, 0.567697, 0.466401],
168 | [0.597469, 0.570718, 0.465821],
169 | [0.601354, 0.573743, 0.465074],
170 | [0.605211, 0.576777, 0.464441],
171 | [0.609105, 0.579816, 0.463638],
172 | [0.612977, 0.582861, 0.46295],
173 | [0.616852, 0.585913, 0.462237],
174 | [0.620765, 0.58897, 0.461351],
175 | [0.624654, 0.592034, 0.460583],
176 | [0.628576, 0.595104, 0.459641],
177 | [0.632506, 0.59818, 0.458668],
178 | [0.636412, 0.601264, 0.457818],
179 | [0.640352, 0.604354, 0.456791],
180 | [0.64427, 0.60745, 0.455886],
181 | [0.648222, 0.610553, 0.454801],
182 | [0.652178, 0.613664, 0.453689],
183 | [0.656114, 0.61678, 0.452702],
184 | [0.660082, 0.619904, 0.451534],
185 | [0.664055, 0.623034, 0.450338],
186 | [0.668008, 0.626171, 0.44927],
187 | [0.671991, 0.629316, 0.448018],
188 | [0.675981, 0.632468, 0.446736],
189 | [0.679979, 0.635626, 0.445424],
190 | [0.68395, 0.638793, 0.444251],
191 | [0.687957, 0.641966, 0.442886],
192 | [0.691971, 0.645145, 0.441491],
193 | [0.695985, 0.648334, 0.440072],
194 | [0.700008, 0.651529, 0.438624],
195 | [0.704037, 0.654731, 0.437147],
196 | [0.708067, 0.657942, 0.435647],
197 | [0.712105, 0.66116, 0.434117],
198 | [0.716177, 0.664384, 0.432386],
199 | [0.720222, 0.667618, 0.430805],
200 | [0.724274, 0.670859, 0.429194],
201 | [0.728334, 0.674107, 0.427554],
202 | [0.732422, 0.677364, 0.425717],
203 | [0.736488, 0.680629, 0.424028],
204 | [0.740589, 0.6839, 0.422131],
205 | [0.744664, 0.687181, 0.420393],
206 | [0.748772, 0.69047, 0.418448],
207 | [0.752886, 0.693766, 0.416472],
208 | [0.756975, 0.697071, 0.414659],
209 | [0.761096, 0.700384, 0.412638],
210 | [0.765223, 0.703705, 0.410587],
211 | [0.769353, 0.707035, 0.408516],
212 | [0.773486, 0.710373, 0.406422],
213 | [0.777651, 0.713719, 0.404112],
214 | [0.781795, 0.717074, 0.401966],
215 | [0.785965, 0.720438, 0.399613],
216 | [0.790116, 0.72381, 0.397423],
217 | [0.794298, 0.72719, 0.395016],
218 | [0.79848, 0.73058, 0.392597],
219 | [0.802667, 0.733978, 0.390153],
220 | [0.806859, 0.737385, 0.387684],
221 | [0.811054, 0.740801, 0.385198],
222 | [0.815274, 0.744226, 0.382504],
223 | [0.819499, 0.747659, 0.379785],
224 | [0.823729, 0.751101, 0.377043],
225 | [0.827959, 0.754553, 0.374292],
226 | [0.832192, 0.758014, 0.371529],
227 | [0.836429, 0.761483, 0.368747],
228 | [0.840693, 0.764962, 0.365746],
229 | [0.844957, 0.76845, 0.362741],
230 | [0.849223, 0.771947, 0.359729],
231 | [0.853515, 0.775454, 0.3565],
232 | [0.857809, 0.778969, 0.353259],
233 | [0.862105, 0.782494, 0.350011],
234 | [0.866421, 0.786028, 0.346571],
235 | [0.870717, 0.789572, 0.343333],
236 | [0.875057, 0.793125, 0.339685],
237 | [0.879378, 0.796687, 0.336241],
238 | [0.88372, 0.800258, 0.332599],
239 | [0.888081, 0.803839, 0.32877],
240 | [0.89244, 0.80743, 0.324968],
241 | [0.896818, 0.81103, 0.320982],
242 | [0.901195, 0.814639, 0.317021],
243 | [0.905589, 0.818257, 0.312889],
244 | [0.91, 0.821885, 0.308594],
245 | [0.914407, 0.825522, 0.304348],
246 | [0.918828, 0.829168, 0.29996],
247 | [0.923279, 0.832822, 0.295244],
248 | [0.927724, 0.836486, 0.290611],
249 | [0.93218, 0.840159, 0.28588],
250 | [0.93666, 0.843841, 0.280876],
251 | [0.941147, 0.84753, 0.275815],
252 | [0.945654, 0.851228, 0.270532],
253 | [0.950178, 0.854933, 0.265085],
254 | [0.954725, 0.858646, 0.259365],
255 | [0.959284, 0.862365, 0.253563],
256 | [0.963872, 0.866089, 0.247445],
257 | [0.968469, 0.869819, 0.24131],
258 | [0.973114, 0.87355, 0.234677],
259 | [0.97778, 0.877281, 0.227954],
260 | [0.982497, 0.881008, 0.220878],
261 | [0.987293, 0.884718, 0.213336],
262 | [0.992218, 0.888385, 0.205468],
263 | [0.994847, 0.892954, 0.203445],
264 | [0.995249, 0.898384, 0.207561],
265 | [0.995503, 0.903866, 0.21237],
266 | [0.995737, 0.909344, 0.217772]
267 | ])
268 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/color_maps/inferno.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from .abstract_colormap import AbstractColorMap
3 |
4 | class InfernoColorMap(AbstractColorMap):
5 | __metaclass__ = abc.ABCMeta
6 |
7 | def __init__(self):
8 | super().__init__()
9 | self.set_color_map([
10 | [0.001462, 0.000466, 0.013866],
11 | [0.002267, 0.00127, 0.01857],
12 | [0.003299, 0.002249, 0.024239],
13 | [0.004547, 0.003392, 0.030909],
14 | [0.006006, 0.004692, 0.038558],
15 | [0.007676, 0.006136, 0.046836],
16 | [0.009561, 0.007713, 0.055143],
17 | [0.011663, 0.009417, 0.06346],
18 | [0.013995, 0.011225, 0.071862],
19 | [0.016561, 0.013136, 0.080282],
20 | [0.019373, 0.015133, 0.088767],
21 | [0.022447, 0.017199, 0.097327],
22 | [0.025793, 0.019331, 0.10593],
23 | [0.029432, 0.021503, 0.114621],
24 | [0.033385, 0.023702, 0.123397],
25 | [0.037668, 0.025921, 0.132232],
26 | [0.042253, 0.028139, 0.141141],
27 | [0.046915, 0.030324, 0.150164],
28 | [0.051644, 0.032474, 0.159254],
29 | [0.056449, 0.034569, 0.168414],
30 | [0.06134, 0.03659, 0.177642],
31 | [0.066331, 0.038504, 0.186962],
32 | [0.071429, 0.040294, 0.196354],
33 | [0.076637, 0.041905, 0.205799],
34 | [0.081962, 0.043328, 0.215289],
35 | [0.087411, 0.044556, 0.224813],
36 | [0.09299, 0.045583, 0.234358],
37 | [0.098702, 0.046402, 0.243904],
38 | [0.104551, 0.047008, 0.25343],
39 | [0.110536, 0.047399, 0.262912],
40 | [0.116656, 0.047574, 0.272321],
41 | [0.122908, 0.047536, 0.281624],
42 | [0.129285, 0.047293, 0.290788],
43 | [0.135778, 0.046856, 0.299776],
44 | [0.142378, 0.046242, 0.308553],
45 | [0.149073, 0.045468, 0.317085],
46 | [0.15585, 0.044559, 0.325338],
47 | [0.162689, 0.043554, 0.333277],
48 | [0.169575, 0.042489, 0.340874],
49 | [0.176493, 0.041402, 0.348111],
50 | [0.183429, 0.040329, 0.354971],
51 | [0.190367, 0.039309, 0.361447],
52 | [0.197297, 0.0384, 0.367535],
53 | [0.204209, 0.037632, 0.373238],
54 | [0.211095, 0.03703, 0.378563],
55 | [0.217949, 0.036615, 0.383522],
56 | [0.224763, 0.036405, 0.388129],
57 | [0.231538, 0.036405, 0.3924],
58 | [0.238273, 0.036621, 0.396353],
59 | [0.244967, 0.037055, 0.400007],
60 | [0.25162, 0.037705, 0.403378],
61 | [0.258234, 0.038571, 0.406485],
62 | [0.26481, 0.039647, 0.409345],
63 | [0.271347, 0.040922, 0.411976],
64 | [0.27785, 0.042353, 0.414392],
65 | [0.284321, 0.043933, 0.416608],
66 | [0.290763, 0.045644, 0.418637],
67 | [0.297178, 0.04747, 0.420491],
68 | [0.303568, 0.049396, 0.422182],
69 | [0.309935, 0.051407, 0.423721],
70 | [0.316282, 0.05349, 0.425116],
71 | [0.32261, 0.055634, 0.426377],
72 | [0.328921, 0.057827, 0.427511],
73 | [0.335217, 0.06006, 0.428524],
74 | [0.3415, 0.062325, 0.429425],
75 | [0.347771, 0.064616, 0.430217],
76 | [0.354032, 0.066925, 0.430906],
77 | [0.360284, 0.069247, 0.431497],
78 | [0.366529, 0.071579, 0.431994],
79 | [0.372768, 0.073915, 0.4324],
80 | [0.379001, 0.076253, 0.432719],
81 | [0.385228, 0.078591, 0.432955],
82 | [0.391453, 0.080927, 0.433109],
83 | [0.397674, 0.083257, 0.433183],
84 | [0.403894, 0.08558, 0.433179],
85 | [0.410113, 0.087896, 0.433098],
86 | [0.416331, 0.090203, 0.432943],
87 | [0.422549, 0.092501, 0.432714],
88 | [0.428768, 0.09479, 0.432412],
89 | [0.434987, 0.097069, 0.432039],
90 | [0.441207, 0.099338, 0.431594],
91 | [0.447428, 0.101597, 0.43108],
92 | [0.453651, 0.103848, 0.430498],
93 | [0.459875, 0.106089, 0.429846],
94 | [0.4661, 0.108322, 0.429125],
95 | [0.472328, 0.110547, 0.428334],
96 | [0.478558, 0.112764, 0.427475],
97 | [0.484789, 0.114974, 0.426548],
98 | [0.491022, 0.117179, 0.425552],
99 | [0.497257, 0.119379, 0.424488],
100 | [0.503493, 0.121575, 0.423356],
101 | [0.50973, 0.123769, 0.422156],
102 | [0.515967, 0.12596, 0.420887],
103 | [0.522206, 0.12815, 0.419549],
104 | [0.528444, 0.130341, 0.418142],
105 | [0.534683, 0.132534, 0.416667],
106 | [0.54092, 0.134729, 0.415123],
107 | [0.547157, 0.136929, 0.413511],
108 | [0.553392, 0.139134, 0.411829],
109 | [0.559624, 0.141346, 0.410078],
110 | [0.565854, 0.143567, 0.408258],
111 | [0.572081, 0.145797, 0.406369],
112 | [0.578304, 0.148039, 0.404411],
113 | [0.584521, 0.150294, 0.402385],
114 | [0.590734, 0.152563, 0.40029],
115 | [0.59694, 0.154848, 0.398125],
116 | [0.603139, 0.157151, 0.395891],
117 | [0.60933, 0.159474, 0.393589],
118 | [0.615513, 0.161817, 0.391219],
119 | [0.621685, 0.164184, 0.388781],
120 | [0.627847, 0.166575, 0.386276],
121 | [0.633998, 0.168992, 0.383704],
122 | [0.640135, 0.171438, 0.381065],
123 | [0.64626, 0.173914, 0.378359],
124 | [0.652369, 0.176421, 0.375586],
125 | [0.658463, 0.178962, 0.372748],
126 | [0.66454, 0.181539, 0.369846],
127 | [0.670599, 0.184153, 0.366879],
128 | [0.676638, 0.186807, 0.363849],
129 | [0.682656, 0.189501, 0.360757],
130 | [0.688653, 0.192239, 0.357603],
131 | [0.694627, 0.195021, 0.354388],
132 | [0.700576, 0.197851, 0.351113],
133 | [0.7065, 0.200728, 0.347777],
134 | [0.712396, 0.203656, 0.344383],
135 | [0.718264, 0.206636, 0.340931],
136 | [0.724103, 0.20967, 0.337424],
137 | [0.729909, 0.212759, 0.333861],
138 | [0.735683, 0.215906, 0.330245],
139 | [0.741423, 0.219112, 0.326576],
140 | [0.747127, 0.222378, 0.322856],
141 | [0.752794, 0.225706, 0.319085],
142 | [0.758422, 0.229097, 0.315266],
143 | [0.76401, 0.232554, 0.311399],
144 | [0.769556, 0.236077, 0.307485],
145 | [0.775059, 0.239667, 0.303526],
146 | [0.780517, 0.243327, 0.299523],
147 | [0.785929, 0.247056, 0.295477],
148 | [0.791293, 0.250856, 0.29139],
149 | [0.796607, 0.254728, 0.287264],
150 | [0.801871, 0.258674, 0.283099],
151 | [0.807082, 0.262692, 0.278898],
152 | [0.812239, 0.266786, 0.274661],
153 | [0.817341, 0.270954, 0.27039],
154 | [0.822386, 0.275197, 0.266085],
155 | [0.827372, 0.279517, 0.26175],
156 | [0.832299, 0.283913, 0.257383],
157 | [0.837165, 0.288385, 0.252988],
158 | [0.841969, 0.292933, 0.248564],
159 | [0.846709, 0.297559, 0.244113],
160 | [0.851384, 0.30226, 0.239636],
161 | [0.855992, 0.307038, 0.235133],
162 | [0.860533, 0.311892, 0.230606],
163 | [0.865006, 0.316822, 0.226055],
164 | [0.869409, 0.321827, 0.221482],
165 | [0.873741, 0.326906, 0.216886],
166 | [0.878001, 0.33206, 0.212268],
167 | [0.882188, 0.337287, 0.207628],
168 | [0.886302, 0.342586, 0.202968],
169 | [0.890341, 0.347957, 0.198286],
170 | [0.894305, 0.353399, 0.193584],
171 | [0.898192, 0.358911, 0.18886],
172 | [0.902003, 0.364492, 0.184116],
173 | [0.905735, 0.37014, 0.17935],
174 | [0.90939, 0.375856, 0.174563],
175 | [0.912966, 0.381636, 0.169755],
176 | [0.916462, 0.387481, 0.164924],
177 | [0.919879, 0.393389, 0.16007],
178 | [0.923215, 0.399359, 0.155193],
179 | [0.92647, 0.405389, 0.150292],
180 | [0.929644, 0.411479, 0.145367],
181 | [0.932737, 0.417627, 0.140417],
182 | [0.935747, 0.423831, 0.13544],
183 | [0.938675, 0.430091, 0.130438],
184 | [0.941521, 0.436405, 0.125409],
185 | [0.944285, 0.442772, 0.120354],
186 | [0.946965, 0.449191, 0.115272],
187 | [0.949562, 0.45566, 0.110164],
188 | [0.952075, 0.462178, 0.105031],
189 | [0.954506, 0.468744, 0.099874],
190 | [0.956852, 0.475356, 0.094695],
191 | [0.959114, 0.482014, 0.089499],
192 | [0.961293, 0.488716, 0.084289],
193 | [0.963387, 0.495462, 0.079073],
194 | [0.965397, 0.502249, 0.073859],
195 | [0.967322, 0.509078, 0.068659],
196 | [0.969163, 0.515946, 0.063488],
197 | [0.970919, 0.522853, 0.058367],
198 | [0.97259, 0.529798, 0.053324],
199 | [0.974176, 0.53678, 0.048392],
200 | [0.975677, 0.543798, 0.043618],
201 | [0.977092, 0.55085, 0.03905],
202 | [0.978422, 0.557937, 0.034931],
203 | [0.979666, 0.565057, 0.031409],
204 | [0.980824, 0.572209, 0.028508],
205 | [0.981895, 0.579392, 0.02625],
206 | [0.982881, 0.586606, 0.024661],
207 | [0.983779, 0.593849, 0.02377],
208 | [0.984591, 0.601122, 0.023606],
209 | [0.985315, 0.608422, 0.024202],
210 | [0.985952, 0.61575, 0.025592],
211 | [0.986502, 0.623105, 0.027814],
212 | [0.986964, 0.630485, 0.030908],
213 | [0.987337, 0.63789, 0.034916],
214 | [0.987622, 0.64532, 0.039886],
215 | [0.987819, 0.652773, 0.045581],
216 | [0.987926, 0.66025, 0.05175],
217 | [0.987945, 0.667748, 0.058329],
218 | [0.987874, 0.675267, 0.065257],
219 | [0.987714, 0.682807, 0.072489],
220 | [0.987464, 0.690366, 0.07999],
221 | [0.987124, 0.697944, 0.087731],
222 | [0.986694, 0.70554, 0.095694],
223 | [0.986175, 0.713153, 0.103863],
224 | [0.985566, 0.720782, 0.112229],
225 | [0.984865, 0.728427, 0.120785],
226 | [0.984075, 0.736087, 0.129527],
227 | [0.983196, 0.743758, 0.138453],
228 | [0.982228, 0.751442, 0.147565],
229 | [0.981173, 0.759135, 0.156863],
230 | [0.980032, 0.766837, 0.166353],
231 | [0.978806, 0.774545, 0.176037],
232 | [0.977497, 0.782258, 0.185923],
233 | [0.976108, 0.789974, 0.196018],
234 | [0.974638, 0.797692, 0.206332],
235 | [0.973088, 0.805409, 0.216877],
236 | [0.971468, 0.813122, 0.227658],
237 | [0.969783, 0.820825, 0.238686],
238 | [0.968041, 0.828515, 0.249972],
239 | [0.966243, 0.836191, 0.261534],
240 | [0.964394, 0.843848, 0.273391],
241 | [0.962517, 0.851476, 0.285546],
242 | [0.960626, 0.859069, 0.29801],
243 | [0.95872, 0.866624, 0.31082],
244 | [0.956834, 0.874129, 0.323974],
245 | [0.954997, 0.881569, 0.337475],
246 | [0.953215, 0.888942, 0.351369],
247 | [0.951546, 0.896226, 0.365627],
248 | [0.950018, 0.903409, 0.380271],
249 | [0.948683, 0.910473, 0.395289],
250 | [0.947594, 0.917399, 0.410665],
251 | [0.946809, 0.924168, 0.426373],
252 | [0.946392, 0.930761, 0.442367],
253 | [0.946403, 0.937159, 0.458592],
254 | [0.946903, 0.943348, 0.47497],
255 | [0.947937, 0.949318, 0.491426],
256 | [0.949545, 0.955063, 0.50786],
257 | [0.95174, 0.960587, 0.524203],
258 | [0.954529, 0.965896, 0.540361],
259 | [0.957896, 0.971003, 0.556275],
260 | [0.961812, 0.975924, 0.571925],
261 | [0.966249, 0.980678, 0.587206],
262 | [0.971162, 0.985282, 0.602154],
263 | [0.976511, 0.989753, 0.61676],
264 | [0.982257, 0.994109, 0.631017],
265 | [0.988362, 0.998364, 0.644924]
266 | ])
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/color_maps/magma.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from .abstract_colormap import AbstractColorMap
3 |
4 | class MagmaColorMap(AbstractColorMap):
5 | __metaclass__ = abc.ABCMeta
6 |
7 | def __init__(self):
8 | super().__init__()
9 | self.set_color_map([
10 | [0.001462, 0.000466, 0.013866],
11 | [0.002258, 0.001295, 0.018331],
12 | [0.003279, 0.002305, 0.023708],
13 | [0.004512, 0.00349, 0.029965],
14 | [0.00595, 0.004843, 0.03713],
15 | [0.007588, 0.006356, 0.044973],
16 | [0.009426, 0.008022, 0.052844],
17 | [0.011465, 0.009828, 0.06075],
18 | [0.013708, 0.011771, 0.068667],
19 | [0.016156, 0.01384, 0.076603],
20 | [0.018815, 0.016026, 0.084584],
21 | [0.021692, 0.01832, 0.09261],
22 | [0.024792, 0.020715, 0.100676],
23 | [0.028123, 0.023201, 0.108787],
24 | [0.031696, 0.025765, 0.116965],
25 | [0.03552, 0.028397, 0.125209],
26 | [0.039608, 0.03109, 0.133515],
27 | [0.04383, 0.03383, 0.141886],
28 | [0.048062, 0.036607, 0.150327],
29 | [0.05232, 0.039407, 0.158841],
30 | [0.056615, 0.04216, 0.167446],
31 | [0.060949, 0.044794, 0.176129],
32 | [0.06533, 0.047318, 0.184892],
33 | [0.069764, 0.049726, 0.193735],
34 | [0.074257, 0.052017, 0.20266],
35 | [0.078815, 0.054184, 0.211667],
36 | [0.083446, 0.056225, 0.220755],
37 | [0.088155, 0.058133, 0.229922],
38 | [0.092949, 0.059904, 0.239164],
39 | [0.097833, 0.061531, 0.248477],
40 | [0.102815, 0.06301, 0.257854],
41 | [0.107899, 0.064335, 0.267289],
42 | [0.113094, 0.065492, 0.276784],
43 | [0.118405, 0.066479, 0.286321],
44 | [0.123833, 0.067295, 0.295879],
45 | [0.12938, 0.067935, 0.305443],
46 | [0.135053, 0.068391, 0.315],
47 | [0.140858, 0.068654, 0.324538],
48 | [0.146785, 0.068738, 0.334011],
49 | [0.152839, 0.068637, 0.343404],
50 | [0.159018, 0.068354, 0.352688],
51 | [0.165308, 0.067911, 0.361816],
52 | [0.171713, 0.067305, 0.370771],
53 | [0.178212, 0.066576, 0.379497],
54 | [0.184801, 0.065732, 0.387973],
55 | [0.19146, 0.064818, 0.396152],
56 | [0.198177, 0.063862, 0.404009],
57 | [0.204935, 0.062907, 0.411514],
58 | [0.211718, 0.061992, 0.418647],
59 | [0.218512, 0.061158, 0.425392],
60 | [0.225302, 0.060445, 0.431742],
61 | [0.232077, 0.059889, 0.437695],
62 | [0.238826, 0.059517, 0.443256],
63 | [0.245543, 0.059352, 0.448436],
64 | [0.25222, 0.059415, 0.453248],
65 | [0.258857, 0.059706, 0.45771],
66 | [0.265447, 0.060237, 0.46184],
67 | [0.271994, 0.060994, 0.46566],
68 | [0.278493, 0.061978, 0.46919],
69 | [0.284951, 0.063168, 0.472451],
70 | [0.291366, 0.064553, 0.475462],
71 | [0.29774, 0.066117, 0.478243],
72 | [0.304081, 0.067835, 0.480812],
73 | [0.310382, 0.069702, 0.483186],
74 | [0.316654, 0.07169, 0.48538],
75 | [0.322899, 0.073782, 0.487408],
76 | [0.329114, 0.075972, 0.489287],
77 | [0.335308, 0.078236, 0.491024],
78 | [0.341482, 0.080564, 0.492631],
79 | [0.347636, 0.082946, 0.494121],
80 | [0.353773, 0.085373, 0.495501],
81 | [0.359898, 0.087831, 0.496778],
82 | [0.366012, 0.090314, 0.49796],
83 | [0.372116, 0.092816, 0.499053],
84 | [0.378211, 0.095332, 0.500067],
85 | [0.384299, 0.097855, 0.501002],
86 | [0.390384, 0.100379, 0.501864],
87 | [0.396467, 0.102902, 0.502658],
88 | [0.402548, 0.10542, 0.503386],
89 | [0.408629, 0.10793, 0.504052],
90 | [0.414709, 0.110431, 0.504662],
91 | [0.420791, 0.11292, 0.505215],
92 | [0.426877, 0.115395, 0.505714],
93 | [0.432967, 0.117855, 0.50616],
94 | [0.439062, 0.120298, 0.506555],
95 | [0.445163, 0.122724, 0.506901],
96 | [0.451271, 0.125132, 0.507198],
97 | [0.457386, 0.127522, 0.507448],
98 | [0.463508, 0.129893, 0.507652],
99 | [0.46964, 0.132245, 0.507809],
100 | [0.47578, 0.134577, 0.507921],
101 | [0.481929, 0.136891, 0.507989],
102 | [0.488088, 0.139186, 0.508011],
103 | [0.494258, 0.141462, 0.507988],
104 | [0.500438, 0.143719, 0.50792],
105 | [0.506629, 0.145958, 0.507806],
106 | [0.512831, 0.148179, 0.507648],
107 | [0.519045, 0.150383, 0.507443],
108 | [0.52527, 0.152569, 0.507192],
109 | [0.531507, 0.154739, 0.506895],
110 | [0.537755, 0.156894, 0.506551],
111 | [0.544015, 0.159033, 0.506159],
112 | [0.550287, 0.161158, 0.505719],
113 | [0.556571, 0.163269, 0.50523],
114 | [0.562866, 0.165368, 0.504692],
115 | [0.569172, 0.167454, 0.504105],
116 | [0.57549, 0.16953, 0.503466],
117 | [0.581819, 0.171596, 0.502777],
118 | [0.588158, 0.173652, 0.502035],
119 | [0.594508, 0.175701, 0.501241],
120 | [0.600868, 0.177743, 0.500394],
121 | [0.607238, 0.179779, 0.499492],
122 | [0.613617, 0.181811, 0.498536],
123 | [0.620005, 0.18384, 0.497524],
124 | [0.626401, 0.185867, 0.496456],
125 | [0.632805, 0.187893, 0.495332],
126 | [0.639216, 0.189921, 0.49415],
127 | [0.645633, 0.191952, 0.49291],
128 | [0.652056, 0.193986, 0.491611],
129 | [0.658483, 0.196027, 0.490253],
130 | [0.664915, 0.198075, 0.488836],
131 | [0.671349, 0.200133, 0.487358],
132 | [0.677786, 0.202203, 0.485819],
133 | [0.684224, 0.204286, 0.484219],
134 | [0.690661, 0.206384, 0.482558],
135 | [0.697098, 0.208501, 0.480835],
136 | [0.703532, 0.210638, 0.479049],
137 | [0.709962, 0.212797, 0.477201],
138 | [0.716387, 0.214982, 0.47529],
139 | [0.722805, 0.217194, 0.473316],
140 | [0.729216, 0.219437, 0.471279],
141 | [0.735616, 0.221713, 0.46918],
142 | [0.742004, 0.224025, 0.467018],
143 | [0.748378, 0.226377, 0.464794],
144 | [0.754737, 0.228772, 0.462509],
145 | [0.761077, 0.231214, 0.460162],
146 | [0.767398, 0.233705, 0.457755],
147 | [0.773695, 0.236249, 0.455289],
148 | [0.779968, 0.238851, 0.452765],
149 | [0.786212, 0.241514, 0.450184],
150 | [0.792427, 0.244242, 0.447543],
151 | [0.798608, 0.24704, 0.444848],
152 | [0.804752, 0.249911, 0.442102],
153 | [0.810855, 0.252861, 0.439305],
154 | [0.816914, 0.255895, 0.436461],
155 | [0.822926, 0.259016, 0.433573],
156 | [0.828886, 0.262229, 0.430644],
157 | [0.834791, 0.26554, 0.427671],
158 | [0.840636, 0.268953, 0.424666],
159 | [0.846416, 0.272473, 0.421631],
160 | [0.852126, 0.276106, 0.418573],
161 | [0.857763, 0.279857, 0.415496],
162 | [0.86332, 0.283729, 0.412403],
163 | [0.868793, 0.287728, 0.409303],
164 | [0.874176, 0.291859, 0.406205],
165 | [0.879464, 0.296125, 0.403118],
166 | [0.884651, 0.30053, 0.400047],
167 | [0.889731, 0.305079, 0.397002],
168 | [0.8947, 0.309773, 0.393995],
169 | [0.899552, 0.314616, 0.391037],
170 | [0.904281, 0.31961, 0.388137],
171 | [0.908884, 0.324755, 0.385308],
172 | [0.913354, 0.330052, 0.382563],
173 | [0.917689, 0.3355, 0.379915],
174 | [0.921884, 0.341098, 0.377376],
175 | [0.925937, 0.346844, 0.374959],
176 | [0.929845, 0.352734, 0.372677],
177 | [0.933606, 0.358764, 0.370541],
178 | [0.937221, 0.364929, 0.368567],
179 | [0.940687, 0.371224, 0.366762],
180 | [0.944006, 0.377643, 0.365136],
181 | [0.94718, 0.384178, 0.363701],
182 | [0.95021, 0.39082, 0.362468],
183 | [0.953099, 0.397563, 0.361438],
184 | [0.955849, 0.4044, 0.360619],
185 | [0.958464, 0.411324, 0.360014],
186 | [0.960949, 0.418323, 0.35963],
187 | [0.96331, 0.42539, 0.359469],
188 | [0.965549, 0.432519, 0.359529],
189 | [0.967671, 0.439703, 0.35981],
190 | [0.96968, 0.446936, 0.360311],
191 | [0.971582, 0.45421, 0.36103],
192 | [0.973381, 0.46152, 0.361965],
193 | [0.975082, 0.468861, 0.363111],
194 | [0.97669, 0.476226, 0.364466],
195 | [0.97821, 0.483612, 0.366025],
196 | [0.979645, 0.491014, 0.367783],
197 | [0.981, 0.498428, 0.369734],
198 | [0.982279, 0.505851, 0.371874],
199 | [0.983485, 0.51328, 0.374198],
200 | [0.984622, 0.520713, 0.376698],
201 | [0.985693, 0.528148, 0.379371],
202 | [0.9867, 0.535582, 0.38221],
203 | [0.987646, 0.543015, 0.38521],
204 | [0.988533, 0.550446, 0.388365],
205 | [0.989363, 0.557873, 0.391671],
206 | [0.990138, 0.565296, 0.395122],
207 | [0.990871, 0.572706, 0.398714],
208 | [0.991558, 0.580107, 0.402441],
209 | [0.992196, 0.587502, 0.406299],
210 | [0.992785, 0.594891, 0.410283],
211 | [0.993326, 0.602275, 0.41439],
212 | [0.993834, 0.609644, 0.418613],
213 | [0.994309, 0.616999, 0.42295],
214 | [0.994738, 0.62435, 0.427397],
215 | [0.995122, 0.631696, 0.431951],
216 | [0.99548, 0.639027, 0.436607],
217 | [0.99581, 0.646344, 0.441361],
218 | [0.996096, 0.653659, 0.446213],
219 | [0.996341, 0.660969, 0.45116],
220 | [0.99658, 0.668256, 0.456192],
221 | [0.996775, 0.675541, 0.461314],
222 | [0.996925, 0.682828, 0.466526],
223 | [0.997077, 0.690088, 0.471811],
224 | [0.997186, 0.697349, 0.477182],
225 | [0.997254, 0.704611, 0.482635],
226 | [0.997325, 0.711848, 0.488154],
227 | [0.997351, 0.719089, 0.493755],
228 | [0.997351, 0.726324, 0.499428],
229 | [0.997341, 0.733545, 0.505167],
230 | [0.997285, 0.740772, 0.510983],
231 | [0.997228, 0.747981, 0.516859],
232 | [0.997138, 0.75519, 0.522806],
233 | [0.997019, 0.762398, 0.528821],
234 | [0.996898, 0.769591, 0.534892],
235 | [0.996727, 0.776795, 0.541039],
236 | [0.996571, 0.783977, 0.547233],
237 | [0.996369, 0.791167, 0.553499],
238 | [0.996162, 0.798348, 0.55982],
239 | [0.995932, 0.805527, 0.566202],
240 | [0.99568, 0.812706, 0.572645],
241 | [0.995424, 0.819875, 0.57914],
242 | [0.995131, 0.827052, 0.585701],
243 | [0.994851, 0.834213, 0.592307],
244 | [0.994524, 0.841387, 0.598983],
245 | [0.994222, 0.84854, 0.605696],
246 | [0.993866, 0.855711, 0.612482],
247 | [0.993545, 0.862859, 0.619299],
248 | [0.99317, 0.870024, 0.626189],
249 | [0.992831, 0.877168, 0.633109],
250 | [0.99244, 0.88433, 0.640099],
251 | [0.992089, 0.89147, 0.647116],
252 | [0.991688, 0.898627, 0.654202],
253 | [0.991332, 0.905763, 0.661309],
254 | [0.99093, 0.912915, 0.668481],
255 | [0.99057, 0.920049, 0.675675],
256 | [0.990175, 0.927196, 0.682926],
257 | [0.989815, 0.934329, 0.690198],
258 | [0.989434, 0.94147, 0.697519],
259 | [0.989077, 0.948604, 0.704863],
260 | [0.988717, 0.955742, 0.712242],
261 | [0.988367, 0.962878, 0.719649],
262 | [0.988033, 0.970012, 0.727077],
263 | [0.987691, 0.977154, 0.734536],
264 | [0.987387, 0.984288, 0.742002],
265 | [0.987053, 0.991438, 0.749504]
266 | ])
267 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/color_maps/plasma.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from .abstract_colormap import AbstractColorMap
3 |
4 | class PlasmaColorMap(AbstractColorMap):
5 | __metaclass__ = abc.ABCMeta
6 |
7 | def __init__(self):
8 | super().__init__()
9 | self.set_color_map([
10 | [0.050383, 0.029803, 0.527975],
11 | [0.063536, 0.028426, 0.533124],
12 | [0.075353, 0.027206, 0.538007],
13 | [0.086222, 0.026125, 0.542658],
14 | [0.096379, 0.025165, 0.547103],
15 | [0.10598, 0.024309, 0.551368],
16 | [0.115124, 0.023556, 0.555468],
17 | [0.123903, 0.022878, 0.559423],
18 | [0.132381, 0.022258, 0.56325],
19 | [0.140603, 0.021687, 0.566959],
20 | [0.148607, 0.021154, 0.570562],
21 | [0.156421, 0.020651, 0.574065],
22 | [0.16407, 0.020171, 0.577478],
23 | [0.171574, 0.019706, 0.580806],
24 | [0.17895, 0.019252, 0.584054],
25 | [0.186213, 0.018803, 0.587228],
26 | [0.193374, 0.018354, 0.59033],
27 | [0.200445, 0.017902, 0.593364],
28 | [0.207435, 0.017442, 0.596333],
29 | [0.21435, 0.016973, 0.599239],
30 | [0.221197, 0.016497, 0.602083],
31 | [0.227983, 0.016007, 0.604867],
32 | [0.234715, 0.015502, 0.607592],
33 | [0.241396, 0.014979, 0.610259],
34 | [0.248032, 0.014439, 0.612868],
35 | [0.254627, 0.013882, 0.615419],
36 | [0.261183, 0.013308, 0.617911],
37 | [0.267703, 0.012716, 0.620346],
38 | [0.274191, 0.012109, 0.622722],
39 | [0.280648, 0.011488, 0.625038],
40 | [0.287076, 0.010855, 0.627295],
41 | [0.293478, 0.010213, 0.62949],
42 | [0.299855, 0.009561, 0.631624],
43 | [0.30621, 0.008902, 0.633694],
44 | [0.312543, 0.008239, 0.6357],
45 | [0.318856, 0.007576, 0.63764],
46 | [0.32515, 0.006915, 0.639512],
47 | [0.331426, 0.006261, 0.641316],
48 | [0.337683, 0.005618, 0.643049],
49 | [0.343925, 0.004991, 0.64471],
50 | [0.35015, 0.004382, 0.646298],
51 | [0.356359, 0.003798, 0.64781],
52 | [0.362553, 0.003243, 0.649245],
53 | [0.368733, 0.002724, 0.650601],
54 | [0.374897, 0.002245, 0.651876],
55 | [0.381047, 0.001814, 0.653068],
56 | [0.387183, 0.001434, 0.654177],
57 | [0.393304, 0.001114, 0.655199],
58 | [0.399411, 0.000859, 0.656133],
59 | [0.405503, 0.000678, 0.656977],
60 | [0.41158, 0.000577, 0.65773],
61 | [0.417642, 0.000564, 0.65839],
62 | [0.423689, 0.000646, 0.658956],
63 | [0.429719, 0.000831, 0.659425],
64 | [0.435734, 0.001127, 0.659797],
65 | [0.441732, 0.00154, 0.660069],
66 | [0.447714, 0.00208, 0.66024],
67 | [0.453677, 0.002755, 0.66031],
68 | [0.459623, 0.003574, 0.660277],
69 | [0.46555, 0.004545, 0.660139],
70 | [0.471457, 0.005678, 0.659897],
71 | [0.477344, 0.00698, 0.659549],
72 | [0.48321, 0.00846, 0.659095],
73 | [0.489055, 0.010127, 0.658534],
74 | [0.494877, 0.01199, 0.657865],
75 | [0.500678, 0.014055, 0.657088],
76 | [0.506454, 0.016333, 0.656202],
77 | [0.512206, 0.018833, 0.655209],
78 | [0.517933, 0.021563, 0.654109],
79 | [0.523633, 0.024532, 0.652901],
80 | [0.529306, 0.027747, 0.651586],
81 | [0.534952, 0.031217, 0.650165],
82 | [0.54057, 0.03495, 0.64864],
83 | [0.546157, 0.038954, 0.64701],
84 | [0.551715, 0.043136, 0.645277],
85 | [0.557243, 0.047331, 0.643443],
86 | [0.562738, 0.051545, 0.641509],
87 | [0.568201, 0.055778, 0.639477],
88 | [0.573632, 0.060028, 0.637349],
89 | [0.579029, 0.064296, 0.635126],
90 | [0.584391, 0.068579, 0.632812],
91 | [0.589719, 0.072878, 0.630408],
92 | [0.595011, 0.07719, 0.627917],
93 | [0.600266, 0.081516, 0.625342],
94 | [0.605485, 0.085854, 0.622686],
95 | [0.610667, 0.090204, 0.619951],
96 | [0.615812, 0.094564, 0.61714],
97 | [0.620919, 0.098934, 0.614257],
98 | [0.625987, 0.103312, 0.611305],
99 | [0.631017, 0.107699, 0.608287],
100 | [0.636008, 0.112092, 0.605205],
101 | [0.640959, 0.116492, 0.602065],
102 | [0.645872, 0.120898, 0.598867],
103 | [0.650746, 0.125309, 0.595617],
104 | [0.65558, 0.129725, 0.592317],
105 | [0.660374, 0.134144, 0.588971],
106 | [0.665129, 0.138566, 0.585582],
107 | [0.669845, 0.142992, 0.582154],
108 | [0.674522, 0.147419, 0.578688],
109 | [0.67916, 0.151848, 0.575189],
110 | [0.683758, 0.156278, 0.57166],
111 | [0.688318, 0.160709, 0.568103],
112 | [0.69284, 0.165141, 0.564522],
113 | [0.697324, 0.169573, 0.560919],
114 | [0.701769, 0.174005, 0.557296],
115 | [0.706178, 0.178437, 0.553657],
116 | [0.710549, 0.182868, 0.550004],
117 | [0.714883, 0.187299, 0.546338],
118 | [0.719181, 0.191729, 0.542663],
119 | [0.723444, 0.196158, 0.538981],
120 | [0.72767, 0.200586, 0.535293],
121 | [0.731862, 0.205013, 0.531601],
122 | [0.736019, 0.209439, 0.527908],
123 | [0.740143, 0.213864, 0.524216],
124 | [0.744232, 0.218288, 0.520524],
125 | [0.748289, 0.222711, 0.516834],
126 | [0.752312, 0.227133, 0.513149],
127 | [0.756304, 0.231555, 0.509468],
128 | [0.760264, 0.235976, 0.505794],
129 | [0.764193, 0.240396, 0.502126],
130 | [0.76809, 0.244817, 0.498465],
131 | [0.771958, 0.249237, 0.494813],
132 | [0.775796, 0.253658, 0.491171],
133 | [0.779604, 0.258078, 0.487539],
134 | [0.783383, 0.2625, 0.483918],
135 | [0.787133, 0.266922, 0.480307],
136 | [0.790855, 0.271345, 0.476706],
137 | [0.794549, 0.27577, 0.473117],
138 | [0.798216, 0.280197, 0.469538],
139 | [0.801855, 0.284626, 0.465971],
140 | [0.805467, 0.289057, 0.462415],
141 | [0.809052, 0.293491, 0.45887],
142 | [0.812612, 0.297928, 0.455338],
143 | [0.816144, 0.302368, 0.451816],
144 | [0.819651, 0.306812, 0.448306],
145 | [0.823132, 0.311261, 0.444806],
146 | [0.826588, 0.315714, 0.441316],
147 | [0.830018, 0.320172, 0.437836],
148 | [0.833422, 0.324635, 0.434366],
149 | [0.836801, 0.329105, 0.430905],
150 | [0.840155, 0.33358, 0.427455],
151 | [0.843484, 0.338062, 0.424013],
152 | [0.846788, 0.342551, 0.420579],
153 | [0.850066, 0.347048, 0.417153],
154 | [0.853319, 0.351553, 0.413734],
155 | [0.856547, 0.356066, 0.410322],
156 | [0.85975, 0.360588, 0.406917],
157 | [0.862927, 0.365119, 0.403519],
158 | [0.866078, 0.36966, 0.400126],
159 | [0.869203, 0.374212, 0.396738],
160 | [0.872303, 0.378774, 0.393355],
161 | [0.875376, 0.383347, 0.389976],
162 | [0.878423, 0.387932, 0.3866],
163 | [0.881443, 0.392529, 0.383229],
164 | [0.884436, 0.397139, 0.37986],
165 | [0.887402, 0.401762, 0.376494],
166 | [0.89034, 0.406398, 0.37313],
167 | [0.89325, 0.411048, 0.369768],
168 | [0.896131, 0.415712, 0.366407],
169 | [0.898984, 0.420392, 0.363047],
170 | [0.901807, 0.425087, 0.359688],
171 | [0.904601, 0.429797, 0.356329],
172 | [0.907365, 0.434524, 0.35297],
173 | [0.910098, 0.439268, 0.34961],
174 | [0.9128, 0.444029, 0.346251],
175 | [0.915471, 0.448807, 0.34289],
176 | [0.918109, 0.453603, 0.339529],
177 | [0.920714, 0.458417, 0.336166],
178 | [0.923287, 0.463251, 0.332801],
179 | [0.925825, 0.468103, 0.329435],
180 | [0.928329, 0.472975, 0.326067],
181 | [0.930798, 0.477867, 0.322697],
182 | [0.933232, 0.48278, 0.319325],
183 | [0.93563, 0.487712, 0.315952],
184 | [0.93799, 0.492667, 0.312575],
185 | [0.940313, 0.497642, 0.309197],
186 | [0.942598, 0.502639, 0.305816],
187 | [0.944844, 0.507658, 0.302433],
188 | [0.947051, 0.512699, 0.299049],
189 | [0.949217, 0.517763, 0.295662],
190 | [0.951344, 0.52285, 0.292275],
191 | [0.953428, 0.52796, 0.288883],
192 | [0.95547, 0.533093, 0.28549],
193 | [0.957469, 0.53825, 0.282096],
194 | [0.959424, 0.543431, 0.278701],
195 | [0.961336, 0.548636, 0.275305],
196 | [0.963203, 0.553865, 0.271909],
197 | [0.965024, 0.559118, 0.268513],
198 | [0.966798, 0.564396, 0.265118],
199 | [0.968526, 0.5697, 0.261721],
200 | [0.970205, 0.575028, 0.258325],
201 | [0.971835, 0.580382, 0.254931],
202 | [0.973416, 0.585761, 0.25154],
203 | [0.974947, 0.591165, 0.248151],
204 | [0.976428, 0.596595, 0.244767],
205 | [0.977856, 0.602051, 0.241387],
206 | [0.979233, 0.607532, 0.238013],
207 | [0.980556, 0.613039, 0.234646],
208 | [0.981826, 0.618572, 0.231287],
209 | [0.983041, 0.624131, 0.227937],
210 | [0.984199, 0.629718, 0.224595],
211 | [0.985301, 0.63533, 0.221265],
212 | [0.986345, 0.640969, 0.217948],
213 | [0.987332, 0.646633, 0.214648],
214 | [0.98826, 0.652325, 0.211364],
215 | [0.989128, 0.658043, 0.2081],
216 | [0.989935, 0.663787, 0.204859],
217 | [0.990681, 0.669558, 0.201642],
218 | [0.991365, 0.675355, 0.198453],
219 | [0.991985, 0.681179, 0.195295],
220 | [0.992541, 0.68703, 0.19217],
221 | [0.993032, 0.692907, 0.189084],
222 | [0.993456, 0.69881, 0.186041],
223 | [0.993814, 0.704741, 0.183043],
224 | [0.994103, 0.710698, 0.180097],
225 | [0.994324, 0.716681, 0.177208],
226 | [0.994474, 0.722691, 0.174381],
227 | [0.994553, 0.728728, 0.171622],
228 | [0.994561, 0.734791, 0.168938],
229 | [0.994495, 0.74088, 0.166335],
230 | [0.994355, 0.746995, 0.163821],
231 | [0.994141, 0.753137, 0.161404],
232 | [0.993851, 0.759304, 0.159092],
233 | [0.993482, 0.765499, 0.156891],
234 | [0.993033, 0.77172, 0.154808],
235 | [0.992505, 0.777967, 0.152855],
236 | [0.991897, 0.784239, 0.151042],
237 | [0.991209, 0.790537, 0.149377],
238 | [0.990439, 0.796859, 0.14787],
239 | [0.989587, 0.803205, 0.146529],
240 | [0.988648, 0.809579, 0.145357],
241 | [0.987621, 0.815978, 0.144363],
242 | [0.986509, 0.822401, 0.143557],
243 | [0.985314, 0.828846, 0.142945],
244 | [0.984031, 0.835315, 0.142528],
245 | [0.982653, 0.841812, 0.142303],
246 | [0.98119, 0.848329, 0.142279],
247 | [0.979644, 0.854866, 0.142453],
248 | [0.977995, 0.861432, 0.142808],
249 | [0.976265, 0.868016, 0.143351],
250 | [0.974443, 0.874622, 0.144061],
251 | [0.97253, 0.88125, 0.144923],
252 | [0.970533, 0.887896, 0.145919],
253 | [0.968443, 0.894564, 0.147014],
254 | [0.966271, 0.901249, 0.14818],
255 | [0.964021, 0.90795, 0.14937],
256 | [0.961681, 0.914672, 0.15052],
257 | [0.959276, 0.921407, 0.151566],
258 | [0.956808, 0.928152, 0.152409],
259 | [0.954287, 0.934908, 0.152921],
260 | [0.951726, 0.941671, 0.152925],
261 | [0.949151, 0.948435, 0.152178],
262 | [0.946602, 0.95519, 0.150328],
263 | [0.944152, 0.961916, 0.146861],
264 | [0.941896, 0.96859, 0.140956],
265 | [0.940015, 0.975158, 0.131326]
266 | ])
267 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/color_maps/viridis.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from .abstract_colormap import AbstractColorMap
3 |
4 | class ViridisColorMap(AbstractColorMap):
5 | __metaclass__ = abc.ABCMeta
6 |
7 | def __init__(self):
8 | super().__init__()
9 | self.set_color_map([
10 | [0.267004, 0.004874, 0.329415],
11 | [0.26851, 0.009605, 0.335427],
12 | [0.269944, 0.014625, 0.341379],
13 | [0.271305, 0.019942, 0.347269],
14 | [0.272594, 0.025563, 0.353093],
15 | [0.273809, 0.031497, 0.358853],
16 | [0.274952, 0.037752, 0.364543],
17 | [0.276022, 0.044167, 0.370164],
18 | [0.277018, 0.050344, 0.375715],
19 | [0.277941, 0.056324, 0.381191],
20 | [0.278791, 0.062145, 0.386592],
21 | [0.279566, 0.067836, 0.391917],
22 | [0.280267, 0.073417, 0.397163],
23 | [0.280894, 0.078907, 0.402329],
24 | [0.281446, 0.08432, 0.407414],
25 | [0.281924, 0.089666, 0.412415],
26 | [0.282327, 0.094955, 0.417331],
27 | [0.282656, 0.100196, 0.42216],
28 | [0.28291, 0.105393, 0.426902],
29 | [0.283091, 0.110553, 0.431554],
30 | [0.283197, 0.11568, 0.436115],
31 | [0.283229, 0.120777, 0.440584],
32 | [0.283187, 0.125848, 0.44496],
33 | [0.283072, 0.130895, 0.449241],
34 | [0.282884, 0.13592, 0.453427],
35 | [0.282623, 0.140926, 0.457517],
36 | [0.28229, 0.145912, 0.46151],
37 | [0.281887, 0.150881, 0.465405],
38 | [0.281412, 0.155834, 0.469201],
39 | [0.280868, 0.160771, 0.472899],
40 | [0.280255, 0.165693, 0.476498],
41 | [0.279574, 0.170599, 0.479997],
42 | [0.278826, 0.17549, 0.483397],
43 | [0.278012, 0.180367, 0.486697],
44 | [0.277134, 0.185228, 0.489898],
45 | [0.276194, 0.190074, 0.493001],
46 | [0.275191, 0.194905, 0.496005],
47 | [0.274128, 0.199721, 0.498911],
48 | [0.273006, 0.20452, 0.501721],
49 | [0.271828, 0.209303, 0.504434],
50 | [0.270595, 0.214069, 0.507052],
51 | [0.269308, 0.218818, 0.509577],
52 | [0.267968, 0.223549, 0.512008],
53 | [0.26658, 0.228262, 0.514349],
54 | [0.265145, 0.232956, 0.516599],
55 | [0.263663, 0.237631, 0.518762],
56 | [0.262138, 0.242286, 0.520837],
57 | [0.260571, 0.246922, 0.522828],
58 | [0.258965, 0.251537, 0.524736],
59 | [0.257322, 0.25613, 0.526563],
60 | [0.255645, 0.260703, 0.528312],
61 | [0.253935, 0.265254, 0.529983],
62 | [0.252194, 0.269783, 0.531579],
63 | [0.250425, 0.27429, 0.533103],
64 | [0.248629, 0.278775, 0.534556],
65 | [0.246811, 0.283237, 0.535941],
66 | [0.244972, 0.287675, 0.53726],
67 | [0.243113, 0.292092, 0.538516],
68 | [0.241237, 0.296485, 0.539709],
69 | [0.239346, 0.300855, 0.540844],
70 | [0.237441, 0.305202, 0.541921],
71 | [0.235526, 0.309527, 0.542944],
72 | [0.233603, 0.313828, 0.543914],
73 | [0.231674, 0.318106, 0.544834],
74 | [0.229739, 0.322361, 0.545706],
75 | [0.227802, 0.326594, 0.546532],
76 | [0.225863, 0.330805, 0.547314],
77 | [0.223925, 0.334994, 0.548053],
78 | [0.221989, 0.339161, 0.548752],
79 | [0.220057, 0.343307, 0.549413],
80 | [0.21813, 0.347432, 0.550038],
81 | [0.21621, 0.351535, 0.550627],
82 | [0.214298, 0.355619, 0.551184],
83 | [0.212395, 0.359683, 0.55171],
84 | [0.210503, 0.363727, 0.552206],
85 | [0.208623, 0.367752, 0.552675],
86 | [0.206756, 0.371758, 0.553117],
87 | [0.204903, 0.375746, 0.553533],
88 | [0.203063, 0.379716, 0.553925],
89 | [0.201239, 0.38367, 0.554294],
90 | [0.19943, 0.387607, 0.554642],
91 | [0.197636, 0.391528, 0.554969],
92 | [0.19586, 0.395433, 0.555276],
93 | [0.1941, 0.399323, 0.555565],
94 | [0.192357, 0.403199, 0.555836],
95 | [0.190631, 0.407061, 0.556089],
96 | [0.188923, 0.41091, 0.556326],
97 | [0.187231, 0.414746, 0.556547],
98 | [0.185556, 0.41857, 0.556753],
99 | [0.183898, 0.422383, 0.556944],
100 | [0.182256, 0.426184, 0.55712],
101 | [0.180629, 0.429975, 0.557282],
102 | [0.179019, 0.433756, 0.55743],
103 | [0.177423, 0.437527, 0.557565],
104 | [0.175841, 0.44129, 0.557685],
105 | [0.174274, 0.445044, 0.557792],
106 | [0.172719, 0.448791, 0.557885],
107 | [0.171176, 0.45253, 0.557965],
108 | [0.169646, 0.456262, 0.55803],
109 | [0.168126, 0.459988, 0.558082],
110 | [0.166617, 0.463708, 0.558119],
111 | [0.165117, 0.467423, 0.558141],
112 | [0.163625, 0.471133, 0.558148],
113 | [0.162142, 0.474838, 0.55814],
114 | [0.160665, 0.47854, 0.558115],
115 | [0.159194, 0.482237, 0.558073],
116 | [0.157729, 0.485932, 0.558013],
117 | [0.15627, 0.489624, 0.557936],
118 | [0.154815, 0.493313, 0.55784],
119 | [0.153364, 0.497, 0.557724],
120 | [0.151918, 0.500685, 0.557587],
121 | [0.150476, 0.504369, 0.55743],
122 | [0.149039, 0.508051, 0.55725],
123 | [0.147607, 0.511733, 0.557049],
124 | [0.14618, 0.515413, 0.556823],
125 | [0.144759, 0.519093, 0.556572],
126 | [0.143343, 0.522773, 0.556295],
127 | [0.141935, 0.526453, 0.555991],
128 | [0.140536, 0.530132, 0.555659],
129 | [0.139147, 0.533812, 0.555298],
130 | [0.13777, 0.537492, 0.554906],
131 | [0.136408, 0.541173, 0.554483],
132 | [0.135066, 0.544853, 0.554029],
133 | [0.133743, 0.548535, 0.553541],
134 | [0.132444, 0.552216, 0.553018],
135 | [0.131172, 0.555899, 0.552459],
136 | [0.129933, 0.559582, 0.551864],
137 | [0.128729, 0.563265, 0.551229],
138 | [0.127568, 0.566949, 0.550556],
139 | [0.126453, 0.570633, 0.549841],
140 | [0.125394, 0.574318, 0.549086],
141 | [0.124395, 0.578002, 0.548287],
142 | [0.123463, 0.581687, 0.547445],
143 | [0.122606, 0.585371, 0.546557],
144 | [0.121831, 0.589055, 0.545623],
145 | [0.121148, 0.592739, 0.544641],
146 | [0.120565, 0.596422, 0.543611],
147 | [0.120092, 0.600104, 0.54253],
148 | [0.119738, 0.603785, 0.5414],
149 | [0.119512, 0.607464, 0.540218],
150 | [0.119423, 0.611141, 0.538982],
151 | [0.119483, 0.614817, 0.537692],
152 | [0.119699, 0.61849, 0.536347],
153 | [0.120081, 0.622161, 0.534946],
154 | [0.120638, 0.625828, 0.533488],
155 | [0.12138, 0.629492, 0.531973],
156 | [0.122312, 0.633153, 0.530398],
157 | [0.123444, 0.636809, 0.528763],
158 | [0.12478, 0.640461, 0.527068],
159 | [0.126326, 0.644107, 0.525311],
160 | [0.128087, 0.647749, 0.523491],
161 | [0.130067, 0.651384, 0.521608],
162 | [0.132268, 0.655014, 0.519661],
163 | [0.134692, 0.658636, 0.517649],
164 | [0.137339, 0.662252, 0.515571],
165 | [0.14021, 0.665859, 0.513427],
166 | [0.143303, 0.669459, 0.511215],
167 | [0.146616, 0.67305, 0.508936],
168 | [0.150148, 0.676631, 0.506589],
169 | [0.153894, 0.680203, 0.504172],
170 | [0.157851, 0.683765, 0.501686],
171 | [0.162016, 0.687316, 0.499129],
172 | [0.166383, 0.690856, 0.496502],
173 | [0.170948, 0.694384, 0.493803],
174 | [0.175707, 0.6979, 0.491033],
175 | [0.180653, 0.701402, 0.488189],
176 | [0.185783, 0.704891, 0.485273],
177 | [0.19109, 0.708366, 0.482284],
178 | [0.196571, 0.711827, 0.479221],
179 | [0.202219, 0.715272, 0.476084],
180 | [0.20803, 0.718701, 0.472873],
181 | [0.214, 0.722114, 0.469588],
182 | [0.220124, 0.725509, 0.466226],
183 | [0.226397, 0.728888, 0.462789],
184 | [0.232815, 0.732247, 0.459277],
185 | [0.239374, 0.735588, 0.455688],
186 | [0.24607, 0.73891, 0.452024],
187 | [0.252899, 0.742211, 0.448284],
188 | [0.259857, 0.745492, 0.444467],
189 | [0.266941, 0.748751, 0.440573],
190 | [0.274149, 0.751988, 0.436601],
191 | [0.281477, 0.755203, 0.432552],
192 | [0.288921, 0.758394, 0.428426],
193 | [0.296479, 0.761561, 0.424223],
194 | [0.304148, 0.764704, 0.419943],
195 | [0.311925, 0.767822, 0.415586],
196 | [0.319809, 0.770914, 0.411152],
197 | [0.327796, 0.77398, 0.40664],
198 | [0.335885, 0.777018, 0.402049],
199 | [0.344074, 0.780029, 0.397381],
200 | [0.35236, 0.783011, 0.392636],
201 | [0.360741, 0.785964, 0.387814],
202 | [0.369214, 0.788888, 0.382914],
203 | [0.377779, 0.791781, 0.377939],
204 | [0.386433, 0.794644, 0.372886],
205 | [0.395174, 0.797475, 0.367757],
206 | [0.404001, 0.800275, 0.362552],
207 | [0.412913, 0.803041, 0.357269],
208 | [0.421908, 0.805774, 0.35191],
209 | [0.430983, 0.808473, 0.346476],
210 | [0.440137, 0.811138, 0.340967],
211 | [0.449368, 0.813768, 0.335384],
212 | [0.458674, 0.816363, 0.329727],
213 | [0.468053, 0.818921, 0.323998],
214 | [0.477504, 0.821444, 0.318195],
215 | [0.487026, 0.823929, 0.312321],
216 | [0.496615, 0.826376, 0.306377],
217 | [0.506271, 0.828786, 0.300362],
218 | [0.515992, 0.831158, 0.294279],
219 | [0.525776, 0.833491, 0.288127],
220 | [0.535621, 0.835785, 0.281908],
221 | [0.545524, 0.838039, 0.275626],
222 | [0.555484, 0.840254, 0.269281],
223 | [0.565498, 0.84243, 0.262877],
224 | [0.575563, 0.844566, 0.256415],
225 | [0.585678, 0.846661, 0.249897],
226 | [0.595839, 0.848717, 0.243329],
227 | [0.606045, 0.850733, 0.236712],
228 | [0.616293, 0.852709, 0.230052],
229 | [0.626579, 0.854645, 0.223353],
230 | [0.636902, 0.856542, 0.21662],
231 | [0.647257, 0.8584, 0.209861],
232 | [0.657642, 0.860219, 0.203082],
233 | [0.668054, 0.861999, 0.196293],
234 | [0.678489, 0.863742, 0.189503],
235 | [0.688944, 0.865448, 0.182725],
236 | [0.699415, 0.867117, 0.175971],
237 | [0.709898, 0.868751, 0.169257],
238 | [0.720391, 0.87035, 0.162603],
239 | [0.730889, 0.871916, 0.156029],
240 | [0.741388, 0.873449, 0.149561],
241 | [0.751884, 0.874951, 0.143228],
242 | [0.762373, 0.876424, 0.137064],
243 | [0.772852, 0.877868, 0.131109],
244 | [0.783315, 0.879285, 0.125405],
245 | [0.79376, 0.880678, 0.120005],
246 | [0.804182, 0.882046, 0.114965],
247 | [0.814576, 0.883393, 0.110347],
248 | [0.82494, 0.88472, 0.106217],
249 | [0.83527, 0.886029, 0.102646],
250 | [0.845561, 0.887322, 0.099702],
251 | [0.85581, 0.888601, 0.097452],
252 | [0.866013, 0.889868, 0.095953],
253 | [0.876168, 0.891125, 0.09525],
254 | [0.886271, 0.892374, 0.095374],
255 | [0.89632, 0.893616, 0.096335],
256 | [0.906311, 0.894855, 0.098125],
257 | [0.916242, 0.896091, 0.100717],
258 | [0.926106, 0.89733, 0.104071],
259 | [0.935904, 0.89857, 0.108131],
260 | [0.945636, 0.899815, 0.112838],
261 | [0.9553, 0.901065, 0.118128],
262 | [0.964894, 0.902323, 0.123941],
263 | [0.974417, 0.90359, 0.130215],
264 | [0.983868, 0.904867, 0.136897],
265 | [0.993248, 0.906157, 0.143936]
266 | ])
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/data/plot/colormap.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | # Based on the typescript implemented
5 | # @see https://github.com/alesgenova/colormap
6 |
7 |
8 | def linear_scale(domain, out_range, value): # pragma: no cover
9 | d0, d1 = domain
10 | r0, r1 = out_range
11 |
12 | return r0 + (r1 - r0) * ((value - d0) / (d1 - d0))
13 |
14 |
15 | def linear_mixer(value, lower_node_value, upper_node_value): # pragma: no cover
16 | frac = (value - lower_node_value) / (upper_node_value - lower_node_value)
17 | return 1. - frac, frac
18 |
19 |
20 | def color_combination(a, X, b, Y): # pragma: no cover
21 | return [
22 | a * X[0] + b * Y[0],
23 | a * X[1] + b * Y[1],
24 | a * X[2] + b * Y[2]
25 | ]
26 |
27 |
28 | def create_map_from_array(v, color_map): # pragma: no cover
29 | channel_position = 3
30 | # rgb:
31 | v = tf.expand_dims(v, channel_position) # add channel axis
32 | v = tf.repeat(v, 3, axis=channel_position) # duplicate values to all channels
33 |
34 | domain = (tf.math.reduce_min(v, axis=(1, 2, 3), keepdims=True),
35 | tf.math.reduce_max(v, axis=(1, 2, 3), keepdims=True))
36 |
37 | scaled_value = linear_scale(domain=domain, out_range=(0, 1), value=v)
38 | vri_len = color_map.shape[0]
39 | index_float = (vri_len - 1) * scaled_value
40 |
41 | t1 = tf.math.less_equal(index_float, 0)
42 | t2 = tf.math.greater_equal(index_float, vri_len - 1)
43 |
44 | result = tf.where(t1, color_map[0], tf.where(t2, color_map[vri_len - 1], [-1, -2, -3]))
45 |
46 | index = tf.math.floor(index_float)
47 | index2 = tf.where(index >= vri_len - 1, tf.cast(vri_len, dtype=tf.float32) - 1, index + 1.)
48 |
49 | coeff0, coeff1 = linear_mixer(value=index_float, lower_node_value=index, upper_node_value=index2)
50 | index = tf.cast(index, dtype=tf.int32)
51 | index2 = tf.cast(index2, dtype=tf.int32)
52 |
53 | # red mask
54 | v1_r = tf.gather(color_map[:, 0], indices=index)
55 | v2_r = tf.gather(color_map[:, 0], indices=index2)
56 | com_r = coeff0 * v1_r + coeff1 * v2_r
57 | # green mask
58 | v1_g = tf.gather(color_map[:, 1], indices=index)
59 | v2_g = tf.gather(color_map[:, 1], indices=index2)
60 | com_g = coeff0 * v1_g + coeff1 * v2_g
61 | # blue mask
62 | v1_b = tf.gather(color_map[:, 2], indices=index)
63 | v2_b = tf.gather(color_map[:, 2], indices=index2)
64 | com_b = coeff0 * v1_b + coeff1 * v2_b
65 | # apply color masks
66 | result = tf.where(tf.math.equal(result, -1), com_r, result)
67 | result = tf.where(tf.math.equal(result, -2), com_g, result)
68 | result = tf.where(tf.math.equal(result, -3), com_b, result)
69 |
70 | return result
71 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/hyperparameter.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import itertools
20 | import json
21 | from tensorboard.plugins.hparams import api as hp
22 |
23 |
24 | class HyperParameterList:
25 | def __init__(self, config_file_name: str):
26 | f = open(config_file_name)
27 | data = json.load(f)
28 | f.close()
29 |
30 | self._config = data
31 | self._param_list = {}
32 |
33 | self.load_configuration()
34 |
35 | def get_hparams(self):
36 | hparams = []
37 |
38 | for key in self._config:
39 | hparams.append(self._param_list[key])
40 |
41 | return hparams
42 |
43 | def load_configuration(self):
44 | self._param_list = {}
45 |
46 | for key in self._config:
47 | self._param_list[key] = hp.HParam(key, hp.Discrete(self._config[key]))
48 |
49 | def get_max_iteration(self):
50 | count = 1
51 |
52 | for key in self._config:
53 | count = count * len(self._config[key])
54 |
55 | return count
56 |
57 | def get_values_tensorboard(self, iteration_no: int):
58 | if iteration_no >= self.get_max_iteration():
59 | raise ValueError(str(self.get_max_iteration()) + ' < iteration_no >= 0')
60 |
61 | configuration_space = []
62 | for key in self._config:
63 | configurations = []
64 | for v in self._config[key]:
65 | configurations.append({key: v})
66 | configuration_space.append(configurations)
67 | perturbations = list(itertools.product(*configuration_space))
68 |
69 | perturbation = perturbations[iteration_no]
70 | hparams = {}
71 | for param in perturbation:
72 | for key in param:
73 | k = self._param_list[key]
74 | hparams[k] = param[key]
75 |
76 | return hparams
77 |
78 | def get_values(self, iteration_no: int):
79 | if iteration_no >= self.get_max_iteration():
80 | raise ValueError(str(self.get_max_iteration()) + ' < iteration_no >= 0')
81 |
82 | configuration_space = []
83 | for key in self._config:
84 | configurations = []
85 | for v in self._config[key]:
86 | configurations.append({key: v})
87 | configuration_space.append(configurations)
88 | perturbations = list(itertools.product(*configuration_space))
89 |
90 | perturbation = perturbations[iteration_no]
91 | hparams = {}
92 | for param in perturbation:
93 | for key in param:
94 | hparams[key] = param[key]
95 |
96 | return hparams
97 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/TransferBaseModel.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | from tensorflow import keras
20 | import tensorflow as tf
21 | from tensorflow.python.keras.metrics import categorical_accuracy
22 | from deepspectrumlite import Model
23 | from .modules.augmentable_model import AugmentableModel
24 | from .modules.squeeze_net import SqueezeNet
25 | import logging
26 |
27 | log = logging.getLogger(__name__)
28 |
29 | class TransferBaseModel(Model):
30 | base_model = None
31 |
32 | def retrain_model(self):
33 | if self.hy_params['finetune_layer'] > 0:
34 | self.base_model.trainable = True
35 |
36 | layer_count = len(self.base_model.layers)
37 | keep_until = int(layer_count * self.hy_params['finetune_layer'])
38 |
39 | for layer in self.base_model.layers[:keep_until]:
40 | layer.trainable = False
41 |
42 | optimizer = self.get_optimizer_fn()
43 | optimizer._set_hyper('learning_rate', self.hy_params['fine_learning_rate'])
44 | self.get_model().compile(loss=self.hy_params['loss'], optimizer=optimizer,
45 | metrics=self._metrics)
46 |
47 | self.get_model().summary(print_fn=log.info)
48 |
49 | def create_model(self):
50 | hy_params = self.hy_params
51 |
52 | input = keras.Input(shape=self.input_shape, dtype=tf.float32)
53 |
54 | weights = None
55 | if hy_params['weights'] != '':
56 | weights = hy_params['weights']
57 |
58 | available_models = {
59 | "vgg16":
60 | tf.keras.applications.vgg16.VGG16,
61 | "vgg19":
62 | tf.keras.applications.vgg19.VGG19,
63 | "resnet50":
64 | tf.keras.applications.resnet50.ResNet50,
65 | "xception":
66 | tf.keras.applications.xception.Xception,
67 | "inception_v3":
68 | tf.keras.applications.inception_v3,
69 | "densenet121":
70 | tf.keras.applications.densenet.DenseNet121,
71 | "densenet169":
72 | tf.keras.applications.densenet.DenseNet169,
73 | "densenet201":
74 | tf.keras.applications.densenet.DenseNet201,
75 | "mobilenet":
76 | tf.keras.applications.mobilenet.MobileNet,
77 | "mobilenet_v2":
78 | tf.keras.applications.mobilenet_v2.MobileNetV2,
79 | "nasnet_large":
80 | tf.keras.applications.nasnet.NASNetLarge,
81 | "nasnet_mobile":
82 | tf.keras.applications.nasnet.NASNetMobile,
83 | "inception_resnet_v2":
84 | tf.keras.applications.inception_resnet_v2.InceptionResNetV2,
85 | "squeezenet_v1":
86 | SqueezeNet,
87 | }
88 |
89 | model_key = hy_params['basemodel_name']
90 |
91 | if model_key in available_models:
92 | self.base_model = available_models[model_key](weights=weights, include_top=False)
93 | else:
94 | raise ValueError(model_key + ' is not implemented')
95 |
96 | if hy_params['weights'] != '':
97 | training = False
98 | else:
99 | training = True
100 |
101 | self.base_model.trainable = training
102 |
103 | feature_batch_average = tf.keras.layers.GlobalAveragePooling2D()(self.base_model(input, training=training))
104 | flatten = keras.layers.Flatten()(feature_batch_average)
105 | dense_1 = keras.layers.Dense(hy_params['num_units'], activation=self.get_activation_fn())(flatten)
106 | dropout_1 = keras.layers.Dropout(rate=hy_params['dropout'])(dense_1)
107 |
108 | activation = 'softmax' if self.prediction_type == 'categorical' else 'linear'
109 | if 'output_activation' in hy_params:
110 | activation = hy_params['output_activation']
111 | predictions = tf.keras.layers.Dense(len(self.data_classes), activation=activation)(dropout_1)
112 |
113 | model = AugmentableModel(inputs=input, outputs=predictions, name=hy_params['basemodel_name'])
114 | model.set_hyper_parameters(hy_params=hy_params)
115 | self.model = model
116 | self.compile_model()
117 |
118 | def train(self, train_dataset: tf.data.Dataset, devel_dataset: tf.data.Dataset):
119 | """
120 | trains the model with given train data.
121 | """
122 | epochs_first = self.hy_params['pre_epochs']
123 | if self.hy_params['weights'] == '':
124 | epochs_first = self.hy_params['epochs'] + self.hy_params['pre_epochs']
125 |
126 | history = self.get_model().fit(x=train_dataset, epochs=epochs_first,
127 | batch_size=self.hy_params['batch_size'],
128 | shuffle=True,
129 | validation_data=devel_dataset,
130 | callbacks=self.get_callbacks(), verbose=self.verbose)
131 |
132 | if self.hy_params['weights'] != '' and self.hy_params['finetune_layer'] > 0:
133 | self.retrain_model()
134 |
135 | self.get_model().fit(x=train_dataset, epochs=epochs_first + self.hy_params['epochs'],
136 | initial_epoch=history.epoch[-1],
137 | batch_size=self.hy_params['batch_size'],
138 | shuffle=True,
139 | validation_data=devel_dataset,
140 | callbacks=self.get_callbacks(), verbose=self.verbose)
141 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/__init__.py:
--------------------------------------------------------------------------------
1 | """Implementations of several models"""
2 | from .ai_model import Model
3 | from .config.gridsearch import *
4 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/model/config/__init__.py
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/config/gridsearch.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | METRIC_ACCURACY = 'accuracy'
20 | METRIC_PRECISION = 'precision'
21 | METRIC_RECALL = 'recall'
22 | METRIC_F_SCORE = 'f1_score'
23 | METRIC_MAE = 'mae'
24 | METRIC_RMSE = 'rmse'
25 | METRIC_MSE = 'mse'
26 | METRIC_LOSS = 'loss'
27 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/model/modules/__init__.py
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/modules/arelu.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 |
21 |
22 | def ARelu(x, alpha=0.9, beta=2.0):
23 | alpha = tf.clip_by_value(alpha, clip_value_min=0.01, clip_value_max=0.99)
24 | beta = 1 + tf.math.sigmoid(beta)
25 | return tf.nn.relu(x) * beta - tf.nn.relu(-x) * alpha
26 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/modules/attention_module.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow import keras
3 |
4 | # code taken from https://github.com/kobiso/CBAM-tensorflow
5 | # Author: Byung Soo Ko
6 |
7 |
8 | def se_block(residual, name, ratio=8):
9 | """Contains the implementation of Squeeze-and-Excitation(SE) block.
10 | As described in https://arxiv.org/abs/1709.01507.
11 | """
12 |
13 | kernel_initializer = tf.keras.initializers.VarianceScaling()
14 | bias_initializer = tf.constant_initializer(value=0.0)
15 |
16 | with tf.name_scope(name):
17 | channel = residual.get_shape()[-1]
18 | # Global average pooling
19 | squeeze = tf.reduce_mean(residual, axis=[1, 2], keepdims=True)
20 | assert squeeze.get_shape()[1:] == (1, 1, channel)
21 | excitation = keras.layers.Dense(
22 | units=channel // ratio,
23 | activation=tf.nn.relu,
24 | kernel_initializer=kernel_initializer,
25 | bias_initializer=bias_initializer,
26 | name='bottleneck_fc')(squeeze)
27 | assert excitation.get_shape()[1:] == (1, 1, channel // ratio)
28 | excitation = keras.layers.Dense(
29 | units=channel,
30 | activation=tf.nn.sigmoid,
31 | kernel_initializer=kernel_initializer,
32 | bias_initializer=bias_initializer,
33 | name='recover_fc')(excitation)
34 | assert excitation.get_shape()[1:] == (1, 1, channel)
35 | # top = tf.multiply(bottom, se, name='scale')
36 | scale = residual * excitation
37 | return scale
38 |
39 |
40 | def cbam_block(input_feature, name, ratio=8):
41 | """Contains the implementation of Convolutional Block Attention Module(CBAM) block.
42 | As described in https://arxiv.org/abs/1807.06521.
43 | """
44 |
45 | with tf.name_scope(name):
46 | attention_feature = channel_attention(input_feature, 'ch_at', ratio)
47 | attention_feature = spatial_attention(attention_feature, 'sp_at')
48 | return attention_feature
49 |
50 |
51 | def channel_attention(input_feature, name, ratio=8):
52 | kernel_initializer = tf.keras.initializers.VarianceScaling()
53 | bias_initializer = tf.constant_initializer(value=0.0)
54 |
55 | with tf.name_scope(name):
56 | channel = input_feature.get_shape()[-1]
57 | avg_pool = tf.reduce_mean(input_feature, axis=[1, 2], keepdims=True)
58 |
59 | assert avg_pool.get_shape()[1:] == (1, 1, channel)
60 | avg_pool = keras.layers.Dense(
61 | units=channel // ratio,
62 | activation=tf.nn.relu,
63 | kernel_initializer=kernel_initializer,
64 | bias_initializer=bias_initializer)(avg_pool)
65 | assert avg_pool.get_shape()[1:] == (1, 1, channel // ratio)
66 | avg_pool = keras.layers.Dense(
67 | units=channel,
68 | kernel_initializer=kernel_initializer,
69 | bias_initializer=bias_initializer)(avg_pool)
70 | assert avg_pool.get_shape()[1:] == (1, 1, channel)
71 |
72 | max_pool = tf.reduce_max(input_feature, axis=[1, 2], keepdims=True)
73 | assert max_pool.get_shape()[1:] == (1, 1, channel)
74 | max_pool = keras.layers.Dense(
75 | units=channel // ratio,
76 | activation=tf.nn.relu)(max_pool)
77 | assert max_pool.get_shape()[1:] == (1, 1, channel // ratio)
78 | max_pool = keras.layers.Dense(
79 | units=channel)(max_pool)
80 | assert max_pool.get_shape()[1:] == (1, 1, channel)
81 |
82 | scale = tf.sigmoid(avg_pool + max_pool, 'sigmoid')
83 |
84 | return input_feature * scale
85 |
86 |
87 | def spatial_attention(input_feature, name):
88 | kernel_size = 7
89 | kernel_initializer = tf.keras.initializers.VarianceScaling()
90 | with tf.name_scope(name):
91 | avg_pool = tf.reduce_mean(input_feature, axis=[3], keepdims=True)
92 | assert avg_pool.get_shape()[-1] == 1
93 | max_pool = tf.reduce_max(input_feature, axis=[3], keepdims=True)
94 | assert max_pool.get_shape()[-1] == 1
95 | concat = tf.concat([avg_pool, max_pool], 3)
96 | assert concat.get_shape()[-1] == 2
97 |
98 | concat = keras.layers.Conv2D(
99 | filters=1,
100 | kernel_size=[kernel_size, kernel_size],
101 | strides=[1, 1],
102 | padding="same",
103 | activation=None,
104 | kernel_initializer=kernel_initializer,
105 | use_bias=False)(concat)
106 | assert concat.get_shape()[-1] == 1
107 | concat = tf.sigmoid(concat, 'sigmoid')
108 |
109 | return input_feature * concat
110 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/modules/augmentable_model.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | from tensorflow.python.eager import backprop
21 | from tensorflow.python.keras.engine import data_adapter
22 |
23 | from deepspectrumlite.lib.data.data_pipeline import preprocess_scalar_zero
24 |
25 |
26 | '''
27 | AugmentableModel implements a SapAugment data augmentation policy
28 |
29 | Hu, Ting-yao et al. "SapAugment: Learning A Sample Adaptive Policy for Data Augmentation" (2020).
30 | @see https://arxiv.org/abs/2011.01156
31 | '''
32 | class AugmentableModel(tf.keras.Model):
33 |
34 | def __init__(self, *args, **kwargs):
35 | super(AugmentableModel, self).__init__(*args, **kwargs)
36 | self.hy_params = {}
37 | self._batch_size = None
38 | self.lambda_sap = None
39 | self._sap_augment_a = None
40 | self._sap_augment_s = None
41 |
42 | def set_hyper_parameters(self, hy_params):
43 | self.hy_params = hy_params
44 |
45 | self._batch_size = self.hy_params['batch_size']
46 |
47 | self.lambda_sap = tf.zeros(shape=(self._batch_size,), dtype=tf.float32, name="lamda_sap")
48 |
49 | self._sap_augment_a = self.hy_params['sap_aug_a']
50 | self._sap_augment_s = self.hy_params['sap_aug_s']
51 |
52 | def set_batch_size(self, batch_size):
53 | self._batch_size = batch_size
54 |
55 | def set_sap_augment(self, a, s):
56 | self._sap_augment_a = a
57 | self._sap_augment_s = s
58 |
59 | def train_step(self, data):
60 | data = data_adapter.expand_1d(data)
61 | x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
62 |
63 | with backprop.GradientTape() as tape:
64 | y_pred = self(x, training=False)
65 | loss_list = []
66 | for i in range(self._batch_size):
67 | # TODO sample_weight None
68 | instance_loss = self.compiled_loss(
69 | y[i,:], y_pred[i, :], None, regularization_losses=self.losses)
70 | loss_list.append(instance_loss)
71 |
72 | sorted_keys = tf.argsort(loss_list, axis=-1, direction='ASCENDING')
73 |
74 | a = self._sap_augment_a
75 | s = self._sap_augment_s
76 |
77 | alpha = s * (1 - a)
78 | beta = s * a
79 | new_lambda = tf.zeros(shape=(self._batch_size, 1))
80 |
81 | for i in range(self._batch_size):
82 | ranking_relative = (i + 1) / self._batch_size
83 | beta_inc = 1 - tf.math.betainc(alpha, beta, ranking_relative) # D = 1 -> 0
84 |
85 | j = sorted_keys[i]
86 | p1 = tf.zeros(shape=(j, 1))
87 | p2 = tf.zeros(shape=(self._batch_size-1-j, 1))
88 | pf = tf.concat([p1, [(beta_inc,)]], 0)
89 | pf = tf.concat([pf, p2], 0)
90 | new_lambda = new_lambda + pf
91 |
92 | self.lambda_sap = new_lambda
93 |
94 | # import numpy as np
95 | #
96 | # for i in range(self._batch_size):
97 | # image = x[i]
98 | # image = np.array(image)
99 | # lambda_value = self.lambda_sap[i][0].numpy()
100 | # loss_value = loss_list[i].numpy()
101 | # path = '/Users/tobias/Downloads/debug/' + str(i) + '-' + str(
102 | # lambda_value) + '-' + str(
103 | # loss_value) + '-orig.png'
104 | # tf.keras.preprocessing.image.save_img(
105 | # path=path, x=image, data_format='channels_last', scale=True
106 | # )
107 |
108 | # tf.print("x=")
109 | # tf.print(tf.cast(tf.round(tf.squeeze(x[0])), tf.int32), summarize=-1, sep=",")
110 | if self.hy_params['augment_cutmix']:
111 | x, y = self.cutmix(x, y)
112 |
113 | # for i in range(self._batch_size):
114 | # image = x[i]
115 | # image = np.array(image)
116 | # lambda_value = self.lambda_sap[i][0].numpy()
117 | # loss_value = loss_list[i].numpy()
118 | # path = '/Users/tobias/Downloads/debug/' + str(i) + '-' + str(
119 | # lambda_value) + '-' + str(
120 | # loss_value) + '-cutmix.png'
121 | # tf.keras.preprocessing.image.save_img(
122 | # path=path, x=image, data_format='channels_last', scale=True
123 | # )
124 |
125 | # tf.print("cutmix=")
126 | # tf.print(tf.cast(tf.round(tf.squeeze(x[0])), tf.int32), summarize=-1, sep=",")
127 | if self.hy_params['augment_specaug']:
128 | x, y = self.apply_spec_aug(x, y)
129 | # tf.print("spec_aug=")
130 | # tf.print(tf.cast(tf.round(tf.squeeze(x[0])), tf.int32), summarize=-1, sep=",")
131 | # self.stop_training = True
132 |
133 | # for i in range(self._batch_size):
134 | # image = x[i]
135 | # image = np.array(image)
136 | # lambda_value = self.lambda_sap[i][0].numpy()
137 | # loss_value = loss_list[i].numpy()
138 | # path = '/Users/tobias/Downloads/debug/' + str(i) + '-' + str(
139 | # lambda_value) + '-' + str(
140 | # loss_value) + '-both.png'
141 | # tf.keras.preprocessing.image.save_img(
142 | # path=path, x=image, data_format='channels_last', scale=True
143 | # )
144 | # self.stop_training = True
145 |
146 | #tf.print("x_NEW=")
147 | #tf.print(x[0])
148 |
149 | y_pred = self(x, training=True) # batch size, softmax output
150 | loss = self.compiled_loss(
151 | y, y_pred, sample_weight, regularization_losses=self.losses)
152 |
153 | self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
154 | self.compiled_metrics.update_state(y, y_pred, sample_weight)
155 | return {m.name: m.result() for m in self.metrics}
156 |
157 | def cutmix(self, images, labels):
158 | """
159 | Implements cutmix data augmentation
160 | Yun, Sangdoo et al. "CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" (2019).
161 |
162 | @param images: batch of images
163 | @param labels: batch of labels (one hot encoded)
164 | @return: images, labels
165 | @see https://arxiv.org/abs/1905.04899
166 | """
167 | spectrogram_x = self.input_shape[1]
168 | spectrogram_y = self.input_shape[2]
169 |
170 | min_augment = self.hy_params['cutmix_min']
171 | max_augment = self.hy_params['cutmix_max']-self.hy_params['cutmix_min']
172 |
173 | probability_min = self.hy_params['da_prob_min']
174 | probability_max = self.hy_params['da_prob_max']
175 |
176 | image_list = []
177 | label_list = []
178 | for j in range(self._batch_size):
179 | probability_threshold = probability_min + probability_max * tf.squeeze(self.lambda_sap[j])
180 | P = tf.cast(tf.random.uniform([], 0, 1) <= probability_threshold, tf.int32)
181 |
182 | # works as we have square images only
183 | w = tf.cast(tf.round(min_augment * spectrogram_x + max_augment * spectrogram_x * tf.squeeze(self.lambda_sap[j])), tf.int32) * P
184 |
185 | k = tf.cast(tf.random.uniform([], 0, self._batch_size), tf.int32)
186 | x = tf.cast(tf.random.uniform([], 0, spectrogram_x), tf.int32)
187 | y = tf.cast(tf.random.uniform([], 0, spectrogram_y), tf.int32)
188 |
189 | xa = tf.math.maximum(0, x - w // 2) # xa denotes the start
190 | xb = tf.math.minimum(spectrogram_x, x + w // 2) # xb denotes the end
191 |
192 | ya = tf.math.maximum(0, y - w // 2) # xa denotes the start
193 | yb = tf.math.minimum(spectrogram_y, y + w // 2) # xb denotes the end
194 |
195 | piece = tf.concat([
196 | images[j, xa:xb, 0:ya, ],
197 | images[k, xa:xb, ya:yb, ],
198 | images[j, xa:xb, yb:spectrogram_y, ]
199 | ], axis=1)
200 |
201 | image = tf.concat([images[j, 0:xa, :, ],
202 | piece,
203 | images[j, xb:spectrogram_x, :, ]], axis=0)
204 | image_list.append(image)
205 |
206 | a = tf.cast((w**2) / (spectrogram_x*spectrogram_y), tf.float32)
207 |
208 | label_1 = labels[j,]
209 | label_2 = labels[k,]
210 | label_list.append((1 - a) * label_1 + a * label_2)
211 |
212 |
213 | x = tf.reshape(tf.stack(image_list), (self._batch_size, spectrogram_x, spectrogram_y, self.input_shape[3]))
214 | y = tf.reshape(tf.stack(label_list), (self._batch_size, self.output_shape[1]))
215 |
216 | return x, y
217 |
218 | def apply_spec_aug(self, images, labels):
219 | """
220 | Implements SpecAugment
221 | Park, Daniel et al. "SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition" (2019).
222 |
223 | @param images: batch of images
224 | @param labels: batch of labels
225 | @return: images, labels
226 | @see https://arxiv.org/abs/1904.08779
227 | """
228 | with tf.name_scope("apply_spec_aug"):
229 | n = self.input_shape[1] # x = time
230 | v = self.input_shape[2] # y = freq
231 |
232 | frequency_mask_num = self.hy_params['specaug_freq_mask_num']
233 | time_mask_num = self.hy_params['specaug_time_mask_num']
234 |
235 | freq_min_augment = self.hy_params['specaug_freq_min']
236 | freq_max_augment = self.hy_params['specaug_freq_max']-self.hy_params['specaug_freq_min']
237 | time_min_augment = self.hy_params['specaug_time_min']
238 | time_max_augment = self.hy_params['specaug_time_max']-self.hy_params['specaug_time_min']
239 |
240 | probability_min = self.hy_params['da_prob_min']
241 | probability_max = self.hy_params['da_prob_max']
242 |
243 | output_mel_spectrogram = []
244 |
245 | for j in range(self._batch_size):
246 | probability_threshold = probability_min + probability_max * tf.squeeze(self.lambda_sap[j])
247 | P = tf.cast(tf.random.uniform([], 0, 1) <= probability_threshold, tf.int32)
248 |
249 | f = tf.cast(tf.round((freq_min_augment * v + freq_max_augment * v * tf.squeeze(self.lambda_sap[j])) / frequency_mask_num), tf.int32) * P
250 | t = tf.cast(tf.round((time_min_augment * n + time_max_augment * n * tf.squeeze(self.lambda_sap[j])) / time_mask_num), tf.int32) * P
251 |
252 | tmp = images
253 |
254 | for i in range(frequency_mask_num):
255 | f0 = tf.random.uniform([], minval=0, maxval=v - f, dtype=tf.int32)
256 |
257 | mask = tf.concat((tf.ones(shape=(1, n, v - f0 - f, 1)),
258 | tf.zeros(shape=(1, n, f, 1)),
259 | tf.ones(shape=(1, n, f0, 1)),
260 | ), 2)
261 | tmp = tmp * mask
262 |
263 | for i in range(time_mask_num):
264 | t0 = tf.random.uniform([], minval=0, maxval=n - t, dtype=tf.int32)
265 |
266 | mask = tf.concat((tf.ones(shape=(1, n - t0 - t, v, 1)),
267 | tf.zeros(shape=(1, t, v, 1)),
268 | tf.ones(shape=(1, t0, v, 1)),
269 | ), 1)
270 | tmp = tmp * mask
271 |
272 | output_mel_spectrogram.append(tmp[j])
273 |
274 | images = tf.stack(output_mel_spectrogram)
275 | #tf.print(mel_spectrogram[0], summarize=-1, sep=",")
276 | # replace zero values by the mean
277 | preprocessed_mask = preprocess_scalar_zero(self.hy_params['basemodel_name'])
278 | images = tf.where(images == 0.0, preprocessed_mask, images)
279 |
280 | #tf.print(mel_spectrogram[0], summarize=-1, sep=",")
281 | #self.stop_training = True
282 | return images, labels
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/model/modules/squeeze_net.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 |
20 | from tensorflow import keras
21 | from keras_applications.imagenet_utils import _obtain_input_shape
22 | from tensorflow.keras import backend as K
23 | from tensorflow.keras.layers import Input, Convolution2D, MaxPooling2D, Activation, concatenate, Dropout
24 | from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D
25 | from tensorflow.keras.models import Model
26 | from tensorflow.keras.utils import get_source_inputs
27 | from tensorflow.keras.utils import get_file
28 | from tensorflow.python.keras.utils import layer_utils
29 |
30 | sq1x1 = "squeeze1x1"
31 | exp1x1 = "expand1x1"
32 | exp3x3 = "expand3x3"
33 | relu = "relu_"
34 |
35 | WEIGHTS_PATH = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels.h5"
36 | WEIGHTS_PATH_NO_TOP = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5"
37 |
38 | # source https://github.com/rcmalli/keras-squeezenet
39 |
40 | # Modular function for Fire Node
41 |
42 | def fire_module(x, fire_id, squeeze=16, expand=64):
43 | s_id = 'fire' + str(fire_id) + '/'
44 |
45 | if K.image_data_format() == 'channels_first':
46 | channel_axis = 1
47 | else:
48 | channel_axis = 3
49 |
50 | x = Convolution2D(squeeze, (1, 1), padding='valid', name=s_id + sq1x1)(x)
51 | x = Activation('relu', name=s_id + relu + sq1x1)(x)
52 |
53 | left = Convolution2D(expand, (1, 1), padding='valid', name=s_id + exp1x1)(x)
54 | left = Activation('relu', name=s_id + relu + exp1x1)(left)
55 |
56 | right = Convolution2D(expand, (3, 3), padding='same', name=s_id + exp3x3)(x)
57 | right = Activation('relu', name=s_id + relu + exp3x3)(right)
58 |
59 | x = concatenate([left, right], axis=channel_axis, name=s_id + 'concat')
60 | return x
61 |
62 |
63 | # Original SqueezeNet from paper.
64 |
65 | def SqueezeNet(include_top=True, weights='imagenet',
66 | input_tensor=None, input_shape=None,
67 | pooling=None,
68 | classes=1000):
69 | """Instantiates the SqueezeNet architecture.
70 | """
71 |
72 | if weights not in {'imagenet', None}:
73 | raise ValueError('The `weights` argument should be either '
74 | '`None` (random initialization) or `imagenet` '
75 | '(pre-training on ImageNet).')
76 |
77 | if weights == 'imagenet' and classes != 1000:
78 | raise ValueError('If using `weights` as imagenet with `include_top`'
79 | ' as true, `classes` should be 1000')
80 |
81 | input_shape = _obtain_input_shape(input_shape,
82 | default_size=227,
83 | min_size=48,
84 | data_format=K.image_data_format(),
85 | require_flatten=include_top)
86 |
87 | if input_tensor is None:
88 | img_input = Input(shape=input_shape)
89 | else:
90 | if not K.is_keras_tensor(input_tensor):
91 | img_input = Input(tensor=input_tensor, shape=input_shape)
92 | else:
93 | img_input = input_tensor
94 |
95 | x = Convolution2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_input)
96 | x = Activation('relu', name='relu_conv1')(x)
97 | x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x)
98 |
99 | x = fire_module(x, fire_id=2, squeeze=16, expand=64)
100 | x = fire_module(x, fire_id=3, squeeze=16, expand=64)
101 | x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x)
102 |
103 | x = fire_module(x, fire_id=4, squeeze=32, expand=128)
104 | x = fire_module(x, fire_id=5, squeeze=32, expand=128)
105 | x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5')(x)
106 |
107 | x = fire_module(x, fire_id=6, squeeze=48, expand=192)
108 | x = fire_module(x, fire_id=7, squeeze=48, expand=192)
109 | x = fire_module(x, fire_id=8, squeeze=64, expand=256)
110 | x = fire_module(x, fire_id=9, squeeze=64, expand=256)
111 |
112 | if include_top:
113 | # It's not obvious where to cut the network...
114 | # Could do the 8th or 9th layer... some work recommends cutting earlier layers.
115 |
116 | x = Dropout(0.5, name='drop9')(x)
117 |
118 | x = Convolution2D(classes, (1, 1), padding='valid', name='conv10')(x)
119 | x = Activation('relu', name='relu_conv10')(x)
120 | x = GlobalAveragePooling2D()(x)
121 | x = Activation('softmax', name='loss')(x)
122 | else:
123 | if pooling == 'avg':
124 | x = GlobalAveragePooling2D()(x)
125 | elif pooling == 'max':
126 | x = GlobalMaxPooling2D()(x)
127 | elif pooling == None:
128 | pass
129 | else:
130 | raise ValueError("Unknown argument for 'pooling'=" + pooling)
131 |
132 | # Ensure that the model takes into account
133 | # any potential predecessors of `input_tensor`.
134 | if input_tensor is not None:
135 | inputs = get_source_inputs(input_tensor)
136 | else:
137 | inputs = img_input
138 |
139 | model = Model(inputs, x, name='squeezenet')
140 |
141 | # load weights
142 | if weights == 'imagenet':
143 | if include_top:
144 | weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels.h5',
145 | WEIGHTS_PATH,
146 | cache_subdir='models')
147 | else:
148 | weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5',
149 | WEIGHTS_PATH_NO_TOP,
150 | cache_subdir='models')
151 |
152 | model.load_weights(weights_path)
153 | if K.backend() == 'theano':
154 | raise ValueError('theano backend not supported')
155 | # layer_utils.convert_all_kernels_in_model(model)
156 |
157 | if K.image_data_format() == 'channels_first':
158 |
159 | if K.backend() == 'tensorflow':
160 | print('You are using the TensorFlow backend, yet you '
161 | 'are using the Theano '
162 | 'image data format convention '
163 | '(`image_data_format="channels_first"`). '
164 | 'For best performance, set '
165 | '`image_data_format="channels_last"` in '
166 | 'your Keras config '
167 | 'at ~/.keras/keras.json.')
168 | return model
169 |
170 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/util/__init__.py:
--------------------------------------------------------------------------------
1 | """some useful utility scripts"""
2 | from .audio_utils import *
3 |
--------------------------------------------------------------------------------
/src/deepspectrumlite/lib/util/audio_utils.py:
--------------------------------------------------------------------------------
1 | # DeepSpectrumLite
2 | # ==============================================================================
3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 |
21 |
22 | def amplitude_to_db(S, amin=1e-16, top_db=80.0): # pragma: no cover
23 | magnitude = tf.abs(S)
24 | ref_value = tf.reduce_max(magnitude)
25 |
26 | power = tf.square(magnitude)
27 | return power_to_db(power, ref=ref_value ** 2, amin=amin ** 2, top_db=top_db)
28 |
29 |
30 | def power_to_db(S, ref=1.0, amin=1e-16, top_db=80.0): # pragma: no cover
31 | """Convert a power-spectrogram (magnitude squared) to decibel (dB) units.
32 | Computes the scaling ``10 * log10(S / max(S))`` in a numerically
33 | stable way.
34 | Based on:
35 | https://librosa.github.io/librosa/generated/librosa.core.power_to_db.html
36 | """
37 |
38 | # @tf.function
39 | def _tf_log10(x):
40 | numerator = tf.math.log(x)
41 | denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
42 | return numerator / denominator
43 |
44 | # Scale magnitude relative to maximum value in S. Zeros in the output
45 | # correspond to positions where S == ref.
46 | ref_value = tf.abs(ref)
47 |
48 | log_spec = 10.0 * _tf_log10(tf.maximum(amin, S))
49 | log_spec -= 10.0 * _tf_log10(tf.maximum(amin, ref_value))
50 |
51 | if top_db is not None:
52 | log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec) - top_db)
53 |
54 | return log_spec
55 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/__init__.py
--------------------------------------------------------------------------------
/tests/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/cli/__init__.py
--------------------------------------------------------------------------------
/tests/cli/audio/dog/file0.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/cli/audio/dog/file0.wav
--------------------------------------------------------------------------------
/tests/cli/audio/seagull/file1.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/cli/audio/seagull/file1.wav
--------------------------------------------------------------------------------
/tests/cli/config/cividis_hp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"],
3 | "model_name": ["TransferBaseModel"],
4 | "prediction_type": ["categorical"],
5 | "basemodel_name": ["densenet121"],
6 | "weights": ["imagenet"],
7 | "tb_experiment": ["cividis_exp"],
8 | "tb_run_id": ["cividis_run"],
9 | "num_units": [4],
10 | "dropout": [0.25],
11 | "optimizer": ["adam"],
12 | "learning_rate": [0.001],
13 | "fine_learning_rate": [0.0001],
14 | "finetune_layer": [0.7],
15 | "loss": ["categorical_crossentropy"],
16 | "activation": ["arelu"],
17 | "pre_epochs": [1],
18 | "epochs": [1],
19 | "batch_size": [1],
20 |
21 | "sample_rate": [16000],
22 |
23 | "chunk_size": [2.0],
24 | "chunk_hop_size": [1.0],
25 | "normalize_audio": [false],
26 |
27 | "stft_window_size": [0.128],
28 | "stft_hop_size": [0.064],
29 | "stft_fft_length": [0.128],
30 |
31 | "mel_scale": [true],
32 | "lower_edge_hertz": [0.0],
33 | "upper_edge_hertz": [8000.0],
34 | "num_mel_bins": [128],
35 | "num_mfccs": [0],
36 | "cep_lifter": [0],
37 | "db_scale": [true],
38 | "use_plot_images": [true],
39 | "color_map": ["cividis"],
40 | "image_width": [224],
41 | "image_height": [224],
42 | "resize_method": ["nearest"],
43 | "anti_alias": [false],
44 |
45 | "sap_aug_a": [0.5],
46 | "sap_aug_s": [10],
47 | "augment_cutmix": [false],
48 | "augment_specaug": [false],
49 | "da_prob_min": [0.1],
50 | "da_prob_max": [0.5],
51 | "cutmix_min": [0.075],
52 | "cutmix_max": [0.25],
53 | "specaug_freq_min": [0.1],
54 | "specaug_freq_max": [0.3],
55 | "specaug_time_min": [0.1],
56 | "specaug_time_max": [0.3],
57 | "specaug_freq_mask_num": [1],
58 | "specaug_time_mask_num": [1]
59 | }
--------------------------------------------------------------------------------
/tests/cli/config/inferno_hp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"],
3 | "model_name": ["TransferBaseModel"],
4 | "prediction_type": ["categorical"],
5 | "basemodel_name": ["densenet121"],
6 | "weights": ["imagenet"],
7 | "tb_experiment": ["inferno_exp"],
8 | "tb_run_id": ["inferno_run"],
9 | "num_units": [4],
10 | "dropout": [0.25],
11 | "optimizer": ["adagrad"],
12 | "learning_rate": [0.001],
13 | "fine_learning_rate": [0.0001],
14 | "finetune_layer": [0.7],
15 | "loss": ["categorical_crossentropy"],
16 | "activation": ["arelu"],
17 | "pre_epochs": [1],
18 | "epochs": [1],
19 | "batch_size": [1],
20 |
21 | "sample_rate": [16000],
22 |
23 | "chunk_size": [2.0],
24 | "chunk_hop_size": [1.0],
25 | "normalize_audio": [false],
26 |
27 | "stft_window_size": [0.128],
28 | "stft_hop_size": [0.064],
29 | "stft_fft_length": [0.128],
30 |
31 | "mel_scale": [true],
32 | "lower_edge_hertz": [0.0],
33 | "upper_edge_hertz": [8000.0],
34 | "num_mel_bins": [128],
35 | "num_mfccs": [0],
36 | "cep_lifter": [0],
37 | "db_scale": [true],
38 | "use_plot_images": [true],
39 | "color_map": ["inferno"],
40 | "image_width": [224],
41 | "image_height": [224],
42 | "resize_method": ["nearest"],
43 | "anti_alias": [false],
44 |
45 | "sap_aug_a": [0.5],
46 | "sap_aug_s": [10],
47 | "augment_cutmix": [false],
48 | "augment_specaug": [false],
49 | "da_prob_min": [0.1],
50 | "da_prob_max": [0.5],
51 | "cutmix_min": [0.075],
52 | "cutmix_max": [0.25],
53 | "specaug_freq_min": [0.1],
54 | "specaug_freq_max": [0.3],
55 | "specaug_time_min": [0.1],
56 | "specaug_time_max": [0.3],
57 | "specaug_freq_mask_num": [1],
58 | "specaug_time_mask_num": [1]
59 | }
--------------------------------------------------------------------------------
/tests/cli/config/magma_hp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"],
3 | "model_name": ["TransferBaseModel"],
4 | "prediction_type": ["categorical"],
5 | "basemodel_name": ["densenet121"],
6 | "weights": ["imagenet"],
7 | "tb_experiment": ["magma_exp"],
8 | "tb_run_id": ["magma_run"],
9 | "num_units": [4],
10 | "dropout": [0.25],
11 | "optimizer": ["nadam"],
12 | "learning_rate": [0.001],
13 | "fine_learning_rate": [0.0001],
14 | "finetune_layer": [0.7],
15 | "loss": ["categorical_crossentropy"],
16 | "activation": ["arelu"],
17 | "pre_epochs": [1],
18 | "epochs": [1],
19 | "batch_size": [1],
20 |
21 | "sample_rate": [16000],
22 |
23 | "chunk_size": [2.0],
24 | "chunk_hop_size": [1.0],
25 | "normalize_audio": [false],
26 |
27 | "stft_window_size": [0.128],
28 | "stft_hop_size": [0.064],
29 | "stft_fft_length": [0.128],
30 |
31 | "mel_scale": [true],
32 | "lower_edge_hertz": [0.0],
33 | "upper_edge_hertz": [8000.0],
34 | "num_mel_bins": [128],
35 | "num_mfccs": [0],
36 | "cep_lifter": [0],
37 | "db_scale": [true],
38 | "use_plot_images": [true],
39 | "color_map": ["magma"],
40 | "image_width": [224],
41 | "image_height": [224],
42 | "resize_method": ["nearest"],
43 | "anti_alias": [false],
44 |
45 | "sap_aug_a": [0.5],
46 | "sap_aug_s": [10],
47 | "augment_cutmix": [false],
48 | "augment_specaug": [false],
49 | "da_prob_min": [0.1],
50 | "da_prob_max": [0.5],
51 | "cutmix_min": [0.075],
52 | "cutmix_max": [0.25],
53 | "specaug_freq_min": [0.1],
54 | "specaug_freq_max": [0.3],
55 | "specaug_time_min": [0.1],
56 | "specaug_time_max": [0.3],
57 | "specaug_freq_mask_num": [1],
58 | "specaug_time_mask_num": [1]
59 | }
--------------------------------------------------------------------------------
/tests/cli/config/plasma_hp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"],
3 | "model_name": ["TransferBaseModel"],
4 | "prediction_type": ["categorical"],
5 | "basemodel_name": ["densenet121"],
6 | "weights": ["imagenet"],
7 | "tb_experiment": ["plasma_exp"],
8 | "tb_run_id": ["plasma_run"],
9 | "num_units": [4],
10 | "dropout": [0.25],
11 | "optimizer": ["rmsprop"],
12 | "learning_rate": [0.001],
13 | "fine_learning_rate": [0.0001],
14 | "finetune_layer": [0.7],
15 | "loss": ["categorical_crossentropy"],
16 | "activation": ["arelu"],
17 | "pre_epochs": [1],
18 | "epochs": [1],
19 | "batch_size": [1],
20 |
21 | "sample_rate": [16000],
22 |
23 | "chunk_size": [2.0],
24 | "chunk_hop_size": [1.0],
25 | "normalize_audio": [false],
26 |
27 | "stft_window_size": [0.128],
28 | "stft_hop_size": [0.064],
29 | "stft_fft_length": [0.128],
30 |
31 | "mel_scale": [true],
32 | "lower_edge_hertz": [0.0],
33 | "upper_edge_hertz": [8000.0],
34 | "num_mel_bins": [128],
35 | "num_mfccs": [0],
36 | "cep_lifter": [0],
37 | "db_scale": [true],
38 | "use_plot_images": [true],
39 | "color_map": ["plasma"],
40 | "image_width": [224],
41 | "image_height": [224],
42 | "resize_method": ["nearest"],
43 | "anti_alias": [false],
44 |
45 | "sap_aug_a": [0.5],
46 | "sap_aug_s": [10],
47 | "augment_cutmix": [false],
48 | "augment_specaug": [false],
49 | "da_prob_min": [0.1],
50 | "da_prob_max": [0.5],
51 | "cutmix_min": [0.075],
52 | "cutmix_max": [0.25],
53 | "specaug_freq_min": [0.1],
54 | "specaug_freq_max": [0.3],
55 | "specaug_time_min": [0.1],
56 | "specaug_time_max": [0.3],
57 | "specaug_freq_mask_num": [1],
58 | "specaug_time_mask_num": [1]
59 | }
--------------------------------------------------------------------------------
/tests/cli/config/regression_class_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "r": "r"
3 | }
--------------------------------------------------------------------------------
/tests/cli/config/regression_hp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"],
3 | "model_name": ["TransferBaseModel"],
4 | "prediction_type": ["regression"],
5 | "basemodel_name": ["squeezenet_v1"],
6 | "weights": ["imagenet"],
7 | "tb_experiment": ["regression_exp"],
8 | "tb_run_id": ["regression_run"],
9 | "num_units": [4],
10 | "dropout": [0.25],
11 | "optimizer": ["ftrl"],
12 | "learning_rate": [0.001],
13 | "fine_learning_rate": [0.0001],
14 | "finetune_layer": [0.7],
15 | "loss": ["mae"],
16 | "activation": ["relu"],
17 | "output_activation": ["sigmoid"],
18 | "pre_epochs": [1],
19 | "epochs": [1],
20 | "batch_size": [1],
21 |
22 | "sample_rate": [16000],
23 |
24 | "chunk_size": [2.0],
25 | "chunk_hop_size": [1.0],
26 | "normalize_audio": [true],
27 |
28 | "stft_window_size": [0.128],
29 | "stft_hop_size": [0.064],
30 | "stft_fft_length": [0.128],
31 |
32 | "mel_scale": [true],
33 | "lower_edge_hertz": [0.0],
34 | "upper_edge_hertz": [8000.0],
35 | "num_mel_bins": [128],
36 | "num_mfccs": [0],
37 | "cep_lifter": [0],
38 | "db_scale": [false],
39 | "use_plot_images": [true],
40 | "color_map": ["viridis"],
41 | "image_width": [224],
42 | "image_height": [224],
43 | "resize_method": ["nearest"],
44 | "anti_alias": [false],
45 |
46 | "sap_aug_a": [0.5],
47 | "sap_aug_s": [10],
48 | "augment_cutmix": [false],
49 | "augment_specaug": [false],
50 | "da_prob_min": [0.1],
51 | "da_prob_max": [0.5],
52 | "cutmix_min": [0.075],
53 | "cutmix_max": [0.25],
54 | "specaug_freq_min": [0.1],
55 | "specaug_freq_max": [0.3],
56 | "specaug_time_min": [0.1],
57 | "specaug_time_max": [0.3],
58 | "specaug_freq_mask_num": [1],
59 | "specaug_time_mask_num": [1]
60 | }
--------------------------------------------------------------------------------
/tests/cli/label/regression_label.csv:
--------------------------------------------------------------------------------
1 | filename,label,duration_frames
2 | train_seagull.wav,0.5,117249
3 | train_dog.wav,0.9,87865
4 | devel_seagull.wav,0.5,117249
5 | devel_dog.wav,0.9,87865
6 | test_seagull.wav,0.5,117249
7 | test_dog.wav,0.9,87865
--------------------------------------------------------------------------------
/tests/cli/test_cli.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import tempfile
4 |
5 | def get_tmp_dir():
6 | tmpdir = tempfile.mkdtemp()
7 | subdir = os.path.join(tmpdir, "pytest_test_train")
8 | os.mkdir(subdir)
9 | return os.path.join(subdir, "")
10 |
11 | temp_dir = get_tmp_dir()
12 |
13 | def test_train():
14 | tmpdir = temp_dir
15 | from deepspectrumlite.__main__ import cli
16 | from click.testing import CliRunner
17 | from os.path import dirname, join
18 |
19 | cur_dir = dirname(__file__)
20 | examples = join(dirname(dirname(cur_dir)), 'examples')
21 |
22 | runner = CliRunner()
23 | result = runner.invoke(cli,
24 | args=[
25 | '-vv', 'train',
26 | '-d', join(examples, 'audio'),
27 | '-md', join(tmpdir, 'output'),
28 | '-hc', join(examples, 'hp_config.json'),
29 | '-cc', join(examples, 'class_config.json'),
30 | '-l', join(examples, 'label.csv')
31 | ],
32 | catch_exceptions=False)
33 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code)
34 |
35 | def test_inference():
36 | tmpdir = temp_dir
37 | from deepspectrumlite.__main__ import cli
38 | from click.testing import CliRunner
39 | from os.path import dirname, join
40 |
41 | cur_dir = dirname(__file__)
42 | examples = join(dirname(dirname(cur_dir)), 'examples')
43 |
44 | runner = CliRunner()
45 | result = runner.invoke(cli,
46 | args=[
47 | '-vv', 'predict',
48 | '-d', join(cur_dir, 'audio'),
49 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0',
50 | 'model.h5'),
51 | '-hc', join(examples, 'hp_config.json'),
52 | '-cc', join(examples, 'class_config.json')
53 | ],
54 | catch_exceptions=False)
55 | assert result.exit_code == 0, f"Exit code for predict is not 0 but " + str(result.exit_code)
56 |
57 |
58 | def test_devel_test():
59 | tmpdir = temp_dir
60 | from deepspectrumlite.__main__ import cli
61 | from click.testing import CliRunner
62 | from os.path import dirname, join
63 |
64 | cur_dir = dirname(__file__)
65 | examples = join(dirname(dirname(cur_dir)), 'examples')
66 |
67 | runner = CliRunner()
68 | result = runner.invoke(cli,
69 | args=[
70 | '-vv', 'devel-test',
71 | '-d', join(examples, 'audio'),
72 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0',
73 | 'model.h5'),
74 | '-hc', join(examples, 'hp_config.json'),
75 | '-cc', join(examples, 'class_config.json'),
76 | '-l', join(examples, 'label.csv')
77 | ],
78 | catch_exceptions=False)
79 | assert result.exit_code == 0, f"Exit code for devel-test is not 0 but " + str(result.exit_code)
80 |
81 |
82 | def test_stats():
83 | tmpdir = temp_dir
84 | from deepspectrumlite.__main__ import cli
85 | from click.testing import CliRunner
86 | from os.path import dirname, join
87 |
88 | runner = CliRunner()
89 | result = runner.invoke(cli,
90 | args=[
91 | '-vv', 'stats',
92 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0',
93 | 'model.h5')
94 | ],
95 | catch_exceptions=False)
96 | assert result.exit_code == 0, f"Exit code for stats is not 0 but " + str(result.exit_code)
97 |
98 |
99 | def test_convert():
100 | tmpdir = temp_dir
101 | from deepspectrumlite.__main__ import cli
102 | from click.testing import CliRunner
103 | from os.path import dirname, join
104 |
105 | runner = CliRunner()
106 | result = runner.invoke(cli,
107 | args=[
108 | '-vv', 'convert',
109 | '-s', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0',
110 | 'model.h5'),
111 | '-d', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0',
112 | 'converted_model.tflite')
113 | ],
114 | catch_exceptions=False)
115 | assert result.exit_code == 0, f"Exit code for convert is not 0 but " + str(result.exit_code)
116 |
117 |
118 | def test_tflite_stats():
119 | tmpdir = temp_dir
120 | from deepspectrumlite.__main__ import cli
121 | from click.testing import CliRunner
122 | from os.path import join
123 |
124 | runner = CliRunner()
125 | result = runner.invoke(cli,
126 | args=[
127 | '-vv', 'tflite-stats',
128 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0',
129 | 'converted_model.tflite')
130 | ],
131 | catch_exceptions=False)
132 | assert result.exit_code == 0, f"Exit code for tflite-stats is not 0 but " + str(result.exit_code)
133 |
134 |
135 | def test_create_preprocessor():
136 | tmpdir = temp_dir
137 | from deepspectrumlite.__main__ import cli
138 | from click.testing import CliRunner
139 | from os.path import dirname, join
140 |
141 | cur_dir = dirname(__file__)
142 | examples = join(dirname(dirname(cur_dir)), 'examples')
143 |
144 | runner = CliRunner()
145 | result = runner.invoke(cli,
146 | args=[
147 | '-vv', 'create-preprocessor',
148 | '-hc', join(examples, 'hp_config.json'),
149 | '-d', join(tmpdir, 'output', 'preprocessor.tflite')
150 | ],
151 | catch_exceptions=False)
152 | assert result.exit_code == 0, f"Exit code for create_preprocessor is not 0 but " + str(result.exit_code)
153 |
154 |
155 | def pytest_sessionfinish():
156 | shutil.rmtree(temp_dir)
157 |
158 | # if __name__ == '__main__':
159 | # print(test_create_preprocessor())
--------------------------------------------------------------------------------
/tests/cli/test_train_config.py:
--------------------------------------------------------------------------------
1 | def test_cividis(tmpdir):
2 | from deepspectrumlite.__main__ import cli
3 | from click.testing import CliRunner
4 | from os.path import dirname, join
5 |
6 | cur_dir = dirname(__file__)
7 | examples = join(dirname(dirname(cur_dir)), 'examples')
8 |
9 | runner = CliRunner()
10 | result = runner.invoke(cli,
11 | args=[
12 | '-vv', 'train',
13 | '-d', join(examples, 'audio'),
14 | '-md', join(tmpdir, 'output'),
15 | '-hc', join(cur_dir, 'config', 'cividis_hp_config.json'),
16 | '-cc', join(examples, 'class_config.json'),
17 | '-l', join(examples, 'label.csv')
18 | ],
19 | catch_exceptions=False)
20 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code)
21 |
22 | def test_plasma(tmpdir):
23 | from deepspectrumlite.__main__ import cli
24 | from click.testing import CliRunner
25 | from os.path import dirname, join
26 |
27 | cur_dir = dirname(__file__)
28 | examples = join(dirname(dirname(cur_dir)), 'examples')
29 |
30 | runner = CliRunner()
31 | result = runner.invoke(cli,
32 | args=[
33 | '-vv', 'train',
34 | '-d', join(examples, 'audio'),
35 | '-md', join(tmpdir, 'output'),
36 | '-hc', join(cur_dir, 'config', 'plasma_hp_config.json'),
37 | '-cc', join(examples, 'class_config.json'),
38 | '-l', join(examples, 'label.csv')
39 | ],
40 | catch_exceptions=False)
41 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code)
42 |
43 | def test_inferno(tmpdir):
44 | from deepspectrumlite.__main__ import cli
45 | from click.testing import CliRunner
46 | from os.path import dirname, join
47 |
48 | cur_dir = dirname(__file__)
49 | examples = join(dirname(dirname(cur_dir)), 'examples')
50 |
51 | runner = CliRunner()
52 | result = runner.invoke(cli,
53 | args=[
54 | '-vv', 'train',
55 | '-d', join(examples, 'audio'),
56 | '-md', join(tmpdir, 'output'),
57 | '-hc', join(cur_dir, 'config', 'inferno_hp_config.json'),
58 | '-cc', join(examples, 'class_config.json'),
59 | '-l', join(examples, 'label.csv')
60 | ],
61 | catch_exceptions=False)
62 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code)
63 |
64 | def test_magma(tmpdir):
65 | from deepspectrumlite.__main__ import cli
66 | from click.testing import CliRunner
67 | from os.path import dirname, join
68 |
69 | cur_dir = dirname(__file__)
70 | examples = join(dirname(dirname(cur_dir)), 'examples')
71 |
72 | runner = CliRunner()
73 | result = runner.invoke(cli,
74 | args=[
75 | '-vv', 'train',
76 | '-d', join(examples, 'audio'),
77 | '-md', join(tmpdir, 'output'),
78 | '-hc', join(cur_dir, 'config', 'magma_hp_config.json'),
79 | '-cc', join(examples, 'class_config.json'),
80 | '-l', join(examples, 'label.csv')
81 | ],
82 | catch_exceptions=False)
83 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code)
84 |
85 | def test_regression(tmpdir):
86 | from deepspectrumlite.__main__ import cli
87 | from click.testing import CliRunner
88 | from os.path import dirname, join
89 |
90 | cur_dir = dirname(__file__)
91 | examples = join(dirname(dirname(cur_dir)), 'examples')
92 |
93 | runner = CliRunner()
94 | result = runner.invoke(cli,
95 | args=[
96 | '-vv', 'train',
97 | '-d', join(examples, 'audio'),
98 | '-md', join(tmpdir, 'output'),
99 | '-hc', join(cur_dir, 'config', 'regression_hp_config.json'),
100 | '-cc', join(cur_dir, 'config', 'regression_class_config.json'),
101 | '-l', join(cur_dir, 'label', 'regression_label.csv')
102 | ],
103 | catch_exceptions=False)
104 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code)
--------------------------------------------------------------------------------