├── .coveragerc ├── .github └── workflows │ ├── ci.yml │ └── python-publish.yml ├── .gitignore ├── DESCRIPTION.md ├── LICENSE ├── README.md ├── environment.yml ├── examples ├── audio │ ├── devel_dog.wav │ ├── devel_seagull.wav │ ├── test_dog.wav │ ├── test_seagull.wav │ ├── train_dog.wav │ └── train_seagull.wav ├── class_config.json ├── hp_config.json └── label.csv ├── notebooks └── explain │ ├── environment.yml │ ├── model │ └── ccs │ │ └── model.h5 │ ├── shap.ipynb │ ├── shap │ └── ccs │ │ └── 100 │ │ └── shap_ccs_100.pdf │ └── spectrograms │ └── ccs │ ├── Negative │ └── test_010.png │ └── Positive │ └── test_039.png ├── pyproject.toml ├── requirements-test.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── deepspectrumlite │ ├── __init__.py │ ├── __main__.py │ ├── cli │ ├── __init__.py │ ├── config │ │ ├── .gitignore │ │ ├── class_config.json │ │ └── hp_config.json │ ├── convert.py │ ├── create_preprocessor.py │ ├── devel_test.py │ ├── predict.py │ ├── stats.py │ ├── tflite_stats.py │ ├── train.py │ └── utils.py │ └── lib │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── data_pipeline.py │ ├── embedded │ │ ├── __init__.py │ │ └── preprocessor.py │ ├── parser │ │ ├── ComParEParser.py │ │ └── __init__.py │ └── plot │ │ ├── __init__.py │ │ ├── color_maps │ │ ├── __init__.py │ │ ├── abstract_colormap.py │ │ ├── cividis.py │ │ ├── inferno.py │ │ ├── magma.py │ │ ├── plasma.py │ │ └── viridis.py │ │ └── colormap.py │ ├── hyperparameter.py │ ├── model │ ├── TransferBaseModel.py │ ├── __init__.py │ ├── ai_model.py │ ├── config │ │ ├── __init__.py │ │ └── gridsearch.py │ └── modules │ │ ├── __init__.py │ │ ├── arelu.py │ │ ├── attention_module.py │ │ ├── augmentable_model.py │ │ └── squeeze_net.py │ └── util │ ├── __init__.py │ └── audio_utils.py └── tests ├── __init__.py └── cli ├── __init__.py ├── audio ├── dog │ └── file0.wav └── seagull │ └── file1.wav ├── config ├── cividis_hp_config.json ├── inferno_hp_config.json ├── magma_hp_config.json ├── plasma_hp_config.json ├── regression_class_config.json └── regression_hp_config.json ├── label └── regression_label.csv ├── test_cli.py └── test_train_config.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source = deepspectrumlite 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | pip: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: [3.8] 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install libsndfile 26 | run: sudo apt-get install libsndfile1-dev 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install . 31 | if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi 32 | - name: Test with pytest 33 | run: | 34 | py.test -s tests/ --cov deepspectrumlite --cov-report xml:coverage.xml 35 | - name: Coverage 36 | uses: codecov/codecov-action@v1 37 | with: 38 | file: ./coverage.xml # optional 39 | fail_ci_if_error: true # optional (default = false) -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.idea 3 | *.DS_Store 4 | *.tflite 5 | venv 6 | env 7 | *.gz 8 | /output/* 9 | recorded 10 | __pycache__ 11 | build/* 12 | dist/* 13 | *.egg-info 14 | .coverage 15 | .coverage.* 16 | *.pytest_cache 17 | .direnv/ 18 | -------------------------------------------------------------------------------- /DESCRIPTION.md: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | DeepSpectrumLite is a Python toolkit to design and train light-weight Deep Neural Networks (DNNs) for classification tasks from raw audio data . 3 | The trained models run on embedded devices. 4 | 5 | DeepSpectrumLite features an extraction pipeline which first creates visual representations for audio data - plots of spectrograms. 6 | The image splots are then fed to a DNN. This could be a pre-trained Image Convolutional Neural Network (CNN). 7 | Activations of a specific layer then form the final feature vectors which are used for the final classification. 8 | 9 | The trained models can be easily converted to a TensorFlow Lite model. During the converting process, the model becomes smaller and faster optimised for inference on embedded devices. 10 | 11 | **(c) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, Sandra Ottl, Björn Schuller: Universität Augsburg** 12 | Published under GPLv3, please see the `LICENSE` file for details. 13 | 14 | Please direct any questions or requests to Shahin Amiriparian (shahin.amiriparian at informatik.uni-augsburg.de) or Tobias Hübner (tobias.huebner at informatik.uni-augsburg.de). 15 | 16 | # Why DeepSpectrumLite? 17 | DeepSpectrumLite is built upon TensorFlow Lite which is a specialised version of TensorFlow that supports embedded decvies. 18 | However, TensorFlow Lite does not support all basic TensorFlow functions for audio signal processing and plot image generation. DeepSpectrumLite offers implementations for unsupported functions. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: DeepSpectrumLite 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - pip 8 | - cudatoolkit=11.2 9 | - cudnn=8.1 10 | - pip: 11 | - click 12 | - librosa 13 | - numba 14 | - pillow 15 | - pandas 16 | - scikit-learn 17 | - tensorflow==2.5.1 18 | - tensorboard==2.5 19 | - keras-applications 20 | - -e . 21 | -------------------------------------------------------------------------------- /examples/audio/devel_dog.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/devel_dog.wav -------------------------------------------------------------------------------- /examples/audio/devel_seagull.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/devel_seagull.wav -------------------------------------------------------------------------------- /examples/audio/test_dog.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/test_dog.wav -------------------------------------------------------------------------------- /examples/audio/test_seagull.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/test_seagull.wav -------------------------------------------------------------------------------- /examples/audio/train_dog.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/train_dog.wav -------------------------------------------------------------------------------- /examples/audio/train_seagull.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/examples/audio/train_seagull.wav -------------------------------------------------------------------------------- /examples/class_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "dog": "Dog", 3 | "seagull": "Seagull" 4 | } -------------------------------------------------------------------------------- /examples/hp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"], 3 | "model_name": ["TransferBaseModel"], 4 | "prediction_type": ["categorical"], 5 | "basemodel_name": ["densenet121"], 6 | "weights": ["imagenet"], 7 | "tb_experiment": ["densenet_exp"], 8 | "tb_run_id": ["densenet121_run"], 9 | "num_units": [4], 10 | "dropout": [0.25], 11 | "optimizer": ["adadelta"], 12 | "learning_rate": [0.001], 13 | "fine_learning_rate": [0.0001], 14 | "finetune_layer": [0.7], 15 | "loss": ["categorical_crossentropy"], 16 | "activation": ["arelu"], 17 | "pre_epochs": [1], 18 | "epochs": [1], 19 | "batch_size": [1], 20 | 21 | "sample_rate": [16000], 22 | 23 | "chunk_size": [2.0], 24 | "chunk_hop_size": [1.0], 25 | "normalize_audio": [true], 26 | 27 | "stft_window_size": [0.128], 28 | "stft_hop_size": [0.064], 29 | "stft_fft_length": [0.128], 30 | 31 | "mel_scale": [true], 32 | "lower_edge_hertz": [0.0], 33 | "upper_edge_hertz": [8000.0], 34 | "num_mel_bins": [128], 35 | "num_mfccs": [0], 36 | "cep_lifter": [0], 37 | "db_scale": [true], 38 | "use_plot_images": [true], 39 | "color_map": ["viridis"], 40 | "image_width": [224], 41 | "image_height": [224], 42 | "resize_method": ["nearest"], 43 | "anti_alias": [false], 44 | 45 | "sap_aug_a": [0.5], 46 | "sap_aug_s": [10], 47 | "augment_cutmix": [true], 48 | "augment_specaug": [true], 49 | "da_prob_min": [0.1], 50 | "da_prob_max": [0.5], 51 | "cutmix_min": [0.075], 52 | "cutmix_max": [0.25], 53 | "specaug_freq_min": [0.1], 54 | "specaug_freq_max": [0.3], 55 | "specaug_time_min": [0.1], 56 | "specaug_time_max": [0.3], 57 | "specaug_freq_mask_num": [1], 58 | "specaug_time_mask_num": [1] 59 | } -------------------------------------------------------------------------------- /examples/label.csv: -------------------------------------------------------------------------------- 1 | filename,label,duration_frames 2 | train_seagull.wav,seagull,117249 3 | train_dog.wav,dog,87865 4 | devel_seagull.wav,seagull,117249 5 | devel_dog.wav,dog,87865 6 | test_seagull.wav,seagull,117249 7 | test_dog.wav,dog,87865 -------------------------------------------------------------------------------- /notebooks/explain/environment.yml: -------------------------------------------------------------------------------- 1 | name: shap 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=1_gnu 9 | - alsa-lib=1.2.3=h516909a_0 10 | - aom=3.2.0=h9c3ff4c_2 11 | - asttokens=2.0.5=pyhd8ed1ab_0 12 | - backcall=0.2.0=pyh9f0ad1d_0 13 | - backports=1.0=py_2 14 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 15 | - bzip2=1.0.8=h7f98852_4 16 | - c-ares=1.18.1=h7f98852_0 17 | - ca-certificates=2021.10.8=ha878542_0 18 | - cairo=1.16.0=h6cf1ce9_1008 19 | - certifi=2021.10.8=py38h578d9bd_1 20 | - dbus=1.13.6=h5008d03_3 21 | - debugpy=1.5.1=py38h709712a_0 22 | - decorator=5.1.1=pyhd8ed1ab_0 23 | - entrypoints=0.4=pyhd8ed1ab_0 24 | - executing=0.8.2=pyhd8ed1ab_0 25 | - expat=2.4.4=h9c3ff4c_0 26 | - ffmpeg=4.4.1=h6987444_1 27 | - fontconfig=2.13.96=ha180cfb_0 28 | - freetype=2.10.4=h0708190_1 29 | - gettext=0.19.8.1=h73d1719_1008 30 | - gmp=6.2.1=h58526e2_0 31 | - gnutls=3.6.13=h85f3911_1 32 | - graphite2=1.3.13=h58526e2_1001 33 | - gst-plugins-base=1.18.5=hf529b03_3 34 | - gstreamer=1.18.5=h9f60fe5_3 35 | - harfbuzz=2.9.1=h83ec7ef_1 36 | - hdf5=1.12.1=nompi_h2750804_103 37 | - icu=68.2=h9c3ff4c_0 38 | - ipykernel=6.9.1=py38he5a9106_0 39 | - ipython=8.0.1=py38h578d9bd_1 40 | - jasper=1.900.1=h07fcdf6_1006 41 | - jbig=2.1=h7f98852_2003 42 | - jedi=0.18.1=py38h578d9bd_0 43 | - jpeg=9e=h7f98852_0 44 | - jupyter_client=7.1.2=pyhd8ed1ab_0 45 | - jupyter_core=4.9.2=py38h578d9bd_0 46 | - krb5=1.19.2=hcc1bbae_3 47 | - lame=3.100=h7f98852_1001 48 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 49 | - lerc=3.0=h9c3ff4c_0 50 | - libblas=3.9.0=13_linux64_openblas 51 | - libcblas=3.9.0=13_linux64_openblas 52 | - libclang=11.1.0=default_ha53f305_1 53 | - libcurl=7.81.0=h2574ce0_0 54 | - libdeflate=1.10=h7f98852_0 55 | - libdrm=2.4.109=h7f98852_0 56 | - libedit=3.1.20191231=he28a2e2_2 57 | - libev=4.33=h516909a_1 58 | - libevent=2.1.10=h9b69904_4 59 | - libffi=3.4.2=h7f98852_5 60 | - libgcc-ng=11.2.0=h1d223b6_12 61 | - libgfortran-ng=11.2.0=h69a702a_12 62 | - libgfortran5=11.2.0=h5c6108e_12 63 | - libglib=2.70.2=h174f98d_4 64 | - libgomp=11.2.0=h1d223b6_12 65 | - libiconv=1.16=h516909a_0 66 | - liblapack=3.9.0=13_linux64_openblas 67 | - liblapacke=3.9.0=13_linux64_openblas 68 | - libllvm11=11.1.0=hf817b99_3 69 | - libnghttp2=1.46.0=h812cca2_0 70 | - libnsl=2.0.0=h7f98852_0 71 | - libogg=1.3.4=h7f98852_1 72 | - libopenblas=0.3.18=pthreads_h8fe5266_0 73 | - libopencv=4.5.3=py38hafa78d9_3 74 | - libopus=1.3.1=h7f98852_1 75 | - libpciaccess=0.16=h516909a_0 76 | - libpng=1.6.37=h21135ba_2 77 | - libpq=13.5=hd57d9b9_1 78 | - libprotobuf=3.18.1=h780b84a_0 79 | - libsodium=1.0.18=h36c2ea0_1 80 | - libssh2=1.10.0=ha56f1ee_2 81 | - libstdcxx-ng=11.2.0=he4da1e4_12 82 | - libtiff=4.3.0=h542a066_3 83 | - libuuid=2.32.1=h7f98852_1000 84 | - libva=2.14.0=h7f98852_0 85 | - libvorbis=1.3.7=h9c3ff4c_0 86 | - libvpx=1.11.0=h9c3ff4c_3 87 | - libwebp-base=1.2.2=h7f98852_1 88 | - libxcb=1.13=h7f98852_1004 89 | - libxkbcommon=1.0.3=he3ba5ed_0 90 | - libxml2=2.9.12=h72842e0_0 91 | - libzlib=1.2.11=h36c2ea0_1013 92 | - lz4-c=1.9.3=h9c3ff4c_1 93 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 94 | - mysql-common=8.0.28=ha770c72_0 95 | - mysql-libs=8.0.28=hfa10184_0 96 | - ncurses=6.3=h9c3ff4c_0 97 | - nest-asyncio=1.5.4=pyhd8ed1ab_0 98 | - nettle=3.6=he412f7d_0 99 | - nspr=4.32=h9c3ff4c_1 100 | - nss=3.74=hb5efdd6_0 101 | - opencv=4.5.3=py38h578d9bd_3 102 | - openh264=2.1.1=h780b84a_0 103 | - openssl=1.1.1l=h7f98852_0 104 | - parso=0.8.3=pyhd8ed1ab_0 105 | - pcre=8.45=h9c3ff4c_0 106 | - pexpect=4.8.0=pyh9f0ad1d_2 107 | - pickleshare=0.7.5=py_1003 108 | - pip=22.0.3=pyhd8ed1ab_0 109 | - pixman=0.40.0=h36c2ea0_0 110 | - prompt-toolkit=3.0.27=pyha770c72_0 111 | - pthread-stubs=0.4=h36c2ea0_1001 112 | - ptyprocess=0.7.0=pyhd3deb0d_0 113 | - pure_eval=0.2.2=pyhd8ed1ab_0 114 | - py-opencv=4.5.3=py38he5a9106_3 115 | - pygments=2.11.2=pyhd8ed1ab_0 116 | - python=3.8.12=ha38a3c6_3_cpython 117 | - python-dateutil=2.8.2=pyhd8ed1ab_0 118 | - python_abi=3.8=2_cp38 119 | - pyzmq=22.3.0=py38h2035c66_1 120 | - qt=5.12.9=hda022c4_4 121 | - readline=8.1=h46c0cb4_0 122 | - setuptools=60.9.2=py38h578d9bd_0 123 | - sqlite=3.37.0=h9cd32fc_0 124 | - stack_data=0.2.0=pyhd8ed1ab_0 125 | - svt-av1=0.9.0=h9c3ff4c_0 126 | - tk=8.6.12=h27826a3_0 127 | - tornado=6.1=py38h497a2fe_2 128 | - traitlets=5.1.1=pyhd8ed1ab_0 129 | - wcwidth=0.2.5=pyh9f0ad1d_2 130 | - wheel=0.37.1=pyhd8ed1ab_0 131 | - x264=1!161.3030=h7f98852_1 132 | - x265=3.5=h4bd325d_1 133 | - xorg-fixesproto=5.0=h7f98852_1002 134 | - xorg-kbproto=1.0.7=h7f98852_1002 135 | - xorg-libice=1.0.10=h7f98852_0 136 | - xorg-libsm=1.2.3=hd9c2040_1000 137 | - xorg-libx11=1.7.2=h7f98852_0 138 | - xorg-libxau=1.0.9=h7f98852_0 139 | - xorg-libxdmcp=1.1.3=h7f98852_0 140 | - xorg-libxext=1.3.4=h7f98852_1 141 | - xorg-libxfixes=5.0.3=h7f98852_1004 142 | - xorg-libxrender=0.9.10=h7f98852_1003 143 | - xorg-renderproto=0.11.1=h7f98852_1002 144 | - xorg-xextproto=7.3.0=h7f98852_1002 145 | - xorg-xproto=7.0.31=h7f98852_1007 146 | - xz=5.2.5=h516909a_1 147 | - zeromq=4.3.4=h9c3ff4c_1 148 | - zlib=1.2.11=h36c2ea0_1013 149 | - zstd=1.5.2=ha95c52a_0 150 | - pip: 151 | - absl-py==0.15.0 152 | - appdirs==1.4.4 153 | - astunparse==1.6.3 154 | - audioread==2.1.9 155 | - cachetools==4.2.4 156 | - cffi==1.15.0 157 | - charset-normalizer==2.0.11 158 | - click==8.0.3 159 | - cloudpickle==2.0.0 160 | - cycler==0.11.0 161 | - deepspectrumlite==1.0.2 162 | - flatbuffers==1.12 163 | - fonttools==4.29.1 164 | - gast==0.4.0 165 | - google-auth==1.35.0 166 | - google-auth-oauthlib==0.4.6 167 | - google-pasta==0.2.0 168 | - grpcio==1.34.1 169 | - h5py==3.1.0 170 | - idna==3.3 171 | - importlib-metadata==4.10.1 172 | - joblib==1.1.0 173 | - keras-applications==1.0.8 174 | - keras-nightly==2.5.0.dev2021032900 175 | - keras-preprocessing==1.1.2 176 | - kiwisolver==1.3.2 177 | - librosa==0.9.0 178 | - llvmlite==0.38.0 179 | - markdown==3.3.6 180 | - matplotlib==3.5.1 181 | - numba==0.55.1 182 | - numpy==1.19.5 183 | - oauthlib==3.2.0 184 | - opt-einsum==3.3.0 185 | - p2j==1.3.2 186 | - packaging==21.3 187 | - pandas==1.4.0 188 | - pillow==9.0.1 189 | - pooch==1.6.0 190 | - protobuf==3.19.4 191 | - pyasn1==0.4.8 192 | - pyasn1-modules==0.2.8 193 | - pycparser==2.21 194 | - pyparsing==3.0.7 195 | - pytz==2021.3 196 | - requests==2.27.1 197 | - requests-oauthlib==1.3.1 198 | - resampy==0.2.2 199 | - rsa==4.8 200 | - scikit-learn==1.0.2 201 | - scipy==1.8.0 202 | - shap==0.40.0 203 | - six==1.15.0 204 | - slicer==0.0.7 205 | - soundfile==0.10.3.post1 206 | - tensorboard==2.5.0 207 | - tensorboard-data-server==0.6.1 208 | - tensorboard-plugin-wit==1.8.1 209 | - tensorflow==2.5.1 210 | - tensorflow-estimator==2.5.0 211 | - termcolor==1.1.0 212 | - threadpoolctl==3.1.0 213 | - tqdm==4.62.3 214 | - typing-extensions==3.7.4.3 215 | - urllib3==1.26.8 216 | - werkzeug==2.0.3 217 | - wrapt==1.12.1 218 | - zipp==3.7.0 -------------------------------------------------------------------------------- /notebooks/explain/model/ccs/model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/model/ccs/model.h5 -------------------------------------------------------------------------------- /notebooks/explain/shap/ccs/100/shap_ccs_100.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/shap/ccs/100/shap_ccs_100.pdf -------------------------------------------------------------------------------- /notebooks/explain/spectrograms/ccs/Negative/test_010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/spectrograms/ccs/Negative/test_010.png -------------------------------------------------------------------------------- /notebooks/explain/spectrograms/ccs/Positive/test_039.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/notebooks/explain/spectrograms/ccs/Positive/test_039.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest==6.2.3 2 | coveralls 3 | pytest-cov 4 | librosa 5 | numba 6 | click 7 | pillow 8 | pandas 9 | scikit-learn 10 | tensorflow==2.5.1 11 | tensorboard==2.5 12 | keras-applications -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa 2 | numba 3 | click 4 | pillow 5 | pandas 6 | scikit-learn 7 | tensorflow==2.5.1 8 | tensorboard==2.5 9 | keras-applications -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import re 3 | import sys 4 | import warnings 5 | 6 | warnings.filterwarnings('ignore', category=DeprecationWarning) 7 | warnings.filterwarnings('ignore', category=FutureWarning) 8 | 9 | from setuptools import setup, find_packages 10 | from subprocess import CalledProcessError, check_output 11 | 12 | PROJECT = "DeepSpectrumLite" 13 | VERSION = "1.0.2" 14 | LICENSE = "GPLv3+" 15 | AUTHOR = "Tobias Hübner" 16 | AUTHOR_EMAIL = "tobias.huebner@informatik.uni-augsburg.de" 17 | URL = 'https://github.com/DeepSpectrum/DeepSpectrumLite' 18 | 19 | with open("DESCRIPTION.md", "r") as fh: 20 | LONG_DESCRIPTION = fh.read() 21 | 22 | install_requires = [ 23 | "librosa", 24 | "numba", 25 | "pillow", 26 | "pandas", 27 | "scikit-learn", 28 | "click", 29 | "tensorflow==2.5.1", 30 | "tensorboard==2.5", 31 | "keras-applications" 32 | ] 33 | 34 | tests_require = ['pytest>=4.4.1', 'pytest-cov>=2.7.1'] 35 | needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) 36 | setup_requires = ['pytest-runner'] if needs_pytest else [] 37 | packages = find_packages('src') 38 | 39 | setup( 40 | name=PROJECT, 41 | version=VERSION, 42 | license=LICENSE, 43 | author=AUTHOR, 44 | author_email=AUTHOR_EMAIL, 45 | long_description=LONG_DESCRIPTION, 46 | long_description_content_type="text/markdown", 47 | descrption="DeepSpectrumLite is a Python toolkit for training light-weight CNN networks targeted at embedded devices.", 48 | platforms=["Any"], 49 | scripts=[], 50 | provides=[], 51 | python_requires="~=3.8.0", 52 | install_requires=install_requires, 53 | setup_requires=setup_requires, 54 | tests_require=tests_require, 55 | namespace_packages=[], 56 | packages=packages, 57 | package_dir={'': 'src'}, 58 | include_package_data=True, 59 | entry_points={ 60 | "console_scripts": [ 61 | "deepspectrumlite = deepspectrumlite.__main__:cli", 62 | ] 63 | }, 64 | classifiers=[ 65 | # How mature is this project? Common values are 66 | # 3 - Alpha 67 | # 4 - Beta 68 | # 5 - Production/Stable 69 | 'Development Status :: 4 - Beta', 70 | 71 | 'Environment :: GPU :: NVIDIA CUDA :: 11.0', 72 | # Indicate who your project is intended for 73 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 74 | 'Intended Audience :: Science/Research', 75 | 76 | # Pick your license as you wish (should match "license" above) 77 | 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', 78 | 79 | 'Programming Language :: Python :: 3.8', 80 | ], 81 | keywords='machine-learning audio-analysis science research', 82 | project_urls={ 83 | 'Source': 'https://github.com/DeepSpectrum/DeepSpectrumLite', 84 | 'Tracker': 'https://github.com/DeepSpectrum/DeepSpectrumLite/issues', 85 | }, 86 | url=URL, 87 | zip_safe=False, 88 | ) -------------------------------------------------------------------------------- /src/deepspectrumlite/__init__.py: -------------------------------------------------------------------------------- 1 | from .lib.util import * 2 | import deepspectrumlite.lib.data.plot as plot 3 | from .lib import HyperParameterList 4 | from .lib.model.ai_model import Model 5 | from .lib.model.TransferBaseModel import TransferBaseModel 6 | from .lib.data.embedded.preprocessor import * 7 | from .lib.data.data_pipeline import DataPipeline 8 | from .lib.model.modules.augmentable_model import * 9 | from .lib.model.config.gridsearch import * 10 | from .lib.model.modules.arelu import * 11 | from .lib.model.modules.squeeze_net import * 12 | import logging 13 | import sys 14 | 15 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 16 | __version__ = '1.0.2' 17 | -------------------------------------------------------------------------------- /src/deepspectrumlite/__main__.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import sys, os 20 | 21 | import warnings 22 | warnings.filterwarnings('ignore', category=DeprecationWarning) 23 | warnings.filterwarnings('ignore', category=FutureWarning) 24 | 25 | import click 26 | import logging 27 | import logging.config 28 | import pkg_resources 29 | 30 | from deepspectrumlite.cli.train import train 31 | from deepspectrumlite.cli.devel_test import devel_test 32 | from deepspectrumlite.cli.stats import stats 33 | from deepspectrumlite.cli.tflite_stats import tflite_stats 34 | from deepspectrumlite.cli.create_preprocessor import create_preprocessor 35 | from deepspectrumlite.cli.convert import convert 36 | from deepspectrumlite.cli.predict import predict 37 | from deepspectrumlite.cli.utils import add_options 38 | from deepspectrumlite import __version__ as VERSION 39 | 40 | 41 | _global_options = [ 42 | click.option('-v', '--verbose', count=True), 43 | ] 44 | 45 | 46 | version_str = f"DeepSpectrumLite %(version)s\nCopyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, Sandra Ottl, " \ 47 | "Björn Schuller\n" \ 48 | "License GPLv3+: GNU GPL version 3 or later .\n" \ 49 | "This is free software: you are free to change and redistribute it.\n" \ 50 | "There is NO WARRANTY, to the extent permitted by law." 51 | 52 | @click.group() 53 | @add_options(_global_options) 54 | @click.version_option(VERSION, message=version_str) 55 | @click.pass_context 56 | def cli(ctx, verbose): 57 | log_levels = ['ERROR', 'INFO', 'DEBUG'] 58 | verbose = min(2, verbose) 59 | ctx.ensure_object(dict) 60 | ctx.obj['verbose'] = verbose 61 | 62 | if verbose == 2: 63 | level = logging.DEBUG 64 | elif verbose == 1: 65 | level = logging.INFO 66 | else: 67 | level = logging.ERROR 68 | logging.basicConfig() 69 | logging.config.dictConfig({ 70 | 'version': 1, 71 | 'disable_existing_loggers': False, # this fixes the problem 72 | 'formatters': { 73 | 'standard': { 74 | 'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s' 75 | }, 76 | }, 77 | 'handlers': { 78 | 'default': { 79 | 'level': log_levels[verbose], 80 | 'class': 'logging.StreamHandler', 81 | 'formatter': 'standard', 82 | 'stream': 'ext://sys.stdout' 83 | }, 84 | }, 85 | 'loggers': { 86 | '': { 87 | 'handlers': ['default'], 88 | 'level': log_levels[verbose], 89 | 'propagate': True 90 | } 91 | } 92 | }) 93 | 94 | logging.debug('Verbosity: %s' % log_levels[verbose]) 95 | # logging.error("error test") 96 | # logging.debug("debug test") 97 | # logging.info("info test") 98 | 99 | os.environ['GLOG_minloglevel'] = '2' 100 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 101 | 102 | cli.add_command(train) 103 | cli.add_command(devel_test) 104 | cli.add_command(stats) 105 | cli.add_command(convert) 106 | cli.add_command(tflite_stats) 107 | cli.add_command(create_preprocessor) 108 | cli.add_command(predict) 109 | 110 | if __name__ == '__main__': 111 | cli(obj={}) -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/cli/__init__.py -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/config/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !class_config.json 4 | !hp_config.json -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/config/class_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "negative": "Negative", 3 | "positive": "Positive" 4 | } -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/config/hp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"], 3 | "model_name": ["TransferBaseModel"], 4 | "prediction_type": ["categorical"], 5 | "basemodel_name": ["densenet121"], 6 | "weights": ["imagenet"], 7 | "tb_experiment": ["densenet_exp"], 8 | "tb_run_id": ["densenet121_run"], 9 | "num_units": [512], 10 | "dropout": [0.25], 11 | "optimizer": ["adadelta"], 12 | "learning_rate": [0.001], 13 | "fine_learning_rate": [0.0001], 14 | "finetune_layer": [0.7], 15 | "loss": ["categorical_crossentropy"], 16 | "activation": ["arelu"], 17 | "output_activation": ["softmax"], 18 | "pre_epochs": [40], 19 | "epochs": [100], 20 | "batch_size": [32], 21 | 22 | "sample_rate": [16000], 23 | 24 | "chunk_size": [4.0], 25 | "chunk_hop_size": [2.0], 26 | "normalize_audio": [false], 27 | 28 | 29 | "stft_window_size": [0.128], 30 | "stft_hop_size": [0.064], 31 | "stft_fft_length": [0.128], 32 | 33 | "mel_scale": [true], 34 | "lower_edge_hertz": [0.0], 35 | "upper_edge_hertz": [8000.0], 36 | "num_mel_bins": [128], 37 | "num_mfccs": [0], 38 | "cep_lifter": [0], 39 | "db_scale": [true], 40 | "use_plot_images": [true], 41 | "color_map": ["viridis"], 42 | "image_width": [224], 43 | "image_height": [224], 44 | "resize_method": ["nearest"], 45 | "anti_alias": [false], 46 | 47 | "sap_aug_a": [0.5], 48 | "sap_aug_s": [10], 49 | "augment_cutmix": [true], 50 | "augment_specaug": [true], 51 | "da_prob_min": [0.1], 52 | "da_prob_max": [0.5], 53 | "cutmix_min": [0.075], 54 | "cutmix_max": [0.25], 55 | "specaug_freq_min": [0.1], 56 | "specaug_freq_max": [0.3], 57 | "specaug_time_min": [0.1], 58 | "specaug_time_max": [0.3], 59 | "specaug_freq_mask_num": [1], 60 | "specaug_time_mask_num": [1] 61 | } 62 | -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/convert.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import logging 20 | import click 21 | from .utils import add_options 22 | import os 23 | import sys 24 | import tensorflow as tf 25 | from tensorflow import keras 26 | import math 27 | from tensorflow.keras import backend as K 28 | from tensorflow.python.saved_model import loader_impl 29 | from deepspectrumlite import AugmentableModel, ARelu 30 | import numpy as np 31 | from os.path import join, dirname, realpath 32 | 33 | log = logging.getLogger(__name__) 34 | 35 | _DESCRIPTION = 'Converts a DeepSpectrumLite model to a TFLite model file.' 36 | 37 | @add_options( 38 | [ 39 | click.option( 40 | "-s", 41 | "--source", 42 | type=click.Path(exists=True, writable=False, readable=True), 43 | help="Source HD5 model file", 44 | required=True 45 | ), 46 | click.option( 47 | "-d", 48 | "--destination", 49 | type=click.Path(exists=False, writable=True, readable=True), 50 | help="Destination TFLite model file", 51 | required=True 52 | ) 53 | ] 54 | ) 55 | 56 | @click.command(help=_DESCRIPTION) 57 | def convert(source, destination, **kwargs): 58 | log.info("Load model: " + source) 59 | 60 | # loader_impl.parse_saved_model(source) 61 | 62 | new_model = tf.keras.models.load_model(source, 63 | custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, 64 | compile=False) 65 | log.info("Successfully loaded model: " + source) 66 | 67 | converter = tf.lite.TFLiteConverter.from_keras_model(new_model) 68 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, 69 | tf.lite.OpsSet.SELECT_TF_OPS] 70 | converter.experimental_new_converter = True 71 | tflite_quant_model = converter.convert() 72 | open(destination, "wb").write(tflite_quant_model) 73 | 74 | log.info("Model was saved as tflite as " + destination) 75 | 76 | interpreter = tf.lite.Interpreter(model_content=tflite_quant_model) 77 | 78 | input_details = interpreter.get_input_details() 79 | output_details = interpreter.get_output_details() 80 | 81 | interpreter.allocate_tensors() 82 | 83 | # interpreter.set_tensor(input_details[0]['index'], tf.convert_to_tensor(np.expand_dims(audio_data, 0), dtype=tf.float32)) 84 | 85 | interpreter.invoke() 86 | 87 | output = interpreter.get_tensor(output_details[0]['index']) 88 | 89 | # Test model on random input data. 90 | input_shape = input_details[0]['shape'] 91 | log.info("input shape: ") 92 | log.info(input_shape) 93 | log.info("output shape: ",) 94 | log.info(output_details[0]['shape']) 95 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 96 | interpreter.set_tensor(input_details[0]['index'], input_data) 97 | 98 | interpreter.invoke() 99 | output_data = interpreter.get_tensor(output_details[0]['index']) 100 | 101 | log.info(output_data) 102 | -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/create_preprocessor.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import logging 20 | import click 21 | from .utils import add_options 22 | import os 23 | import sys 24 | import tensorflow as tf 25 | from tensorflow import keras 26 | import math 27 | from tensorflow.keras import backend as K 28 | from tensorflow.python.saved_model import loader_impl 29 | from deepspectrumlite import HyperParameterList, PreprocessAudio 30 | import time 31 | import numpy as np 32 | from os.path import join, dirname, realpath 33 | 34 | log = logging.getLogger(__name__) 35 | 36 | _DESCRIPTION = 'Creates a DeepSpectrumLite preprocessor TFLite file.' 37 | 38 | @add_options( 39 | [ 40 | click.option( 41 | "-hc", 42 | "--hyper-config", 43 | type=click.Path(exists=True, writable=False, readable=True), 44 | help="Directory for the hyper parameter config file.", 45 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True 46 | ), 47 | click.option( 48 | "-d", 49 | "--destination", 50 | type=click.Path(exists=False, writable=True, readable=True), 51 | help="Destination of the TFLite preprocessor file", 52 | required=True 53 | ) 54 | ] 55 | ) 56 | 57 | @click.command(help=_DESCRIPTION) 58 | def create_preprocessor(hyper_config, destination, **kwargs): 59 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config) 60 | hparam_values = hyper_parameter_list.get_values(iteration_no=0) 61 | working_directory = dirname(destination) 62 | 63 | preprocess = PreprocessAudio(hparams=hparam_values, name="dsl_audio_preprocessor") 64 | input = tf.convert_to_tensor(np.array(np.random.random_sample((1, 16000)), dtype=np.float32), dtype=tf.float32) 65 | result = preprocess.preprocess(input) 66 | 67 | # ATTENTION: antialias is not supported in tflite 68 | tmp_save_path = os.path.join(working_directory, "preprocessor") 69 | os.makedirs(tmp_save_path, exist_ok=True) 70 | tf.saved_model.save(preprocess, tmp_save_path) 71 | 72 | # new_model = preprocess 73 | converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=tmp_save_path) 74 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, 75 | tf.lite.OpsSet.SELECT_TF_OPS] 76 | converter.experimental_new_converter = True 77 | tflite_quant_model = converter.convert() 78 | open(destination, "wb").write(tflite_quant_model) 79 | 80 | interpreter = tf.lite.Interpreter(model_path=destination) 81 | input_details = interpreter.get_input_details() 82 | output_details = interpreter.get_output_details() 83 | log.info(input_details) 84 | log.info(output_details) 85 | 86 | interpreter.allocate_tensors() 87 | 88 | interpreter.set_tensor(input_details[0]['index'], tf.convert_to_tensor(np.array(np.random.random_sample((1, 16000)), dtype=np.float32), dtype=tf.float32)) 89 | 90 | interpreter.invoke() 91 | 92 | output = interpreter.get_tensor(output_details[0]['index']) 93 | 94 | # Test model on random input data. 95 | input_shape = input_details[0]['shape'] 96 | log.info("input shape:") 97 | log.info(input_shape) 98 | log.info("output shape:") 99 | log.info(output_details[0]['shape']) 100 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 101 | interpreter.set_tensor(input_details[0]['index'], input_data) 102 | start_time = time.time() 103 | interpreter.invoke() 104 | stop_time = time.time() 105 | output_data = interpreter.get_tensor(output_details[0]['index']) 106 | 107 | log.info(output_data) 108 | log.info('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) 109 | log.info("Finished creating the TFLite preprocessor") 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/devel_test.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import logging 20 | import click 21 | from .utils import add_options 22 | import os 23 | import sys 24 | import tensorflow as tf 25 | from tensorflow import keras 26 | import math 27 | from tensorflow.keras import backend as K 28 | from tensorflow.python.saved_model import loader_impl 29 | from deepspectrumlite import AugmentableModel, DataPipeline, HyperParameterList, ARelu 30 | import glob 31 | from pathlib import Path 32 | import numpy as np 33 | import importlib 34 | import json 35 | from sklearn.metrics import recall_score, classification_report, confusion_matrix 36 | import csv 37 | import json 38 | import pandas as pd 39 | from collections import Counter 40 | from os.path import join, dirname, realpath 41 | 42 | log = logging.getLogger(__name__) 43 | 44 | _DESCRIPTION = 'Test a DeepSpectrumLite transer learning model.' 45 | 46 | @add_options( 47 | [ 48 | click.option( 49 | "-d", 50 | "--data-dir", 51 | type=click.Path(exists=True), 52 | help="Directory of data class categories containing folders of each data class.", 53 | required=True 54 | ), 55 | click.option( 56 | "-md", 57 | "--model-dir", 58 | type=click.Path(exists=False, writable=True), 59 | help="Path to HD5 model file", 60 | required=True 61 | ), 62 | click.option( 63 | "-hc", 64 | "--hyper-config", 65 | type=click.Path(exists=True, writable=False, readable=True), 66 | help="Directory for the hyper parameter config file.", 67 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True 68 | ), 69 | click.option( 70 | "-cc", 71 | "--class-config", 72 | type=click.Path(exists=True, writable=False, readable=True), 73 | help="Directory for the class config file.", 74 | default=join(dirname(realpath(__file__)), "config/class_config.json"), show_default=True 75 | ), 76 | click.option( 77 | "-l", 78 | "--label-file", 79 | type=click.Path(exists=True, writable=False, readable=True), 80 | help="Directory for the label file.", 81 | required=True 82 | ) 83 | ] 84 | ) 85 | 86 | @click.command(help=_DESCRIPTION) 87 | @click.pass_context 88 | def devel_test(ctx, model_dir, data_dir, class_config, hyper_config, label_file, **kwargs): 89 | verbose = ctx.obj['verbose'] 90 | f = open(class_config) 91 | data = json.load(f) 92 | f.close() 93 | 94 | data_dir = os.path.join(data_dir, '') 95 | 96 | data_classes = data 97 | 98 | if data_classes is None: 99 | raise ValueError('no data classes defined') 100 | 101 | class_list = {} 102 | for i, data_class in enumerate(data_classes): 103 | class_list[data_class] = i 104 | 105 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config) 106 | 107 | log.info("Search by rule: " + model_dir) 108 | model_dir_list = glob.glob(model_dir) 109 | log.info("Found " + str(len(model_dir_list)) + " files") 110 | 111 | for model_filename in model_dir_list: 112 | log.info("Load " + model_filename) 113 | p = Path(model_filename) 114 | parent = p.parent 115 | directory = parent.name 116 | 117 | result_dir = os.path.join(parent, "evaluation") 118 | 119 | iteration_no = int(directory.split("_")[-1]) 120 | 121 | log.info('--- Testing trial: %s' % iteration_no) 122 | hparam_values = hyper_parameter_list.get_values(iteration_no=iteration_no) 123 | log.info(hparam_values) 124 | 125 | label_parser_key = hparam_values['label_parser'] 126 | 127 | if ":" not in label_parser_key: 128 | raise ValueError('Please provide the parser in the following format: path.to.parser_file.py:ParserClass') 129 | 130 | log.info(f'Using custom external parser: {label_parser_key}') 131 | path, class_name = label_parser_key.split(':') 132 | module_name = os.path.splitext(os.path.basename(path))[0] 133 | dir_path = os.path.dirname(os.path.realpath(__file__)) 134 | path = os.path.join(dir_path, path) 135 | spec = importlib.util.spec_from_file_location(module_name, path) 136 | foo = importlib.util.module_from_spec(spec) 137 | spec.loader.exec_module(foo) 138 | parser_class = getattr(foo, class_name) 139 | 140 | parser = parser_class(file_path=label_file) 141 | _, devel_data, test_data = parser.parse_labels() 142 | log.info("Successfully parsed labels: " + label_file) 143 | model = tf.keras.models.load_model(model_filename, 144 | custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, 145 | compile=False) 146 | model.set_hyper_parameters(hparam_values) 147 | log.info("Successfully loaded model: " + model_filename) 148 | 149 | dataset_list = ["devel", "test"] 150 | 151 | for dataset_name in dataset_list: 152 | log.info("===== Dataset Partition: " + dataset_name) 153 | data_raw = [] 154 | if dataset_name == 'devel': 155 | data_raw = devel_data # [:10] 156 | elif dataset_name == 'test': 157 | data_raw = test_data # [:10] 158 | 159 | dataset_result_dir = os.path.join(result_dir, dataset_name) 160 | 161 | os.makedirs(dataset_result_dir, exist_ok=True) 162 | 163 | data_pipeline = DataPipeline(name=dataset_name + '_data_set', data_classes=data_classes, 164 | enable_gpu=True, verbose=True, enable_augmentation=False, 165 | hparams=hparam_values, run_id=iteration_no) 166 | data_pipeline.set_data(data_raw) 167 | data_pipeline.set_filename_prepend(prepend_filename_str=data_dir) 168 | data_pipeline.preprocess() 169 | filename_list = data_pipeline.filenames 170 | dataset = data_pipeline.pipeline(cache=False, shuffle=False, drop_remainder=False) 171 | 172 | X_pred = model.predict(x=dataset, verbose=verbose) 173 | true_categories = tf.concat([y for x, y in dataset], axis=0) 174 | 175 | X_pred = tf.argmax(X_pred, axis=1) 176 | X_pred_ny = X_pred.numpy() 177 | 178 | true_categories = tf.argmax(true_categories, axis=1) 179 | true_np = true_categories.numpy() 180 | cm = tf.math.confusion_matrix(true_categories, X_pred) 181 | log.info("Confusion Matrix (chunks):") 182 | log.info(cm.numpy()) 183 | 184 | target_names = [] 185 | for data_class in data_classes: 186 | target_names.append(data_class) 187 | 188 | log.info(classification_report(y_true=true_categories.numpy(), y_pred=X_pred_ny, 189 | target_names=target_names, 190 | digits=4)) 191 | 192 | recall = recall_score(y_true=true_categories.numpy(), y_pred=X_pred_ny, average='macro') 193 | log.info("UAR: " + str(recall * 100)) 194 | 195 | json_cm_dir = os.path.join(dataset_result_dir, dataset_name + ".chunks.metrics.json") 196 | with open(json_cm_dir, 'w') as f: 197 | json.dump({"cm": cm.numpy().tolist(), "uar": round(recall * 100, 4)}, f) 198 | 199 | X_pred_pd = pd.DataFrame(data=X_pred_ny, columns=["prediction"]) 200 | pd_filename_list = pd.DataFrame(data=filename_list[..., 0], columns=["filename"]) 201 | 202 | df = pd_filename_list.join(X_pred_pd, how='outer') 203 | df['filename'] = df['filename'].apply(lambda x: os.path.basename(x)) 204 | 205 | df.to_csv(os.path.join(dataset_result_dir, dataset_name + ".chunks.predictions.csv"), index=False) 206 | 207 | ###### grouped ####### 208 | 209 | grouped_data = df.groupby('filename', as_index=False).agg(lambda x: Counter(x).most_common(1)[0][0]) 210 | grouped_data.to_csv(os.path.join(dataset_result_dir, dataset_name + ".grouped.predictions.csv"), 211 | index=False) 212 | grouped_X_pred = grouped_data.values[..., 1].tolist() 213 | 214 | # test 215 | pd_filename_list = pd.DataFrame(data=filename_list[..., 0], columns=["filename"]) 216 | true_pd = pd.DataFrame(data=true_np, columns=["label"]) 217 | df = pd_filename_list.join(true_pd, how='outer') 218 | df['filename'] = df['filename'].apply(lambda x: os.path.basename(x)) 219 | data_raw_labels = df.groupby('filename', as_index=False).agg(lambda x: Counter(x).most_common(1)[0][0]) 220 | 221 | # data_raw_labels = data_raw 222 | # data_raw_labels['label'] = data_raw_labels['label'].apply(lambda x: class_list[x]) 223 | grouped_true = data_raw_labels.values[..., 1].tolist() 224 | cm = confusion_matrix(grouped_true, grouped_X_pred) 225 | log.info("Confusion Matrix (grouped):") 226 | log.info(cm) 227 | 228 | log.info(classification_report(y_true=grouped_true, y_pred=grouped_X_pred, 229 | target_names=target_names, 230 | digits=4)) 231 | 232 | recall = recall_score(y_true=grouped_true, y_pred=grouped_X_pred, average='macro') 233 | log.info("UAR: " + str(recall * 100)) 234 | 235 | json_cm_dir = os.path.join(dataset_result_dir, dataset_name + ".grouped.metrics.json") 236 | with open(json_cm_dir, 'w') as f: 237 | json.dump({"cm": cm.tolist(), "uar": round(recall * 100, 4)}, f) 238 | -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/predict.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import logging 20 | import click 21 | from .utils import add_options 22 | import os 23 | import sys 24 | import tensorflow as tf 25 | from tensorflow import keras 26 | import math 27 | from tensorflow.keras import backend as K 28 | from tensorflow.python.saved_model import loader_impl 29 | from deepspectrumlite import AugmentableModel, DataPipeline, HyperParameterList, ARelu 30 | import glob 31 | from pathlib import Path 32 | import numpy as np 33 | import importlib 34 | import json 35 | from sklearn.metrics import recall_score 36 | from sklearn.metrics import classification_report, confusion_matrix 37 | import csv 38 | import json 39 | import pandas as pd 40 | import librosa 41 | from collections import Counter 42 | from os.path import join, dirname, realpath 43 | 44 | log = logging.getLogger(__name__) 45 | 46 | _DESCRIPTION = 'Predict a file using an existing DeepSpectrumLite transer learning model.' 47 | 48 | @add_options( 49 | [ 50 | click.option( 51 | "-d", 52 | "--data-dir", 53 | type=click.Path(exists=True), 54 | help="Directory of data class categories containing folders of each data class.", 55 | required=True 56 | ), 57 | click.option( 58 | "-md", 59 | "--model-dir", 60 | type=click.Path(exists=True, writable=False), 61 | help="HD5 file of the DeepSpectrumLite model", 62 | required=True 63 | ), 64 | click.option( 65 | "-hc", 66 | "--hyper-config", 67 | type=click.Path(exists=True, writable=False, readable=True), 68 | help="Directory for the hyper parameter config file.", 69 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True 70 | ), 71 | click.option( 72 | "-cc", 73 | "--class-config", 74 | type=click.Path(exists=True, writable=False, readable=True), 75 | help="Directory for the class config file.", 76 | default=join(dirname(realpath(__file__)), "config/class_config.json"), show_default=True 77 | ) 78 | ] 79 | ) 80 | 81 | @click.command(help=_DESCRIPTION) 82 | @click.pass_context 83 | def predict(ctx, model_dir, data_dir, class_config, hyper_config, **kwargs): 84 | verbose = ctx.obj['verbose'] 85 | f = open(class_config) 86 | data = json.load(f) 87 | f.close() 88 | 89 | data_dir = os.path.join(data_dir, '') 90 | 91 | data_classes = data 92 | wav_files = sorted(glob.glob(f'{data_dir}/**/*.wav', recursive=True)) 93 | filenames, labels, duration_frames = list(map(lambda x: os.path.relpath(x, start=data_dir), wav_files)), [list(data_classes.keys())[0]]*len(wav_files), [] 94 | for fn in filenames: 95 | y, sr = librosa.load(os.path.join(data_dir, fn), sr=None) 96 | duration_frames.append(y.shape[0]) 97 | 98 | log.info('Found %d wav files' % len(filenames)) 99 | 100 | if data_classes is None: 101 | raise ValueError('no data classes defined') 102 | 103 | class_list = {} 104 | for i, data_class in enumerate(data_classes): 105 | class_list[data_class] = i 106 | 107 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config) 108 | log.info("Search within rule: " + model_dir) 109 | model_dir_list = glob.glob(model_dir) 110 | log.info("Found "+ str(len(model_dir_list)) + " files") 111 | 112 | for model_filename in model_dir_list: 113 | log.info("Load " + model_filename) 114 | p = Path(model_filename) 115 | parent = p.parent 116 | directory = parent.name 117 | 118 | result_dir = os.path.join(parent, "test") 119 | iteration_no = int(directory.split("_")[-1]) 120 | 121 | log.info('--- Testing trial: %s' % iteration_no) 122 | hparam_values = hyper_parameter_list.get_values(iteration_no=iteration_no) 123 | log.info(hparam_values) 124 | 125 | test_data = pd.DataFrame({'filename': filenames, 'label': labels, 'duration_frames': duration_frames}) 126 | 127 | print("Loading model: " + model_filename) 128 | model = tf.keras.models.load_model(model_filename, 129 | custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, 130 | compile=False) 131 | model.set_hyper_parameters(hparam_values) 132 | log.info("Successfully loaded model: " + model_filename) 133 | 134 | data_raw = test_data # [:10] 135 | dataset_name = 'test' 136 | 137 | dataset_result_dir = os.path.join(result_dir, dataset_name) 138 | 139 | os.makedirs(dataset_result_dir, exist_ok=True) 140 | 141 | data_pipeline = DataPipeline(name=dataset_name+'_data_set', data_classes=data_classes, 142 | enable_gpu=True, verbose=True, enable_augmentation=False, 143 | hparams=hparam_values, run_id=iteration_no) 144 | data_pipeline.set_data(data_raw) 145 | data_pipeline.set_filename_prepend(prepend_filename_str=data_dir) 146 | data_pipeline.preprocess() 147 | filename_list = data_pipeline.filenames 148 | dataset = data_pipeline.pipeline(cache=False, shuffle=False, drop_remainder=False) 149 | 150 | X_probs = model.predict(x=dataset, verbose=verbose) 151 | true_categories = tf.concat([y for x, y in dataset], axis=0) 152 | X_pred = tf.argmax(X_probs, axis=1) 153 | X_pred_ny = X_pred.numpy() 154 | 155 | 156 | target_names = [] 157 | for data_class in data_classes: 158 | target_names.append(data_class) 159 | 160 | df = pd.DataFrame(data=filename_list[...,0], columns=["filename"]) 161 | 162 | df['filename'] = df['filename'].apply(lambda x: os.path.basename(x)) 163 | df['time'] = list(map(lambda x: int(x)/sr, filename_list[...,1])) 164 | for i, target in enumerate(target_names): 165 | df[f'prob_{target}'] = X_probs[:, i] 166 | df['prediction'] = list(map(lambda x: target_names[x], X_pred)) 167 | 168 | df.to_csv(os.path.join(dataset_result_dir, dataset_name+".chunks.predictions.csv"), index=False) 169 | 170 | log.info("Finished testing") 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/stats.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import logging 20 | import click 21 | from .utils import add_options 22 | import os 23 | import tensorflow as tf 24 | import numpy as np 25 | from deepspectrumlite import AugmentableModel, ARelu, HyperParameterList 26 | 27 | def get_detailed_stats(model_h5_path): 28 | session = tf.compat.v1.Session() 29 | graph = tf.compat.v1.get_default_graph() 30 | 31 | parent_dir = dirname(model_h5_path) 32 | 33 | 34 | with graph.as_default(): 35 | with session.as_default(): 36 | new_model = tf.keras.models.load_model(model_h5_path, custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, compile=False) 37 | new_model.summary(print_fn=log.info) 38 | run_meta = tf.compat.v1.RunMetadata() 39 | input_details = new_model.get_config() 40 | input_shape = input_details['layers'][0]['config']['batch_input_shape'] 41 | 42 | _ = session.run(new_model.output, { 43 | 'input_1:0': np.random.normal(size=(1, input_shape[1], input_shape[2], input_shape[3]))}, 44 | options=tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE), 45 | run_metadata=run_meta) 46 | 47 | # ''' 48 | ProfileOptionBuilder = tf.compat.v1.profiler.ProfileOptionBuilder 49 | opts = ProfileOptionBuilder(ProfileOptionBuilder.time_and_memory() 50 | ).with_step(1).with_timeline_output('test.json').build() 51 | 52 | tf.compat.v1.profiler.profile( 53 | tf.compat.v1.get_default_graph(), 54 | run_meta=run_meta, 55 | cmd='code', 56 | options=opts) 57 | # ''' 58 | # Print to stdout an analysis of the memory usage and the timing information 59 | # broken down by operation types. 60 | json_export = tf.compat.v1.profiler.profile( 61 | tf.compat.v1.get_default_graph(), 62 | run_meta=run_meta, 63 | cmd='op', 64 | options=tf.compat.v1.profiler.ProfileOptionBuilder.time_and_memory()) 65 | 66 | text_file = open(os.path.join(parent_dir, "profiler.json"), "w") 67 | text_file.write(str(json_export)) 68 | text_file.close() 69 | # print(json_export) 70 | tf.compat.v1.reset_default_graph() 71 | 72 | # deprecated 73 | def get_flops(model_h5_path): # pragma: no cover 74 | tf.compat.v1.enable_eager_execution() 75 | session = tf.compat.v1.Session() 76 | graph = tf.compat.v1.get_default_graph() 77 | 78 | 79 | with graph.as_default(): 80 | with session.as_default(): 81 | tf.keras.models.load_model(model_h5_path, custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, compile=False) 82 | run_meta = tf.compat.v1.RunMetadata() 83 | opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() 84 | 85 | # Optional: save printed results to file 86 | # flops_log_path = os.path.join(tempfile.gettempdir(), 'tf_flops_log.txt') 87 | # opts['output'] = 'file:outfile={}'.format(flops_log_path) 88 | 89 | # We use the Keras session graph in the call to the profiler. 90 | flops = tf.compat.v1.profiler.profile(graph=graph, 91 | run_meta=run_meta, cmd='op', options=opts) 92 | tf.compat.v1.reset_default_graph() 93 | return flops.total_float_ops 94 | 95 | from os.path import join, dirname, realpath 96 | 97 | log = logging.getLogger(__name__) 98 | 99 | _DESCRIPTION = 'Retrieve statistics of a DeepSpectrumLite transer learning model.' 100 | 101 | @add_options( 102 | [ 103 | click.option( 104 | "-md", 105 | "--model-dir", 106 | type=click.Path(exists=False, writable=True), 107 | help="Directory of a DeepSpectrumLite HD5 Model.", 108 | required=True 109 | ) 110 | ] 111 | ) 112 | 113 | @click.command(help=_DESCRIPTION) 114 | def stats(model_dir, **kwargs): 115 | tf.compat.v1.enable_eager_execution() 116 | tf.config.run_functions_eagerly(True) 117 | # reset seed values 118 | np.random.seed(0) 119 | tf.compat.v1.set_random_seed(0) 120 | 121 | # new_model = tf.keras.models.load_model(model_dir, custom_objects={'AugmentableModel': AugmentableModel, 'ARelu': ARelu}, compile=False) 122 | # new_model.summary(print_fn=log.info) 123 | get_detailed_stats(model_dir) -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/tflite_stats.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import logging 20 | import click 21 | import os 22 | import tensorflow as tf 23 | from .utils import add_options 24 | import numpy as np 25 | import argparse 26 | import time 27 | import h5py 28 | import sys 29 | from deepspectrumlite import AugmentableModel, ARelu 30 | 31 | 32 | def get_detailed_stats(model_h5_path): 33 | session = tf.compat.v1.Session() 34 | graph = tf.compat.v1.get_default_graph() 35 | 36 | with graph.as_default(): 37 | with session.as_default(): 38 | new_model = tf.keras.models.load_model(model_h5_path, custom_objects={'AugmentableModel': AugmentableModel, 39 | 'ARelu': ARelu}, compile=False) 40 | run_meta = tf.compat.v1.RunMetadata() 41 | input_details = new_model.get_config() 42 | input_shape = input_details['layers'][0]['config']['batch_input_shape'] 43 | 44 | _ = session.run(new_model.output, { 45 | 'input_1:0': np.random.normal(size=(1, input_shape[1], input_shape[2], input_shape[3]))}, 46 | options=tf.compat.v1.RunOptions(trace_level=tf.compat.v1.RunOptions.FULL_TRACE), 47 | run_metadata=run_meta) 48 | 49 | # ''' 50 | ProfileOptionBuilder = tf.compat.v1.profiler.ProfileOptionBuilder 51 | opts = ProfileOptionBuilder(ProfileOptionBuilder.time_and_memory() 52 | ).with_step(0).with_timeline_output('test.json').build() 53 | 54 | tf.compat.v1.profiler.profile( 55 | tf.compat.v1.get_default_graph(), 56 | run_meta=run_meta, 57 | cmd='code', 58 | options=opts) 59 | # ''' 60 | # Print to stdout an analysis of the memory usage and the timing information 61 | # broken down by operation types. 62 | json_export = tf.compat.v1.profiler.profile( 63 | tf.compat.v1.get_default_graph(), 64 | run_meta=run_meta, 65 | cmd='op', 66 | options=tf.compat.v1.profiler.ProfileOptionBuilder.time_and_memory()) 67 | 68 | text_file = open("profiler.json", "w") 69 | text_file.write(str(json_export)) 70 | text_file.close() 71 | # print(json_export) 72 | tf.compat.v1.reset_default_graph() 73 | 74 | from os.path import join, dirname, realpath 75 | 76 | log = logging.getLogger(__name__) 77 | 78 | _DESCRIPTION = 'Test a TensorFlowLite model and retrieve statistics about it.' 79 | 80 | @add_options( 81 | [ 82 | click.option( 83 | "-md", 84 | "--model-dir", 85 | type=click.Path(exists=False, writable=True), 86 | help="Path to the TensorFlow Lite model", 87 | required=True 88 | ) 89 | ] 90 | ) 91 | 92 | @click.command(help=_DESCRIPTION) 93 | def tflite_stats(model_dir, **kwargs): 94 | tf.compat.v1.enable_eager_execution() 95 | tf.config.run_functions_eagerly(True) 96 | 97 | # reset seed values 98 | np.random.seed(0) 99 | tf.compat.v1.set_random_seed(0) 100 | 101 | model_sub_dir = dirname(model_dir) 102 | 103 | interpreter = tf.lite.Interpreter(model_path=model_dir) 104 | input_details = interpreter.get_input_details() 105 | output_details = interpreter.get_output_details() 106 | all_layers_details = interpreter.get_tensor_details() 107 | 108 | interpreter.allocate_tensors() 109 | 110 | f = h5py.File(os.path.join(model_sub_dir, "converted_model_weights_infos.hdf5"), "w") 111 | parameters = 0 112 | for layer in all_layers_details: 113 | # to create a group in an hdf5 file 114 | grp = f.create_group(str(layer['index'])) 115 | 116 | # to store layer's metadata in group's metadata 117 | grp.attrs["name"] = layer['name'] 118 | grp.attrs["shape"] = layer['shape'] 119 | # grp.attrs["dtype"] = all_layers_details[i]['dtype'] 120 | grp.attrs["quantization"] = layer['quantization'] 121 | weights = interpreter.get_tensor(layer['index']) 122 | # print(weights.size) 123 | parameters += weights.size 124 | # to store the weights in a dataset 125 | grp.create_dataset("weights", data=weights) 126 | 127 | f.close() 128 | log.info(str(parameters)) 129 | 130 | # interpreter.set_tensor(input_details[0]['index'], tf.convert_to_tensor(np.expand_dims(audio_data, 0), dtype=tf.float32)) 131 | 132 | interpreter.invoke() 133 | 134 | output = interpreter.get_tensor(output_details[0]['index']) 135 | 136 | # Test model on random input data. 137 | input_shape = input_details[0]['shape'] 138 | log.info("input shape: ") 139 | log.info(input_shape) 140 | log.info("output shape: ",) 141 | log.info(output_details[0]['shape']) 142 | input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) 143 | interpreter.set_tensor(input_details[0]['index'], input_data) 144 | start_time = time.time() 145 | time.sleep(10.0) 146 | log.info("start") 147 | i = 50 148 | while i > 0: 149 | interpreter.invoke() 150 | i = i - 1 151 | stop_time = time.time() 152 | output_data = interpreter.get_tensor(output_details[0]['index']) 153 | 154 | log.info(output_data) 155 | log.info('time: {:.3f}ms'.format((stop_time - start_time) * 1000)) 156 | log.info('mean time: {:.3f}ms'.format((stop_time - start_time) * 1000 / 50)) 157 | -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/train.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import logging 20 | import click 21 | from os import environ 22 | from .utils import add_options 23 | import shutil 24 | import sys 25 | import json 26 | import os 27 | from os.path import join, dirname, realpath 28 | import platform 29 | 30 | log = logging.getLogger(__name__) 31 | 32 | _DESCRIPTION = 'Train a DeepSpectrumLite transer learning model.' 33 | 34 | @add_options( 35 | [ 36 | click.option( 37 | "-d", 38 | "--data-dir", 39 | type=click.Path(exists=True), 40 | help="Directory of data class categories containing folders of each data class.", 41 | required=True 42 | ), 43 | click.option( 44 | "-md", 45 | "--model-dir", 46 | type=click.Path(exists=False, writable=True), 47 | help="Directory for all training output (logs and final model files).", 48 | required=True 49 | ), 50 | click.option( 51 | "-hc", 52 | "--hyper-config", 53 | type=click.Path(exists=True, writable=False, readable=True), 54 | help="Directory for the hyper parameter config file.", 55 | default=join(dirname(realpath(__file__)), "config/hp_config.json"), show_default=True 56 | ), 57 | click.option( 58 | "-cc", 59 | "--class-config", 60 | type=click.Path(exists=True, writable=False, readable=True), 61 | help="Directory for the class config file.", 62 | default=join(dirname(realpath(__file__)), "config/class_config.json"), show_default=True 63 | ), 64 | click.option( 65 | "-l", 66 | "--label-file", 67 | type=click.Path(exists=True, writable=False, readable=True), 68 | help="Directory for the label file.", 69 | required=True 70 | ), 71 | click.option( 72 | "-dc", 73 | "--disable-cache", 74 | type=click.Path(exists=True, writable=False, readable=True), 75 | help="Disables the in-memory caching." 76 | ) 77 | ] 78 | ) 79 | 80 | @click.command(help=_DESCRIPTION) 81 | @click.pass_context 82 | def train(ctx, model_dir, data_dir, class_config, hyper_config, label_file, disable_cache, **kwargs): 83 | import tensorflow as tf 84 | # tf.compat.v1.enable_eager_execution() 85 | # tf.config.experimental_run_functions_eagerly(True) 86 | from tensorboard.plugins.hparams import api as hp 87 | import numpy as np 88 | import importlib 89 | from deepspectrumlite import HyperParameterList, TransferBaseModel, DataPipeline, \ 90 | METRIC_ACCURACY, METRIC_MAE, METRIC_RMSE, METRIC_RECALL, METRIC_PRECISION, METRIC_F_SCORE, METRIC_LOSS, METRIC_MSE 91 | import math 92 | 93 | verbose = ctx.obj['verbose'] 94 | 95 | enable_cache = not disable_cache 96 | data_dir = os.path.join(data_dir, '') # add trailing slash 97 | 98 | f = open(class_config) 99 | data = json.load(f) 100 | f.close() 101 | 102 | data_classes = data 103 | 104 | if data_classes is None: 105 | raise ValueError('no data classes defined') 106 | 107 | tensorboard_initialised = False 108 | 109 | log.info("Physical devices:") 110 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 111 | log.info(physical_devices) 112 | del physical_devices 113 | 114 | hyper_parameter_list = HyperParameterList(config_file_name=hyper_config) 115 | 116 | max_iterations = hyper_parameter_list.get_max_iteration() 117 | log.info('Loaded hyperparameter configuration.') 118 | log.info("Recognised combinations of settings: " + str(max_iterations) + "") 119 | 120 | slurm_jobid = os.getenv('SLURM_ARRAY_TASK_ID') 121 | 122 | if slurm_jobid is not None: 123 | slurm_jobid = int(slurm_jobid) 124 | 125 | if slurm_jobid >= max_iterations: 126 | raise ValueError('slurm jobid ' + str(slurm_jobid) + ' is out of bound') 127 | 128 | for iteration_no in range(max_iterations): 129 | if slurm_jobid is not None: 130 | iteration_no = slurm_jobid 131 | hparam_values = hyper_parameter_list.get_values(iteration_no=iteration_no) 132 | hparam_values_tensorboard = hyper_parameter_list.get_values_tensorboard(iteration_no=iteration_no) 133 | 134 | run_identifier = hparam_values['tb_run_id'] + '_config_' + str(iteration_no) 135 | 136 | tensorboard_dir = hparam_values['tb_experiment'] 137 | 138 | log_dir = os.path.join(model_dir, 'logs', tensorboard_dir) 139 | run_log_dir = os.path.join(log_dir, run_identifier) 140 | model_dir = os.path.join(model_dir, 'models', tensorboard_dir, run_identifier) 141 | # delete old log 142 | if os.path.isdir(run_log_dir): 143 | shutil.rmtree(run_log_dir) 144 | 145 | if not tensorboard_initialised: 146 | # create tensorboard 147 | with tf.summary.create_file_writer(log_dir).as_default(): 148 | hp.hparams_config( 149 | hparams=hyper_parameter_list.get_hparams(), 150 | metrics=[hp.Metric(METRIC_ACCURACY, display_name='accuracy'), 151 | hp.Metric(METRIC_PRECISION, display_name='precision'), 152 | hp.Metric(METRIC_RECALL, display_name='unweighted recall'), 153 | hp.Metric(METRIC_F_SCORE, display_name='f1 score'), 154 | hp.Metric(METRIC_MAE, display_name='mae'), 155 | hp.Metric(METRIC_RMSE, display_name='rmse') 156 | ], 157 | ) 158 | tensorboard_initialised = True 159 | 160 | # Use a label file parser to load data 161 | label_parser_key = hparam_values['label_parser'] 162 | 163 | if ":" not in label_parser_key: 164 | raise ValueError('Please provide the parser in the following format: path.to.parser_file.py:ParserClass') 165 | 166 | log.info(f'Using custom external parser: {label_parser_key}') 167 | if platform.system() == "Windows": # need to consider : after drive letter in windows paths 168 | # split 169 | s = label_parser_key.split(":") 170 | assert len(s) == 3 171 | # reconstruct path from first two letters 172 | path = "" + s[0] + ":" + s[1] 173 | class_name = s[-1] 174 | 175 | else: # Linux or Mac 176 | path, class_name = label_parser_key.split(':') 177 | module_name = os.path.splitext(os.path.basename(path))[0] 178 | dir_path = os.path.dirname(os.path.realpath(__file__)) 179 | path = os.path.join(dir_path, path) 180 | spec = importlib.util.spec_from_file_location(module_name, path) 181 | foo = importlib.util.module_from_spec(spec) 182 | spec.loader.exec_module(foo) 183 | parser_class = getattr(foo, class_name) 184 | 185 | parser = parser_class(file_path=label_file) 186 | train_data, devel_data, test_data = parser.parse_labels() 187 | 188 | # reset seed values to make keras reproducible 189 | np.random.seed(0) 190 | tf.compat.v1.set_random_seed(0) 191 | 192 | log.info('--- Starting trial: %s' % run_identifier) 193 | log.info({h.name: hparam_values_tensorboard[h] for h in hparam_values_tensorboard}) 194 | 195 | log.info("Load data pipeline ...") 196 | 197 | ########### TRAIN DATA ########### 198 | train_data_pipeline = DataPipeline(name='train_data_set', data_classes=data_classes, 199 | enable_gpu=True, verbose=True, enable_augmentation=False, 200 | hparams=hparam_values, run_id=iteration_no) 201 | train_data_pipeline.set_data(train_data) 202 | train_data_pipeline.set_filename_prepend(prepend_filename_str=data_dir) 203 | train_data_pipeline.preprocess() 204 | train_data_pipeline.up_sample() 205 | train_dataset = train_data_pipeline.pipeline(cache=enable_cache) 206 | 207 | ########### DEVEL DATA ########### 208 | devel_data_pipeline = DataPipeline(name='devel_data_set', data_classes=data_classes, 209 | enable_gpu=True, verbose=True, enable_augmentation=False, 210 | hparams=hparam_values, run_id=iteration_no) 211 | devel_data_pipeline.set_data(devel_data) 212 | devel_data_pipeline.set_filename_prepend(prepend_filename_str=data_dir) 213 | devel_dataset = devel_data_pipeline.pipeline(cache=enable_cache, shuffle=False, drop_remainder=False) 214 | 215 | ########### TEST DATA ########### 216 | test_data_pipeline = DataPipeline(name='test_data_set', data_classes=data_classes, 217 | enable_gpu=True, verbose=True, enable_augmentation=False, 218 | hparams=hparam_values, run_id=iteration_no) 219 | test_data_pipeline.set_data(test_data) 220 | test_data_pipeline.set_filename_prepend(prepend_filename_str=data_dir) 221 | test_dataset = test_data_pipeline.pipeline(cache=enable_cache, shuffle=False, drop_remainder=False) 222 | 223 | log.info("All data pipelines have been successfully loaded.") 224 | log.info("Caching in memory is: " + str(enable_cache)) 225 | 226 | model_name = hparam_values['model_name'] 227 | 228 | available_ai_models = { 229 | 'TransferBaseModel': TransferBaseModel 230 | } 231 | 232 | if model_name in available_ai_models: 233 | model = available_ai_models[model_name](hyper_parameter_list, 234 | train_data_pipeline.get_model_input_shape(), 235 | run_dir=run_log_dir, 236 | data_classes=data_classes, 237 | use_ram=True, 238 | run_id=iteration_no, 239 | verbose=verbose) 240 | 241 | model.run(train_dataset=train_dataset, 242 | test_dataset=test_dataset, 243 | devel_dataset=devel_dataset, 244 | save_model=True, 245 | save_dir=model_dir) 246 | else: 247 | ValueError("Unknown model name: " + model_name) 248 | 249 | if slurm_jobid is not None: 250 | break 251 | -------------------------------------------------------------------------------- /src/deepspectrumlite/cli/utils.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | def add_options(options): 20 | def _add_options(func): 21 | for option in reversed(options): 22 | func = option(func) 23 | return func 24 | 25 | return _add_options 26 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .data.plot import * 2 | from .hyperparameter import HyperParameterList 3 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/__init__.py: -------------------------------------------------------------------------------- 1 | """data handling operations for models""" 2 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/embedded/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/data/embedded/__init__.py -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/embedded/preprocessor.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | from tensorflow.python.keras.applications import imagenet_utils 21 | from tensorflow.python.ops.image_ops_impl import ResizeMethod 22 | 23 | from deepspectrumlite import power_to_db, amplitude_to_db 24 | from deepspectrumlite.lib.data.plot import create_map_from_array, CividisColorMap, InfernoColorMap, MagmaColorMap, \ 25 | PlasmaColorMap, ViridisColorMap 26 | import numpy as np 27 | import math 28 | 29 | 30 | # TODO refactor 31 | class PreprocessAudio(tf.Module): 32 | def __init__(self, hparams, *args, **kwargs): 33 | super(PreprocessAudio, self).__init__(*args, **kwargs) 34 | self.hparams = hparams 35 | 36 | self.resize_method = ResizeMethod.BILINEAR 37 | if 'resize_method' in self.hparams: 38 | self.resize_method = self.hparams['resize_method'] 39 | 40 | self.anti_alias = True 41 | if 'anti_alias' in self.hparams: 42 | self.anti_alias = self.hparams['anti_alias'] 43 | 44 | self.colormap = ViridisColorMap() 45 | 46 | available_color_maps = { 47 | "cividis": CividisColorMap, 48 | "inferno": InfernoColorMap, 49 | "magma": MagmaColorMap, 50 | "plasma": PlasmaColorMap, 51 | "viridis": ViridisColorMap 52 | } 53 | 54 | if self.hparams['color_map'] in available_color_maps: 55 | self.colormap = available_color_maps[self.hparams['color_map']]() 56 | 57 | self.preprocessors = { 58 | "vgg16": 59 | tf.keras.applications.vgg16.preprocess_input, 60 | "vgg19": 61 | tf.keras.applications.vgg19.preprocess_input, 62 | "resnet50": 63 | tf.keras.applications.resnet50.preprocess_input, 64 | "xception": 65 | tf.keras.applications.xception.preprocess_input, 66 | "inception_v3": 67 | tf.keras.applications.inception_v3.preprocess_input, 68 | "densenet121": 69 | tf.keras.applications.densenet.preprocess_input, 70 | "densenet169": 71 | tf.keras.applications.densenet.preprocess_input, 72 | "densenet201": 73 | tf.keras.applications.densenet.preprocess_input, 74 | "mobilenet": 75 | tf.keras.applications.mobilenet.preprocess_input, 76 | "mobilenet_v2": 77 | tf.keras.applications.mobilenet_v2.preprocess_input, 78 | "nasnet_large": 79 | tf.keras.applications.nasnet.preprocess_input, 80 | "nasnet_mobile": 81 | tf.keras.applications.nasnet.preprocess_input, 82 | "inception_resnet_v2": 83 | tf.keras.applications.inception_resnet_v2.preprocess_input, 84 | "squeezenet_v1": 85 | tf.keras.applications.imagenet_utils.preprocess_input, 86 | } 87 | 88 | def __preprocess_vgg(self, x, data_format=None): 89 | """ 90 | Legacy function for VGG16 and VGG19 preprocessing without centering. 91 | """ 92 | x = x[:, :, :, ::-1] 93 | return x 94 | 95 | @tf.function(input_signature=[tf.TensorSpec(shape=(1, 16000), dtype=tf.float32)]) 96 | def preprocess(self, audio_signal): # pragma: no cover 97 | decoded_audio = audio_signal * (0.7079 / tf.reduce_max(tf.abs(audio_signal))) 98 | 99 | frame_length = int(self.hparams['stft_window_size'] * self.hparams['sample_rate']) 100 | frame_step = int(self.hparams['stft_hop_size'] * self.hparams['sample_rate']) 101 | fft_length = int(self.hparams['stft_fft_length'] * self.hparams['sample_rate']) 102 | 103 | stfts = tf.signal.stft(decoded_audio, frame_length=frame_length, frame_step=frame_step, 104 | fft_length=fft_length) 105 | spectrograms = tf.abs(stfts, name="magnitude_spectrograms") ** 2 106 | 107 | # Warp the linear scale spectrograms into the mel-scale. 108 | num_spectrogram_bins = stfts.shape[-1] 109 | lower_edge_hertz, upper_edge_hertz, num_mel_bins = self.hparams['lower_edge_hertz'], \ 110 | self.hparams['upper_edge_hertz'], \ 111 | self.hparams['num_mel_bins'] 112 | 113 | linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix( 114 | num_mel_bins, num_spectrogram_bins, self.hparams['sample_rate'], lower_edge_hertz, 115 | upper_edge_hertz) 116 | 117 | mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1) 118 | mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate( 119 | linear_to_mel_weight_matrix.shape[-1:])) 120 | num_mfcc = self.hparams['num_mfccs'] 121 | 122 | if num_mfcc: 123 | if self.hparams['db_scale']: 124 | mel_spectrograms = amplitude_to_db(mel_spectrograms, top_db=None) 125 | else: 126 | mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6) 127 | 128 | # Compute MFCCs from mel_spectrograms 129 | mfccs = tf.signal.mfccs_from_log_mel_spectrograms(mel_spectrograms)[..., :num_mfcc] 130 | 131 | # cep filter 132 | if self.hparams['cep_lifter'] > 0: 133 | cep_lifter = self.hparams['cep_lifter'] 134 | 135 | (nframes, ncoeff) = mfccs.shape[-2], mfccs.shape[-1] 136 | n = tf.keras.backend.arange(start=0, stop=ncoeff, dtype=tf.float32) 137 | lift = 1 + (cep_lifter / 2) * tf.sin(math.pi * n / cep_lifter) 138 | 139 | mfccs *= lift 140 | 141 | output = mfccs 142 | else: 143 | if self.hparams['db_scale']: 144 | output = power_to_db(mel_spectrograms, top_db=None) 145 | else: 146 | output = mel_spectrograms 147 | 148 | if self.hparams['use_plot_images']: 149 | 150 | color_map = np.array(self.colormap.get_color_map(), dtype=np.float32) 151 | # color_map = tf.Variable(initial_value=self.colormap.get_color_map(), name="color_map") 152 | 153 | image_data = create_map_from_array(output, color_map=color_map) 154 | 155 | image_data = tf.image.resize( 156 | image_data, (self.hparams['image_width'], self.hparams['image_height']), 157 | method=self.resize_method, preserve_aspect_ratio=False, antialias=False 158 | ) 159 | 160 | image_data = image_data * 255. 161 | image_data = tf.clip_by_value(image_data, clip_value_min=0., clip_value_max=255.) 162 | image_data = tf.image.rot90(image_data, k=1) 163 | 164 | def _preprocess(x): 165 | # values in the range [0, 255] are expected!! 166 | model_key = self.hparams['basemodel_name'] 167 | 168 | if model_key in self.preprocessors: 169 | return self.preprocessors[model_key](x, data_format='channels_last') 170 | 171 | return x 172 | 173 | # values in the range [0, 255] are expected!! 174 | image_data = _preprocess(image_data) 175 | 176 | else: 177 | image_data = tf.expand_dims(output, axis=3) 178 | 179 | return image_data 180 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/parser/ComParEParser.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import re 20 | import pandas as pd 21 | import numpy as np 22 | 23 | 24 | class ComParEParser: # pragma: no cover 25 | 26 | def __init__(self, file_path: str, delimiter=','): # pragma: no cover 27 | self._file_path = file_path 28 | self._delimiter = delimiter 29 | 30 | def parse_labels(self): # pragma: no cover 31 | complete = pd.read_csv(self._file_path, sep=self._delimiter) 32 | complete.columns = ['filename', 'label', 'duration_frames'] 33 | 34 | train_data = complete[complete.filename.str.startswith('train')] # 1-3 35 | devel_data = complete[complete.filename.str.startswith('devel')] # 4 36 | test_data = complete[complete.filename.str.startswith('test')] # 5 37 | 38 | return train_data, devel_data, test_data 39 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/data/parser/__init__.py -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/__init__.py: -------------------------------------------------------------------------------- 1 | """ implementation of our spectrogram image plotting """ 2 | from .colormap import * 3 | from .color_maps import * 4 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/color_maps/__init__.py: -------------------------------------------------------------------------------- 1 | """implementation of various color maps""" 2 | from .abstract_colormap import AbstractColorMap 3 | from .cividis import * 4 | from .inferno import * 5 | from .magma import * 6 | from .plasma import * 7 | from .viridis import * 8 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/color_maps/abstract_colormap.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class AbstractColorMap(object): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | def __init__(self): 8 | self.color_map = None 9 | 10 | def set_color_map(self, color_map): 11 | self.color_map = color_map 12 | 13 | def get_color_map(self): 14 | return self.color_map 15 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/color_maps/cividis.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from .abstract_colormap import AbstractColorMap 3 | 4 | 5 | class CividisColorMap(AbstractColorMap): 6 | __metaclass__ = abc.ABCMeta 7 | 8 | def __init__(self): 9 | super().__init__() 10 | self.set_color_map([ 11 | [0.0, 0.135112, 0.304751], 12 | [0.0, 0.138068, 0.311105], 13 | [0.0, 0.141013, 0.317579], 14 | [0.0, 0.143951, 0.323982], 15 | [0.0, 0.146877, 0.330479], 16 | [0.0, 0.149791, 0.337065], 17 | [0.0, 0.152673, 0.343704], 18 | [0.0, 0.155377, 0.3505], 19 | [0.0, 0.157932, 0.357521], 20 | [0.0, 0.160495, 0.364534], 21 | [0.0, 0.163058, 0.371608], 22 | [0.0, 0.165621, 0.378769], 23 | [0.0, 0.168204, 0.385902], 24 | [0.0, 0.1708, 0.3931], 25 | [0.0, 0.17342, 0.400353], 26 | [0.0, 0.176082, 0.407577], 27 | [0.0, 0.178802, 0.414764], 28 | [0.0, 0.18161, 0.421859], 29 | [0.0, 0.18455, 0.428802], 30 | [0.0, 0.186915, 0.435532], 31 | [0.0, 0.188769, 0.439563], 32 | [0.0, 0.19095, 0.441085], 33 | [0.0, 0.193366, 0.441561], 34 | [0.003602, 0.195911, 0.441564], 35 | [0.017852, 0.198528, 0.441248], 36 | [0.03211, 0.201199, 0.440785], 37 | [0.046205, 0.203903, 0.440196], 38 | [0.058378, 0.206629, 0.439531], 39 | [0.068968, 0.209372, 0.438863], 40 | [0.078624, 0.212122, 0.438105], 41 | [0.087465, 0.214879, 0.437342], 42 | [0.095645, 0.217643, 0.436593], 43 | [0.103401, 0.220406, 0.43579], 44 | [0.110658, 0.22317, 0.435067], 45 | [0.117612, 0.225935, 0.434308], 46 | [0.124291, 0.228697, 0.433547], 47 | [0.130669, 0.231458, 0.43284], 48 | [0.13683, 0.234216, 0.432148], 49 | [0.142852, 0.236972, 0.431404], 50 | [0.148638, 0.239724, 0.430752], 51 | [0.154261, 0.242475, 0.43012], 52 | [0.159733, 0.245221, 0.429528], 53 | [0.165113, 0.247965, 0.428908], 54 | [0.170362, 0.250707, 0.428325], 55 | [0.17549, 0.253444, 0.42779], 56 | [0.180503, 0.25618, 0.427299], 57 | [0.185453, 0.258914, 0.426788], 58 | [0.190303, 0.261644, 0.426329], 59 | [0.195057, 0.264372, 0.425924], 60 | [0.199764, 0.267099, 0.425497], 61 | [0.204385, 0.269823, 0.425126], 62 | [0.208926, 0.272546, 0.424809], 63 | [0.213431, 0.275266, 0.42448], 64 | [0.217863, 0.277985, 0.424206], 65 | [0.222264, 0.280702, 0.423914], 66 | [0.226598, 0.283419, 0.423678], 67 | [0.230871, 0.286134, 0.423498], 68 | [0.23512, 0.288848, 0.423304], 69 | [0.239312, 0.291562, 0.423167], 70 | [0.243485, 0.294274, 0.423014], 71 | [0.247605, 0.296986, 0.422917], 72 | [0.251675, 0.299698, 0.422873], 73 | [0.255731, 0.302409, 0.422814], 74 | [0.25974, 0.30512, 0.42281], 75 | [0.263738, 0.307831, 0.422789], 76 | [0.267693, 0.310542, 0.422821], 77 | [0.271639, 0.313253, 0.422837], 78 | [0.275513, 0.315965, 0.422979], 79 | [0.279411, 0.318677, 0.423031], 80 | [0.28324, 0.32139, 0.423211], 81 | [0.287065, 0.324103, 0.423373], 82 | [0.290884, 0.326816, 0.423517], 83 | [0.294669, 0.329531, 0.423716], 84 | [0.298421, 0.332247, 0.423973], 85 | [0.302169, 0.334963, 0.424213], 86 | [0.305886, 0.337681, 0.424512], 87 | [0.309601, 0.340399, 0.42479], 88 | [0.313287, 0.34312, 0.42512], 89 | [0.316941, 0.345842, 0.425512], 90 | [0.320595, 0.348565, 0.425889], 91 | [0.32425, 0.351289, 0.42625], 92 | [0.327875, 0.354016, 0.42667], 93 | [0.331474, 0.356744, 0.427144], 94 | [0.335073, 0.359474, 0.427605], 95 | [0.338673, 0.362206, 0.428053], 96 | [0.342246, 0.364939, 0.428559], 97 | [0.345793, 0.367676, 0.429127], 98 | [0.349341, 0.370414, 0.429685], 99 | [0.352892, 0.373153, 0.430226], 100 | [0.356418, 0.375896, 0.430823], 101 | [0.359916, 0.378641, 0.431501], 102 | [0.363446, 0.381388, 0.432075], 103 | [0.366923, 0.384139, 0.432796], 104 | [0.37043, 0.38689, 0.433428], 105 | [0.373884, 0.389646, 0.434209], 106 | [0.377371, 0.392404, 0.43489], 107 | [0.38083, 0.395164, 0.435653], 108 | [0.384268, 0.397928, 0.436475], 109 | [0.387705, 0.400694, 0.437305], 110 | [0.391151, 0.403464, 0.438096], 111 | [0.394568, 0.406236, 0.438986], 112 | [0.397991, 0.409011, 0.439848], 113 | [0.401418, 0.41179, 0.440708], 114 | [0.40482, 0.414572, 0.441642], 115 | [0.408226, 0.417357, 0.44257], 116 | [0.411607, 0.420145, 0.443577], 117 | [0.414992, 0.422937, 0.444578], 118 | [0.418383, 0.425733, 0.44556], 119 | [0.421748, 0.428531, 0.44664], 120 | [0.42512, 0.431334, 0.447692], 121 | [0.428462, 0.43414, 0.448864], 122 | [0.431817, 0.43695, 0.449982], 123 | [0.435168, 0.439763, 0.451134], 124 | [0.438504, 0.44258, 0.452341], 125 | [0.44181, 0.445402, 0.453659], 126 | [0.445148, 0.448226, 0.454885], 127 | [0.448447, 0.451053, 0.456264], 128 | [0.451759, 0.453887, 0.457582], 129 | [0.455072, 0.456718, 0.458976], 130 | [0.458366, 0.459552, 0.460457], 131 | [0.461616, 0.462405, 0.461969], 132 | [0.464947, 0.465241, 0.463395], 133 | [0.468254, 0.468083, 0.464908], 134 | [0.471501, 0.47096, 0.466357], 135 | [0.474812, 0.473832, 0.467681], 136 | [0.478186, 0.476699, 0.468845], 137 | [0.481622, 0.479573, 0.469767], 138 | [0.485141, 0.482451, 0.470384], 139 | [0.488697, 0.485318, 0.471008], 140 | [0.492278, 0.488198, 0.471453], 141 | [0.495913, 0.491076, 0.471751], 142 | [0.499552, 0.49396, 0.472032], 143 | [0.503185, 0.496851, 0.472305], 144 | [0.506866, 0.499743, 0.472432], 145 | [0.51054, 0.502643, 0.47255], 146 | [0.514226, 0.505546, 0.47264], 147 | [0.51792, 0.508454, 0.472707], 148 | [0.521643, 0.511367, 0.472639], 149 | [0.525348, 0.514285, 0.47266], 150 | [0.529086, 0.517207, 0.472543], 151 | [0.532829, 0.520135, 0.472401], 152 | [0.536553, 0.523067, 0.472352], 153 | [0.540307, 0.526005, 0.472163], 154 | [0.544069, 0.528948, 0.471947], 155 | [0.54784, 0.531895, 0.471704], 156 | [0.551612, 0.534849, 0.471439], 157 | [0.555393, 0.537807, 0.471147], 158 | [0.559181, 0.540771, 0.470829], 159 | [0.562972, 0.543741, 0.470488], 160 | [0.566802, 0.546715, 0.469988], 161 | [0.570607, 0.549695, 0.469593], 162 | [0.574417, 0.552682, 0.469172], 163 | [0.578236, 0.555673, 0.468724], 164 | [0.582087, 0.55867, 0.468118], 165 | [0.585916, 0.561674, 0.467618], 166 | [0.589753, 0.564682, 0.46709], 167 | [0.593622, 0.567697, 0.466401], 168 | [0.597469, 0.570718, 0.465821], 169 | [0.601354, 0.573743, 0.465074], 170 | [0.605211, 0.576777, 0.464441], 171 | [0.609105, 0.579816, 0.463638], 172 | [0.612977, 0.582861, 0.46295], 173 | [0.616852, 0.585913, 0.462237], 174 | [0.620765, 0.58897, 0.461351], 175 | [0.624654, 0.592034, 0.460583], 176 | [0.628576, 0.595104, 0.459641], 177 | [0.632506, 0.59818, 0.458668], 178 | [0.636412, 0.601264, 0.457818], 179 | [0.640352, 0.604354, 0.456791], 180 | [0.64427, 0.60745, 0.455886], 181 | [0.648222, 0.610553, 0.454801], 182 | [0.652178, 0.613664, 0.453689], 183 | [0.656114, 0.61678, 0.452702], 184 | [0.660082, 0.619904, 0.451534], 185 | [0.664055, 0.623034, 0.450338], 186 | [0.668008, 0.626171, 0.44927], 187 | [0.671991, 0.629316, 0.448018], 188 | [0.675981, 0.632468, 0.446736], 189 | [0.679979, 0.635626, 0.445424], 190 | [0.68395, 0.638793, 0.444251], 191 | [0.687957, 0.641966, 0.442886], 192 | [0.691971, 0.645145, 0.441491], 193 | [0.695985, 0.648334, 0.440072], 194 | [0.700008, 0.651529, 0.438624], 195 | [0.704037, 0.654731, 0.437147], 196 | [0.708067, 0.657942, 0.435647], 197 | [0.712105, 0.66116, 0.434117], 198 | [0.716177, 0.664384, 0.432386], 199 | [0.720222, 0.667618, 0.430805], 200 | [0.724274, 0.670859, 0.429194], 201 | [0.728334, 0.674107, 0.427554], 202 | [0.732422, 0.677364, 0.425717], 203 | [0.736488, 0.680629, 0.424028], 204 | [0.740589, 0.6839, 0.422131], 205 | [0.744664, 0.687181, 0.420393], 206 | [0.748772, 0.69047, 0.418448], 207 | [0.752886, 0.693766, 0.416472], 208 | [0.756975, 0.697071, 0.414659], 209 | [0.761096, 0.700384, 0.412638], 210 | [0.765223, 0.703705, 0.410587], 211 | [0.769353, 0.707035, 0.408516], 212 | [0.773486, 0.710373, 0.406422], 213 | [0.777651, 0.713719, 0.404112], 214 | [0.781795, 0.717074, 0.401966], 215 | [0.785965, 0.720438, 0.399613], 216 | [0.790116, 0.72381, 0.397423], 217 | [0.794298, 0.72719, 0.395016], 218 | [0.79848, 0.73058, 0.392597], 219 | [0.802667, 0.733978, 0.390153], 220 | [0.806859, 0.737385, 0.387684], 221 | [0.811054, 0.740801, 0.385198], 222 | [0.815274, 0.744226, 0.382504], 223 | [0.819499, 0.747659, 0.379785], 224 | [0.823729, 0.751101, 0.377043], 225 | [0.827959, 0.754553, 0.374292], 226 | [0.832192, 0.758014, 0.371529], 227 | [0.836429, 0.761483, 0.368747], 228 | [0.840693, 0.764962, 0.365746], 229 | [0.844957, 0.76845, 0.362741], 230 | [0.849223, 0.771947, 0.359729], 231 | [0.853515, 0.775454, 0.3565], 232 | [0.857809, 0.778969, 0.353259], 233 | [0.862105, 0.782494, 0.350011], 234 | [0.866421, 0.786028, 0.346571], 235 | [0.870717, 0.789572, 0.343333], 236 | [0.875057, 0.793125, 0.339685], 237 | [0.879378, 0.796687, 0.336241], 238 | [0.88372, 0.800258, 0.332599], 239 | [0.888081, 0.803839, 0.32877], 240 | [0.89244, 0.80743, 0.324968], 241 | [0.896818, 0.81103, 0.320982], 242 | [0.901195, 0.814639, 0.317021], 243 | [0.905589, 0.818257, 0.312889], 244 | [0.91, 0.821885, 0.308594], 245 | [0.914407, 0.825522, 0.304348], 246 | [0.918828, 0.829168, 0.29996], 247 | [0.923279, 0.832822, 0.295244], 248 | [0.927724, 0.836486, 0.290611], 249 | [0.93218, 0.840159, 0.28588], 250 | [0.93666, 0.843841, 0.280876], 251 | [0.941147, 0.84753, 0.275815], 252 | [0.945654, 0.851228, 0.270532], 253 | [0.950178, 0.854933, 0.265085], 254 | [0.954725, 0.858646, 0.259365], 255 | [0.959284, 0.862365, 0.253563], 256 | [0.963872, 0.866089, 0.247445], 257 | [0.968469, 0.869819, 0.24131], 258 | [0.973114, 0.87355, 0.234677], 259 | [0.97778, 0.877281, 0.227954], 260 | [0.982497, 0.881008, 0.220878], 261 | [0.987293, 0.884718, 0.213336], 262 | [0.992218, 0.888385, 0.205468], 263 | [0.994847, 0.892954, 0.203445], 264 | [0.995249, 0.898384, 0.207561], 265 | [0.995503, 0.903866, 0.21237], 266 | [0.995737, 0.909344, 0.217772] 267 | ]) 268 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/color_maps/inferno.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from .abstract_colormap import AbstractColorMap 3 | 4 | class InfernoColorMap(AbstractColorMap): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | def __init__(self): 8 | super().__init__() 9 | self.set_color_map([ 10 | [0.001462, 0.000466, 0.013866], 11 | [0.002267, 0.00127, 0.01857], 12 | [0.003299, 0.002249, 0.024239], 13 | [0.004547, 0.003392, 0.030909], 14 | [0.006006, 0.004692, 0.038558], 15 | [0.007676, 0.006136, 0.046836], 16 | [0.009561, 0.007713, 0.055143], 17 | [0.011663, 0.009417, 0.06346], 18 | [0.013995, 0.011225, 0.071862], 19 | [0.016561, 0.013136, 0.080282], 20 | [0.019373, 0.015133, 0.088767], 21 | [0.022447, 0.017199, 0.097327], 22 | [0.025793, 0.019331, 0.10593], 23 | [0.029432, 0.021503, 0.114621], 24 | [0.033385, 0.023702, 0.123397], 25 | [0.037668, 0.025921, 0.132232], 26 | [0.042253, 0.028139, 0.141141], 27 | [0.046915, 0.030324, 0.150164], 28 | [0.051644, 0.032474, 0.159254], 29 | [0.056449, 0.034569, 0.168414], 30 | [0.06134, 0.03659, 0.177642], 31 | [0.066331, 0.038504, 0.186962], 32 | [0.071429, 0.040294, 0.196354], 33 | [0.076637, 0.041905, 0.205799], 34 | [0.081962, 0.043328, 0.215289], 35 | [0.087411, 0.044556, 0.224813], 36 | [0.09299, 0.045583, 0.234358], 37 | [0.098702, 0.046402, 0.243904], 38 | [0.104551, 0.047008, 0.25343], 39 | [0.110536, 0.047399, 0.262912], 40 | [0.116656, 0.047574, 0.272321], 41 | [0.122908, 0.047536, 0.281624], 42 | [0.129285, 0.047293, 0.290788], 43 | [0.135778, 0.046856, 0.299776], 44 | [0.142378, 0.046242, 0.308553], 45 | [0.149073, 0.045468, 0.317085], 46 | [0.15585, 0.044559, 0.325338], 47 | [0.162689, 0.043554, 0.333277], 48 | [0.169575, 0.042489, 0.340874], 49 | [0.176493, 0.041402, 0.348111], 50 | [0.183429, 0.040329, 0.354971], 51 | [0.190367, 0.039309, 0.361447], 52 | [0.197297, 0.0384, 0.367535], 53 | [0.204209, 0.037632, 0.373238], 54 | [0.211095, 0.03703, 0.378563], 55 | [0.217949, 0.036615, 0.383522], 56 | [0.224763, 0.036405, 0.388129], 57 | [0.231538, 0.036405, 0.3924], 58 | [0.238273, 0.036621, 0.396353], 59 | [0.244967, 0.037055, 0.400007], 60 | [0.25162, 0.037705, 0.403378], 61 | [0.258234, 0.038571, 0.406485], 62 | [0.26481, 0.039647, 0.409345], 63 | [0.271347, 0.040922, 0.411976], 64 | [0.27785, 0.042353, 0.414392], 65 | [0.284321, 0.043933, 0.416608], 66 | [0.290763, 0.045644, 0.418637], 67 | [0.297178, 0.04747, 0.420491], 68 | [0.303568, 0.049396, 0.422182], 69 | [0.309935, 0.051407, 0.423721], 70 | [0.316282, 0.05349, 0.425116], 71 | [0.32261, 0.055634, 0.426377], 72 | [0.328921, 0.057827, 0.427511], 73 | [0.335217, 0.06006, 0.428524], 74 | [0.3415, 0.062325, 0.429425], 75 | [0.347771, 0.064616, 0.430217], 76 | [0.354032, 0.066925, 0.430906], 77 | [0.360284, 0.069247, 0.431497], 78 | [0.366529, 0.071579, 0.431994], 79 | [0.372768, 0.073915, 0.4324], 80 | [0.379001, 0.076253, 0.432719], 81 | [0.385228, 0.078591, 0.432955], 82 | [0.391453, 0.080927, 0.433109], 83 | [0.397674, 0.083257, 0.433183], 84 | [0.403894, 0.08558, 0.433179], 85 | [0.410113, 0.087896, 0.433098], 86 | [0.416331, 0.090203, 0.432943], 87 | [0.422549, 0.092501, 0.432714], 88 | [0.428768, 0.09479, 0.432412], 89 | [0.434987, 0.097069, 0.432039], 90 | [0.441207, 0.099338, 0.431594], 91 | [0.447428, 0.101597, 0.43108], 92 | [0.453651, 0.103848, 0.430498], 93 | [0.459875, 0.106089, 0.429846], 94 | [0.4661, 0.108322, 0.429125], 95 | [0.472328, 0.110547, 0.428334], 96 | [0.478558, 0.112764, 0.427475], 97 | [0.484789, 0.114974, 0.426548], 98 | [0.491022, 0.117179, 0.425552], 99 | [0.497257, 0.119379, 0.424488], 100 | [0.503493, 0.121575, 0.423356], 101 | [0.50973, 0.123769, 0.422156], 102 | [0.515967, 0.12596, 0.420887], 103 | [0.522206, 0.12815, 0.419549], 104 | [0.528444, 0.130341, 0.418142], 105 | [0.534683, 0.132534, 0.416667], 106 | [0.54092, 0.134729, 0.415123], 107 | [0.547157, 0.136929, 0.413511], 108 | [0.553392, 0.139134, 0.411829], 109 | [0.559624, 0.141346, 0.410078], 110 | [0.565854, 0.143567, 0.408258], 111 | [0.572081, 0.145797, 0.406369], 112 | [0.578304, 0.148039, 0.404411], 113 | [0.584521, 0.150294, 0.402385], 114 | [0.590734, 0.152563, 0.40029], 115 | [0.59694, 0.154848, 0.398125], 116 | [0.603139, 0.157151, 0.395891], 117 | [0.60933, 0.159474, 0.393589], 118 | [0.615513, 0.161817, 0.391219], 119 | [0.621685, 0.164184, 0.388781], 120 | [0.627847, 0.166575, 0.386276], 121 | [0.633998, 0.168992, 0.383704], 122 | [0.640135, 0.171438, 0.381065], 123 | [0.64626, 0.173914, 0.378359], 124 | [0.652369, 0.176421, 0.375586], 125 | [0.658463, 0.178962, 0.372748], 126 | [0.66454, 0.181539, 0.369846], 127 | [0.670599, 0.184153, 0.366879], 128 | [0.676638, 0.186807, 0.363849], 129 | [0.682656, 0.189501, 0.360757], 130 | [0.688653, 0.192239, 0.357603], 131 | [0.694627, 0.195021, 0.354388], 132 | [0.700576, 0.197851, 0.351113], 133 | [0.7065, 0.200728, 0.347777], 134 | [0.712396, 0.203656, 0.344383], 135 | [0.718264, 0.206636, 0.340931], 136 | [0.724103, 0.20967, 0.337424], 137 | [0.729909, 0.212759, 0.333861], 138 | [0.735683, 0.215906, 0.330245], 139 | [0.741423, 0.219112, 0.326576], 140 | [0.747127, 0.222378, 0.322856], 141 | [0.752794, 0.225706, 0.319085], 142 | [0.758422, 0.229097, 0.315266], 143 | [0.76401, 0.232554, 0.311399], 144 | [0.769556, 0.236077, 0.307485], 145 | [0.775059, 0.239667, 0.303526], 146 | [0.780517, 0.243327, 0.299523], 147 | [0.785929, 0.247056, 0.295477], 148 | [0.791293, 0.250856, 0.29139], 149 | [0.796607, 0.254728, 0.287264], 150 | [0.801871, 0.258674, 0.283099], 151 | [0.807082, 0.262692, 0.278898], 152 | [0.812239, 0.266786, 0.274661], 153 | [0.817341, 0.270954, 0.27039], 154 | [0.822386, 0.275197, 0.266085], 155 | [0.827372, 0.279517, 0.26175], 156 | [0.832299, 0.283913, 0.257383], 157 | [0.837165, 0.288385, 0.252988], 158 | [0.841969, 0.292933, 0.248564], 159 | [0.846709, 0.297559, 0.244113], 160 | [0.851384, 0.30226, 0.239636], 161 | [0.855992, 0.307038, 0.235133], 162 | [0.860533, 0.311892, 0.230606], 163 | [0.865006, 0.316822, 0.226055], 164 | [0.869409, 0.321827, 0.221482], 165 | [0.873741, 0.326906, 0.216886], 166 | [0.878001, 0.33206, 0.212268], 167 | [0.882188, 0.337287, 0.207628], 168 | [0.886302, 0.342586, 0.202968], 169 | [0.890341, 0.347957, 0.198286], 170 | [0.894305, 0.353399, 0.193584], 171 | [0.898192, 0.358911, 0.18886], 172 | [0.902003, 0.364492, 0.184116], 173 | [0.905735, 0.37014, 0.17935], 174 | [0.90939, 0.375856, 0.174563], 175 | [0.912966, 0.381636, 0.169755], 176 | [0.916462, 0.387481, 0.164924], 177 | [0.919879, 0.393389, 0.16007], 178 | [0.923215, 0.399359, 0.155193], 179 | [0.92647, 0.405389, 0.150292], 180 | [0.929644, 0.411479, 0.145367], 181 | [0.932737, 0.417627, 0.140417], 182 | [0.935747, 0.423831, 0.13544], 183 | [0.938675, 0.430091, 0.130438], 184 | [0.941521, 0.436405, 0.125409], 185 | [0.944285, 0.442772, 0.120354], 186 | [0.946965, 0.449191, 0.115272], 187 | [0.949562, 0.45566, 0.110164], 188 | [0.952075, 0.462178, 0.105031], 189 | [0.954506, 0.468744, 0.099874], 190 | [0.956852, 0.475356, 0.094695], 191 | [0.959114, 0.482014, 0.089499], 192 | [0.961293, 0.488716, 0.084289], 193 | [0.963387, 0.495462, 0.079073], 194 | [0.965397, 0.502249, 0.073859], 195 | [0.967322, 0.509078, 0.068659], 196 | [0.969163, 0.515946, 0.063488], 197 | [0.970919, 0.522853, 0.058367], 198 | [0.97259, 0.529798, 0.053324], 199 | [0.974176, 0.53678, 0.048392], 200 | [0.975677, 0.543798, 0.043618], 201 | [0.977092, 0.55085, 0.03905], 202 | [0.978422, 0.557937, 0.034931], 203 | [0.979666, 0.565057, 0.031409], 204 | [0.980824, 0.572209, 0.028508], 205 | [0.981895, 0.579392, 0.02625], 206 | [0.982881, 0.586606, 0.024661], 207 | [0.983779, 0.593849, 0.02377], 208 | [0.984591, 0.601122, 0.023606], 209 | [0.985315, 0.608422, 0.024202], 210 | [0.985952, 0.61575, 0.025592], 211 | [0.986502, 0.623105, 0.027814], 212 | [0.986964, 0.630485, 0.030908], 213 | [0.987337, 0.63789, 0.034916], 214 | [0.987622, 0.64532, 0.039886], 215 | [0.987819, 0.652773, 0.045581], 216 | [0.987926, 0.66025, 0.05175], 217 | [0.987945, 0.667748, 0.058329], 218 | [0.987874, 0.675267, 0.065257], 219 | [0.987714, 0.682807, 0.072489], 220 | [0.987464, 0.690366, 0.07999], 221 | [0.987124, 0.697944, 0.087731], 222 | [0.986694, 0.70554, 0.095694], 223 | [0.986175, 0.713153, 0.103863], 224 | [0.985566, 0.720782, 0.112229], 225 | [0.984865, 0.728427, 0.120785], 226 | [0.984075, 0.736087, 0.129527], 227 | [0.983196, 0.743758, 0.138453], 228 | [0.982228, 0.751442, 0.147565], 229 | [0.981173, 0.759135, 0.156863], 230 | [0.980032, 0.766837, 0.166353], 231 | [0.978806, 0.774545, 0.176037], 232 | [0.977497, 0.782258, 0.185923], 233 | [0.976108, 0.789974, 0.196018], 234 | [0.974638, 0.797692, 0.206332], 235 | [0.973088, 0.805409, 0.216877], 236 | [0.971468, 0.813122, 0.227658], 237 | [0.969783, 0.820825, 0.238686], 238 | [0.968041, 0.828515, 0.249972], 239 | [0.966243, 0.836191, 0.261534], 240 | [0.964394, 0.843848, 0.273391], 241 | [0.962517, 0.851476, 0.285546], 242 | [0.960626, 0.859069, 0.29801], 243 | [0.95872, 0.866624, 0.31082], 244 | [0.956834, 0.874129, 0.323974], 245 | [0.954997, 0.881569, 0.337475], 246 | [0.953215, 0.888942, 0.351369], 247 | [0.951546, 0.896226, 0.365627], 248 | [0.950018, 0.903409, 0.380271], 249 | [0.948683, 0.910473, 0.395289], 250 | [0.947594, 0.917399, 0.410665], 251 | [0.946809, 0.924168, 0.426373], 252 | [0.946392, 0.930761, 0.442367], 253 | [0.946403, 0.937159, 0.458592], 254 | [0.946903, 0.943348, 0.47497], 255 | [0.947937, 0.949318, 0.491426], 256 | [0.949545, 0.955063, 0.50786], 257 | [0.95174, 0.960587, 0.524203], 258 | [0.954529, 0.965896, 0.540361], 259 | [0.957896, 0.971003, 0.556275], 260 | [0.961812, 0.975924, 0.571925], 261 | [0.966249, 0.980678, 0.587206], 262 | [0.971162, 0.985282, 0.602154], 263 | [0.976511, 0.989753, 0.61676], 264 | [0.982257, 0.994109, 0.631017], 265 | [0.988362, 0.998364, 0.644924] 266 | ]) -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/color_maps/magma.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from .abstract_colormap import AbstractColorMap 3 | 4 | class MagmaColorMap(AbstractColorMap): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | def __init__(self): 8 | super().__init__() 9 | self.set_color_map([ 10 | [0.001462, 0.000466, 0.013866], 11 | [0.002258, 0.001295, 0.018331], 12 | [0.003279, 0.002305, 0.023708], 13 | [0.004512, 0.00349, 0.029965], 14 | [0.00595, 0.004843, 0.03713], 15 | [0.007588, 0.006356, 0.044973], 16 | [0.009426, 0.008022, 0.052844], 17 | [0.011465, 0.009828, 0.06075], 18 | [0.013708, 0.011771, 0.068667], 19 | [0.016156, 0.01384, 0.076603], 20 | [0.018815, 0.016026, 0.084584], 21 | [0.021692, 0.01832, 0.09261], 22 | [0.024792, 0.020715, 0.100676], 23 | [0.028123, 0.023201, 0.108787], 24 | [0.031696, 0.025765, 0.116965], 25 | [0.03552, 0.028397, 0.125209], 26 | [0.039608, 0.03109, 0.133515], 27 | [0.04383, 0.03383, 0.141886], 28 | [0.048062, 0.036607, 0.150327], 29 | [0.05232, 0.039407, 0.158841], 30 | [0.056615, 0.04216, 0.167446], 31 | [0.060949, 0.044794, 0.176129], 32 | [0.06533, 0.047318, 0.184892], 33 | [0.069764, 0.049726, 0.193735], 34 | [0.074257, 0.052017, 0.20266], 35 | [0.078815, 0.054184, 0.211667], 36 | [0.083446, 0.056225, 0.220755], 37 | [0.088155, 0.058133, 0.229922], 38 | [0.092949, 0.059904, 0.239164], 39 | [0.097833, 0.061531, 0.248477], 40 | [0.102815, 0.06301, 0.257854], 41 | [0.107899, 0.064335, 0.267289], 42 | [0.113094, 0.065492, 0.276784], 43 | [0.118405, 0.066479, 0.286321], 44 | [0.123833, 0.067295, 0.295879], 45 | [0.12938, 0.067935, 0.305443], 46 | [0.135053, 0.068391, 0.315], 47 | [0.140858, 0.068654, 0.324538], 48 | [0.146785, 0.068738, 0.334011], 49 | [0.152839, 0.068637, 0.343404], 50 | [0.159018, 0.068354, 0.352688], 51 | [0.165308, 0.067911, 0.361816], 52 | [0.171713, 0.067305, 0.370771], 53 | [0.178212, 0.066576, 0.379497], 54 | [0.184801, 0.065732, 0.387973], 55 | [0.19146, 0.064818, 0.396152], 56 | [0.198177, 0.063862, 0.404009], 57 | [0.204935, 0.062907, 0.411514], 58 | [0.211718, 0.061992, 0.418647], 59 | [0.218512, 0.061158, 0.425392], 60 | [0.225302, 0.060445, 0.431742], 61 | [0.232077, 0.059889, 0.437695], 62 | [0.238826, 0.059517, 0.443256], 63 | [0.245543, 0.059352, 0.448436], 64 | [0.25222, 0.059415, 0.453248], 65 | [0.258857, 0.059706, 0.45771], 66 | [0.265447, 0.060237, 0.46184], 67 | [0.271994, 0.060994, 0.46566], 68 | [0.278493, 0.061978, 0.46919], 69 | [0.284951, 0.063168, 0.472451], 70 | [0.291366, 0.064553, 0.475462], 71 | [0.29774, 0.066117, 0.478243], 72 | [0.304081, 0.067835, 0.480812], 73 | [0.310382, 0.069702, 0.483186], 74 | [0.316654, 0.07169, 0.48538], 75 | [0.322899, 0.073782, 0.487408], 76 | [0.329114, 0.075972, 0.489287], 77 | [0.335308, 0.078236, 0.491024], 78 | [0.341482, 0.080564, 0.492631], 79 | [0.347636, 0.082946, 0.494121], 80 | [0.353773, 0.085373, 0.495501], 81 | [0.359898, 0.087831, 0.496778], 82 | [0.366012, 0.090314, 0.49796], 83 | [0.372116, 0.092816, 0.499053], 84 | [0.378211, 0.095332, 0.500067], 85 | [0.384299, 0.097855, 0.501002], 86 | [0.390384, 0.100379, 0.501864], 87 | [0.396467, 0.102902, 0.502658], 88 | [0.402548, 0.10542, 0.503386], 89 | [0.408629, 0.10793, 0.504052], 90 | [0.414709, 0.110431, 0.504662], 91 | [0.420791, 0.11292, 0.505215], 92 | [0.426877, 0.115395, 0.505714], 93 | [0.432967, 0.117855, 0.50616], 94 | [0.439062, 0.120298, 0.506555], 95 | [0.445163, 0.122724, 0.506901], 96 | [0.451271, 0.125132, 0.507198], 97 | [0.457386, 0.127522, 0.507448], 98 | [0.463508, 0.129893, 0.507652], 99 | [0.46964, 0.132245, 0.507809], 100 | [0.47578, 0.134577, 0.507921], 101 | [0.481929, 0.136891, 0.507989], 102 | [0.488088, 0.139186, 0.508011], 103 | [0.494258, 0.141462, 0.507988], 104 | [0.500438, 0.143719, 0.50792], 105 | [0.506629, 0.145958, 0.507806], 106 | [0.512831, 0.148179, 0.507648], 107 | [0.519045, 0.150383, 0.507443], 108 | [0.52527, 0.152569, 0.507192], 109 | [0.531507, 0.154739, 0.506895], 110 | [0.537755, 0.156894, 0.506551], 111 | [0.544015, 0.159033, 0.506159], 112 | [0.550287, 0.161158, 0.505719], 113 | [0.556571, 0.163269, 0.50523], 114 | [0.562866, 0.165368, 0.504692], 115 | [0.569172, 0.167454, 0.504105], 116 | [0.57549, 0.16953, 0.503466], 117 | [0.581819, 0.171596, 0.502777], 118 | [0.588158, 0.173652, 0.502035], 119 | [0.594508, 0.175701, 0.501241], 120 | [0.600868, 0.177743, 0.500394], 121 | [0.607238, 0.179779, 0.499492], 122 | [0.613617, 0.181811, 0.498536], 123 | [0.620005, 0.18384, 0.497524], 124 | [0.626401, 0.185867, 0.496456], 125 | [0.632805, 0.187893, 0.495332], 126 | [0.639216, 0.189921, 0.49415], 127 | [0.645633, 0.191952, 0.49291], 128 | [0.652056, 0.193986, 0.491611], 129 | [0.658483, 0.196027, 0.490253], 130 | [0.664915, 0.198075, 0.488836], 131 | [0.671349, 0.200133, 0.487358], 132 | [0.677786, 0.202203, 0.485819], 133 | [0.684224, 0.204286, 0.484219], 134 | [0.690661, 0.206384, 0.482558], 135 | [0.697098, 0.208501, 0.480835], 136 | [0.703532, 0.210638, 0.479049], 137 | [0.709962, 0.212797, 0.477201], 138 | [0.716387, 0.214982, 0.47529], 139 | [0.722805, 0.217194, 0.473316], 140 | [0.729216, 0.219437, 0.471279], 141 | [0.735616, 0.221713, 0.46918], 142 | [0.742004, 0.224025, 0.467018], 143 | [0.748378, 0.226377, 0.464794], 144 | [0.754737, 0.228772, 0.462509], 145 | [0.761077, 0.231214, 0.460162], 146 | [0.767398, 0.233705, 0.457755], 147 | [0.773695, 0.236249, 0.455289], 148 | [0.779968, 0.238851, 0.452765], 149 | [0.786212, 0.241514, 0.450184], 150 | [0.792427, 0.244242, 0.447543], 151 | [0.798608, 0.24704, 0.444848], 152 | [0.804752, 0.249911, 0.442102], 153 | [0.810855, 0.252861, 0.439305], 154 | [0.816914, 0.255895, 0.436461], 155 | [0.822926, 0.259016, 0.433573], 156 | [0.828886, 0.262229, 0.430644], 157 | [0.834791, 0.26554, 0.427671], 158 | [0.840636, 0.268953, 0.424666], 159 | [0.846416, 0.272473, 0.421631], 160 | [0.852126, 0.276106, 0.418573], 161 | [0.857763, 0.279857, 0.415496], 162 | [0.86332, 0.283729, 0.412403], 163 | [0.868793, 0.287728, 0.409303], 164 | [0.874176, 0.291859, 0.406205], 165 | [0.879464, 0.296125, 0.403118], 166 | [0.884651, 0.30053, 0.400047], 167 | [0.889731, 0.305079, 0.397002], 168 | [0.8947, 0.309773, 0.393995], 169 | [0.899552, 0.314616, 0.391037], 170 | [0.904281, 0.31961, 0.388137], 171 | [0.908884, 0.324755, 0.385308], 172 | [0.913354, 0.330052, 0.382563], 173 | [0.917689, 0.3355, 0.379915], 174 | [0.921884, 0.341098, 0.377376], 175 | [0.925937, 0.346844, 0.374959], 176 | [0.929845, 0.352734, 0.372677], 177 | [0.933606, 0.358764, 0.370541], 178 | [0.937221, 0.364929, 0.368567], 179 | [0.940687, 0.371224, 0.366762], 180 | [0.944006, 0.377643, 0.365136], 181 | [0.94718, 0.384178, 0.363701], 182 | [0.95021, 0.39082, 0.362468], 183 | [0.953099, 0.397563, 0.361438], 184 | [0.955849, 0.4044, 0.360619], 185 | [0.958464, 0.411324, 0.360014], 186 | [0.960949, 0.418323, 0.35963], 187 | [0.96331, 0.42539, 0.359469], 188 | [0.965549, 0.432519, 0.359529], 189 | [0.967671, 0.439703, 0.35981], 190 | [0.96968, 0.446936, 0.360311], 191 | [0.971582, 0.45421, 0.36103], 192 | [0.973381, 0.46152, 0.361965], 193 | [0.975082, 0.468861, 0.363111], 194 | [0.97669, 0.476226, 0.364466], 195 | [0.97821, 0.483612, 0.366025], 196 | [0.979645, 0.491014, 0.367783], 197 | [0.981, 0.498428, 0.369734], 198 | [0.982279, 0.505851, 0.371874], 199 | [0.983485, 0.51328, 0.374198], 200 | [0.984622, 0.520713, 0.376698], 201 | [0.985693, 0.528148, 0.379371], 202 | [0.9867, 0.535582, 0.38221], 203 | [0.987646, 0.543015, 0.38521], 204 | [0.988533, 0.550446, 0.388365], 205 | [0.989363, 0.557873, 0.391671], 206 | [0.990138, 0.565296, 0.395122], 207 | [0.990871, 0.572706, 0.398714], 208 | [0.991558, 0.580107, 0.402441], 209 | [0.992196, 0.587502, 0.406299], 210 | [0.992785, 0.594891, 0.410283], 211 | [0.993326, 0.602275, 0.41439], 212 | [0.993834, 0.609644, 0.418613], 213 | [0.994309, 0.616999, 0.42295], 214 | [0.994738, 0.62435, 0.427397], 215 | [0.995122, 0.631696, 0.431951], 216 | [0.99548, 0.639027, 0.436607], 217 | [0.99581, 0.646344, 0.441361], 218 | [0.996096, 0.653659, 0.446213], 219 | [0.996341, 0.660969, 0.45116], 220 | [0.99658, 0.668256, 0.456192], 221 | [0.996775, 0.675541, 0.461314], 222 | [0.996925, 0.682828, 0.466526], 223 | [0.997077, 0.690088, 0.471811], 224 | [0.997186, 0.697349, 0.477182], 225 | [0.997254, 0.704611, 0.482635], 226 | [0.997325, 0.711848, 0.488154], 227 | [0.997351, 0.719089, 0.493755], 228 | [0.997351, 0.726324, 0.499428], 229 | [0.997341, 0.733545, 0.505167], 230 | [0.997285, 0.740772, 0.510983], 231 | [0.997228, 0.747981, 0.516859], 232 | [0.997138, 0.75519, 0.522806], 233 | [0.997019, 0.762398, 0.528821], 234 | [0.996898, 0.769591, 0.534892], 235 | [0.996727, 0.776795, 0.541039], 236 | [0.996571, 0.783977, 0.547233], 237 | [0.996369, 0.791167, 0.553499], 238 | [0.996162, 0.798348, 0.55982], 239 | [0.995932, 0.805527, 0.566202], 240 | [0.99568, 0.812706, 0.572645], 241 | [0.995424, 0.819875, 0.57914], 242 | [0.995131, 0.827052, 0.585701], 243 | [0.994851, 0.834213, 0.592307], 244 | [0.994524, 0.841387, 0.598983], 245 | [0.994222, 0.84854, 0.605696], 246 | [0.993866, 0.855711, 0.612482], 247 | [0.993545, 0.862859, 0.619299], 248 | [0.99317, 0.870024, 0.626189], 249 | [0.992831, 0.877168, 0.633109], 250 | [0.99244, 0.88433, 0.640099], 251 | [0.992089, 0.89147, 0.647116], 252 | [0.991688, 0.898627, 0.654202], 253 | [0.991332, 0.905763, 0.661309], 254 | [0.99093, 0.912915, 0.668481], 255 | [0.99057, 0.920049, 0.675675], 256 | [0.990175, 0.927196, 0.682926], 257 | [0.989815, 0.934329, 0.690198], 258 | [0.989434, 0.94147, 0.697519], 259 | [0.989077, 0.948604, 0.704863], 260 | [0.988717, 0.955742, 0.712242], 261 | [0.988367, 0.962878, 0.719649], 262 | [0.988033, 0.970012, 0.727077], 263 | [0.987691, 0.977154, 0.734536], 264 | [0.987387, 0.984288, 0.742002], 265 | [0.987053, 0.991438, 0.749504] 266 | ]) 267 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/color_maps/plasma.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from .abstract_colormap import AbstractColorMap 3 | 4 | class PlasmaColorMap(AbstractColorMap): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | def __init__(self): 8 | super().__init__() 9 | self.set_color_map([ 10 | [0.050383, 0.029803, 0.527975], 11 | [0.063536, 0.028426, 0.533124], 12 | [0.075353, 0.027206, 0.538007], 13 | [0.086222, 0.026125, 0.542658], 14 | [0.096379, 0.025165, 0.547103], 15 | [0.10598, 0.024309, 0.551368], 16 | [0.115124, 0.023556, 0.555468], 17 | [0.123903, 0.022878, 0.559423], 18 | [0.132381, 0.022258, 0.56325], 19 | [0.140603, 0.021687, 0.566959], 20 | [0.148607, 0.021154, 0.570562], 21 | [0.156421, 0.020651, 0.574065], 22 | [0.16407, 0.020171, 0.577478], 23 | [0.171574, 0.019706, 0.580806], 24 | [0.17895, 0.019252, 0.584054], 25 | [0.186213, 0.018803, 0.587228], 26 | [0.193374, 0.018354, 0.59033], 27 | [0.200445, 0.017902, 0.593364], 28 | [0.207435, 0.017442, 0.596333], 29 | [0.21435, 0.016973, 0.599239], 30 | [0.221197, 0.016497, 0.602083], 31 | [0.227983, 0.016007, 0.604867], 32 | [0.234715, 0.015502, 0.607592], 33 | [0.241396, 0.014979, 0.610259], 34 | [0.248032, 0.014439, 0.612868], 35 | [0.254627, 0.013882, 0.615419], 36 | [0.261183, 0.013308, 0.617911], 37 | [0.267703, 0.012716, 0.620346], 38 | [0.274191, 0.012109, 0.622722], 39 | [0.280648, 0.011488, 0.625038], 40 | [0.287076, 0.010855, 0.627295], 41 | [0.293478, 0.010213, 0.62949], 42 | [0.299855, 0.009561, 0.631624], 43 | [0.30621, 0.008902, 0.633694], 44 | [0.312543, 0.008239, 0.6357], 45 | [0.318856, 0.007576, 0.63764], 46 | [0.32515, 0.006915, 0.639512], 47 | [0.331426, 0.006261, 0.641316], 48 | [0.337683, 0.005618, 0.643049], 49 | [0.343925, 0.004991, 0.64471], 50 | [0.35015, 0.004382, 0.646298], 51 | [0.356359, 0.003798, 0.64781], 52 | [0.362553, 0.003243, 0.649245], 53 | [0.368733, 0.002724, 0.650601], 54 | [0.374897, 0.002245, 0.651876], 55 | [0.381047, 0.001814, 0.653068], 56 | [0.387183, 0.001434, 0.654177], 57 | [0.393304, 0.001114, 0.655199], 58 | [0.399411, 0.000859, 0.656133], 59 | [0.405503, 0.000678, 0.656977], 60 | [0.41158, 0.000577, 0.65773], 61 | [0.417642, 0.000564, 0.65839], 62 | [0.423689, 0.000646, 0.658956], 63 | [0.429719, 0.000831, 0.659425], 64 | [0.435734, 0.001127, 0.659797], 65 | [0.441732, 0.00154, 0.660069], 66 | [0.447714, 0.00208, 0.66024], 67 | [0.453677, 0.002755, 0.66031], 68 | [0.459623, 0.003574, 0.660277], 69 | [0.46555, 0.004545, 0.660139], 70 | [0.471457, 0.005678, 0.659897], 71 | [0.477344, 0.00698, 0.659549], 72 | [0.48321, 0.00846, 0.659095], 73 | [0.489055, 0.010127, 0.658534], 74 | [0.494877, 0.01199, 0.657865], 75 | [0.500678, 0.014055, 0.657088], 76 | [0.506454, 0.016333, 0.656202], 77 | [0.512206, 0.018833, 0.655209], 78 | [0.517933, 0.021563, 0.654109], 79 | [0.523633, 0.024532, 0.652901], 80 | [0.529306, 0.027747, 0.651586], 81 | [0.534952, 0.031217, 0.650165], 82 | [0.54057, 0.03495, 0.64864], 83 | [0.546157, 0.038954, 0.64701], 84 | [0.551715, 0.043136, 0.645277], 85 | [0.557243, 0.047331, 0.643443], 86 | [0.562738, 0.051545, 0.641509], 87 | [0.568201, 0.055778, 0.639477], 88 | [0.573632, 0.060028, 0.637349], 89 | [0.579029, 0.064296, 0.635126], 90 | [0.584391, 0.068579, 0.632812], 91 | [0.589719, 0.072878, 0.630408], 92 | [0.595011, 0.07719, 0.627917], 93 | [0.600266, 0.081516, 0.625342], 94 | [0.605485, 0.085854, 0.622686], 95 | [0.610667, 0.090204, 0.619951], 96 | [0.615812, 0.094564, 0.61714], 97 | [0.620919, 0.098934, 0.614257], 98 | [0.625987, 0.103312, 0.611305], 99 | [0.631017, 0.107699, 0.608287], 100 | [0.636008, 0.112092, 0.605205], 101 | [0.640959, 0.116492, 0.602065], 102 | [0.645872, 0.120898, 0.598867], 103 | [0.650746, 0.125309, 0.595617], 104 | [0.65558, 0.129725, 0.592317], 105 | [0.660374, 0.134144, 0.588971], 106 | [0.665129, 0.138566, 0.585582], 107 | [0.669845, 0.142992, 0.582154], 108 | [0.674522, 0.147419, 0.578688], 109 | [0.67916, 0.151848, 0.575189], 110 | [0.683758, 0.156278, 0.57166], 111 | [0.688318, 0.160709, 0.568103], 112 | [0.69284, 0.165141, 0.564522], 113 | [0.697324, 0.169573, 0.560919], 114 | [0.701769, 0.174005, 0.557296], 115 | [0.706178, 0.178437, 0.553657], 116 | [0.710549, 0.182868, 0.550004], 117 | [0.714883, 0.187299, 0.546338], 118 | [0.719181, 0.191729, 0.542663], 119 | [0.723444, 0.196158, 0.538981], 120 | [0.72767, 0.200586, 0.535293], 121 | [0.731862, 0.205013, 0.531601], 122 | [0.736019, 0.209439, 0.527908], 123 | [0.740143, 0.213864, 0.524216], 124 | [0.744232, 0.218288, 0.520524], 125 | [0.748289, 0.222711, 0.516834], 126 | [0.752312, 0.227133, 0.513149], 127 | [0.756304, 0.231555, 0.509468], 128 | [0.760264, 0.235976, 0.505794], 129 | [0.764193, 0.240396, 0.502126], 130 | [0.76809, 0.244817, 0.498465], 131 | [0.771958, 0.249237, 0.494813], 132 | [0.775796, 0.253658, 0.491171], 133 | [0.779604, 0.258078, 0.487539], 134 | [0.783383, 0.2625, 0.483918], 135 | [0.787133, 0.266922, 0.480307], 136 | [0.790855, 0.271345, 0.476706], 137 | [0.794549, 0.27577, 0.473117], 138 | [0.798216, 0.280197, 0.469538], 139 | [0.801855, 0.284626, 0.465971], 140 | [0.805467, 0.289057, 0.462415], 141 | [0.809052, 0.293491, 0.45887], 142 | [0.812612, 0.297928, 0.455338], 143 | [0.816144, 0.302368, 0.451816], 144 | [0.819651, 0.306812, 0.448306], 145 | [0.823132, 0.311261, 0.444806], 146 | [0.826588, 0.315714, 0.441316], 147 | [0.830018, 0.320172, 0.437836], 148 | [0.833422, 0.324635, 0.434366], 149 | [0.836801, 0.329105, 0.430905], 150 | [0.840155, 0.33358, 0.427455], 151 | [0.843484, 0.338062, 0.424013], 152 | [0.846788, 0.342551, 0.420579], 153 | [0.850066, 0.347048, 0.417153], 154 | [0.853319, 0.351553, 0.413734], 155 | [0.856547, 0.356066, 0.410322], 156 | [0.85975, 0.360588, 0.406917], 157 | [0.862927, 0.365119, 0.403519], 158 | [0.866078, 0.36966, 0.400126], 159 | [0.869203, 0.374212, 0.396738], 160 | [0.872303, 0.378774, 0.393355], 161 | [0.875376, 0.383347, 0.389976], 162 | [0.878423, 0.387932, 0.3866], 163 | [0.881443, 0.392529, 0.383229], 164 | [0.884436, 0.397139, 0.37986], 165 | [0.887402, 0.401762, 0.376494], 166 | [0.89034, 0.406398, 0.37313], 167 | [0.89325, 0.411048, 0.369768], 168 | [0.896131, 0.415712, 0.366407], 169 | [0.898984, 0.420392, 0.363047], 170 | [0.901807, 0.425087, 0.359688], 171 | [0.904601, 0.429797, 0.356329], 172 | [0.907365, 0.434524, 0.35297], 173 | [0.910098, 0.439268, 0.34961], 174 | [0.9128, 0.444029, 0.346251], 175 | [0.915471, 0.448807, 0.34289], 176 | [0.918109, 0.453603, 0.339529], 177 | [0.920714, 0.458417, 0.336166], 178 | [0.923287, 0.463251, 0.332801], 179 | [0.925825, 0.468103, 0.329435], 180 | [0.928329, 0.472975, 0.326067], 181 | [0.930798, 0.477867, 0.322697], 182 | [0.933232, 0.48278, 0.319325], 183 | [0.93563, 0.487712, 0.315952], 184 | [0.93799, 0.492667, 0.312575], 185 | [0.940313, 0.497642, 0.309197], 186 | [0.942598, 0.502639, 0.305816], 187 | [0.944844, 0.507658, 0.302433], 188 | [0.947051, 0.512699, 0.299049], 189 | [0.949217, 0.517763, 0.295662], 190 | [0.951344, 0.52285, 0.292275], 191 | [0.953428, 0.52796, 0.288883], 192 | [0.95547, 0.533093, 0.28549], 193 | [0.957469, 0.53825, 0.282096], 194 | [0.959424, 0.543431, 0.278701], 195 | [0.961336, 0.548636, 0.275305], 196 | [0.963203, 0.553865, 0.271909], 197 | [0.965024, 0.559118, 0.268513], 198 | [0.966798, 0.564396, 0.265118], 199 | [0.968526, 0.5697, 0.261721], 200 | [0.970205, 0.575028, 0.258325], 201 | [0.971835, 0.580382, 0.254931], 202 | [0.973416, 0.585761, 0.25154], 203 | [0.974947, 0.591165, 0.248151], 204 | [0.976428, 0.596595, 0.244767], 205 | [0.977856, 0.602051, 0.241387], 206 | [0.979233, 0.607532, 0.238013], 207 | [0.980556, 0.613039, 0.234646], 208 | [0.981826, 0.618572, 0.231287], 209 | [0.983041, 0.624131, 0.227937], 210 | [0.984199, 0.629718, 0.224595], 211 | [0.985301, 0.63533, 0.221265], 212 | [0.986345, 0.640969, 0.217948], 213 | [0.987332, 0.646633, 0.214648], 214 | [0.98826, 0.652325, 0.211364], 215 | [0.989128, 0.658043, 0.2081], 216 | [0.989935, 0.663787, 0.204859], 217 | [0.990681, 0.669558, 0.201642], 218 | [0.991365, 0.675355, 0.198453], 219 | [0.991985, 0.681179, 0.195295], 220 | [0.992541, 0.68703, 0.19217], 221 | [0.993032, 0.692907, 0.189084], 222 | [0.993456, 0.69881, 0.186041], 223 | [0.993814, 0.704741, 0.183043], 224 | [0.994103, 0.710698, 0.180097], 225 | [0.994324, 0.716681, 0.177208], 226 | [0.994474, 0.722691, 0.174381], 227 | [0.994553, 0.728728, 0.171622], 228 | [0.994561, 0.734791, 0.168938], 229 | [0.994495, 0.74088, 0.166335], 230 | [0.994355, 0.746995, 0.163821], 231 | [0.994141, 0.753137, 0.161404], 232 | [0.993851, 0.759304, 0.159092], 233 | [0.993482, 0.765499, 0.156891], 234 | [0.993033, 0.77172, 0.154808], 235 | [0.992505, 0.777967, 0.152855], 236 | [0.991897, 0.784239, 0.151042], 237 | [0.991209, 0.790537, 0.149377], 238 | [0.990439, 0.796859, 0.14787], 239 | [0.989587, 0.803205, 0.146529], 240 | [0.988648, 0.809579, 0.145357], 241 | [0.987621, 0.815978, 0.144363], 242 | [0.986509, 0.822401, 0.143557], 243 | [0.985314, 0.828846, 0.142945], 244 | [0.984031, 0.835315, 0.142528], 245 | [0.982653, 0.841812, 0.142303], 246 | [0.98119, 0.848329, 0.142279], 247 | [0.979644, 0.854866, 0.142453], 248 | [0.977995, 0.861432, 0.142808], 249 | [0.976265, 0.868016, 0.143351], 250 | [0.974443, 0.874622, 0.144061], 251 | [0.97253, 0.88125, 0.144923], 252 | [0.970533, 0.887896, 0.145919], 253 | [0.968443, 0.894564, 0.147014], 254 | [0.966271, 0.901249, 0.14818], 255 | [0.964021, 0.90795, 0.14937], 256 | [0.961681, 0.914672, 0.15052], 257 | [0.959276, 0.921407, 0.151566], 258 | [0.956808, 0.928152, 0.152409], 259 | [0.954287, 0.934908, 0.152921], 260 | [0.951726, 0.941671, 0.152925], 261 | [0.949151, 0.948435, 0.152178], 262 | [0.946602, 0.95519, 0.150328], 263 | [0.944152, 0.961916, 0.146861], 264 | [0.941896, 0.96859, 0.140956], 265 | [0.940015, 0.975158, 0.131326] 266 | ]) 267 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/color_maps/viridis.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from .abstract_colormap import AbstractColorMap 3 | 4 | class ViridisColorMap(AbstractColorMap): 5 | __metaclass__ = abc.ABCMeta 6 | 7 | def __init__(self): 8 | super().__init__() 9 | self.set_color_map([ 10 | [0.267004, 0.004874, 0.329415], 11 | [0.26851, 0.009605, 0.335427], 12 | [0.269944, 0.014625, 0.341379], 13 | [0.271305, 0.019942, 0.347269], 14 | [0.272594, 0.025563, 0.353093], 15 | [0.273809, 0.031497, 0.358853], 16 | [0.274952, 0.037752, 0.364543], 17 | [0.276022, 0.044167, 0.370164], 18 | [0.277018, 0.050344, 0.375715], 19 | [0.277941, 0.056324, 0.381191], 20 | [0.278791, 0.062145, 0.386592], 21 | [0.279566, 0.067836, 0.391917], 22 | [0.280267, 0.073417, 0.397163], 23 | [0.280894, 0.078907, 0.402329], 24 | [0.281446, 0.08432, 0.407414], 25 | [0.281924, 0.089666, 0.412415], 26 | [0.282327, 0.094955, 0.417331], 27 | [0.282656, 0.100196, 0.42216], 28 | [0.28291, 0.105393, 0.426902], 29 | [0.283091, 0.110553, 0.431554], 30 | [0.283197, 0.11568, 0.436115], 31 | [0.283229, 0.120777, 0.440584], 32 | [0.283187, 0.125848, 0.44496], 33 | [0.283072, 0.130895, 0.449241], 34 | [0.282884, 0.13592, 0.453427], 35 | [0.282623, 0.140926, 0.457517], 36 | [0.28229, 0.145912, 0.46151], 37 | [0.281887, 0.150881, 0.465405], 38 | [0.281412, 0.155834, 0.469201], 39 | [0.280868, 0.160771, 0.472899], 40 | [0.280255, 0.165693, 0.476498], 41 | [0.279574, 0.170599, 0.479997], 42 | [0.278826, 0.17549, 0.483397], 43 | [0.278012, 0.180367, 0.486697], 44 | [0.277134, 0.185228, 0.489898], 45 | [0.276194, 0.190074, 0.493001], 46 | [0.275191, 0.194905, 0.496005], 47 | [0.274128, 0.199721, 0.498911], 48 | [0.273006, 0.20452, 0.501721], 49 | [0.271828, 0.209303, 0.504434], 50 | [0.270595, 0.214069, 0.507052], 51 | [0.269308, 0.218818, 0.509577], 52 | [0.267968, 0.223549, 0.512008], 53 | [0.26658, 0.228262, 0.514349], 54 | [0.265145, 0.232956, 0.516599], 55 | [0.263663, 0.237631, 0.518762], 56 | [0.262138, 0.242286, 0.520837], 57 | [0.260571, 0.246922, 0.522828], 58 | [0.258965, 0.251537, 0.524736], 59 | [0.257322, 0.25613, 0.526563], 60 | [0.255645, 0.260703, 0.528312], 61 | [0.253935, 0.265254, 0.529983], 62 | [0.252194, 0.269783, 0.531579], 63 | [0.250425, 0.27429, 0.533103], 64 | [0.248629, 0.278775, 0.534556], 65 | [0.246811, 0.283237, 0.535941], 66 | [0.244972, 0.287675, 0.53726], 67 | [0.243113, 0.292092, 0.538516], 68 | [0.241237, 0.296485, 0.539709], 69 | [0.239346, 0.300855, 0.540844], 70 | [0.237441, 0.305202, 0.541921], 71 | [0.235526, 0.309527, 0.542944], 72 | [0.233603, 0.313828, 0.543914], 73 | [0.231674, 0.318106, 0.544834], 74 | [0.229739, 0.322361, 0.545706], 75 | [0.227802, 0.326594, 0.546532], 76 | [0.225863, 0.330805, 0.547314], 77 | [0.223925, 0.334994, 0.548053], 78 | [0.221989, 0.339161, 0.548752], 79 | [0.220057, 0.343307, 0.549413], 80 | [0.21813, 0.347432, 0.550038], 81 | [0.21621, 0.351535, 0.550627], 82 | [0.214298, 0.355619, 0.551184], 83 | [0.212395, 0.359683, 0.55171], 84 | [0.210503, 0.363727, 0.552206], 85 | [0.208623, 0.367752, 0.552675], 86 | [0.206756, 0.371758, 0.553117], 87 | [0.204903, 0.375746, 0.553533], 88 | [0.203063, 0.379716, 0.553925], 89 | [0.201239, 0.38367, 0.554294], 90 | [0.19943, 0.387607, 0.554642], 91 | [0.197636, 0.391528, 0.554969], 92 | [0.19586, 0.395433, 0.555276], 93 | [0.1941, 0.399323, 0.555565], 94 | [0.192357, 0.403199, 0.555836], 95 | [0.190631, 0.407061, 0.556089], 96 | [0.188923, 0.41091, 0.556326], 97 | [0.187231, 0.414746, 0.556547], 98 | [0.185556, 0.41857, 0.556753], 99 | [0.183898, 0.422383, 0.556944], 100 | [0.182256, 0.426184, 0.55712], 101 | [0.180629, 0.429975, 0.557282], 102 | [0.179019, 0.433756, 0.55743], 103 | [0.177423, 0.437527, 0.557565], 104 | [0.175841, 0.44129, 0.557685], 105 | [0.174274, 0.445044, 0.557792], 106 | [0.172719, 0.448791, 0.557885], 107 | [0.171176, 0.45253, 0.557965], 108 | [0.169646, 0.456262, 0.55803], 109 | [0.168126, 0.459988, 0.558082], 110 | [0.166617, 0.463708, 0.558119], 111 | [0.165117, 0.467423, 0.558141], 112 | [0.163625, 0.471133, 0.558148], 113 | [0.162142, 0.474838, 0.55814], 114 | [0.160665, 0.47854, 0.558115], 115 | [0.159194, 0.482237, 0.558073], 116 | [0.157729, 0.485932, 0.558013], 117 | [0.15627, 0.489624, 0.557936], 118 | [0.154815, 0.493313, 0.55784], 119 | [0.153364, 0.497, 0.557724], 120 | [0.151918, 0.500685, 0.557587], 121 | [0.150476, 0.504369, 0.55743], 122 | [0.149039, 0.508051, 0.55725], 123 | [0.147607, 0.511733, 0.557049], 124 | [0.14618, 0.515413, 0.556823], 125 | [0.144759, 0.519093, 0.556572], 126 | [0.143343, 0.522773, 0.556295], 127 | [0.141935, 0.526453, 0.555991], 128 | [0.140536, 0.530132, 0.555659], 129 | [0.139147, 0.533812, 0.555298], 130 | [0.13777, 0.537492, 0.554906], 131 | [0.136408, 0.541173, 0.554483], 132 | [0.135066, 0.544853, 0.554029], 133 | [0.133743, 0.548535, 0.553541], 134 | [0.132444, 0.552216, 0.553018], 135 | [0.131172, 0.555899, 0.552459], 136 | [0.129933, 0.559582, 0.551864], 137 | [0.128729, 0.563265, 0.551229], 138 | [0.127568, 0.566949, 0.550556], 139 | [0.126453, 0.570633, 0.549841], 140 | [0.125394, 0.574318, 0.549086], 141 | [0.124395, 0.578002, 0.548287], 142 | [0.123463, 0.581687, 0.547445], 143 | [0.122606, 0.585371, 0.546557], 144 | [0.121831, 0.589055, 0.545623], 145 | [0.121148, 0.592739, 0.544641], 146 | [0.120565, 0.596422, 0.543611], 147 | [0.120092, 0.600104, 0.54253], 148 | [0.119738, 0.603785, 0.5414], 149 | [0.119512, 0.607464, 0.540218], 150 | [0.119423, 0.611141, 0.538982], 151 | [0.119483, 0.614817, 0.537692], 152 | [0.119699, 0.61849, 0.536347], 153 | [0.120081, 0.622161, 0.534946], 154 | [0.120638, 0.625828, 0.533488], 155 | [0.12138, 0.629492, 0.531973], 156 | [0.122312, 0.633153, 0.530398], 157 | [0.123444, 0.636809, 0.528763], 158 | [0.12478, 0.640461, 0.527068], 159 | [0.126326, 0.644107, 0.525311], 160 | [0.128087, 0.647749, 0.523491], 161 | [0.130067, 0.651384, 0.521608], 162 | [0.132268, 0.655014, 0.519661], 163 | [0.134692, 0.658636, 0.517649], 164 | [0.137339, 0.662252, 0.515571], 165 | [0.14021, 0.665859, 0.513427], 166 | [0.143303, 0.669459, 0.511215], 167 | [0.146616, 0.67305, 0.508936], 168 | [0.150148, 0.676631, 0.506589], 169 | [0.153894, 0.680203, 0.504172], 170 | [0.157851, 0.683765, 0.501686], 171 | [0.162016, 0.687316, 0.499129], 172 | [0.166383, 0.690856, 0.496502], 173 | [0.170948, 0.694384, 0.493803], 174 | [0.175707, 0.6979, 0.491033], 175 | [0.180653, 0.701402, 0.488189], 176 | [0.185783, 0.704891, 0.485273], 177 | [0.19109, 0.708366, 0.482284], 178 | [0.196571, 0.711827, 0.479221], 179 | [0.202219, 0.715272, 0.476084], 180 | [0.20803, 0.718701, 0.472873], 181 | [0.214, 0.722114, 0.469588], 182 | [0.220124, 0.725509, 0.466226], 183 | [0.226397, 0.728888, 0.462789], 184 | [0.232815, 0.732247, 0.459277], 185 | [0.239374, 0.735588, 0.455688], 186 | [0.24607, 0.73891, 0.452024], 187 | [0.252899, 0.742211, 0.448284], 188 | [0.259857, 0.745492, 0.444467], 189 | [0.266941, 0.748751, 0.440573], 190 | [0.274149, 0.751988, 0.436601], 191 | [0.281477, 0.755203, 0.432552], 192 | [0.288921, 0.758394, 0.428426], 193 | [0.296479, 0.761561, 0.424223], 194 | [0.304148, 0.764704, 0.419943], 195 | [0.311925, 0.767822, 0.415586], 196 | [0.319809, 0.770914, 0.411152], 197 | [0.327796, 0.77398, 0.40664], 198 | [0.335885, 0.777018, 0.402049], 199 | [0.344074, 0.780029, 0.397381], 200 | [0.35236, 0.783011, 0.392636], 201 | [0.360741, 0.785964, 0.387814], 202 | [0.369214, 0.788888, 0.382914], 203 | [0.377779, 0.791781, 0.377939], 204 | [0.386433, 0.794644, 0.372886], 205 | [0.395174, 0.797475, 0.367757], 206 | [0.404001, 0.800275, 0.362552], 207 | [0.412913, 0.803041, 0.357269], 208 | [0.421908, 0.805774, 0.35191], 209 | [0.430983, 0.808473, 0.346476], 210 | [0.440137, 0.811138, 0.340967], 211 | [0.449368, 0.813768, 0.335384], 212 | [0.458674, 0.816363, 0.329727], 213 | [0.468053, 0.818921, 0.323998], 214 | [0.477504, 0.821444, 0.318195], 215 | [0.487026, 0.823929, 0.312321], 216 | [0.496615, 0.826376, 0.306377], 217 | [0.506271, 0.828786, 0.300362], 218 | [0.515992, 0.831158, 0.294279], 219 | [0.525776, 0.833491, 0.288127], 220 | [0.535621, 0.835785, 0.281908], 221 | [0.545524, 0.838039, 0.275626], 222 | [0.555484, 0.840254, 0.269281], 223 | [0.565498, 0.84243, 0.262877], 224 | [0.575563, 0.844566, 0.256415], 225 | [0.585678, 0.846661, 0.249897], 226 | [0.595839, 0.848717, 0.243329], 227 | [0.606045, 0.850733, 0.236712], 228 | [0.616293, 0.852709, 0.230052], 229 | [0.626579, 0.854645, 0.223353], 230 | [0.636902, 0.856542, 0.21662], 231 | [0.647257, 0.8584, 0.209861], 232 | [0.657642, 0.860219, 0.203082], 233 | [0.668054, 0.861999, 0.196293], 234 | [0.678489, 0.863742, 0.189503], 235 | [0.688944, 0.865448, 0.182725], 236 | [0.699415, 0.867117, 0.175971], 237 | [0.709898, 0.868751, 0.169257], 238 | [0.720391, 0.87035, 0.162603], 239 | [0.730889, 0.871916, 0.156029], 240 | [0.741388, 0.873449, 0.149561], 241 | [0.751884, 0.874951, 0.143228], 242 | [0.762373, 0.876424, 0.137064], 243 | [0.772852, 0.877868, 0.131109], 244 | [0.783315, 0.879285, 0.125405], 245 | [0.79376, 0.880678, 0.120005], 246 | [0.804182, 0.882046, 0.114965], 247 | [0.814576, 0.883393, 0.110347], 248 | [0.82494, 0.88472, 0.106217], 249 | [0.83527, 0.886029, 0.102646], 250 | [0.845561, 0.887322, 0.099702], 251 | [0.85581, 0.888601, 0.097452], 252 | [0.866013, 0.889868, 0.095953], 253 | [0.876168, 0.891125, 0.09525], 254 | [0.886271, 0.892374, 0.095374], 255 | [0.89632, 0.893616, 0.096335], 256 | [0.906311, 0.894855, 0.098125], 257 | [0.916242, 0.896091, 0.100717], 258 | [0.926106, 0.89733, 0.104071], 259 | [0.935904, 0.89857, 0.108131], 260 | [0.945636, 0.899815, 0.112838], 261 | [0.9553, 0.901065, 0.118128], 262 | [0.964894, 0.902323, 0.123941], 263 | [0.974417, 0.90359, 0.130215], 264 | [0.983868, 0.904867, 0.136897], 265 | [0.993248, 0.906157, 0.143936] 266 | ]) -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/data/plot/colormap.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # Based on the typescript implemented 5 | # @see https://github.com/alesgenova/colormap 6 | 7 | 8 | def linear_scale(domain, out_range, value): # pragma: no cover 9 | d0, d1 = domain 10 | r0, r1 = out_range 11 | 12 | return r0 + (r1 - r0) * ((value - d0) / (d1 - d0)) 13 | 14 | 15 | def linear_mixer(value, lower_node_value, upper_node_value): # pragma: no cover 16 | frac = (value - lower_node_value) / (upper_node_value - lower_node_value) 17 | return 1. - frac, frac 18 | 19 | 20 | def color_combination(a, X, b, Y): # pragma: no cover 21 | return [ 22 | a * X[0] + b * Y[0], 23 | a * X[1] + b * Y[1], 24 | a * X[2] + b * Y[2] 25 | ] 26 | 27 | 28 | def create_map_from_array(v, color_map): # pragma: no cover 29 | channel_position = 3 30 | # rgb: 31 | v = tf.expand_dims(v, channel_position) # add channel axis 32 | v = tf.repeat(v, 3, axis=channel_position) # duplicate values to all channels 33 | 34 | domain = (tf.math.reduce_min(v, axis=(1, 2, 3), keepdims=True), 35 | tf.math.reduce_max(v, axis=(1, 2, 3), keepdims=True)) 36 | 37 | scaled_value = linear_scale(domain=domain, out_range=(0, 1), value=v) 38 | vri_len = color_map.shape[0] 39 | index_float = (vri_len - 1) * scaled_value 40 | 41 | t1 = tf.math.less_equal(index_float, 0) 42 | t2 = tf.math.greater_equal(index_float, vri_len - 1) 43 | 44 | result = tf.where(t1, color_map[0], tf.where(t2, color_map[vri_len - 1], [-1, -2, -3])) 45 | 46 | index = tf.math.floor(index_float) 47 | index2 = tf.where(index >= vri_len - 1, tf.cast(vri_len, dtype=tf.float32) - 1, index + 1.) 48 | 49 | coeff0, coeff1 = linear_mixer(value=index_float, lower_node_value=index, upper_node_value=index2) 50 | index = tf.cast(index, dtype=tf.int32) 51 | index2 = tf.cast(index2, dtype=tf.int32) 52 | 53 | # red mask 54 | v1_r = tf.gather(color_map[:, 0], indices=index) 55 | v2_r = tf.gather(color_map[:, 0], indices=index2) 56 | com_r = coeff0 * v1_r + coeff1 * v2_r 57 | # green mask 58 | v1_g = tf.gather(color_map[:, 1], indices=index) 59 | v2_g = tf.gather(color_map[:, 1], indices=index2) 60 | com_g = coeff0 * v1_g + coeff1 * v2_g 61 | # blue mask 62 | v1_b = tf.gather(color_map[:, 2], indices=index) 63 | v2_b = tf.gather(color_map[:, 2], indices=index2) 64 | com_b = coeff0 * v1_b + coeff1 * v2_b 65 | # apply color masks 66 | result = tf.where(tf.math.equal(result, -1), com_r, result) 67 | result = tf.where(tf.math.equal(result, -2), com_g, result) 68 | result = tf.where(tf.math.equal(result, -3), com_b, result) 69 | 70 | return result 71 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/hyperparameter.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import itertools 20 | import json 21 | from tensorboard.plugins.hparams import api as hp 22 | 23 | 24 | class HyperParameterList: 25 | def __init__(self, config_file_name: str): 26 | f = open(config_file_name) 27 | data = json.load(f) 28 | f.close() 29 | 30 | self._config = data 31 | self._param_list = {} 32 | 33 | self.load_configuration() 34 | 35 | def get_hparams(self): 36 | hparams = [] 37 | 38 | for key in self._config: 39 | hparams.append(self._param_list[key]) 40 | 41 | return hparams 42 | 43 | def load_configuration(self): 44 | self._param_list = {} 45 | 46 | for key in self._config: 47 | self._param_list[key] = hp.HParam(key, hp.Discrete(self._config[key])) 48 | 49 | def get_max_iteration(self): 50 | count = 1 51 | 52 | for key in self._config: 53 | count = count * len(self._config[key]) 54 | 55 | return count 56 | 57 | def get_values_tensorboard(self, iteration_no: int): 58 | if iteration_no >= self.get_max_iteration(): 59 | raise ValueError(str(self.get_max_iteration()) + ' < iteration_no >= 0') 60 | 61 | configuration_space = [] 62 | for key in self._config: 63 | configurations = [] 64 | for v in self._config[key]: 65 | configurations.append({key: v}) 66 | configuration_space.append(configurations) 67 | perturbations = list(itertools.product(*configuration_space)) 68 | 69 | perturbation = perturbations[iteration_no] 70 | hparams = {} 71 | for param in perturbation: 72 | for key in param: 73 | k = self._param_list[key] 74 | hparams[k] = param[key] 75 | 76 | return hparams 77 | 78 | def get_values(self, iteration_no: int): 79 | if iteration_no >= self.get_max_iteration(): 80 | raise ValueError(str(self.get_max_iteration()) + ' < iteration_no >= 0') 81 | 82 | configuration_space = [] 83 | for key in self._config: 84 | configurations = [] 85 | for v in self._config[key]: 86 | configurations.append({key: v}) 87 | configuration_space.append(configurations) 88 | perturbations = list(itertools.product(*configuration_space)) 89 | 90 | perturbation = perturbations[iteration_no] 91 | hparams = {} 92 | for param in perturbation: 93 | for key in param: 94 | hparams[key] = param[key] 95 | 96 | return hparams 97 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/TransferBaseModel.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | from tensorflow import keras 20 | import tensorflow as tf 21 | from tensorflow.python.keras.metrics import categorical_accuracy 22 | from deepspectrumlite import Model 23 | from .modules.augmentable_model import AugmentableModel 24 | from .modules.squeeze_net import SqueezeNet 25 | import logging 26 | 27 | log = logging.getLogger(__name__) 28 | 29 | class TransferBaseModel(Model): 30 | base_model = None 31 | 32 | def retrain_model(self): 33 | if self.hy_params['finetune_layer'] > 0: 34 | self.base_model.trainable = True 35 | 36 | layer_count = len(self.base_model.layers) 37 | keep_until = int(layer_count * self.hy_params['finetune_layer']) 38 | 39 | for layer in self.base_model.layers[:keep_until]: 40 | layer.trainable = False 41 | 42 | optimizer = self.get_optimizer_fn() 43 | optimizer._set_hyper('learning_rate', self.hy_params['fine_learning_rate']) 44 | self.get_model().compile(loss=self.hy_params['loss'], optimizer=optimizer, 45 | metrics=self._metrics) 46 | 47 | self.get_model().summary(print_fn=log.info) 48 | 49 | def create_model(self): 50 | hy_params = self.hy_params 51 | 52 | input = keras.Input(shape=self.input_shape, dtype=tf.float32) 53 | 54 | weights = None 55 | if hy_params['weights'] != '': 56 | weights = hy_params['weights'] 57 | 58 | available_models = { 59 | "vgg16": 60 | tf.keras.applications.vgg16.VGG16, 61 | "vgg19": 62 | tf.keras.applications.vgg19.VGG19, 63 | "resnet50": 64 | tf.keras.applications.resnet50.ResNet50, 65 | "xception": 66 | tf.keras.applications.xception.Xception, 67 | "inception_v3": 68 | tf.keras.applications.inception_v3, 69 | "densenet121": 70 | tf.keras.applications.densenet.DenseNet121, 71 | "densenet169": 72 | tf.keras.applications.densenet.DenseNet169, 73 | "densenet201": 74 | tf.keras.applications.densenet.DenseNet201, 75 | "mobilenet": 76 | tf.keras.applications.mobilenet.MobileNet, 77 | "mobilenet_v2": 78 | tf.keras.applications.mobilenet_v2.MobileNetV2, 79 | "nasnet_large": 80 | tf.keras.applications.nasnet.NASNetLarge, 81 | "nasnet_mobile": 82 | tf.keras.applications.nasnet.NASNetMobile, 83 | "inception_resnet_v2": 84 | tf.keras.applications.inception_resnet_v2.InceptionResNetV2, 85 | "squeezenet_v1": 86 | SqueezeNet, 87 | } 88 | 89 | model_key = hy_params['basemodel_name'] 90 | 91 | if model_key in available_models: 92 | self.base_model = available_models[model_key](weights=weights, include_top=False) 93 | else: 94 | raise ValueError(model_key + ' is not implemented') 95 | 96 | if hy_params['weights'] != '': 97 | training = False 98 | else: 99 | training = True 100 | 101 | self.base_model.trainable = training 102 | 103 | feature_batch_average = tf.keras.layers.GlobalAveragePooling2D()(self.base_model(input, training=training)) 104 | flatten = keras.layers.Flatten()(feature_batch_average) 105 | dense_1 = keras.layers.Dense(hy_params['num_units'], activation=self.get_activation_fn())(flatten) 106 | dropout_1 = keras.layers.Dropout(rate=hy_params['dropout'])(dense_1) 107 | 108 | activation = 'softmax' if self.prediction_type == 'categorical' else 'linear' 109 | if 'output_activation' in hy_params: 110 | activation = hy_params['output_activation'] 111 | predictions = tf.keras.layers.Dense(len(self.data_classes), activation=activation)(dropout_1) 112 | 113 | model = AugmentableModel(inputs=input, outputs=predictions, name=hy_params['basemodel_name']) 114 | model.set_hyper_parameters(hy_params=hy_params) 115 | self.model = model 116 | self.compile_model() 117 | 118 | def train(self, train_dataset: tf.data.Dataset, devel_dataset: tf.data.Dataset): 119 | """ 120 | trains the model with given train data. 121 | """ 122 | epochs_first = self.hy_params['pre_epochs'] 123 | if self.hy_params['weights'] == '': 124 | epochs_first = self.hy_params['epochs'] + self.hy_params['pre_epochs'] 125 | 126 | history = self.get_model().fit(x=train_dataset, epochs=epochs_first, 127 | batch_size=self.hy_params['batch_size'], 128 | shuffle=True, 129 | validation_data=devel_dataset, 130 | callbacks=self.get_callbacks(), verbose=self.verbose) 131 | 132 | if self.hy_params['weights'] != '' and self.hy_params['finetune_layer'] > 0: 133 | self.retrain_model() 134 | 135 | self.get_model().fit(x=train_dataset, epochs=epochs_first + self.hy_params['epochs'], 136 | initial_epoch=history.epoch[-1], 137 | batch_size=self.hy_params['batch_size'], 138 | shuffle=True, 139 | validation_data=devel_dataset, 140 | callbacks=self.get_callbacks(), verbose=self.verbose) 141 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/__init__.py: -------------------------------------------------------------------------------- 1 | """Implementations of several models""" 2 | from .ai_model import Model 3 | from .config.gridsearch import * 4 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/model/config/__init__.py -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/config/gridsearch.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | METRIC_ACCURACY = 'accuracy' 20 | METRIC_PRECISION = 'precision' 21 | METRIC_RECALL = 'recall' 22 | METRIC_F_SCORE = 'f1_score' 23 | METRIC_MAE = 'mae' 24 | METRIC_RMSE = 'rmse' 25 | METRIC_MSE = 'mse' 26 | METRIC_LOSS = 'loss' 27 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/src/deepspectrumlite/lib/model/modules/__init__.py -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/modules/arelu.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | 21 | 22 | def ARelu(x, alpha=0.9, beta=2.0): 23 | alpha = tf.clip_by_value(alpha, clip_value_min=0.01, clip_value_max=0.99) 24 | beta = 1 + tf.math.sigmoid(beta) 25 | return tf.nn.relu(x) * beta - tf.nn.relu(-x) * alpha 26 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/modules/attention_module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | 4 | # code taken from https://github.com/kobiso/CBAM-tensorflow 5 | # Author: Byung Soo Ko 6 | 7 | 8 | def se_block(residual, name, ratio=8): 9 | """Contains the implementation of Squeeze-and-Excitation(SE) block. 10 | As described in https://arxiv.org/abs/1709.01507. 11 | """ 12 | 13 | kernel_initializer = tf.keras.initializers.VarianceScaling() 14 | bias_initializer = tf.constant_initializer(value=0.0) 15 | 16 | with tf.name_scope(name): 17 | channel = residual.get_shape()[-1] 18 | # Global average pooling 19 | squeeze = tf.reduce_mean(residual, axis=[1, 2], keepdims=True) 20 | assert squeeze.get_shape()[1:] == (1, 1, channel) 21 | excitation = keras.layers.Dense( 22 | units=channel // ratio, 23 | activation=tf.nn.relu, 24 | kernel_initializer=kernel_initializer, 25 | bias_initializer=bias_initializer, 26 | name='bottleneck_fc')(squeeze) 27 | assert excitation.get_shape()[1:] == (1, 1, channel // ratio) 28 | excitation = keras.layers.Dense( 29 | units=channel, 30 | activation=tf.nn.sigmoid, 31 | kernel_initializer=kernel_initializer, 32 | bias_initializer=bias_initializer, 33 | name='recover_fc')(excitation) 34 | assert excitation.get_shape()[1:] == (1, 1, channel) 35 | # top = tf.multiply(bottom, se, name='scale') 36 | scale = residual * excitation 37 | return scale 38 | 39 | 40 | def cbam_block(input_feature, name, ratio=8): 41 | """Contains the implementation of Convolutional Block Attention Module(CBAM) block. 42 | As described in https://arxiv.org/abs/1807.06521. 43 | """ 44 | 45 | with tf.name_scope(name): 46 | attention_feature = channel_attention(input_feature, 'ch_at', ratio) 47 | attention_feature = spatial_attention(attention_feature, 'sp_at') 48 | return attention_feature 49 | 50 | 51 | def channel_attention(input_feature, name, ratio=8): 52 | kernel_initializer = tf.keras.initializers.VarianceScaling() 53 | bias_initializer = tf.constant_initializer(value=0.0) 54 | 55 | with tf.name_scope(name): 56 | channel = input_feature.get_shape()[-1] 57 | avg_pool = tf.reduce_mean(input_feature, axis=[1, 2], keepdims=True) 58 | 59 | assert avg_pool.get_shape()[1:] == (1, 1, channel) 60 | avg_pool = keras.layers.Dense( 61 | units=channel // ratio, 62 | activation=tf.nn.relu, 63 | kernel_initializer=kernel_initializer, 64 | bias_initializer=bias_initializer)(avg_pool) 65 | assert avg_pool.get_shape()[1:] == (1, 1, channel // ratio) 66 | avg_pool = keras.layers.Dense( 67 | units=channel, 68 | kernel_initializer=kernel_initializer, 69 | bias_initializer=bias_initializer)(avg_pool) 70 | assert avg_pool.get_shape()[1:] == (1, 1, channel) 71 | 72 | max_pool = tf.reduce_max(input_feature, axis=[1, 2], keepdims=True) 73 | assert max_pool.get_shape()[1:] == (1, 1, channel) 74 | max_pool = keras.layers.Dense( 75 | units=channel // ratio, 76 | activation=tf.nn.relu)(max_pool) 77 | assert max_pool.get_shape()[1:] == (1, 1, channel // ratio) 78 | max_pool = keras.layers.Dense( 79 | units=channel)(max_pool) 80 | assert max_pool.get_shape()[1:] == (1, 1, channel) 81 | 82 | scale = tf.sigmoid(avg_pool + max_pool, 'sigmoid') 83 | 84 | return input_feature * scale 85 | 86 | 87 | def spatial_attention(input_feature, name): 88 | kernel_size = 7 89 | kernel_initializer = tf.keras.initializers.VarianceScaling() 90 | with tf.name_scope(name): 91 | avg_pool = tf.reduce_mean(input_feature, axis=[3], keepdims=True) 92 | assert avg_pool.get_shape()[-1] == 1 93 | max_pool = tf.reduce_max(input_feature, axis=[3], keepdims=True) 94 | assert max_pool.get_shape()[-1] == 1 95 | concat = tf.concat([avg_pool, max_pool], 3) 96 | assert concat.get_shape()[-1] == 2 97 | 98 | concat = keras.layers.Conv2D( 99 | filters=1, 100 | kernel_size=[kernel_size, kernel_size], 101 | strides=[1, 1], 102 | padding="same", 103 | activation=None, 104 | kernel_initializer=kernel_initializer, 105 | use_bias=False)(concat) 106 | assert concat.get_shape()[-1] == 1 107 | concat = tf.sigmoid(concat, 'sigmoid') 108 | 109 | return input_feature * concat 110 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/modules/augmentable_model.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | from tensorflow.python.eager import backprop 21 | from tensorflow.python.keras.engine import data_adapter 22 | 23 | from deepspectrumlite.lib.data.data_pipeline import preprocess_scalar_zero 24 | 25 | 26 | ''' 27 | AugmentableModel implements a SapAugment data augmentation policy 28 | 29 | Hu, Ting-yao et al. "SapAugment: Learning A Sample Adaptive Policy for Data Augmentation" (2020). 30 | @see https://arxiv.org/abs/2011.01156 31 | ''' 32 | class AugmentableModel(tf.keras.Model): 33 | 34 | def __init__(self, *args, **kwargs): 35 | super(AugmentableModel, self).__init__(*args, **kwargs) 36 | self.hy_params = {} 37 | self._batch_size = None 38 | self.lambda_sap = None 39 | self._sap_augment_a = None 40 | self._sap_augment_s = None 41 | 42 | def set_hyper_parameters(self, hy_params): 43 | self.hy_params = hy_params 44 | 45 | self._batch_size = self.hy_params['batch_size'] 46 | 47 | self.lambda_sap = tf.zeros(shape=(self._batch_size,), dtype=tf.float32, name="lamda_sap") 48 | 49 | self._sap_augment_a = self.hy_params['sap_aug_a'] 50 | self._sap_augment_s = self.hy_params['sap_aug_s'] 51 | 52 | def set_batch_size(self, batch_size): 53 | self._batch_size = batch_size 54 | 55 | def set_sap_augment(self, a, s): 56 | self._sap_augment_a = a 57 | self._sap_augment_s = s 58 | 59 | def train_step(self, data): 60 | data = data_adapter.expand_1d(data) 61 | x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data) 62 | 63 | with backprop.GradientTape() as tape: 64 | y_pred = self(x, training=False) 65 | loss_list = [] 66 | for i in range(self._batch_size): 67 | # TODO sample_weight None 68 | instance_loss = self.compiled_loss( 69 | y[i,:], y_pred[i, :], None, regularization_losses=self.losses) 70 | loss_list.append(instance_loss) 71 | 72 | sorted_keys = tf.argsort(loss_list, axis=-1, direction='ASCENDING') 73 | 74 | a = self._sap_augment_a 75 | s = self._sap_augment_s 76 | 77 | alpha = s * (1 - a) 78 | beta = s * a 79 | new_lambda = tf.zeros(shape=(self._batch_size, 1)) 80 | 81 | for i in range(self._batch_size): 82 | ranking_relative = (i + 1) / self._batch_size 83 | beta_inc = 1 - tf.math.betainc(alpha, beta, ranking_relative) # D = 1 -> 0 84 | 85 | j = sorted_keys[i] 86 | p1 = tf.zeros(shape=(j, 1)) 87 | p2 = tf.zeros(shape=(self._batch_size-1-j, 1)) 88 | pf = tf.concat([p1, [(beta_inc,)]], 0) 89 | pf = tf.concat([pf, p2], 0) 90 | new_lambda = new_lambda + pf 91 | 92 | self.lambda_sap = new_lambda 93 | 94 | # import numpy as np 95 | # 96 | # for i in range(self._batch_size): 97 | # image = x[i] 98 | # image = np.array(image) 99 | # lambda_value = self.lambda_sap[i][0].numpy() 100 | # loss_value = loss_list[i].numpy() 101 | # path = '/Users/tobias/Downloads/debug/' + str(i) + '-' + str( 102 | # lambda_value) + '-' + str( 103 | # loss_value) + '-orig.png' 104 | # tf.keras.preprocessing.image.save_img( 105 | # path=path, x=image, data_format='channels_last', scale=True 106 | # ) 107 | 108 | # tf.print("x=") 109 | # tf.print(tf.cast(tf.round(tf.squeeze(x[0])), tf.int32), summarize=-1, sep=",") 110 | if self.hy_params['augment_cutmix']: 111 | x, y = self.cutmix(x, y) 112 | 113 | # for i in range(self._batch_size): 114 | # image = x[i] 115 | # image = np.array(image) 116 | # lambda_value = self.lambda_sap[i][0].numpy() 117 | # loss_value = loss_list[i].numpy() 118 | # path = '/Users/tobias/Downloads/debug/' + str(i) + '-' + str( 119 | # lambda_value) + '-' + str( 120 | # loss_value) + '-cutmix.png' 121 | # tf.keras.preprocessing.image.save_img( 122 | # path=path, x=image, data_format='channels_last', scale=True 123 | # ) 124 | 125 | # tf.print("cutmix=") 126 | # tf.print(tf.cast(tf.round(tf.squeeze(x[0])), tf.int32), summarize=-1, sep=",") 127 | if self.hy_params['augment_specaug']: 128 | x, y = self.apply_spec_aug(x, y) 129 | # tf.print("spec_aug=") 130 | # tf.print(tf.cast(tf.round(tf.squeeze(x[0])), tf.int32), summarize=-1, sep=",") 131 | # self.stop_training = True 132 | 133 | # for i in range(self._batch_size): 134 | # image = x[i] 135 | # image = np.array(image) 136 | # lambda_value = self.lambda_sap[i][0].numpy() 137 | # loss_value = loss_list[i].numpy() 138 | # path = '/Users/tobias/Downloads/debug/' + str(i) + '-' + str( 139 | # lambda_value) + '-' + str( 140 | # loss_value) + '-both.png' 141 | # tf.keras.preprocessing.image.save_img( 142 | # path=path, x=image, data_format='channels_last', scale=True 143 | # ) 144 | # self.stop_training = True 145 | 146 | #tf.print("x_NEW=") 147 | #tf.print(x[0]) 148 | 149 | y_pred = self(x, training=True) # batch size, softmax output 150 | loss = self.compiled_loss( 151 | y, y_pred, sample_weight, regularization_losses=self.losses) 152 | 153 | self.optimizer.minimize(loss, self.trainable_variables, tape=tape) 154 | self.compiled_metrics.update_state(y, y_pred, sample_weight) 155 | return {m.name: m.result() for m in self.metrics} 156 | 157 | def cutmix(self, images, labels): 158 | """ 159 | Implements cutmix data augmentation 160 | Yun, Sangdoo et al. "CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" (2019). 161 | 162 | @param images: batch of images 163 | @param labels: batch of labels (one hot encoded) 164 | @return: images, labels 165 | @see https://arxiv.org/abs/1905.04899 166 | """ 167 | spectrogram_x = self.input_shape[1] 168 | spectrogram_y = self.input_shape[2] 169 | 170 | min_augment = self.hy_params['cutmix_min'] 171 | max_augment = self.hy_params['cutmix_max']-self.hy_params['cutmix_min'] 172 | 173 | probability_min = self.hy_params['da_prob_min'] 174 | probability_max = self.hy_params['da_prob_max'] 175 | 176 | image_list = [] 177 | label_list = [] 178 | for j in range(self._batch_size): 179 | probability_threshold = probability_min + probability_max * tf.squeeze(self.lambda_sap[j]) 180 | P = tf.cast(tf.random.uniform([], 0, 1) <= probability_threshold, tf.int32) 181 | 182 | # works as we have square images only 183 | w = tf.cast(tf.round(min_augment * spectrogram_x + max_augment * spectrogram_x * tf.squeeze(self.lambda_sap[j])), tf.int32) * P 184 | 185 | k = tf.cast(tf.random.uniform([], 0, self._batch_size), tf.int32) 186 | x = tf.cast(tf.random.uniform([], 0, spectrogram_x), tf.int32) 187 | y = tf.cast(tf.random.uniform([], 0, spectrogram_y), tf.int32) 188 | 189 | xa = tf.math.maximum(0, x - w // 2) # xa denotes the start 190 | xb = tf.math.minimum(spectrogram_x, x + w // 2) # xb denotes the end 191 | 192 | ya = tf.math.maximum(0, y - w // 2) # xa denotes the start 193 | yb = tf.math.minimum(spectrogram_y, y + w // 2) # xb denotes the end 194 | 195 | piece = tf.concat([ 196 | images[j, xa:xb, 0:ya, ], 197 | images[k, xa:xb, ya:yb, ], 198 | images[j, xa:xb, yb:spectrogram_y, ] 199 | ], axis=1) 200 | 201 | image = tf.concat([images[j, 0:xa, :, ], 202 | piece, 203 | images[j, xb:spectrogram_x, :, ]], axis=0) 204 | image_list.append(image) 205 | 206 | a = tf.cast((w**2) / (spectrogram_x*spectrogram_y), tf.float32) 207 | 208 | label_1 = labels[j,] 209 | label_2 = labels[k,] 210 | label_list.append((1 - a) * label_1 + a * label_2) 211 | 212 | 213 | x = tf.reshape(tf.stack(image_list), (self._batch_size, spectrogram_x, spectrogram_y, self.input_shape[3])) 214 | y = tf.reshape(tf.stack(label_list), (self._batch_size, self.output_shape[1])) 215 | 216 | return x, y 217 | 218 | def apply_spec_aug(self, images, labels): 219 | """ 220 | Implements SpecAugment 221 | Park, Daniel et al. "SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition" (2019). 222 | 223 | @param images: batch of images 224 | @param labels: batch of labels 225 | @return: images, labels 226 | @see https://arxiv.org/abs/1904.08779 227 | """ 228 | with tf.name_scope("apply_spec_aug"): 229 | n = self.input_shape[1] # x = time 230 | v = self.input_shape[2] # y = freq 231 | 232 | frequency_mask_num = self.hy_params['specaug_freq_mask_num'] 233 | time_mask_num = self.hy_params['specaug_time_mask_num'] 234 | 235 | freq_min_augment = self.hy_params['specaug_freq_min'] 236 | freq_max_augment = self.hy_params['specaug_freq_max']-self.hy_params['specaug_freq_min'] 237 | time_min_augment = self.hy_params['specaug_time_min'] 238 | time_max_augment = self.hy_params['specaug_time_max']-self.hy_params['specaug_time_min'] 239 | 240 | probability_min = self.hy_params['da_prob_min'] 241 | probability_max = self.hy_params['da_prob_max'] 242 | 243 | output_mel_spectrogram = [] 244 | 245 | for j in range(self._batch_size): 246 | probability_threshold = probability_min + probability_max * tf.squeeze(self.lambda_sap[j]) 247 | P = tf.cast(tf.random.uniform([], 0, 1) <= probability_threshold, tf.int32) 248 | 249 | f = tf.cast(tf.round((freq_min_augment * v + freq_max_augment * v * tf.squeeze(self.lambda_sap[j])) / frequency_mask_num), tf.int32) * P 250 | t = tf.cast(tf.round((time_min_augment * n + time_max_augment * n * tf.squeeze(self.lambda_sap[j])) / time_mask_num), tf.int32) * P 251 | 252 | tmp = images 253 | 254 | for i in range(frequency_mask_num): 255 | f0 = tf.random.uniform([], minval=0, maxval=v - f, dtype=tf.int32) 256 | 257 | mask = tf.concat((tf.ones(shape=(1, n, v - f0 - f, 1)), 258 | tf.zeros(shape=(1, n, f, 1)), 259 | tf.ones(shape=(1, n, f0, 1)), 260 | ), 2) 261 | tmp = tmp * mask 262 | 263 | for i in range(time_mask_num): 264 | t0 = tf.random.uniform([], minval=0, maxval=n - t, dtype=tf.int32) 265 | 266 | mask = tf.concat((tf.ones(shape=(1, n - t0 - t, v, 1)), 267 | tf.zeros(shape=(1, t, v, 1)), 268 | tf.ones(shape=(1, t0, v, 1)), 269 | ), 1) 270 | tmp = tmp * mask 271 | 272 | output_mel_spectrogram.append(tmp[j]) 273 | 274 | images = tf.stack(output_mel_spectrogram) 275 | #tf.print(mel_spectrogram[0], summarize=-1, sep=",") 276 | # replace zero values by the mean 277 | preprocessed_mask = preprocess_scalar_zero(self.hy_params['basemodel_name']) 278 | images = tf.where(images == 0.0, preprocessed_mask, images) 279 | 280 | #tf.print(mel_spectrogram[0], summarize=-1, sep=",") 281 | #self.stop_training = True 282 | return images, labels -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/model/modules/squeeze_net.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | 20 | from tensorflow import keras 21 | from keras_applications.imagenet_utils import _obtain_input_shape 22 | from tensorflow.keras import backend as K 23 | from tensorflow.keras.layers import Input, Convolution2D, MaxPooling2D, Activation, concatenate, Dropout 24 | from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D 25 | from tensorflow.keras.models import Model 26 | from tensorflow.keras.utils import get_source_inputs 27 | from tensorflow.keras.utils import get_file 28 | from tensorflow.python.keras.utils import layer_utils 29 | 30 | sq1x1 = "squeeze1x1" 31 | exp1x1 = "expand1x1" 32 | exp3x3 = "expand3x3" 33 | relu = "relu_" 34 | 35 | WEIGHTS_PATH = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels.h5" 36 | WEIGHTS_PATH_NO_TOP = "https://github.com/rcmalli/keras-squeezenet/releases/download/v1.0/squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5" 37 | 38 | # source https://github.com/rcmalli/keras-squeezenet 39 | 40 | # Modular function for Fire Node 41 | 42 | def fire_module(x, fire_id, squeeze=16, expand=64): 43 | s_id = 'fire' + str(fire_id) + '/' 44 | 45 | if K.image_data_format() == 'channels_first': 46 | channel_axis = 1 47 | else: 48 | channel_axis = 3 49 | 50 | x = Convolution2D(squeeze, (1, 1), padding='valid', name=s_id + sq1x1)(x) 51 | x = Activation('relu', name=s_id + relu + sq1x1)(x) 52 | 53 | left = Convolution2D(expand, (1, 1), padding='valid', name=s_id + exp1x1)(x) 54 | left = Activation('relu', name=s_id + relu + exp1x1)(left) 55 | 56 | right = Convolution2D(expand, (3, 3), padding='same', name=s_id + exp3x3)(x) 57 | right = Activation('relu', name=s_id + relu + exp3x3)(right) 58 | 59 | x = concatenate([left, right], axis=channel_axis, name=s_id + 'concat') 60 | return x 61 | 62 | 63 | # Original SqueezeNet from paper. 64 | 65 | def SqueezeNet(include_top=True, weights='imagenet', 66 | input_tensor=None, input_shape=None, 67 | pooling=None, 68 | classes=1000): 69 | """Instantiates the SqueezeNet architecture. 70 | """ 71 | 72 | if weights not in {'imagenet', None}: 73 | raise ValueError('The `weights` argument should be either ' 74 | '`None` (random initialization) or `imagenet` ' 75 | '(pre-training on ImageNet).') 76 | 77 | if weights == 'imagenet' and classes != 1000: 78 | raise ValueError('If using `weights` as imagenet with `include_top`' 79 | ' as true, `classes` should be 1000') 80 | 81 | input_shape = _obtain_input_shape(input_shape, 82 | default_size=227, 83 | min_size=48, 84 | data_format=K.image_data_format(), 85 | require_flatten=include_top) 86 | 87 | if input_tensor is None: 88 | img_input = Input(shape=input_shape) 89 | else: 90 | if not K.is_keras_tensor(input_tensor): 91 | img_input = Input(tensor=input_tensor, shape=input_shape) 92 | else: 93 | img_input = input_tensor 94 | 95 | x = Convolution2D(64, (3, 3), strides=(2, 2), padding='valid', name='conv1')(img_input) 96 | x = Activation('relu', name='relu_conv1')(x) 97 | x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool1')(x) 98 | 99 | x = fire_module(x, fire_id=2, squeeze=16, expand=64) 100 | x = fire_module(x, fire_id=3, squeeze=16, expand=64) 101 | x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool3')(x) 102 | 103 | x = fire_module(x, fire_id=4, squeeze=32, expand=128) 104 | x = fire_module(x, fire_id=5, squeeze=32, expand=128) 105 | x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), name='pool5')(x) 106 | 107 | x = fire_module(x, fire_id=6, squeeze=48, expand=192) 108 | x = fire_module(x, fire_id=7, squeeze=48, expand=192) 109 | x = fire_module(x, fire_id=8, squeeze=64, expand=256) 110 | x = fire_module(x, fire_id=9, squeeze=64, expand=256) 111 | 112 | if include_top: 113 | # It's not obvious where to cut the network... 114 | # Could do the 8th or 9th layer... some work recommends cutting earlier layers. 115 | 116 | x = Dropout(0.5, name='drop9')(x) 117 | 118 | x = Convolution2D(classes, (1, 1), padding='valid', name='conv10')(x) 119 | x = Activation('relu', name='relu_conv10')(x) 120 | x = GlobalAveragePooling2D()(x) 121 | x = Activation('softmax', name='loss')(x) 122 | else: 123 | if pooling == 'avg': 124 | x = GlobalAveragePooling2D()(x) 125 | elif pooling == 'max': 126 | x = GlobalMaxPooling2D()(x) 127 | elif pooling == None: 128 | pass 129 | else: 130 | raise ValueError("Unknown argument for 'pooling'=" + pooling) 131 | 132 | # Ensure that the model takes into account 133 | # any potential predecessors of `input_tensor`. 134 | if input_tensor is not None: 135 | inputs = get_source_inputs(input_tensor) 136 | else: 137 | inputs = img_input 138 | 139 | model = Model(inputs, x, name='squeezenet') 140 | 141 | # load weights 142 | if weights == 'imagenet': 143 | if include_top: 144 | weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels.h5', 145 | WEIGHTS_PATH, 146 | cache_subdir='models') 147 | else: 148 | weights_path = get_file('squeezenet_weights_tf_dim_ordering_tf_kernels_notop.h5', 149 | WEIGHTS_PATH_NO_TOP, 150 | cache_subdir='models') 151 | 152 | model.load_weights(weights_path) 153 | if K.backend() == 'theano': 154 | raise ValueError('theano backend not supported') 155 | # layer_utils.convert_all_kernels_in_model(model) 156 | 157 | if K.image_data_format() == 'channels_first': 158 | 159 | if K.backend() == 'tensorflow': 160 | print('You are using the TensorFlow backend, yet you ' 161 | 'are using the Theano ' 162 | 'image data format convention ' 163 | '(`image_data_format="channels_first"`). ' 164 | 'For best performance, set ' 165 | '`image_data_format="channels_last"` in ' 166 | 'your Keras config ' 167 | 'at ~/.keras/keras.json.') 168 | return model 169 | 170 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/util/__init__.py: -------------------------------------------------------------------------------- 1 | """some useful utility scripts""" 2 | from .audio_utils import * 3 | -------------------------------------------------------------------------------- /src/deepspectrumlite/lib/util/audio_utils.py: -------------------------------------------------------------------------------- 1 | # DeepSpectrumLite 2 | # ============================================================================== 3 | # Copyright (C) 2020-2021 Shahin Amiriparian, Tobias Hübner, Maurice Gerczuk, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | 21 | 22 | def amplitude_to_db(S, amin=1e-16, top_db=80.0): # pragma: no cover 23 | magnitude = tf.abs(S) 24 | ref_value = tf.reduce_max(magnitude) 25 | 26 | power = tf.square(magnitude) 27 | return power_to_db(power, ref=ref_value ** 2, amin=amin ** 2, top_db=top_db) 28 | 29 | 30 | def power_to_db(S, ref=1.0, amin=1e-16, top_db=80.0): # pragma: no cover 31 | """Convert a power-spectrogram (magnitude squared) to decibel (dB) units. 32 | Computes the scaling ``10 * log10(S / max(S))`` in a numerically 33 | stable way. 34 | Based on: 35 | https://librosa.github.io/librosa/generated/librosa.core.power_to_db.html 36 | """ 37 | 38 | # @tf.function 39 | def _tf_log10(x): 40 | numerator = tf.math.log(x) 41 | denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) 42 | return numerator / denominator 43 | 44 | # Scale magnitude relative to maximum value in S. Zeros in the output 45 | # correspond to positions where S == ref. 46 | ref_value = tf.abs(ref) 47 | 48 | log_spec = 10.0 * _tf_log10(tf.maximum(amin, S)) 49 | log_spec -= 10.0 * _tf_log10(tf.maximum(amin, ref_value)) 50 | 51 | if top_db is not None: 52 | log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec) - top_db) 53 | 54 | return log_spec 55 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/__init__.py -------------------------------------------------------------------------------- /tests/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/cli/__init__.py -------------------------------------------------------------------------------- /tests/cli/audio/dog/file0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/cli/audio/dog/file0.wav -------------------------------------------------------------------------------- /tests/cli/audio/seagull/file1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSpectrum/DeepSpectrumLite/a09e668a2d8880c37ccb2a797f235c5f8c46acfd/tests/cli/audio/seagull/file1.wav -------------------------------------------------------------------------------- /tests/cli/config/cividis_hp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"], 3 | "model_name": ["TransferBaseModel"], 4 | "prediction_type": ["categorical"], 5 | "basemodel_name": ["densenet121"], 6 | "weights": ["imagenet"], 7 | "tb_experiment": ["cividis_exp"], 8 | "tb_run_id": ["cividis_run"], 9 | "num_units": [4], 10 | "dropout": [0.25], 11 | "optimizer": ["adam"], 12 | "learning_rate": [0.001], 13 | "fine_learning_rate": [0.0001], 14 | "finetune_layer": [0.7], 15 | "loss": ["categorical_crossentropy"], 16 | "activation": ["arelu"], 17 | "pre_epochs": [1], 18 | "epochs": [1], 19 | "batch_size": [1], 20 | 21 | "sample_rate": [16000], 22 | 23 | "chunk_size": [2.0], 24 | "chunk_hop_size": [1.0], 25 | "normalize_audio": [false], 26 | 27 | "stft_window_size": [0.128], 28 | "stft_hop_size": [0.064], 29 | "stft_fft_length": [0.128], 30 | 31 | "mel_scale": [true], 32 | "lower_edge_hertz": [0.0], 33 | "upper_edge_hertz": [8000.0], 34 | "num_mel_bins": [128], 35 | "num_mfccs": [0], 36 | "cep_lifter": [0], 37 | "db_scale": [true], 38 | "use_plot_images": [true], 39 | "color_map": ["cividis"], 40 | "image_width": [224], 41 | "image_height": [224], 42 | "resize_method": ["nearest"], 43 | "anti_alias": [false], 44 | 45 | "sap_aug_a": [0.5], 46 | "sap_aug_s": [10], 47 | "augment_cutmix": [false], 48 | "augment_specaug": [false], 49 | "da_prob_min": [0.1], 50 | "da_prob_max": [0.5], 51 | "cutmix_min": [0.075], 52 | "cutmix_max": [0.25], 53 | "specaug_freq_min": [0.1], 54 | "specaug_freq_max": [0.3], 55 | "specaug_time_min": [0.1], 56 | "specaug_time_max": [0.3], 57 | "specaug_freq_mask_num": [1], 58 | "specaug_time_mask_num": [1] 59 | } -------------------------------------------------------------------------------- /tests/cli/config/inferno_hp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"], 3 | "model_name": ["TransferBaseModel"], 4 | "prediction_type": ["categorical"], 5 | "basemodel_name": ["densenet121"], 6 | "weights": ["imagenet"], 7 | "tb_experiment": ["inferno_exp"], 8 | "tb_run_id": ["inferno_run"], 9 | "num_units": [4], 10 | "dropout": [0.25], 11 | "optimizer": ["adagrad"], 12 | "learning_rate": [0.001], 13 | "fine_learning_rate": [0.0001], 14 | "finetune_layer": [0.7], 15 | "loss": ["categorical_crossentropy"], 16 | "activation": ["arelu"], 17 | "pre_epochs": [1], 18 | "epochs": [1], 19 | "batch_size": [1], 20 | 21 | "sample_rate": [16000], 22 | 23 | "chunk_size": [2.0], 24 | "chunk_hop_size": [1.0], 25 | "normalize_audio": [false], 26 | 27 | "stft_window_size": [0.128], 28 | "stft_hop_size": [0.064], 29 | "stft_fft_length": [0.128], 30 | 31 | "mel_scale": [true], 32 | "lower_edge_hertz": [0.0], 33 | "upper_edge_hertz": [8000.0], 34 | "num_mel_bins": [128], 35 | "num_mfccs": [0], 36 | "cep_lifter": [0], 37 | "db_scale": [true], 38 | "use_plot_images": [true], 39 | "color_map": ["inferno"], 40 | "image_width": [224], 41 | "image_height": [224], 42 | "resize_method": ["nearest"], 43 | "anti_alias": [false], 44 | 45 | "sap_aug_a": [0.5], 46 | "sap_aug_s": [10], 47 | "augment_cutmix": [false], 48 | "augment_specaug": [false], 49 | "da_prob_min": [0.1], 50 | "da_prob_max": [0.5], 51 | "cutmix_min": [0.075], 52 | "cutmix_max": [0.25], 53 | "specaug_freq_min": [0.1], 54 | "specaug_freq_max": [0.3], 55 | "specaug_time_min": [0.1], 56 | "specaug_time_max": [0.3], 57 | "specaug_freq_mask_num": [1], 58 | "specaug_time_mask_num": [1] 59 | } -------------------------------------------------------------------------------- /tests/cli/config/magma_hp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"], 3 | "model_name": ["TransferBaseModel"], 4 | "prediction_type": ["categorical"], 5 | "basemodel_name": ["densenet121"], 6 | "weights": ["imagenet"], 7 | "tb_experiment": ["magma_exp"], 8 | "tb_run_id": ["magma_run"], 9 | "num_units": [4], 10 | "dropout": [0.25], 11 | "optimizer": ["nadam"], 12 | "learning_rate": [0.001], 13 | "fine_learning_rate": [0.0001], 14 | "finetune_layer": [0.7], 15 | "loss": ["categorical_crossentropy"], 16 | "activation": ["arelu"], 17 | "pre_epochs": [1], 18 | "epochs": [1], 19 | "batch_size": [1], 20 | 21 | "sample_rate": [16000], 22 | 23 | "chunk_size": [2.0], 24 | "chunk_hop_size": [1.0], 25 | "normalize_audio": [false], 26 | 27 | "stft_window_size": [0.128], 28 | "stft_hop_size": [0.064], 29 | "stft_fft_length": [0.128], 30 | 31 | "mel_scale": [true], 32 | "lower_edge_hertz": [0.0], 33 | "upper_edge_hertz": [8000.0], 34 | "num_mel_bins": [128], 35 | "num_mfccs": [0], 36 | "cep_lifter": [0], 37 | "db_scale": [true], 38 | "use_plot_images": [true], 39 | "color_map": ["magma"], 40 | "image_width": [224], 41 | "image_height": [224], 42 | "resize_method": ["nearest"], 43 | "anti_alias": [false], 44 | 45 | "sap_aug_a": [0.5], 46 | "sap_aug_s": [10], 47 | "augment_cutmix": [false], 48 | "augment_specaug": [false], 49 | "da_prob_min": [0.1], 50 | "da_prob_max": [0.5], 51 | "cutmix_min": [0.075], 52 | "cutmix_max": [0.25], 53 | "specaug_freq_min": [0.1], 54 | "specaug_freq_max": [0.3], 55 | "specaug_time_min": [0.1], 56 | "specaug_time_max": [0.3], 57 | "specaug_freq_mask_num": [1], 58 | "specaug_time_mask_num": [1] 59 | } -------------------------------------------------------------------------------- /tests/cli/config/plasma_hp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"], 3 | "model_name": ["TransferBaseModel"], 4 | "prediction_type": ["categorical"], 5 | "basemodel_name": ["densenet121"], 6 | "weights": ["imagenet"], 7 | "tb_experiment": ["plasma_exp"], 8 | "tb_run_id": ["plasma_run"], 9 | "num_units": [4], 10 | "dropout": [0.25], 11 | "optimizer": ["rmsprop"], 12 | "learning_rate": [0.001], 13 | "fine_learning_rate": [0.0001], 14 | "finetune_layer": [0.7], 15 | "loss": ["categorical_crossentropy"], 16 | "activation": ["arelu"], 17 | "pre_epochs": [1], 18 | "epochs": [1], 19 | "batch_size": [1], 20 | 21 | "sample_rate": [16000], 22 | 23 | "chunk_size": [2.0], 24 | "chunk_hop_size": [1.0], 25 | "normalize_audio": [false], 26 | 27 | "stft_window_size": [0.128], 28 | "stft_hop_size": [0.064], 29 | "stft_fft_length": [0.128], 30 | 31 | "mel_scale": [true], 32 | "lower_edge_hertz": [0.0], 33 | "upper_edge_hertz": [8000.0], 34 | "num_mel_bins": [128], 35 | "num_mfccs": [0], 36 | "cep_lifter": [0], 37 | "db_scale": [true], 38 | "use_plot_images": [true], 39 | "color_map": ["plasma"], 40 | "image_width": [224], 41 | "image_height": [224], 42 | "resize_method": ["nearest"], 43 | "anti_alias": [false], 44 | 45 | "sap_aug_a": [0.5], 46 | "sap_aug_s": [10], 47 | "augment_cutmix": [false], 48 | "augment_specaug": [false], 49 | "da_prob_min": [0.1], 50 | "da_prob_max": [0.5], 51 | "cutmix_min": [0.075], 52 | "cutmix_max": [0.25], 53 | "specaug_freq_min": [0.1], 54 | "specaug_freq_max": [0.3], 55 | "specaug_time_min": [0.1], 56 | "specaug_time_max": [0.3], 57 | "specaug_freq_mask_num": [1], 58 | "specaug_time_mask_num": [1] 59 | } -------------------------------------------------------------------------------- /tests/cli/config/regression_class_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "r": "r" 3 | } -------------------------------------------------------------------------------- /tests/cli/config/regression_hp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "label_parser": ["../lib/data/parser/ComParEParser.py:ComParEParser"], 3 | "model_name": ["TransferBaseModel"], 4 | "prediction_type": ["regression"], 5 | "basemodel_name": ["squeezenet_v1"], 6 | "weights": ["imagenet"], 7 | "tb_experiment": ["regression_exp"], 8 | "tb_run_id": ["regression_run"], 9 | "num_units": [4], 10 | "dropout": [0.25], 11 | "optimizer": ["ftrl"], 12 | "learning_rate": [0.001], 13 | "fine_learning_rate": [0.0001], 14 | "finetune_layer": [0.7], 15 | "loss": ["mae"], 16 | "activation": ["relu"], 17 | "output_activation": ["sigmoid"], 18 | "pre_epochs": [1], 19 | "epochs": [1], 20 | "batch_size": [1], 21 | 22 | "sample_rate": [16000], 23 | 24 | "chunk_size": [2.0], 25 | "chunk_hop_size": [1.0], 26 | "normalize_audio": [true], 27 | 28 | "stft_window_size": [0.128], 29 | "stft_hop_size": [0.064], 30 | "stft_fft_length": [0.128], 31 | 32 | "mel_scale": [true], 33 | "lower_edge_hertz": [0.0], 34 | "upper_edge_hertz": [8000.0], 35 | "num_mel_bins": [128], 36 | "num_mfccs": [0], 37 | "cep_lifter": [0], 38 | "db_scale": [false], 39 | "use_plot_images": [true], 40 | "color_map": ["viridis"], 41 | "image_width": [224], 42 | "image_height": [224], 43 | "resize_method": ["nearest"], 44 | "anti_alias": [false], 45 | 46 | "sap_aug_a": [0.5], 47 | "sap_aug_s": [10], 48 | "augment_cutmix": [false], 49 | "augment_specaug": [false], 50 | "da_prob_min": [0.1], 51 | "da_prob_max": [0.5], 52 | "cutmix_min": [0.075], 53 | "cutmix_max": [0.25], 54 | "specaug_freq_min": [0.1], 55 | "specaug_freq_max": [0.3], 56 | "specaug_time_min": [0.1], 57 | "specaug_time_max": [0.3], 58 | "specaug_freq_mask_num": [1], 59 | "specaug_time_mask_num": [1] 60 | } -------------------------------------------------------------------------------- /tests/cli/label/regression_label.csv: -------------------------------------------------------------------------------- 1 | filename,label,duration_frames 2 | train_seagull.wav,0.5,117249 3 | train_dog.wav,0.9,87865 4 | devel_seagull.wav,0.5,117249 5 | devel_dog.wav,0.9,87865 6 | test_seagull.wav,0.5,117249 7 | test_dog.wav,0.9,87865 -------------------------------------------------------------------------------- /tests/cli/test_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tempfile 4 | 5 | def get_tmp_dir(): 6 | tmpdir = tempfile.mkdtemp() 7 | subdir = os.path.join(tmpdir, "pytest_test_train") 8 | os.mkdir(subdir) 9 | return os.path.join(subdir, "") 10 | 11 | temp_dir = get_tmp_dir() 12 | 13 | def test_train(): 14 | tmpdir = temp_dir 15 | from deepspectrumlite.__main__ import cli 16 | from click.testing import CliRunner 17 | from os.path import dirname, join 18 | 19 | cur_dir = dirname(__file__) 20 | examples = join(dirname(dirname(cur_dir)), 'examples') 21 | 22 | runner = CliRunner() 23 | result = runner.invoke(cli, 24 | args=[ 25 | '-vv', 'train', 26 | '-d', join(examples, 'audio'), 27 | '-md', join(tmpdir, 'output'), 28 | '-hc', join(examples, 'hp_config.json'), 29 | '-cc', join(examples, 'class_config.json'), 30 | '-l', join(examples, 'label.csv') 31 | ], 32 | catch_exceptions=False) 33 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code) 34 | 35 | def test_inference(): 36 | tmpdir = temp_dir 37 | from deepspectrumlite.__main__ import cli 38 | from click.testing import CliRunner 39 | from os.path import dirname, join 40 | 41 | cur_dir = dirname(__file__) 42 | examples = join(dirname(dirname(cur_dir)), 'examples') 43 | 44 | runner = CliRunner() 45 | result = runner.invoke(cli, 46 | args=[ 47 | '-vv', 'predict', 48 | '-d', join(cur_dir, 'audio'), 49 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0', 50 | 'model.h5'), 51 | '-hc', join(examples, 'hp_config.json'), 52 | '-cc', join(examples, 'class_config.json') 53 | ], 54 | catch_exceptions=False) 55 | assert result.exit_code == 0, f"Exit code for predict is not 0 but " + str(result.exit_code) 56 | 57 | 58 | def test_devel_test(): 59 | tmpdir = temp_dir 60 | from deepspectrumlite.__main__ import cli 61 | from click.testing import CliRunner 62 | from os.path import dirname, join 63 | 64 | cur_dir = dirname(__file__) 65 | examples = join(dirname(dirname(cur_dir)), 'examples') 66 | 67 | runner = CliRunner() 68 | result = runner.invoke(cli, 69 | args=[ 70 | '-vv', 'devel-test', 71 | '-d', join(examples, 'audio'), 72 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0', 73 | 'model.h5'), 74 | '-hc', join(examples, 'hp_config.json'), 75 | '-cc', join(examples, 'class_config.json'), 76 | '-l', join(examples, 'label.csv') 77 | ], 78 | catch_exceptions=False) 79 | assert result.exit_code == 0, f"Exit code for devel-test is not 0 but " + str(result.exit_code) 80 | 81 | 82 | def test_stats(): 83 | tmpdir = temp_dir 84 | from deepspectrumlite.__main__ import cli 85 | from click.testing import CliRunner 86 | from os.path import dirname, join 87 | 88 | runner = CliRunner() 89 | result = runner.invoke(cli, 90 | args=[ 91 | '-vv', 'stats', 92 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0', 93 | 'model.h5') 94 | ], 95 | catch_exceptions=False) 96 | assert result.exit_code == 0, f"Exit code for stats is not 0 but " + str(result.exit_code) 97 | 98 | 99 | def test_convert(): 100 | tmpdir = temp_dir 101 | from deepspectrumlite.__main__ import cli 102 | from click.testing import CliRunner 103 | from os.path import dirname, join 104 | 105 | runner = CliRunner() 106 | result = runner.invoke(cli, 107 | args=[ 108 | '-vv', 'convert', 109 | '-s', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0', 110 | 'model.h5'), 111 | '-d', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0', 112 | 'converted_model.tflite') 113 | ], 114 | catch_exceptions=False) 115 | assert result.exit_code == 0, f"Exit code for convert is not 0 but " + str(result.exit_code) 116 | 117 | 118 | def test_tflite_stats(): 119 | tmpdir = temp_dir 120 | from deepspectrumlite.__main__ import cli 121 | from click.testing import CliRunner 122 | from os.path import join 123 | 124 | runner = CliRunner() 125 | result = runner.invoke(cli, 126 | args=[ 127 | '-vv', 'tflite-stats', 128 | '-md', join(tmpdir, 'output', 'models', 'densenet_exp', 'densenet121_run_config_0', 129 | 'converted_model.tflite') 130 | ], 131 | catch_exceptions=False) 132 | assert result.exit_code == 0, f"Exit code for tflite-stats is not 0 but " + str(result.exit_code) 133 | 134 | 135 | def test_create_preprocessor(): 136 | tmpdir = temp_dir 137 | from deepspectrumlite.__main__ import cli 138 | from click.testing import CliRunner 139 | from os.path import dirname, join 140 | 141 | cur_dir = dirname(__file__) 142 | examples = join(dirname(dirname(cur_dir)), 'examples') 143 | 144 | runner = CliRunner() 145 | result = runner.invoke(cli, 146 | args=[ 147 | '-vv', 'create-preprocessor', 148 | '-hc', join(examples, 'hp_config.json'), 149 | '-d', join(tmpdir, 'output', 'preprocessor.tflite') 150 | ], 151 | catch_exceptions=False) 152 | assert result.exit_code == 0, f"Exit code for create_preprocessor is not 0 but " + str(result.exit_code) 153 | 154 | 155 | def pytest_sessionfinish(): 156 | shutil.rmtree(temp_dir) 157 | 158 | # if __name__ == '__main__': 159 | # print(test_create_preprocessor()) -------------------------------------------------------------------------------- /tests/cli/test_train_config.py: -------------------------------------------------------------------------------- 1 | def test_cividis(tmpdir): 2 | from deepspectrumlite.__main__ import cli 3 | from click.testing import CliRunner 4 | from os.path import dirname, join 5 | 6 | cur_dir = dirname(__file__) 7 | examples = join(dirname(dirname(cur_dir)), 'examples') 8 | 9 | runner = CliRunner() 10 | result = runner.invoke(cli, 11 | args=[ 12 | '-vv', 'train', 13 | '-d', join(examples, 'audio'), 14 | '-md', join(tmpdir, 'output'), 15 | '-hc', join(cur_dir, 'config', 'cividis_hp_config.json'), 16 | '-cc', join(examples, 'class_config.json'), 17 | '-l', join(examples, 'label.csv') 18 | ], 19 | catch_exceptions=False) 20 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code) 21 | 22 | def test_plasma(tmpdir): 23 | from deepspectrumlite.__main__ import cli 24 | from click.testing import CliRunner 25 | from os.path import dirname, join 26 | 27 | cur_dir = dirname(__file__) 28 | examples = join(dirname(dirname(cur_dir)), 'examples') 29 | 30 | runner = CliRunner() 31 | result = runner.invoke(cli, 32 | args=[ 33 | '-vv', 'train', 34 | '-d', join(examples, 'audio'), 35 | '-md', join(tmpdir, 'output'), 36 | '-hc', join(cur_dir, 'config', 'plasma_hp_config.json'), 37 | '-cc', join(examples, 'class_config.json'), 38 | '-l', join(examples, 'label.csv') 39 | ], 40 | catch_exceptions=False) 41 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code) 42 | 43 | def test_inferno(tmpdir): 44 | from deepspectrumlite.__main__ import cli 45 | from click.testing import CliRunner 46 | from os.path import dirname, join 47 | 48 | cur_dir = dirname(__file__) 49 | examples = join(dirname(dirname(cur_dir)), 'examples') 50 | 51 | runner = CliRunner() 52 | result = runner.invoke(cli, 53 | args=[ 54 | '-vv', 'train', 55 | '-d', join(examples, 'audio'), 56 | '-md', join(tmpdir, 'output'), 57 | '-hc', join(cur_dir, 'config', 'inferno_hp_config.json'), 58 | '-cc', join(examples, 'class_config.json'), 59 | '-l', join(examples, 'label.csv') 60 | ], 61 | catch_exceptions=False) 62 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code) 63 | 64 | def test_magma(tmpdir): 65 | from deepspectrumlite.__main__ import cli 66 | from click.testing import CliRunner 67 | from os.path import dirname, join 68 | 69 | cur_dir = dirname(__file__) 70 | examples = join(dirname(dirname(cur_dir)), 'examples') 71 | 72 | runner = CliRunner() 73 | result = runner.invoke(cli, 74 | args=[ 75 | '-vv', 'train', 76 | '-d', join(examples, 'audio'), 77 | '-md', join(tmpdir, 'output'), 78 | '-hc', join(cur_dir, 'config', 'magma_hp_config.json'), 79 | '-cc', join(examples, 'class_config.json'), 80 | '-l', join(examples, 'label.csv') 81 | ], 82 | catch_exceptions=False) 83 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code) 84 | 85 | def test_regression(tmpdir): 86 | from deepspectrumlite.__main__ import cli 87 | from click.testing import CliRunner 88 | from os.path import dirname, join 89 | 90 | cur_dir = dirname(__file__) 91 | examples = join(dirname(dirname(cur_dir)), 'examples') 92 | 93 | runner = CliRunner() 94 | result = runner.invoke(cli, 95 | args=[ 96 | '-vv', 'train', 97 | '-d', join(examples, 'audio'), 98 | '-md', join(tmpdir, 'output'), 99 | '-hc', join(cur_dir, 'config', 'regression_hp_config.json'), 100 | '-cc', join(cur_dir, 'config', 'regression_class_config.json'), 101 | '-l', join(cur_dir, 'label', 'regression_label.csv') 102 | ], 103 | catch_exceptions=False) 104 | assert result.exit_code == 0, f"Exit code for train is not 0 but " + str(result.exit_code) --------------------------------------------------------------------------------