├── .gitignore ├── README.md ├── bpm_detection └── bpm_detection.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | build/ 3 | dist/ 4 | env/ 5 | htmlcov/ 6 | .coverage 7 | *.DS_Store 8 | *.pyc 9 | *.bak 10 | *.mp3 11 | *.m4a 12 | *.tmp 13 | *.wav 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | BPM Detector in Python 2 | ======================= 3 | Implementation of a Beats Per Minute (BPM) detection algorithm, as presented in the paper of G. Tzanetakis, G. Essl and P. Cook titled: "Audio Analysis using the Discrete Wavelet Transform". 4 | 5 | You can find it here: http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.63.5712 6 | 7 | Based on the work done in the MATLAB code located at github.com/panagiop/the-BPM-detector-python. 8 | 9 | Process .wav file to determine the Beats Per Minute. 10 | 11 | ## Requirements 12 | Tested with Python 3.10. Key Dependencies: scipy, numpy, pywavelets, matplotlib. See requirements.txt 13 | -------------------------------------------------------------------------------- /bpm_detection/bpm_detection.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Free Software Foundation, Inc. 2 | # 3 | # This file is part of The BPM Detector Python 4 | # 5 | # The BPM Detector Python is free software; you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation; either version 3, or (at your option) 8 | # any later version. 9 | # 10 | # The BPM Detector Python is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with The BPM Detector Python; see the file COPYING. If not, write to 17 | # the Free Software Foundation, Inc., 51 Franklin Street, 18 | # Boston, MA 02110-1301, USA. 19 | 20 | import argparse 21 | import array 22 | import math 23 | import wave 24 | 25 | import matplotlib.pyplot as plt 26 | import numpy 27 | import pywt 28 | from scipy import signal 29 | 30 | 31 | def read_wav(filename): 32 | # open file, get metadata for audio 33 | try: 34 | wf = wave.open(filename, "rb") 35 | except IOError as e: 36 | print(e) 37 | return 38 | 39 | # typ = choose_type( wf.getsampwidth() ) # TODO: implement choose_type 40 | nsamps = wf.getnframes() 41 | assert nsamps > 0 42 | 43 | fs = wf.getframerate() 44 | assert fs > 0 45 | 46 | # Read entire file and make into an array 47 | samps = list(array.array("i", wf.readframes(nsamps))) 48 | 49 | try: 50 | assert nsamps == len(samps) 51 | except AssertionError: 52 | print(nsamps, "not equal to", len(samps)) 53 | 54 | return samps, fs 55 | 56 | 57 | # print an error when no data can be found 58 | def no_audio_data(): 59 | print("No audio data for sample, skipping...") 60 | return None, None 61 | 62 | 63 | # simple peak detection 64 | def peak_detect(data): 65 | max_val = numpy.amax(abs(data)) 66 | peak_ndx = numpy.where(data == max_val) 67 | if len(peak_ndx[0]) == 0: # if nothing found then the max must be negative 68 | peak_ndx = numpy.where(data == -max_val) 69 | return peak_ndx 70 | 71 | 72 | def bpm_detector(data, fs): 73 | cA = [] 74 | cD = [] 75 | correl = [] 76 | cD_sum = [] 77 | levels = 4 78 | max_decimation = 2 ** (levels - 1) 79 | min_ndx = math.floor(60.0 / 220 * (fs / max_decimation)) 80 | max_ndx = math.floor(60.0 / 40 * (fs / max_decimation)) 81 | 82 | for loop in range(0, levels): 83 | cD = [] 84 | # 1) DWT 85 | if loop == 0: 86 | [cA, cD] = pywt.dwt(data, "db4") 87 | cD_minlen = len(cD) / max_decimation + 1 88 | cD_sum = numpy.zeros(math.floor(cD_minlen)) 89 | else: 90 | [cA, cD] = pywt.dwt(cA, "db4") 91 | 92 | # 2) Filter 93 | cD = signal.lfilter([0.01], [1 - 0.99], cD) 94 | 95 | # 4) Subtract out the mean. 96 | 97 | # 5) Decimate for reconstruction later. 98 | cD = abs(cD[:: (2 ** (levels - loop - 1))]) 99 | cD = cD - numpy.mean(cD) 100 | 101 | # 6) Recombine the signal before ACF 102 | # Essentially, each level the detail coefs (i.e. the HPF values) are concatenated to the beginning of the array 103 | cD_sum = cD[0 : math.floor(cD_minlen)] + cD_sum 104 | 105 | if [b for b in cA if b != 0.0] == []: 106 | return no_audio_data() 107 | 108 | # Adding in the approximate data as well... 109 | cA = signal.lfilter([0.01], [1 - 0.99], cA) 110 | cA = abs(cA) 111 | cA = cA - numpy.mean(cA) 112 | cD_sum = cA[0 : math.floor(cD_minlen)] + cD_sum 113 | 114 | # ACF 115 | correl = numpy.correlate(cD_sum, cD_sum, "full") 116 | 117 | midpoint = math.floor(len(correl) / 2) 118 | correl_midpoint_tmp = correl[midpoint:] 119 | peak_ndx = peak_detect(correl_midpoint_tmp[min_ndx:max_ndx]) 120 | if len(peak_ndx) > 1: 121 | return no_audio_data() 122 | 123 | peak_ndx_adjusted = peak_ndx[0] + min_ndx 124 | bpm = 60.0 / peak_ndx_adjusted * (fs / max_decimation) 125 | print(bpm) 126 | return bpm, correl 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser(description="Process .wav file to determine the Beats Per Minute.") 131 | parser.add_argument("--filename", required=True, help=".wav file for processing") 132 | parser.add_argument( 133 | "--window", 134 | type=float, 135 | default=3, 136 | help="Size of the the window (seconds) that will be scanned to determine the bpm. Typically less than 10 seconds. [3]", 137 | ) 138 | 139 | args = parser.parse_args() 140 | samps, fs = read_wav(args.filename) 141 | data = [] 142 | correl = [] 143 | bpm = 0 144 | n = 0 145 | nsamps = len(samps) 146 | window_samps = int(args.window * fs) 147 | samps_ndx = 0 # First sample in window_ndx 148 | max_window_ndx = math.floor(nsamps / window_samps) 149 | bpms = numpy.zeros(max_window_ndx) 150 | 151 | # Iterate through all windows 152 | for window_ndx in range(0, max_window_ndx): 153 | 154 | # Get a new set of samples 155 | # print(n,":",len(bpms),":",max_window_ndx_int,":",fs,":",nsamps,":",samps_ndx) 156 | data = samps[samps_ndx : samps_ndx + window_samps] 157 | if not ((len(data) % window_samps) == 0): 158 | raise AssertionError(str(len(data))) 159 | 160 | bpm, correl_temp = bpm_detector(data, fs) 161 | if bpm is None: 162 | continue 163 | bpms[window_ndx] = bpm 164 | correl = correl_temp 165 | 166 | # Iterate at the end of the loop 167 | samps_ndx = samps_ndx + window_samps 168 | 169 | # Counter for debug... 170 | n = n + 1 171 | 172 | bpm = numpy.median(bpms) 173 | print("Completed! Estimated Beats Per Minute:", bpm) 174 | 175 | n = range(0, len(correl)) 176 | plt.plot(n, abs(correl)) 177 | plt.show(block=True) 178 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi @ file:///private/var/folders/sy/f16zz6x50xz3113nwtb9bvq00000gp/T/abs_477u68wvzm/croot/certifi_1671487773341/work/certifi 2 | contourpy==1.0.7 3 | cycler==0.11.0 4 | fonttools==4.38.0 5 | kiwisolver==1.4.4 6 | matplotlib==3.6.3 7 | numpy==1.24.2 8 | packaging==23.0 9 | Pillow==9.4.0 10 | pyparsing==3.0.9 11 | python-dateutil==2.8.2 12 | PyWavelets==1.4.1 13 | scipy==1.10.0 14 | six==1.16.0 15 | --------------------------------------------------------------------------------