├── .bumpversion.cfg ├── .git-blame-ignore-revs ├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── .pylintrc ├── LICENSE.md ├── README.md ├── app.py ├── batdetect2 ├── __init__.py ├── api.py ├── cli.py ├── detector │ ├── __init__.py │ ├── compute_features.py │ ├── model_helpers.py │ ├── models.py │ ├── parameters.py │ └── post_process.py ├── evaluate │ ├── __init__.py │ ├── evaluate_models.py │ └── readme.md ├── finetune │ ├── __init__.py │ ├── finetune_model.py │ ├── prep_data_finetune.py │ └── readme.md ├── models │ ├── Net2DFast_UK_same.pth.tar │ └── readme.md ├── plot.py ├── train │ ├── __init__.py │ ├── audio_dataloader.py │ ├── evaluate.py │ ├── losses.py │ ├── readme.md │ ├── train_model.py │ ├── train_split.py │ └── train_utils.py ├── types.py └── utils │ ├── __init__.py │ ├── audio_utils.py │ ├── detector_utils.py │ ├── plot_utils.py │ ├── visualize.py │ └── wavfile.py ├── batdetect2_notebook.ipynb ├── environment.yml ├── example_data └── audio │ ├── 20170701_213954-MYOMYS-LR_0_0.5.wav │ ├── 20180530_213516-EPTSER-LR_0_0.5.wav │ └── 20180627_215323-RHIFER-LR_0_0.5.wav ├── faq.md ├── ims └── bat_icon.png ├── pyproject.toml ├── requirements.txt ├── run_batdetect.py ├── scripts ├── README.md ├── gen_dataset_summary_image.py ├── gen_spec_image.py ├── gen_spec_video.py └── viz_helpers.py ├── tests ├── __init__.py ├── conftest.py ├── data │ ├── 20230322_172000_selec2.wav │ └── contrib │ │ ├── jeff37 │ │ ├── 0166_20240531_223911.wav │ │ ├── 0166_20240602_225340.wav │ │ ├── 0166_20240603_033731.wav │ │ ├── 0166_20240603_033937.wav │ │ └── 0166_20240604_233500.wav │ │ └── padpadpadpad │ │ ├── Audiomoth.WAV │ │ ├── AudiomothNoBatCalls.WAV │ │ └── Echometer.wav ├── test_api.py ├── test_audio_utils.py ├── test_cli.py ├── test_contrib.py ├── test_detections.py ├── test_features.py └── test_model.py └── uv.lock /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.3.0 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:batdetect2/__init__.py] 7 | 8 | [bumpversion:file:pyproject.toml] 9 | -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Format code with Black and isort 2 | 3c17a2337166245de8df778fe174aad997e14e8f 3 | 9cb6b20949c7c31ee21ed2b800e8b691f1be32a7 4 | 53100f51e083cf4d900ed325ae0543cc754a26cc 5 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.9", "3.10", "3.11", "3.12"] 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Install uv 20 | uses: astral-sh/setup-uv@v3 21 | with: 22 | enable-cache: true 23 | cache-dependency-glob: "uv.lock" 24 | - name: Set up Python ${{ matrix.python-version }} 25 | run: uv python install ${{ matrix.python-version }} 26 | - name: Install the project 27 | run: uv sync --all-extras --dev 28 | - name: Test with pytest 29 | run: uv run pytest 30 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v3 18 | with: 19 | python-version: "3.x" 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build 24 | - name: Build package 25 | run: python -m build 26 | - name: Publish package 27 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 28 | with: 29 | user: __token__ 30 | password: ${{ secrets.PYPI_API_TOKEN }} 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .nox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | *.py,cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | cover/ 50 | 51 | # Sphinx documentation 52 | docs/_build/ 53 | 54 | # PyBuilder 55 | .pybuilder/ 56 | target/ 57 | 58 | # IPython 59 | profile_default/ 60 | ipython_config.py 61 | 62 | # pdm 63 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 64 | #pdm.lock 65 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 66 | # in version control. 67 | # https://pdm.fming.dev/#use-with-ide 68 | .pdm-python 69 | 70 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 71 | __pypackages__/ 72 | 73 | # Environments 74 | .env 75 | .venv 76 | env/ 77 | venv/ 78 | ENV/ 79 | env.bak/ 80 | venv.bak/ 81 | 82 | # Rope project settings 83 | .ropeproject/ 84 | 85 | # mypy 86 | .mypy_cache/ 87 | .dmypy.json 88 | dmypy.json 89 | 90 | # Model artifacts 91 | *.png 92 | *.jpg 93 | *.wav 94 | *.tar 95 | *.json 96 | plots/* 97 | 98 | # Model experiments 99 | experiments/* 100 | 101 | # Jupiter notebooks 102 | .virtual_documents 103 | .ipynb_checkpoints 104 | *.ipynb 105 | 106 | # DO Include 107 | !batdetect2_notebook.ipynb 108 | !batdetect2/models/*.pth.tar 109 | !tests/data/*.wav 110 | !tests/data/**/*.wav 111 | notebooks/lightning_logs 112 | example_data/preprocessed 113 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [TYPECHECK] 2 | 3 | # List of members which are set dynamically and missed by Pylint inference 4 | # system, and so shouldn't trigger E1101 when accessed. 5 | generated-members=torch.* 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BatDetect2 2 | Code for detecting and classifying bat echolocation calls in high frequency audio recordings. 3 | 4 | ## Getting started 5 | ### Python Environment 6 | 7 | We recommend using an isolated Python environment to avoid dependency issues. Choose one 8 | of the following options: 9 | 10 | * Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it: 11 | 12 | ```bash 13 | conda create -y --name batdetect2 python==3.10 14 | conda activate batdetect2 15 | ``` 16 | 17 | * If you already have Python installed (version >= 3.8,< 3.11) and prefer using virtual environments then: 18 | 19 | ```bash 20 | python -m venv .venv 21 | source .venv/bin/activate 22 | ``` 23 | 24 | ### Installing BatDetect2 25 | You can use pip to install `batdetect2`: 26 | 27 | ```bash 28 | pip install batdetect2 29 | ``` 30 | 31 | Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it. 32 | Once unzipped, run this from extracted folder. 33 | 34 | ```bash 35 | pip install . 36 | ``` 37 | 38 | Make sure you have the environment activated before installing `batdetect2`. 39 | 40 | 41 | ## Try the model 42 | 1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2). 43 | 44 | 2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally. 45 | 46 | 47 | ## Running the model on your own data 48 | 49 | After following the above steps to install the code you can run the model on your own data. 50 | 51 | 52 | ### Using the command line 53 | 54 | You can run the model by opening the command line and typing: 55 | ```bash 56 | batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD 57 | ``` 58 | e.g. 59 | ```bash 60 | batdetect2 detect example_data/audio/ example_data/anns/ 0.3 61 | ``` 62 | 63 | `AUDIO_DIR` is the path on your computer to the audio wav files of interest. 64 | `ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file. 65 | `DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes. 66 | 67 | There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files: 68 | `batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features` 69 | 70 | You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g. 71 | `batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar` 72 | 73 | 74 | ### Using the Python API 75 | 76 | If you prefer to process your data within a Python script then you can use the `batdetect2` Python API. 77 | 78 | ```python 79 | from batdetect2 import api 80 | 81 | AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav" 82 | 83 | # Process a whole file 84 | results = api.process_file(AUDIO_FILE) 85 | 86 | # Or, load audio and compute spectrograms 87 | audio = api.load_audio(AUDIO_FILE) 88 | spec = api.generate_spectrogram(audio) 89 | 90 | # And process the audio or the spectrogram with the model 91 | detections, features, spec = api.process_audio(audio) 92 | detections, features = api.process_spectrogram(spec) 93 | 94 | # Do something else ... 95 | ``` 96 | 97 | You can integrate the detections or the extracted features to your custom analysis pipeline. 98 | 99 | #### Using the Python API with HTTP 100 | 101 | ```python 102 | from batdetect2 import api 103 | import io 104 | import requests 105 | 106 | AUDIO_URL = "" 107 | 108 | # Process a whole file from a url 109 | results = api.process_url(AUDIO_URL) 110 | 111 | # Or, load audio and compute spectrograms 112 | # 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes 113 | audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content)) 114 | spec = api.generate_spectrogram(audio) 115 | 116 | # And process the audio or the spectrogram with the model 117 | detections, features, spec = api.process_audio(audio) 118 | detections, features = api.process_spectrogram(spec) 119 | ``` 120 | 121 | ## Training the model on your own data 122 | Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model. 123 | 124 | 125 | ## Data and annotations 126 | The raw audio data and annotations used to train the models in the paper will be added soon. 127 | The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI). 128 | 129 | 130 | ## Warning 131 | The models developed and shared as part of this repository should be used with caution. 132 | While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment. 133 | Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted. 134 | 135 | 136 | ## FAQ 137 | For more information please consult our [FAQ](faq.md). 138 | 139 | 140 | ## Reference 141 | If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1): 142 | ``` 143 | @article{batdetect2_2022, 144 | title = {Towards a General Approach for Bat Echolocation Detection and Classification}, 145 | author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataudm, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.}, 146 | journal = {bioRxiv}, 147 | year = {2022} 148 | } 149 | ``` 150 | 151 | ## Acknowledgements 152 | Thanks to all the contributors who spent time collecting and annotating audio data. 153 | 154 | 155 | ### TODOs 156 | - [x] Release the code and pretrained model 157 | - [ ] Release the datasets and annotations used the experiments in the paper 158 | - [ ] Add the scripts used to generate the tables and figures from the paper 159 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import matplotlib 3 | 4 | matplotlib.use("Agg") 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from batdetect2 import api, plot 11 | 12 | MAX_DURATION = 2 13 | DETECTION_THRESHOLD = 0.3 14 | 15 | 16 | examples = [ 17 | [ 18 | "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav", 19 | DETECTION_THRESHOLD, 20 | ], 21 | [ 22 | "example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav", 23 | DETECTION_THRESHOLD, 24 | ], 25 | [ 26 | "example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav", 27 | DETECTION_THRESHOLD, 28 | ], 29 | ] 30 | 31 | 32 | def make_prediction(file_name, detection_threshold=DETECTION_THRESHOLD): 33 | # configure the model run 34 | run_config = api.get_config( 35 | detection_threshold=detection_threshold, 36 | max_duration=MAX_DURATION, 37 | ) 38 | 39 | # process the file to generate predictions 40 | results = api.process_file(file_name, config=run_config) 41 | 42 | # extract the detections 43 | detections = results["pred_dict"]["annotation"] 44 | 45 | # create a dataframe of the predictions 46 | df = pd.DataFrame( 47 | [ 48 | { 49 | "species": pred["class"], 50 | "time": pred["start_time"], 51 | "detection_prob": pred["class_prob"], 52 | "species_prob": pred["class_prob"], 53 | } 54 | for pred in detections 55 | ] 56 | ) 57 | im = generate_results_image(file_name, detections, run_config) 58 | 59 | return im, df 60 | 61 | 62 | def generate_results_image(file_name, detections, config): 63 | audio = api.load_audio( 64 | file_name, 65 | max_duration=config["max_duration"], 66 | time_exp_fact=config["time_expansion"], 67 | target_samp_rate=config["target_samp_rate"], 68 | ) 69 | 70 | spec = api.generate_spectrogram(audio, config=config) 71 | 72 | # create fig 73 | plt.close("all") 74 | fig = plt.figure( 75 | 1, 76 | figsize=(15, 4), 77 | dpi=100, 78 | frameon=False, 79 | ) 80 | ax = fig.add_subplot(111) 81 | plot.spectrogram_with_detections(spec, detections, ax=ax) 82 | plt.tight_layout() 83 | 84 | # convert fig to image 85 | fig.canvas.draw() 86 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 87 | w, h = fig.canvas.get_width_height() 88 | im = data.reshape((int(h), int(w), -1)) 89 | return im 90 | 91 | 92 | descr_txt = ( 93 | "Demo of BatDetect2 deep learning-based bat echolocation call detection. " 94 | "
This model is only trained on bat species from the UK. If the input " 95 | "file is longer than 2 seconds, only the first 2 seconds will be processed." 96 | "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." 97 | ) 98 | 99 | gr.Interface( 100 | fn=make_prediction, 101 | inputs=[ 102 | gr.Audio( 103 | source="upload", 104 | type="filepath", 105 | label="Audio File", 106 | info="Upload an audio file to be processed.", 107 | ), 108 | gr.Slider( 109 | minimum=0, 110 | maximum=1, 111 | value=DETECTION_THRESHOLD, 112 | label="Detection Threshold", 113 | step=0.1, 114 | info=( 115 | "All detections with a detection probability below this " 116 | "threshold will be ignored." 117 | ), 118 | ), 119 | ], 120 | live=True, 121 | outputs=[ 122 | gr.Image(label="Visualisation"), 123 | gr.Dataframe( 124 | headers=["species", "time", "detection_prob", "species_prob"], 125 | datatype=["str", "number", "number", "number"], 126 | row_count=1, 127 | col_count=(4, "fixed"), 128 | label="Predictions", 129 | ), 130 | ], 131 | theme="huggingface", 132 | title="BatDetect2 Demo", 133 | description=descr_txt, 134 | examples=examples, 135 | allow_flagging="never", 136 | ).launch() 137 | -------------------------------------------------------------------------------- /batdetect2/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | numba_logger = logging.getLogger("numba") 4 | numba_logger.setLevel(logging.WARNING) 5 | 6 | __version__ = "1.3.0" 7 | -------------------------------------------------------------------------------- /batdetect2/cli.py: -------------------------------------------------------------------------------- 1 | """BatDetect2 command line interface.""" 2 | 3 | import os 4 | 5 | import click 6 | 7 | from batdetect2 import api 8 | from batdetect2.detector.parameters import DEFAULT_MODEL_PATH 9 | from batdetect2.types import ProcessingConfiguration 10 | from batdetect2.utils.detector_utils import save_results_to_file 11 | 12 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | 15 | INFO_STR = """ 16 | BatDetect2 - Detection and Classification 17 | Assumes audio files are mono, not stereo. 18 | Spaces in the input paths will throw an error. Wrap in quotes. 19 | Input files should be short in duration e.g. < 30 seconds. 20 | """ 21 | 22 | 23 | @click.group() 24 | def cli(): 25 | """BatDetect2 - Bat Call Detection and Classification.""" 26 | click.echo(INFO_STR) 27 | 28 | 29 | @cli.command() 30 | @click.argument( 31 | "audio_dir", 32 | type=click.Path(exists=True), 33 | ) 34 | @click.argument( 35 | "ann_dir", 36 | type=click.Path(exists=False), 37 | ) 38 | @click.argument( 39 | "detection_threshold", 40 | type=float, 41 | ) 42 | @click.option( 43 | "--cnn_features", 44 | is_flag=True, 45 | default=False, 46 | help="Extracts CNN call features", 47 | ) 48 | @click.option( 49 | "--chunk_size", 50 | type=float, 51 | default=2, 52 | help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.", 53 | ) 54 | @click.option( 55 | "--spec_features", 56 | is_flag=True, 57 | default=False, 58 | help="Extracts low level call features", 59 | ) 60 | @click.option( 61 | "--time_expansion_factor", 62 | type=int, 63 | default=1, 64 | help="The time expansion factor used for all files (default is 1)", 65 | ) 66 | @click.option( 67 | "--quiet", 68 | is_flag=True, 69 | default=False, 70 | help="Minimize output printing", 71 | ) 72 | @click.option( 73 | "--save_preds_if_empty", 74 | is_flag=True, 75 | default=False, 76 | help="Save empty annotation file if no detections made.", 77 | ) 78 | @click.option( 79 | "--model_path", 80 | type=str, 81 | default=DEFAULT_MODEL_PATH, 82 | help="Path to trained BatDetect2 model", 83 | ) 84 | def detect( 85 | audio_dir: str, 86 | ann_dir: str, 87 | detection_threshold: float, 88 | time_expansion_factor: int, 89 | chunk_size: float, 90 | **args, 91 | ): 92 | """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. 93 | 94 | DETECTION_THRESHOLD is the detection threshold. All predictions with a 95 | score below this threshold will be discarded. Values between 0 and 1. 96 | 97 | Assumes audio files are mono, not stereo. 98 | 99 | Spaces in the input paths will throw an error. Wrap in quotes. 100 | 101 | Input files should be short in duration e.g. < 30 seconds. 102 | """ 103 | click.echo(f"Loading model: {args['model_path']}") 104 | model, params = api.load_model(args["model_path"]) 105 | 106 | click.echo(f"\nInput directory: {audio_dir}") 107 | files = api.list_audio_files(audio_dir) 108 | 109 | click.echo(f"Number of audio files: {len(files)}") 110 | click.echo(f"\nSaving results to: {ann_dir}") 111 | 112 | config = api.get_config( 113 | **{ 114 | **params, 115 | **args, 116 | "time_expansion": time_expansion_factor, 117 | "spec_slices": False, 118 | "chunk_size": chunk_size, 119 | "detection_threshold": detection_threshold, 120 | } 121 | ) 122 | 123 | if not args["quiet"]: 124 | print_config(config) 125 | 126 | # process files 127 | error_files = [] 128 | for index, audio_file in enumerate(files): 129 | try: 130 | if not args["quiet"]: 131 | click.echo(f"\n{index} {audio_file}") 132 | 133 | results = api.process_file(audio_file, model, config=config) 134 | 135 | if args["save_preds_if_empty"] or ( 136 | len(results["pred_dict"]["annotation"]) > 0 137 | ): 138 | results_path = audio_file.replace(audio_dir, ann_dir) 139 | save_results_to_file(results, results_path) 140 | except (RuntimeError, ValueError, LookupError, EOFError) as err: 141 | error_files.append(audio_file) 142 | click.secho(f"Error processing file {audio_file}: {err}", fg="red") 143 | 144 | click.echo(f"\nResults saved to: {ann_dir}") 145 | 146 | if len(error_files) > 0: 147 | click.secho("\nUnable to process the follow files:", fg="red") 148 | for err in error_files: 149 | click.echo(f" {err}") 150 | 151 | 152 | def print_config(config: ProcessingConfiguration): 153 | """Print the processing configuration.""" 154 | click.echo("\nProcessing Configuration:") 155 | click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") 156 | click.echo(f"Detection Threshold: {config.get('detection_threshold')}") 157 | click.echo(f"Chunk Size: {config.get('chunk_size')}s") 158 | 159 | 160 | if __name__ == "__main__": 161 | cli() 162 | -------------------------------------------------------------------------------- /batdetect2/detector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/batdetect2/detector/__init__.py -------------------------------------------------------------------------------- /batdetect2/detector/compute_features.py: -------------------------------------------------------------------------------- 1 | """Functions to compute features from predictions.""" 2 | from typing import Dict, Optional 3 | 4 | import numpy as np 5 | 6 | from batdetect2 import types 7 | from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ 8 | 9 | 10 | def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): 11 | """Convert spectrogram index to frequency in Hz.""" "" 12 | spec_ind = spec_height - spec_ind 13 | return round( 14 | (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 15 | ) 16 | 17 | 18 | def extract_spec_slices(spec, pred_nms): 19 | """Extract spectrogram slices from spectrogram. 20 | 21 | The slices are extracted based on detected call locations. 22 | """ 23 | x_pos = pred_nms["x_pos"] 24 | bb_width = pred_nms["bb_width"] 25 | slices = [] 26 | 27 | # add 20% padding either side of call 28 | pad = bb_width * 0.2 29 | x_pos_pad = x_pos - pad 30 | bb_width_pad = bb_width + 2 * pad 31 | 32 | for ff in range(len(pred_nms["det_probs"])): 33 | x_start = int(np.maximum(0, x_pos_pad[ff])) 34 | x_end = int( 35 | np.minimum( 36 | spec.shape[1] - 1, np.round(x_pos_pad[ff] + bb_width_pad[ff]) 37 | ) 38 | ) 39 | slices.append(spec[:, x_start:x_end].astype(np.float16)) 40 | return slices 41 | 42 | 43 | def compute_duration( 44 | prediction: types.Prediction, 45 | **_, 46 | ) -> float: 47 | """Compute duration of call in seconds.""" 48 | return round(prediction["end_time"] - prediction["start_time"], 5) 49 | 50 | 51 | def compute_low_freq( 52 | prediction: types.Prediction, 53 | **_, 54 | ) -> float: 55 | """Compute lowest frequency in call in Hz.""" 56 | return int(prediction["low_freq"]) 57 | 58 | 59 | def compute_high_freq( 60 | prediction: types.Prediction, 61 | **_, 62 | ) -> float: 63 | """Compute highest frequency in call in Hz.""" 64 | return int(prediction["high_freq"]) 65 | 66 | 67 | def compute_bandwidth( 68 | prediction: types.Prediction, 69 | **_, 70 | ) -> float: 71 | """Compute bandwidth of call in Hz.""" 72 | return int(prediction["high_freq"] - prediction["low_freq"]) 73 | 74 | 75 | def compute_max_power_bb( 76 | prediction: types.Prediction, 77 | spec: Optional[np.ndarray] = None, 78 | min_freq: int = MIN_FREQ_HZ, 79 | max_freq: int = MAX_FREQ_HZ, 80 | **_, 81 | ) -> float: 82 | """Compute frequency with maximum power in call in Hz. 83 | 84 | This is the frequency with the maximum power in the bounding box of the 85 | call. 86 | """ 87 | if spec is None: 88 | return np.nan 89 | 90 | x_start = max(0, prediction["x_pos"]) 91 | x_end = min( 92 | spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] 93 | ) 94 | 95 | # y low is the lowest freq but it will have a higher value due to array 96 | # starting at 0 at top 97 | y_low = min(spec.shape[0] - 1, prediction["y_pos"]) 98 | y_high = max(0, prediction["y_pos"] - prediction["bb_height"]) 99 | 100 | spec_bb = spec[y_high:y_low, x_start:x_end] 101 | power_per_freq_band = np.sum(spec_bb, axis=1) 102 | 103 | try: 104 | max_power_ind = np.argmax(power_per_freq_band) 105 | except ValueError: 106 | # If the call is too short, the bounding box might be empty. 107 | # In this case, return NaN. 108 | return np.nan 109 | 110 | return int( 111 | convert_int_to_freq( 112 | y_high + max_power_ind, 113 | spec.shape[0], 114 | min_freq, 115 | max_freq, 116 | ) 117 | ) 118 | 119 | 120 | def compute_max_power( 121 | prediction: types.Prediction, 122 | spec: Optional[np.ndarray] = None, 123 | min_freq: int = MIN_FREQ_HZ, 124 | max_freq: int = MAX_FREQ_HZ, 125 | **_, 126 | ) -> float: 127 | """Compute frequency with maximum power in during the call in Hz.""" 128 | if spec is None: 129 | return np.nan 130 | 131 | x_start = max(0, prediction["x_pos"]) 132 | x_end = min( 133 | spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] 134 | ) 135 | spec_call = spec[:, x_start:x_end] 136 | power_per_freq_band = np.sum(spec_call, axis=1) 137 | max_power_ind = np.argmax(power_per_freq_band) 138 | return int( 139 | convert_int_to_freq( 140 | max_power_ind, 141 | spec.shape[0], 142 | min_freq, 143 | max_freq, 144 | ) 145 | ) 146 | 147 | 148 | def compute_max_power_first( 149 | prediction: types.Prediction, 150 | spec: Optional[np.ndarray] = None, 151 | min_freq: int = MIN_FREQ_HZ, 152 | max_freq: int = MAX_FREQ_HZ, 153 | **_, 154 | ) -> float: 155 | """Compute frequency with maximum power in first half of call in Hz.""" 156 | if spec is None: 157 | return np.nan 158 | 159 | x_start = max(0, prediction["x_pos"]) 160 | x_end = min( 161 | spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] 162 | ) 163 | spec_call = spec[:, x_start:x_end] 164 | first_half = spec_call[:, : int(spec_call.shape[1] / 2)] 165 | power_per_freq_band = np.sum(first_half, axis=1) 166 | max_power_ind = np.argmax(power_per_freq_band) 167 | return int( 168 | convert_int_to_freq( 169 | max_power_ind, 170 | spec.shape[0], 171 | min_freq, 172 | max_freq, 173 | ) 174 | ) 175 | 176 | 177 | def compute_max_power_second( 178 | prediction: types.Prediction, 179 | spec: Optional[np.ndarray] = None, 180 | min_freq: int = MIN_FREQ_HZ, 181 | max_freq: int = MAX_FREQ_HZ, 182 | **_, 183 | ) -> float: 184 | """Compute frequency with maximum power in second half of call in Hz.""" 185 | if spec is None: 186 | return np.nan 187 | 188 | x_start = max(0, prediction["x_pos"]) 189 | x_end = min( 190 | spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] 191 | ) 192 | spec_call = spec[:, x_start:x_end] 193 | second_half = spec_call[:, int(spec_call.shape[1] / 2) :] 194 | power_per_freq_band = np.sum(second_half, axis=1) 195 | max_power_ind = np.argmax(power_per_freq_band) 196 | return int( 197 | convert_int_to_freq( 198 | max_power_ind, 199 | spec.shape[0], 200 | min_freq, 201 | max_freq, 202 | ) 203 | ) 204 | 205 | 206 | def compute_call_interval( 207 | prediction: types.Prediction, 208 | previous: Optional[types.Prediction] = None, 209 | **_, 210 | ) -> float: 211 | """Compute time between this call and the previous call in seconds.""" 212 | if previous is None: 213 | return np.nan 214 | return round(prediction["start_time"] - previous["end_time"], 5) 215 | 216 | 217 | # NOTE: The order of the features in this dictionary is important. The 218 | # features are extracted in this order and the order of the columns in the 219 | # output csv file is determined by this order. In order to avoid breaking 220 | # changes in the output csv file, new features should be added to the end of 221 | # this dictionary. 222 | FEATURES: Dict[str, types.FeatureExtractor] = { 223 | "duration": compute_duration, 224 | "low_freq_bb": compute_low_freq, 225 | "high_freq_bb": compute_high_freq, 226 | "bandwidth": compute_bandwidth, 227 | "max_power_bb": compute_max_power_bb, 228 | "max_power": compute_max_power, 229 | "max_power_first": compute_max_power_first, 230 | "max_power_second": compute_max_power_second, 231 | "call_interval": compute_call_interval, 232 | } 233 | 234 | 235 | def get_feats( 236 | spec: np.ndarray, 237 | pred_nms: types.PredictionResults, 238 | params: types.FeatureExtractionParameters, 239 | ): 240 | """Extract features from spectrogram based on detected call locations. 241 | 242 | The features extracted are: 243 | 244 | - duration: duration of call in seconds 245 | - low_freq: lowest frequency in call in kHz 246 | - high_freq: highest frequency in call in kHz 247 | - bandwidth: high_freq - low_freq 248 | - max_power_bb: frequency with maximum power in call in kHz 249 | - max_power: frequency with maximum power in spectrogram in kHz 250 | - max_power_first: frequency with maximum power in first half of call in 251 | kHz. 252 | - max_power_second: frequency with maximum power in second half of call in 253 | kHz. 254 | - call_interval: time between this call and the previous call in seconds 255 | 256 | Consider re-extracting spectrogram for this to get better temporal 257 | resolution. 258 | 259 | For more possible features check out: 260 | https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt 261 | 262 | Parameters 263 | ---------- 264 | spec : np.ndarray 265 | Spectrogram from which to extract features. 266 | 267 | pred_nms : types.PredictionResults 268 | Information about detected calls from which to extract features. 269 | 270 | params : types.FeatureExtractionParameters 271 | Parameters for feature extraction. 272 | 273 | Returns 274 | ------- 275 | features : np.ndarray 276 | Extracted features for each detected call. Shape is 277 | (num_detections, num_features). 278 | """ 279 | num_detections = len(pred_nms["det_probs"]) 280 | features = np.empty((num_detections, len(FEATURES)), dtype=np.float32) 281 | previous = None 282 | 283 | for row in range(num_detections): 284 | prediction: types.Prediction = { 285 | "det_prob": float(pred_nms["det_probs"][row]), 286 | "class_prob": pred_nms["class_probs"][:, row], 287 | "start_time": float(pred_nms["start_times"][row]), 288 | "end_time": float(pred_nms["end_times"][row]), 289 | "low_freq": float(pred_nms["low_freqs"][row]), 290 | "high_freq": float(pred_nms["high_freqs"][row]), 291 | "x_pos": int(pred_nms["x_pos"][row]), 292 | "y_pos": int(pred_nms["y_pos"][row]), 293 | "bb_width": int(pred_nms["bb_width"][row]), 294 | "bb_height": int(pred_nms["bb_height"][row]), 295 | } 296 | 297 | for col, feature in enumerate(FEATURES.values()): 298 | features[row, col] = feature( 299 | prediction, 300 | previous=previous, 301 | spec=spec, 302 | **params, 303 | ) 304 | 305 | previous = prediction 306 | 307 | return features 308 | 309 | 310 | def get_feature_names(): 311 | """Get names of features in the order they are extracted.""" 312 | return list(FEATURES.keys()) 313 | -------------------------------------------------------------------------------- /batdetect2/detector/model_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | __all__ = [ 6 | "SelfAttention", 7 | "ConvBlockDownCoordF", 8 | "ConvBlockDownStandard", 9 | "ConvBlockUpF", 10 | "ConvBlockUpStandard", 11 | ] 12 | 13 | 14 | class SelfAttention(nn.Module): 15 | def __init__(self, ip_dim, att_dim): 16 | super(SelfAttention, self).__init__() 17 | # Note, does not encode position information (absolute or realtive) 18 | self.temperature = 1.0 19 | self.att_dim = att_dim 20 | self.key_fun = nn.Linear(ip_dim, att_dim) 21 | self.val_fun = nn.Linear(ip_dim, att_dim) 22 | self.que_fun = nn.Linear(ip_dim, att_dim) 23 | self.pro_fun = nn.Linear(att_dim, ip_dim) 24 | 25 | def forward(self, x): 26 | x = x.squeeze(2).permute(0, 2, 1) 27 | 28 | kk = torch.matmul( 29 | x, self.key_fun.weight.T 30 | ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) 31 | qq = torch.matmul( 32 | x, self.que_fun.weight.T 33 | ) + self.que_fun.bias.unsqueeze(0).unsqueeze(0) 34 | vv = torch.matmul( 35 | x, self.val_fun.weight.T 36 | ) + self.val_fun.bias.unsqueeze(0).unsqueeze(0) 37 | 38 | kk_qq = torch.bmm(kk, qq.permute(0, 2, 1)) / ( 39 | self.temperature * self.att_dim 40 | ) 41 | att_weights = F.softmax( 42 | kk_qq, 1 43 | ) # each col of each attention matrix sums to 1 44 | att = torch.bmm(vv.permute(0, 2, 1), att_weights) 45 | 46 | op = torch.matmul( 47 | att.permute(0, 2, 1), self.pro_fun.weight.T 48 | ) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0) 49 | op = op.permute(0, 2, 1).unsqueeze(2) 50 | 51 | return op 52 | 53 | 54 | class ConvBlockDownCoordF(nn.Module): 55 | def __init__( 56 | self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1 57 | ): 58 | super(ConvBlockDownCoordF, self).__init__() 59 | self.coords = nn.Parameter( 60 | torch.linspace(-1, 1, ip_height)[None, None, ..., None], 61 | requires_grad=False, 62 | ) 63 | self.conv = nn.Conv2d( 64 | in_chn + 1, 65 | out_chn, 66 | kernel_size=k_size, 67 | padding=pad_size, 68 | stride=stride, 69 | ) 70 | self.conv_bn = nn.BatchNorm2d(out_chn) 71 | 72 | def forward(self, x): 73 | freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3]) 74 | x = torch.cat((x, freq_info), 1) 75 | x = F.max_pool2d(self.conv(x), 2, 2) 76 | x = F.relu(self.conv_bn(x), inplace=True) 77 | return x 78 | 79 | 80 | class ConvBlockDownStandard(nn.Module): 81 | def __init__( 82 | self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1 83 | ): 84 | super(ConvBlockDownStandard, self).__init__() 85 | self.conv = nn.Conv2d( 86 | in_chn, 87 | out_chn, 88 | kernel_size=k_size, 89 | padding=pad_size, 90 | stride=stride, 91 | ) 92 | self.conv_bn = nn.BatchNorm2d(out_chn) 93 | 94 | def forward(self, x): 95 | x = F.max_pool2d(self.conv(x), 2, 2) 96 | x = F.relu(self.conv_bn(x), inplace=True) 97 | return x 98 | 99 | 100 | class ConvBlockUpF(nn.Module): 101 | def __init__( 102 | self, 103 | in_chn, 104 | out_chn, 105 | ip_height, 106 | k_size=3, 107 | pad_size=1, 108 | up_mode="bilinear", 109 | up_scale=(2, 2), 110 | ): 111 | super(ConvBlockUpF, self).__init__() 112 | self.up_scale = up_scale 113 | self.up_mode = up_mode 114 | self.coords = nn.Parameter( 115 | torch.linspace(-1, 1, ip_height * up_scale[0])[ 116 | None, None, ..., None 117 | ], 118 | requires_grad=False, 119 | ) 120 | self.conv = nn.Conv2d( 121 | in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size 122 | ) 123 | self.conv_bn = nn.BatchNorm2d(out_chn) 124 | 125 | def forward(self, x): 126 | op = F.interpolate( 127 | x, 128 | size=( 129 | x.shape[-2] * self.up_scale[0], 130 | x.shape[-1] * self.up_scale[1], 131 | ), 132 | mode=self.up_mode, 133 | align_corners=False, 134 | ) 135 | freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3]) 136 | op = torch.cat((op, freq_info), 1) 137 | op = self.conv(op) 138 | op = F.relu(self.conv_bn(op), inplace=True) 139 | return op 140 | 141 | 142 | class ConvBlockUpStandard(nn.Module): 143 | def __init__( 144 | self, 145 | in_chn, 146 | out_chn, 147 | ip_height=None, 148 | k_size=3, 149 | pad_size=1, 150 | up_mode="bilinear", 151 | up_scale=(2, 2), 152 | ): 153 | super(ConvBlockUpStandard, self).__init__() 154 | self.up_scale = up_scale 155 | self.up_mode = up_mode 156 | self.conv = nn.Conv2d( 157 | in_chn, out_chn, kernel_size=k_size, padding=pad_size 158 | ) 159 | self.conv_bn = nn.BatchNorm2d(out_chn) 160 | 161 | def forward(self, x): 162 | op = F.interpolate( 163 | x, 164 | size=( 165 | x.shape[-2] * self.up_scale[0], 166 | x.shape[-1] * self.up_scale[1], 167 | ), 168 | mode=self.up_mode, 169 | align_corners=False, 170 | ) 171 | op = self.conv(op) 172 | op = F.relu(self.conv_bn(op), inplace=True) 173 | return op 174 | -------------------------------------------------------------------------------- /batdetect2/detector/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.fft 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from batdetect2.detector.model_helpers import ( 7 | ConvBlockDownCoordF, 8 | ConvBlockDownStandard, 9 | ConvBlockUpF, 10 | ConvBlockUpStandard, 11 | SelfAttention, 12 | ) 13 | from batdetect2.types import ModelOutput 14 | 15 | __all__ = [ 16 | "Net2DFast", 17 | "Net2DFastNoAttn", 18 | "Net2DFastNoCoordConv", 19 | ] 20 | 21 | 22 | class Net2DFast(nn.Module): 23 | def __init__( 24 | self, 25 | num_filts, 26 | num_classes=0, 27 | emb_dim=0, 28 | ip_height=128, 29 | resize_factor=0.5, 30 | ): 31 | super().__init__() 32 | self.num_classes = num_classes 33 | self.emb_dim = emb_dim 34 | self.num_filts = num_filts 35 | self.resize_factor = resize_factor 36 | self.ip_height_rs = ip_height 37 | self.bneck_height = self.ip_height_rs // 32 38 | 39 | # encoder 40 | self.conv_dn_0 = ConvBlockDownCoordF( 41 | 1, 42 | num_filts // 4, 43 | self.ip_height_rs, 44 | k_size=3, 45 | pad_size=1, 46 | stride=1, 47 | ) 48 | self.conv_dn_1 = ConvBlockDownCoordF( 49 | num_filts // 4, 50 | num_filts // 2, 51 | self.ip_height_rs // 2, 52 | k_size=3, 53 | pad_size=1, 54 | stride=1, 55 | ) 56 | self.conv_dn_2 = ConvBlockDownCoordF( 57 | num_filts // 2, 58 | num_filts, 59 | self.ip_height_rs // 4, 60 | k_size=3, 61 | pad_size=1, 62 | stride=1, 63 | ) 64 | self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1) 65 | self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2) 66 | 67 | # bottleneck 68 | self.conv_1d = nn.Conv2d( 69 | num_filts * 2, 70 | num_filts * 2, 71 | (self.ip_height_rs // 8, 1), 72 | padding=0, 73 | ) 74 | self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2) 75 | self.att = SelfAttention(num_filts * 2, num_filts * 2) 76 | 77 | # decoder 78 | self.conv_up_2 = ConvBlockUpF( 79 | num_filts * 2, num_filts // 2, self.ip_height_rs // 8 80 | ) 81 | self.conv_up_3 = ConvBlockUpF( 82 | num_filts // 2, num_filts // 4, self.ip_height_rs // 4 83 | ) 84 | self.conv_up_4 = ConvBlockUpF( 85 | num_filts // 4, num_filts // 4, self.ip_height_rs // 2 86 | ) 87 | 88 | # output 89 | # +1 to include background class for class output 90 | self.conv_op = nn.Conv2d( 91 | num_filts // 4, num_filts // 4, kernel_size=3, padding=1 92 | ) 93 | self.conv_op_bn = nn.BatchNorm2d(num_filts // 4) 94 | self.conv_size_op = nn.Conv2d( 95 | num_filts // 4, 2, kernel_size=1, padding=0 96 | ) 97 | self.conv_classes_op = nn.Conv2d( 98 | num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0 99 | ) 100 | 101 | if self.emb_dim > 0: 102 | self.conv_emb = nn.Conv2d( 103 | num_filts, self.emb_dim, kernel_size=1, padding=0 104 | ) 105 | 106 | def forward(self, ip, return_feats=False) -> ModelOutput: 107 | # encoder 108 | x1 = self.conv_dn_0(ip) 109 | x2 = self.conv_dn_1(x1) 110 | x3 = self.conv_dn_2(x2) 111 | x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) 112 | 113 | # bottleneck 114 | x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) 115 | x = self.att(x) 116 | x = x.repeat([1, 1, self.bneck_height * 4, 1]) 117 | 118 | # decoder 119 | x = self.conv_up_2(x + x3) 120 | x = self.conv_up_3(x + x2) 121 | x = self.conv_up_4(x + x1) 122 | 123 | # output 124 | x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) 125 | cls = self.conv_classes_op(x) 126 | comb = torch.softmax(cls, 1) 127 | 128 | return ModelOutput( 129 | pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), 130 | pred_size=F.relu(self.conv_size_op(x), inplace=True), 131 | pred_class=comb, 132 | pred_class_un_norm=cls, 133 | features=x, 134 | ) 135 | 136 | 137 | class Net2DFastNoAttn(nn.Module): 138 | def __init__( 139 | self, 140 | num_filts, 141 | num_classes=0, 142 | emb_dim=0, 143 | ip_height=128, 144 | resize_factor=0.5, 145 | ): 146 | super().__init__() 147 | 148 | self.num_classes = num_classes 149 | self.emb_dim = emb_dim 150 | self.num_filts = num_filts 151 | self.resize_factor = resize_factor 152 | self.ip_height_rs = ip_height 153 | self.bneck_height = self.ip_height_rs // 32 154 | 155 | self.conv_dn_0 = ConvBlockDownCoordF( 156 | 1, 157 | num_filts // 4, 158 | self.ip_height_rs, 159 | k_size=3, 160 | pad_size=1, 161 | stride=1, 162 | ) 163 | self.conv_dn_1 = ConvBlockDownCoordF( 164 | num_filts // 4, 165 | num_filts // 2, 166 | self.ip_height_rs // 2, 167 | k_size=3, 168 | pad_size=1, 169 | stride=1, 170 | ) 171 | self.conv_dn_2 = ConvBlockDownCoordF( 172 | num_filts // 2, 173 | num_filts, 174 | self.ip_height_rs // 4, 175 | k_size=3, 176 | pad_size=1, 177 | stride=1, 178 | ) 179 | self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1) 180 | self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2) 181 | 182 | self.conv_1d = nn.Conv2d( 183 | num_filts * 2, 184 | num_filts * 2, 185 | (self.ip_height_rs // 8, 1), 186 | padding=0, 187 | ) 188 | self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2) 189 | 190 | self.conv_up_2 = ConvBlockUpF( 191 | num_filts * 2, num_filts // 2, self.ip_height_rs // 8 192 | ) 193 | self.conv_up_3 = ConvBlockUpF( 194 | num_filts // 2, num_filts // 4, self.ip_height_rs // 4 195 | ) 196 | self.conv_up_4 = ConvBlockUpF( 197 | num_filts // 4, num_filts // 4, self.ip_height_rs // 2 198 | ) 199 | 200 | # output 201 | # +1 to include background class for class output 202 | self.conv_op = nn.Conv2d( 203 | num_filts // 4, num_filts // 4, kernel_size=3, padding=1 204 | ) 205 | self.conv_op_bn = nn.BatchNorm2d(num_filts // 4) 206 | self.conv_size_op = nn.Conv2d( 207 | num_filts // 4, 2, kernel_size=1, padding=0 208 | ) 209 | self.conv_classes_op = nn.Conv2d( 210 | num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0 211 | ) 212 | 213 | if self.emb_dim > 0: 214 | self.conv_emb = nn.Conv2d( 215 | num_filts, self.emb_dim, kernel_size=1, padding=0 216 | ) 217 | 218 | def forward(self, ip, return_feats=False) -> ModelOutput: 219 | x1 = self.conv_dn_0(ip) 220 | x2 = self.conv_dn_1(x1) 221 | x3 = self.conv_dn_2(x2) 222 | x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) 223 | 224 | x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) 225 | x = x.repeat([1, 1, self.bneck_height * 4, 1]) 226 | 227 | x = self.conv_up_2(x + x3) 228 | x = self.conv_up_3(x + x2) 229 | x = self.conv_up_4(x + x1) 230 | 231 | x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) 232 | cls = self.conv_classes_op(x) 233 | comb = torch.softmax(cls, 1) 234 | 235 | return ModelOutput( 236 | pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), 237 | pred_size=F.relu(self.conv_size_op(x), inplace=True), 238 | pred_class=comb, 239 | pred_class_un_norm=cls, 240 | features=x, 241 | ) 242 | 243 | 244 | class Net2DFastNoCoordConv(nn.Module): 245 | def __init__( 246 | self, 247 | num_filts, 248 | num_classes=0, 249 | emb_dim=0, 250 | ip_height=128, 251 | resize_factor=0.5, 252 | ): 253 | super().__init__() 254 | 255 | self.num_classes = num_classes 256 | self.emb_dim = emb_dim 257 | self.num_filts = num_filts 258 | self.resize_factor = resize_factor 259 | self.ip_height_rs = ip_height 260 | self.bneck_height = self.ip_height_rs // 32 261 | 262 | self.conv_dn_0 = ConvBlockDownStandard( 263 | 1, 264 | num_filts // 4, 265 | self.ip_height_rs, 266 | k_size=3, 267 | pad_size=1, 268 | stride=1, 269 | ) 270 | self.conv_dn_1 = ConvBlockDownStandard( 271 | num_filts // 4, 272 | num_filts // 2, 273 | self.ip_height_rs // 2, 274 | k_size=3, 275 | pad_size=1, 276 | stride=1, 277 | ) 278 | self.conv_dn_2 = ConvBlockDownStandard( 279 | num_filts // 2, 280 | num_filts, 281 | self.ip_height_rs // 4, 282 | k_size=3, 283 | pad_size=1, 284 | stride=1, 285 | ) 286 | self.conv_dn_3 = nn.Conv2d(num_filts, num_filts * 2, 3, padding=1) 287 | self.conv_dn_3_bn = nn.BatchNorm2d(num_filts * 2) 288 | 289 | self.conv_1d = nn.Conv2d( 290 | num_filts * 2, 291 | num_filts * 2, 292 | (self.ip_height_rs // 8, 1), 293 | padding=0, 294 | ) 295 | self.conv_1d_bn = nn.BatchNorm2d(num_filts * 2) 296 | 297 | self.att = SelfAttention(num_filts * 2, num_filts * 2) 298 | 299 | self.conv_up_2 = ConvBlockUpStandard( 300 | num_filts * 2, num_filts // 2, self.ip_height_rs // 8 301 | ) 302 | self.conv_up_3 = ConvBlockUpStandard( 303 | num_filts // 2, num_filts // 4, self.ip_height_rs // 4 304 | ) 305 | self.conv_up_4 = ConvBlockUpStandard( 306 | num_filts // 4, num_filts // 4, self.ip_height_rs // 2 307 | ) 308 | 309 | # output 310 | # +1 to include background class for class output 311 | self.conv_op = nn.Conv2d( 312 | num_filts // 4, num_filts // 4, kernel_size=3, padding=1 313 | ) 314 | self.conv_op_bn = nn.BatchNorm2d(num_filts // 4) 315 | self.conv_size_op = nn.Conv2d( 316 | num_filts // 4, 2, kernel_size=1, padding=0 317 | ) 318 | self.conv_classes_op = nn.Conv2d( 319 | num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0 320 | ) 321 | 322 | if self.emb_dim > 0: 323 | self.conv_emb = nn.Conv2d( 324 | num_filts, self.emb_dim, kernel_size=1, padding=0 325 | ) 326 | 327 | def forward(self, ip, return_feats=False) -> ModelOutput: 328 | x1 = self.conv_dn_0(ip) 329 | x2 = self.conv_dn_1(x1) 330 | x3 = self.conv_dn_2(x2) 331 | x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True) 332 | 333 | x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True) 334 | x = self.att(x) 335 | x = x.repeat([1, 1, self.bneck_height * 4, 1]) 336 | 337 | x = self.conv_up_2(x + x3) 338 | x = self.conv_up_3(x + x2) 339 | x = self.conv_up_4(x + x1) 340 | 341 | x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True) 342 | cls = self.conv_classes_op(x) 343 | comb = torch.softmax(cls, 1) 344 | 345 | pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,) 346 | 347 | return ModelOutput( 348 | pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), 349 | pred_size=F.relu(self.conv_size_op(x), inplace=True), 350 | pred_class=comb, 351 | pred_class_un_norm=cls, 352 | features=x, 353 | ) 354 | -------------------------------------------------------------------------------- /batdetect2/detector/parameters.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | from batdetect2.types import ProcessingConfiguration, SpectrogramParameters 5 | 6 | TARGET_SAMPLERATE_HZ = 256000 7 | FFT_WIN_LENGTH_S = 512 / 256000.0 8 | FFT_OVERLAP = 0.75 9 | MAX_FREQ_HZ = 120000 10 | MIN_FREQ_HZ = 10000 11 | RESIZE_FACTOR = 0.5 12 | SPEC_DIVIDE_FACTOR = 32 13 | SPEC_HEIGHT = 256 14 | SCALE_RAW_AUDIO = False 15 | DETECTION_THRESHOLD = 0.01 16 | NMS_KERNEL_SIZE = 9 17 | NMS_TOP_K_PER_SEC = 200 18 | SPEC_SCALE = "pcen" 19 | DENOISE_SPEC_AVG = True 20 | MAX_SCALE_SPEC = False 21 | 22 | 23 | DEFAULT_MODEL_PATH = os.path.join( 24 | os.path.dirname(os.path.dirname(__file__)), 25 | "models", 26 | "Net2DFast_UK_same.pth.tar", 27 | ) 28 | 29 | 30 | DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = { 31 | "fft_win_length": FFT_WIN_LENGTH_S, 32 | "fft_overlap": FFT_OVERLAP, 33 | "spec_height": SPEC_HEIGHT, 34 | "resize_factor": RESIZE_FACTOR, 35 | "spec_divide_factor": SPEC_DIVIDE_FACTOR, 36 | "max_freq": MAX_FREQ_HZ, 37 | "min_freq": MIN_FREQ_HZ, 38 | "spec_scale": SPEC_SCALE, 39 | "denoise_spec_avg": DENOISE_SPEC_AVG, 40 | "max_scale_spec": MAX_SCALE_SPEC, 41 | } 42 | 43 | 44 | DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = { 45 | "detection_threshold": DETECTION_THRESHOLD, 46 | "spec_slices": False, 47 | "chunk_size": 3, 48 | "spec_features": False, 49 | "cnn_features": False, 50 | "quiet": True, 51 | "target_samp_rate": TARGET_SAMPLERATE_HZ, 52 | "fft_win_length": FFT_WIN_LENGTH_S, 53 | "fft_overlap": FFT_OVERLAP, 54 | "resize_factor": RESIZE_FACTOR, 55 | "spec_divide_factor": SPEC_DIVIDE_FACTOR, 56 | "spec_height": SPEC_HEIGHT, 57 | "scale_raw_audio": SCALE_RAW_AUDIO, 58 | "class_names": [], 59 | "time_expansion": 1, 60 | "top_n": 3, 61 | "return_raw_preds": False, 62 | "max_duration": None, 63 | "nms_kernel_size": NMS_KERNEL_SIZE, 64 | "max_freq": MAX_FREQ_HZ, 65 | "min_freq": MIN_FREQ_HZ, 66 | "nms_top_k_per_sec": NMS_TOP_K_PER_SEC, 67 | "spec_scale": SPEC_SCALE, 68 | "denoise_spec_avg": DENOISE_SPEC_AVG, 69 | "max_scale_spec": MAX_SCALE_SPEC, 70 | } 71 | 72 | 73 | def mk_dir(path): 74 | if not os.path.isdir(path): 75 | os.makedirs(path) 76 | 77 | 78 | def get_params(make_dirs=False, exps_dir="../../experiments/"): 79 | params = {} 80 | 81 | params[ 82 | "model_name" 83 | ] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN 84 | params["num_filters"] = 128 85 | 86 | now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S") 87 | model_name = now_str + ".pth.tar" 88 | params["experiment"] = os.path.join(exps_dir, now_str, "") 89 | params["model_file_name"] = os.path.join(params["experiment"], model_name) 90 | params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "") 91 | params["op_im_dir_test"] = os.path.join( 92 | params["experiment"], "op_ims_test", "" 93 | ) 94 | # params['notes'] = '' # can save notes about an experiment here 95 | 96 | # spec parameters 97 | params[ 98 | "target_samp_rate" 99 | ] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate 100 | params[ 101 | "fft_win_length" 102 | ] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step 103 | params["fft_overlap"] = FFT_OVERLAP # stft window overlap 104 | 105 | params[ 106 | "max_freq" 107 | ] = MAX_FREQ_HZ # in Hz, everything above this will be discarded 108 | params[ 109 | "min_freq" 110 | ] = MIN_FREQ_HZ # in Hz, everything below this will be discarded 111 | 112 | params[ 113 | "resize_factor" 114 | ] = RESIZE_FACTOR # resize so the spectrogram at the input of the network 115 | params[ 116 | "spec_height" 117 | ] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed) 118 | params[ 119 | "spec_train_width" 120 | ] = 512 # units are number of time steps (before resizing is performed) 121 | params[ 122 | "spec_divide_factor" 123 | ] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height 124 | 125 | # spec processing params 126 | params[ 127 | "denoise_spec_avg" 128 | ] = DENOISE_SPEC_AVG # removes the mean for each frequency band 129 | params[ 130 | "scale_raw_audio" 131 | ] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1] 132 | params[ 133 | "max_scale_spec" 134 | ] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1 135 | params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none' 136 | 137 | # detection params 138 | params[ 139 | "detection_overlap" 140 | ] = 0.01 # has to be within this number of ms to count as detection 141 | params[ 142 | "ignore_start_end" 143 | ] = 0.01 # if start of GT calls are within this time from the start/end of file ignore 144 | params[ 145 | "detection_threshold" 146 | ] = DETECTION_THRESHOLD # the smaller this is the better the recall will be 147 | params[ 148 | "nms_kernel_size" 149 | ] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression 150 | params[ 151 | "nms_top_k_per_sec" 152 | ] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio 153 | params["target_sigma"] = 2.0 154 | 155 | # augmentation params 156 | params[ 157 | "aug_prob" 158 | ] = 0.20 # augmentations will be performed with this probability 159 | params["augment_at_train"] = True 160 | params["augment_at_train_combine"] = True 161 | params[ 162 | "echo_max_delay" 163 | ] = 0.005 # simulate echo by adding copy of raw audio 164 | params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec 165 | params[ 166 | "mask_max_time_perc" 167 | ] = 0.05 # max mask size - here percentage, not ideal 168 | params[ 169 | "mask_max_freq_perc" 170 | ] = 0.10 # max mask size - here percentage, not ideal 171 | params[ 172 | "spec_amp_scaling" 173 | ] = 2.0 # multiply the "volume" by 0:X times current amount 174 | params["aug_sampling_rates"] = [ 175 | 220500, 176 | 256000, 177 | 300000, 178 | 312500, 179 | 384000, 180 | 441000, 181 | 500000, 182 | ] 183 | 184 | # loss params 185 | params["train_loss"] = "focal" # mse or focal 186 | params["det_loss_weight"] = 1.0 # weight for the detection part of the loss 187 | params["size_loss_weight"] = 0.1 # weight for the bbox size loss 188 | params["class_loss_weight"] = 2.0 # weight for the classification loss 189 | params["individual_loss_weight"] = 0.0 # not used 190 | if params["individual_loss_weight"] == 0.0: 191 | params[ 192 | "emb_dim" 193 | ] = 0 # number of dimensions used for individual id embedding 194 | else: 195 | params["emb_dim"] = 3 196 | 197 | # train params 198 | params["lr"] = 0.001 199 | params["batch_size"] = 8 200 | params["num_workers"] = 4 201 | params["num_epochs"] = 200 202 | params["num_eval_epochs"] = 5 # run evaluation every X epochs 203 | params["device"] = "cuda" 204 | params["save_test_image_during_train"] = False 205 | params["save_test_image_after_train"] = True 206 | 207 | params["convert_to_genus"] = False 208 | params["genus_mapping"] = [] 209 | params["class_names"] = [] 210 | params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"] 211 | params["generic_class"] = ["Bat"] 212 | params["events_of_interest"] = [ 213 | "Echolocation" 214 | ] # will ignore all other types of events e.g. social calls 215 | 216 | # the classes in this list are standardized during training so that the same low and high freq are used 217 | params["standardize_classs_names"] = [] 218 | 219 | # create directories 220 | if make_dirs: 221 | print("Model name : " + params["model_name"]) 222 | print("Model file : " + params["model_file_name"]) 223 | print("Experiment : " + params["experiment"]) 224 | 225 | mk_dir(params["experiment"]) 226 | if params["save_test_image_during_train"]: 227 | mk_dir(params["op_im_dir"]) 228 | if params["save_test_image_after_train"]: 229 | mk_dir(params["op_im_dir_test"]) 230 | mk_dir(os.path.dirname(params["model_file_name"])) 231 | 232 | return params 233 | -------------------------------------------------------------------------------- /batdetect2/detector/post_process.py: -------------------------------------------------------------------------------- 1 | """Post-processing of the output of the model.""" 2 | from typing import List, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | from batdetect2.detector.models import ModelOutput 9 | from batdetect2.types import NonMaximumSuppressionConfig, PredictionResults 10 | 11 | np.seterr(divide="ignore", invalid="ignore") 12 | 13 | 14 | def x_coords_to_time( 15 | x_pos: float, 16 | sampling_rate: int, 17 | fft_win_length: float, 18 | fft_overlap: float, 19 | ) -> float: 20 | """Convert x coordinates of spectrogram to time in seconds. 21 | 22 | Parameters 23 | ---------- 24 | x_pos: X position of the detection in pixels. 25 | sampling_rate: Sampling rate of the audio in Hz. 26 | fft_win_length: Length of the FFT window in seconds. 27 | fft_overlap: Overlap of the FFT windows in seconds. 28 | 29 | Returns 30 | ------- 31 | Time in seconds. 32 | """ 33 | nfft = int(fft_win_length * sampling_rate) 34 | noverlap = int(fft_overlap * nfft) 35 | return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate 36 | 37 | 38 | def overall_class_pred(det_prob, class_prob): 39 | weighted_pred = (class_prob * det_prob).sum(1) 40 | return weighted_pred / weighted_pred.sum() 41 | 42 | 43 | def run_nms( 44 | outputs: ModelOutput, 45 | params: NonMaximumSuppressionConfig, 46 | sampling_rate: np.ndarray, 47 | ) -> Tuple[List[PredictionResults], List[np.ndarray]]: 48 | """Run non-maximum suppression on the output of the model. 49 | 50 | Model outputs processed are expected to have a batch dimension. 51 | Each element of the batch is processed independently. The 52 | result is a pair of lists, one for the predictions and one for 53 | the features. Each element of the lists corresponds to one 54 | element of the batch. 55 | """ 56 | pred_det, pred_size, pred_class, _, features = outputs 57 | 58 | pred_det_nms = non_max_suppression(pred_det, params["nms_kernel_size"]) 59 | freq_rescale = (params["max_freq"] - params["min_freq"]) / pred_det.shape[ 60 | -2 61 | ] 62 | 63 | # NOTE: there will be small differences depending on which sampling rate 64 | # is chosen as we are choosing the same sampling rate for the entire batch 65 | duration = x_coords_to_time( 66 | pred_det.shape[-1], 67 | int(sampling_rate[0].item()), 68 | params["fft_win_length"], 69 | params["fft_overlap"], 70 | ) 71 | top_k = int(duration * params["nms_top_k_per_sec"]) 72 | scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) 73 | 74 | # loop over batch to save outputs 75 | preds: List[PredictionResults] = [] 76 | feats: List[np.ndarray] = [] 77 | for num_detection in range(pred_det_nms.shape[0]): 78 | # get valid indices 79 | inds_ord = torch.argsort(x_pos[num_detection, :]) 80 | valid_inds = ( 81 | scores[num_detection, inds_ord] > params["detection_threshold"] 82 | ) 83 | valid_inds = inds_ord[valid_inds] 84 | 85 | # create result dictionary 86 | pred = {} 87 | pred["det_probs"] = scores[num_detection, valid_inds] 88 | pred["x_pos"] = x_pos[num_detection, valid_inds] 89 | pred["y_pos"] = y_pos[num_detection, valid_inds] 90 | pred["bb_width"] = pred_size[ 91 | num_detection, 92 | 0, 93 | pred["y_pos"], 94 | pred["x_pos"], 95 | ] 96 | pred["bb_height"] = pred_size[ 97 | num_detection, 98 | 1, 99 | pred["y_pos"], 100 | pred["x_pos"], 101 | ] 102 | pred["start_times"] = x_coords_to_time( 103 | pred["x_pos"].float() / params["resize_factor"], 104 | int(sampling_rate[num_detection].item()), 105 | params["fft_win_length"], 106 | params["fft_overlap"], 107 | ) 108 | pred["end_times"] = x_coords_to_time( 109 | (pred["x_pos"].float() + pred["bb_width"]) 110 | / params["resize_factor"], 111 | int(sampling_rate[num_detection].item()), 112 | params["fft_win_length"], 113 | params["fft_overlap"], 114 | ) 115 | pred["low_freqs"] = ( 116 | pred_size[num_detection].shape[1] - pred["y_pos"].float() 117 | ) * freq_rescale + params["min_freq"] 118 | pred["high_freqs"] = ( 119 | pred["low_freqs"] + pred["bb_height"] * freq_rescale 120 | ) 121 | 122 | # extract the per class votes 123 | if pred_class is not None: 124 | pred["class_probs"] = pred_class[ 125 | num_detection, 126 | :, 127 | y_pos[num_detection, valid_inds], 128 | x_pos[num_detection, valid_inds], 129 | ] 130 | 131 | # extract the model features 132 | if features is not None: 133 | feat = features[ 134 | num_detection, 135 | :, 136 | y_pos[num_detection, valid_inds], 137 | x_pos[num_detection, valid_inds], 138 | ].transpose(0, 1) 139 | feat = feat.detach().cpu().numpy().astype(np.float32) 140 | feats.append(feat) 141 | 142 | # convert to numpy 143 | for key, value in pred.items(): 144 | pred[key] = value.detach().cpu().numpy().astype(np.float32) 145 | 146 | preds.append(pred) # type: ignore 147 | 148 | return preds, feats 149 | 150 | 151 | def non_max_suppression( 152 | heat: torch.Tensor, 153 | kernel_size: Union[int, Tuple[int, int]], 154 | ): 155 | # kernel can be an int or list/tuple 156 | if isinstance(kernel_size, int): 157 | kernel_size_h = kernel_size 158 | kernel_size_w = kernel_size 159 | else: 160 | kernel_size_h, kernel_size_w = kernel_size 161 | 162 | pad_h = (kernel_size_h - 1) // 2 163 | pad_w = (kernel_size_w - 1) // 2 164 | 165 | hmax = nn.functional.max_pool2d( 166 | heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w) 167 | ) 168 | keep = (hmax == heat).float() 169 | 170 | return heat * keep 171 | 172 | 173 | def get_topk_scores(scores, K): 174 | # expects input of size: batch x 1 x height x width 175 | batch, _, height, width = scores.size() 176 | 177 | topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) 178 | topk_inds = topk_inds % (height * width) 179 | topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long() 180 | topk_xs = (topk_inds % width).long() 181 | 182 | return topk_scores, topk_ys, topk_xs 183 | -------------------------------------------------------------------------------- /batdetect2/evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/batdetect2/evaluate/__init__.py -------------------------------------------------------------------------------- /batdetect2/evaluate/readme.md: -------------------------------------------------------------------------------- 1 | # Evaluating BatDetect2 2 | 3 | > **Warning** 4 | > This code in currently broken. Will fix soon, stay tuned. 5 | 6 | This script evaluates a trained model and outputs several plots summarizing the performance. It is used as follows: 7 | `python path_to_store_images/ path_to_audio_files/ path_to_annotation_file/ path_to_trained_model/` 8 | 9 | e.g. 10 | `python evaluate_models.py ../../plots/results_compare_yuc/ /data1/bat_data/data/yucatan/audio/ /data1/bat_data/annotations/anns_finetune/ ../../experiments/2021_12_17__15_58_43/2021_12_17__15_58_43.pth.tar` 11 | 12 | By default this will just evaluate the set of test files that are already specified in the model at training time. However, you can also specify a single set of annotations to evaluate using the `--test_file` flag. These must be stored in one annotation file, containing a list of the individual files. 13 | 14 | e.g. 15 | `python evaluate_models.py ../../plots/results_compare_yuc/ /data1/bat_data/data/yucatan/audio/ /data1/bat_data/annotations/anns_finetune/ ../../experiments/2021_12_17__15_58_43/2021_12_17__15_58_43.pth.tar --test_file yucatan_TEST.json` 16 | 17 | You can also specify if the plots are saved as a .png or .pdf using `--file_type` and you can set title text for a plot using `--title_text`, e.g. `--file_type pdf --title_text "My Dataset Name"` 18 | 19 | 20 | 21 | 22 | ### Comparing to Tadarida-D 23 | It is also possible to compare to Tadarida-D. For Tadarida-D the following steps are performed: 24 | - Matches Tadarida's detections to manually annotated calls 25 | - Trains a RandomForest classifier using Tadarida call features 26 | - Evaluate the classifier on a held out set 27 | 28 | Uses precomputed binaries for Tadarida-D from: 29 | `https://github.com/YvesBas/Tadarida-D/archive/master.zip` 30 | 31 | Needs to be run using the following arguments: 32 | `./TadaridaD -t 4 -x 1 ip_dir/` 33 | -t 4 means 4 threads 34 | -x 1 means time expansions of 1 35 | 36 | This will generate a folder called `txt` which contains a corresponding `.ta` file for each input audio file. Example usage is as follows: 37 | `python evaluate_models.py ../../plots/results_compare_yuc/ /data1/bat_data/data/yucatan/audio/ /data1/bat_data/annotations/anns_finetune/ ../../experiments/2021_12_17__15_58_43/2021_12_17__15_58_43.pth.tar --td_ip_dir /data1/bat_data/baselines/tadarida_D/` 38 | -------------------------------------------------------------------------------- /batdetect2/finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/batdetect2/finetune/__init__.py -------------------------------------------------------------------------------- /batdetect2/finetune/finetune_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | import sys 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | 13 | import batdetect2.detector.models as models 14 | import batdetect2.detector.parameters as parameters 15 | import batdetect2.detector.post_process as pp 16 | import batdetect2.train.audio_dataloader as adl 17 | import batdetect2.train.evaluate as evl 18 | import batdetect2.train.losses as losses 19 | import batdetect2.train.train_model as tm 20 | import batdetect2.train.train_utils as tu 21 | import batdetect2.utils.detector_utils as du 22 | import batdetect2.utils.plot_utils as pu 23 | 24 | if __name__ == "__main__": 25 | info_str = "\nBatDetect - Finetune Model\n" 26 | 27 | print(info_str) 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "audio_path", type=str, help="Input directory for audio" 31 | ) 32 | parser.add_argument( 33 | "train_ann_path", 34 | type=str, 35 | help="Path to where train annotation file is stored", 36 | ) 37 | parser.add_argument( 38 | "test_ann_path", 39 | type=str, 40 | help="Path to where test annotation file is stored", 41 | ) 42 | parser.add_argument("model_path", type=str, help="Path to pretrained model") 43 | parser.add_argument( 44 | "--op_model_name", 45 | type=str, 46 | default="", 47 | help="Path and name for finetuned model", 48 | ) 49 | parser.add_argument( 50 | "--num_epochs", 51 | type=int, 52 | default=200, 53 | dest="num_epochs", 54 | help="Number of finetuning epochs", 55 | ) 56 | parser.add_argument( 57 | "--finetune_only_last_layer", 58 | action="store_true", 59 | help="Only train final layers", 60 | ) 61 | parser.add_argument( 62 | "--train_from_scratch", 63 | action="store_true", 64 | help="Do not use pretrained weights", 65 | ) 66 | parser.add_argument( 67 | "--do_not_save_images", 68 | action="store_false", 69 | help="Do not save images at the end of training", 70 | ) 71 | parser.add_argument( 72 | "--notes", type=str, default="", help="Notes to save in text file" 73 | ) 74 | args = vars(parser.parse_args()) 75 | 76 | params = parameters.get_params(True, "../../experiments/") 77 | if torch.cuda.is_available(): 78 | params["device"] = "cuda" 79 | else: 80 | params["device"] = "cpu" 81 | print( 82 | "\nNote, this will be a lot faster if you use computer with a GPU.\n" 83 | ) 84 | 85 | print("\nAudio directory: " + args["audio_path"]) 86 | print("Train file: " + args["train_ann_path"]) 87 | print("Test file: " + args["test_ann_path"]) 88 | print("Loading model: " + args["model_path"]) 89 | 90 | dataset_name = ( 91 | os.path.basename(args["train_ann_path"]) 92 | .replace(".json", "") 93 | .replace("_TRAIN", "") 94 | ) 95 | 96 | if args["train_from_scratch"]: 97 | print("\nTraining model from scratch i.e. not using pretrained weights") 98 | model, params_train = du.load_model(args["model_path"], False) 99 | else: 100 | model, params_train = du.load_model(args["model_path"], True) 101 | model.to(params["device"]) 102 | 103 | params["num_epochs"] = args["num_epochs"] 104 | if args["op_model_name"] != "": 105 | params["model_file_name"] = args["op_model_name"] 106 | classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] 107 | 108 | # save notes file 109 | params["notes"] = args["notes"] 110 | if args["notes"] != "": 111 | tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"]) 112 | 113 | # load train annotations 114 | train_sets = [] 115 | train_sets.append( 116 | tu.get_blank_dataset_dict( 117 | dataset_name, False, args["train_ann_path"], args["audio_path"] 118 | ) 119 | ) 120 | params["train_sets"] = [ 121 | tu.get_blank_dataset_dict( 122 | dataset_name, 123 | False, 124 | os.path.basename(args["train_ann_path"]), 125 | args["audio_path"], 126 | ) 127 | ] 128 | 129 | print("\nTrain set:") 130 | ( 131 | data_train, 132 | params["class_names"], 133 | params["class_inv_freq"], 134 | ) = tu.load_set_of_anns( 135 | train_sets, classes_to_ignore, params["events_of_interest"] 136 | ) 137 | print("Number of files", len(data_train)) 138 | 139 | params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( 140 | params["class_names"] 141 | ) 142 | params["class_names_short"] = tu.get_short_class_names( 143 | params["class_names"] 144 | ) 145 | 146 | # load test annotations 147 | test_sets = [] 148 | test_sets.append( 149 | tu.get_blank_dataset_dict( 150 | dataset_name, True, args["test_ann_path"], args["audio_path"] 151 | ) 152 | ) 153 | params["test_sets"] = [ 154 | tu.get_blank_dataset_dict( 155 | dataset_name, 156 | True, 157 | os.path.basename(args["test_ann_path"]), 158 | args["audio_path"], 159 | ) 160 | ] 161 | 162 | print("\nTest set:") 163 | data_test, _, _ = tu.load_set_of_anns( 164 | test_sets, classes_to_ignore, params["events_of_interest"] 165 | ) 166 | print("Number of files", len(data_test)) 167 | 168 | # train loader 169 | train_dataset = adl.AudioLoader(data_train, params, is_train=True) 170 | train_loader = torch.utils.data.DataLoader( 171 | train_dataset, 172 | batch_size=params["batch_size"], 173 | shuffle=True, 174 | num_workers=params["num_workers"], 175 | pin_memory=True, 176 | ) 177 | 178 | # test loader - batch size of one because of variable file length 179 | test_dataset = adl.AudioLoader(data_test, params, is_train=False) 180 | test_loader = torch.utils.data.DataLoader( 181 | test_dataset, 182 | batch_size=1, 183 | shuffle=False, 184 | num_workers=params["num_workers"], 185 | pin_memory=True, 186 | ) 187 | 188 | inputs_train = next(iter(train_loader)) 189 | params["ip_height"] = inputs_train["spec"].shape[2] 190 | print("\ntrain batch size :", inputs_train["spec"].shape) 191 | 192 | assert params_train["model_name"] == "Net2DFast" 193 | print( 194 | "\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n" 195 | ) 196 | 197 | # set the number of output classes 198 | num_filts = model.conv_classes_op.in_channels 199 | k_size = model.conv_classes_op.kernel_size 200 | pad = model.conv_classes_op.padding 201 | model.conv_classes_op = torch.nn.Conv2d( 202 | num_filts, 203 | len(params["class_names"]) + 1, 204 | kernel_size=k_size, 205 | padding=pad, 206 | ) 207 | model.conv_classes_op.to(params["device"]) 208 | 209 | if args["finetune_only_last_layer"]: 210 | print("\nOnly finetuning the final layers.\n") 211 | train_layers_i = [ 212 | "conv_classes", 213 | "conv_classes_op", 214 | "conv_size", 215 | "conv_size_op", 216 | ] 217 | train_layers = [tt + ".weight" for tt in train_layers_i] + [ 218 | tt + ".bias" for tt in train_layers_i 219 | ] 220 | for name, param in model.named_parameters(): 221 | if name in train_layers: 222 | param.requires_grad = True 223 | else: 224 | param.requires_grad = False 225 | 226 | optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"]) 227 | scheduler = CosineAnnealingLR( 228 | optimizer, params["num_epochs"] * len(train_loader) 229 | ) 230 | if params["train_loss"] == "mse": 231 | det_criterion = losses.mse_loss 232 | elif params["train_loss"] == "focal": 233 | det_criterion = losses.focal_loss 234 | 235 | # plotting 236 | train_plt_ls = pu.LossPlotter( 237 | params["experiment"] + "train_loss.png", 238 | params["num_epochs"] + 1, 239 | ["train_loss"], 240 | None, 241 | None, 242 | ["epoch", "train_loss"], 243 | logy=True, 244 | ) 245 | test_plt_ls = pu.LossPlotter( 246 | params["experiment"] + "test_loss.png", 247 | params["num_epochs"] + 1, 248 | ["test_loss"], 249 | None, 250 | None, 251 | ["epoch", "test_loss"], 252 | logy=True, 253 | ) 254 | test_plt = pu.LossPlotter( 255 | params["experiment"] + "test.png", 256 | params["num_epochs"] + 1, 257 | ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"], 258 | [0, 1], 259 | None, 260 | ["epoch", ""], 261 | ) 262 | test_plt_class = pu.LossPlotter( 263 | params["experiment"] + "test_avg_prec.png", 264 | params["num_epochs"] + 1, 265 | params["class_names_short"], 266 | [0, 1], 267 | params["class_names_short"], 268 | ["epoch", "avg_prec"], 269 | ) 270 | 271 | # main train loop 272 | for epoch in range(0, params["num_epochs"] + 1): 273 | train_loss = tm.train( 274 | model, 275 | epoch, 276 | train_loader, 277 | det_criterion, 278 | optimizer, 279 | scheduler, 280 | params, 281 | ) 282 | train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]]) 283 | 284 | if epoch % params["num_eval_epochs"] == 0: 285 | # detection accuracy on test set 286 | test_res, test_loss = tm.test( 287 | model, epoch, test_loader, det_criterion, params 288 | ) 289 | test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]]) 290 | test_plt.update_and_save( 291 | epoch, 292 | [ 293 | test_res["avg_prec"], 294 | test_res["rec_at_x"], 295 | test_res["avg_prec_class"], 296 | test_res["file_acc"], 297 | test_res["top_class"]["avg_prec"], 298 | ], 299 | ) 300 | test_plt_class.update_and_save( 301 | epoch, [rs["avg_prec"] for rs in test_res["class_pr"]] 302 | ) 303 | pu.plot_pr_curve_class( 304 | params["experiment"], "test_pr", "test_pr", test_res 305 | ) 306 | 307 | # save finetuned model 308 | print("saving model to: " + params["model_file_name"]) 309 | op_state = { 310 | "epoch": epoch + 1, 311 | "state_dict": model.state_dict(), 312 | "params": params, 313 | } 314 | torch.save(op_state, params["model_file_name"]) 315 | 316 | # save an image with associated prediction for each batch in the test set 317 | if not args["do_not_save_images"]: 318 | tm.save_images_batch(model, test_loader, params) 319 | -------------------------------------------------------------------------------- /batdetect2/finetune/prep_data_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | 7 | import batdetect2.train.train_utils as tu 8 | 9 | 10 | def print_dataset_stats(data, split_name, classes_to_ignore): 11 | print("\nSplit:", split_name) 12 | print("Num files:", len(data)) 13 | 14 | class_cnts = {} 15 | for dd in data: 16 | for aa in dd["annotation"]: 17 | if aa["class"] not in classes_to_ignore: 18 | if aa["class"] in class_cnts: 19 | class_cnts[aa["class"]] += 1 20 | else: 21 | class_cnts[aa["class"]] = 1 22 | 23 | if len(class_cnts) == 0: 24 | class_names = [] 25 | else: 26 | class_names = np.sort([*class_cnts]).tolist() 27 | print("Class count:") 28 | str_len = np.max([len(cc) for cc in class_names]) + 5 29 | 30 | for ii, cc in enumerate(class_names): 31 | print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc])) 32 | 33 | return class_names 34 | 35 | 36 | def load_file_names(file_name): 37 | if os.path.isfile(file_name): 38 | with open(file_name) as da: 39 | files = [line.rstrip() for line in da.readlines()] 40 | for ff in files: 41 | if ff.lower()[-3:] != "wav": 42 | print("Error: Filenames need to end in .wav - ", ff) 43 | assert False 44 | else: 45 | print("Error: Input file not found - ", file_name) 46 | assert False 47 | 48 | return files 49 | 50 | 51 | if __name__ == "__main__": 52 | info_str = "\nBatDetect - Prepare Data for Finetuning\n" 53 | 54 | print(info_str) 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument( 57 | "dataset_name", type=str, help="Name to call your dataset" 58 | ) 59 | parser.add_argument("audio_dir", type=str, help="Input directory for audio") 60 | parser.add_argument( 61 | "ann_dir", 62 | type=str, 63 | help="Input directory for where the audio annotations are stored", 64 | ) 65 | parser.add_argument( 66 | "op_dir", 67 | type=str, 68 | help="Path where the train and test splits will be stored", 69 | ) 70 | parser.add_argument( 71 | "--percent_val", 72 | type=float, 73 | default=0.20, 74 | help="Hold out this much data for validation. Should be number between 0 and 1", 75 | ) 76 | parser.add_argument( 77 | "--rand_seed", 78 | type=int, 79 | default=2001, 80 | help="Random seed used for creating the validation split", 81 | ) 82 | parser.add_argument( 83 | "--train_file", 84 | type=str, 85 | default="", 86 | help="Text file where each line is a wav file in train split", 87 | ) 88 | parser.add_argument( 89 | "--test_file", 90 | type=str, 91 | default="", 92 | help="Text file where each line is a wav file in test split", 93 | ) 94 | parser.add_argument( 95 | "--input_class_names", 96 | type=str, 97 | default="", 98 | help='Specify names of classes that you want to change. Separate with ";"', 99 | ) 100 | parser.add_argument( 101 | "--output_class_names", 102 | type=str, 103 | default="", 104 | help='New class names to use instead. One to one mapping with "--input_class_names". \ 105 | Separate with ";"', 106 | ) 107 | args = vars(parser.parse_args()) 108 | 109 | np.random.seed(args["rand_seed"]) 110 | 111 | classes_to_ignore = ["", " ", "Unknown", "Not Bat"] 112 | generic_class = ["Bat"] 113 | events_of_interest = ["Echolocation"] 114 | 115 | if args["input_class_names"] != "" and args["output_class_names"] != "": 116 | # change the names of the classes 117 | ip_names = args["input_class_names"].split(";") 118 | op_names = args["output_class_names"].split(";") 119 | name_dict = dict(zip(ip_names, op_names)) 120 | else: 121 | name_dict = False 122 | 123 | # load annotations 124 | data_all, _, _ = tu.load_set_of_anns( 125 | {"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]}, 126 | classes_to_ignore, 127 | events_of_interest, 128 | False, 129 | False, 130 | list_of_anns=True, 131 | filter_issues=True, 132 | name_replace=name_dict, 133 | ) 134 | 135 | print("Dataset name: " + args["dataset_name"]) 136 | print("Audio directory: " + args["audio_dir"]) 137 | print("Annotation directory: " + args["ann_dir"]) 138 | print("Ouput directory: " + args["op_dir"]) 139 | print("Num annotated files: " + str(len(data_all))) 140 | 141 | if args["train_file"] != "" and args["test_file"] != "": 142 | # user has specifed the train / test split 143 | train_files = load_file_names(args["train_file"]) 144 | test_files = load_file_names(args["test_file"]) 145 | file_names_all = [dd["id"] for dd in data_all] 146 | train_inds = [ 147 | file_names_all.index(ff) 148 | for ff in train_files 149 | if ff in file_names_all 150 | ] 151 | test_inds = [ 152 | file_names_all.index(ff) 153 | for ff in test_files 154 | if ff in file_names_all 155 | ] 156 | 157 | else: 158 | # split the data into train and test at the file level 159 | num_exs = len(data_all) 160 | test_inds = np.random.choice( 161 | np.arange(num_exs), 162 | int(num_exs * args["percent_val"]), 163 | replace=False, 164 | ) 165 | test_inds = np.sort(test_inds) 166 | train_inds = np.setdiff1d(np.arange(num_exs), test_inds) 167 | 168 | data_train = [data_all[ii] for ii in train_inds] 169 | data_test = [data_all[ii] for ii in test_inds] 170 | 171 | if not os.path.isdir(args["op_dir"]): 172 | os.makedirs(args["op_dir"]) 173 | op_name = os.path.join(args["op_dir"], args["dataset_name"]) 174 | op_name_train = op_name + "_TRAIN.json" 175 | op_name_test = op_name + "_TEST.json" 176 | 177 | class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore) 178 | class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore) 179 | 180 | if len(data_train) > 0 and len(data_test) > 0: 181 | if class_un_train != class_un_test: 182 | print( 183 | '\nError: some classes are not in both the training and test sets.\ 184 | \nTry a different random seed "--rand_seed".' 185 | ) 186 | assert False 187 | 188 | print("\n") 189 | if len(data_train) == 0: 190 | print("No train annotations to save") 191 | else: 192 | print("Saving: ", op_name_train) 193 | with open(op_name_train, "w") as da: 194 | json.dump(data_train, da, indent=2) 195 | 196 | if len(data_test) == 0: 197 | print("No test annotations to save") 198 | else: 199 | print("Saving: ", op_name_test) 200 | with open(op_name_test, "w") as da: 201 | json.dump(data_test, da, indent=2) 202 | -------------------------------------------------------------------------------- /batdetect2/finetune/readme.md: -------------------------------------------------------------------------------- 1 | # Finetuning the BatDetet2 model on your own data 2 | 3 | > **Warning** 4 | > This code in currently broken. Will fix soon, stay tuned. 5 | 6 | Main steps: 7 | 1. Annotate your data using the annotation GUI. 8 | 2. Run `prep_data_finetune.py` to create a training and validation split for your data. 9 | 3. Run `finetune_model.py` to finetune a model on your data. 10 | 11 | 12 | ## 1. Annotate calls of interest in audio data 13 | Use the annotation tools provided [here](https://github.com/macaodha/batdetect2_GUI) to manually identify where the events of interest (e.g. bat echolocation calls) are in your files. 14 | This will result in a directory of audio files and a directory of annotation files, where each audio file will have a corresponding `.json` annotation file. 15 | Make sure to annotation all instances of a bat call. 16 | If unsure of the species, just label the call as `Bat`. 17 | 18 | 19 | ## 2. Split data into train and test sets 20 | After performing the previous step you should have a directory of annotations files saved as jsons, one for each audio file you have annotated. 21 | * The next step is to split these into training and testing subsets. 22 | Run `prep_data_finetune.py` to split the data into train and test sets. This will result in two separate files, a train and a test one, i.e. 23 | `python prep_data_finetune.py dataset_name path_to_audio/ path_to_annotations/ path_to_output_anns/` 24 | This may result an error if it does not generate output files containing the same set of species in the train and test splits. You can try different random seeds if this is an issue e.g. `--rand_seed 123456`. 25 | 26 | * You can also load the train and test split using text files, where each line of the text file is the name of a `wav` file (without the file path) e.g. 27 | `python prep_data_finetune.py dataset_name path_to_audio/ path_to_annotations/ path_to_output/ --train_file path_to_file/list_of_train_files.txt --test_file path_to_file/list_of_test_files.txt` 28 | 29 | 30 | * Can also replace class names. This can be helpful if you don't think you have enough calls/files for a given species. Use semi-colons to separate, without spaces between them e.g. 31 | `python prep_data_finetune.py dataset_name path_to_audio/audio/ path_to_annotations/anns/ path_to_output/ --input_class_names "Histiotus;Molossidae;Lasiurus;Myotis;Rhogeesa;Vespertilionidae" --output_class_names "Group One;Group One;Group One;Group Two;Group Two;Group Three"` 32 | 33 | 34 | ## 3. Finetuning the model 35 | Finally, you can finetune the model using your data i.e. 36 | `python finetune_model.py path_to_audio/ path_to_train/TRAIN.json path_to_train/TEST.json ../../models/Net2DFast_UK_same.pth.tar` 37 | Here, `TRAIN.json` and `TEST.json` are the splits created in the previous steps. 38 | 39 | 40 | #### Additional notes 41 | * For the first step it is better to cut the files into less than 5 second audio clips and make sure to annotate them exhaustively (i.e. all bat calls should be annotated). 42 | * You can train the model for longer, by setting the `--num_epochs` flag to a larger number e.g. `--num_epochs 400`. The default is `200`. 43 | * If you do not want to finetune the model, but instead want to train it from scratch, you can set the `--train_from_scratch` flag. 44 | -------------------------------------------------------------------------------- /batdetect2/models/Net2DFast_UK_same.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/batdetect2/models/Net2DFast_UK_same.pth.tar -------------------------------------------------------------------------------- /batdetect2/models/readme.md: -------------------------------------------------------------------------------- 1 | Trained models go here. 2 | -------------------------------------------------------------------------------- /batdetect2/plot.py: -------------------------------------------------------------------------------- 1 | """Plot functions to visualize detections and spectrograms.""" 2 | 3 | from typing import List, Optional, Tuple, Union, cast 4 | 5 | import numpy as np 6 | import torch 7 | from matplotlib import axes, patches 8 | import matplotlib.ticker as tick 9 | from matplotlib import pyplot as plt 10 | 11 | from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS 12 | from batdetect2.types import ( 13 | Annotation, 14 | ProcessingConfiguration, 15 | SpectrogramParameters, 16 | ) 17 | 18 | __all__ = [ 19 | "spectrogram_with_detections", 20 | "detection", 21 | "detections", 22 | "spectrogram", 23 | ] 24 | 25 | 26 | def spectrogram( 27 | spec: Union[torch.Tensor, np.ndarray], 28 | config: Optional[ProcessingConfiguration] = None, 29 | ax: Optional[axes.Axes] = None, 30 | figsize: Optional[Tuple[int, int]] = None, 31 | cmap: str = "plasma", 32 | start_time: float = 0, 33 | ) -> axes.Axes: 34 | """Plot a spectrogram. 35 | 36 | Parameters 37 | ---------- 38 | spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. 39 | config (Optional[ProcessingConfiguration], optional): Configuration 40 | used to compute the spectrogram. Defaults to None. If None, 41 | the default configuration will be used. 42 | ax (Optional[axes.Axes], optional): Matplotlib axes object. 43 | Defaults to None. if provided, the spectrogram will be plotted 44 | on this axes. 45 | figsize (Optional[Tuple[int, int]], optional): Figure size. 46 | Defaults to None. If `ax` is None, this will be used to create 47 | a new figure of the given size. 48 | cmap (str, optional): Colormap to use. Defaults to "plasma". 49 | start_time (float, optional): Start time of the spectrogram. 50 | Defaults to 0. This is useful if plotting a spectrogram 51 | of a segment of a longer audio file. 52 | 53 | Returns 54 | ------- 55 | axes.Axes: Matplotlib axes object. 56 | 57 | Raises 58 | ------ 59 | ValueError: If the spectrogram is not of 60 | shape (1, T, F), (1, 1, T, F) or (T, F) 61 | """ 62 | # Convert to numpy array if needed 63 | if isinstance(spec, torch.Tensor): 64 | spec = spec.detach().cpu().numpy() 65 | 66 | # Remove batch and channel dimensions if present 67 | spec = spec.squeeze() 68 | 69 | if spec.ndim != 2: 70 | raise ValueError( 71 | f"Expected a 2D tensor, got {spec.ndim}D tensor instead." 72 | ) 73 | 74 | # Get config 75 | if config is None: 76 | config = DEFAULT_PROCESSING_CONFIGURATIONS.copy() 77 | 78 | # Frequency axis is reversed 79 | spec = spec[::-1, :] 80 | 81 | if ax is None: 82 | # Using cast to fix typing. pyplot subplots is not 83 | # correctly typed. 84 | ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1]) 85 | 86 | # compute extent 87 | extent = _compute_spec_extent(spec.shape, config) 88 | 89 | # add start time 90 | extent = (extent[0] + start_time, extent[1] + start_time, *extent[2:]) 91 | 92 | ax.imshow(spec, aspect="auto", origin="lower", cmap=cmap, extent=extent) 93 | 94 | ax.set_xlabel("Time (s)") 95 | ax.set_ylabel("Frequency (kHz)") 96 | 97 | def y_fmt(x, _): 98 | return f"{x / 1000:.0f}" 99 | 100 | ax.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) 101 | 102 | return ax 103 | 104 | 105 | def spectrogram_with_detections( 106 | spec: Union[torch.Tensor, np.ndarray], 107 | dets: List[Annotation], 108 | config: Optional[ProcessingConfiguration] = None, 109 | ax: Optional[axes.Axes] = None, 110 | figsize: Optional[Tuple[int, int]] = None, 111 | cmap: str = "plasma", 112 | with_names: bool = True, 113 | start_time: float = 0, 114 | **kwargs, 115 | ) -> axes.Axes: 116 | """Plot a spectrogram with detections. 117 | 118 | Parameters 119 | ---------- 120 | spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. 121 | detections (List[Annotation]): List of detections. 122 | config (Optional[ProcessingConfiguration], optional): Configuration 123 | used to compute the spectrogram. Defaults to None. If None, 124 | the default configuration will be used. 125 | ax (Optional[axes.Axes], optional): Matplotlib axes object. 126 | Defaults to None. if provided, the spectrogram will be plotted 127 | on this axes. 128 | figsize (Optional[Tuple[int, int]], optional): Figure size. 129 | Defaults to None. If `ax` is None, this will be used to create 130 | a new figure of the given size. 131 | cmap (str, optional): Colormap to use. Defaults to "plasma". 132 | with_names (bool, optional): Whether to plot the name of the 133 | predicted class next to the detection. Defaults to True. 134 | start_time (float, optional): Start time of the spectrogram. 135 | Defaults to 0. This is useful if plotting a spectrogram 136 | of a segment of a longer audio file. 137 | **kwargs: Additional keyword arguments to pass to the 138 | `plot.detections` function. 139 | 140 | Returns 141 | ------- 142 | axes.Axes: Matplotlib axes object. 143 | 144 | Raises 145 | ------ 146 | ValueError: If the spectrogram is not of shape (1, F, T), 147 | (1, 1, F, T) or (F, T). 148 | """ 149 | ax = spectrogram( 150 | spec, 151 | start_time=start_time, 152 | config=config, 153 | cmap=cmap, 154 | ax=ax, 155 | figsize=figsize, 156 | ) 157 | 158 | ax = detections( 159 | dets, 160 | ax=ax, 161 | figsize=figsize, 162 | with_names=with_names, 163 | **kwargs, 164 | ) 165 | 166 | return ax 167 | 168 | 169 | def detections( 170 | dets: List[Annotation], 171 | ax: Optional[axes.Axes] = None, 172 | figsize: Optional[Tuple[int, int]] = None, 173 | with_names: bool = True, 174 | **kwargs, 175 | ) -> axes.Axes: 176 | """Plot a list of detections. 177 | 178 | Parameters 179 | ---------- 180 | dets (List[Annotation]): List of detections. 181 | ax (Optional[axes.Axes], optional): Matplotlib axes object. 182 | Defaults to None. if provided, the spectrogram will be plotted 183 | on this axes. 184 | figsize (Optional[Tuple[int, int]], optional): Figure size. 185 | Defaults to None. If `ax` is None, this will be used to create 186 | a new figure of the given size. 187 | with_names (bool, optional): Whether to plot the name of the 188 | predicted class next to the detection. Defaults to True. 189 | **kwargs: Additional keyword arguments to pass to the 190 | `plot.detection` function. 191 | 192 | Returns 193 | ------- 194 | axes.Axes: Matplotlib axes object on which the detections 195 | were plotted. 196 | """ 197 | if ax is None: 198 | # Using cast to fix typing. pyplot subplots is not 199 | # correctly typed. 200 | ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1]) 201 | 202 | for det in dets: 203 | ax = detection( 204 | det, 205 | ax=ax, 206 | figsize=figsize, 207 | with_name=with_names, 208 | **kwargs, 209 | ) 210 | 211 | return ax 212 | 213 | 214 | def detection( 215 | det: Annotation, 216 | ax: Optional[axes.Axes] = None, 217 | figsize: Optional[Tuple[int, int]] = None, 218 | linewidth: float = 1, 219 | edgecolor: str = "w", 220 | facecolor: str = "none", 221 | with_name: bool = True, 222 | ) -> axes.Axes: 223 | """Plot a single detection. 224 | 225 | Parameters 226 | ---------- 227 | det (Annotation): Detection to plot. 228 | ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults 229 | to None. If provided, the spectrogram will be plotted on this axes. 230 | figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults 231 | to None. If `ax` is None, this will be used to create a new figure 232 | of the given size. 233 | linewidth (float, optional): Line width of the detection. 234 | Defaults to 1. 235 | edgecolor (str, optional): Edge color of the detection. 236 | Defaults to "w", i.e. white. 237 | facecolor (str, optional): Face color of the detection. 238 | Defaults to "none", i.e. transparent. 239 | with_name (bool, optional): Whether to plot the name of the 240 | predicted class next to the detection. Defaults to True. 241 | 242 | Returns 243 | ------- 244 | axes.Axes: Matplotlib axes object on which the detection 245 | was plotted. 246 | """ 247 | if ax is None: 248 | # Using cast to fix typing. pyplot subplots is not 249 | # correctly typed. 250 | ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1]) 251 | 252 | # Plot detection 253 | rect = patches.Rectangle( 254 | (det["start_time"], det["low_freq"]), 255 | det["end_time"] - det["start_time"], 256 | det["high_freq"] - det["low_freq"], 257 | linewidth=linewidth, 258 | edgecolor=edgecolor, 259 | facecolor=facecolor, 260 | alpha=det.get("det_prob", 1), 261 | ) 262 | ax.add_patch(rect) 263 | 264 | if with_name: 265 | # Add class label 266 | txt = " ".join([sp[:3] for sp in det["class"].split(" ")]) 267 | font_info = { 268 | "color": edgecolor, 269 | "size": 10, 270 | "weight": "bold", 271 | "alpha": rect.get_alpha(), 272 | } 273 | y_pos = rect.get_xy()[1] + rect.get_height() 274 | ax.text(rect.get_xy()[0], y_pos, txt, fontdict=font_info) 275 | 276 | return ax 277 | 278 | 279 | def _compute_spec_extent( 280 | shape: Tuple[int, int], 281 | params: SpectrogramParameters, 282 | ) -> Tuple[float, float, float, float]: 283 | """Compute the extent of a spectrogram. 284 | 285 | Parameters 286 | ---------- 287 | shape (Tuple[int, int]): Shape of the spectrogram. 288 | The first dimension is the frequency axis and the second 289 | dimension is the time axis. 290 | params (SpectrogramParameters): Spectrogram parameters. 291 | Should be the same as the ones used to compute the spectrogram. 292 | 293 | Returns 294 | ------- 295 | Tuple[float, float, float, float]: Extent of the spectrogram. 296 | The first two values are the minimum and maximum time values, 297 | the last two values are the minimum and maximum frequency values. 298 | """ 299 | fft_win_length = params["fft_win_length"] 300 | fft_overlap = params["fft_overlap"] 301 | max_freq = params["max_freq"] 302 | min_freq = params["min_freq"] 303 | 304 | # compute duration based on spectrogram parameters 305 | duration = (shape[1] + 1) * (fft_win_length * (1 - fft_overlap)) 306 | 307 | # If the spectrogram is not resized, the duration is correct 308 | # but if it is resized, the duration needs to be adjusted 309 | resize_factor = params["resize_factor"] 310 | spec_height = params["spec_height"] 311 | if spec_height * resize_factor == shape[0]: 312 | duration = duration / resize_factor 313 | 314 | return 0, duration, min_freq, max_freq 315 | -------------------------------------------------------------------------------- /batdetect2/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/batdetect2/train/__init__.py -------------------------------------------------------------------------------- /batdetect2/train/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def bbox_size_loss(pred_size, gt_size): 6 | """ 7 | Bounding box size loss. Only compute loss where there is a bounding box. 8 | """ 9 | gt_size_mask = (gt_size > 0).float() 10 | return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / ( 11 | gt_size_mask.sum() + 1e-5 12 | ) 13 | 14 | 15 | def focal_loss(pred, gt, weights=None, valid_mask=None): 16 | """ 17 | Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints 18 | pred (batch x c x h x w) 19 | gt (batch x c x h x w) 20 | """ 21 | eps = 1e-5 22 | beta = 4 23 | alpha = 2 24 | 25 | pos_inds = gt.eq(1).float() 26 | neg_inds = gt.lt(1).float() 27 | 28 | pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds 29 | neg_loss = ( 30 | torch.log(1 - pred + eps) 31 | * torch.pow(pred, alpha) 32 | * torch.pow(1 - gt, beta) 33 | * neg_inds 34 | ) 35 | 36 | if weights is not None: 37 | pos_loss = pos_loss * weights 38 | # neg_loss = neg_loss*weights 39 | 40 | if valid_mask is not None: 41 | pos_loss = pos_loss * valid_mask 42 | neg_loss = neg_loss * valid_mask 43 | 44 | pos_loss = pos_loss.sum() 45 | neg_loss = neg_loss.sum() 46 | 47 | num_pos = pos_inds.float().sum() 48 | if num_pos == 0: 49 | loss = -neg_loss 50 | else: 51 | loss = -(pos_loss + neg_loss) / num_pos 52 | return loss 53 | 54 | 55 | def mse_loss(pred, gt, weights=None, valid_mask=None): 56 | """ 57 | Mean squared error loss. 58 | """ 59 | if valid_mask is None: 60 | op = ((gt - pred) ** 2).mean() 61 | else: 62 | op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum() 63 | return op 64 | -------------------------------------------------------------------------------- /batdetect2/train/readme.md: -------------------------------------------------------------------------------- 1 | ## How to train a model from scratch 2 | `python train_model.py data_dir annotation_dir` e.g. 3 | `python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/` 4 | 5 | More comprehensive instructions are provided in the finetune directory. 6 | 7 | 8 | ## Training on your own data 9 | You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory. 10 | 11 | Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on. 12 | 13 | Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task. 14 | 15 | ## Additional notes 16 | Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`). 17 | 18 | Training will be slow without a GPU. 19 | -------------------------------------------------------------------------------- /batdetect2/train/train_split.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run scripts/extract_anns.py to generate these json files. 3 | """ 4 | 5 | 6 | def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True): 7 | if split_name == "diff": 8 | train_sets, test_sets = split_diff(ann_dir, wav_dir, load_extra) 9 | elif split_name == "same": 10 | train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra) 11 | else: 12 | print("Split not defined") 13 | assert False 14 | 15 | return train_sets, test_sets 16 | 17 | 18 | def split_diff(ann_dir, wav_dir, load_extra=True): 19 | 20 | train_sets = [] 21 | if load_extra: 22 | train_sets.append( 23 | { 24 | "dataset_name": "BatDetective", 25 | "is_test": False, 26 | "is_binary": True, # just a bat / not bat dataset ie no classes 27 | "ann_path": ann_dir 28 | + "train_set_bulgaria_batdetective_with_bbs.json", 29 | "wav_path": wav_dir + "batdetect2ive/audio/", 30 | } 31 | ) 32 | train_sets.append( 33 | { 34 | "dataset_name": "bat_logger_qeop_empty", 35 | "is_test": False, 36 | "is_binary": True, 37 | "ann_path": ann_dir + "bat_logger_qeop_empty.json", 38 | "wav_path": wav_dir + "bat_logger_qeop_empty/audio/", 39 | } 40 | ) 41 | train_sets.append( 42 | { 43 | "dataset_name": "bat_logger_2016_empty", 44 | "is_test": False, 45 | "is_binary": True, 46 | "ann_path": ann_dir + "train_set_bat_logger_2016_empty.json", 47 | "wav_path": wav_dir + "bat_logger_2016/audio/", 48 | } 49 | ) 50 | # train_sets.append({'dataset_name': 'brazil_data_binary', 51 | # 'is_test': False, 52 | # 'ann_path': ann_dir + 'brazil_data_binary.json', 53 | # 'wav_path': wav_dir + 'brazil_data/audio/'}) 54 | 55 | train_sets.append( 56 | { 57 | "dataset_name": "echobank", 58 | "is_test": False, 59 | "is_binary": False, 60 | "ann_path": ann_dir + "Echobank_train_expert.json", 61 | "wav_path": wav_dir + "echobank/audio/", 62 | } 63 | ) 64 | train_sets.append( 65 | { 66 | "dataset_name": "sn_scot_nor", 67 | "is_test": False, 68 | "is_binary": False, 69 | "ann_path": ann_dir + "sn_scot_nor_0.5_expert.json", 70 | "wav_path": wav_dir + "sn_scot_nor/audio/", 71 | } 72 | ) 73 | train_sets.append( 74 | { 75 | "dataset_name": "BCT_1_sec", 76 | "is_test": False, 77 | "is_binary": False, 78 | "ann_path": ann_dir + "BCT_1_sec_train_expert.json", 79 | "wav_path": wav_dir + "BCT_1_sec/audio/", 80 | } 81 | ) 82 | train_sets.append( 83 | { 84 | "dataset_name": "bcireland", 85 | "is_test": False, 86 | "is_binary": False, 87 | "ann_path": ann_dir + "bcireland_expert.json", 88 | "wav_path": wav_dir + "bcireland/audio/", 89 | } 90 | ) 91 | train_sets.append( 92 | { 93 | "dataset_name": "rhinolophus_steve_BCT", 94 | "is_test": False, 95 | "is_binary": False, 96 | "ann_path": ann_dir + "rhinolophus_steve_BCT_expert.json", 97 | "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/", 98 | } 99 | ) 100 | 101 | test_sets = [] 102 | test_sets.append( 103 | { 104 | "dataset_name": "bat_data_martyn_2018", 105 | "is_test": True, 106 | "is_binary": False, 107 | "ann_path": ann_dir 108 | + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json", 109 | "wav_path": wav_dir + "bat_data_martyn_2018/audio/", 110 | } 111 | ) 112 | test_sets.append( 113 | { 114 | "dataset_name": "bat_data_martyn_2018_test", 115 | "is_test": True, 116 | "is_binary": False, 117 | "ann_path": ann_dir 118 | + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json", 119 | "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/", 120 | } 121 | ) 122 | test_sets.append( 123 | { 124 | "dataset_name": "bat_data_martyn_2019", 125 | "is_test": True, 126 | "is_binary": False, 127 | "ann_path": ann_dir 128 | + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json", 129 | "wav_path": wav_dir + "bat_data_martyn_2019/audio/", 130 | } 131 | ) 132 | test_sets.append( 133 | { 134 | "dataset_name": "bat_data_martyn_2019_test", 135 | "is_test": True, 136 | "is_binary": False, 137 | "ann_path": ann_dir 138 | + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json", 139 | "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/", 140 | } 141 | ) 142 | 143 | return train_sets, test_sets 144 | 145 | 146 | def split_same(ann_dir, wav_dir, load_extra=True): 147 | 148 | train_sets = [] 149 | if load_extra: 150 | train_sets.append( 151 | { 152 | "dataset_name": "BatDetective", 153 | "is_test": False, 154 | "is_binary": True, 155 | "ann_path": ann_dir 156 | + "train_set_bulgaria_batdetective_with_bbs.json", 157 | "wav_path": wav_dir + "batdetect2ive/audio/", 158 | } 159 | ) 160 | train_sets.append( 161 | { 162 | "dataset_name": "bat_logger_qeop_empty", 163 | "is_test": False, 164 | "is_binary": True, 165 | "ann_path": ann_dir + "bat_logger_qeop_empty.json", 166 | "wav_path": wav_dir + "bat_logger_qeop_empty/audio/", 167 | } 168 | ) 169 | train_sets.append( 170 | { 171 | "dataset_name": "bat_logger_2016_empty", 172 | "is_test": False, 173 | "is_binary": True, 174 | "ann_path": ann_dir + "train_set_bat_logger_2016_empty.json", 175 | "wav_path": wav_dir + "bat_logger_2016/audio/", 176 | } 177 | ) 178 | # train_sets.append({'dataset_name': 'brazil_data_binary', 179 | # 'is_test': False, 180 | # 'ann_path': ann_dir + 'brazil_data_binary.json', 181 | # 'wav_path': wav_dir + 'brazil_data/audio/'}) 182 | 183 | train_sets.append( 184 | { 185 | "dataset_name": "echobank", 186 | "is_test": False, 187 | "is_binary": False, 188 | "ann_path": ann_dir + "Echobank_train_expert_TRAIN.json", 189 | "wav_path": wav_dir + "echobank/audio/", 190 | } 191 | ) 192 | train_sets.append( 193 | { 194 | "dataset_name": "sn_scot_nor", 195 | "is_test": False, 196 | "is_binary": False, 197 | "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TRAIN.json", 198 | "wav_path": wav_dir + "sn_scot_nor/audio/", 199 | } 200 | ) 201 | train_sets.append( 202 | { 203 | "dataset_name": "BCT_1_sec", 204 | "is_test": False, 205 | "is_binary": False, 206 | "ann_path": ann_dir + "BCT_1_sec_train_expert_TRAIN.json", 207 | "wav_path": wav_dir + "BCT_1_sec/audio/", 208 | } 209 | ) 210 | train_sets.append( 211 | { 212 | "dataset_name": "bcireland", 213 | "is_test": False, 214 | "is_binary": False, 215 | "ann_path": ann_dir + "bcireland_expert_TRAIN.json", 216 | "wav_path": wav_dir + "bcireland/audio/", 217 | } 218 | ) 219 | train_sets.append( 220 | { 221 | "dataset_name": "rhinolophus_steve_BCT", 222 | "is_test": False, 223 | "is_binary": False, 224 | "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TRAIN.json", 225 | "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/", 226 | } 227 | ) 228 | train_sets.append( 229 | { 230 | "dataset_name": "bat_data_martyn_2018", 231 | "is_test": False, 232 | "is_binary": False, 233 | "ann_path": ann_dir 234 | + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json", 235 | "wav_path": wav_dir + "bat_data_martyn_2018/audio/", 236 | } 237 | ) 238 | train_sets.append( 239 | { 240 | "dataset_name": "bat_data_martyn_2018_test", 241 | "is_test": False, 242 | "is_binary": False, 243 | "ann_path": ann_dir 244 | + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json", 245 | "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/", 246 | } 247 | ) 248 | train_sets.append( 249 | { 250 | "dataset_name": "bat_data_martyn_2019", 251 | "is_test": False, 252 | "is_binary": False, 253 | "ann_path": ann_dir 254 | + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json", 255 | "wav_path": wav_dir + "bat_data_martyn_2019/audio/", 256 | } 257 | ) 258 | train_sets.append( 259 | { 260 | "dataset_name": "bat_data_martyn_2019_test", 261 | "is_test": False, 262 | "is_binary": False, 263 | "ann_path": ann_dir 264 | + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json", 265 | "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/", 266 | } 267 | ) 268 | 269 | # train_sets.append({'dataset_name': 'bat_data_martyn_2021_train', 270 | # 'is_test': False, 271 | # 'is_binary': False, 272 | # 'ann_path': ann_dir + 'bat_data_martyn_2021_TRAIN.json', 273 | # 'wav_path': wav_dir + 'bat_data_martyn_2021/audio/'}) 274 | # train_sets.append({'dataset_name': 'volunteers_2021_train', 275 | # 'is_test': False, 276 | # 'is_binary': False, 277 | # 'ann_path': ann_dir + 'volunteers_2021_TRAIN.json', 278 | # 'wav_path': wav_dir + 'volunteers_2021/audio/'}) 279 | 280 | test_sets = [] 281 | test_sets.append( 282 | { 283 | "dataset_name": "echobank", 284 | "is_test": True, 285 | "is_binary": False, 286 | "ann_path": ann_dir + "Echobank_train_expert_TEST.json", 287 | "wav_path": wav_dir + "echobank/audio/", 288 | } 289 | ) 290 | test_sets.append( 291 | { 292 | "dataset_name": "sn_scot_nor", 293 | "is_test": True, 294 | "is_binary": False, 295 | "ann_path": ann_dir + "sn_scot_nor_0.5_expert_TEST.json", 296 | "wav_path": wav_dir + "sn_scot_nor/audio/", 297 | } 298 | ) 299 | test_sets.append( 300 | { 301 | "dataset_name": "BCT_1_sec", 302 | "is_test": True, 303 | "is_binary": False, 304 | "ann_path": ann_dir + "BCT_1_sec_train_expert_TEST.json", 305 | "wav_path": wav_dir + "BCT_1_sec/audio/", 306 | } 307 | ) 308 | test_sets.append( 309 | { 310 | "dataset_name": "bcireland", 311 | "is_test": True, 312 | "is_binary": False, 313 | "ann_path": ann_dir + "bcireland_expert_TEST.json", 314 | "wav_path": wav_dir + "bcireland/audio/", 315 | } 316 | ) 317 | test_sets.append( 318 | { 319 | "dataset_name": "rhinolophus_steve_BCT", 320 | "is_test": True, 321 | "is_binary": False, 322 | "ann_path": ann_dir + "rhinolophus_steve_BCT_expert_TEST.json", 323 | "wav_path": wav_dir + "rhinolophus_steve_BCT/audio/", 324 | } 325 | ) 326 | test_sets.append( 327 | { 328 | "dataset_name": "bat_data_martyn_2018", 329 | "is_test": True, 330 | "is_binary": False, 331 | "ann_path": ann_dir 332 | + "BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json", 333 | "wav_path": wav_dir + "bat_data_martyn_2018/audio/", 334 | } 335 | ) 336 | test_sets.append( 337 | { 338 | "dataset_name": "bat_data_martyn_2018_test", 339 | "is_test": True, 340 | "is_binary": False, 341 | "ann_path": ann_dir 342 | + "BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json", 343 | "wav_path": wav_dir + "bat_data_martyn_2018_test/audio/", 344 | } 345 | ) 346 | test_sets.append( 347 | { 348 | "dataset_name": "bat_data_martyn_2019", 349 | "is_test": True, 350 | "is_binary": False, 351 | "ann_path": ann_dir 352 | + "BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json", 353 | "wav_path": wav_dir + "bat_data_martyn_2019/audio/", 354 | } 355 | ) 356 | test_sets.append( 357 | { 358 | "dataset_name": "bat_data_martyn_2019_test", 359 | "is_test": True, 360 | "is_binary": False, 361 | "ann_path": ann_dir 362 | + "BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json", 363 | "wav_path": wav_dir + "bat_data_martyn_2019_test/audio/", 364 | } 365 | ) 366 | 367 | # test_sets.append({'dataset_name': 'bat_data_martyn_2021_test', 368 | # 'is_test': True, 369 | # 'is_binary': False, 370 | # 'ann_path': ann_dir + 'bat_data_martyn_2021_TEST.json', 371 | # 'wav_path': wav_dir + 'bat_data_martyn_2021/audio/'}) 372 | # test_sets.append({'dataset_name': 'volunteers_2021_test', 373 | # 'is_test': True, 374 | # 'is_binary': False, 375 | # 'ann_path': ann_dir + 'volunteers_2021_TEST.json', 376 | # 'wav_path': wav_dir + 'volunteers_2021/audio/'}) 377 | 378 | return train_sets, test_sets 379 | -------------------------------------------------------------------------------- /batdetect2/train/train_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | 4 | import numpy as np 5 | 6 | 7 | def write_notes_file(file_name, text): 8 | with open(file_name, "a") as da: 9 | da.write(text + "\n") 10 | 11 | 12 | def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path): 13 | ddict = { 14 | "dataset_name": dataset_name, 15 | "is_test": is_test, 16 | "is_binary": False, 17 | "ann_path": ann_path, 18 | "wav_path": wav_path, 19 | } 20 | return ddict 21 | 22 | 23 | def get_short_class_names(class_names, str_len=3): 24 | class_names_short = [] 25 | for cc in class_names: 26 | class_names_short.append( 27 | " ".join([sp[:str_len] for sp in cc.split(" ")]) 28 | ) 29 | return class_names_short 30 | 31 | 32 | def remove_dupes(data_train, data_test): 33 | test_ids = [dd["id"] for dd in data_test] 34 | data_train_prune = [] 35 | for aa in data_train: 36 | if aa["id"] not in test_ids: 37 | data_train_prune.append(aa) 38 | diff = len(data_train) - len(data_train_prune) 39 | if diff != 0: 40 | print(diff, "items removed from train set") 41 | return data_train_prune 42 | 43 | 44 | def get_genus_mapping(class_names): 45 | genus_names, genus_mapping = np.unique( 46 | [cc.split(" ")[0] for cc in class_names], return_inverse=True 47 | ) 48 | return genus_names.tolist(), genus_mapping.tolist() 49 | 50 | 51 | def standardize_low_freq(data, class_of_interest): 52 | # address the issue of highly variable low frequency annotations 53 | # this often happens for contstant frequency calls 54 | # for the class of interest sets the low and high freq to be the dataset mean 55 | low_freqs = [] 56 | high_freqs = [] 57 | for dd in data: 58 | for aa in dd["annotation"]: 59 | if aa["class"] == class_of_interest: 60 | low_freqs.append(aa["low_freq"]) 61 | high_freqs.append(aa["high_freq"]) 62 | 63 | low_mean = np.mean(low_freqs) 64 | high_mean = np.mean(high_freqs) 65 | assert low_mean < high_mean 66 | 67 | print("\nStandardizing low and high frequency for:") 68 | print(class_of_interest) 69 | print("low: ", round(low_mean, 2)) 70 | print("high: ", round(high_mean, 2)) 71 | 72 | # only set the low freq, high stays the same 73 | # assumes that low_mean < high_mean 74 | for dd in data: 75 | for aa in dd["annotation"]: 76 | if aa["class"] == class_of_interest: 77 | aa["low_freq"] = low_mean 78 | if aa["high_freq"] < low_mean: 79 | aa["high_freq"] = high_mean 80 | 81 | return data 82 | 83 | 84 | def load_set_of_anns( 85 | data, 86 | classes_to_ignore=[], 87 | events_of_interest=None, 88 | convert_to_genus=False, 89 | verbose=True, 90 | list_of_anns=False, 91 | filter_issues=False, 92 | name_replace=False, 93 | ): 94 | 95 | # load the annotations 96 | anns = [] 97 | if list_of_anns: 98 | # path to list of individual json files 99 | anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"])) 100 | else: 101 | # dictionary of datasets 102 | for dd in data: 103 | anns.extend(load_anns(dd["ann_path"], dd["wav_path"])) 104 | 105 | # discarding unannoated files 106 | anns = [aa for aa in anns if aa["annotated"] is True] 107 | 108 | # filter files that have annotation issues - is the input is a dictionary of 109 | # datasets, this will lilely have already been done 110 | if filter_issues: 111 | anns = [aa for aa in anns if aa["issues"] is False] 112 | 113 | # check for some basic formatting errors with class names 114 | for ann in anns: 115 | for aa in ann["annotation"]: 116 | aa["class"] = aa["class"].strip() 117 | 118 | # only load specified events - i.e. types of calls 119 | if events_of_interest is not None: 120 | for ann in anns: 121 | filtered_events = [] 122 | for aa in ann["annotation"]: 123 | if aa["event"] in events_of_interest: 124 | filtered_events.append(aa) 125 | ann["annotation"] = filtered_events 126 | 127 | # change class names 128 | # replace_names will be a dictionary mapping input name to output 129 | if type(name_replace) is dict: 130 | for ann in anns: 131 | for aa in ann["annotation"]: 132 | if aa["class"] in name_replace: 133 | aa["class"] = name_replace[aa["class"]] 134 | 135 | # convert everything to genus name 136 | if convert_to_genus: 137 | for ann in anns: 138 | for aa in ann["annotation"]: 139 | aa["class"] = aa["class"].split(" ")[0] 140 | 141 | # get unique class names 142 | class_names_all = [] 143 | for ann in anns: 144 | for aa in ann["annotation"]: 145 | if aa["class"] not in classes_to_ignore: 146 | class_names_all.append(aa["class"]) 147 | 148 | class_names, class_cnts = np.unique(class_names_all, return_counts=True) 149 | class_inv_freq = class_cnts.sum() / ( 150 | len(class_names) * class_cnts.astype(np.float32) 151 | ) 152 | 153 | if verbose: 154 | print("Class count:") 155 | str_len = np.max([len(cc) for cc in class_names]) + 5 156 | for cc in range(len(class_names)): 157 | print( 158 | str(cc).ljust(5) 159 | + class_names[cc].ljust(str_len) 160 | + str(class_cnts[cc]) 161 | ) 162 | 163 | if len(classes_to_ignore) == 0: 164 | return anns 165 | else: 166 | return anns, class_names.tolist(), class_inv_freq.tolist() 167 | 168 | 169 | def load_anns(ann_file_name, raw_audio_dir): 170 | with open(ann_file_name) as da: 171 | anns = json.load(da) 172 | 173 | for aa in anns: 174 | aa["file_path"] = raw_audio_dir + aa["id"] 175 | 176 | return anns 177 | 178 | 179 | def load_anns_from_path(ann_file_dir, raw_audio_dir): 180 | files = glob.glob(ann_file_dir + "*.json") 181 | anns = [] 182 | for ff in files: 183 | with open(ff) as da: 184 | ann = json.load(da) 185 | ann["file_path"] = raw_audio_dir + ann["id"] 186 | anns.append(ann) 187 | 188 | return anns 189 | 190 | 191 | class AverageMeter(object): 192 | """Computes and stores the average and current value""" 193 | 194 | def __init__(self): 195 | self.reset() 196 | 197 | def reset(self): 198 | self.val = 0 199 | self.avg = 0 200 | self.sum = 0 201 | self.count = 0 202 | 203 | def update(self, val, n=1): 204 | self.val = val 205 | self.sum += val * n 206 | self.count += n 207 | self.avg = self.sum / self.count 208 | -------------------------------------------------------------------------------- /batdetect2/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/batdetect2/utils/__init__.py -------------------------------------------------------------------------------- /batdetect2/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from matplotlib import patches 4 | from matplotlib.axes._axes import _log as matplotlib_axes_logger 5 | from sklearn.svm import LinearSVC 6 | 7 | matplotlib_axes_logger.setLevel("ERROR") 8 | 9 | 10 | colors = [ 11 | "#e6194B", 12 | "#3cb44b", 13 | "#ffe119", 14 | "#4363d8", 15 | "#f58231", 16 | "#911eb4", 17 | "#42d4f4", 18 | "#f032e6", 19 | "#bfef45", 20 | "#fabebe", 21 | "#469990", 22 | "#e6beff", 23 | "#9A6324", 24 | "#fffac8", 25 | "#800000", 26 | "#aaffc3", 27 | "#808000", 28 | "#ffd8b1", 29 | "#000075", 30 | "#a9a9a9", 31 | ] 32 | 33 | 34 | class InteractivePlotter: 35 | def __init__( 36 | self, 37 | feats_ds, 38 | feats, 39 | spec_slices, 40 | call_info, 41 | freq_lims, 42 | allow_training, 43 | ): 44 | """ 45 | Plots 2D low dimensional features on left and corresponding spectgrams on 46 | the right. 47 | """ 48 | self.feats_ds = feats_ds 49 | self.feats = feats 50 | self.clf = None 51 | 52 | self.spec_slices = spec_slices 53 | self.call_info = call_info 54 | # _, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True) 55 | self.labels = np.zeros(len(call_info), dtype=np.int) 56 | self.annotated = np.zeros( 57 | self.labels.shape[0], dtype=np.int 58 | ) # can populate this with 1's where we have labels 59 | self.labels_cols = [ 60 | colors[self.labels[ii]] for ii in range(len(self.labels)) 61 | ] 62 | self.freq_lims = freq_lims 63 | 64 | self.allow_training = allow_training 65 | self.pt_size = 5.0 66 | self.spec_pad = ( 67 | 0.2 # this much padding has been applied to the spec slices 68 | ) 69 | self.fig_width = 12 70 | self.fig_height = 8 71 | 72 | self.current_id = 0 73 | max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices]) 74 | self.max_width = self.spec_slices[max_ind].shape[1] 75 | self.blank_spec = np.zeros( 76 | (self.spec_slices[0].shape[0], self.max_width) 77 | ) 78 | 79 | def plot(self, fig_id): 80 | self.fig, self.ax = plt.subplots( 81 | nrows=1, 82 | ncols=2, 83 | num=fig_id, 84 | figsize=(self.fig_width, self.fig_height), 85 | gridspec_kw={"width_ratios": [2, 1]}, 86 | ) 87 | plt.tight_layout() 88 | 89 | # plot 2D TNSE features 90 | self.low_dim_plt = self.ax[0].scatter( 91 | self.feats_ds[:, 0], 92 | self.feats_ds[:, 1], 93 | c=self.labels_cols, 94 | s=self.pt_size, 95 | picker=5, 96 | ) 97 | self.ax[0].set_title("TSNE of Call Features") 98 | self.ax[0].set_xticks([]) 99 | self.ax[0].set_yticks([]) 100 | 101 | # plot clip from spectrogram 102 | spec_min_max = ( 103 | 0, 104 | self.blank_spec.shape[1], 105 | self.freq_lims[0], 106 | self.freq_lims[1], 107 | ) 108 | self.ax[1].imshow( 109 | self.blank_spec, extent=spec_min_max, cmap="plasma", aspect="auto" 110 | ) 111 | self.spec_im = self.ax[1].get_images()[0] 112 | self.ax[1].set_title("Spectrogram") 113 | self.ax[1].grid(color="w", linewidth=0.5) 114 | self.ax[1].set_xticks([]) 115 | self.ax[1].set_ylabel("kHz") 116 | 117 | bbox_orig = patches.Rectangle( 118 | (0, 0), 0, 0, edgecolor="w", linewidth=0, fill=False 119 | ) 120 | self.ax[1].add_patch(bbox_orig) 121 | 122 | self.annot = self.ax[0].annotate( 123 | "", 124 | xy=(0, 0), 125 | xytext=(20, 20), 126 | textcoords="offset points", 127 | bbox=dict(boxstyle="round", fc="w"), 128 | arrowprops=dict(arrowstyle="->"), 129 | ) 130 | self.annot.set_visible(False) 131 | 132 | self.fig.canvas.mpl_connect("motion_notify_event", self.mouse_hover) 133 | self.fig.canvas.mpl_connect("key_press_event", self.key_press) 134 | 135 | def mouse_hover(self, event): 136 | vis = self.annot.get_visible() 137 | if event.inaxes == self.ax[0]: 138 | cont, ind = self.low_dim_plt.contains(event) 139 | if cont: 140 | self.current_id = ind["ind"][0] 141 | 142 | # copy spec into full window - probably a better way of doing this 143 | new_spec = self.blank_spec.copy() 144 | w_diff = ( 145 | self.blank_spec.shape[1] 146 | - self.spec_slices[self.current_id].shape[1] 147 | ) // 2 148 | new_spec[ 149 | :, 150 | w_diff : self.spec_slices[self.current_id].shape[1] 151 | + w_diff, 152 | ] = self.spec_slices[self.current_id] 153 | self.spec_im.set_data(new_spec) 154 | self.spec_im.set_clim(vmin=0, vmax=new_spec.max()) 155 | 156 | # draw bounding box around call 157 | self.ax[1].patches[0].remove() 158 | spec_width_orig = self.spec_slices[self.current_id].shape[1] / ( 159 | 1.0 + 2.0 * self.spec_pad 160 | ) 161 | xx = w_diff + self.spec_pad * spec_width_orig 162 | ww = spec_width_orig 163 | yy = self.call_info[self.current_id]["low_freq"] / 1000 164 | hh = ( 165 | self.call_info[self.current_id]["high_freq"] 166 | - self.call_info[self.current_id]["low_freq"] 167 | ) / 1000 168 | bbox = patches.Rectangle( 169 | (xx, yy), ww, hh, edgecolor="r", linewidth=0.5, fill=False 170 | ) 171 | self.ax[1].add_patch(bbox) 172 | 173 | # update annotation arrow 174 | pos = self.low_dim_plt.get_offsets()[self.current_id] 175 | self.annot.xy = pos 176 | self.annot.set_visible(True) 177 | 178 | # write call info 179 | info_str = ( 180 | self.call_info[self.current_id]["file_name"] 181 | + ", time=" 182 | + str( 183 | round(self.call_info[self.current_id]["start_time"], 3) 184 | ) 185 | + ", prob=" 186 | + str(round(self.call_info[self.current_id]["det_prob"], 3)) 187 | ) 188 | self.ax[0].set_xlabel(info_str) 189 | 190 | # redraw 191 | self.fig.canvas.draw_idle() 192 | 193 | def key_press(self, event): 194 | if event.key.isdigit(): 195 | self.labels_cols[self.current_id] = colors[int(event.key)] 196 | self.labels[self.current_id] = int(event.key) 197 | self.annotated[self.current_id] = 1 198 | elif event.key == "enter" and self.allow_training: 199 | self.train_classifier() 200 | elif event.key == "x" and self.allow_training: 201 | self.get_classifier_params() 202 | 203 | self.ax[0].scatter( 204 | self.feats_ds[:, 0], 205 | self.feats_ds[:, 1], 206 | c=self.labels_cols, 207 | s=self.pt_size, 208 | ) 209 | self.fig.canvas.draw_idle() 210 | 211 | def train_classifier(self): 212 | # TODO maybe it's better to classify in 2D space - but then can't be linear ... 213 | inds = np.where(self.annotated == 1)[0] 214 | labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True) 215 | 216 | if labs_un.shape[0] > 1: # needs at least 2 classes 217 | self.clf = LinearSVC( 218 | C=1.0, 219 | penalty="l2", 220 | loss="squared_hinge", 221 | tol=0.0001, 222 | intercept_scaling=1.0, 223 | max_iter=2000, 224 | ) 225 | 226 | self.clf.fit(self.feats[inds, :], self.labels[inds]) 227 | 228 | # update labels 229 | inds_unlab = np.where(self.annotated == 0)[0] 230 | self.labels[inds_unlab] = self.clf.predict(self.feats[inds_unlab]) 231 | for ii in inds_unlab: 232 | self.labels_cols[ii] = colors[self.labels[ii]] 233 | else: 234 | print("Not enough data - please label more classes.") 235 | 236 | def get_classifier_params(self): 237 | res = {} 238 | if self.clf is None: 239 | print("Model not trained!") 240 | else: 241 | res["weights"] = self.clf.coef_.astype(np.float32) 242 | res["biases"] = self.clf.intercept_.astype(np.float32) 243 | return res 244 | -------------------------------------------------------------------------------- /batdetect2/utils/wavfile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to read / write wav files using numpy arrays 3 | 4 | Functions 5 | --------- 6 | `read`: Return the sample rate (in samples/sec) and data from a WAV file. 7 | 8 | `write`: Write a numpy array as a WAV file. 9 | 10 | """ 11 | from __future__ import absolute_import, division, print_function 12 | 13 | import os 14 | import struct 15 | import sys 16 | import warnings 17 | 18 | import numpy 19 | 20 | 21 | class WavFileWarning(UserWarning): 22 | pass 23 | 24 | 25 | _big_endian = False 26 | 27 | WAVE_FORMAT_PCM = 0x0001 28 | WAVE_FORMAT_IEEE_FLOAT = 0x0003 29 | WAVE_FORMAT_EXTENSIBLE = 0xFFFE 30 | KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT) 31 | 32 | # assumes file pointer is immediately 33 | # after the 'fmt ' id 34 | 35 | 36 | def _read_fmt_chunk(fid): 37 | if _big_endian: 38 | fmt = ">" 39 | else: 40 | fmt = "<" 41 | res = struct.unpack(fmt + "iHHIIHH", fid.read(20)) 42 | size, comp, noc, rate, sbytes, ba, bits = res 43 | if comp not in KNOWN_WAVE_FORMATS or size > 16: 44 | comp = WAVE_FORMAT_PCM 45 | warnings.warn("Unknown wave file format", WavFileWarning) 46 | if size > 16: 47 | fid.read(size - 16) 48 | 49 | return size, comp, noc, rate, sbytes, ba, bits 50 | 51 | 52 | # assumes file pointer is immediately 53 | # after the 'data' id 54 | def _read_data_chunk(fid, comp, noc, bits, mmap=False): 55 | if _big_endian: 56 | fmt = ">i" 57 | else: 58 | fmt = " 1: 83 | data = data.reshape(-1, noc) 84 | return data 85 | 86 | 87 | def _skip_unknown_chunk(fid): 88 | if _big_endian: 89 | fmt = ">i" 90 | else: 91 | fmt = "" or ( 280 | data.dtype.byteorder == "=" and sys.byteorder == "big" 281 | ): 282 | data = data.byteswap() 283 | _array_tofile(fid, data) 284 | 285 | # Determine file size and place it in correct 286 | # position at start of the file (replacing the 4 bytes of zeros) 287 | size = fid.tell() 288 | fid.seek(4) 289 | fid.write(struct.pack("= 3: 299 | 300 | def _array_tofile(fid, data): 301 | # ravel gives a c-contiguous buffer 302 | fid.write(data.ravel().view("b").data) 303 | 304 | else: 305 | 306 | def _array_tofile(fid, data): 307 | fid.write(data.tostring()) 308 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: batdetect2 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | - nvidia 7 | dependencies: 8 | - python==3.10 9 | - matplotlib 10 | - pandas 11 | - scikit-learn 12 | - numpy 13 | - pytorch 14 | - scipy 15 | - torchvision 16 | - librosa 17 | - torchaudio 18 | -------------------------------------------------------------------------------- /example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav -------------------------------------------------------------------------------- /example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav -------------------------------------------------------------------------------- /example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav -------------------------------------------------------------------------------- /faq.md: -------------------------------------------------------------------------------- 1 | # BatDetect2 - FAQ 2 | 3 | ## Installation 4 | 5 | #### Do I need to know Python to be able to use this? 6 | No. To simply run the code on your own data you do not need any knowledge of Python. However, a small bit of familiarity with the terminal (i.e. command line) in Windows/Linux/OSX may make things easier. 7 | 8 | 9 | #### Are there any plans for an R version? 10 | Currently no. All the scripts export simple `.csv` files that can be read using any programming language of choice. 11 | 12 | 13 | #### How do I install the code? 14 | The codebase has been tested under Windows 10, Ubuntu, and OSX. Read the instructions in the main readme to get started. If you are having problems getting it working and you feel like you have tried everything (e.g. confirming that your Anaconda Python distribution is correctly installed) feel free to open an issue on GitHub. 15 | 16 | 17 | ## Performance 18 | 19 | #### The model does not work very well on my data? 20 | Our model is based on a machine learning approach and as such if your data is very different from our training set it may not work as well. Feel free to use our annotation tools to label some of your own data and retrain the model. Even better, if you have large quantities of audio data with reliable species data that you are willing to share with the community please get in touch. 21 | 22 | 23 | #### The model is incorrectly classifying insects/noise/... as bats? 24 | Fine-tuning the model on your data can make a big difference. See previous answer. 25 | 26 | 27 | #### The model fails to correctly detect feeding buzzes and social calls? 28 | This is a limitation of our current training data. If you have such data or would be willing to label some for us please get in touch. 29 | 30 | 31 | #### Calls that are clearly belonging to the same call sequence are being predicted as coming from different species? 32 | Currently we do not do any sophisticated post processing on the results output by the model. We return a probability associated with each species for each call. You can use these predictions to clean up the noisy predictions for sequences of calls. 33 | 34 | 35 | #### Can I trust the model outputs? 36 | The models developed and shared as part of this repository should be used with caution. While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment. Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted. 37 | 38 | 39 | #### The code works well but it is slow? 40 | Try a different/faster computer. On a reasonably recent desktop it takes about 13 seconds (on the GPU) or 1.3 minutes (on the CPU) to process 7.5 minutes of audio. In general, we observe a factor of ~5-10 speed up using recent Nvidia GPUs compared to CPU only systems. 41 | 42 | 43 | #### My audio files are very big and as a result the code is slow. 44 | If your audio files are very long in duration (i.e. multiple minutes) it might be better to split them up into several smaller files. Have a look at the instructions and scripts in our annotation GUI codebase for how to crop your files into shorter ones - see [here](https://github.com/macaodha/batdetect2_GUI). 45 | 46 | 47 | ## Training a new model 48 | 49 | #### Can I train a model on my own bat data with different species? 50 | Yes. You just need to provide annotations in the correct format. 51 | 52 | 53 | #### Will this work for frequency-division or zero-crossing recordings? 54 | No. The code assumes that we can convert the input audio into a spectrogram. 55 | 56 | 57 | #### Will this code work for non-bat audio data e.g. insects or birds? 58 | In principle yes, however you may need to change some of the training hyper-parameters to ignore high frequency information when you re-train. Please open an issue on GitHub if you have a specific request. 59 | 60 | 61 | 62 | ## Usage 63 | 64 | #### Can I use the code for commercial purposes or incorporate raw source code or trained models into my commercial system? 65 | No. This codebase is currently only for non-commercial use. See the license. 66 | -------------------------------------------------------------------------------- /ims/bat_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/ims/bat_icon.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "batdetect2" 3 | version = "1.3.0" 4 | description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings." 5 | authors = [ 6 | { "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" }, 7 | { "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }, 8 | ] 9 | dependencies = [ 10 | "click>=8.1.7", 11 | "librosa>=0.10.1", 12 | "matplotlib>=3.7.1", 13 | "numpy>=1.23.5", 14 | "pandas>=1.5.3", 15 | "scikit-learn>=1.2.2", 16 | "scipy>=1.10.1", 17 | "torch>=1.13.1,<2.5.0", 18 | "torchaudio>=1.13.1,<2.5.0", 19 | "torchvision>=0.14.0", 20 | ] 21 | requires-python = ">=3.9,<3.13" 22 | readme = "README.md" 23 | license = { text = "CC-by-nc-4" } 24 | classifiers = [ 25 | "Development Status :: 4 - Beta", 26 | "Intended Audience :: Science/Research", 27 | "Natural Language :: English", 28 | "Operating System :: OS Independent", 29 | "Programming Language :: Python :: 3.9", 30 | "Programming Language :: Python :: 3.10", 31 | "Programming Language :: Python :: 3.11", 32 | "Programming Language :: Python :: 3.12", 33 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 34 | "Topic :: Software Development :: Libraries :: Python Modules", 35 | "Topic :: Multimedia :: Sound/Audio :: Analysis", 36 | ] 37 | keywords = [ 38 | "bat", 39 | "echolocation", 40 | "deep learning", 41 | "audio", 42 | "machine learning", 43 | "classification", 44 | "detection", 45 | ] 46 | 47 | [build-system] 48 | requires = ["hatchling"] 49 | build-backend = "hatchling.build" 50 | 51 | [project.scripts] 52 | batdetect2 = "batdetect2.cli:cli" 53 | 54 | [tool.uv] 55 | dev-dependencies = [ 56 | "debugpy>=1.8.8", 57 | "hypothesis>=6.118.7", 58 | "pyright>=1.1.388", 59 | "pytest>=7.2.2", 60 | "ruff>=0.7.3", 61 | ] 62 | 63 | [tool.ruff] 64 | line-length = 79 65 | target-version = "py39" 66 | 67 | [tool.ruff.format] 68 | docstring-code-format = true 69 | docstring-code-line-length = 79 70 | 71 | [tool.ruff.lint] 72 | select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"] 73 | 74 | [tool.ruff.lint.pydocstyle] 75 | convention = "numpy" 76 | 77 | [tool.pyright] 78 | include = ["batdetect2", "tests"] 79 | venvPath = "." 80 | venv = ".venv" 81 | pythonVersion = "3.9" 82 | pythonPlatform = "All" 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.9.2 2 | matplotlib==3.6.2 3 | numpy==1.23.4 4 | pandas==1.5.2 5 | scikit_learn==1.2.0 6 | scipy==1.9.3 7 | torch==1.13.0 8 | torchaudio==0.13.0 9 | torchvision==0.14.0 10 | click 11 | -------------------------------------------------------------------------------- /run_batdetect.py: -------------------------------------------------------------------------------- 1 | """Run batdetect2.command.main() from the command line.""" 2 | from batdetect2.cli import detect 3 | 4 | if __name__ == "__main__": 5 | detect() 6 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | This directory contains some scripts for visualizing the raw data and model outputs. 2 | 3 | 4 | `gen_spec_image.py`: saves the model predictions on a spectrogram of the input audio file. 5 | e.g. 6 | `python gen_spec_image.py ../example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav ../models/Net2DFast_UK_same.pth.tar` 7 | 8 | 9 | `gen_spec_video.py`: generates a video showing the model predictions for a file. 10 | e.g. 11 | `python gen_spec_video.py ../example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav ../models/Net2DFast_UK_same.pth.tar` 12 | 13 | 14 | 15 | `gen_dataset_summary_image.py`: generates an image displaying the mean spectrogram for each class in a specified dataset. 16 | e.g. 17 | `python gen_dataset_summary_image.py --ann_file PATH_TO_ANN/australia_TRAIN.json PATH_TO_AUDIO/audio/ ../plots/australia/` 18 | -------------------------------------------------------------------------------- /scripts/gen_dataset_summary_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loads a set of annotations corresponding to a dataset and saves an image which 3 | is the mean spectrogram for each class. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import viz_helpers as vz 12 | 13 | import batdetect2.detector.parameters as parameters 14 | import batdetect2.train.train_split as ts 15 | import batdetect2.train.train_utils as tu 16 | import batdetect2.utils.audio_utils as au 17 | 18 | if __name__ == "__main__": 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "audio_path", type=str, help="Input directory for audio" 23 | ) 24 | parser.add_argument( 25 | "op_dir", 26 | type=str, 27 | help="Path to where single annotation json file is stored", 28 | ) 29 | parser.add_argument( 30 | "--ann_file", 31 | type=str, 32 | help="Path to where single annotation json file is stored", 33 | ) 34 | parser.add_argument( 35 | "--uk_split", type=str, default="", help="Set as: diff or same" 36 | ) 37 | parser.add_argument( 38 | "--file_type", 39 | type=str, 40 | default="png", 41 | help="Type of image to save png or pdf", 42 | ) 43 | args = vars(parser.parse_args()) 44 | 45 | if not os.path.isdir(args["op_dir"]): 46 | os.makedirs(args["op_dir"]) 47 | 48 | params = parameters.get_params(False) 49 | params["smooth_spec"] = False 50 | params["spec_width"] = 48 51 | params["norm_type"] = "log" # log, pcen 52 | params["aud_pad"] = 0.005 53 | classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] 54 | 55 | # load train annotations 56 | if args["uk_split"] == "": 57 | print("\nLoading:", args["ann_file"], "\n") 58 | dataset_name = os.path.basename(args["ann_file"]).replace(".json", "") 59 | datasets = [] 60 | datasets.append( 61 | tu.get_blank_dataset_dict( 62 | dataset_name, False, args["ann_file"], args["audio_path"] 63 | ) 64 | ) 65 | else: 66 | # load uk data - special case 67 | print("\nLoading:", args["uk_split"], "\n") 68 | dataset_name = "uk_" + args["uk_split"] # should be uk_diff, or uk_same 69 | datasets, _ = ts.get_train_test_data( 70 | args["ann_file"], 71 | args["audio_path"], 72 | args["uk_split"], 73 | load_extra=False, 74 | ) 75 | 76 | anns, class_names, _ = tu.load_set_of_anns( 77 | datasets, classes_to_ignore, params["events_of_interest"] 78 | ) 79 | class_names_order = range(len(class_names)) 80 | 81 | x_train, y_train = vz.load_data( 82 | anns, 83 | params, 84 | class_names, 85 | smooth_spec=params["smooth_spec"], 86 | norm_type=params["norm_type"], 87 | ) 88 | 89 | op_file_name = os.path.join( 90 | args["op_dir"], dataset_name + "." + args["file_type"] 91 | ) 92 | vz.save_summary_image( 93 | x_train, y_train, class_names, params, op_file_name, class_names_order 94 | ) 95 | print("\nImage saved to:", op_file_name) 96 | -------------------------------------------------------------------------------- /scripts/gen_spec_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualize predctions on top of spectrogram. 3 | 4 | Will save images with: 5 | 1) raw spectrogram 6 | 2) spectrogram with GT boxes 7 | 3) spectrogram with predicted boxes 8 | """ 9 | 10 | import argparse 11 | import json 12 | import os 13 | import sys 14 | 15 | import torch 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | 19 | import batdetect2.evaluate.evaluate_models as evlm 20 | import batdetect2.utils.audio_utils as au 21 | import batdetect2.utils.detector_utils as du 22 | import batdetect2.utils.plot_utils as viz 23 | 24 | 25 | def filter_anns(anns, start_time, stop_time): 26 | anns_op = [] 27 | for aa in anns: 28 | if (aa["start_time"] >= start_time) and ( 29 | aa["start_time"] < stop_time - 0.02 30 | ): 31 | anns_op.append(aa) 32 | return anns_op 33 | 34 | 35 | if __name__ == "__main__": 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("audio_file", type=str, help="Path to audio file") 39 | parser.add_argument("model_path", type=str, help="Path to BatDetect model") 40 | parser.add_argument( 41 | "--ann_file", type=str, default="", help="Path to annotation file" 42 | ) 43 | parser.add_argument( 44 | "--op_dir", 45 | type=str, 46 | default="plots/", 47 | help="Output directory for plots", 48 | ) 49 | parser.add_argument( 50 | "--file_type", 51 | type=str, 52 | default="png", 53 | help="Type of image to save png or pdf", 54 | ) 55 | parser.add_argument( 56 | "--title_text", 57 | type=str, 58 | default="", 59 | help="Text to add as title of plots", 60 | ) 61 | parser.add_argument( 62 | "--detection_threshold", 63 | type=float, 64 | default=0.2, 65 | help="Threshold for output detections", 66 | ) 67 | parser.add_argument( 68 | "--start_time", 69 | type=float, 70 | default=0.0, 71 | help="Start time for cropped file", 72 | ) 73 | parser.add_argument( 74 | "--stop_time", 75 | type=float, 76 | default=0.5, 77 | help="End time for cropped file", 78 | ) 79 | parser.add_argument( 80 | "--time_expansion_factor", 81 | type=int, 82 | default=1, 83 | help="Time expansion factor", 84 | ) 85 | 86 | args_cmd = vars(parser.parse_args()) 87 | 88 | # load the model 89 | bd_args = du.get_default_bd_args() 90 | model, params_bd = du.load_model(args_cmd["model_path"]) 91 | bd_args["detection_threshold"] = args_cmd["detection_threshold"] 92 | bd_args["time_expansion_factor"] = args_cmd["time_expansion_factor"] 93 | 94 | # load the annotation if it exists 95 | gt_present = False 96 | if args_cmd["ann_file"] != "": 97 | if os.path.isfile(args_cmd["ann_file"]): 98 | with open(args_cmd["ann_file"]) as da: 99 | gt_anns = json.load(da) 100 | gt_anns = filter_anns( 101 | gt_anns["annotation"], 102 | args_cmd["start_time"], 103 | args_cmd["stop_time"], 104 | ) 105 | gt_present = True 106 | else: 107 | print("Annotation file not found: ", args_cmd["ann_file"]) 108 | 109 | # load the audio file 110 | if not os.path.isfile(args_cmd["audio_file"]): 111 | print("Audio file not found: ", args_cmd["audio_file"]) 112 | sys.exit() 113 | 114 | # load audio and crop 115 | print("\nProcessing: " + os.path.basename(args_cmd["audio_file"])) 116 | print("\nOutput directory: " + args_cmd["op_dir"]) 117 | sampling_rate, audio = au.load_audio( 118 | args_cmd["audio_file"], 119 | args_cmd["time_exp"], 120 | params_bd["target_samp_rate"], 121 | params_bd["scale_raw_audio"], 122 | ) 123 | st_samp = int(sampling_rate * args_cmd["start_time"]) 124 | en_samp = int(sampling_rate * args_cmd["stop_time"]) 125 | if en_samp > audio.shape[0]: 126 | audio = np.hstack( 127 | (audio, np.zeros((en_samp) - audio.shape[0], dtype=audio.dtype)) 128 | ) 129 | audio = audio[st_samp:en_samp] 130 | 131 | duration = audio.shape[0] / sampling_rate 132 | print("File duration: {} seconds".format(duration)) 133 | 134 | # create spec for viz 135 | spec, _ = au.generate_spectrogram( 136 | audio, sampling_rate, params_bd, True, False 137 | ) 138 | 139 | run_config = { 140 | **params_bd, 141 | **bd_args, 142 | } 143 | 144 | # run model and filter detections so only keep ones in relevant time range 145 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 146 | results = du.process_file(args_cmd["audio_file"], model, run_config, device) 147 | pred_anns = filter_anns( 148 | results["pred_dict"]["annotation"], 149 | args_cmd["start_time"], 150 | args_cmd["stop_time"], 151 | ) 152 | print(len(pred_anns), "Detections") 153 | 154 | # save output 155 | if not os.path.isdir(args_cmd["op_dir"]): 156 | os.makedirs(args_cmd["op_dir"]) 157 | 158 | # create output file names 159 | op_path_clean = ( 160 | os.path.basename(args_cmd["audio_file"])[:-4] 161 | + "_clean." 162 | + args_cmd["file_type"] 163 | ) 164 | op_path_clean = os.path.join(args_cmd["op_dir"], op_path_clean) 165 | op_path_pred = ( 166 | os.path.basename(args_cmd["audio_file"])[:-4] 167 | + "_pred." 168 | + args_cmd["file_type"] 169 | ) 170 | op_path_pred = os.path.join(args_cmd["op_dir"], op_path_pred) 171 | 172 | # create and save iamges 173 | viz.save_ann_spec( 174 | op_path_clean, 175 | spec, 176 | params_bd["min_freq"], 177 | params_bd["max_freq"], 178 | duration, 179 | args_cmd["start_time"], 180 | "", 181 | None, 182 | ) 183 | viz.save_ann_spec( 184 | op_path_pred, 185 | spec, 186 | params_bd["min_freq"], 187 | params_bd["max_freq"], 188 | duration, 189 | args_cmd["start_time"], 190 | "", 191 | pred_anns, 192 | ) 193 | 194 | if gt_present: 195 | op_path_gt = ( 196 | os.path.basename(args_cmd["audio_file"])[:-4] 197 | + "_gt." 198 | + args_cmd["file_type"] 199 | ) 200 | op_path_gt = os.path.join(args_cmd["op_dir"], op_path_gt) 201 | viz.save_ann_spec( 202 | op_path_gt, 203 | spec, 204 | params_bd["min_freq"], 205 | params_bd["max_freq"], 206 | duration, 207 | args_cmd["start_time"], 208 | "", 209 | gt_anns, 210 | ) 211 | -------------------------------------------------------------------------------- /scripts/gen_spec_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script takes an audio file as input, runs the detector, and makes a video output 3 | 4 | Notes: 5 | It needs ffmpeg installed to make the videos 6 | Sometimes conda can overwrite the default ffmpeg path set this to use system one. 7 | Check which one is being used with `which ffmpeg`. If conda version, can thow an error. 8 | Best to use system one - see ffmpeg_path. 9 | """ 10 | 11 | import argparse 12 | import os 13 | import shutil 14 | import sys 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import torch 19 | from scipy.io import wavfile 20 | 21 | import batdetect2.detector.parameters as parameters 22 | import batdetect2.utils.audio_utils as au 23 | import batdetect2.utils.detector_utils as du 24 | import batdetect2.utils.plot_utils as viz 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("audio_file", type=str, help="Path to input audio file") 29 | parser.add_argument( 30 | "model_path", type=str, help="Path to trained BatDetect model" 31 | ) 32 | parser.add_argument( 33 | "--op_dir", 34 | type=str, 35 | default="generated_vids/", 36 | help="Path to output directory", 37 | ) 38 | parser.add_argument( 39 | "--no_detector", action="store_true", help="Do not run detector" 40 | ) 41 | parser.add_argument( 42 | "--plot_class_names_off", 43 | action="store_true", 44 | help="Do not plot class names", 45 | ) 46 | parser.add_argument( 47 | "--disable_axis", action="store_true", help="Do not plot axis" 48 | ) 49 | parser.add_argument( 50 | "--detection_threshold", 51 | type=float, 52 | default=0.2, 53 | help="Cut-off probability for detector", 54 | ) 55 | parser.add_argument( 56 | "--time_expansion_factor", 57 | type=int, 58 | default=1, 59 | dest="time_expansion_factor", 60 | help="The time expansion factor used for all files (default is 1)", 61 | ) 62 | args_cmd = vars(parser.parse_args()) 63 | 64 | # file of interest 65 | audio_file = args_cmd["audio_file"] 66 | op_dir = args_cmd["op_dir"] 67 | op_str = "_output" 68 | ffmpeg_path = "/usr/bin/" 69 | 70 | if not os.path.isfile(audio_file): 71 | print("Audio file not found: ", audio_file) 72 | sys.exit() 73 | 74 | if not os.path.isfile(args_cmd["model_path"]): 75 | print("Model not found: ", args_cmd["model_path"]) 76 | sys.exit() 77 | 78 | start_time = 0.0 79 | duration = 0.5 80 | reveal_boxes = True # makes the boxes appear one at a time 81 | fps = 24 82 | dpi = 100 83 | 84 | op_dir_tmp = os.path.join(op_dir, "op_tmp_vids", "") 85 | if not os.path.isdir(op_dir_tmp): 86 | os.makedirs(op_dir_tmp) 87 | if not os.path.isdir(op_dir): 88 | os.makedirs(op_dir) 89 | 90 | params = parameters.get_params(False) 91 | args = du.get_default_bd_args() 92 | args["time_expansion_factor"] = args_cmd["time_expansion_factor"] 93 | args["detection_threshold"] = args_cmd["detection_threshold"] 94 | 95 | # load audio file 96 | print("\nProcessing: " + os.path.basename(audio_file)) 97 | print("\nOutput directory: " + op_dir) 98 | sampling_rate, audio = au.load_audio( 99 | audio_file, args["time_expansion_factor"], params["target_samp_rate"] 100 | ) 101 | audio = audio[ 102 | int(sampling_rate * start_time) : int( 103 | sampling_rate * start_time + sampling_rate * duration 104 | ) 105 | ] 106 | audio_orig = audio.copy() 107 | audio = au.pad_audio( 108 | audio, 109 | sampling_rate, 110 | params["fft_win_length"], 111 | params["fft_overlap"], 112 | params["resize_factor"], 113 | params["spec_divide_factor"], 114 | ) 115 | 116 | # generate spectrogram 117 | spec, _ = au.generate_spectrogram(audio, sampling_rate, params, True) 118 | max_val = spec.max() * 1.1 119 | 120 | if not args_cmd["no_detector"]: 121 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 122 | 123 | print(" Loading model and running detector on entire file ...") 124 | model, det_params = du.load_model(args_cmd["model_path"]) 125 | det_params["detection_threshold"] = args["detection_threshold"] 126 | 127 | run_config = { 128 | **det_params, 129 | **args, 130 | } 131 | results = du.process_file( 132 | audio_file, 133 | model, 134 | run_config, 135 | device, 136 | ) 137 | 138 | print(" Processing detections and plotting ...") 139 | detections = [] 140 | for bb in results["pred_dict"]["annotation"]: 141 | if (bb["start_time"] >= start_time) and ( 142 | bb["end_time"] < start_time + duration 143 | ): 144 | detections.append(bb) 145 | 146 | # plot boxes 147 | fig = plt.figure( 148 | 1, figsize=(spec.shape[1] / dpi, spec.shape[0] / dpi), dpi=dpi 149 | ) 150 | duration = au.x_coords_to_time( 151 | spec.shape[1], 152 | sampling_rate, 153 | params["fft_win_length"], 154 | params["fft_overlap"], 155 | ) 156 | viz.create_box_image( 157 | spec, 158 | fig, 159 | detections, 160 | start_time, 161 | start_time + duration, 162 | duration, 163 | params, 164 | max_val, 165 | plot_class_names=not args_cmd["plot_class_names_off"], 166 | ) 167 | op_im_file_boxes = os.path.join( 168 | op_dir, os.path.basename(audio_file)[:-4] + op_str + "_boxes.png" 169 | ) 170 | fig.savefig(op_im_file_boxes, dpi=dpi) 171 | plt.close(1) 172 | spec_with_boxes = plt.imread(op_im_file_boxes) 173 | 174 | print(" Saving audio file ...") 175 | if args["time_expansion_factor"] == 1: 176 | sampling_rate_op = int(sampling_rate / 10.0) 177 | else: 178 | sampling_rate_op = sampling_rate 179 | op_audio_file = os.path.join( 180 | op_dir, os.path.basename(audio_file)[:-4] + op_str + ".wav" 181 | ) 182 | wavfile.write(op_audio_file, sampling_rate_op, audio_orig) 183 | 184 | print(" Saving image ...") 185 | op_im_file = os.path.join( 186 | op_dir, os.path.basename(audio_file)[:-4] + op_str + ".png" 187 | ) 188 | plt.imsave(op_im_file, spec, vmin=0, vmax=max_val, cmap="plasma") 189 | spec_blank = plt.imread(op_im_file) 190 | 191 | # create figure 192 | freq_scale = 1000 # turn Hz to kHz 193 | min_freq = params["min_freq"] // freq_scale 194 | max_freq = params["max_freq"] // freq_scale 195 | y_extent = [0, duration, min_freq, max_freq] 196 | 197 | print(" Saving video frames ...") 198 | # save images that will be combined into video 199 | # will either plot with or without boxes 200 | for ii, col in enumerate( 201 | np.linspace(0, spec.shape[1] - 1, int(fps * duration * 10)) 202 | ): 203 | if not args_cmd["no_detector"]: 204 | spec_op = spec_with_boxes.copy() 205 | if ii > 0: 206 | spec_op[:, int(col), :] = 1.0 207 | if reveal_boxes: 208 | spec_op[:, int(col) + 1 :, :] = spec_blank[ 209 | :, int(col) + 1 :, : 210 | ] 211 | elif ii == 0 and reveal_boxes: 212 | spec_op = spec_blank 213 | 214 | if not args_cmd["disable_axis"]: 215 | plt.close("all") 216 | fig = plt.figure( 217 | ii, 218 | figsize=( 219 | 1.2 * (spec_op.shape[1] / dpi), 220 | 1.5 * (spec_op.shape[0] / dpi), 221 | ), 222 | dpi=dpi, 223 | ) 224 | plt.xlabel("Time - seconds") 225 | plt.ylabel("Frequency - kHz") 226 | plt.imshow( 227 | spec_op, 228 | vmin=0, 229 | vmax=1.0, 230 | cmap="plasma", 231 | extent=y_extent, 232 | aspect="auto", 233 | ) 234 | plt.tight_layout() 235 | fig.savefig(op_dir_tmp + str(ii).zfill(4) + ".png", dpi=dpi) 236 | else: 237 | plt.imsave( 238 | op_dir_tmp + str(ii).zfill(4) + ".png", 239 | spec_op, 240 | vmin=0, 241 | vmax=1.0, 242 | cmap="plasma", 243 | ) 244 | else: 245 | spec_op = spec.copy() 246 | if ii > 0: 247 | spec_op[:, int(col)] = max_val 248 | plt.imsave( 249 | op_dir_tmp + str(ii).zfill(4) + ".png", 250 | spec_op, 251 | vmin=0, 252 | vmax=max_val, 253 | cmap="plasma", 254 | ) 255 | 256 | print(" Creating video ...") 257 | op_vid_file = os.path.join( 258 | op_dir, os.path.basename(audio_file)[:-4] + op_str + ".avi" 259 | ) 260 | ffmpeg_cmd = ( 261 | "ffmpeg -hide_banner -loglevel panic -y -r {} -f image2 -s {}x{} -i {}%04d.png -i {} -vcodec libx264 " 262 | "-crf 25 -pix_fmt yuv420p -acodec copy {}".format( 263 | fps, 264 | spec.shape[1], 265 | spec.shape[0], 266 | op_dir_tmp, 267 | op_audio_file, 268 | op_vid_file, 269 | ) 270 | ) 271 | ffmpeg_cmd = ffmpeg_path + ffmpeg_cmd 272 | os.system(ffmpeg_cmd) 273 | 274 | print(" Deleting temporary files ...") 275 | if os.path.isdir(op_dir_tmp): 276 | shutil.rmtree(op_dir_tmp) 277 | -------------------------------------------------------------------------------- /scripts/viz_helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from scipy import ndimage 7 | 8 | sys.path.append(os.path.join("..")) 9 | 10 | import batdetect2.utils.audio_utils as au 11 | 12 | 13 | def generate_spectrogram_data( 14 | audio, sampling_rate, params, norm_type="log", smooth_spec=False 15 | ): 16 | max_freq = round(params["max_freq"] * params["fft_win_length"]) 17 | min_freq = round(params["min_freq"] * params["fft_win_length"]) 18 | 19 | # create spectrogram - numpy 20 | spec = au.gen_mag_spectrogram( 21 | audio, sampling_rate, params["fft_win_length"], params["fft_overlap"] 22 | ) 23 | # spec = au.gen_mag_spectrogram_pt(audio, sampling_rate, params['fft_win_length'], params['fft_overlap']).numpy() 24 | if spec.shape[0] < max_freq: 25 | freq_pad = max_freq - spec.shape[0] 26 | spec = np.vstack( 27 | (np.zeros((freq_pad, spec.shape[1]), dtype=np.float32), spec) 28 | ) 29 | spec = spec[-max_freq : spec.shape[0] - min_freq, :] 30 | 31 | if norm_type == "log": 32 | log_scaling = ( 33 | 2.0 34 | * (1.0 / sampling_rate) 35 | * ( 36 | 1.0 37 | / ( 38 | np.abs( 39 | np.hanning( 40 | int(params["fft_win_length"] * sampling_rate) 41 | ) 42 | ) 43 | ** 2 44 | ).sum() 45 | ) 46 | ) 47 | ##log_scaling = 0.01 48 | spec = np.log(1.0 + log_scaling * spec).astype(np.float32) 49 | elif norm_type == "pcen": 50 | spec = au.pcen(spec, sampling_rate) 51 | else: 52 | pass 53 | 54 | if smooth_spec: 55 | spec = ndimage.gaussian_filter(spec, 1) 56 | 57 | return spec 58 | 59 | 60 | def load_data( 61 | anns, 62 | params, 63 | class_names, 64 | smooth_spec=False, 65 | norm_type="log", 66 | extract_bg=False, 67 | ): 68 | specs = [] 69 | labels = [] 70 | coords = [] 71 | audios = [] 72 | sampling_rates = [] 73 | file_names = [] 74 | for cur_file in anns: 75 | sampling_rate, audio_orig = au.load_audio( 76 | cur_file["file_path"], 77 | cur_file["time_exp"], 78 | params["target_samp_rate"], 79 | params["scale_raw_audio"], 80 | ) 81 | 82 | for ann in cur_file["annotation"]: 83 | if ( 84 | ann["class"] not in params["classes_to_ignore"] 85 | and ann["class"] in class_names 86 | ): 87 | # clip out of bounds 88 | if ann["low_freq"] < params["min_freq"]: 89 | ann["low_freq"] = params["min_freq"] 90 | if ann["high_freq"] > params["max_freq"]: 91 | ann["high_freq"] = params["max_freq"] 92 | 93 | # load cropped audio 94 | start_samp_diff = int(sampling_rate * ann["start_time"]) - int( 95 | sampling_rate * params["aud_pad"] 96 | ) 97 | start_samp = np.maximum(0, start_samp_diff) 98 | end_samp = np.minimum( 99 | audio_orig.shape[0], 100 | int(sampling_rate * ann["end_time"]) * 2 101 | + int(sampling_rate * params["aud_pad"]), 102 | ) 103 | audio = audio_orig[start_samp:end_samp] 104 | if start_samp_diff < 0: 105 | # need to pad at start if the call is at the very begining 106 | audio = np.hstack( 107 | (np.zeros(-start_samp_diff, dtype=np.float32), audio) 108 | ) 109 | 110 | nfft = int(params["fft_win_length"] * sampling_rate) 111 | noverlap = int(params["fft_overlap"] * nfft) 112 | max_samps = params["spec_width"] * (nfft - noverlap) + noverlap 113 | 114 | if max_samps > audio.shape[0]: 115 | audio = np.hstack( 116 | (audio, np.zeros(max_samps - audio.shape[0])) 117 | ) 118 | audio = audio[:max_samps].astype(np.float32) 119 | 120 | audio = au.pad_audio( 121 | audio, 122 | sampling_rate, 123 | params["fft_win_length"], 124 | params["fft_overlap"], 125 | params["resize_factor"], 126 | params["spec_divide_factor"], 127 | ) 128 | 129 | # generate spectrogram 130 | spec = generate_spectrogram_data( 131 | audio, sampling_rate, params, norm_type, smooth_spec 132 | )[:, : params["spec_width"]] 133 | 134 | specs.append(spec[np.newaxis, ...]) 135 | labels.append(ann["class"]) 136 | 137 | audios.append(audio) 138 | sampling_rates.append(sampling_rate) 139 | file_names.append(cur_file["file_path"]) 140 | 141 | # position in crop 142 | x1 = int( 143 | au.time_to_x_coords( 144 | np.array(params["aud_pad"]), 145 | sampling_rate, 146 | params["fft_win_length"], 147 | params["fft_overlap"], 148 | ) 149 | ) 150 | y1 = (ann["low_freq"] - params["min_freq"]) * params[ 151 | "fft_win_length" 152 | ] 153 | coords.append((y1, x1)) 154 | 155 | _, file_ids = np.unique(file_names, return_inverse=True) 156 | labels = np.array([class_names.index(ll) for ll in labels]) 157 | 158 | # return np.vstack(specs), labels, coords, audios, sampling_rates, file_ids, file_names 159 | return np.vstack(specs), labels 160 | 161 | 162 | def save_summary_image( 163 | specs, 164 | labels, 165 | species_names, 166 | params, 167 | op_file_name="plots/all_species.png", 168 | order=None, 169 | ): 170 | # takes the mean for each class and plots it on a grid 171 | mean_specs = [] 172 | max_band = [] 173 | for ii in range(len(species_names)): 174 | inds = np.where(labels == ii)[0] 175 | mu = specs[inds, :].mean(0) 176 | max_band.append(np.argmax(mu.sum(1))) 177 | mean_specs.append(mu) 178 | 179 | # control the order in which classes are printed 180 | if order is None: 181 | order = np.arange(len(species_names)) 182 | 183 | max_cols = 6 184 | nrows = int(np.ceil(len(species_names) / max_cols)) 185 | ncols = np.minimum(len(species_names), max_cols) 186 | 187 | fig, ax = plt.subplots( 188 | nrows=nrows, 189 | ncols=ncols, 190 | figsize=(ncols * 3.3, nrows * 6), 191 | gridspec_kw={"wspace": 0, "hspace": 0.2}, 192 | ) 193 | spec_min_max = ( 194 | 0, 195 | mean_specs[0].shape[1], 196 | params["min_freq"] / 1000, 197 | params["max_freq"] / 1000, 198 | ) 199 | ii = 0 200 | for row in ax: 201 | 202 | if type(row) != np.ndarray: 203 | row = np.array([row]) 204 | 205 | for col in row: 206 | if ii >= len(species_names): 207 | col.axis("off") 208 | else: 209 | inds = np.where(labels == order[ii])[0] 210 | col.imshow( 211 | mean_specs[order[ii]], 212 | extent=spec_min_max, 213 | cmap="plasma", 214 | aspect="equal", 215 | ) 216 | col.grid(color="w", alpha=0.3, linewidth=0.3) 217 | col.set_xticks([]) 218 | col.title.set_text(str(ii + 1) + " " + species_names[order[ii]]) 219 | col.tick_params(axis="both", which="major", labelsize=7) 220 | ii += 1 221 | 222 | # plt.tight_layout() 223 | # plt.show() 224 | plt.savefig(op_file_name) 225 | plt.close("all") 226 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | 4 | import pytest 5 | 6 | 7 | @pytest.fixture 8 | def example_data_dir() -> Path: 9 | pkg_dir = Path(__file__).parent.parent 10 | example_data_dir = pkg_dir / "example_data" 11 | assert example_data_dir.exists() 12 | return example_data_dir 13 | 14 | 15 | @pytest.fixture 16 | def example_audio_dir(example_data_dir: Path) -> Path: 17 | example_audio_dir = example_data_dir / "audio" 18 | assert example_audio_dir.exists() 19 | return example_audio_dir 20 | 21 | 22 | @pytest.fixture 23 | def example_audio_files(example_audio_dir: Path) -> List[Path]: 24 | audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]")) 25 | assert len(audio_files) == 3 26 | return audio_files 27 | 28 | 29 | @pytest.fixture 30 | def data_dir() -> Path: 31 | dir = Path(__file__).parent / "data" 32 | assert dir.exists() 33 | return dir 34 | 35 | 36 | @pytest.fixture 37 | def contrib_dir(data_dir) -> Path: 38 | dir = data_dir / "contrib" 39 | assert dir.exists() 40 | return dir 41 | -------------------------------------------------------------------------------- /tests/data/20230322_172000_selec2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/20230322_172000_selec2.wav -------------------------------------------------------------------------------- /tests/data/contrib/jeff37/0166_20240531_223911.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/jeff37/0166_20240531_223911.wav -------------------------------------------------------------------------------- /tests/data/contrib/jeff37/0166_20240602_225340.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/jeff37/0166_20240602_225340.wav -------------------------------------------------------------------------------- /tests/data/contrib/jeff37/0166_20240603_033731.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/jeff37/0166_20240603_033731.wav -------------------------------------------------------------------------------- /tests/data/contrib/jeff37/0166_20240603_033937.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/jeff37/0166_20240603_033937.wav -------------------------------------------------------------------------------- /tests/data/contrib/jeff37/0166_20240604_233500.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/jeff37/0166_20240604_233500.wav -------------------------------------------------------------------------------- /tests/data/contrib/padpadpadpad/Audiomoth.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/padpadpadpad/Audiomoth.WAV -------------------------------------------------------------------------------- /tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV -------------------------------------------------------------------------------- /tests/data/contrib/padpadpadpad/Echometer.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/batdetect2/4cd71497e7d126e018aafcb71b4e5e17f5ee95e6/tests/data/contrib/padpadpadpad/Echometer.wav -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | """Test bat detect module API.""" 2 | 3 | import os 4 | from glob import glob 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import soundfile as sf 9 | import torch 10 | from torch import nn 11 | 12 | from batdetect2 import api 13 | import io 14 | 15 | PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") 17 | TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) 18 | 19 | DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 20 | 21 | def test_load_model_with_default_params(): 22 | """Test loading model with default parameters.""" 23 | model, params = api.load_model() 24 | 25 | assert model is not None 26 | assert isinstance(model, nn.Module) 27 | 28 | assert params is not None 29 | assert isinstance(params, dict) 30 | 31 | assert "model_name" in params 32 | assert "num_filters" in params 33 | assert "emb_dim" in params 34 | assert "ip_height" in params 35 | 36 | assert params["model_name"] == "Net2DFast" 37 | assert params["num_filters"] == 128 38 | assert params["emb_dim"] == 0 39 | assert params["ip_height"] == 128 40 | assert params["resize_factor"] == 0.5 41 | assert len(params["class_names"]) == 17 42 | 43 | 44 | def test_list_audio_files(): 45 | """Test listing audio files.""" 46 | audio_files = api.list_audio_files(TEST_DATA_DIR) 47 | 48 | assert len(audio_files) == 3 49 | assert all(path.endswith((".wav", ".WAV")) for path in audio_files) 50 | 51 | 52 | def test_load_audio(): 53 | """Test loading audio.""" 54 | audio = api.load_audio(TEST_DATA[0]) 55 | 56 | assert audio is not None 57 | assert isinstance(audio, np.ndarray) 58 | assert audio.shape == (128000,) 59 | 60 | 61 | def test_generate_spectrogram(): 62 | """Test generating spectrogram.""" 63 | audio = api.load_audio(TEST_DATA[0]) 64 | spectrogram = api.generate_spectrogram(audio) 65 | 66 | assert spectrogram is not None 67 | assert isinstance(spectrogram, torch.Tensor) 68 | assert spectrogram.shape == (1, 1, 128, 512) 69 | 70 | 71 | def test_get_default_config(): 72 | """Test getting default configuration.""" 73 | config = api.get_config() 74 | 75 | assert config is not None 76 | assert isinstance(config, dict) 77 | 78 | assert config["target_samp_rate"] == 256000 79 | assert config["fft_win_length"] == 0.002 80 | assert config["fft_overlap"] == 0.75 81 | assert config["resize_factor"] == 0.5 82 | assert config["spec_divide_factor"] == 32 83 | assert config["spec_height"] == 256 84 | assert config["spec_scale"] == "pcen" 85 | assert config["denoise_spec_avg"] is True 86 | assert config["max_scale_spec"] is False 87 | assert config["scale_raw_audio"] is False 88 | assert len(config["class_names"]) == 17 89 | assert config["detection_threshold"] == 0.01 90 | assert config["time_expansion"] == 1 91 | assert config["top_n"] == 3 92 | assert config["return_raw_preds"] is False 93 | assert config["max_duration"] is None 94 | assert config["nms_kernel_size"] == 9 95 | assert config["max_freq"] == 120000 96 | assert config["min_freq"] == 10000 97 | assert config["nms_top_k_per_sec"] == 200 98 | assert config["quiet"] is True 99 | assert config["chunk_size"] == 3 100 | assert config["cnn_features"] is False 101 | assert config["spec_features"] is False 102 | assert config["spec_slices"] is False 103 | 104 | 105 | def test_api_exposes_default_model(): 106 | """Test that API exposes default model.""" 107 | assert hasattr(api, "model") 108 | assert isinstance(api.model, nn.Module) 109 | assert type(api.model).__name__ == "Net2DFast" 110 | 111 | # Check that model has expected attributes 112 | assert api.model.num_classes == 17 113 | assert api.model.num_filts == 128 114 | assert api.model.emb_dim == 0 115 | assert api.model.ip_height_rs == 128 116 | assert api.model.resize_factor == 0.5 117 | 118 | 119 | def test_api_exposes_default_config(): 120 | """Test that API exposes default configuration.""" 121 | assert hasattr(api, "config") 122 | assert isinstance(api.config, dict) 123 | 124 | assert api.config["target_samp_rate"] == 256000 125 | assert api.config["fft_win_length"] == 0.002 126 | assert api.config["fft_overlap"] == 0.75 127 | assert api.config["resize_factor"] == 0.5 128 | assert api.config["spec_divide_factor"] == 32 129 | assert api.config["spec_height"] == 256 130 | assert api.config["spec_scale"] == "pcen" 131 | assert api.config["denoise_spec_avg"] is True 132 | assert api.config["max_scale_spec"] is False 133 | assert api.config["scale_raw_audio"] is False 134 | assert len(api.config["class_names"]) == 17 135 | assert api.config["detection_threshold"] == 0.01 136 | assert api.config["time_expansion"] == 1 137 | assert api.config["top_n"] == 3 138 | assert api.config["return_raw_preds"] is False 139 | assert api.config["max_duration"] is None 140 | assert api.config["nms_kernel_size"] == 9 141 | assert api.config["max_freq"] == 120000 142 | assert api.config["min_freq"] == 10000 143 | assert api.config["nms_top_k_per_sec"] == 200 144 | assert api.config["quiet"] is True 145 | assert api.config["chunk_size"] == 3 146 | assert api.config["cnn_features"] is False 147 | assert api.config["spec_features"] is False 148 | assert api.config["spec_slices"] is False 149 | 150 | 151 | def test_process_file_with_default_model(): 152 | """Test processing file with model.""" 153 | predictions = api.process_file(TEST_DATA[0]) 154 | 155 | assert predictions is not None 156 | assert isinstance(predictions, dict) 157 | 158 | assert "pred_dict" in predictions 159 | 160 | # By default will not return other features 161 | assert "spec_feats" not in predictions 162 | assert "spec_feat_names" not in predictions 163 | assert "cnn_feats" not in predictions 164 | assert "cnn_feat_names" not in predictions 165 | assert "spec_slices" not in predictions 166 | 167 | # Check that predictions are returned 168 | assert isinstance(predictions["pred_dict"], dict) 169 | pred_dict = predictions["pred_dict"] 170 | assert pred_dict["id"] == os.path.basename(TEST_DATA[0]) 171 | assert pred_dict["annotated"] is False 172 | assert pred_dict["issues"] is False 173 | assert pred_dict["notes"] == "Automatically generated." 174 | assert pred_dict["time_exp"] == 1 175 | assert pred_dict["duration"] == 0.5 176 | assert pred_dict["class_name"] is not None 177 | assert len(pred_dict["annotation"]) > 0 178 | 179 | 180 | def test_process_spectrogram_with_default_model(): 181 | """Test processing spectrogram with model.""" 182 | audio = api.load_audio(TEST_DATA[0]) 183 | spectrogram = api.generate_spectrogram(audio) 184 | predictions, features = api.process_spectrogram(spectrogram) 185 | 186 | assert predictions is not None 187 | assert isinstance(predictions, list) 188 | assert len(predictions) > 0 189 | sample_pred = predictions[0] 190 | assert isinstance(sample_pred, dict) 191 | assert "class" in sample_pred 192 | assert "class_prob" in sample_pred 193 | assert "det_prob" in sample_pred 194 | assert "start_time" in sample_pred 195 | assert "end_time" in sample_pred 196 | assert "low_freq" in sample_pred 197 | assert "high_freq" in sample_pred 198 | 199 | assert features is not None 200 | assert isinstance(features, np.ndarray) 201 | assert len(features) == len(predictions) 202 | 203 | 204 | def test_process_audio_with_default_model(): 205 | """Test processing audio with model.""" 206 | audio = api.load_audio(TEST_DATA[0]) 207 | predictions, features, spec = api.process_audio(audio) 208 | 209 | assert predictions is not None 210 | assert isinstance(predictions, list) 211 | assert len(predictions) > 0 212 | sample_pred = predictions[0] 213 | assert isinstance(sample_pred, dict) 214 | assert "class" in sample_pred 215 | assert "class_prob" in sample_pred 216 | assert "det_prob" in sample_pred 217 | assert "start_time" in sample_pred 218 | assert "end_time" in sample_pred 219 | assert "low_freq" in sample_pred 220 | assert "high_freq" in sample_pred 221 | 222 | assert features is not None 223 | assert isinstance(features, np.ndarray) 224 | assert len(features) == len(predictions) 225 | 226 | assert spec is not None 227 | assert isinstance(spec, torch.Tensor) 228 | assert spec.shape == (1, 1, 128, 512) 229 | 230 | 231 | def test_postprocess_model_outputs(): 232 | """Test postprocessing model outputs.""" 233 | # Load model outputs 234 | audio = api.load_audio(TEST_DATA[1]) 235 | spec = api.generate_spectrogram(audio) 236 | model_outputs = api.model(spec) 237 | 238 | # Postprocess outputs 239 | predictions, features = api.postprocess(model_outputs) 240 | 241 | assert predictions is not None 242 | assert isinstance(predictions, list) 243 | assert len(predictions) > 0 244 | sample_pred = predictions[0] 245 | assert isinstance(sample_pred, dict) 246 | assert "class" in sample_pred 247 | assert "class_prob" in sample_pred 248 | assert "det_prob" in sample_pred 249 | assert "start_time" in sample_pred 250 | assert "end_time" in sample_pred 251 | assert "low_freq" in sample_pred 252 | assert "high_freq" in sample_pred 253 | 254 | assert features is not None 255 | assert isinstance(features, np.ndarray) 256 | assert features.shape[0] == len(predictions) 257 | assert features.shape[1] == 32 258 | 259 | 260 | def test_process_file_with_spec_slices(): 261 | """Test process file returns spec slices.""" 262 | config = api.get_config(spec_slices=True) 263 | results = api.process_file(TEST_DATA[0], config=config) 264 | detections = results["pred_dict"]["annotation"] 265 | 266 | assert "spec_slices" in results 267 | assert isinstance(results["spec_slices"], list) 268 | assert len(results["spec_slices"]) == len(detections) 269 | 270 | 271 | def test_process_file_with_empty_predictions_does_not_fail( 272 | tmp_path: Path, 273 | ): 274 | """Test process file with empty predictions does not fail.""" 275 | # Create empty file 276 | empty_file = tmp_path / "empty.wav" 277 | empty_wav = np.zeros((0, 1), dtype=np.float32) 278 | sf.write(empty_file, empty_wav, 256000) 279 | 280 | # Process file 281 | results = api.process_file(str(empty_file)) 282 | 283 | assert results is not None 284 | assert len(results["pred_dict"]["annotation"]) == 0 285 | 286 | def test_process_file_file_id_defaults_to_basename(): 287 | """Test that process_file assigns basename as an id if no file_id is provided.""" 288 | # Recording donated by @@kdarras 289 | basename = "20230322_172000_selec2.wav" 290 | path = os.path.join(DATA_DIR, basename) 291 | 292 | output = api.process_file(path) 293 | predictions = output["pred_dict"] 294 | id = predictions["id"] 295 | assert id == basename 296 | 297 | def test_bytesio_file_id_defaults_to_md5(): 298 | """Test that process_file assigns an md5 sum as an id if no file_id is provided when using binary data.""" 299 | # Recording donated by @@kdarras 300 | basename = "20230322_172000_selec2.wav" 301 | path = os.path.join(DATA_DIR, basename) 302 | 303 | with open(path, "rb") as f: 304 | data = io.BytesIO(f.read()) 305 | 306 | output = api.process_file(data) 307 | predictions = output["pred_dict"] 308 | id = predictions["id"] 309 | assert id == "7ade9ebf1a9fe5477ff3a2dc57001929" 310 | -------------------------------------------------------------------------------- /tests/test_audio_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from hypothesis import given 5 | from hypothesis import strategies as st 6 | 7 | from batdetect2.detector import parameters 8 | from batdetect2.utils import audio_utils, detector_utils 9 | import io 10 | import os 11 | 12 | DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 13 | 14 | @given(duration=st.floats(min_value=0.1, max_value=2)) 15 | def test_can_compute_correct_spectrogram_width(duration: float): 16 | samplerate = parameters.TARGET_SAMPLERATE_HZ 17 | params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS 18 | 19 | length = int(duration * samplerate) 20 | audio = np.random.rand(length) 21 | 22 | spectrogram, _ = audio_utils.generate_spectrogram( 23 | audio, 24 | samplerate, 25 | params, 26 | ) 27 | 28 | # convert to pytorch 29 | spectrogram = torch.from_numpy(spectrogram) 30 | 31 | # add batch and channel dimensions 32 | spectrogram = spectrogram.unsqueeze(0).unsqueeze(0) 33 | 34 | # resize the spec 35 | resize_factor = params["resize_factor"] 36 | spec_op_shape = ( 37 | int(params["spec_height"] * resize_factor), 38 | int(spectrogram.shape[-1] * resize_factor), 39 | ) 40 | spectrogram = F.interpolate( 41 | spectrogram, 42 | size=spec_op_shape, 43 | mode="bilinear", 44 | align_corners=False, 45 | ) 46 | 47 | expected_width = audio_utils.compute_spectrogram_width( 48 | length, 49 | samplerate=parameters.TARGET_SAMPLERATE_HZ, 50 | window_duration=params["fft_win_length"], 51 | window_overlap=params["fft_overlap"], 52 | resize_factor=params["resize_factor"], 53 | ) 54 | 55 | assert spectrogram.shape[-1] == expected_width 56 | 57 | 58 | @given(duration=st.floats(min_value=0.1, max_value=2)) 59 | def test_pad_audio_without_fixed_size(duration: float): 60 | # Test the pad_audio function 61 | # This function is used to pad audio with zeros to a specific length 62 | # It is used in the generate_spectrogram function 63 | # The function is tested with a simplepas 64 | samplerate = parameters.TARGET_SAMPLERATE_HZ 65 | params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS 66 | 67 | length = int(duration * samplerate) 68 | audio = np.random.rand(length) 69 | 70 | # pad the audio to be divisible by divide factor 71 | padded_audio = audio_utils.pad_audio( 72 | audio, 73 | samplerate=samplerate, 74 | window_duration=params["fft_win_length"], 75 | window_overlap=params["fft_overlap"], 76 | resize_factor=params["resize_factor"], 77 | divide_factor=params["spec_divide_factor"], 78 | ) 79 | 80 | # check that the padded audio is divisible by the divide factor 81 | expected_width = audio_utils.compute_spectrogram_width( 82 | len(padded_audio), 83 | samplerate=parameters.TARGET_SAMPLERATE_HZ, 84 | window_duration=params["fft_win_length"], 85 | window_overlap=params["fft_overlap"], 86 | resize_factor=params["resize_factor"], 87 | ) 88 | 89 | assert expected_width % params["spec_divide_factor"] == 0 90 | 91 | 92 | @given(duration=st.floats(min_value=0.1, max_value=2)) 93 | def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor( 94 | duration: float, 95 | ): 96 | samplerate = parameters.TARGET_SAMPLERATE_HZ 97 | params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS 98 | length = int(duration * samplerate) 99 | audio = np.random.rand(length) 100 | _, spectrogram, _ = detector_utils.compute_spectrogram( 101 | audio, 102 | samplerate, 103 | params, 104 | torch.device("cpu"), 105 | ) 106 | assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0 107 | 108 | 109 | @given( 110 | duration=st.floats(min_value=0.1, max_value=2), 111 | width=st.integers(min_value=128, max_value=1024), 112 | ) 113 | def test_pad_audio_with_fixed_width(duration: float, width: int): 114 | samplerate = parameters.TARGET_SAMPLERATE_HZ 115 | params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS 116 | 117 | length = int(duration * samplerate) 118 | audio = np.random.rand(length) 119 | 120 | # pad the audio to be divisible by divide factor 121 | padded_audio = audio_utils.pad_audio( 122 | audio, 123 | samplerate=samplerate, 124 | window_duration=params["fft_win_length"], 125 | window_overlap=params["fft_overlap"], 126 | resize_factor=params["resize_factor"], 127 | divide_factor=params["spec_divide_factor"], 128 | fixed_width=width, 129 | ) 130 | 131 | # check that the padded audio is divisible by the divide factor 132 | expected_width = audio_utils.compute_spectrogram_width( 133 | len(padded_audio), 134 | samplerate=parameters.TARGET_SAMPLERATE_HZ, 135 | window_duration=params["fft_win_length"], 136 | window_overlap=params["fft_overlap"], 137 | resize_factor=params["resize_factor"], 138 | ) 139 | assert expected_width == width 140 | 141 | 142 | def test_load_audio_using_bytesio(): 143 | basename = "20230322_172000_selec2.wav" 144 | path = os.path.join(DATA_DIR, basename) 145 | 146 | with open(path, "rb") as f: 147 | data = io.BytesIO(f.read()) 148 | 149 | sample_rate, audio_data, file_sample_rate = audio_utils.load_audio_and_samplerate(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ) 150 | 151 | expected_sample_rate, expected_audio_data, exp_file_sample_rate = audio_utils.load_audio_and_samplerate(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ) 152 | 153 | assert expected_sample_rate == sample_rate 154 | assert exp_file_sample_rate == file_sample_rate 155 | 156 | assert np.array_equal(audio_data, expected_audio_data) -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | """Test the command line interface.""" 2 | 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | from click.testing import CliRunner 7 | 8 | from batdetect2.cli import cli 9 | 10 | runner = CliRunner() 11 | 12 | 13 | def test_cli_base_command(): 14 | """Test the base command.""" 15 | result = runner.invoke(cli, ["--help"]) 16 | assert result.exit_code == 0 17 | assert ( 18 | "BatDetect2 - Bat Call Detection and Classification" in result.output 19 | ) 20 | 21 | 22 | def test_cli_detect_command_help(): 23 | """Test the detect command help.""" 24 | result = runner.invoke(cli, ["detect", "--help"]) 25 | assert result.exit_code == 0 26 | assert "Detect bat calls in files in AUDIO_DIR" in result.output 27 | 28 | 29 | def test_cli_detect_command_on_test_audio(tmp_path): 30 | """Test the detect command on test audio.""" 31 | results_dir = tmp_path / "results" 32 | 33 | # Remove results dir if it exists 34 | if results_dir.exists(): 35 | results_dir.rmdir() 36 | 37 | result = runner.invoke( 38 | cli, 39 | [ 40 | "detect", 41 | "example_data/audio", 42 | str(results_dir), 43 | "0.3", 44 | ], 45 | ) 46 | assert result.exit_code == 0 47 | assert results_dir.exists() 48 | assert len(list(results_dir.glob("*.csv"))) == 3 49 | assert len(list(results_dir.glob("*.json"))) == 3 50 | 51 | 52 | def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path): 53 | """Test the detect command with a non-trivial time expansion factor.""" 54 | results_dir = tmp_path / "results" 55 | 56 | # Remove results dir if it exists 57 | if results_dir.exists(): 58 | results_dir.rmdir() 59 | 60 | result = runner.invoke( 61 | cli, 62 | [ 63 | "detect", 64 | "example_data/audio", 65 | str(results_dir), 66 | "0.3", 67 | "--time_expansion_factor", 68 | "10", 69 | ], 70 | ) 71 | 72 | assert result.exit_code == 0 73 | assert "Time Expansion Factor: 10" in result.stdout 74 | 75 | 76 | def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path): 77 | """Test the detect command with the spec feature flag.""" 78 | results_dir = tmp_path / "results" 79 | 80 | # Remove results dir if it exists 81 | if results_dir.exists(): 82 | results_dir.rmdir() 83 | 84 | result = runner.invoke( 85 | cli, 86 | [ 87 | "detect", 88 | "example_data/audio", 89 | str(results_dir), 90 | "0.3", 91 | "--spec_features", 92 | ], 93 | ) 94 | assert result.exit_code == 0 95 | assert results_dir.exists() 96 | 97 | csv_files = [path.name for path in results_dir.glob("*.csv")] 98 | 99 | expected_files = [ 100 | "20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv", 101 | "20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv", 102 | "20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv", 103 | ] 104 | 105 | for expected_file in expected_files: 106 | assert expected_file in csv_files 107 | 108 | df = pd.read_csv(results_dir / expected_file) 109 | assert not (df.duration == -1).any() 110 | 111 | 112 | def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path): 113 | results_dir = tmp_path / "results" 114 | target = tmp_path / "audio" 115 | target.mkdir() 116 | 117 | # Create an empty file with the .wav extension 118 | empty_file = target / "empty.wav" 119 | empty_file.touch() 120 | 121 | result = runner.invoke( 122 | cli, 123 | args=[ 124 | "detect", 125 | str(target), 126 | str(results_dir), 127 | "0.3", 128 | "--spec_features", 129 | ], 130 | ) 131 | assert result.exit_code == 0 132 | assert f"Error processing file {empty_file}" in result.output 133 | 134 | 135 | def test_can_set_chunk_size(tmp_path: Path): 136 | results_dir = tmp_path / "results" 137 | 138 | # Remove results dir if it exists 139 | if results_dir.exists(): 140 | results_dir.rmdir() 141 | 142 | result = runner.invoke( 143 | cli, 144 | [ 145 | "detect", 146 | "example_data/audio", 147 | str(results_dir), 148 | "0.3", 149 | "--chunk_size", 150 | "1", 151 | ], 152 | ) 153 | 154 | assert "Chunk Size: 1.0s" in result.output 155 | assert result.exit_code == 0 156 | assert results_dir.exists() 157 | assert len(list(results_dir.glob("*.csv"))) == 3 158 | assert len(list(results_dir.glob("*.json"))) == 3 159 | -------------------------------------------------------------------------------- /tests/test_contrib.py: -------------------------------------------------------------------------------- 1 | """Test suite to ensure user provided files are correctly processed.""" 2 | 3 | from pathlib import Path 4 | 5 | from click.testing import CliRunner 6 | 7 | from batdetect2.cli import cli 8 | 9 | runner = CliRunner() 10 | 11 | 12 | def test_can_process_jeff37_files( 13 | contrib_dir: Path, 14 | tmp_path: Path, 15 | ): 16 | """This test stems from issue #31. 17 | 18 | A user provided a set of files which which batdetect2 cli failed and 19 | generated the following error message: 20 | 21 | [2272] "Error processing file!: negative dimensions are not allowed" 22 | 23 | This test ensures that the error message is not generated when running 24 | batdetect2 cli with the same set of files. 25 | """ 26 | path = contrib_dir / "jeff37" 27 | assert path.exists() 28 | 29 | results_dir = tmp_path / "results" 30 | result = runner.invoke( 31 | cli, 32 | [ 33 | "detect", 34 | str(path), 35 | str(results_dir), 36 | "0.3", 37 | ], 38 | ) 39 | assert result.exit_code == 0 40 | assert results_dir.exists() 41 | assert len(list(results_dir.glob("*.csv"))) == 5 42 | assert len(list(results_dir.glob("*.json"))) == 5 43 | 44 | 45 | def test_can_process_padpadpadpad_files( 46 | contrib_dir: Path, 47 | tmp_path: Path, 48 | ): 49 | """This test stems from issue #29. 50 | 51 | Batdetect2 cli failed on the files provided by the user @padpadpadpad 52 | with the following error message: 53 | 54 | AttributeError: module 'numpy' has no attribute 'AxisError' 55 | 56 | This test ensures that the files are processed without any error. 57 | """ 58 | path = contrib_dir / "padpadpadpad" 59 | assert path.exists() 60 | results_dir = tmp_path / "results" 61 | result = runner.invoke( 62 | cli, 63 | [ 64 | "detect", 65 | str(path), 66 | str(results_dir), 67 | "0.3", 68 | ], 69 | ) 70 | assert result.exit_code == 0 71 | assert results_dir.exists() 72 | assert len(list(results_dir.glob("*.csv"))) == 2 73 | assert len(list(results_dir.glob("*.json"))) == 2 74 | -------------------------------------------------------------------------------- /tests/test_detections.py: -------------------------------------------------------------------------------- 1 | """Test suite to ensure that model detections are not incorrect.""" 2 | 3 | import os 4 | 5 | from batdetect2 import api 6 | 7 | DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 8 | 9 | 10 | def test_no_detections_above_nyquist(): 11 | """Test that no detections are made above the nyquist frequency.""" 12 | # Recording donated by @@kdarras 13 | path = os.path.join(DATA_DIR, "20230322_172000_selec2.wav") 14 | 15 | # This recording has a sampling rate of 192 kHz 16 | nyquist = 192_000 / 2 17 | 18 | output = api.process_file(path) 19 | predictions = output["pred_dict"] 20 | assert len(predictions["annotation"]) != 0 21 | assert all( 22 | pred["high_freq"] < nyquist for pred in predictions["annotation"] 23 | ) 24 | -------------------------------------------------------------------------------- /tests/test_features.py: -------------------------------------------------------------------------------- 1 | """Test suite for feature extraction functions.""" 2 | 3 | import logging 4 | 5 | import librosa 6 | import numpy as np 7 | import pytest 8 | 9 | import batdetect2.detector.compute_features as feats 10 | from batdetect2 import api, types 11 | from batdetect2.utils import audio_utils as au 12 | 13 | numba_logger = logging.getLogger("numba") 14 | numba_logger.setLevel(logging.WARNING) 15 | 16 | 17 | def index_to_freq( 18 | index: int, 19 | spec_height: int, 20 | min_freq: int, 21 | max_freq: int, 22 | ) -> float: 23 | """Convert spectrogram index to frequency in Hz.""" 24 | index = spec_height - index 25 | return round( 26 | (index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 27 | ) 28 | 29 | 30 | def index_to_time( 31 | index: int, 32 | spec_width: int, 33 | spec_duration: float, 34 | ) -> float: 35 | """Convert spectrogram index to time in seconds.""" 36 | return round((index / float(spec_width)) * spec_duration, 2) 37 | 38 | 39 | def test_get_feats_function_with_empty_spectrogram(): 40 | """Test get_feats function with empty spectrogram. 41 | 42 | This tests that the overall flow of the function works, even if the 43 | spectrogram is empty. 44 | """ 45 | spec_duration = 3 46 | spec_width = 100 47 | spec_height = 100 48 | min_freq = 10_000 49 | max_freq = 120_000 50 | spectrogram = np.zeros((spec_height, spec_width)) 51 | 52 | x_pos = 20 53 | y_pos = 80 54 | bb_width = 20 55 | bb_height = 20 56 | 57 | start_time = index_to_time(x_pos, spec_width, spec_duration) 58 | end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration) 59 | low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq) 60 | high_freq = index_to_freq( 61 | y_pos - bb_height, spec_height, min_freq, max_freq 62 | ) 63 | 64 | pred_nms: types.PredictionResults = { 65 | "det_probs": np.array([1]), 66 | "class_probs": np.array([[1]]), 67 | "x_pos": np.array([x_pos]), 68 | "y_pos": np.array([y_pos]), 69 | "bb_width": np.array([bb_width]), 70 | "bb_height": np.array([bb_height]), 71 | "start_times": np.array([start_time]), 72 | "end_times": np.array([end_time]), 73 | "low_freqs": np.array([low_freq]), 74 | "high_freqs": np.array([high_freq]), 75 | } 76 | 77 | params: types.FeatureExtractionParameters = { 78 | "min_freq": min_freq, 79 | "max_freq": max_freq, 80 | } 81 | 82 | features = feats.get_feats(spectrogram, pred_nms, params) 83 | assert low_freq < high_freq 84 | assert isinstance(features, np.ndarray) 85 | assert features.shape == (len(pred_nms["det_probs"]), 9) 86 | assert np.isclose( 87 | features[0], 88 | np.array( 89 | [ 90 | end_time - start_time, 91 | low_freq, 92 | high_freq, 93 | high_freq - low_freq, 94 | high_freq, 95 | max_freq, 96 | max_freq, 97 | max_freq, 98 | np.nan, 99 | ] 100 | ), 101 | equal_nan=True, 102 | ).all() 103 | 104 | 105 | @pytest.mark.parametrize( 106 | "max_power", 107 | [ 108 | 30_000, 109 | 31_000, 110 | 32_000, 111 | 33_000, 112 | 34_000, 113 | 35_000, 114 | 36_000, 115 | 37_000, 116 | 38_000, 117 | 39_000, 118 | 40_000, 119 | ], 120 | ) 121 | def test_compute_max_power_bb(max_power: int): 122 | """Test compute_max_power_bb function.""" 123 | duration = 1 124 | samplerate = 256_000 125 | min_freq = 0 126 | max_freq = 128_000 127 | 128 | start_time = 0.3 129 | end_time = 0.6 130 | low_freq = 30_000 131 | high_freq = 40_000 132 | 133 | audio = np.zeros((int(duration * samplerate),)) 134 | 135 | # Add a signal during the time and frequency range of interest 136 | audio[ 137 | int(start_time * samplerate) : int(end_time * samplerate) 138 | ] = 0.5 * librosa.tone( 139 | max_power, sr=samplerate, duration=end_time - start_time 140 | ) 141 | 142 | # Add a more powerful signal outside frequency range of interest 143 | audio[ 144 | int(start_time * samplerate) : int(end_time * samplerate) 145 | ] += 2 * librosa.tone( 146 | 80_000, sr=samplerate, duration=end_time - start_time 147 | ) 148 | 149 | params = api.get_config( 150 | min_freq=min_freq, 151 | max_freq=max_freq, 152 | target_samp_rate=samplerate, 153 | ) 154 | 155 | spec, _ = au.generate_spectrogram( 156 | audio, 157 | samplerate, 158 | params, 159 | ) 160 | 161 | x_start = int( 162 | au.time_to_x_coords( 163 | start_time, 164 | samplerate, 165 | params["fft_win_length"], 166 | params["fft_overlap"], 167 | ) 168 | ) 169 | 170 | x_end = int( 171 | au.time_to_x_coords( 172 | end_time, 173 | samplerate, 174 | params["fft_win_length"], 175 | params["fft_overlap"], 176 | ) 177 | ) 178 | 179 | num_freq_bins = spec.shape[0] 180 | y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq) 181 | y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq) 182 | 183 | prediction: types.Prediction = { 184 | "det_prob": 1, 185 | "class_prob": np.ones((1,)), 186 | "x_pos": x_start, 187 | "y_pos": int(y_low), 188 | "bb_width": int(x_end - x_start), 189 | "bb_height": int(y_low - y_high), 190 | "start_time": start_time, 191 | "end_time": end_time, 192 | "low_freq": low_freq, 193 | "high_freq": high_freq, 194 | } 195 | 196 | print(prediction) 197 | 198 | max_power_bb = feats.compute_max_power_bb( 199 | prediction, 200 | spec, 201 | min_freq=min_freq, 202 | max_freq=max_freq, 203 | ) 204 | 205 | assert abs(max_power_bb - max_power) <= 500 206 | 207 | 208 | def test_compute_max_power(): 209 | """Test compute_max_power_bb function.""" 210 | duration = 3 211 | samplerate = 16_000 212 | min_freq = 0 213 | max_freq = 8_000 214 | 215 | start_time = 1 216 | end_time = 2 217 | low_freq = 3_000 218 | high_freq = 4_000 219 | max_power = 5_000 220 | 221 | audio = np.zeros((int(duration * samplerate),)) 222 | 223 | # Add a signal during the time and frequency range of interest 224 | audio[ 225 | int(start_time * samplerate) : int(end_time * samplerate) 226 | ] = 0.5 * librosa.tone( 227 | 3_500, sr=samplerate, duration=end_time - start_time 228 | ) 229 | 230 | # Add a more powerful signal outside frequency range of interest 231 | audio[ 232 | int(start_time * samplerate) : int(end_time * samplerate) 233 | ] += 2 * librosa.tone( 234 | max_power, sr=samplerate, duration=end_time - start_time 235 | ) 236 | 237 | params = api.get_config( 238 | min_freq=min_freq, 239 | max_freq=max_freq, 240 | target_samp_rate=samplerate, 241 | ) 242 | 243 | spec, _ = au.generate_spectrogram( 244 | audio, 245 | samplerate, 246 | params, 247 | ) 248 | 249 | x_start = int( 250 | au.time_to_x_coords( 251 | start_time, 252 | samplerate, 253 | params["fft_win_length"], 254 | params["fft_overlap"], 255 | ) 256 | ) 257 | 258 | x_end = int( 259 | au.time_to_x_coords( 260 | end_time, 261 | samplerate, 262 | params["fft_win_length"], 263 | params["fft_overlap"], 264 | ) 265 | ) 266 | 267 | num_freq_bins = spec.shape[0] 268 | y_low = int(num_freq_bins * low_freq / max_freq) 269 | y_high = int(num_freq_bins * high_freq / max_freq) 270 | 271 | prediction: types.Prediction = { 272 | "det_prob": 1, 273 | "class_prob": np.ones((1,)), 274 | "x_pos": x_start, 275 | "y_pos": int(y_high), 276 | "bb_width": int(x_end - x_start), 277 | "bb_height": int(y_high - y_low), 278 | "start_time": start_time, 279 | "end_time": end_time, 280 | "low_freq": low_freq, 281 | "high_freq": high_freq, 282 | } 283 | 284 | computed_max_power = feats.compute_max_power( 285 | prediction, 286 | spec, 287 | min_freq=min_freq, 288 | max_freq=max_freq, 289 | ) 290 | 291 | assert abs(computed_max_power - max_power) < 100 292 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | """Test suite for model functions.""" 2 | 3 | import warnings 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import numpy as np 8 | from hypothesis import given, settings 9 | from hypothesis import strategies as st 10 | 11 | from batdetect2 import api 12 | from batdetect2.detector import parameters 13 | 14 | 15 | def test_can_import_model_without_warnings(): 16 | with warnings.catch_warnings(): 17 | warnings.simplefilter("error") 18 | api.load_model() 19 | 20 | 21 | @settings(deadline=None, max_examples=5) 22 | @given(duration=st.floats(min_value=0.1, max_value=2)) 23 | def test_can_import_model_without_pickle(duration: float): 24 | # NOTE: remove this test once no other issues are found This is a temporary 25 | # test to check that change in model loading did not impact model behaviour 26 | # in any way. 27 | 28 | samplerate = parameters.TARGET_SAMPLERATE_HZ 29 | audio = np.random.rand(int(duration * samplerate)) 30 | 31 | model_without_pickle, model_params_without_pickle = api.load_model( 32 | weights_only=True 33 | ) 34 | model_with_pickle, model_params_with_pickle = api.load_model( 35 | weights_only=False 36 | ) 37 | 38 | assert model_params_without_pickle == model_params_with_pickle 39 | 40 | predictions_without_pickle, _, _ = api.process_audio( 41 | audio, 42 | model=model_without_pickle, 43 | ) 44 | predictions_with_pickle, _, _ = api.process_audio( 45 | audio, 46 | model=model_with_pickle, 47 | ) 48 | 49 | assert predictions_without_pickle == predictions_with_pickle 50 | 51 | 52 | def test_can_import_model_without_pickle_on_test_data( 53 | example_audio_files: List[Path], 54 | ): 55 | # NOTE: remove this test once no other issues are found This is a temporary 56 | # test to check that change in model loading did not impact model behaviour 57 | # in any way. 58 | 59 | model_without_pickle, model_params_without_pickle = api.load_model( 60 | weights_only=True 61 | ) 62 | model_with_pickle, model_params_with_pickle = api.load_model( 63 | weights_only=False 64 | ) 65 | 66 | assert model_params_without_pickle == model_params_with_pickle 67 | 68 | for audio_file in example_audio_files: 69 | audio = api.load_audio(str(audio_file)) 70 | predictions_without_pickle, _, _ = api.process_audio( 71 | audio, 72 | model=model_without_pickle, 73 | ) 74 | predictions_with_pickle, _, _ = api.process_audio( 75 | audio, 76 | model=model_with_pickle, 77 | ) 78 | assert predictions_without_pickle == predictions_with_pickle 79 | --------------------------------------------------------------------------------