├── steme ├── viz.py ├── __init__.py ├── paths.py ├── metrics.py ├── data_augmentation.py ├── calibration.py ├── loader.py ├── audio.py ├── utils.py ├── models.py └── dataset.py ├── notebooks ├── center_dict.pkl ├── results_dict.pkl ├── tempo_range.ipynb ├── distributions.ipynb ├── calibrate_and_evaluate.ipynb ├── tempogram_types.ipynb └── data_augmentation.ipynb ├── scripts ├── helpers │ ├── crop.sh │ └── create_gtzan_augmented.sh ├── generate_dataset.py ├── tempogram_data_augmentation.py ├── audio_data_augmentation.py ├── generate_predictions.py ├── evaluate_model.py ├── train_model.py ├── calibrate_model.py └── test_model.py ├── .gitignore ├── setup.py └── README.md /steme/viz.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /steme/__init__.py: -------------------------------------------------------------------------------- 1 | from . import audio, calibration, dataset, loader, metrics, models, paths, utils 2 | -------------------------------------------------------------------------------- /notebooks/center_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giovana-morais/steme/HEAD/notebooks/center_dict.pkl -------------------------------------------------------------------------------- /notebooks/results_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/giovana-morais/steme/HEAD/notebooks/results_dict.pkl -------------------------------------------------------------------------------- /scripts/helpers/crop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # get 30 s of audio 4 | for i in *_augmented.wav; do 5 | newfile="${i%.*}_cropped.wav"; 6 | echo "cropping $i to $newfile"; 7 | ffmpeg -i $i -ss 00:00:5 -to 00:00:35 -acodec copy $newfile; 8 | done 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignoring installation files 2 | steme.egg-info/ 3 | 4 | # ignoring caches 5 | */*__pycache__ 6 | */.ipynb_checkpoints/ 7 | 8 | # ignoring generated images 9 | *.png 10 | *.svg 11 | 12 | # ignoring .sbatch files 13 | jobs/ 14 | 15 | # models and data content 16 | data/* 17 | models/* 18 | -------------------------------------------------------------------------------- /steme/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PROJECT_FOLDER = os.getcwd() 4 | DATA_FOLDER = os.path.join(PROJECT_FOLDER, "data") 5 | DATASET_FOLDER = os.path.join(os.environ["HOME"], "datasets") 6 | FIG_FOLDER = os.path.join(PROJECT_FOLDER, "figures") 7 | LOG_FOLDER = os.path.join(PROJECT_FOLDER, "../logs") 8 | MODEL_FOLDER = os.path.join(PROJECT_FOLDER, "models") 9 | # DATASET_FOLDER = "/scratch/mf3734/share/datasets" 10 | DATASET_FOLDER = "/scratch/gv2167/datasets" 11 | PRE_COMPUTED_DATA_FOLDER = os.path.join(DATA_FOLDER, "pre_computed_tempograms") 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | this file has setup to distribute this as a package 5 | """ 6 | 7 | from setuptools import setup 8 | 9 | setup( 10 | name = "steme", 11 | version = 0.1, 12 | description = "", 13 | url = "", 14 | author = "Giovana Morais", 15 | author_email = "giovana.vmorais@gmail.com", 16 | license = "", 17 | packages = ["steme"], 18 | install_requires = [ 19 | "h5py>=3.7", 20 | "librosa>=0.8.0", 21 | "mirdata>=0.3.7", 22 | "numpy>=1.19.2", 23 | "pandas>=2.0.0", 24 | "scipy>=1.9.0", 25 | "tensorflow>=2.0" 26 | ], 27 | zip_safe = False 28 | ) 29 | -------------------------------------------------------------------------------- /scripts/helpers/create_gtzan_augmented.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | 3 | foldername=gtzan_augmented 4 | audiofolder=$foldername/audio 5 | tempofolder=$foldername/annotations/tempo 6 | rm -rf foldername 7 | 8 | # create folders 9 | mkdir $audiofolder -p 10 | mkdir $tempofolder -p 11 | 12 | basefolder=gtzan_genre/gtzan_genre/genres 13 | 14 | # copy audio content content there 15 | echo "copying audio data to $audiofolder"; 16 | for i in gtzan_genre/gtzan_genre/genres/*; do 17 | # echo "copying $i to $audiofolder" 18 | cp $i/* $audiofolder; 19 | done 20 | 21 | # copy annotations 22 | echo "copying bpm data to $tempofolder" 23 | for i in gtzan_genre/gtzan_tempo_beat-main/tempo/*; do 24 | cp $i $tempofolder; 25 | done 26 | 27 | cd $tempofolder; 28 | 29 | # format names 30 | echo "formatting audio data" 31 | for i in gtzan_*; do 32 | remove_gtzan=${i:6}; 33 | new_name=${remove_gtzan/_/.}; 34 | mv $i $new_name; 35 | done 36 | -------------------------------------------------------------------------------- /scripts/generate_dataset.py: -------------------------------------------------------------------------------- 1 | import steme.audio as audio 2 | import steme.dataset as dataset 3 | from steme.paths import * 4 | 5 | 6 | def main(dataset_name, dataset_type, synthetic, tmin, bins_per_octave, n_bins, 7 | t_type, **kwargs): 8 | theta = dataset.variables_non_linear(tmin=tmin, 9 | bins_per_octave=bins_per_octave, 10 | n_bins=n_bins) 11 | 12 | if not synthetic: 13 | dataset.generate_dataset( 14 | dataset_name=dataset_name, 15 | dataset_type=dataset_type, 16 | theta=theta, 17 | t_type=t_type) 18 | else: 19 | dataset.generate_synthetic_dataset( 20 | dataset_name=dataset_name, 21 | dataset_type=dataset_type, 22 | theta=theta, 23 | t_type=t_type, 24 | **kwargs) 25 | 26 | return 27 | 28 | 29 | if __name__ == "__main__": 30 | import fire 31 | fire.Fire(main) 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # steme 2 | Self supervised TEMpo Estimation 3 | 4 | Source code for the paper ["Tempo vs. Pitch: understanding self-supervised tempo 5 | estimation"](https://arxiv.org/abs/2304.06868). 6 | 7 | ![steme_workflow](https://github.com/giovana-morais/steme/assets/12520431/0ce4c9dc-f2eb-4749-a113-e79759712ae9) 8 | 9 | 10 | ## Installation from the source 11 | 12 | It is not possible (yet!) to install the package directly, such as `pip install 13 | steme`. Therefore, you need to clone the repo, and then run 14 | 15 | ```bash 16 | pip install -e . 17 | ``` 18 | 19 | ## Cite 20 | 21 | ``` 22 | @inproceedings{morais2023tempovspitch, 23 | author={Morais, Giovana and Davies, Matthew E. P. and Queiroz, Marcelo and Fuentes, Magdalena}, 24 | booktitle={ICASSP 2023 - 2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 25 | title={Tempo vs. Pitch: Understanding Self-Supervised Tempo Estimation}, 26 | year={2023}, 27 | volume={}, 28 | number={}, 29 | pages={1-5}, 30 | doi={10.1109/ICASSP49357.2023.10095292}} 31 | ``` 32 | -------------------------------------------------------------------------------- /steme/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute tempo metrics 3 | """ 4 | import numpy as np 5 | 6 | 7 | def acc1(reference_tempo, estimated_tempo, tolerance=0.04, factor=1.0): 8 | return np.abs(reference_tempo * factor - estimated_tempo)\ 9 | <= (reference_tempo * factor * tolerance) 10 | 11 | 12 | def acc2(reference_tempo, estimated_tempo, tolerance=0.04): 13 | return ( 14 | (acc1(reference_tempo, estimated_tempo, tolerance, 1.0)) 15 | | (acc1(reference_tempo, estimated_tempo, tolerance, 2.0)) 16 | | (acc1(reference_tempo, estimated_tempo, tolerance, 3.0)) 17 | | (acc1(reference_tempo, estimated_tempo, tolerance, 1.0 / 2.0)) 18 | | (acc1(reference_tempo, estimated_tempo, tolerance, 1.0 / 3.0)) 19 | ) 20 | 21 | 22 | def oe1(reference_tempo, estimated_tempo, octave_factor=1.0): 23 | return np.log2((estimated_tempo * octave_factor) / reference_tempo) 24 | 25 | 26 | def oe2(reference_tempo, estimated_tempo): 27 | factors = [1 / 3, 1 / 2, 1, 2, 3] 28 | oe = np.zeros_like(factors) 29 | 30 | for idx, factor in enumerate(factors): 31 | oe[idx] = oe1(reference_tempo, estimated_tempo, factor) 32 | 33 | return oe.min() 34 | 35 | 36 | def aoe1(reference_tempo, estimated_tempo): 37 | return np.abs(oe1(reference_tempo, estimated_tempo)) 38 | 39 | 40 | def aoe2(reference_tempo, estimated_tempo): 41 | return 42 | -------------------------------------------------------------------------------- /scripts/tempogram_data_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Create new dataset doing augmentations directly in the tempogram domain 3 | """ 4 | import h5py 5 | import numpy as np 6 | 7 | from steme import dataset, data_augmentation, paths 8 | 9 | 10 | if __name__ == "__main__": 11 | # 1. load GTZAN 12 | gtzan, tracks, tempi = dataset.gtzan_data() 13 | 14 | main_file = "gtzan_tempogram_aug" 15 | 16 | main_filepath = os.path.join(paths.DATA_FOLDER, f"{main_file}.h5") 17 | 18 | linear_theta = np.arange(30, 670, 1) 19 | 20 | with h5py.File(main_filepath) as hf: 21 | for track in tracks: 22 | audio, sr = gtzan.track(track_id).audio 23 | 24 | # 2. calcular todos os tempogramas com o dobro do tamanho 25 | T, freqs, times = audio.tempogram(audio, sr, window_size_in_seconds=10, 26 | t_type="fourier", theta=linear_theta) 27 | 28 | # 3. reduzir o tamanho de geral 29 | aug_T = get_even_rows(T) 30 | aug_bpm = gtzan.track(track_id).bpm / 2 31 | 32 | # 4. salvar os tempogramas reduzidos + o novo andamento (que será metade do 33 | # original) 34 | hf.create_dataset(f"{track_id}_augmented", data=T) 35 | 36 | # save tempo 37 | with open(os.path.join(DATASET_PATH, 38 | f"{track_id}_augmented.bpm"), "w") as f: 39 | f.write(str(aug_bpm)) 40 | -------------------------------------------------------------------------------- /steme/data_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Augmentation helper functions 3 | """ 4 | 5 | import logging 6 | from typing import List, Dict, Tuple 7 | 8 | import numpy as np 9 | import numpy.typing as npt 10 | 11 | 12 | def key_boundaries(key: str) -> List[int]: 13 | """ 14 | transform a string key of type '60, 65' into a list of ints [60, 65] 15 | """ 16 | return [int(i) for i in key.split(", ")] 17 | 18 | 19 | def check_missing_tracks(transformation_dict: Dict[str, float]) -> Tuple[str, 20 | float]: 21 | """ 22 | return the first occurrence of a boundary that needs to be filled with 23 | data. 24 | """ 25 | for k, v in transformation_dict.items(): 26 | if v > 0: 27 | return k, v 28 | 29 | 30 | def create_transformation_dict(verbose: bool = True) -> Dict[str, float]: 31 | """ 32 | create dictionary with boundaries and the augmentation needed in 33 | that boundary. e.g. {"[60, 65]": -2" means that 2 tracks need to be 34 | removed from the interval [60,65]. 35 | """ 36 | removals = 0 37 | additions = 0 38 | 39 | transformation_dict = {} 40 | 41 | for idx, value in enumerate(diff_tempi): 42 | transformation_dict[f"{finer_bins[idx]}, {finer_bins[idx+1]}"] = value 43 | 44 | if value < 0: 45 | message = f"remove {value} samples" 46 | removals += np.abs(value) 47 | elif value > 0: 48 | message = f"add {value} samples" 49 | additions += value 50 | else: 51 | message = "do nothing" 52 | 53 | logger.debug(f"{finer_bins[idx]} - {finer_bins[idx+1]}: {message}") 54 | 55 | logger.info(f"total removals = {removals}, total additions = {additions}") 56 | 57 | return transformation_dict 58 | 59 | 60 | def reset_transformation_dict() -> Dict[str, float]: 61 | """ 62 | return a new transformation dict 63 | """ 64 | return create_transformation_dict(verbose=False) 65 | 66 | 67 | def tempogram_augmentation(): 68 | """ 69 | create new tempogram with half the tempo of the original 70 | 71 | parameters 72 | --- 73 | T : np.array (2D) 74 | tempogram matrix 75 | 76 | return 77 | --- 78 | augmented_T : np.array (2D) 79 | 80 | """ 81 | return 82 | 83 | 84 | # TODO: arrumar o nome dessa função pq tá tenebroso 85 | def get_even_rows(T: npt.ArrayLike) -> npt.ArrayLike: 86 | """ 87 | create tempogram with only the even lines of the input tempogram 88 | 89 | parameters 90 | --- 91 | T : np.array (2D) 92 | tempogram matrix 93 | 94 | return 95 | --- 96 | augmented_T : np.array (2D) 97 | 98 | """ 99 | 100 | if T.shape[0] % 2 != 0: 101 | raise ValueError(f"Tempogram shape is not suitable for this reduction. \ 102 | Expected even rows, but got {T.shape}") 103 | 104 | augmented_T = T[::2, :].copy() 105 | 106 | return augmented_T 107 | 108 | 109 | def create_helper_dict(bins: list) -> Dict[str, list(str)]: 110 | # criar um dicionário com intervalo: {track_ids} 111 | # [30,40]: ["classical.0000", "blues.0010"] 112 | helper_dict = {} 113 | for idx in range(len(bins)-1): 114 | helper_dict[f"{bins[idx]}, {bins[idx+1]}"] = [] 115 | 116 | return helper_dict 117 | -------------------------------------------------------------------------------- /scripts/audio_data_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Augment dataset in audio domain 3 | """ 4 | 5 | import logger 6 | import os 7 | import random 8 | from typing import List, Dict, Tuple 9 | 10 | import numpy as np 11 | import numpy.typing as npt 12 | import pyrubberband 13 | 14 | import steme.augmentation as aug 15 | import steme.dataset as dataset 16 | import steme.path as paths 17 | 18 | DATASET_PATH = os.path.join(paths.DATASET_FOLDER, "gtzan_augmented_log") 19 | 20 | def remove_tracks(to_remove): 21 | for track_id in to_remove: 22 | try: 23 | logger.debug(f"removing {track_id}") 24 | os.remove(os.path.join(DATASET_PATH, f"audio/{track_id}.wav")) 25 | os.remove(os.path.join(DATASET_PATH, f"annotations/tempo/{track_id}.bpm")) 26 | except: 27 | logger.debug("already removed") 28 | continue 29 | return 30 | 31 | if __name__ == "__main__": 32 | gtzan, tracks, tempi = dataset.gtzan_data() 33 | dist_low = dataset.lognormal70() 34 | theta = dataset.variables_non_linear(25, 40, 190) 35 | log_bins = theta[(theta > 30) & (theta < 370)][::2] 36 | bins = log_bins 37 | diff_tempi = dist_low_hist[0] - gtzan_dist[0] 38 | 39 | helper_dict = aug.create_helper_dict(finer_bins) 40 | gtzan_mapping = {} 41 | 42 | for i in tracks: 43 | tempo = gtzan.track(i).tempo 44 | 45 | boundaries = np.digitize(tempo, finer_bins) 46 | gtzan_mapping[i] = (tempo, f"{finer_bins[boundaries-1]}, {finer_bins[boundaries]}") 47 | helper_dict[f"{finer_bins[boundaries-1]}, {finer_bins[boundaries]}"].append(i) 48 | 49 | transformation_dict = aug.reset_transformation_dict() 50 | 51 | to_remove = [] 52 | j = 0 53 | for key, val in list(transformation_dict.items())[::-1]: 54 | # for key, val in transformation_dict.items(): 55 | if val < 0: 56 | logger.info(f"augmenting tracks from {key}") 57 | for track_id in helper_dict[key]: 58 | logger.debug(track_id) 59 | original_tempo = gtzan.track(track_id).tempo 60 | original_boundaries = gtzan_mapping[track_id][1] 61 | 62 | str_boundaries = aug.check_missing_tracks(transformation_dict) 63 | 64 | if str_boundaries is None or key == str_boundaries[0]: 65 | logger.debug(transformation_dict) 66 | break 67 | 68 | new_tempo_boundaries = aug.key_boundaries(str_boundaries[0]) 69 | 70 | # if key == str_boundaries[0]: 71 | # print(f"we will not transform {key} into {str_boundaries[0]}") 72 | # # transformation_dict[str_boundaries[0]] -= 1 73 | # break 74 | 75 | new_tempo = random.uniform(float(new_tempo_boundaries[0]), float(new_tempo_boundaries[1])) 76 | 77 | tempo_rate = new_tempo/original_tempo 78 | 79 | x, fs = gtzan.track(track_id).audio 80 | to_remove.append(track_id) 81 | 82 | logger.debug(f"original_tempo {original_tempo}, new_tempo {new_tempo}, tempo_rate {tempo_rate}") 83 | 84 | # pyrubberband parameters 85 | # -3 means the finest algorithm, therefore the best audio 86 | # quality 87 | rbags = {"-3": ""} 88 | x_stretch = pyrb.time_stretch(x, fs, tempo_rate, **rbags) 89 | 90 | # update dicts 91 | transformation_dict[str_boundaries[0]] -= 1 92 | transformation_dict[original_boundaries] += 1 93 | 94 | # save audio 95 | sf.write(os.path.join(DATASET_PATH, f"audio/{track_id}_augmented.wav"), x_stretch, fs, subtype="PCM_24") 96 | 97 | # save tempo 98 | with open(os.path.join(DATASET_PATH, f"annotations/tempo/{track_id}_augmented.bpm"), "w") as f: 99 | f.write(str(new_tempo)) 100 | -------------------------------------------------------------------------------- /notebooks/tempo_range.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "a271e0d1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from steme import dataset " 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "ebf92d59", 16 | "metadata": {}, 17 | "source": [ 18 | "# Discussion on tempogram ranges\n", 19 | "\n", 20 | "Tempograms are a time-tempo representation of an audio clip, that is we can see the most salient BPMs of a song in a given time. When building a tempogram, we usually have a linear axis, with a resolution of 1 BPM, but for this work we decided to change the axis of the representation and we chose a nonlinear axis that follows the formula `bpms = t0 * frequencies*(k/Q)`" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "1085efc1", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "theta1 = dataset.variables_non_linear(tmin=30, bins_per_octave=40, n_bins=190)\n", 31 | "\n", 32 | "print(\"$k_min \\in U(0,8)$\")\n", 33 | "print(theta1[0:128][0], theta1[0:128][-1])\n", 34 | "print(theta1[8:128][8], theta1[8:128+8][-1])\n", 35 | "\n", 36 | "print(\"\\n$k_min \\in U(11,18)$\")\n", 37 | "print(theta1[11:128+11][0], theta1[11:128+11][-1])\n", 38 | "print(theta1[18:128+18][0], theta1[18:128+18][-1])" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "a59ecc54", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "theta2 = dataset.variables_non_linear(tmin=25, bins_per_octave=40, n_bins=190)\n", 49 | "\n", 50 | "print(\"$k_min \\in U(0,8)$\")\n", 51 | "print(theta2[0:128][0], theta2[0:128][-1])\n", 52 | "print(theta2[8:128][0], theta2[8:128+8][-1])\n", 53 | "\n", 54 | "print(\"\\n$k_min \\in U(11,18)$\")\n", 55 | "print(theta2[11:128+11][0], theta2[11:128+11][-1])\n", 56 | "print(theta2[18:128+18][0], theta2[18:128+18][-1])" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "c0c750ce", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "len(theta1)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "93b2bf76", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "len(theta2)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "f85bed39", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "theta1" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "868453bf", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "theta2[0], theta2[-1]" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "a78ac02a", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [] 106 | } 107 | ], 108 | "metadata": { 109 | "kernelspec": { 110 | "display_name": "Python 3 (ipykernel)", 111 | "language": "python", 112 | "name": "python3" 113 | }, 114 | "language_info": { 115 | "codemirror_mode": { 116 | "name": "ipython", 117 | "version": 3 118 | }, 119 | "file_extension": ".py", 120 | "mimetype": "text/x-python", 121 | "name": "python", 122 | "nbconvert_exporter": "python", 123 | "pygments_lexer": "ipython3", 124 | "version": "3.10.6" 125 | } 126 | }, 127 | "nbformat": 4, 128 | "nbformat_minor": 5 129 | } 130 | -------------------------------------------------------------------------------- /scripts/generate_predictions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | 5 | import h5py 6 | import numpy as np 7 | import pandas as pd 8 | import tensorflow as tf 9 | 10 | import steme.audio 11 | import steme.dataset as dt 12 | import steme.loader 13 | import steme.metrics 14 | import steme.utils 15 | from steme.paths import * 16 | 17 | 18 | def generate_predictions(mirdata_dataset, mirdata_dataset_data_folder, 19 | model_name, kmin, kmax, track_file): 20 | 21 | model = tf.keras.models.load_model(os.path.join(MODEL_FOLDER, model_name)) 22 | 23 | predictions_folder = os.path.join(DATA_FOLDER, f"predictions/{model_name}") 24 | print(f"predictions_folder {predictions_folder}") 25 | 26 | predictions_folder = os.path.join( 27 | DATA_FOLDER, f"predictions/{model_name}_fixed_shift") 28 | 29 | if not os.path.isdir(predictions_folder): 30 | os.mkdir(predictions_folder) 31 | 32 | with h5py.File(track_file, "r") as hf: 33 | total = len(hf.keys()) 34 | 35 | for idx, track_id in enumerate(hf.keys()): 36 | print(f"processing {track_id}. {idx}/{total}") 37 | dest_path = os.path.join(predictions_folder, track_id) 38 | predictions = [] 39 | 40 | if os.path.isfile(f"{dest_path}.npz"): 41 | print(f"{dest_path} exists") 42 | continue 43 | 44 | track_data = np.load( 45 | os.path.join( 46 | mirdata_dataset_data_folder, 47 | f"{track_id}.npz")) 48 | T = track_data["T"] 49 | t = track_data["t"] 50 | freqs = track_data["freqs"] 51 | 52 | for i in range(T.shape[1]): 53 | s1, sh1, _, _, _ = dt.get_tempogram_slices( 54 | T, slice_idx=i, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0) 55 | 56 | s1 = s1[np.newaxis, :] 57 | 58 | xhat1, xhat2, y1, y2 = model.predict( 59 | [s1, s1, sh1, sh1], verbose=0) 60 | 61 | predictions.append(y1[0][0]) 62 | 63 | baseline_tempo = np.take(freqs, np.argmax(T, axis=-2)) 64 | 65 | np.savez( 66 | dest_path, 67 | baseline_tempo=baseline_tempo, 68 | prediction=np.array(predictions) 69 | ) 70 | 71 | return 72 | 73 | 74 | def main( 75 | model_name, 76 | dataset_name, 77 | kmin, 78 | kmax, 79 | tmin, 80 | n_bins, 81 | bins_per_octave, 82 | **kwargs): 83 | 84 | theta = dt.variables_non_linear(tmin, bins_per_octave, n_bins) 85 | 86 | response = dt.read_dataset_info(dataset_name) 87 | 88 | main_file = response["main_file"] 89 | train_file = response["train_file"] 90 | validation_file = response["validation_file"] 91 | main_filepath = response["main_filepath"] 92 | train_filepath = response["train_filepath"] 93 | validation_filepath = response["validation_filepath"] 94 | 95 | distribution = response["distribution"] 96 | 97 | # generate predictions 98 | gtzan, _, _ = dt.gtzan_data() 99 | gtzan_data_folder = os.path.join( 100 | PRE_COMPUTED_DATA_FOLDER, 101 | f"gtzan_{tmin}_{n_bins}_{bins_per_octave}_fourier") 102 | # ballroom = loader.custom_dataset_loader(path=DATASET_FOLDER, dataset_name="ballroom", folder="") 103 | # ballroom_data_folder = os.path.join(PRE_COMPUTED_DATA_FOLDER, f"ballroom_{tmin}_{n_bins}_{bins_per_octave}") 104 | print(f"ballroom_data_folder: {gtzan_data_folder}") 105 | generate_predictions( 106 | gtzan, 107 | gtzan_data_folder, 108 | model_name, 109 | kmin, 110 | kmax, 111 | validation_filepath) 112 | return 113 | 114 | 115 | if __name__ == "__main__": 116 | import fire 117 | fire.Fire(main) 118 | -------------------------------------------------------------------------------- /scripts/evaluate_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | 5 | import h5py 6 | import numpy as np 7 | import pandas as pd 8 | import tensorflow as tf 9 | 10 | import steme.audio 11 | import steme.dataset as dt 12 | import steme.loader 13 | import steme.metrics 14 | import steme.utils 15 | from steme.paths import * 16 | 17 | 18 | def evaluate_model(ballroom, evaluation_file, model, kmin, kmax): 19 | print("Evaluating model") 20 | if not os.path.isfile(evaluation_file): 21 | print(f"Generating {evaluation_file}") 22 | with h5py.File(evaluation_file, "a") as whf: 23 | for track_id in ballroom.track_ids: 24 | fixed_shift_predictions = [] 25 | random_shift_predictions = [] 26 | 27 | theta = dt.variables_non_linear() 28 | x, sr = ballroom.track(track_id).audio 29 | T, t, freqs = audio.tempogram( 30 | x, sr, window_size_seconds=10, t_type="hybrid", theta=theta) 31 | reference_tempo = ballroom.track(track_id).tempo 32 | 33 | for i in range(T.shape[1]): 34 | s1, sh1, s2, sh2, _ = dt.get_tempogram_slices( 35 | T, slice_idx=i, kmin=kmin, kmax=kmax) 36 | 37 | fixed_s1, fixed_sh1, _, _, _ = dt.get_tempogram_slices( 38 | T, slice_idx=i, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0) 39 | 40 | s1 = s1[np.newaxis, :] 41 | fixed_s1 = fixed_s1[np.newaxis, :] 42 | 43 | xhat1, xhat2, y1, y2 = model.predict( 44 | [s1, s1, sh1, sh1], verbose=0) 45 | fixed_xhat1, fixed_xhat2, fixed_y1, fixed_y2 = model.predict( 46 | [fixed_s1, fixed_s1, fixed_sh1, fixed_sh1], verbose=0) 47 | 48 | random_shift_predictions.append(y1[0][0]) 49 | fixed_shift_predictions.append(fixed_y1[0][0]) 50 | 51 | # predicted_tempo_linear = np.array(predictions)*a+b 52 | # predicted_tempo_quadratic = quad(np.array(predictions)) 53 | baseline_tempo = np.take(freqs, np.argmax(T, axis=-2)) 54 | 55 | g = whf.create_group(track_id) 56 | g["reference_tempo"] = reference_tempo 57 | g["fixed_shift_model_output"] = fixed_shift_predictions 58 | # g["random_shift_model_output"] = random_shift_predictions 59 | # g["predicted_tempo_linear"] = predicted_tempo_linear 60 | # g["predicted_tempo_quadratic"] = predicted_tempo_quadratic 61 | g["baseline_tempo"] = baseline_tempo 62 | g["T"] = T.copy() 63 | g["t"] = t 64 | g["freqs"] = freqs 65 | else: 66 | print(f"{evaluation_file} already exists") 67 | 68 | return 69 | 70 | 71 | def main( 72 | model_name, 73 | dataset_name, 74 | dataset_type, 75 | synthetic, 76 | n_predictions, 77 | kmin, 78 | kmax, 79 | tmin, 80 | n_bins, 81 | bins_per_octave, 82 | **kwargs): 83 | 84 | theta = dt.variables_non_linear(tmin, bins_per_octave, n_bins) 85 | 86 | response = dt.read_dataset_info(dataset_name) 87 | 88 | main_file = response["main_file"] 89 | train_file = response["train_file"] 90 | validation_file = response["validation_file"] 91 | main_filepath = response["main_filepath"] 92 | train_filepath = response["train_filepath"] 93 | validation_filepath = response["validation_filepath"] 94 | 95 | tmin = response["tmin"] 96 | tmax = response["tmax"] 97 | distribution = response["distribution"] 98 | 99 | model_path = os.path.join(MODEL_FOLDER, model_name) 100 | model = tf.keras.models.load_model(model_path) 101 | 102 | # evaluate on ballroom 103 | ballroom = loader.custom_dataset_loader( 104 | path=DATASET_FOLDER, dataset_name="ballroom", folder="") 105 | evaluation_file = os.path.join(DATA_FOLDER, f"{model_name}_evaluation.h5") 106 | main_file = os.path.join(DATA_FOLDER, "ballroom.h5") 107 | evaluate_model(ballroom, evaluation_file, model, kmin, kmax) 108 | 109 | return 110 | 111 | 112 | if __name__ == "__main__": 113 | import fire 114 | fire.Fire(main) 115 | -------------------------------------------------------------------------------- /scripts/train_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import h5py 5 | import keras 6 | import numpy as np 7 | import tensorboard 8 | import tensorflow as tf 9 | 10 | import steme.dataset as dt 11 | import steme.models as models 12 | from steme.paths import * 13 | 14 | LEARNING_RATE = 10e-4 15 | 16 | 17 | def train_model( 18 | model, 19 | model_name, 20 | train_data, 21 | validation_data, 22 | epochs, 23 | early_stopping): 24 | 25 | log_dir = os.path.join(LOG_FOLDER, f"{model_name}") # _{TIMESTAMP}") 26 | tensorboard_callback = tf.keras.callbacks.TensorBoard( 27 | log_dir=log_dir, histogram_freq=1) 28 | model_path = os.path.join(MODEL_FOLDER, f"{model_name}") # _{TIMESTAMP}") 29 | 30 | if not os.path.isdir(model_path): 31 | model_checkpoint = tf.keras.callbacks.ModelCheckpoint( 32 | model_path, 33 | monitor="val_loss", 34 | verbose=0, 35 | save_best_only=True, 36 | save_weights_only=False, 37 | mode="auto", 38 | save_freq="epoch", 39 | initial_value_threshold=None, 40 | ) 41 | 42 | callbacks = [tensorboard_callback, model_checkpoint] 43 | 44 | if early_stopping: 45 | early_stopping_callback = tf.keras.callbacks.EarlyStopping( 46 | monitor="val_loss", patience=10) 47 | callbacks.append(early_stopping_callback) 48 | 49 | model.compile(optimizer=tf.keras.optimizers.Adam(LEARNING_RATE)) 50 | model.fit( 51 | train_data, 52 | validation_data=validation_data, 53 | epochs=epochs, 54 | callbacks=callbacks 55 | ) 56 | else: 57 | print(f"Model exists. Loading {model_name}") 58 | model = tf.keras.models.load_model(model_path) 59 | 60 | return model 61 | 62 | 63 | def main(model_name, epochs, early_stopping, main_file, 64 | kmin, kmax, tmin, n_bins, bins_per_octave, sigma_type, model_type, 65 | w_tempo=10e4, w_recon=1, **kwargs): 66 | 67 | main_filepath = os.path.join(DATA_FOLDER, f"{main_file}.h5") 68 | 69 | response = dt.read_dataset_info(main_file) 70 | 71 | train_filepath = response["train_filepath"] 72 | validation_filepath = response["validation_filepath"] 73 | train_setsize = response["train_setsize"] 74 | validation_setsize = response["validation_setsize"] 75 | 76 | training_tmin = response["tmin"] 77 | training_tmax = response["tmax"] 78 | 79 | sigma = dt.sigma(training_tmin, training_tmax, bins_per_octave) 80 | 81 | print(f"sigma = {sigma}") 82 | 83 | if model_type == "spice": 84 | model = models.spice(sigma, w_tempo, w_recon) 85 | else: 86 | model = models.convolutional_autoencoder(sigma, w_tempo, w_recon) 87 | 88 | output_signature = ( 89 | # input shapes 90 | ( 91 | tf.TensorSpec(shape=(128, 1), dtype=tf.float32), 92 | tf.TensorSpec(shape=(128, 1), dtype=tf.float32), 93 | tf.TensorSpec(shape=(1), dtype=tf.float32), 94 | tf.TensorSpec(shape=(1), dtype=tf.float32) 95 | ), 96 | # output shapes 97 | ( 98 | tf.TensorSpec(shape=(128, 1), dtype=tf.float32), 99 | tf.TensorSpec(shape=(128, 1), dtype=tf.float32), 100 | tf.TensorSpec(shape=(1), dtype=tf.float32), 101 | tf.TensorSpec(shape=(1), dtype=tf.float32) 102 | ) 103 | ) 104 | 105 | train_dataset = tf.data.Dataset.from_generator( 106 | lambda: dt.tempo_data_generator(train_filepath, 107 | set_size=train_setsize, 108 | kmin=kmin, 109 | kmax=kmax 110 | ), 111 | output_signature=output_signature 112 | ) 113 | train_dataset = train_dataset.batch(64) 114 | 115 | validation_dataset = tf.data.Dataset.from_generator( 116 | lambda: dt.tempo_data_generator(validation_filepath, 117 | set_size=validation_setsize, # <- synthetic data 118 | kmin=kmin, 119 | kmax=kmax 120 | ), 121 | output_signature=output_signature 122 | ) 123 | validation_dataset = validation_dataset.batch(64) 124 | 125 | model = train_model( 126 | model=model, 127 | model_name=model_name, 128 | train_data=train_dataset, 129 | validation_data=validation_dataset, 130 | epochs=epochs, 131 | early_stopping=early_stopping) 132 | return 133 | 134 | 135 | if __name__ == "__main__": 136 | import fire 137 | fire.Fire(main) 138 | -------------------------------------------------------------------------------- /steme/calibration.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import h5py 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import steme.audio as audio 8 | import steme.dataset as dataset 9 | import steme.utils as utils 10 | 11 | 12 | def load_calibration_tracks(filename): 13 | with h5py.File(filename, "r") as hf: 14 | tracks = [key for key in hf.keys()] 15 | bpm_dict = {} 16 | 17 | for key, value in hf.items(): 18 | reference_tempo = value["reference_tempo"][()] 19 | bpm_dict[reference_tempo] = {} 20 | 21 | bpm_dict[reference_tempo]["T"] = value["T"][:] 22 | bpm_dict[reference_tempo]["t"] = value["t"][:] 23 | bpm_dict[reference_tempo]["freqs"] = value["freqs"][:] 24 | bpm_dict[reference_tempo]["audio"] = value["audio"][:] 25 | bpm_dict[reference_tempo]["reference_tempo"] = value["reference_tempo"][:] 26 | 27 | return bpm_dict 28 | 29 | 30 | def choose_calibration_candidates(calibration_range): 31 | if not isinstance(calibration_range, list): 32 | raise TypeError( 33 | f"calibration_range should be a list or an array, but it \ 34 | is {type(calbration_range)}") 35 | gtzan, tracks, tempi = dataset.gtzan_data() 36 | 37 | gtzan_info = {} 38 | for i in tracks: 39 | gtzan_info[i] = gtzan.track(i).tempo 40 | 41 | gtzan_info = { 42 | k: v for k, 43 | v in sorted( 44 | gtzan_info.items(), 45 | key=lambda item: item[1])} 46 | 47 | calibration_tracks = {} 48 | for i in calibration_range: 49 | candidates = { 50 | k: v for k, 51 | v in gtzan_info.items() if np.abs( 52 | v - i) <= 1} 53 | 54 | random_candidate = random.choice(list(candidates.keys())) 55 | calibration_tracks.update( 56 | {random_candidate: candidates[random_candidate]}) 57 | 58 | return calibration_tracks 59 | 60 | 61 | def synthetic_tracks(theta, step): 62 | # tracks = theta[(theta > 30) & (theta < 300)][::step].copy() 63 | tracks = np.arange(30, 340, 10) 64 | bpm_dict = {} 65 | 66 | for i in tracks: 67 | bpm_dict[i] = {} 68 | bpm_dict["sr"] = 22050 69 | bpm_dict["audio"] = audio.click_track(bpm=i, sr=22050) 70 | 71 | return bpm_dict 72 | 73 | 74 | def calibration_synthetic( 75 | model_name, 76 | model_path, 77 | tmin, 78 | n_bins, 79 | bins_per_octave, 80 | n_predictions, 81 | distribution, 82 | t_type): 83 | theta = dataset.variables_non_linear( 84 | tmin, n_bins=n_bins, bins_per_octave=bins_per_octave) 85 | bpm_dict = create_calibration_tracks(theta, step) 86 | 87 | model = tf.keras.models.load_model(model_path) 88 | 89 | for bpm, val in bpm_dict.items(): 90 | T, t, bpms = audio.tempogram( 91 | val["audio"], val["sr"], window_size_seconds=10, t_type=t_type, theta=theta) 92 | 93 | bpm_dict[bpm] = {} 94 | bpm_dict[bpm]["T"] = T 95 | bpm_dict[bpm]["t"] = t 96 | bpm_dict[bpm]["freqs"] = bpms 97 | 98 | bpm_dict, a, b, _ = _calibrate( 99 | bpm_dict, model, kmin, kmax, n_predictions 100 | ) 101 | bpm_preds = [ 102 | v["predictions"] for k, v in bpm_dict.items() 103 | ] 104 | bpm_tracks = np.round(list(bpm_dict.keys()), 2) 105 | 106 | return bpm_preds 107 | 108 | 109 | def _calibrate(bpm_dict, model, kmin, kmax, n_predictions=100): 110 | print("Calibrating model") 111 | model_output = np.zeros(len(bpm_dict.keys())) 112 | j = 0 113 | for bpm in bpm_dict.keys(): 114 | T = bpm_dict[bpm]["T"] 115 | 116 | preds = np.zeros(n_predictions) 117 | step = T.shape[1] // n_predictions 118 | 119 | for i in range(n_predictions): 120 | slice_idx = i * step 121 | s1, sh1, s2, sh2, _ = dataset.get_tempogram_slices( 122 | T=T, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0, slice_idx=slice_idx) 123 | s1 = s1[np.newaxis, :] 124 | # s2 = s2[np.newaxis, :] 125 | 126 | # xhat1, xhat2, y1, y2 = model.predict([s1, s2, sh1, sh2], verbose=0) 127 | xhat1, xhat2, y1, y2 = model.predict([s1, s1, sh1, sh1], verbose=0) 128 | preds[i] = y1[0][0] 129 | # preds[i] = (y1[0][0]+y2[0][0])/2 130 | # preds[i] = y2[0][0] 131 | 132 | bpm_dict[bpm]["slice"] = s1[0, :, 0] 133 | bpm_dict[bpm]["shift"] = sh1 134 | bpm_dict[bpm]["estimation"] = xhat1[0, :, 0] 135 | bpm_dict[bpm]["predictions"] = np.array(preds) 136 | # print(f"Predictions for {bpm} = {np.array(preds)}") 137 | 138 | model_output[j] = np.median(np.array(preds)) 139 | j += 1 140 | 141 | # quad = np.poly1d(np.polyfit(model_output, list(bpm_dict.keys()), 2)) 142 | a, b = utils.get_slope(model_output, list(bpm_dict.keys())) 143 | 144 | return bpm_dict, a, b # , quad 145 | -------------------------------------------------------------------------------- /steme/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import librosa 4 | import numpy as np 5 | from mirdata import annotations 6 | from mirdata.core import cached_property 7 | from typing import BinaryIO, Optional, TextIO, Tuple 8 | 9 | MAX_STR_LEN = 100 10 | 11 | 12 | def custom_loader(path, dataset_name, folder='datasets'): 13 | print(f'Loading {dataset_name} through custom loader') 14 | datasetdir = os.path.join(path, folder, dataset_name) 15 | dataset = Dataset( 16 | data_home=os.path.join( 17 | datasetdir, 'audio'), annotations_home=os.path.join( 18 | datasetdir, 'annotations')) 19 | tracks = dataset.load_tracks() 20 | return tracks 21 | 22 | 23 | def custom_dataset_loader(path, dataset_name, folder='datasets'): 24 | print(f'Loading {dataset_name} through custom loader') 25 | datasetdir = os.path.join(path, folder, dataset_name) 26 | dataset = Dataset( 27 | data_home=os.path.join( 28 | datasetdir, 'audio'), annotations_home=os.path.join( 29 | datasetdir, 'annotations')) 30 | return dataset 31 | 32 | 33 | class Track: 34 | def __init__(self, track_id, dataset_name, index): 35 | self.track_id = track_id 36 | self._dataset_name = dataset_name 37 | self._track_paths = index[track_id] 38 | 39 | self.audio_path = self.get_path("audio") 40 | self.beats_path = self.get_path("beats") 41 | self.tempo_path = self.get_path("tempo") 42 | 43 | @cached_property 44 | def beats(self) -> Optional[annotations.BeatData]: 45 | return load_beats(self.beats_path) 46 | 47 | @cached_property 48 | def tempo(self) -> Optional[float]: 49 | return load_tempo(self.tempo_path) 50 | 51 | @cached_property 52 | def audio(self) -> Optional[Tuple[np.ndarray, float]]: 53 | return load_audio(self.audio_path) 54 | 55 | def get_path(self, key): 56 | if self._track_paths[key] is None: 57 | return None 58 | else: 59 | return self._track_paths[key] 60 | 61 | def __repr__(self): 62 | properties = [v for v in dir(self.__class__) if not v.startswith("_")] 63 | attributes = [v for v in dir(self) if not v.startswith( 64 | "_") and v not in properties] 65 | 66 | repr_str = "Track(\n" 67 | 68 | for attr in attributes: 69 | val = getattr(self, attr) 70 | if isinstance(val, str): 71 | if len(val) > MAX_STR_LEN: 72 | val = "...{}".format(val[-MAX_STR_LEN:]) 73 | val = '"{}"'.format(val) 74 | repr_str += " {}={},\n".format(attr, val) 75 | 76 | for prop in properties: 77 | val = getattr(self.__class__, prop) 78 | if isinstance(val, types.FunctionType): 79 | continue 80 | 81 | if val.__doc__ is None: 82 | doc = "" 83 | else: 84 | doc = val.__doc__ 85 | 86 | val_type_str = doc.split(":")[0] 87 | repr_str += " {}: {},\n".format(prop, val_type_str) 88 | 89 | repr_str += ")" 90 | return repr_str 91 | 92 | 93 | def load_beats(fhandle: TextIO) -> annotations.BeatData: 94 | beats = np.loadtxt(fhandle) 95 | times = beats[:, 0] 96 | positions = beats[:, 1] 97 | beat_data = annotations.BeatData( 98 | times=times, 99 | time_unit="s", 100 | positions=positions, 101 | position_unit="bar_index") 102 | return beat_data 103 | 104 | 105 | def load_tempo(fhandle: TextIO) -> float: 106 | tempo = np.loadtxt(fhandle) 107 | return float(tempo) 108 | 109 | 110 | def load_audio(fhandle: BinaryIO) -> Tuple[np.ndarray, float]: 111 | audio, sr = librosa.load(fhandle, sr=44100, mono=True) 112 | return audio, sr 113 | 114 | 115 | def brid_indexing_function(filename): 116 | return filename[1:5] 117 | 118 | 119 | def candombe_indexing_function(filename): 120 | return os.path.splitext(filename)[0] 121 | 122 | 123 | def indexing_function(filename): 124 | return os.path.splitext(filename)[0] 125 | 126 | 127 | class Dataset: 128 | def __init__( 129 | self, 130 | data_home='./datasets/brid/audio', 131 | annotations_home='./datasets/brid/', 132 | dataset_name='brid', 133 | indexing_function=indexing_function): 134 | self.dataset_name = dataset_name 135 | self.data_home = data_home 136 | self.annotations_home = annotations_home 137 | 138 | self._index = {} 139 | beats_home = os.path.join(self.annotations_home, 'beats') 140 | tempo_home = os.path.join(self.annotations_home, 'tempo') 141 | for root, dirs, files in os.walk(self.data_home): 142 | for name in files: 143 | if not name == '.DS_Store': 144 | aux_dict = { 145 | 'audio': os.path.join( 146 | root, name), 'beats': os.path.join( 147 | beats_home, name.replace( 148 | '.wav', '.beats')), 'tempo': os.path.join( 149 | tempo_home, name.replace( 150 | '.wav', '.bpm'))} 151 | file_code = indexing_function(name) 152 | self._index[file_code] = aux_dict 153 | 154 | def track(self, track_id): 155 | return Track(track_id, self.dataset_name, self._index) 156 | 157 | def load_tracks(self): 158 | return {track_id: self.track(track_id) for track_id in self.track_ids} 159 | 160 | @property 161 | def track_ids(self): 162 | return list(self._index.keys()) 163 | 164 | @property 165 | def name(self): 166 | return self.dataset_name 167 | -------------------------------------------------------------------------------- /scripts/calibrate_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | import h5py 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import steme.audio as audio 10 | import steme.dataset as dataset 11 | import steme.calibration as calibration 12 | 13 | def default_variables(): 14 | return { 15 | "tmin": 25, 16 | "n_bins": 190, 17 | "bins_per_octave": 40, 18 | "kmin": 11, 19 | "kmax": 19 20 | } 21 | def read_dataset_info(main_file): 22 | dataset_metadata = os.path.join("/home/gigibs/Documents/steme/data", f"{main_file}_metadata.h5") 23 | print(f"Reading metadata file {dataset_metadata}") 24 | response = {} 25 | 26 | with h5py.File(dataset_metadata, "r") as hf: 27 | response["main_file"] = hf.get("main_file")[()].decode("UTF-8") 28 | response["validation_file"] = hf.get("validation_file")[()].decode("UTF-8") 29 | response["train_file"] = hf.get("train_file")[()].decode("UTF-8") 30 | response["main_filepath"] = hf.get("main_filepath")[()].decode("UTF-8") 31 | response["validation_filepath"] = hf.get("validation_filepath")[()].decode("UTF-8") 32 | response["train_filepath"] = hf.get("train_filepath")[()].decode("UTF-8") 33 | response["distribution"] = hf.get("distribution")[:] 34 | response["validation_setsize"] = hf.get("validation_setsize")[()] 35 | response["train_setsize"] = hf.get("train_setsize")[()] 36 | response["tmin"] = hf.get("tmin")[()] 37 | response["tmax"] = hf.get("tmax")[()] 38 | 39 | return response 40 | 41 | # def get_center_bins(step, offset): 42 | # # defining the center bins for the random tracks in calibration 43 | # # step = 5 44 | # # offset = 5 45 | # left = theta[(theta > 30) & (theta < 350)][::step] 46 | # center = theta[(theta > 30) & (theta < 350)][offset::step] 47 | # right = theta[(theta > 30) & (theta < 350)][offset::step] 48 | 49 | # bins = [] 50 | # for i, j, k in zip(left, center, right): 51 | # print(f"boundaries for {np.round(j,2)}: [{np.round(np.sqrt(i*j),2)}, {np.round(np.sqrt(j*k))}]") 52 | # bins.append(i) 53 | # bins.append(j) 54 | # return bins 55 | 56 | def create_center_dict(n_predictions, tracks_per_bin): 57 | # n_predictions = 2 58 | # tracks_per_bin = 50 59 | 60 | center_dict = {} 61 | for idx, val in enumerate(center): 62 | left_boundary = np.sqrt(left[idx]*center[idx]) 63 | right_boundary = np.sqrt(center[idx]*right[idx]) 64 | 65 | center_dict[val] = np.random.uniform(left_boundary, right_boundary, size=tracks_per_bin) 66 | return center_dict 67 | 68 | def calibration_results(dists, t_types, variation, n_predictions): 69 | results_dict = {} 70 | for dist_name in dists: 71 | results_dict[dist_name] = {} 72 | for t_type in t_types: 73 | print(dist_name, t_type) 74 | dataset_name = f"{dist_name}_{t_type}" 75 | 76 | response = read_dataset_info(dataset_name) 77 | distribution = response["distribution"] 78 | 79 | results_dict[dist_name][t_type] = {} 80 | 81 | model_name = f"{dataset_name}_15_default" 82 | model_path = f"../models/{variation}/{model_name}" 83 | 84 | model = tf.keras.models.load_model(model_path) 85 | 86 | for idx, val in enumerate(center): 87 | results_dict[dist_name][t_type][val] = {} 88 | sr = 22050 89 | preds = np.zeros(n_predictions*tracks_per_bin) 90 | 91 | j = 0 92 | 93 | for bpm in center_dict[val]: 94 | x = audio.click_track(bpm=bpm, sr=sr) 95 | T, t, bpms = audio.tempogram(x, sr, window_size_seconds=10, t_type=t_type, theta=theta) 96 | 97 | step = T.shape[1]//n_predictions 98 | 99 | for i in range(n_predictions): 100 | slice_idx = i*step 101 | s1, sh1, s2, sh2, _ = dataset.get_tempogram_slices( 102 | T=T, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0, slice_idx=slice_idx 103 | ) 104 | 105 | s1 = s1[np.newaxis, :] 106 | 107 | xhat1, xhat2, y1, y2 = model.predict([s1, s1, sh1, sh1], verbose=0) 108 | preds[j] = y1[0][0] 109 | j += 1 110 | results_dict[dist_name][t_type][val]["predictions"] = np.array(preds) 111 | # results_dict[dist_name][t_type][val]["tracks"] = bpm_tracks 112 | 113 | del model 114 | return results_dict 115 | 116 | 117 | if __name__ == "__main__": 118 | variables = default_variables() 119 | tmin = variables["tmin"] 120 | n_bins = variables["n_bins"] 121 | bins_per_octave = variables["bins_per_octave"] 122 | kmin, kmax = variables["kmin"], variables["kmax"] 123 | theta = dataset.variables_non_linear(tmin, n_bins=n_bins, bins_per_octave=bins_per_octave) 124 | 125 | step = 5 126 | offset = 5 127 | left = theta[(theta > 30) & (theta < 350)][::step] 128 | center = theta[(theta > 30) & (theta < 350)][offset::step] 129 | right = theta[(theta > 30) & (theta < 350)][offset::step] 130 | # bins = get_center_bins(5, 5) 131 | 132 | dists = [ 133 | "gtzan_augmented_log_25_190_40", 134 | "gtzan_augmented_log_cropped_25_190_40", 135 | "log_uniform_25_190_40", 136 | "synthetic_lognorm_0.7_30_50_1000_25_190_40", 137 | "synthetic_lognorm_0.7_70_50_1000_25_190_40", 138 | "synthetic_lognorm_0.7_120_50_1000_25_190_40", 139 | "gtzan_25_190_40" 140 | ] 141 | 142 | t_types = ["fourier", "autocorrelation", "hybrid"] 143 | 144 | variations = ["early_stopping", "wo_early_stopping"] 145 | n_predictions = 1 146 | tracks_per_bin = 1 147 | 148 | center_dict = create_center_dict(n_predictions, tracks_per_bin) 149 | with open('center_dict_aug_full.pkl', 'wb') as f: 150 | pickle.dump(center_dict, f) 151 | 152 | for v in variations: 153 | results_dict = calibration_results(dists, t_types, v, n_predictions) 154 | with open(f'results_dict_aug_{variation}.pkl', 'wb') as f: 155 | pickle.dump(results_dict, f) 156 | -------------------------------------------------------------------------------- /scripts/test_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import datetime 3 | import os 4 | import random 5 | 6 | import h5py 7 | import keras 8 | import librosa 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import matplotlib.ticker as ticker 12 | import pandas as pd 13 | import tensorflow as tf 14 | 15 | import steme.audio 16 | import steme.dataset as dt 17 | import steme.loader 18 | import steme.metrics 19 | import steme.models 20 | import steme.utils 21 | from steme.paths import * 22 | 23 | 24 | def calibrate_model( 25 | model, 26 | kmin, 27 | kmax, 28 | mirdata_dataset, 29 | n_predictions=100, 30 | shift=False): 31 | tempi = [mirdata_dataset.track(i).tempo for i in mirdata_dataset.track_ids] 32 | track = [i for i in mirdata_dataset.track_ids] 33 | tracks = [] 34 | bpm_tracks = [] 35 | 36 | if mirdata_dataset.name == "gtzan_genre": 37 | tempi.remove(None) 38 | track.remove("reggae.00086") 39 | 40 | tempi = np.array(tempi) 41 | track = np.array(track) 42 | 43 | ordered_indexes = np.argsort(tempi) 44 | tempi = tempi[ordered_indexes] 45 | track = track[ordered_indexes] 46 | 47 | track_dict = {i: j for i, j in zip(tempi, track)} 48 | 49 | intervals = np.arange(30, 350, 10) 50 | bpm_dict = {} 51 | 52 | theta = dt.variables_non_linear() 53 | 54 | # TODO: instead of calculating tempograms everytime, we could only load the 55 | # file from the dataset name 56 | print("Sampling tracks for calibration") 57 | for idx in range(len(intervals) - 1): 58 | try: 59 | interval = tempi[(tempi > intervals[idx]) & 60 | (tempi < intervals[idx + 1])] 61 | bpm = random.choice(interval) 62 | bpm_dict[bpm] = {} 63 | x, sr = mirdata_dataset.track(track_dict[bpm]).audio 64 | T, t, bpms = audio.tempogram( 65 | x, sr, window_size_seconds=10, t_type="hybrid", theta=theta) 66 | 67 | bpm_dict[bpm]["audio"] = x 68 | bpm_dict[bpm]["T"] = T 69 | bpm_dict[bpm]["t"] = t 70 | bpm_dict[bpm]["freqs"] = bpms 71 | 72 | tracks.append(track_dict[bpm]) 73 | bpm_tracks.append(bpm) 74 | print(f"[{intervals[idx]}, {intervals[idx+1]}] - {bpm}") 75 | except IndexError: 76 | print( 77 | f"no index for interval [{intervals[idx]}, {intervals[idx+1]}]") 78 | 79 | print("Calibrating model") 80 | model_output = np.zeros(len(bpm_dict.keys())) 81 | j = 0 82 | for bpm in bpm_dict.keys(): 83 | T = bpm_dict[bpm]["T"] 84 | 85 | preds = np.zeros(n_predictions) 86 | 87 | for i in range(n_predictions): 88 | s1, sh1, s2, sh2, _ = dt.get_tempogram_slices( 89 | T, kmin=kmin, kmax=kmax) 90 | s1 = s1[np.newaxis, :] 91 | s2 = s1[np.newaxis, :] 92 | 93 | if not shift: 94 | s2 = s1 95 | sh2 = sh1 96 | 97 | xhat1, xhat2, y1, y2 = model.predict([s1, s2, sh1, sh2], verbose=0) 98 | preds[i] = y1[0][0] 99 | 100 | model_output[j] = np.median(np.array(preds)) 101 | j += 1 102 | 103 | quad = np.poly1d(np.polyfit(model_output, list(bpm_dict.keys()), 2)) 104 | a, b = utils.get_slope(model_output, list(bpm_dict.keys())) 105 | 106 | return a, b, quad 107 | 108 | 109 | def evaluate_model(ballroom, evaluation_file, model, a, b, quad): 110 | with h5py.File(evaluation_file, "a") as whf: 111 | for track_id in ballroom.track_ids: 112 | predictions = [] 113 | 114 | theta = dt.variables_non_linear() 115 | x, sr = ballroom.track(track_id).audio 116 | T, t, freqs = audio.tempogram( 117 | x, sr, window_size_seconds=10, t_type="hybrid", theta=theta) 118 | reference_tempo = ballroom.track(track_id).tempo 119 | 120 | for i in range(T.shape[1]): 121 | s1, sh1, s2, sh2, _ = dt.get_tempogram_slices( 122 | T, slice_idx=i, kmin=11, kmax=19) 123 | # range between 0,1 124 | s1 /= s1.max() 125 | s2 /= s2.max() 126 | 127 | s1 = s1[np.newaxis, :] 128 | 129 | xhat1, xhat2, y1, y2 = model.predict( 130 | [s1, s1, sh1, sh1], verbose=0) 131 | predictions.append(y1[0][0]) 132 | 133 | predicted_tempo_linear = np.array(predictions) * a + b 134 | predicted_tempo_quadratic = quad(np.array(predictions)) 135 | baseline_tempo = np.take(freqs, np.argmax(T, axis=-2)) 136 | 137 | g = whf.create_group(track_id) 138 | g["reference_tempo"] = reference_tempo 139 | g["predicted_tempo_linear"] = predicted_tempo_linear 140 | g["predicted_tempo_quadratic"] = predicted_tempo_quadratic 141 | g["baseline_tempo"] = baseline_tempo 142 | g["T"] = T.copy() 143 | g["t"] = t 144 | g["freqs"] = freqs 145 | 146 | return 147 | 148 | 149 | def get_metrics(evaluation_file): 150 | baseline_metrics = {} 151 | predicted_metrics = {} 152 | 153 | with h5py.File(evaluation_file, "r") as hf: 154 | for key, value in hf.items(): 155 | baseline_tempo = value["baseline_tempo"][:] 156 | reference_tempo = value["reference_tempo"][()] 157 | predicted_tempo_linear = value["predicted_tempo_linear"][:] 158 | T = value["T"][:] 159 | # t = value["t"][:] 160 | # freqs = value["freqs"][:] 161 | 162 | baseline_acc1 = metrics.acc1( 163 | reference_tempo, np.median(baseline_tempo)) 164 | baseline_acc2 = metrics.acc2( 165 | reference_tempo, np.median(baseline_tempo)) 166 | 167 | predicted_acc1 = metrics.acc1( 168 | reference_tempo, np.median(predicted_tempo_linear)) 169 | predicted_acc2 = metrics.acc2( 170 | reference_tempo, np.median(predicted_tempo_linear)) 171 | 172 | baseline_metrics[key] = {} 173 | baseline_metrics[key]["acc1"] = baseline_acc1 174 | baseline_metrics[key]["acc2"] = baseline_acc2 175 | 176 | predicted_metrics[key] = {} 177 | predicted_metrics[key]["acc1"] = predicted_acc1 178 | predicted_metrics[key]["acc2"] = predicted_acc2 179 | 180 | baseline_df = pd.DataFrame.from_dict(baseline_metrics, orient="index") 181 | predicted_df = pd.DataFrame.from_dict(predicted_metrics, orient="index") 182 | 183 | df = baseline_df.merge( 184 | predicted_df, 185 | left_index=True, right_index=True, 186 | suffixes=("_baseline", "_predicted") 187 | ) 188 | 189 | df["acc1_baseline"] = df["acc1_baseline"].astype(float) 190 | df["acc1_predicted"] = df["acc1_predicted"].astype(float) 191 | df["acc2_baseline"] = df["acc2_baseline"].astype(float) 192 | df["acc2_predicted"] = df["acc2_predicted"].astype(float) 193 | 194 | print("saving .csv file") 195 | df.to_csv(DATA_FOLDER, "ballroom_metrics.csv", index=False) 196 | 197 | return 198 | 199 | 200 | def main(model_path, kmin, kmax): 201 | model = tf.keras.models.load_model(model_path) 202 | ballroom = loader.custom_dataset_loader( 203 | path=DATASET_FOLDER, dataset_name="ballroom", folder="") 204 | 205 | # TODO: 206 | # 0. check if file with tempograms exist so we don't need to calculate it 207 | # 1. calibrate model 208 | a, b, quad = calibrate_model(model, kmin, kmax, ballroom) 209 | 210 | # evaluation_file = os.path.join(DATA_FOLDER, "evaluation", "ballroom_metrics.h5") 211 | evaluation_file = os.path.join(DATA_FOLDER, "ballroom_evaluation.h5") 212 | main_file = os.path.join(DATA_FOLDER, "ballroom.h5") 213 | 214 | evaluate_model(ballroom, evaluation_file, model, a, b, quad) 215 | 216 | get_metrics(evaluation_file) 217 | 218 | return 219 | 220 | 221 | if __name__ == "__main__": 222 | import fire 223 | fire.Fire(main) 224 | -------------------------------------------------------------------------------- /steme/audio.py: -------------------------------------------------------------------------------- 1 | from scipy.interpolate import interp1d 2 | import librosa 3 | import numpy as np 4 | 5 | def spectral_flux( 6 | x, 7 | sr, 8 | n_fft=1024, 9 | hop_length=256, 10 | gamma=100.0, 11 | avg_window=10, 12 | norm=True): 13 | """ 14 | Compute the spectral flux of a signal and apply logarithmic compression. 15 | 16 | Parameters 17 | --------- 18 | x : np.ndarray 19 | audio signal 20 | sr : int 21 | sampling rate 22 | n_fft : int, optional 23 | fft window size 24 | hop_length : int, optional 25 | step between fft windows 26 | gamma : float, optional 27 | logarithmic compression factor 28 | avg_window : int, optional 29 | window size (in samples) to compute local average 30 | norm : bool, optional 31 | boolean flag to normalize or not the novelty function 32 | Return 33 | ------ 34 | novelty : np.ndarray 35 | the novelty function 36 | sr_novelty : float 37 | the sampling rate of the novelty function. defined as (sampling 38 | rate)/hop length 39 | """ 40 | X = librosa.stft( 41 | x, 42 | n_fft=n_fft, 43 | hop_length=hop_length, 44 | win_length=n_fft, 45 | window="hann") 46 | sr_novelty = sr / hop_length 47 | 48 | Y = np.log(1 + gamma * np.abs(X)) 49 | 50 | Y_diff = np.diff(Y) 51 | Y_diff[Y_diff < 0] = 0 52 | 53 | novelty = np.sum(Y_diff, axis=0) 54 | novelty = np.concatenate((novelty, np.array([0.0]))) 55 | 56 | # subtract local avg 57 | if avg_window > 0: 58 | L = len(novelty) 59 | local_avg = np.zeros(L) 60 | for m in range(L): 61 | init = max(m - avg_window, 0) 62 | end = min(m + avg_window + 1, L) 63 | local_avg[m] = np.sum(novelty[init:end]) * \ 64 | (1 / (1 + 2 * avg_window)) 65 | novelty = novelty - local_avg 66 | novelty[novelty < 0] = 0.0 67 | 68 | if norm: 69 | max_value = max(novelty) 70 | if max_value > 0: 71 | novelty /= max_value 72 | 73 | return novelty, sr_novelty 74 | 75 | 76 | def fourier_tempogram(novelty, sr_novelty, window_size, hop_size, theta): 77 | """ 78 | Compute Fourier tempogram 79 | 80 | Parameters 81 | ---------- 82 | novelty : np.ndarray 83 | novelty function 84 | sr_novelty : np.float 85 | sampling rate of the novelty function 86 | window_size : int 87 | window size in frames. 1000 corresponds to 10s in a signal sampled 88 | at 100 Hz 89 | hop_size : int 90 | hop size 91 | theta : np.ndarray 92 | range of BPM to cover 93 | """ 94 | window = np.hanning(window_size) 95 | pad_size = int(window_size // 2) 96 | 97 | L = novelty.shape[0] + 2 * pad_size 98 | 99 | novelty_pad = np.concatenate( 100 | (np.zeros(pad_size), novelty, np.zeros(pad_size))) 101 | t_pad = np.arange(L) 102 | 103 | M = np.int64(np.floor(L - window_size) / hop_size + 1) 104 | K = len(theta) 105 | X = np.zeros((K, M), dtype=np.complex_) 106 | 107 | for k in range(K): 108 | omega = (theta[k] / 60) / sr_novelty 109 | 110 | exponential = np.exp(-2 * np.pi * 1j * omega * t_pad) 111 | x_exp = novelty_pad * exponential 112 | 113 | for n in range(M): 114 | t_0 = n * hop_size 115 | t_1 = t_0 + window_size 116 | X[k, n] = np.sum(window * x_exp[t_0:t_1]) 117 | 118 | times = np.arange(M) * hop_size / sr_novelty 119 | tempi = theta 120 | 121 | return np.abs(X), times, tempi 122 | 123 | 124 | def tempogram(x, sr, window_size_seconds, t_type, theta): 125 | """ 126 | x : np.ndarray 127 | signal 128 | sr : float64 129 | sampling rate 130 | window_size : int, optional 131 | size in seconds of the tempogram window. default is 5s. 132 | type : string, optional 133 | tempogram type. accepted values are "fourier", "autocorrelation", 134 | "hybrid" 135 | theta : np.arange, optional 136 | tempi interval (BPM). default is (30,300,1), i.e from 30 to 300, 1 137 | at a time. 138 | """ 139 | 140 | if not isinstance(theta, np.ndarray): 141 | raise ValueError( 142 | f"theta type incorrect. it should be np.ndarray, but is {type(theta)}") 143 | 144 | novelty, sr_novelty = spectral_flux(x, sr, n_fft=2048, hop_length=512) 145 | 146 | window_size_frames = int(window_size_seconds * sr_novelty) 147 | hop_size = 1 148 | 149 | if t_type == "fourier": 150 | T, t, bpm = fourier_tempogram( 151 | novelty, 152 | sr_novelty, 153 | window_size=window_size_frames, 154 | hop_size=hop_size, 155 | theta=theta 156 | ) 157 | elif t_type == "autocorrelation": 158 | T, t, bpm, _, _ = autocorrelation_tempogram( 159 | novelty, sr_novelty, window_size=window_size_frames, hop_size=hop_size, theta=theta) 160 | elif t_type == "hybrid": 161 | ft, t, bpm = fourier_tempogram( 162 | novelty, sr_novelty, window_size=window_size_frames, hop_size=hop_size, theta=theta) 163 | at, ta, freqsa, _, _ = autocorrelation_tempogram( 164 | novelty, sr_novelty, window_size=window_size_frames, hop_size=hop_size, theta=theta) 165 | 166 | T = ft * at 167 | else: 168 | raise ValueError("tempogram_type incorrect. accepted values are \ 169 | ['fourier', 'autocorrelation', 'hybrid']") 170 | 171 | return T, t, bpm 172 | 173 | 174 | def click_track(bpm, sr=22050, duration=60): 175 | """ 176 | Generates a 60 seconds click track with the desired BPM 177 | 178 | Parameters 179 | ---------- 180 | bpm : int 181 | desired tempo 182 | sr : int, optional 183 | sampling rate 184 | duration : int 185 | duration in seconds 186 | """ 187 | 188 | step = 60 / bpm 189 | 190 | times = np.arange(0, duration, step) 191 | 192 | return librosa.clicks(times=times, sr=sr) 193 | 194 | 195 | def local_autocorrelation(x, sr, N, H): 196 | """Compute local autocorrelation [FMP, Section 6.2.3] 197 | 198 | Notebook: C6/C6S2_TempogramAutocorrelation.ipynb 199 | 200 | Args: 201 | x (np.ndarray): Input signal 202 | sr (scalar): Sampling rate 203 | N (int): Window length 204 | H (int): Hop size 205 | 206 | Returns: 207 | A (np.ndarray): Time-lag representation 208 | times (np.ndarray): Time axis (seconds) 209 | lags (np.ndarray): Lag axis 210 | """ 211 | L_left = round(N / 2) 212 | L_right = L_left 213 | x_pad = np.concatenate((np.zeros(L_left), x, np.zeros(L_right))) 214 | L_pad = len(x_pad) 215 | M = int(np.floor(L_pad - N) / H) + 1 216 | A = np.zeros((N, M)) 217 | win = np.ones(N) 218 | for n in range(M): 219 | t_0 = n * H 220 | t_1 = t_0 + N 221 | x_local = win * x_pad[t_0:t_1] 222 | r_xx = np.correlate(x_local, x_local, mode='full') 223 | r_xx = r_xx[N - 1:] 224 | A[:, n] = r_xx 225 | sr_A = sr / H 226 | times = np.arange(A.shape[1]) / sr_A 227 | lags = np.arange(N) / sr 228 | return A, times, lags 229 | 230 | 231 | def autocorrelation_tempogram( 232 | novelty, sr_novelty, window_size, hop_size, theta): 233 | """ 234 | Compute autocorrelation-based tempogram 235 | 236 | Parameters 237 | ---------- 238 | novelty : np.ndarray 239 | input novelty function 240 | sr_novelty : float64 241 | sampling rate 242 | window_size : int 243 | window length in frames 244 | hop_size : int 245 | hop size 246 | theta : np.ndarray 247 | array with BPM values we want to interpolate the autocorrelation 248 | 249 | Return 250 | ------ 251 | tempogram : np.ndarray 252 | autocorrelation tempogram 253 | times : np.ndarray 254 | time axis (seconds) 255 | bpms : np.ndarray 256 | tempo axis (BPM) 257 | A_cut : np.ndarray 258 | time-lag representation A_cut (cut according to theta) 259 | lags_cut : np.ndarray 260 | Lag axis lags_cut 261 | """ 262 | tempo_min = theta[0] 263 | tempo_max = theta[-1] 264 | lag_min = int(np.ceil(sr_novelty * 60 / tempo_max)) 265 | lag_max = int(np.ceil(sr_novelty * 60 / tempo_min)) 266 | 267 | A, times, lags = local_autocorrelation(novelty, sr_novelty, 268 | window_size, hop_size) 269 | # getting the min/max lag interval to use in the interpolation 270 | A_cut = A[lag_min:lag_max + 1, :] 271 | 272 | # "cut" the frequencies out of the max/min 273 | lags_cut = lags[lag_min:lag_max + 1] 274 | 275 | # translate to BPM 276 | bpms_cut = 60 / lags_cut 277 | bpms = theta 278 | 279 | # interpolate 280 | axis_interpolation = interp1d( 281 | bpms_cut, 282 | A_cut, 283 | kind='linear', 284 | axis=0, 285 | fill_value='extrapolate') 286 | 287 | tempogram = axis_interpolation(bpms) 288 | return tempogram, times, bpms, A_cut, lags_cut 289 | -------------------------------------------------------------------------------- /notebooks/distributions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "e2297721", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "import matplotlib.ticker as ticker\n", 13 | "\n", 14 | "from scipy.stats import lognorm, uniform\n", 15 | "\n", 16 | "import steme.dataset as dataset\n", 17 | "import steme.loader as loader" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "41cd8ee0", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import matplotlib\n", 28 | "matplotlib.rc('xtick', labelsize=18) \n", 29 | "matplotlib.rc('ytick', labelsize=18) \n", 30 | "matplotlib.rc('axes', labelsize=18)\n", 31 | "matplotlib.rc('legend', fontsize=16)\n", 32 | "matplotlib.rc('figure', titlesize=16)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "eba0eab5", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "def gtzan_data():\n", 43 | " import mirdata\n", 44 | " gtzan = mirdata.initialize(\"gtzan_genre\",\n", 45 | " data_home=\"../../datasets/gtzan_genre\",\n", 46 | " version=\"default\")\n", 47 | " tracks = gtzan.track_ids\n", 48 | " tracks.remove(\"reggae.00086\")\n", 49 | " tempi = [gtzan.track(track_id).tempo for track_id in tracks]\n", 50 | "\n", 51 | " return gtzan, tracks, tempi\n", 52 | "\n", 53 | "dist_low = lognorm.rvs(0.25, loc=30, scale=50, size=1000, random_state=42)\n", 54 | "dist_medium = lognorm.rvs(0.25, loc=70, scale=50, size=1000, random_state=42)\n", 55 | "dist_high = lognorm.rvs(0.25, loc=120, scale=50, size=1000, random_state=42)\n", 56 | "dist_uniform = uniform.rvs(30, scale=210,size=1000, random_state=42)\n", 57 | "dist_log_uniform = 30*np.e**(np.random.rand(1000)*np.log(240/30))\n", 58 | "_, _, dist_gtzan = gtzan_data()\n", 59 | "dist_gtzan = np.array(dist_gtzan)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "6562d664", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "from collections import Counter" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "5fc37205", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "theta = dataset.variables_non_linear(25, 40, 190)\n", 80 | "bins = theta[(theta > 30) & (theta < 370)][::2]" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "0df5957a", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "init = lambda x: 25 * 2.0 ** (x / 40)\n", 91 | "end = lambda y: 25 * 2.0 ** ((128+y-1) / 40)\n", 92 | "init(11), end(11)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "c5e11618", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "# test = dist_gtzan[(dist_gtzan > 90) & (dist_gtzan < 240)]\n", 103 | "\n", 104 | "# theta[(theta > 30)]" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "2a01ab2a", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "colors = plt.rcParams[\"axes.prop_cycle\"]()" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "fd3ddc5d", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "cmap = matplotlib.cm.get_cmap('tab10')" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "8b795188", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "cmap" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "09044c97", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "fig, ax = plt.subplots(2,1, figsize=(15,8))\n", 145 | "\n", 146 | "kwargs = {\n", 147 | " \"alpha\": 0.7,\n", 148 | " \"histtype\": \"stepfilled\"\n", 149 | "}\n", 150 | "\n", 151 | "ax[0].hist(dist_low, bins=50, label=\"lognorm @ 70\", edgecolor=\"black\", color=cmap.colors[0], **kwargs)\n", 152 | "ax[0].hist(dist_medium, bins=50, label=\"lognorm @ 120\", edgecolor=\"black\",color=cmap.colors[2], **kwargs)\n", 153 | "ax[0].hist(dist_high, bins=50, label=\"lognorm @ 170\", edgecolor=\"black\",color=cmap.colors[4],**kwargs)\n", 154 | "ax[0].hist(dist_log_uniform, bins=50, label=\"log uniform\", edgecolor=\"black\", color=cmap.colors[3],**kwargs)\n", 155 | "ax[0].grid(True, axis=\"x\", alpha=0.7)\n", 156 | "ax[0].set_xticks(np.arange(30, 340, 20))\n", 157 | "ax[0].title.set_text(\"Linear axis\")\n", 158 | "ax[0].title.set_fontsize(20)\n", 159 | "ax[0].set_xlim(20, 340)\n", 160 | "# ax[0].legend(loc=\"upper right\")\n", 161 | "plt.setp(ax[0], xticklabels=[])\n", 162 | "\n", 163 | "# ax[1].hist(dist_gtzan, bins=50, label=\"GTZAN\", edgecolor=\"black\", color=cmap.colors[8],**kwargs)\n", 164 | "# ax[1].set_xticks(np.arange(30, 340, 20))\n", 165 | "# ax[1].grid(True, axis=\"x\", alpha=0.7)\n", 166 | "# ax[1].set_xlim(20, 340)\n", 167 | "\n", 168 | "# ax[1].title.set_text(\"Linear axis\")\n", 169 | "\n", 170 | "ax[1].hist(dist_low, bins=bins, label=\"lognorm @ 70\", edgecolor=\"black\", color=cmap.colors[0],**kwargs)\n", 171 | "ax[1].hist(dist_medium, bins=bins, label=\"lognorm @ 120\", edgecolor=\"black\",color=cmap.colors[2], **kwargs)\n", 172 | "ax[1].hist(dist_high, bins=bins, label=\"lognorm @ 170\", edgecolor=\"black\",color=cmap.colors[4], **kwargs)\n", 173 | "ax[1].hist(dist_log_uniform, bins=bins, label=\"log uniform\", edgecolor=\"black\", color=cmap.colors[3],**kwargs)\n", 174 | "# ax[2].hist(dist_gtzan, bins=bins, label=\"GTZAN\", edgecolor=\"black\", color=cmap.colors[8], **kwargs)\n", 175 | "ax[1].title.set_text(\"Logarithmic axis\")\n", 176 | "ax[1].title.set_fontsize(20)\n", 177 | "ax[1].grid(True, axis=\"x\", alpha=0.7)\n", 178 | "plt.xscale('log')\n", 179 | "\n", 180 | "ax = plt.gca()\n", 181 | "handles, labels = ax.get_legend_handles_labels()\n", 182 | "fig.legend(handles, labels, loc=\"upper right\", bbox_to_anchor=(0.4955, 0.465, 0.5, 0.5), framealpha=1)\n", 183 | "ax.set_xticks([], [])\n", 184 | "ax.set_xticks(np.round(bins[::4]))\n", 185 | "ax.xaxis.set_major_formatter(ticker.ScalarFormatter())\n", 186 | "ax.set_xlabel(\"BPM\")\n", 187 | "ax.set_xlim(28, 360)\n", 188 | "\n", 189 | "\n", 190 | "# plt.xscale(\"log\")\n", 191 | "plt.set_cmap(\"Accent\")\n", 192 | "plt.tight_layout()\n", 193 | "plt.savefig(\"distributions.svg\", dpi='figure', format=\"svg\", metadata=None,\n", 194 | " bbox_inches=None, pad_inches=0.1,\n", 195 | " facecolor='auto', edgecolor='auto',\n", 196 | " backend=None\n", 197 | ")" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "e86d5533", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "linear_bins = np.arange(20, 360, 10)\n", 208 | "ballroom, b_tracks, b_tempi = dataset.ballroom_data()\n", 209 | "giant_steps, gs_tracks, gs_tempi = dataset.giant_steps_data()" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "6cda21c2", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "fig, ax = plt.subplots(1,3, figsize=(15,8))\n", 220 | "\n", 221 | "ax[0].hist(dist_gtzan, linear_bins, label=\"gtzan\", color=\"red\")\n", 222 | "ax[0].hist(dist_low, linear_bins, label=\"lognorm @ 70\", color=\"orange\", alpha=0.6)\n", 223 | "ax[0].title.set_text(\"GTZAN (1000 tracks)\")\n", 224 | "ax[0].legend()\n", 225 | "\n", 226 | "ax[1].hist(b_tempi, linear_bins, label=\"ballroom\", color=\"blue\")\n", 227 | "ax[1].hist(dist_low, linear_bins, label=\"lognorm @ 70\", color=\"orange\", alpha=0.6)\n", 228 | "ax[1].title.set_text(\"Ballroom (698 tracks)\")\n", 229 | "ax[1].legend()\n", 230 | "\n", 231 | "ax[2].hist(gs_tempi, linear_bins, label=\"ballroom\", color=\"green\")\n", 232 | "ax[2].hist(dist_low, linear_bins, label=\"lognorm @ 70\", color=\"orange\", alpha=0.6)\n", 233 | "ax[2].title.set_text(\"Giant Steps (659 tracks)\")\n", 234 | "ax[2].legend()\n", 235 | "# plt.tight_layout()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "455caee8", 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "combined_hist = []\n", 246 | "\n", 247 | "for i in dist_gtzan:\n", 248 | " combined_hist.append(i)\n", 249 | " \n", 250 | "for i in b_tempi:\n", 251 | " combined_hist.append(i)\n", 252 | " \n", 253 | "for i in gs_tempi:\n", 254 | " combined_hist.append(i)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "id": "f603eb32", 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "plt.hist(combined_hist, linear_bins, label=\"combined_datasets\")\n", 265 | "plt.hist(dist_low, linear_bins, alpha=0.6, label=\"lognormal @ 70\")\n", 266 | "plt.title(\"Combined datasets\")\n", 267 | "plt.legend()" 268 | ] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "Python 3 (ipykernel)", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.10.6" 288 | } 289 | }, 290 | "nbformat": 4, 291 | "nbformat_minor": 5 292 | } 293 | -------------------------------------------------------------------------------- /steme/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import matplotlib.ticker as ticker 4 | import numpy as np 5 | 6 | import steme.paths as paths 7 | 8 | 9 | def plot_calibration(tracks, predictions, filename): 10 | fig, ax = plt.subplots(1, 1, figsize=(20, 5)) 11 | ax = [ax] 12 | ax[0].boxplot(predictions, vert=1) 13 | ax[0].xaxis.set_major_locator(ticker.FixedLocator( 14 | np.arange(1, len(tracks) + 1) 15 | )) 16 | ax[0].plot() 17 | ax[0].xaxis.set_major_formatter(ticker.FixedFormatter(tracks)) 18 | ax[0].set_xlabel("BPM") 19 | ax[0].set_ylabel("Model output") 20 | ax[0].grid(True) 21 | # ax[0].title.set_text(f"Prediction with fixed shift. a = {np.round(a_fixed, 2)}, b = {np.round(b_fixed, 2)}") 22 | 23 | # fig.suptitle(model_name, fontsize=16) 24 | 25 | plt.savefig(os.path.join(paths.FIG_FOLDER, f"{filename}_synthetic.png")) 26 | 27 | return fig, ax 28 | 29 | 30 | def plot_reconstructions(bpm_tracks, bpm_dict, main_file, theta): 31 | fig, ax = plt.subplots(len(bpm_tracks), 1, figsize=(5, 45)) 32 | idx = 0 33 | 34 | for i, key in enumerate(bpm_tracks): 35 | ax[idx].plot(bpm_dict[key]["slice"], label="input") 36 | ax[idx].plot(bpm_dict[key]["estimation"], label="estimation") 37 | shift = bpm_dict[key]["shift"][0] 38 | ax[idx].title.set_text( 39 | f"BPM = {key}, shift {shift} estimation: {np.median(bpm_dict[key]['predictions']):.4f}") 40 | 41 | ax[idx].set_xticks(np.arange(0, 128, 20)) 42 | ax[idx].set_xticklabels(np.round(theta[0 + shift:128 + shift:20], 2)) 43 | ax[idx].legend() 44 | idx += 1 45 | 46 | plt.tight_layout() 47 | plt.grid(True) 48 | plt.savefig( 49 | os.path.join( 50 | paths.DATA_FOLDER, 51 | f"imgs/{main_file}_reconstructions.png")) 52 | return 53 | 54 | 55 | def plot_tempogram(T, t, freqs, title=None): 56 | figsize = (10, 5) 57 | fig, ax = plt.subplots(1, 1, figsize=figsize) 58 | kwargs = _tempogram_kwargs(t, freqs) 59 | 60 | ax.imshow(T, **kwargs) 61 | 62 | xlim = (t[0], t[-1]) 63 | ylim = (freqs[0], freqs[-1]) 64 | 65 | plt.setp(ax, xlim=xlim, ylim=ylim) 66 | 67 | if title is not None: 68 | fig.suptitle(title, fontsize=16) 69 | 70 | ax.set_xlabel("Time (s)") 71 | ax.set_ylabel("Tempo (BPM)") 72 | 73 | return fig, ax 74 | 75 | def plot_tempogram_comparison(T, t, freqs, subplot_titles=None, fig_title=None, figsize=None): 76 | num_tempograms = len(T) 77 | 78 | if figsize is None: 79 | figsize = (5*num_tempograms, 5) 80 | fig, ax = plt.subplots(1, num_tempograms, figsize=figsize) 81 | 82 | for idx in range(num_tempograms): 83 | kwargs = _tempogram_kwargs(t[idx], freqs[idx]) 84 | 85 | ax[idx].imshow(T[idx], **kwargs) 86 | 87 | xlim = (t[idx][0], t[idx][-1]) 88 | ylim = (freqs[idx][0], freqs[idx][-1]) 89 | 90 | plt.setp(ax, xlim=xlim, ylim=ylim) 91 | 92 | if fig_title is not None: 93 | fig.suptitle(fig_title, fontsize=16) 94 | 95 | ax[idx].set_xlabel("Time (s)") 96 | ax[idx].set_ylabel("Tempo (BPM)") 97 | 98 | if subplot_titles is not None: 99 | ax[idx].title.set_text(subplot_titles[idx]) 100 | 101 | return fig, ax 102 | 103 | 104 | def plot_comparison( 105 | T, 106 | t, 107 | freqs, 108 | reference_tempo, 109 | predicted_tempo, 110 | share_plot=False, 111 | xlim=None, 112 | ylim=None, 113 | title=None): 114 | """ 115 | Plot tempogram comparison 116 | 117 | Parameters 118 | --------- 119 | T : np.ndarray(t, freqs) 120 | tempogram matrix 121 | t : np.ndarray 122 | array with time values 123 | freqs : np.ndarray 124 | array with bpm covered 125 | reference_tempo : float64 126 | ground truth tempo annotation 127 | predicted_tempo : np.ndarray 128 | array with tempo predictions 129 | share_plot : bool, optional 130 | if True, ground truth and predictions are plotted together. otherwise, 131 | each one is in a different plot. 132 | xlim : tuple, optional 133 | limit for xaxis. if None, xlim is defined by t values 134 | ylim : tuple, optional 135 | limit for yaxis. if None, ylim is defined by freqs values *or* by the 136 | reference_tempo value + 10 BPM for visualization purposes. 137 | 138 | Return 139 | ----- 140 | fig, ax 141 | """ 142 | fig = None 143 | figsize = (10, 5) 144 | 145 | if share_plot: 146 | fig, ax = plt.subplots(1, 1, figsize=figsize) 147 | ax = [ax] 148 | prediction_plot_idx = 0 149 | else: 150 | fig, ax = plt.subplots(1, 2, figsize=figsize) 151 | prediction_plot_idx = 1 152 | 153 | kwargs = _tempogram_kwargs(t, freqs) 154 | 155 | # plot tempogram and ground_truth tempo 156 | ax[0].imshow(T, **kwargs) 157 | ax[0].hlines(reference_tempo, 158 | xmin=t[0], 159 | xmax=t[-1], 160 | label=f"ground_truth: {reference_tempo} bpm", 161 | color="r", 162 | linestyle="-.") 163 | ax[0].legend() 164 | 165 | # plot tempogram and tempo prediction 166 | median_prediction = np.median(predicted_tempo) 167 | ax[prediction_plot_idx].imshow(T, **kwargs) 168 | ax[prediction_plot_idx].hlines( 169 | median_prediction, 170 | xmin=t[0], xmax=t[-1], 171 | label=f"median(predictions): {median_prediction:.2f} BPM", 172 | color="r" 173 | ) 174 | ax[prediction_plot_idx].plot(t, predicted_tempo, color="orange", alpha=0.4) 175 | ax[prediction_plot_idx].scatter( 176 | t, 177 | predicted_tempo, 178 | label="predictions", 179 | s=6, 180 | color="orange", 181 | alpha=0.7) 182 | ax[prediction_plot_idx].legend() 183 | 184 | if xlim is None: 185 | xlim = (t[0], t[-1]) 186 | if ylim is None: 187 | ylim = (freqs[0], freqs[-1]) 188 | 189 | plt.setp(ax, xlim=xlim, ylim=ylim) 190 | 191 | if title is not None: 192 | fig.suptitle(title, fontsize=16) 193 | 194 | return fig, ax 195 | 196 | 197 | def plot_experiment_results( 198 | results, 199 | n_plots=10, 200 | ylim=None, 201 | theta=np.arange(30, 301, 1)): 202 | """ 203 | Plot experiment results 204 | 205 | Parameters 206 | --------- 207 | results : dict 208 | Dictionary with track_id, tempogram, times, tempi range, reference_tempo 209 | and predicted_tempo values. baseline_tempo is an optional key for the 210 | dictionary. 211 | n_plots : int 212 | Number of plots one desires to check 213 | """ 214 | if n_plots > len(results): 215 | raise ValueError( 216 | f"You're trying to plot {n_plots} samples, but there are {len(results)} available") 217 | 218 | if ylim is None: 219 | ylim = (30, 300) 220 | 221 | n_rows = int(n_plots // 2) 222 | n_cols = 2 223 | fig, ax = plt.subplots(n_rows, n_cols, figsize=(20, 25)) 224 | 225 | row_idx, col_idx = 0, 0 226 | plots = 0 227 | 228 | for track_id, values in results.items(): 229 | if plots >= n_plots: 230 | break 231 | 232 | T = values["T"] 233 | t = values["t"] 234 | freqs = values["freqs"] 235 | reference_tempo = values["reference_tempo"] 236 | predicted_tempo = values["predicted_tempo"] 237 | baseline_tempo = values.get("baseline_tempo", None) 238 | 239 | kwargs = _tempogram_kwargs(t, freqs) 240 | 241 | ax[row_idx][col_idx].imshow(T, **kwargs) 242 | ax[row_idx][col_idx].hlines( 243 | reference_tempo, 244 | xmin=t[0], xmax=t[-1], 245 | label=f"ground_truth: {reference_tempo} bpm", 246 | color="r", 247 | linestyle="-." 248 | ) 249 | 250 | if baseline_tempo is not None: 251 | median_baseline = np.median(baseline_tempo) 252 | ax[row_idx][col_idx].hlines( 253 | median_baseline, 254 | xmin=t[0], xmax=t[-1], 255 | label=f"median(baseline): {median_baseline:.2f} BPM", 256 | color="b", 257 | linestyle="--" 258 | ) 259 | ax[row_idx][col_idx].scatter( 260 | t, 261 | baseline_tempo, 262 | label="prediction by frame", 263 | s=6, 264 | color="blue", 265 | alpha=0.7) 266 | 267 | median_prediction = np.median(predicted_tempo) 268 | ax[row_idx][col_idx].hlines( 269 | median_prediction, 270 | xmin=t[0], xmax=t[-1], 271 | label=f"median(predictions): {median_prediction:.2f} BPM", 272 | color="green" 273 | ) 274 | # ax[row_idx][col_idx].plot(t, predicted_tempo, color="orange", alpha=0.4) 275 | ax[row_idx][col_idx].scatter( 276 | t, 277 | predicted_tempo, 278 | label="baseline by frame", 279 | s=6, 280 | color="green", 281 | alpha=0.7) 282 | 283 | ax[row_idx][col_idx].title.set_text(track_id) 284 | ax[row_idx][col_idx].legend() 285 | 286 | col_idx = (col_idx + 1) % n_cols 287 | if col_idx == 0: 288 | row_idx += 1 % n_rows 289 | 290 | plots += 1 291 | 292 | plt.setp(ax, ylim=ylim) 293 | plt.tight_layout() 294 | 295 | return fig, ax 296 | 297 | 298 | def _tempogram_kwargs(t, freqs): 299 | kwargs = {} 300 | x_ext1 = (t[1] - t[0]) / 2 301 | x_ext2 = (t[-1] - t[-2]) / 2 302 | y_ext1 = (freqs[1] - freqs[0]) / 2 303 | y_ext2 = (freqs[-1] - freqs[-2]) / 2 304 | 305 | kwargs["extent"] = [t[0] - x_ext1, t[-1] + 306 | x_ext2, freqs[0] - y_ext1, freqs[-1] + y_ext2] 307 | kwargs["cmap"] = "gray_r" 308 | kwargs["aspect"] = "auto" 309 | kwargs["origin"] = "lower" 310 | kwargs["interpolation"] = "nearest" 311 | 312 | return kwargs 313 | 314 | 315 | def plot_slice(tempo_slice, freqs, tempo, harmonics=False): 316 | ylim = tempo_slice.max() + 1 317 | plt.vlines( 318 | tempo, 319 | ymin=0, 320 | ymax=ylim, 321 | linestyle="--", 322 | colors="r", 323 | label=f"{tempo} bpm") 324 | 325 | if harmonics: 326 | plt.vlines(tempo * 2, ymin=0, ymax=ylim, linestyle="--", colors="r") 327 | plt.vlines(tempo * 4, ymin=0, ymax=ylim, linestyle="--", colors="r", 328 | alpha=0.6) 329 | 330 | plt.plot(freqs, tempo_slice, label="tempogram slice") 331 | plt.xlim(30, 300) 332 | plt.legend() 333 | 334 | 335 | def get_slope(model_output, frequencies): 336 | """ 337 | Return values to translate model output to BPM frequencies 338 | """ 339 | if not isinstance(model_output, np.ndarray): 340 | raise TypeError("head_output should be an array") 341 | 342 | x = model_output 343 | y = frequencies 344 | n = np.size(x) 345 | 346 | x_mean = np.mean(x) 347 | y_mean = np.mean(y) 348 | 349 | Sxy = np.sum(x * y) - n * x_mean * y_mean 350 | Sxx = np.sum(x * x) - n * x_mean * x_mean 351 | 352 | a = Sxy / Sxx 353 | b = y_mean - a * x_mean 354 | 355 | return a, b 356 | -------------------------------------------------------------------------------- /steme/models.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import tensorflow as tf 3 | 4 | from tensorflow.keras import layers 5 | 6 | 7 | def spice(sigma, w_tempo, w_recon): 8 | encoder_filter = 64 9 | decoder_filter = 32 10 | kernel_size = 3 11 | strides = 1 12 | padding = "same" 13 | 14 | x1 = layers.Input(shape=(128, 1)) 15 | x2 = layers.Input(shape=(128, 1)) 16 | k1 = layers.Input(shape=(1)) 17 | k2 = layers.Input(shape=(1)) 18 | 19 | conv1 = layers.Conv1D( 20 | encoder_filter, 21 | kernel_size, 22 | strides, 23 | padding=padding) 24 | bn1 = layers.BatchNormalization() 25 | relu1 = layers.ReLU() 26 | mp1 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 27 | conv2 = layers.Conv1D( 28 | encoder_filter * 2, 29 | kernel_size, 30 | strides, 31 | padding=padding) 32 | bn2 = layers.BatchNormalization() 33 | relu2 = layers.ReLU() 34 | mp2 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 35 | conv3 = layers.Conv1D( 36 | encoder_filter * 4, 37 | kernel_size, 38 | strides, 39 | padding=padding) 40 | bn3 = layers.BatchNormalization() 41 | relu3 = layers.ReLU() 42 | mp3 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 43 | conv4 = layers.Conv1D( 44 | encoder_filter * 8, 45 | kernel_size, 46 | strides, 47 | padding=padding) 48 | bn4 = layers.BatchNormalization() 49 | relu4 = layers.ReLU() 50 | mp4 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 51 | conv5 = layers.Conv1D( 52 | encoder_filter * 8, 53 | kernel_size, 54 | strides, 55 | padding=padding) 56 | bn5 = layers.BatchNormalization() 57 | relu5 = layers.ReLU() 58 | mp5 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 59 | conv6 = layers.Conv1D( 60 | encoder_filter * 8, 61 | kernel_size, 62 | strides, 63 | padding=padding) 64 | bn6 = layers.BatchNormalization() 65 | relu6 = layers.ReLU() 66 | mp6 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 67 | flatten = layers.Flatten() 68 | 69 | phead_dense = layers.Dense(48) 70 | phead_y = layers.Dense(1, activation="sigmoid") 71 | 72 | dense48 = layers.Dense(48) 73 | reshape = layers.Reshape((48, 1)) 74 | 75 | dtconv1 = layers.Conv1DTranspose( 76 | decoder_filter * 8, 77 | kernel_size, 78 | strides, 79 | padding=padding) 80 | dbn1 = layers.BatchNormalization() 81 | drelu1 = layers.ReLU() 82 | dmp1 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 83 | dtconv2 = layers.Conv1DTranspose( 84 | decoder_filter * 8, 85 | kernel_size, 86 | strides, 87 | padding=padding) 88 | dbn2 = layers.BatchNormalization() 89 | drelu2 = layers.ReLU() 90 | dmp2 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 91 | dtconv3 = layers.Conv1DTranspose( 92 | decoder_filter * 8, 93 | kernel_size, 94 | strides, 95 | padding=padding) 96 | dbn3 = layers.BatchNormalization() 97 | drelu3 = layers.ReLU() 98 | dmp3 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 99 | dtconv4 = layers.Conv1DTranspose( 100 | decoder_filter * 4, 101 | kernel_size, 102 | strides, 103 | padding=padding) 104 | dbn4 = layers.BatchNormalization() 105 | drelu4 = layers.ReLU() 106 | dmp4 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 107 | dtconv5 = layers.Conv1DTranspose( 108 | decoder_filter * 2, 109 | kernel_size, 110 | strides, 111 | padding=padding) 112 | dbn5 = layers.BatchNormalization() 113 | drelu5 = layers.ReLU() 114 | dmp5 = layers.MaxPool1D(pool_size=3, strides=2, padding=padding) 115 | 116 | dreshape = layers.Reshape((128, 1)) 117 | 118 | embd1 = conv1(x1) 119 | embd1 = relu1(embd1) 120 | embd1 = bn1(embd1) 121 | embd1 = mp1(embd1) 122 | embd1 = conv2(embd1) 123 | embd1 = relu2(embd1) 124 | embd1 = bn2(embd1) 125 | embd1 = mp2(embd1) 126 | embd1 = conv3(embd1) 127 | embd1 = relu3(embd1) 128 | embd1 = bn3(embd1) 129 | embd1 = mp3(embd1) 130 | embd1 = conv4(embd1) 131 | embd1 = relu4(embd1) 132 | embd1 = bn4(embd1) 133 | embd1 = mp4(embd1) 134 | embd1 = conv5(embd1) 135 | embd1 = relu5(embd1) 136 | embd1 = bn5(embd1) 137 | embd1 = mp5(embd1) 138 | embd1 = conv6(embd1) 139 | embd1 = relu6(embd1) 140 | embd1 = bn6(embd1) 141 | embd1 = mp6(embd1) 142 | embd1 = flatten(embd1) 143 | 144 | embd2 = conv1(x2) 145 | embd2 = relu1(embd2) 146 | embd2 = bn1(embd2) 147 | embd2 = mp1(embd2) 148 | embd2 = conv2(embd2) 149 | embd2 = relu2(embd2) 150 | embd2 = bn2(embd2) 151 | embd2 = mp2(embd2) 152 | embd2 = conv3(embd2) 153 | embd2 = relu3(embd2) 154 | embd2 = bn3(embd2) 155 | embd2 = mp3(embd2) 156 | embd2 = conv4(embd2) 157 | embd2 = relu4(embd2) 158 | embd2 = bn4(embd2) 159 | embd2 = mp4(embd2) 160 | embd2 = conv5(embd2) 161 | embd2 = relu5(embd2) 162 | embd2 = bn5(embd2) 163 | embd2 = mp5(embd2) 164 | embd2 = conv6(embd2) 165 | embd2 = relu6(embd2) 166 | embd2 = bn6(embd2) 167 | embd2 = mp6(embd2) 168 | embd2 = flatten(embd2) 169 | 170 | y1 = phead_dense(embd1) 171 | y1 = phead_y(y1) 172 | y2 = phead_dense(embd2) 173 | y2 = phead_y(y2) 174 | 175 | xhat1 = dense48(y1) 176 | xhat1 = reshape(xhat1) 177 | xhat1 = dtconv1(xhat1) 178 | xhat1 = drelu1(xhat1) 179 | xhat1 = dbn1(xhat1) 180 | xhat1 = dmp1(xhat1) 181 | xhat1 = dtconv2(xhat1) 182 | xhat1 = drelu2(xhat1) 183 | xhat1 = dbn2(xhat1) 184 | xhat1 = dmp2(xhat1) 185 | xhat1 = dtconv3(xhat1) 186 | xhat1 = drelu3(xhat1) 187 | xhat1 = dbn3(xhat1) 188 | xhat1 = dmp3(xhat1) 189 | xhat1 = dtconv4(xhat1) 190 | xhat1 = drelu4(xhat1) 191 | xhat1 = dbn4(xhat1) 192 | xhat1 = dmp4(xhat1) 193 | xhat1 = dtconv5(xhat1) 194 | xhat1 = drelu5(xhat1) 195 | xhat1 = dbn5(xhat1) 196 | xhat1 = dmp5(xhat1) 197 | xhat1 = dreshape(xhat1) 198 | 199 | xhat2 = dense48(y2) 200 | xhat2 = reshape(xhat2) 201 | xhat2 = dtconv1(xhat2) 202 | xhat2 = drelu1(xhat2) 203 | xhat2 = dbn1(xhat2) 204 | xhat2 = dmp1(xhat2) 205 | xhat2 = dtconv2(xhat2) 206 | xhat2 = drelu2(xhat2) 207 | xhat2 = dbn2(xhat2) 208 | xhat2 = dmp2(xhat2) 209 | xhat2 = dtconv3(xhat2) 210 | xhat2 = drelu3(xhat2) 211 | xhat2 = dbn3(xhat2) 212 | xhat2 = dmp3(xhat2) 213 | xhat2 = dtconv4(xhat2) 214 | xhat2 = drelu4(xhat2) 215 | xhat2 = dbn4(xhat2) 216 | xhat2 = dmp4(xhat2) 217 | xhat2 = dtconv5(xhat2) 218 | xhat2 = drelu5(xhat2) 219 | xhat2 = dbn5(xhat2) 220 | xhat2 = dmp5(xhat2) 221 | xhat2 = dreshape(xhat2) 222 | 223 | model = tf.keras.Model([x1, x2, k1, k2], [xhat1, xhat2, y1, y2]) 224 | 225 | h = tf.keras.losses.Huber( 226 | delta=0.25 * sigma, 227 | reduction="sum_over_batch_size") 228 | 229 | e_t = K.abs((y1 - y2) - sigma * (k2 - k1)) 230 | 231 | loss_tempo = h(e_t, 0) * w_tempo 232 | # https://math.stackexchange.com/questions/2690199/should-the-2-in-l-2-norm-notation-be-a-subscript-or-superscript 233 | loss_recon = K.mean(K.mean(K.square(x1 - xhat1) + 234 | K.square(x2 - xhat2), axis=1)) * w_recon 235 | 236 | model.add_loss(loss_tempo) 237 | model.add_loss(loss_recon) 238 | 239 | model.add_metric(loss_tempo, name="tempo_loss") 240 | model.add_metric(loss_recon, name="reconstruction_loss") 241 | 242 | return model 243 | 244 | 245 | def convolutional_autoencoder(sigma, w_tempo, w_recon): 246 | encoder_filter = 64 247 | decoder_filter = 64 248 | kernel_size = 3 249 | strides = 2 250 | padding = "same" 251 | activation = "relu" 252 | 253 | x1 = layers.Input(shape=(128, 1)) 254 | x2 = layers.Input(shape=(128, 1)) 255 | k1 = layers.Input(shape=(1)) 256 | k2 = layers.Input(shape=(1)) 257 | 258 | # encoder 259 | conv1 = layers.Conv1D( 260 | encoder_filter, 261 | kernel_size, 262 | strides, 263 | padding=padding, 264 | activation=activation) 265 | conv2 = layers.Conv1D( 266 | encoder_filter * 2, 267 | kernel_size, 268 | strides, 269 | padding=padding, 270 | activation=activation) 271 | conv3 = layers.Conv1D( 272 | encoder_filter * 4, 273 | kernel_size, 274 | strides, 275 | padding=padding, 276 | activation=activation) 277 | conv4 = layers.Conv1D( 278 | encoder_filter * 8, 279 | kernel_size, 280 | strides, 281 | padding=padding, 282 | activation=activation) 283 | conv5 = layers.Conv1D( 284 | encoder_filter * 8, 285 | kernel_size, 286 | strides, 287 | padding=padding, 288 | activation=activation) 289 | conv6 = layers.Conv1D( 290 | encoder_filter * 8, 291 | kernel_size, 292 | strides, 293 | padding=padding, 294 | activation=activation) 295 | flatten = layers.Flatten() 296 | 297 | # tempo head 298 | phead_dense = layers.Dense(48) 299 | phead_y = layers.Dense(1, activation="sigmoid") 300 | 301 | # decoder 302 | reshape = layers.Reshape((1, 1)) 303 | 304 | dtconv1 = layers.Conv1DTranspose( 305 | decoder_filter * 8, 306 | kernel_size, 307 | strides, 308 | padding=padding, 309 | activation=activation) 310 | dtconv2 = layers.Conv1DTranspose( 311 | decoder_filter * 8, 312 | kernel_size, 313 | strides, 314 | padding=padding, 315 | activation=activation) 316 | dtconv3 = layers.Conv1DTranspose( 317 | decoder_filter * 8, 318 | kernel_size, 319 | strides, 320 | padding=padding, 321 | activation=activation) 322 | dtconv4 = layers.Conv1DTranspose( 323 | decoder_filter * 4, 324 | kernel_size, 325 | strides, 326 | padding=padding, 327 | activation=activation) 328 | dtconv5 = layers.Conv1DTranspose( 329 | decoder_filter * 2, 330 | kernel_size, 331 | strides, 332 | padding=padding, 333 | activation=activation) 334 | dtconv6 = layers.Conv1DTranspose( 335 | decoder_filter, 336 | kernel_size, 337 | strides, 338 | padding=padding, 339 | activation=activation) 340 | dreshape = layers.Conv1DTranspose( 341 | 1, 342 | kernel_size, 343 | strides, 344 | padding=padding, 345 | activation=activation) 346 | 347 | # encoder 1 348 | embd1 = conv1(x1) 349 | embd1 = conv2(embd1) 350 | embd1 = conv3(embd1) 351 | embd1 = conv4(embd1) 352 | embd1 = conv5(embd1) 353 | embd1 = conv6(embd1) 354 | embd1 = flatten(embd1) 355 | 356 | y1 = phead_dense(embd1) 357 | y1 = phead_y(y1) 358 | 359 | xhat1 = reshape(y1) 360 | xhat1 = dtconv1(xhat1) 361 | xhat1 = dtconv2(xhat1) 362 | xhat1 = dtconv3(xhat1) 363 | xhat1 = dtconv4(xhat1) 364 | xhat1 = dtconv5(xhat1) 365 | xhat1 = dtconv6(xhat1) 366 | xhat1 = dreshape(xhat1) 367 | 368 | embd2 = conv1(x2) 369 | embd2 = conv2(embd2) 370 | embd2 = conv3(embd2) 371 | embd2 = conv4(embd2) 372 | embd2 = conv5(embd2) 373 | embd2 = conv6(embd2) 374 | embd2 = flatten(embd2) 375 | 376 | y2 = phead_dense(embd2) 377 | y2 = phead_y(y2) 378 | 379 | xhat2 = reshape(y2) 380 | xhat2 = dtconv1(xhat2) 381 | xhat2 = dtconv2(xhat2) 382 | xhat2 = dtconv3(xhat2) 383 | xhat2 = dtconv4(xhat2) 384 | xhat2 = dtconv5(xhat2) 385 | xhat2 = dtconv6(xhat2) 386 | xhat2 = dreshape(xhat2) 387 | 388 | model = tf.keras.Model([x1, x2, k1, k2], [xhat1, xhat2, y1, y2]) 389 | 390 | h = tf.keras.losses.Huber( 391 | delta=0.25 * sigma, 392 | reduction="sum_over_batch_size") 393 | 394 | e_t = K.abs((y1 - y2) - sigma * (k2 - k1)) 395 | 396 | loss_tempo = h(e_t, 0) * w_tempo 397 | # https://math.stackexchange.com/questions/2690199/should-the-2-in-l-2-norm-notation-be-a-subscript-or-superscript 398 | loss_recon = K.mean(K.mean(K.square(x1 - xhat1) + 399 | K.square(x2 - xhat2), axis=1)) * w_recon 400 | 401 | model.add_loss(loss_tempo) 402 | model.add_loss(loss_recon) 403 | 404 | model.add_metric(loss_tempo, name="tempo_loss") 405 | model.add_metric(loss_recon, name="reconstruction_loss") 406 | 407 | return model 408 | 409 | -------------------------------------------------------------------------------- /steme/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import Counter 4 | 5 | import h5py 6 | import numpy as np 7 | import mirdata 8 | from scipy.stats import tukeylambda, lognorm, uniform as st_uniform 9 | 10 | import steme.audio as audio 11 | import steme.loader as loader 12 | import steme.paths as paths 13 | 14 | 15 | def generate_biased_data(main_file, distribution, theta, t_type): 16 | """ 17 | Generates synthetic data (click tracks) following a distribution. 18 | 19 | main_file : str 20 | filename that will be used to store the data 21 | distribution : dict 22 | array with BPM/frequency. each track is duplicated following the 23 | frequency of the BPM. 24 | theta : float 25 | parameter to calculate the tempogram 26 | tempogram_type : str 27 | tempogram type. options are "fourier", "autocorrelation", "hybrid" 28 | """ 29 | with h5py.File(f"data/{main_file}.h5", "w") as hf: 30 | for track_id in distribution: 31 | sr = 22050 32 | x = audio.click_track(track_id, sr) 33 | 34 | if distribution[track_id] > 1: 35 | x = np.repeat(x, distribution[track_id]) 36 | 37 | T, t, bpm = audio.tempogram(x, sr, window_size_seconds=10, 38 | t_type=t_type, theta=theta) 39 | 40 | hf.create_dataset(str(track_id), data=T) 41 | 42 | return 43 | 44 | 45 | def remove_out_of_bound_data(distribution): 46 | """ 47 | Removes data below 30 BPM and above 350 BPM 48 | """ 49 | key_to_remove = [k for k, v in distribution.items() if k < 30 or k > 350] 50 | new_dist = distribution.copy() 51 | 52 | for k in key_to_remove: 53 | new_dist.pop(k) 54 | 55 | return new_dist 56 | 57 | 58 | def copy_data(main_file, new_file, ids): 59 | """ 60 | Copy a specific list of ids to a new file 61 | """ 62 | samples = 0 63 | with h5py.File(f"data/{main_file}.h5", "r") as rhf: 64 | with h5py.File(f"data/{new_file}.h5", "w") as whf: 65 | for track_id in ids: 66 | tempogram = rhf.get(str(track_id)) 67 | samples += tempogram.shape[1] 68 | 69 | whf.create_dataset(str(track_id), data=tempogram) 70 | return samples 71 | 72 | 73 | def gtzan_data(): 74 | gtzan = mirdata.initialize( 75 | "gtzan_genre", 76 | data_home=os.path.join( 77 | paths.DATASET_FOLDER, 78 | "gtzan_genre"), 79 | version="default") 80 | tracks = gtzan.track_ids 81 | 82 | # remove tracks with no tempo annotation 83 | tracks.remove("reggae.00086") 84 | tempi = [gtzan.track(track_id).tempo for track_id in tracks] 85 | 86 | 87 | return gtzan, tracks, tempi 88 | 89 | 90 | def giant_steps_data(): 91 | gs = loader.custom_dataset_loader( 92 | path=paths.DATASET_FOLDER, 93 | dataset_name="giantsteps-tempo-dataset", 94 | folder="" 95 | ) 96 | tracks = gs.track_ids 97 | # remove tracks with no tempo annotation 98 | tracks.remove("3041381.LOFI") 99 | tracks.remove("3041383.LOFI") 100 | tracks.remove("1327052.LOFI") 101 | 102 | tempi = [gs.track(track_id).tempo for track_id in tracks] 103 | 104 | return gs, tracks, tempi 105 | 106 | 107 | def ballroom_data(): 108 | ballroom = loader.custom_dataset_loader( 109 | path=paths.DATASET_FOLDER, 110 | dataset_name="ballroom", 111 | folder="" 112 | ) 113 | 114 | tempi = [ballroom.track(i).tempo for i in ballroom.track_ids] 115 | tracks = [i for i in ballroom.track_ids] 116 | return ballroom, tracks, tempi 117 | 118 | def gtzan_augmented_data(): 119 | gtzan_augmented = loader.custom_dataset_loader( 120 | path=paths.DATASET_FOLDER, 121 | dataset_name="gtzan_augmented", 122 | folder="", 123 | ) 124 | 125 | tracks = gtzan_augmented.track_ids 126 | tracks.remove("reggae.00086") 127 | tempi = [gtzan_augmented.track(track_id).tempo for track_id 128 | in tracks] 129 | 130 | return gtzan_augmented, tracks, tempi 131 | 132 | 133 | def gtzan_augmented_log_data(): 134 | gtzan_augmented = loader.custom_dataset_loader( 135 | path=paths.DATASET_FOLDER, 136 | dataset_name="gtzan_augmented_log", 137 | folder="", 138 | ) 139 | 140 | tracks = gtzan_augmented.track_ids 141 | tracks.remove("reggae.00086") 142 | tempi = [gtzan_augmented.track(track_id).tempo for track_id 143 | in tracks] 144 | 145 | return gtzan_augmented, tracks, tempi 146 | 147 | def gtzan_augmented_log_cropped_data(): 148 | gtzan_augmented = loader.custom_dataset_loader( 149 | path=paths.DATASET_FOLDER, 150 | dataset_name="gtzan_augmented_log_cropped", 151 | folder="", 152 | ) 153 | 154 | tracks = gtzan_augmented.track_ids 155 | tracks.remove("reggae.00086") 156 | tempi = [gtzan_augmented.track(track_id).tempo for track_id 157 | in tracks] 158 | 159 | return gtzan_augmented, tracks, tempi 160 | 161 | def brid_data(): 162 | brid = loader.custom_dataset_loader( 163 | path=paths.DATASET_FOLDER, 164 | dataset_name="brid", 165 | folder="" 166 | ) 167 | 168 | tempi = [brid.track(i).tempo for i in brid.track_ids] 169 | tracks = [i for i in brid.track_ids] 170 | return brid, tracks, tempi 171 | 172 | 173 | def get_metadata_file(main_file): 174 | return os.path.join(paths.DATA_FOLDER, f"{main_file}_metadata.h5") 175 | 176 | 177 | def read_dataset_info(main_file): 178 | dataset_metadata = get_metadata_file(main_file) 179 | print(f"Reading metadata file {dataset_metadata}") 180 | response = {} 181 | 182 | with h5py.File(dataset_metadata, "r") as hf: 183 | response["main_file"] = hf.get("main_file")[()].decode("UTF-8") 184 | response["validation_file"] = hf.get("validation_file")[ 185 | ()].decode("UTF-8") 186 | response["train_file"] = hf.get("train_file")[()].decode("UTF-8") 187 | response["main_filepath"] = hf.get("main_filepath")[()].decode("UTF-8") 188 | response["validation_filepath"] = hf.get("validation_filepath")[ 189 | ()].decode("UTF-8") 190 | response["train_filepath"] = hf.get("train_filepath")[ 191 | ()].decode("UTF-8") 192 | response["distribution"] = hf.get("distribution")[:] 193 | response["validation_setsize"] = hf.get("validation_setsize")[()] 194 | response["train_setsize"] = hf.get("train_setsize")[()] 195 | response["tmin"] = hf.get("tmin")[()] 196 | response["tmax"] = hf.get("tmax")[()] 197 | 198 | return response 199 | 200 | 201 | def write_dataset_info(main_file, response): 202 | dataset_metadata = get_metadata_file(main_file) 203 | print(f"Creating metadata file: {dataset_metadata}") 204 | 205 | with h5py.File(dataset_metadata, "w") as hf: 206 | for k, v in response.items(): 207 | hf.create_dataset(k, data=v) 208 | 209 | return response 210 | 211 | 212 | def generate_synthetic_dataset( 213 | dataset_name, 214 | dataset_type, 215 | theta, 216 | t_type, 217 | lam, 218 | loc, 219 | scale, 220 | size): 221 | """ 222 | Creates synthetic datasets that will follow a distribution 223 | """ 224 | 225 | if dataset_type == "tukey_lambda": 226 | dist = tukeylambda.rvs(lam, loc=loc, scale=scale, size=size, 227 | random_state=42) 228 | elif dataset_type == "lognorm": 229 | dist = lognorm.rvs( 230 | lam, 231 | loc=loc, 232 | scale=scale, 233 | size=size, 234 | random_state=42) 235 | elif dataset_type == "uniform": 236 | dist = st_uniform.rvs(loc=loc, scale=scale, size=1000, random_state=42) 237 | elif dataset_type == "gtzan_synthetic": 238 | _, _, dist = gtzan_data() 239 | elif dataset_type == "log_uniform": 240 | # tmin = theta[0] 241 | # tmax = theta[-1] 242 | tmin = 30 243 | tmax = 240 244 | 245 | dist = tmin * np.e**(np.random.rand(size) * np.log(tmax / tmin)) 246 | 247 | main_file = dataset_name 248 | print(f"dataset_name = {dataset_name}") 249 | 250 | dist_counter = Counter(dist) 251 | dist_counter = remove_out_of_bound_data(dist_counter) 252 | 253 | train_file = f"{main_file}_train" 254 | validation_file = f"{main_file}_validation" 255 | 256 | main_filepath = os.path.join(paths.DATA_FOLDER, f"{main_file}.h5") 257 | train_filepath = os.path.join(paths.DATA_FOLDER, f"{main_file}_train.h5") 258 | validation_filepath = os.path.join( 259 | paths.DATA_FOLDER, f"{main_file}_validation.h5") 260 | 261 | keys = [k for k in dist_counter.keys()] 262 | tmin = min(keys) 263 | tmax = max(keys) 264 | 265 | print("Generating biased files") 266 | if not os.path.isfile(main_filepath): 267 | generate_biased_data(main_file, dist_counter, theta, t_type) 268 | random.shuffle(keys) 269 | 270 | train_split = int(len(keys) * 0.8) 271 | train_ids = keys[:train_split] 272 | validation_ids = keys[train_split:] 273 | 274 | # create train and validation files 275 | train_samples = copy_data(main_file, train_file, train_ids) 276 | validation_samples = copy_data( 277 | main_file, validation_file, validation_ids) 278 | 279 | print(f"total train samples: {train_samples}") 280 | print(f"total validation samples: {validation_samples}") 281 | 282 | response = { 283 | "distribution": dist, 284 | "main_file": main_file, 285 | "train_file": train_file, 286 | "validation_file": validation_file, 287 | "main_filepath": main_filepath, 288 | "train_filepath": train_filepath, 289 | "validation_filepath": validation_filepath, 290 | "train_setsize": train_samples, 291 | "validation_setsize": validation_samples, 292 | "tmin": tmin, 293 | "tmax": tmax, 294 | } 295 | 296 | write_dataset_info(main_file, response) 297 | else: 298 | response = read_dataset_info(main_file) 299 | 300 | return response 301 | 302 | def lognormal70(): 303 | return lognorm.rvs(0.25, loc=30, scale=50, size=1000, random_state=42) 304 | 305 | def lognormal150(): 306 | return lognorm.rvs(0.25, loc=70, scale=50, size=1000, random_state=42) 307 | 308 | def lognormal170(): 309 | return lognorm.rvs(0.25, loc=120, scale=50, size=1000, random_state=42) 310 | 311 | def log_uniform(): 312 | return 30*np.e**(np.random.rand(1000)*np.log(240/30)) 313 | 314 | def uniform(): 315 | return st_uniform.rvs(30, scale=210,size=1000, random_state=42) 316 | 317 | def generate_dataset(dataset_name, dataset_type, theta, t_type): 318 | if dataset_type == "gtzan": 319 | gtzan, tracks, tempi = gtzan_data() 320 | elif dataset_type == "gtzan_augmented": 321 | gtzan, tracks, tempi = gtzan_augmented_data() 322 | elif dataset_type == "gtzan_augmented_log": 323 | gtzan, tracks, tempi = gtzan_augmented_log_data() 324 | elif dataset_type == "gtzan_augmented_log_cropped": 325 | gtzan, tracks, tempi = gtzan_augmented_log_cropped_data() 326 | elif dataset_type == "giant_steps": 327 | gs, tracks, tempi = giant_steps_data() 328 | elif dataset_type == "ballroom": 329 | ballroom, tracks, tempi = ballroom_data() 330 | elif dataset_type == "gtzan_giant_steps": 331 | gtzan, gtzan_tracks, gtzan_tempi = gtzan_data() 332 | gs, gs_tracks, gs_tempi = giant_steps_data() 333 | tracks = gtzan_tracks + gs_tracks 334 | tempi = gtzan_tempi + gs_tempi 335 | elif dataset_type == "brid": 336 | brid, tracks, tempi = brid_data() 337 | elif dataset_type == "gtzan+giant_steps": 338 | gtzan, gtzan_tracks, gtzan_tempi = gtzan_data() 339 | gs, gs_tracks, gs_tempi = giant_steps_data() 340 | 341 | tracks = gtzan_tracks + gs_tracks 342 | tempi = gtzan_tempi + gs_tempi 343 | 344 | main_file = f"{dataset_name}" 345 | train_file = f"{main_file}_train" 346 | validation_file = f"{main_file}_validation" 347 | 348 | main_filepath = os.path.join(paths.DATA_FOLDER, f"{main_file}.h5") 349 | train_filepath = os.path.join(paths.DATA_FOLDER, f"{main_file}_train.h5") 350 | validation_filepath = os.path.join( 351 | paths.DATA_FOLDER, f"{main_file}_validation.h5") 352 | 353 | tmin = min(tempi) 354 | tmax = max(tempi) 355 | 356 | print(f"Generating tempogram files: {main_file}.h5") 357 | if not os.path.isfile(main_filepath): 358 | with h5py.File(main_filepath, "w") as hf: 359 | for track_id in tracks: 360 | if "LOFI" in track_id: 361 | x, sr = gs.track(track_id).audio 362 | else: 363 | x, sr = gtzan.track(track_id).audio 364 | 365 | T, t, bpm = audio.tempogram( 366 | x, sr, window_size_seconds=10, t_type=t_type, theta=theta) 367 | 368 | hf.create_dataset(str(track_id), data=T) 369 | 370 | random.shuffle(tracks) 371 | 372 | train_split = int(len(tracks) * 0.8) 373 | train_ids = tracks[:train_split] 374 | validation_ids = tracks[train_split:] 375 | 376 | # create train and validation files 377 | train_samples = copy_data(main_file, train_file, train_ids) 378 | validation_samples = copy_data( 379 | main_file, validation_file, validation_ids) 380 | 381 | response = { 382 | "distribution": tempi, 383 | "main_file": main_file, 384 | "train_file": train_file, 385 | "validation_file": validation_file, 386 | "main_filepath": main_filepath, 387 | "train_filepath": train_filepath, 388 | "validation_filepath": validation_filepath, 389 | "train_setsize": train_samples, 390 | "validation_setsize": validation_samples, 391 | "tmin": tmin, 392 | "tmax": tmax, 393 | } 394 | write_dataset_info(main_file, response) 395 | else: 396 | response = read_dataset_info(main_file) 397 | 398 | return response 399 | 400 | 401 | def sigma(tmin, tmax, bins_per_octave): 402 | """ 403 | Calculates sigma for a given shift interval. 404 | """ 405 | if tmin > tmax: 406 | raise ValueError(f"tmin > tmax. {tmin, tmax}") 407 | 408 | sigma = 1 / (bins_per_octave * np.log2(tmax / tmin)) 409 | 410 | return sigma 411 | 412 | 413 | def sigma_diff(kmin, kmax): 414 | return 1 / (kmax - kmin) 415 | 416 | 417 | def get_tempogram_slices( 418 | T, 419 | F=128, 420 | slice_idx=None, 421 | kmin=0, 422 | kmax=8, 423 | shift_1=None, 424 | shift_2=None): 425 | """ 426 | Return a F-dimension slice from the tempogram 427 | 428 | Parameters 429 | --------- 430 | T : np.array 431 | Tempogram 432 | F : int, optional 433 | Size of the slice returned. Default is 128. 434 | If F > T.shape[0], raises an exception. 435 | slice_idx : int, optional 436 | Slice position you want to return. Default is None and returns a random 437 | slice. If slice_idx != None, returns specifically the slice_idx 438 | position. 439 | Return 440 | ------ 441 | tempo_sample_1, tempo_sample_2 : np.ndarray(1,F) 442 | shift_1, shift_2 : int 443 | Integers representing the shifts 444 | 445 | """ 446 | if T.shape[0] < F: 447 | raise ValueError( 448 | f"Dimensions mismatch. It is not possible to retrieve a {F}-slice from a {T.shape} matrix") 449 | 450 | if shift_1 is None and shift_2 is None: 451 | shift_1 = random.randrange(kmin, kmax) 452 | shift_2 = random.randrange(kmin, kmax) 453 | 454 | if slice_idx is None: 455 | slice_idx = random.randint(0, T.shape[1] - 1) 456 | 457 | tempo_sample_1 = T[shift_1:shift_1 + F, slice_idx].copy() 458 | tempo_sample_2 = T[shift_2:shift_2 + F, slice_idx].copy() 459 | 460 | tempo_sample_1 = tempo_sample_1 / (tempo_sample_1.max() + 1e-6) 461 | tempo_sample_2 = tempo_sample_2 / (tempo_sample_2.max() + 1e-6) 462 | 463 | # correct array dimension for training 464 | tempo_sample_1 = tempo_sample_1[:, np.newaxis] 465 | tempo_sample_2 = tempo_sample_2[:, np.newaxis] 466 | 467 | shift_1 = np.array([shift_1]) 468 | shift_2 = np.array([shift_2]) 469 | 470 | return tempo_sample_1, shift_1, tempo_sample_2, shift_2, slice_idx 471 | 472 | 473 | def tempo_data_generator(filename, set_size=12000, **kwargs): 474 | """ 475 | Parameters 476 | ---------- 477 | filename : str 478 | The file path 479 | set_size : int, optional 480 | Total number of samples in training/test set 481 | **kwargs , optional 482 | Parameters of get_tempogram_slices function 483 | 484 | Returns 485 | ------- 486 | (sample_1, sample_2, shift_1, shift_2), (sample_1, sample_2, shift_1, shift_2) 487 | Two tuples with both inputs and outputs. 488 | sample_1, sample_2 : np.array 489 | np.array containing tempogram from the test file 490 | shift_1, shift_2 : int 491 | The shift made in the representation to calculate the tempo 492 | """ 493 | 494 | if not os.path.isfile(filename): 495 | raise ValueError(f"File '{filename}' does not exist") 496 | 497 | if not h5py.is_hdf5(filename): 498 | raise ValueError(f"File '{filename}' is not an HDF5 file") 499 | 500 | if set_size <= 0: 501 | raise ValueError(f"Invalid set size {set_size}") 502 | 503 | with h5py.File(filename, "r") as hf: 504 | track_ids = [key for key in hf.keys()] 505 | for i in range(set_size): 506 | track_id = random.choice(track_ids) 507 | tempogram = hf.get(track_id) 508 | 509 | tempogram_1, shift_1, tempogram_2, shift_2, slice_idx = get_tempogram_slices( 510 | tempogram, **kwargs) 511 | 512 | yield (tempogram_1, tempogram_2, shift_1, shift_2), (tempogram_1, tempogram_2, shift_1, shift_2) 513 | 514 | 515 | def variables_2bpm(): 516 | theta = np.arange(30, 350, 2) 517 | kmin = 0 518 | kmax = 16 519 | return theta, kmin, kmax 520 | 521 | 522 | def variables_non_linear(tmin=25, bins_per_octave=40, n_bins=190): 523 | frequencies = 2.0 ** (np.arange(0, n_bins, dtype=float) / bins_per_octave) 524 | theta = tmin * frequencies 525 | 526 | return theta 527 | -------------------------------------------------------------------------------- /notebooks/calibrate_and_evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "8b6ccdf2", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import pickle\n", 12 | "import random\n", 13 | "\n", 14 | "import h5py\n", 15 | "import numpy as np\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import matplotlib.ticker as ticker\n", 18 | "import mirdata\n", 19 | "import pandas as pd\n", 20 | "import tensorflow as tf\n", 21 | "\n", 22 | "import steme.audio as audio\n", 23 | "import steme.dataset as dataset\n", 24 | "import steme.loader as loader\n", 25 | "import steme.metrics as metrics\n", 26 | "import steme.paths as paths\n", 27 | "import steme.utils as utils" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "a6cb0223", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "def plot_calibration(predictions, bins, model_name):\n", 38 | " fig, ax = plt.subplots(1, 1, figsize=(20,5))\n", 39 | " ax = [ax]\n", 40 | " ax[0].boxplot(predictions, vert=1)\n", 41 | " ax[0].xaxis.set_major_locator(ticker.FixedLocator(\n", 42 | " np.arange(1, len(bins)+1)\n", 43 | " ))\n", 44 | " ax[0].plot()\n", 45 | " ax[0].set_xlabel(\"BPM\")\n", 46 | " ax[0].set_ylabel(\"Model output\")\n", 47 | " ax[0].grid(True)\n", 48 | "# ax[0].title.set_text(model_name)#f\"Prediction with fixed shift. a = {np.round(a_fixed, 2)}, b = {np.round(b_fixed, 2)}\")\n", 49 | "\n", 50 | " fig.suptitle(model_name, fontsize=16)\n", 51 | " \n", 52 | " \n", 53 | "def _calibrate(bpm_dict, model, kmin, kmax, n_predictions=100, fixed=False):\n", 54 | " print(\"Calibrating model\")\n", 55 | " model_output = np.zeros(len(bpm_dict.keys()))\n", 56 | " j = 0\n", 57 | " for bpm in bpm_dict.keys():\n", 58 | " T = bpm_dict[bpm][\"T\"]\n", 59 | "\n", 60 | " preds = np.zeros(n_predictions)\n", 61 | " step = T.shape[1]//n_predictions\n", 62 | " \n", 63 | " for i in range(n_predictions):\n", 64 | " slice_idx = i*step\n", 65 | " s1, sh1, s2, sh2, _ = dataset.get_tempogram_slices(\n", 66 | " T=T, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0, slice_idx=slice_idx\n", 67 | " )\n", 68 | " \n", 69 | " s1 = s1[np.newaxis, :]\n", 70 | "\n", 71 | " xhat1, xhat2, y1, y2 = model.predict([s1, s1, sh1, sh1], verbose=0)\n", 72 | " preds[i] = y1[0][0]\n", 73 | "\n", 74 | " bpm_dict[bpm][\"slice\"] = s1[0,:,0]\n", 75 | " bpm_dict[bpm][\"shift\"] = sh1\n", 76 | " bpm_dict[bpm][\"estimation\"] = xhat1[0,:,0]\n", 77 | " bpm_dict[bpm][\"predictions\"] = np.array(preds)\n", 78 | "\n", 79 | " model_output[j] = np.median(np.array(preds))\n", 80 | " j += 1\n", 81 | "\n", 82 | " quad = np.poly1d(np.polyfit(model_output, list(bpm_dict.keys()), 2))\n", 83 | " a, b = utils.get_slope(model_output, list(bpm_dict.keys()))\n", 84 | "\n", 85 | " return bpm_dict, a, b, quad\n", 86 | " \n", 87 | "def read_dataset_info(main_file):\n", 88 | " dataset_metadata = os.path.join(\"/home/gigibs/Documents/steme/data\", f\"{main_file}_metadata.h5\")\n", 89 | " print(f\"Reading metadata file {dataset_metadata}\")\n", 90 | " response = {}\n", 91 | "\n", 92 | " with h5py.File(dataset_metadata, \"r\") as hf:\n", 93 | " response[\"main_file\"] = hf.get(\"main_file\")[()].decode(\"UTF-8\")\n", 94 | " response[\"validation_file\"] = hf.get(\"validation_file\")[()].decode(\"UTF-8\")\n", 95 | " response[\"train_file\"] = hf.get(\"train_file\")[()].decode(\"UTF-8\")\n", 96 | " response[\"main_filepath\"] = hf.get(\"main_filepath\")[()].decode(\"UTF-8\")\n", 97 | " response[\"validation_filepath\"] = hf.get(\"validation_filepath\")[()].decode(\"UTF-8\")\n", 98 | " response[\"train_filepath\"] = hf.get(\"train_filepath\")[()].decode(\"UTF-8\")\n", 99 | " response[\"distribution\"] = hf.get(\"distribution\")[:]\n", 100 | " response[\"validation_setsize\"] = hf.get(\"validation_setsize\")[()]\n", 101 | " response[\"train_setsize\"] = hf.get(\"train_setsize\")[()]\n", 102 | " response[\"tmin\"] = hf.get(\"tmin\")[()]\n", 103 | " response[\"tmax\"] = hf.get(\"tmax\")[()]\n", 104 | "\n", 105 | " return response\n", 106 | "\n", 107 | "def default_variables():\n", 108 | " return {\n", 109 | " \"tmin\": 25,\n", 110 | " \"n_bins\": 190,\n", 111 | " \"bins_per_octave\": 40,\n", 112 | " \"kmin\": 11, \n", 113 | " \"kmax\": 19\n", 114 | " }\n", 115 | "\n", 116 | "def wider_tempi_variables():\n", 117 | " return {\n", 118 | " \"tmin\": 20,\n", 119 | " \"n_bins\": 190,\n", 120 | " \"bins_per_octave\": 30,\n", 121 | " \"kmin\": 0, \n", 122 | " \"kmax\": 8\n", 123 | " }\n", 124 | "\n", 125 | "def calibration_results(dists, t_types):\n", 126 | " results_dict = {}\n", 127 | " for dist_name in dists:\n", 128 | " results_dict[dist_name] = {}\n", 129 | " for t_type in t_types:\n", 130 | " print(dist_name, t_type)\n", 131 | " dataset_name = f\"{dist_name}_{t_type}\"\n", 132 | "\n", 133 | " response = read_dataset_info(dataset_name)\n", 134 | " distribution = response[\"distribution\"]\n", 135 | "\n", 136 | " results_dict[dist_name][t_type] = {}\n", 137 | "\n", 138 | " model_name = f\"{dataset_name}_15_default\"\n", 139 | " model_path = f\"../models/{model_name}\"\n", 140 | "\n", 141 | " model = tf.keras.models.load_model(model_path)\n", 142 | "\n", 143 | " for idx, val in enumerate(center):\n", 144 | " results_dict[dist_name][t_type][val] = {}\n", 145 | " sr = 22050\n", 146 | " preds = np.zeros(n_predictions*tracks_per_bin)\n", 147 | "\n", 148 | " j = 0\n", 149 | "\n", 150 | " for bpm in center_dict[val]:\n", 151 | " x = audio.click_track(bpm=bpm, sr=sr)\n", 152 | " T, t, bpms = audio.tempogram(x, sr, window_size_seconds=10, t_type=t_type, theta=theta)\n", 153 | "\n", 154 | " step = T.shape[1]//n_predictions\n", 155 | "\n", 156 | " for i in range(n_predictions):\n", 157 | " slice_idx = i*step\n", 158 | " s1, sh1, s2, sh2, _ = dataset.get_tempogram_slices(\n", 159 | " T=T, kmin=kmin, kmax=kmax, shift_1=0, shift_2=0, slice_idx=slice_idx\n", 160 | " )\n", 161 | "\n", 162 | " s1 = s1[np.newaxis, :]\n", 163 | "\n", 164 | " xhat1, xhat2, y1, y2 = model.predict([s1, s1, sh1, sh1], verbose=0)\n", 165 | " preds[j] = y1[0][0]\n", 166 | " j += 1\n", 167 | " results_dict[dist_name][t_type][val][\"predictions\"] = np.array(preds)\n", 168 | "# results_dict[dist_name][t_type][val][\"tracks\"] = bpm_tracks\n", 169 | "\n", 170 | " del model\n", 171 | " return results_dict" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "ecf00f51", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "id": "06b16a62", 185 | "metadata": {}, 186 | "source": [ 187 | "## Calibrate model with synthetic data\n" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "id": "6fd0de06", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "variables = default_variables()\n", 198 | "tmin = variables[\"tmin\"]\n", 199 | "n_bins = variables[\"n_bins\"]\n", 200 | "bins_per_octave = variables[\"bins_per_octave\"]\n", 201 | "kmin, kmax = variables[\"kmin\"], variables[\"kmax\"]\n", 202 | "theta = dataset.variables_non_linear(tmin, n_bins=n_bins, bins_per_octave=bins_per_octave)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "8a06b203", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "step = 8\n", 213 | "offset = 5\n", 214 | "left = theta[(theta > 30) & (theta < 350)][::step]\n", 215 | "center = theta[(theta > 30) & (theta < 350)][offset::step]\n", 216 | "right = theta[(theta > 30) & (theta < 350)][offset::step]\n", 217 | "\n", 218 | "bins_tmp = []\n", 219 | "for i, j, k in zip(left, center, right):\n", 220 | " print(f\"boundaries for {np.round(j,2)}: [{np.round(np.sqrt(i*j),2)}, {np.round(np.sqrt(j*k))}]\")\n", 221 | " bins_tmp.append(i)\n", 222 | " bins_tmp.append(j)\n", 223 | "\n", 224 | "# len(bins_tmp)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "f0e9e21e", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "dists = [\n", 235 | " \"gtzan_augmented_log_25_190_40\"\n", 236 | "# \"log_uniform_25_190_40\",\n", 237 | "# \"synthetic_lognorm_0.7_30_50_1000_25_190_40\", \n", 238 | "# \"synthetic_lognorm_0.7_70_50_1000_25_190_40\",\n", 239 | "# \"synthetic_lognorm_0.7_120_50_1000_25_190_40\", \n", 240 | "# \"gtzan_25_190_40\",\n", 241 | "]\n", 242 | "\n", 243 | "t_types = [\"fourier\", \"autocorrelation\", \"hybrid\"]" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "id": "2e427437", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "n_predictions = 2\n", 254 | "tracks_per_bin = 50\n", 255 | "\n", 256 | "center_dict = {}\n", 257 | "for idx, val in enumerate(center):\n", 258 | " left_boundary = np.sqrt(left[idx]*center[idx])\n", 259 | " right_boundary = np.sqrt(center[idx]*right[idx])\n", 260 | " \n", 261 | " center_dict[val] = np.random.uniform(left_boundary, right_boundary, size=tracks_per_bin)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "06af1cfb", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "calculate_results = False\n", 272 | "calculate_center_bins = False\n", 273 | "\n", 274 | "try:\n", 275 | " with open(\"results_dict_aug.pkl\", \"rb\") as pickle_file:\n", 276 | " results_dict = pickle.load(pickle_file)\n", 277 | "except:\n", 278 | " calculate_results = True \n", 279 | " \n", 280 | "try:\n", 281 | " with open('center_dict_aug.pkl', 'rb') as f:\n", 282 | " center_dict = pickle.load(f)\n", 283 | "except:\n", 284 | " calculate_center_bins = True" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "id": "7eef0627", 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "calculate_results, calculate_center_bins" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": null, 300 | "id": "3ebc4ec1", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "if calculate_results:\n", 305 | " results_dict = calibration_results(dists, t_types)\n", 306 | " \n", 307 | " with open('results_dict_aug.pkl', 'wb') as f:\n", 308 | " pickle.dump(results_dict, f)\n", 309 | " \n", 310 | " with open('center_dict_aug.pkl', 'wb') as f:\n", 311 | " pickle.dump(center_dict, f)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "id": "02c13011", 318 | "metadata": { 319 | "scrolled": true 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "# for dist_name in dists:\n", 324 | "# for t_type in t_types:\n", 325 | "# model_name = f\"{dist_name}_{t_type}\"\n", 326 | "# res = results_dict[dist_name][t_type]\n", 327 | "# predictions = [v[\"predictions\"] for k, v in res.items()]\n", 328 | "# plot_calibration(predictions, np.round(center, 2), model_name)" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "id": "067d22b7", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "def plot_calibration_ax(predictions, bins, model_name, ax):\n", 339 | " ax.boxplot(predictions, vert=1)\n", 340 | " ax.xaxis.set_major_locator(ticker.FixedLocator(\n", 341 | " np.arange(1, len(bins)+1, 3)\n", 342 | " ))\n", 343 | "# ax.xaxis.set_xticks()\n", 344 | " ax.xaxis.set_major_formatter(ticker.FixedFormatter(bins[::3]))\n", 345 | " ax.plot()\n", 346 | " ax.grid(True, alpha=0.8)\n", 347 | "# ax.set_ylim(0, 1)\n", 348 | "# ax.title.set_text(model_name)#f\"Prediction with fixed shift. a = {np.round(a_fixed, 2)}, b = {np.round(b_fixed, 2)}\")\n", 349 | "\n", 350 | "# fig.suptitle(model_name, fontsize=16)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "id": "d58837dd", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "model_dict = {\n", 361 | "# \"synthetic_lognorm_0.7_30_50_1000_25_190_40\": \"lognorm @ 70\", \n", 362 | "# \"synthetic_lognorm_0.7_70_50_1000_25_190_40\": \"lognorm @ 120\",\n", 363 | "# \"synthetic_lognorm_0.7_120_50_1000_25_190_40\": \"lognorm @ 170\",\n", 364 | "# \"gtzan_25_190_40\": \"GTZAN\",\n", 365 | "# \"log_uniform_25_190_40\": \"log uniform\",\n", 366 | " \"gtzan_augmented_log_25_190_40\": \"gtzan_augmented\"\n", 367 | "}" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "id": "258f12b6", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "plt.rc(\"axes\", labelsize=18)\n", 378 | "plt.rc(\"xtick\", labelsize=15)\n", 379 | "plt.rc(\"ytick\", labelsize=15)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "id": "f3b5aa67", 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "import matplotlib\n", 390 | "from scipy.stats import lognorm, uniform\n", 391 | "\n", 392 | "def gtzan_data():\n", 393 | " import mirdata\n", 394 | " gtzan = mirdata.initialize(\"gtzan_genre\",\n", 395 | " data_home=\"../../datasets/gtzan_genre\",\n", 396 | " version=\"default\")\n", 397 | " tracks = gtzan.track_ids\n", 398 | " tracks.remove(\"reggae.00086\")\n", 399 | " tempi = [gtzan.track(track_id).tempo for track_id in tracks]\n", 400 | "\n", 401 | " return gtzan, tracks, tempi\n", 402 | "\n", 403 | "theta = dataset.variables_non_linear(25, 40, 190)\n", 404 | "bins = theta[(theta > 30) & (theta < 370)][::2]\n", 405 | "cmap = matplotlib.cm.get_cmap('tab10')\n", 406 | "dist_low = lognorm.rvs(0.25, loc=30, scale=50, size=1000, random_state=42)\n", 407 | "dist_medium = lognorm.rvs(0.25, loc=70, scale=50, size=1000, random_state=42)\n", 408 | "dist_high = lognorm.rvs(0.25, loc=120, scale=50, size=1000, random_state=42)\n", 409 | "dist_uniform = uniform.rvs(30, scale=210,size=1000, random_state=42)\n", 410 | "dist_log_uniform = 30*np.e**(np.random.rand(1000)*np.log(240/30))\n", 411 | "_, _, dist_gtzan = gtzan_data()\n", 412 | "dist_gtzan = np.array(dist_gtzan)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "id": "54ff536a", 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [ 422 | "# add_subplot example\n", 423 | "# https://towardsdatascience.com/customizing-multiple-subplots-in-matplotlib-a3e1c2e099bc\n", 424 | "# https://python-course.eu/numerical-programming/creating-subplots-in-matplotlib.php\n", 425 | "\n", 426 | "fig = plt.figure(figsize=(20, 10))\n", 427 | "plt.subplots_adjust(wspace= 0.25, hspace= 0.25)\n", 428 | "\n", 429 | "kwargs = {\n", 430 | " \"alpha\": 0.7,\n", 431 | " \"histtype\": \"stepfilled\"\n", 432 | "}\n", 433 | "\n", 434 | "dist_bins = np.round(bins)\n", 435 | "\n", 436 | "p0 = fig.add_subplot(4, 5, 1)\n", 437 | "p0.hist(dist_log_uniform, bins=dist_bins, label=\"log uniform\", edgecolor=\"black\", color=cmap.colors[3],**kwargs)\n", 438 | "p0.set_xscale(\"log\")\n", 439 | "p0.set_xticks([], [])\n", 440 | "p0.set_xticks(center.astype(int)[::3])\n", 441 | "p0.xaxis.set_major_formatter(ticker.ScalarFormatter())\n", 442 | "p0.set_xlim(28, 360)\n", 443 | "p0.set_ylim(0, 200)\n", 444 | "p0.set_ylabel(\"Distribution\")\n", 445 | "p0.set_title(\"log uniform\", fontsize=18)\n", 446 | "\n", 447 | "p2 = fig.add_subplot(4, 5, 2, sharex=p0, sharey=p0)\n", 448 | "p2.hist(dist_low, bins=dist_bins, label=\"lognorm @ 70\", edgecolor=\"black\", color=cmap.colors[0], **kwargs)\n", 449 | "p2.set_title(\"lognorm @ 70\", fontsize=18)\n", 450 | "\n", 451 | "p3 = fig.add_subplot(4, 5, 3, sharex=p0, sharey=p0)\n", 452 | "p3.hist(dist_medium, bins=dist_bins, label=\"lognorm @ 120\", edgecolor=\"black\",color=cmap.colors[2], **kwargs)\n", 453 | "p3.set_title(\"lognorm @ 120\", fontsize=18)\n", 454 | "\n", 455 | "p4 = fig.add_subplot(4, 5, 4, sharex=p0, sharey=p0)\n", 456 | "p4.hist(dist_high, bins=dist_bins, label=\"lognorm @ 170\", edgecolor=\"black\",color=cmap.colors[4], **kwargs)\n", 457 | "p4.set_title(\"lognorm @ 170\", fontsize=18)\n", 458 | "\n", 459 | "p5 = fig.add_subplot(4, 5, 5, sharex=p0, sharey=p0)\n", 460 | "p5.hist(dist_gtzan, bins=dist_bins, label=\"GTZAN\", edgecolor=\"black\", color=cmap.colors[8], **kwargs)\n", 461 | "p5.set_title(\"GTZAN\", fontsize=18)\n", 462 | "\n", 463 | "calibration_bins = center[::2].astype(int)\n", 464 | "plot_index = 6\n", 465 | "for t_type in t_types:\n", 466 | " for dist_name in dists:\n", 467 | " model_name = model_dict[dist_name]\n", 468 | " res = results_dict[dist_name][t_type]\n", 469 | " predictions = [v[\"predictions\"] for k, v in res.items()]\n", 470 | " \n", 471 | " p = fig.add_subplot(4, 5, plot_index)\n", 472 | " plot_calibration_ax(predictions, center.astype(int), model_name, p)\n", 473 | " \n", 474 | " if plot_index > 15:\n", 475 | " p.set_xlabel(\"BPM\")\n", 476 | " if plot_index == 6:\n", 477 | " p.set_ylabel(\"Fourier\")\n", 478 | " if plot_index == 11:\n", 479 | " p.set_ylabel(\"Autocorrelation\")\n", 480 | " if plot_index == 16:\n", 481 | " p.set_ylabel(\"Hybrid\") \n", 482 | " \n", 483 | " plot_index += 1\n", 484 | "\n", 485 | "#plt.savefig(\"calibration_with_dists.png\", format=\"png\")" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "id": "9bc0dc37", 492 | "metadata": {}, 493 | "outputs": [], 494 | "source": [ 495 | "fig = plt.figure(figsize=(20, 10))\n", 496 | "plot_index = 1\n", 497 | "for t_type in t_types:\n", 498 | " for dist_name in dists:\n", 499 | " model_name = model_dict[dist_name]\n", 500 | " res = results_dict[dist_name][t_type]\n", 501 | " predictions = [v[\"predictions\"] for k, v in res.items()]\n", 502 | " \n", 503 | " p = fig.add_subplot(3, 1, plot_index)\n", 504 | " p.set_title(f\"{model_name}_{t_type}\", fontsize=18)\n", 505 | "\n", 506 | " \n", 507 | " plot_calibration_ax(predictions, center.astype(int), model_name, p)\n", 508 | " plot_index += 1\n", 509 | " \n", 510 | "plt.tight_layout()" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "id": "efee07e9", 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [] 520 | } 521 | ], 522 | "metadata": { 523 | "kernelspec": { 524 | "display_name": "Python 3 (ipykernel)", 525 | "language": "python", 526 | "name": "python3" 527 | }, 528 | "language_info": { 529 | "codemirror_mode": { 530 | "name": "ipython", 531 | "version": 3 532 | }, 533 | "file_extension": ".py", 534 | "mimetype": "text/x-python", 535 | "name": "python", 536 | "nbconvert_exporter": "python", 537 | "pygments_lexer": "ipython3", 538 | "version": "3.10.6" 539 | } 540 | }, 541 | "nbformat": 4, 542 | "nbformat_minor": 5 543 | } 544 | -------------------------------------------------------------------------------- /notebooks/tempogram_types.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d156313a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import librosa \n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "\n", 14 | "import steme as st\n", 15 | "\n", 16 | "import IPython.display as ipd" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "37c72f1d", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "librosa.util.list_examples()" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "id": "dccc3a8b", 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "# x, fs = librosa.load(librosa.example(\"brahms\"))\n", 37 | "# you can explore further examples if you want to:\n", 38 | "x, fs = librosa.load(librosa.example(\"choice\"))" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "b2e1caf1", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "ipd.Audio(x, rate=fs)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "409ff09d", 54 | "metadata": {}, 55 | "source": [ 56 | "# linear axis versus log axis" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "afb1b77b", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "linear_axis = np.arange(30,350,1)\n", 67 | "log_axis = st.dataset.variables_non_linear()\n", 68 | "log_axis = log_axis[log_axis < 350]" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "6d462512", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "def plot_comparison(T, t, freqs, ttypes, subplot_titles, fig_title=None):\n", 79 | " \"\"\"\n", 80 | " helper function to plot tempograms side-by-side.\n", 81 | " \"\"\"\n", 82 | " figsize = (15, 5)\n", 83 | " num_tempograms = len(T)\n", 84 | " fig, ax = plt.subplots(1, num_tempograms, figsize=figsize)\n", 85 | "\n", 86 | " for idx in range(num_tempograms):\n", 87 | " kwargs = st.utils._tempogram_kwargs(t[idx], freqs[idx])\n", 88 | "\n", 89 | " ax[idx].imshow(T[idx], **kwargs)\n", 90 | "\n", 91 | " xlim = (t[idx][0], t[idx][-1])\n", 92 | " ylim = (freqs[idx][0], freqs[idx][-1])\n", 93 | "\n", 94 | " plt.setp(ax, xlim=xlim, ylim=ylim)\n", 95 | " \n", 96 | " if ttypes[idx] == \"log\":\n", 97 | " labels = [item.get_text() for item in ax[0].get_yticklabels()]\n", 98 | " new_labels = np.rint(log_axis[::20]).astype(int)\n", 99 | " ax[idx].set_yticklabels(new_labels)\n", 100 | "\n", 101 | " if fig_title is not None:\n", 102 | " fig.suptitle(fig_title, fontsize=16)\n", 103 | "\n", 104 | " ax[idx].set_xlabel(\"Time (s)\")\n", 105 | " ax[idx].set_ylabel(\"Tempo (BPM)\")\n", 106 | " ax[idx].title.set_text(subplot_titles[idx])\n", 107 | " return fig, ax" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "5f36d1b4", 113 | "metadata": {}, 114 | "source": [ 115 | "# tmp" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "1eed8093", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "figsize = (15, 5)\n", 126 | "fig, ax = plt.subplots(1, 1, figsize=figsize)\n", 127 | "\n", 128 | "tempo_estimation = librosa.feature.tempo(y=x, sr=fs)\n", 129 | "\n", 130 | "kwargs = st.utils._tempogram_kwargs(linear_ft, linear_ffreqs)\n", 131 | "ax.imshow(linear_fT, **kwargs)\n", 132 | "ax.hlines(tempo_estimation[0], xmin=0, xmax=1000, color=\"red\", label=\"estimated time\")\n", 133 | "\n", 134 | "xlim = (linear_ft[0], linear_ft[-1])\n", 135 | "ylim = (linear_ffreqs[0], linear_ffreqs[-1])\n", 136 | "\n", 137 | "plt.setp(ax, xlim=xlim, ylim=ylim)\n", 138 | "\n", 139 | "fig.suptitle(f\"Fourier Tempogram (~{np.round(tempo_estimation,2)[0]} BPM)\", fontsize=16)\n", 140 | "\n", 141 | "ax.set_xlabel(\"Time (s)\")\n", 142 | "ax.set_ylabel(\"Tempo (BPM)\")" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "f7ca1eb4", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "from matplotlib.patches import Rectangle" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "851454b8", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "full_slice_data = linear_fT[:,805]\n", 163 | "plt.plot(linear_ffreqs, full_slice_data)\n", 164 | "plt.vlines(tempo_estimation, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 165 | "plt.title(\"full slice (320-dimensional array)\")\n", 166 | "plt.xlabel(\"BPM\")\n", 167 | "plt.xlim(30, 350)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "59fc3ef8", 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "fig, ax = plt.subplots(1,1)\n", 178 | "plt.plot(linear_ffreqs, full_slice_data)\n", 179 | "# ax.vlines(tempo_estimation*2, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 180 | "# ax.vlines(tempo_estimation, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 181 | "ax.vlines(170, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 182 | "\n", 183 | "# ax.title.set_text(\"full slice (320 samples)\")\n", 184 | "# plt.title(f\"How the model is supposed to get THIS tempo value right?\")\n", 185 | "ax.set_xlabel(\"BPM\")\n", 186 | "\n", 187 | "# plt.title(f\"What if the tempo was {np.round(tempo_estimation[0])*2} BPM?\")\n", 188 | "# plt.title(\"full slice (320-dimensional array)\")\n", 189 | "\n", 190 | "#add rectangle to plot\n", 191 | "ax.add_patch(Rectangle((30, 0), 128, 12, fill=False, color=\"green\"))\n", 192 | "\n", 193 | "# ax.add_patch(Rectangle((38, 0), 128, 12, fill=False, color=\"green\", linestyle=\"--\"))\n", 194 | "\n", 195 | "ax.add_patch(Rectangle((180, 0), 128, 12, fill=False, color=\"green\", linestyle=\"--\"))" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "ebbf9a22", 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "plt.plot(full_slice_data[0:128])\n", 206 | "# plt.plot(linear_ffreqs[0:128], full_slice_data[0:128])\n", 207 | "\n", 208 | "# plt.vlines(tempo_estimation, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 209 | "plt.title(\"0-shift slice (128 samples), covering from 30 BPM to 158 BPM\")\n", 210 | "# plt.xlabel(\"BPM\")\n", 211 | "plt.xlim(0,128)\n", 212 | "# plt.xlim(linear_ffreqs[0],linear_ffreqs[128+0])" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "6b9d45a5", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "# plt.plot(linear_ffreqs[8:128+8], full_slice_data[8:128+8])\n", 223 | "plt.plot(full_slice_data[8:128+8])\n", 224 | "# plt.vlines(tempo_estimation, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 225 | "plt.title(\"8-shift slice (128 samples), covering from 38 BPM to 162 BPM\")\n", 226 | "# plt.xlabel(\"BPM\")\n", 227 | "plt.xlim(0,128)\n", 228 | "# plt.xlim(linear_ffreqs[8],linear_ffreqs[128+8])" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "c92ec4f3", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "figsize = (15, 5)\n", 239 | "fig, ax = plt.subplots(1, 1, figsize=figsize)\n", 240 | "\n", 241 | "tempo_estimation = librosa.feature.tempo(y=x, sr=fs)\n", 242 | "\n", 243 | "kwargs = st.utils._tempogram_kwargs(ft, ffreqs)\n", 244 | "ax.imshow(fT, **kwargs)\n", 245 | "ax.hlines(tempo_estimation[0], xmin=0, xmax=1000, color=\"red\", label=\"estimated time\")\n", 246 | "\n", 247 | "xlim = (ft[0], ft[-1])\n", 248 | "ylim = (ffreqs[0], ffreqs[-1])\n", 249 | "\n", 250 | "plt.setp(ax, xlim=xlim, ylim=ylim)\n", 251 | "\n", 252 | "fig.suptitle(f\"Fourier Tempogram (~{np.round(tempo_estimation,2)[0]} BPM)\", fontsize=16)\n", 253 | "\n", 254 | "ax.set_xlabel(\"Time (s)\")\n", 255 | "ax.set_ylabel(\"Tempo (BPM)\")" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "id": "edf64220", 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "fT[:,805].shape" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "5718fe0d", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "# everything but now for log\n", 276 | "full_slice_data = fT[:,805]\n", 277 | "plt.plot(ffreqs, full_slice_data)\n", 278 | "plt.xscale(\"log\")\n", 279 | "plt.vlines(tempo_estimation, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 280 | "plt.title(\"full slice (153-dimensional array)\")\n", 281 | "plt.xlabel(\"BPM\")\n", 282 | "plt.xlim(30, 350)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "id": "53e445e3", 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "import matplotlib.ticker as ticker\n" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "id": "9e8443f6", 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "fig, ax = plt.subplots(1,1)\n", 303 | "plt.plot(ffreqs, full_slice_data)\n", 304 | "# ax.vlines(tempo_estimation*2, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 305 | "ax.vlines(tempo_estimation, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 306 | "# ax.vlines(170, ymin=0, ymax=12, color=\"red\", alpha=0.8)\n", 307 | "ax.set_xscale(\"log\")\n", 308 | "ax.set_xticks([], [])\n", 309 | "ax.set_xticks(ffreqs.astype(int)[::15])\n", 310 | "ax.xaxis.set_major_formatter(ticker.ScalarFormatter())\n", 311 | "\n", 312 | "# ax.title.set_text(\"full slice (320 samples)\")\n", 313 | "# plt.title(f\"How the model is supposed to get THIS tempo value right?\")\n", 314 | "ax.set_xlabel(\"BPM\")\n", 315 | "\n", 316 | "# plt.title(f\"What if the tempo was {np.round(tempo_estimation[0])*2} BPM?\")\n", 317 | "plt.title(\"full slice (153-dimensional array)\")\n", 318 | "\n", 319 | "#add rectangle to plot\n", 320 | "ax.add_patch(Rectangle((30, 0), ffreqs[11:128+11][-1]-ffreqs[11:128+11][0], 12, fill=False, color=\"green\"))\n", 321 | "\n", 322 | "ax.add_patch(Rectangle((38, 0), ffreqs[18:128+18][-1]-ffreqs[18:128+18][0], 12, fill=False, color=\"green\", linestyle=\"--\"))\n", 323 | "\n", 324 | "# ax.add_patch(Rectangle((180, 0), 128, 12, fill=False, color=\"green\", linestyle=\"--\"))" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": null, 330 | "id": "c30778b2", 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "ffreqs[18:128+18]" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": null, 340 | "id": "e7fb6dc9", 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "plt.plot(full_slice_data[11:128+11])\n", 345 | "plt.title(\"11-shift slice (128 samples), covering from 30 BPM to 273 BPM\")\n", 346 | "plt.xlim(0,128)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "id": "ad12df9a", 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "plt.plot(full_slice_data[18:128+18])\n", 357 | "plt.title(\"18-shift slice (128 samples), covering from 34 BPM to 308 BPM\")\n", 358 | "plt.xlim(0,128)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "id": "b25b8247", 364 | "metadata": {}, 365 | "source": [ 366 | "# end tmp" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "id": "597a61b1", 372 | "metadata": {}, 373 | "source": [ 374 | "# tmp 2" 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "id": "c9ca5ce8", 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "figsize = (15, 5)\n", 385 | "fig, ax = plt.subplots(1, 1, figsize=figsize)\n", 386 | "\n", 387 | "increasing_tempo = np.array([])\n", 388 | "sr = 22500\n", 389 | "for i in range(1,8):\n", 390 | " tmp = st.audio.click_track(bpm=50*i, sr=sr, duration=10)\n", 391 | " increasing_tempo = np.append(increasing_tempo, tmp)\n", 392 | " \n", 393 | "x_fT, x_ft, x_ffreqs = st.audio.tempogram(x=increasing_tempo, sr=sr, window_size_seconds=10, t_type=\"hybrid\", theta=linear_axis)\n", 394 | "\n", 395 | " \n", 396 | "kwargs = st.utils._tempogram_kwargs(x_ft, x_ffreqs)\n", 397 | "ax.imshow(x_fT, **kwargs)\n", 398 | "\n", 399 | "xlim = (x_ft[0], x_ft[-1])\n", 400 | "ylim = (x_ffreqs[0], x_ffreqs[-1])\n", 401 | "\n", 402 | "plt.setp(ax, xlim=xlim, ylim=ylim)\n", 403 | "\n", 404 | "fig.suptitle(f\"Hybrid Tempogram\", fontsize=16)\n", 405 | "\n", 406 | "ax.set_xlabel(\"Time (s)\")\n", 407 | "ax.set_ylabel(\"Tempo (BPM)\")" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "id": "2f468ec7", 413 | "metadata": {}, 414 | "source": [ 415 | "# end tmp 2" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "id": "0bddd749", 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "linear_fT, linear_ft, linear_ffreqs = st.audio.tempogram(x=x, sr=fs, window_size_seconds=10, t_type=\"fourier\", theta=linear_axis)\n", 426 | "fT, ft, ffreqs = st.audio.tempogram(x=x, sr=fs, window_size_seconds=10, t_type=\"fourier\", theta=log_axis)\n", 427 | "\n", 428 | "fig, ax = plot_comparison(\n", 429 | " T=[linear_fT, fT], \n", 430 | " t=[linear_ft, ft], \n", 431 | " freqs=[linear_ffreqs, ffreqs], \n", 432 | " subplot_titles=[\"linear axis\", \"logarithmic axis\"],\n", 433 | " ttypes=[\"linear\", \"log\"],\n", 434 | " fig_title=\"Fourier tempogram\"\n", 435 | ")" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "id": "2d467719", 442 | "metadata": {}, 443 | "outputs": [], 444 | "source": [ 445 | "linear_aT, linear_at, linear_afreqs = st.audio.tempogram(x=x, sr=fs, window_size_seconds=10, t_type=\"autocorrelation\", theta=linear_axis)\n", 446 | "aT, at, afreqs = st.audio.tempogram(x=x, sr=fs, window_size_seconds=10, t_type=\"autocorrelation\", theta=log_axis)\n", 447 | "\n", 448 | "fig, ax = plot_comparison(\n", 449 | " T=[linear_aT, aT], \n", 450 | " t=[linear_at, at], \n", 451 | " freqs=[linear_afreqs, afreqs], \n", 452 | " subplot_titles=[\"linear axis\", \"logarithmic axis\"],\n", 453 | " ttypes=[\"linear\", \"log\"],\n", 454 | " fig_title=\"Autocorrelation tempogram\"\n", 455 | ")" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "id": "3804eb40", 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [ 465 | "linear_hT, linear_ht, linear_hfreqs = st.audio.tempogram(x=x, sr=fs, window_size_seconds=10, t_type=\"autocorrelation\", theta=linear_axis)\n", 466 | "hT, ht, hfreqs = st.audio.tempogram(x=x, sr=fs, window_size_seconds=10, t_type=\"hybrid\", theta=log_axis)\n", 467 | "\n", 468 | "fig, ax = plot_comparison(\n", 469 | " T=[linear_hT, hT], \n", 470 | " t=[linear_ht, ht], \n", 471 | " freqs=[linear_hfreqs, hfreqs], \n", 472 | " subplot_titles=[\"linear axis\", \"logarithmic axis\"],\n", 473 | " ttypes=[\"linear\", \"log\"],\n", 474 | " fig_title=\"Hybrid tempogram\"\n", 475 | ")" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "id": "8e0cdfb5", 481 | "metadata": {}, 482 | "source": [ 483 | "# same representation for steady and changing tempo" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "id": "9c8009d3", 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [ 493 | "x, x_fs = librosa.load(librosa.example(\"brahms\"))\n", 494 | "y, y_fs = librosa.load(librosa.example(\"choice\"))" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": null, 500 | "id": "2030d46f", 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "x_fT, x_ft, x_ffreqs = st.audio.tempogram(x=x, sr=x_fs, window_size_seconds=10, t_type=\"fourier\", theta=log_axis)\n", 505 | "y_fT, y_ft, y_ffreqs = st.audio.tempogram(x=y, sr=y_fs, window_size_seconds=10, t_type=\"fourier\", theta=log_axis)\n", 506 | "\n", 507 | "fig, ax = plot_comparison(\n", 508 | " T=[x_fT, y_fT], \n", 509 | " t=[x_ft, y_ft], \n", 510 | " freqs=[x_ffreqs, y_ffreqs], \n", 511 | " ttypes=[\"log\", \"log\"],\n", 512 | " subplot_titles=[\"changing tempo (brahms)\", \"steady tempo (choice)\"],\n", 513 | " fig_title=\"Fourier tempogram\"\n", 514 | ")" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "id": "5e392dfd", 521 | "metadata": {}, 522 | "outputs": [], 523 | "source": [ 524 | "x_fT, x_ft, x_ffreqs = st.audio.tempogram(x=x, sr=x_fs, window_size_seconds=10, t_type=\"autocorrelation\", theta=log_axis)\n", 525 | "y_fT, y_ft, y_ffreqs = st.audio.tempogram(x=y, sr=y_fs, window_size_seconds=10, t_type=\"autocorrelation\", theta=log_axis)\n", 526 | "\n", 527 | "fig, ax = plot_comparison(\n", 528 | " T=[x_fT, y_fT], \n", 529 | " t=[x_ft, y_ft], \n", 530 | " freqs=[x_ffreqs, y_ffreqs], \n", 531 | " ttypes=[\"log\", \"log\"],\n", 532 | " subplot_titles=[\"changing tempo (brahms)\", \"steady tempo (choice)\"],\n", 533 | " fig_title=\"Autocorrelation tempogram\"\n", 534 | ")" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "id": "e962ba6b", 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "x_fT, x_ft, x_ffreqs = st.audio.tempogram(x=x, sr=x_fs, window_size_seconds=10, t_type=\"hybrid\", theta=log_axis)\n", 545 | "y_fT, y_ft, y_ffreqs = st.audio.tempogram(x=y, sr=y_fs, window_size_seconds=10, t_type=\"hybrid\", theta=log_axis)\n", 546 | "\n", 547 | "fig, ax = plot_comparison(\n", 548 | " T=[x_fT, y_fT], \n", 549 | " t=[x_ft, y_ft], \n", 550 | " freqs=[x_ffreqs, y_ffreqs],\n", 551 | " ttypes=[\"log\", \"log\"],\n", 552 | " subplot_titles=[\"changing tempo (brahms)\", \"steady tempo (choice)\"],\n", 553 | " fig_title=\"Hybrid tempogram\"\n", 554 | ")\n" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": null, 560 | "id": "d9fe7dd3", 561 | "metadata": {}, 562 | "outputs": [], 563 | "source": [ 564 | "increasing_tempo = np.array([])\n", 565 | "sr = 22500\n", 566 | "for i in range(1,8):\n", 567 | " tmp = st.audio.click_track(bpm=50*i, sr=sr, duration=10)\n", 568 | " increasing_tempo = np.append(increasing_tempo, tmp)\n", 569 | " \n", 570 | "x_fT, x_ft, x_ffreqs = st.audio.tempogram(x=increasing_tempo, sr=sr, window_size_seconds=10, t_type=\"fourier\", theta=linear_axis)\n", 571 | "y_fT, y_ft, y_ffreqs = st.audio.tempogram(x=increasing_tempo, sr=sr, window_size_seconds=10, t_type=\"fourier\", theta=log_axis)\n", 572 | "\n", 573 | "fig, ax = plot_comparison(\n", 574 | " T=[x_fT, y_fT], \n", 575 | " t=[x_ft, y_ft], \n", 576 | " freqs=[x_ffreqs, y_ffreqs], \n", 577 | " subplot_titles=[\"linear\", \"log\"],\n", 578 | " ttypes=[\"linear\", \"log\"],\n", 579 | " fig_title=\"fourier tempogram\"\n", 580 | ")\n" 581 | ] 582 | }, 583 | { 584 | "cell_type": "markdown", 585 | "id": "5207ef70", 586 | "metadata": {}, 587 | "source": [ 588 | "# Interactive view" 589 | ] 590 | }, 591 | { 592 | "cell_type": "code", 593 | "execution_count": null, 594 | "id": "9f83abf1", 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [ 598 | "import holoviews as hv \n", 599 | "import panel as pn\n", 600 | "hv.extension(\"bokeh\", logo=False)" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": null, 606 | "id": "7f74e366", 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "increasing_tempo = np.array([]) #np.zeros([fs*7*3])\n", 611 | "sr = 22500\n", 612 | "for i in range(1,8):\n", 613 | " tmp = audio.click_track(bpm=50*i, sr=sr, duration=3)\n", 614 | " increasing_tempo = np.append(increasing_tempo, tmp)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": null, 620 | "id": "7f7c66b9", 621 | "metadata": {}, 622 | "outputs": [], 623 | "source": [ 624 | "audio_data = np.int16(increasing_tempo * 32767)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "id": "237b7a8a", 631 | "metadata": {}, 632 | "outputs": [], 633 | "source": [ 634 | "ipd.Audio(increasing_tempo, rate=sr)" 635 | ] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "id": "077b797d", 640 | "metadata": {}, 641 | "source": [ 642 | "## Fourier" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": null, 648 | "id": "3e760760", 649 | "metadata": {}, 650 | "outputs": [], 651 | "source": [ 652 | "fT, ft, ffreqs = audio.tempogram(x=increasing_tempo, sr=sr, window_size_seconds=10, t_type=\"fourier\", theta=log_axis)\n", 653 | "st.utils.plot_tempogram(fT, ft, ffreqs)" 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": null, 659 | "id": "faf8ef60", 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "fT, ft, ffreqs = st.audio.tempogram(x=increasing_tempo, sr=sr, window_size_seconds=10, t_type=\"fourier\", theta=linear_axis)\n", 664 | "\n", 665 | "spec_gram = hv.Image((ft, ffreqs, fT), [\"Time (s)\", \"Tempo (BPM)\"]).opts(width=600)\n", 666 | "audio = pn.pane.Audio(audio_data, sample_rate=sr, name='Audio', throttle=500)\n", 667 | "\n", 668 | "def update_playhead(x,y,t):\n", 669 | " if x is None:\n", 670 | " return hv.VLine(t)\n", 671 | " else:\n", 672 | " audio.time = x\n", 673 | " return hv.VLine(x)\n", 674 | "\n", 675 | "tap_stream = hv.streams.SingleTap(transient=True)\n", 676 | "time_play_stream = hv.streams.Params(parameters=[audio.param.time], rename={'time': 't'})\n", 677 | "dmap_time = hv.DynamicMap(update_playhead, streams=[time_play_stream, tap_stream])\n", 678 | "out = pn.Column( audio, \n", 679 | " (spec_gram * dmap_time))" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "id": "bec52d19", 686 | "metadata": {}, 687 | "outputs": [], 688 | "source": [ 689 | "out" 690 | ] 691 | }, 692 | { 693 | "cell_type": "markdown", 694 | "id": "f4b57874", 695 | "metadata": {}, 696 | "source": [ 697 | "## Autocorrelation" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": null, 703 | "id": "569c9751", 704 | "metadata": {}, 705 | "outputs": [], 706 | "source": [ 707 | "spec_gram = hv.Image((at, afreqs, aT), [\"Time (s)\", \"Tempo (BPM)\"]).opts(width=600)\n", 708 | "audio = pn.pane.Audio(audio_data, sample_rate=fs, name='Audio', throttle=500)\n", 709 | "\n", 710 | "def update_playhead(x,y,t):\n", 711 | " if x is None:\n", 712 | " return hv.VLine(t)\n", 713 | " else:\n", 714 | " audio.time = x\n", 715 | " return hv.VLine(x)\n", 716 | "\n", 717 | "tap_stream = hv.streams.SingleTap(transient=True)\n", 718 | "time_play_stream = hv.streams.Params(parameters=[audio.param.time], rename={'time': 't'})\n", 719 | "dmap_time = hv.DynamicMap(update_playhead, streams=[time_play_stream, tap_stream])\n", 720 | "out = pn.Column( audio, \n", 721 | " (spec_gram * dmap_time))" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": null, 727 | "id": "22f9410d", 728 | "metadata": {}, 729 | "outputs": [], 730 | "source": [ 731 | "out" 732 | ] 733 | }, 734 | { 735 | "cell_type": "markdown", 736 | "id": "2edfc8cb", 737 | "metadata": {}, 738 | "source": [ 739 | "## Hybrid" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "id": "2a43a1fe", 746 | "metadata": {}, 747 | "outputs": [], 748 | "source": [ 749 | "spec_gram = hv.Image((ht, hfreqs, hT), [\"Time (s)\", \"Tempo (BPM)\"]).opts(width=600)\n", 750 | "audio = pn.pane.Audio(audio_data, sample_rate=fs, name='Audio', throttle=500)\n", 751 | "\n", 752 | "def update_playhead(x,y,t):\n", 753 | " if x is None:\n", 754 | " return hv.VLine(t)\n", 755 | " else:\n", 756 | " audio.time = x\n", 757 | " return hv.VLine(x)\n", 758 | "\n", 759 | "tap_stream = hv.streams.SingleTap(transient=True)\n", 760 | "time_play_stream = hv.streams.Params(parameters=[audio.param.time], rename={'time': 't'})\n", 761 | "dmap_time = hv.DynamicMap(update_playhead, streams=[time_play_stream, tap_stream])\n", 762 | "out = pn.Column( audio, \n", 763 | " (spec_gram * dmap_time))" 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "execution_count": null, 769 | "id": "176ff00b", 770 | "metadata": {}, 771 | "outputs": [], 772 | "source": [ 773 | "out" 774 | ] 775 | } 776 | ], 777 | "metadata": { 778 | "kernelspec": { 779 | "display_name": "Python 3 (ipykernel)", 780 | "language": "python", 781 | "name": "python3" 782 | }, 783 | "language_info": { 784 | "codemirror_mode": { 785 | "name": "ipython", 786 | "version": 3 787 | }, 788 | "file_extension": ".py", 789 | "mimetype": "text/x-python", 790 | "name": "python", 791 | "nbconvert_exporter": "python", 792 | "pygments_lexer": "ipython3", 793 | "version": "3.10.6" 794 | } 795 | }, 796 | "nbformat": 4, 797 | "nbformat_minor": 5 798 | } 799 | -------------------------------------------------------------------------------- /notebooks/data_augmentation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f98a9c54", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import random\n", 12 | "from typing import List, Dict, Tuple\n", 13 | "\n", 14 | "import librosa\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "import matplotlib.ticker as ticker\n", 17 | "import numpy as np\n", 18 | "import pyrubberband as pyrb\n", 19 | "import soundfile as sf\n", 20 | "from scipy.stats import lognorm, uniform\n", 21 | "\n", 22 | "import steme.audio as audio\n", 23 | "import steme.dataset as dataset\n", 24 | "import steme.utils as utils" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "f438a30c", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import IPython.display as ipd" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "394d2bb8", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "DATASET_PATH = \"/home/gigibs/Documents/datasets/gtzan_augmented_log\"" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "f82ee15e", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "gtzan, tracks, tempi = dataset.gtzan_data()\n", 55 | "giant_steps, gs_tracks, gs_tempi = dataset.giant_steps_data()\n", 56 | "ballroom, b_tracks, b_tempi = dataset.ballroom_data()\n", 57 | "\n", 58 | "dist_low = dataset.lognormal70()\n", 59 | "\n", 60 | "theta = dataset.variables_non_linear(25, 40, 190)\n", 61 | "log_bins = theta[(theta > 30) & (theta < 370)][::2]\n", 62 | "# linear_bins = np.arange(30, 350, 10)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "ef7c95c2", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "bins = log_bins" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "c4792c3c", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "print(f\"gtzan size: {len(tracks)}\")\n", 83 | "print(f\"giant_steps size: {len(gs_tracks)}\")\n", 84 | "print(f\"ballroom size: {len(b_tracks)}\")\n", 85 | "print(f\"lognorm @ 70 size: {len(dist_low)}\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "c8a7b91d", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "fig, ax = plt.subplots(1,3, figsize=(10,5))\n", 96 | "\n", 97 | "ax[0].hist(tempi, bins=bins, color=\"orange\", alpha=0.7, label=\"gtzan\")\n", 98 | "ax[0].title.set_text(\"GTZAN (999 tracks)\")\n", 99 | "ax[0].set_xlabel(\"BPM\")\n", 100 | "ax[0].set_ylabel(\"# tracks\")\n", 101 | "ax[1].hist(gs_tempi, bins=bins, color=\"red\", alpha=0.7, label=\"giant_steps\")\n", 102 | "ax[1].title.set_text(\"Giant Steps (659 tracks)\")\n", 103 | "ax[1].set_xlabel(\"BPM\")\n", 104 | "ax[1].set_ylabel(\"# tracks\")\n", 105 | "ax[2].hist(b_tempi, bins=bins, color=\"blue\", alpha=0.7, label=\"ballroom\")\n", 106 | "ax[2].title.set_text(\"Ballroom (698 tracks)\")\n", 107 | "ax[2].set_xlabel(\"BPM\")\n", 108 | "ax[2].set_ylabel(\"# tracks\")\n", 109 | "\n", 110 | "#ax.hist(dist_low, bins=bins, color=\"green\", alpha=0.7, label=\"lognorm @ 70\")\n", 111 | "\n", 112 | "plt.tight_layout()\n", 113 | "# plt.savefig(\"datasets_tempo_distribution.svg\")" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "id": "770ce0dc", 119 | "metadata": {}, 120 | "source": [ 121 | "## Augmenting GTZAN" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "id": "5144bdb6", 127 | "metadata": {}, 128 | "source": [ 129 | "# Approach 1: time streching only GTZAN" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "6d9bb79c", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "finer_bins = log_bins[::2]" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "a0394063", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "plt.hist([tempi, dist_low], bins=finer_bins, color=[\"red\", \"orange\"], label=[\"gtzan\", \"lognorm@70\"])\n", 150 | "plt.legend()" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "a901884a", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "dist_low_hist = np.histogram(dist_low, bins=finer_bins)\n", 161 | "gtzan_dist = np.histogram(tempi, bins=finer_bins)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "f71cd59e", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "diff_tempi = dist_low_hist[0] - gtzan_dist[0]" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "b58800b5", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "fig, ax = plt.subplots(2,1)\n", 182 | "ax[0].hist([tempi, dist_low], finer_bins, alpha=0.7, label=[\"gtzan\", \"lognorm@70\"], color=[\"red\", \"orange\"], \n", 183 | " stacked=False)\n", 184 | "ax[0].legend()\n", 185 | "#plt.hist(dist_low, finer_bins, alpha=0.7, label=\"lognorm@70\", color=\"orange\")\n", 186 | "ax[1].bar(finer_bins[1:], diff_tempi, 2.5, alpha=0.5, label=\"diff\", color=\"blue\")\n", 187 | "\n", 188 | "ax[1].legend()" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "d8e6e9a5", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "# gtzan_info = {bin: qtd de faixas no bin}\n", 199 | "# transformation_dict = {bin: transformação}\n", 200 | "# se transformation_dict[bin] <= 0 e gtzan_info >= 0, faz a transformação pra faixa necessária\n", 201 | "# se transformation_dict[bin] >= 0 e gtzan_info >= 0, pula pro próximo bin" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "a19ae8c9", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "def create_transformation_dict(verbose=True):\n", 212 | " removals = 0\n", 213 | " additions = 0\n", 214 | "\n", 215 | " transformation_dict = {}\n", 216 | "\n", 217 | " for idx, value in enumerate(diff_tempi): \n", 218 | " transformation_dict[f\"{finer_bins[idx]}, {finer_bins[idx+1]}\"] = value\n", 219 | "\n", 220 | " if value < 0:\n", 221 | " message = f\"remove {value} samples\"\n", 222 | " removals += np.abs(value)\n", 223 | " elif value > 0:\n", 224 | " message = f\"add {value} samples\"\n", 225 | " additions += value\n", 226 | " else:\n", 227 | " message = \"do nothing\"\n", 228 | " \n", 229 | " if verbose:\n", 230 | " print(f\"{finer_bins[idx]} - {finer_bins[idx+1]}: {message}\")\n", 231 | " \n", 232 | " if verbose:\n", 233 | " print(f\"total removals = {removals}, total additions = {additions}\")\n", 234 | " return transformation_dict\n", 235 | "\n", 236 | "def reset_transformation_dict():\n", 237 | " return create_transformation_dict(verbose=False)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "36e2f469", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "def create_helper_dict(bins):\n", 248 | " # criar um dicionário com intervalo: {track_ids}\n", 249 | " # [30,40]: [\"classical.0000\", \"blues.0010\"]\n", 250 | " helper_dict = {}\n", 251 | " for idx in range(len(bins)-1):\n", 252 | " helper_dict[f\"{bins[idx]}, {bins[idx+1]}\"] = []\n", 253 | " return helper_dict" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "255a09d0", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "helper_dict = create_helper_dict(finer_bins)\n", 264 | "gtzan_mapping = {}\n", 265 | "\n", 266 | "for i in tracks:\n", 267 | " tempo = gtzan.track(i).tempo\n", 268 | " \n", 269 | " boundaries = np.digitize(tempo, finer_bins)\n", 270 | " gtzan_mapping[i] = (tempo, f\"{finer_bins[boundaries-1]}, {finer_bins[boundaries]}\")\n", 271 | " helper_dict[f\"{finer_bins[boundaries-1]}, {finer_bins[boundaries]}\"].append(i)" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "id": "94fb3a86", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "def check_missing_tracks(transformation_dict):\n", 282 | " for k, v in list(transformation_dict.items())[::-1]:\n", 283 | "# for k, v in transformation_dict.items():\n", 284 | " if v > 0:\n", 285 | " return k, v" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "id": "01c9b10f", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "def key_boundaries(key):\n", 296 | " return [float(i) for i in key.split(\", \")]" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "f70f1654", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "transformation_dict = reset_transformation_dict()\n", 307 | "augmented_dict = transformation_dict.copy()" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "id": "62a9ffbb", 314 | "metadata": {}, 315 | "outputs": [], 316 | "source": [ 317 | "to_remove = []\n", 318 | "j = 0\n", 319 | "for key, val in list(transformation_dict.items())[::-1]:\n", 320 | "# for key, val in transformation_dict.items():\n", 321 | " if val < 0:\n", 322 | " print(f\"augmenting tracks from {key}\")\n", 323 | " for track_id in helper_dict[key]:\n", 324 | " print(track_id)\n", 325 | " original_tempo = gtzan.track(track_id).tempo\n", 326 | " original_boundaries = gtzan_mapping[track_id][1]\n", 327 | "\n", 328 | " str_boundaries = check_missing_tracks(transformation_dict)\n", 329 | "\n", 330 | " if str_boundaries is None:\n", 331 | " # print(transformation_dict)\n", 332 | " # we're done then!\n", 333 | " break \n", 334 | "\n", 335 | " new_tempo_boundaries = key_boundaries(str_boundaries[0])\n", 336 | " \n", 337 | " if key == str_boundaries[0]:\n", 338 | " print(f\"we will not transform {key} into {str_boundaries[0]}\")\n", 339 | "# transformation_dict[str_boundaries[0]] -= 1\n", 340 | " break\n", 341 | " \n", 342 | " new_tempo = random.uniform(float(new_tempo_boundaries[0]), float(new_tempo_boundaries[1]))\n", 343 | "\n", 344 | "# print(f\"transforming tracks from {key} to {new_tempo_boundaries}\")\n", 345 | "\n", 346 | " tempo_rate = new_tempo/original_tempo\n", 347 | "\n", 348 | " x, fs = gtzan.track(track_id).audio\n", 349 | " to_remove.append(track_id)\n", 350 | "\n", 351 | "# print(f\"original_tempo {original_tempo}, new_tempo {new_tempo}, tempo_rate {tempo_rate}\")\n", 352 | "\n", 353 | " # pyrubberband parameters\n", 354 | " rbags = {\"-2\": \"\"} # choose finer algorithms to have a better quali\n", 355 | " x_stretch = pyrb.time_stretch(x, fs, tempo_rate)\n", 356 | "\n", 357 | " # print(f\"augmented one track from {original_boundaries} to {str_boundaries[0]}\")\n", 358 | " transformation_dict[str_boundaries[0]] -= 1\n", 359 | " transformation_dict[original_boundaries] += 1\n", 360 | " augmented_dict[str_boundaries[0]] -= 1\n", 361 | " augmented_dict[original_boundaries] += 1\n", 362 | " \n", 363 | " # save audio\n", 364 | " sf.write(os.path.join(DATASET_PATH, f\"audio/{track_id}_augmented.wav\"), x_stretch, fs, subtype=\"PCM_24\")\n", 365 | " # save tempo \n", 366 | " with open(os.path.join(DATASET_PATH, f\"annotations/tempo/{track_id}_augmented.bpm\"), \"w\") as f:\n", 367 | " f.write(str(new_tempo))" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "id": "420ece0e", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "for track_id in to_remove:\n", 378 | " try:\n", 379 | "# print(f\"removing {track_id}\")\n", 380 | " os.remove(os.path.join(DATASET_PATH, f\"audio/{track_id}.wav\"))\n", 381 | " os.remove(os.path.join(DATASET_PATH, f\"annotations/tempo/{track_id}.bpm\"))\n", 382 | " except:\n", 383 | "# print(\"already removed\")\n", 384 | " continue" 385 | ] 386 | }, 387 | { 388 | "cell_type": "code", 389 | "execution_count": null, 390 | "id": "fe1a28e5", 391 | "metadata": {}, 392 | "outputs": [], 393 | "source": [ 394 | "import steme.loader as loader" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "id": "8a23a2b3", 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "gtzan_augmented = loader.custom_dataset_loader(\n", 405 | " path=DATASET_PATH,\n", 406 | " dataset_name=\"\",\n", 407 | " folder=\"\",\n", 408 | ")" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "id": "dbd1dc04", 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "gtzan_augmented_tracks = gtzan_augmented.track_ids\n", 419 | "gtzan_augmented_tracks.remove(\"reggae.00086\")\n", 420 | "gtzan_augmented_tempi = [gtzan_augmented.track(track_id).tempo for track_id in gtzan_augmented_tracks]" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "id": "7aad6314", 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "plt.hist(gtzan_augmented_tempi, bins=finer_bins, color=\"red\", label=\"gtzan_augmented\")\n", 431 | "plt.legend()" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "id": "07fa94ec", 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [ 441 | "plt.hist(\n", 442 | " [gtzan_augmented_tempi, dist_low], \n", 443 | " bins=np.arange(30,200,10), \n", 444 | " color=[\"blue\", \"orange\"], \n", 445 | " label=[\"gtzan_augmented\", \"lognorm@70\"]\n", 446 | ")\n", 447 | "plt.legend()" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "id": "5fdefebc", 454 | "metadata": {}, 455 | "outputs": [], 456 | "source": [ 457 | "len(gtzan_augmented_tracks)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "id": "8f2dff3a", 463 | "metadata": {}, 464 | "source": [ 465 | "# Quality comparison" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": null, 471 | "id": "6ca608b6", 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "# load original track\n", 476 | "orig_x, orig_fs = gtzan.track(\"blues.00002\").audio" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "id": "21078c71", 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "ipd.Audio(orig_x, rate=orig_fs)" 487 | ] 488 | }, 489 | { 490 | "cell_type": "code", 491 | "execution_count": null, 492 | "id": "7987dc41", 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "aug_x, aug_fs = gtzan_augmented.track(\"blues.00002_augmented\").audio" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": null, 502 | "id": "bfc29848", 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "ipd.Audio(aug_x, rate=aug_fs)" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": null, 512 | "id": "d7925f55", 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "orig_nov, _ = audio.spectral_flux(orig_x, orig_fs, n_fft=2048, hop_length=512)\n", 517 | "orig_frame_time = librosa.frames_to_time(np.arange(len(orig_nov)),\n", 518 | " sr=orig_fs,\n", 519 | " hop_length=512)\n", 520 | "\n", 521 | "aug_nov, _ = audio.spectral_flux(aug_x[:30*aug_fs], aug_fs, n_fft=2048, hop_length=512)\n", 522 | "aug_frame_time = librosa.frames_to_time(np.arange(len(aug_nov)),\n", 523 | " sr=aug_fs,\n", 524 | " hop_length=512)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "id": "c9f6ebbc", 531 | "metadata": {}, 532 | "outputs": [], 533 | "source": [ 534 | "plt.plot(orig_frame_time, orig_nov, color=\"red\", label=\"original audio\")\n", 535 | "plt.plot(aug_frame_time, aug_nov, color=\"blue\", label=\"augmented audio\")\n", 536 | "plt.legend()" 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": null, 542 | "id": "f641328c", 543 | "metadata": {}, 544 | "outputs": [], 545 | "source": [ 546 | "linear_theta = np.arange(30,350,1)" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "id": "0b01c043", 553 | "metadata": {}, 554 | "outputs": [], 555 | "source": [ 556 | "orig_T, orig_fT, orig_times = audio.tempogram(orig_x, orig_fs, 10, \"fourier\", linear_theta)\n", 557 | "aug_T, aug_fT, aug_times = audio.tempogram(aug_x, aug_fs, 10, \"fourier\", linear_theta)" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "id": "796dc79c", 564 | "metadata": {}, 565 | "outputs": [], 566 | "source": [ 567 | "def plot_comparison(T, t, freqs, ttypes, subplot_titles, fig_title=None):\n", 568 | " \"\"\"\n", 569 | " helper function to plot tempograms side-by-side.\n", 570 | " \"\"\"\n", 571 | " figsize = (15, 5)\n", 572 | " num_tempograms = len(T)\n", 573 | " fig, ax = plt.subplots(1, num_tempograms, figsize=figsize)\n", 574 | "\n", 575 | " for idx in range(num_tempograms):\n", 576 | " kwargs = utils._tempogram_kwargs(t[idx], freqs[idx])\n", 577 | "\n", 578 | " ax[idx].imshow(T[idx], **kwargs)\n", 579 | "\n", 580 | " xlim = (t[idx][0], t[idx][-1])\n", 581 | " ylim = (freqs[idx][0], freqs[idx][-1])\n", 582 | "\n", 583 | " #plt.setp(ax, xlim=xlim, ylim=ylim)\n", 584 | " \n", 585 | " if ttypes[idx] == \"log\":\n", 586 | " labels = [item.get_text() for item in ax[0].get_yticklabels()]\n", 587 | " new_labels = np.rint(log_axis[::20]).astype(int)\n", 588 | " ax[idx].set_yticklabels(new_labels)\n", 589 | "\n", 590 | " if fig_title is not None:\n", 591 | " fig.suptitle(fig_title, fontsize=16)\n", 592 | "\n", 593 | " ax[idx].set_xlabel(\"Time (s)\")\n", 594 | " ax[idx].set_ylabel(\"Tempo (BPM)\")\n", 595 | " ax[idx].title.set_text(subplot_titles[idx])\n", 596 | " return fig, ax" 597 | ] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "id": "0177054b", 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [ 606 | "plot_comparison([orig_T, aug_T], [orig_fT, aug_fT], [orig_times, aug_times], subplot_titles=[f\"orig {orig_bpm}\", f\"aug {aug_bpm}\"], ttypes=\"linear\")" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": null, 612 | "id": "89d90d83", 613 | "metadata": {}, 614 | "outputs": [], 615 | "source": [ 616 | "orig_bpm = gtzan.track(\"blues.00002\").tempo\n", 617 | "#utils.plot_tempogram(orig_T, orig_fT, orig_times, title=f\"Original audio ({orig_bpm} BPM))\")" 618 | ] 619 | }, 620 | { 621 | "cell_type": "code", 622 | "execution_count": null, 623 | "id": "72019405", 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "aug_bpm = gtzan_augmented.track(\"blues.00002_augmented\").tempo\n", 628 | "utils.plot_tempogram(aug_T, aug_fT, aug_times, title=f\"Augmented ({aug_bpm} BPM)\")" 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "id": "2d1a68b6", 634 | "metadata": {}, 635 | "source": [ 636 | "## Approach 2: Augmentation in the tempogram domain" 637 | ] 638 | }, 639 | { 640 | "cell_type": "code", 641 | "execution_count": null, 642 | "id": "cfc3fe46", 643 | "metadata": {}, 644 | "outputs": [], 645 | "source": [ 646 | "# first looking at the linear scenario" 647 | ] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "execution_count": null, 652 | "id": "6537e17d", 653 | "metadata": {}, 654 | "outputs": [], 655 | "source": [ 656 | "orig_T, orig_fT, orig_times = audio.tempogram(orig_x, orig_fs, 10, \"fourier\", linear_theta)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": null, 662 | "id": "7273000b", 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "#utils.plot_tempogram(orig_T, orig_fT, orig_times, title=\"Original audio\")" 667 | ] 668 | }, 669 | { 670 | "cell_type": "code", 671 | "execution_count": null, 672 | "id": "d9299c79", 673 | "metadata": {}, 674 | "outputs": [], 675 | "source": [ 676 | "larger_orig_T, larger_orig_fT, larger_orig_times = audio.tempogram(orig_x, orig_fs, 10, \"fourier\", np.arange(30, 670))" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": null, 682 | "id": "7c5193b0", 683 | "metadata": {}, 684 | "outputs": [], 685 | "source": [ 686 | "#utils.plot_tempogram(larger_orig_T, larger_orig_fT, larger_orig_times, title=\"Original audio\")" 687 | ] 688 | }, 689 | { 690 | "cell_type": "code", 691 | "execution_count": null, 692 | "id": "494a1e58", 693 | "metadata": {}, 694 | "outputs": [], 695 | "source": [ 696 | "aug_T.shape, larger_orig_T.shape" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": null, 702 | "id": "73ce870e", 703 | "metadata": {}, 704 | "outputs": [], 705 | "source": [ 706 | "larger_orig_T.shape[1]" 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": null, 712 | "id": "4f5b7164", 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "raw_aug = np.zeros(orig_T.shape)\n", 717 | "## dumb way of doing it\n", 718 | "# average every 2 lines, copy the result to the new array\n", 719 | "large_idx = 0\n", 720 | "idx = 0\n", 721 | "while idx < 320:\n", 722 | " # we have to use +2 because np slicing is [start, end), instead of [start, end]\n", 723 | " avg_lines = np.mean(larger_orig_T[large_idx:large_idx+2, :], axis=0)\n", 724 | " \n", 725 | " raw_aug[idx,:] = avg_lines\n", 726 | " \n", 727 | " idx += 1\n", 728 | " large_idx += 2\n", 729 | " " 730 | ] 731 | }, 732 | { 733 | "cell_type": "code", 734 | "execution_count": null, 735 | "id": "182229c6", 736 | "metadata": {}, 737 | "outputs": [], 738 | "source": [ 739 | "larger_orig_T[::2, :].shape" 740 | ] 741 | }, 742 | { 743 | "cell_type": "code", 744 | "execution_count": null, 745 | "id": "09cc8e41", 746 | "metadata": {}, 747 | "outputs": [], 748 | "source": [ 749 | "raw_aug = larger_orig_T[::2, :].copy()" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": null, 755 | "id": "85e32c3b", 756 | "metadata": {}, 757 | "outputs": [], 758 | "source": [ 759 | "plot_comparison([raw_aug, aug_T], [orig_fT, aug_fT], [orig_times, aug_times], subplot_titles=[\"tempogram_aug\", \"audio_aug\"], ttypes=\"linear\")" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": null, 765 | "id": "7d961df7", 766 | "metadata": {}, 767 | "outputs": [], 768 | "source": [ 769 | "utils.plot_tempogram(aug_T, aug_fT, aug_times, title=f\"Augmented ({aug_bpm} BPM)\")" 770 | ] 771 | }, 772 | { 773 | "cell_type": "code", 774 | "execution_count": null, 775 | "id": "ef990a0d", 776 | "metadata": { 777 | "scrolled": true 778 | }, 779 | "outputs": [], 780 | "source": [ 781 | "utils.plot_tempogram(orig_T, orig_fT, orig_times, title=f\"Original audio ({orig_bpm} BPM)\")" 782 | ] 783 | }, 784 | { 785 | "cell_type": "code", 786 | "execution_count": null, 787 | "id": "b12eed61", 788 | "metadata": {}, 789 | "outputs": [], 790 | "source": [ 791 | "idx = 0\n", 792 | "tmp = np.mean(larger_orig_T[idx:idx+2, :], axis=0)" 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": null, 798 | "id": "b5bc5e6a", 799 | "metadata": {}, 800 | "outputs": [], 801 | "source": [ 802 | "tmp.shape" 803 | ] 804 | }, 805 | { 806 | "cell_type": "code", 807 | "execution_count": null, 808 | "id": "d1e5f1ab", 809 | "metadata": {}, 810 | "outputs": [], 811 | "source": [ 812 | "tempi_array = np.asarray(tempi)" 813 | ] 814 | }, 815 | { 816 | "cell_type": "code", 817 | "execution_count": null, 818 | "id": "753b59d5", 819 | "metadata": {}, 820 | "outputs": [], 821 | "source": [ 822 | "plt.hist(tempi_array, bins=50)" 823 | ] 824 | }, 825 | { 826 | "cell_type": "code", 827 | "execution_count": null, 828 | "id": "eef9d4cb", 829 | "metadata": {}, 830 | "outputs": [], 831 | "source": [ 832 | "plt.hist(\n", 833 | " [np.append(tempi_array,tempi_array/2), dist_low], \n", 834 | " bins=50, \n", 835 | " color=[\"red\", \"orange\"], \n", 836 | " label=[\"gtzan + gtzan/2\", \"lognorm@70\"]\n", 837 | ")\n", 838 | "plt.legend()" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": null, 844 | "id": "ccace1d5", 845 | "metadata": {}, 846 | "outputs": [], 847 | "source": [ 848 | "np.append(tempi_array,tempi_array/2)" 849 | ] 850 | } 851 | ], 852 | "metadata": { 853 | "kernelspec": { 854 | "display_name": "Python 3 (ipykernel)", 855 | "language": "python", 856 | "name": "python3" 857 | }, 858 | "language_info": { 859 | "codemirror_mode": { 860 | "name": "ipython", 861 | "version": 3 862 | }, 863 | "file_extension": ".py", 864 | "mimetype": "text/x-python", 865 | "name": "python", 866 | "nbconvert_exporter": "python", 867 | "pygments_lexer": "ipython3", 868 | "version": "3.10.6" 869 | } 870 | }, 871 | "nbformat": 4, 872 | "nbformat_minor": 5 873 | } 874 | --------------------------------------------------------------------------------