├── .github └── workflows │ └── python-unit-tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── flatten ├── Dockerfile ├── __init__.py ├── common │ ├── dataset.py │ ├── errors.py │ ├── graphing.py │ ├── sampling.py │ ├── spectrum.py │ └── wavelet.py ├── dsp-server.py ├── dsp.py ├── parameters.json ├── requirements-blocks.txt └── third_party │ └── placeholder ├── image ├── Dockerfile ├── __init__.py ├── common │ ├── dataset.py │ ├── errors.py │ ├── graphing.py │ ├── sampling.py │ ├── spectrum.py │ └── wavelet.py ├── dsp-server.py ├── dsp.py ├── parameters.json ├── requirements-blocks.txt └── third_party │ └── placeholder ├── mfcc ├── Dockerfile ├── README.md ├── __init__.py ├── common │ ├── dataset.py │ ├── errors.py │ ├── graphing.py │ ├── sampling.py │ ├── spectrum.py │ └── wavelet.py ├── dsp-server.py ├── dsp.py ├── parameters.json ├── requirements-blocks.txt └── third_party │ └── speechpy │ ├── __init__.py │ ├── feature.py │ ├── functions.py │ └── processing.py ├── mfe ├── Dockerfile ├── __init__.py ├── common │ ├── dataset.py │ ├── errors.py │ ├── graphing.py │ ├── sampling.py │ ├── spectrum.py │ └── wavelet.py ├── dsp-server.py ├── dsp.py ├── parameters.json ├── requirements-blocks.txt └── third_party │ └── speechpy │ ├── __init__.py │ ├── feature.py │ ├── functions.py │ └── processing.py ├── raw ├── Dockerfile ├── __init__.py ├── common │ ├── dataset.py │ ├── errors.py │ ├── graphing.py │ ├── sampling.py │ ├── spectrum.py │ └── wavelet.py ├── dsp-server.py ├── dsp.py ├── parameters.json ├── requirements-blocks.txt └── third_party │ └── placeholder ├── requirements.txt ├── run_mfe.py ├── run_spectral_analysis_via_python.py ├── spectral_analysis ├── Dockerfile ├── __init__.py ├── common │ ├── dataset.py │ ├── errors.py │ ├── graphing.py │ ├── sampling.py │ ├── spectrum.py │ └── wavelet.py ├── dsp-server.py ├── dsp.py ├── parameters.json ├── requirements-blocks.txt └── third_party │ └── placeholder ├── spectrogram ├── Dockerfile ├── __init__.py ├── common │ ├── dataset.py │ ├── errors.py │ ├── graphing.py │ ├── sampling.py │ ├── spectrum.py │ └── wavelet.py ├── dsp-server.py ├── dsp.py ├── parameters.json ├── requirements-blocks.txt └── third_party │ └── speechpy │ ├── __init__.py │ ├── feature.py │ ├── functions.py │ └── processing.py └── tests └── test_spectrogram.py /.github/workflows/python-unit-tests.yml: -------------------------------------------------------------------------------- 1 | name: Python application 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: [ master ] 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Remove unnecessary files 18 | run: | 19 | sudo rm -rf /usr/share/dotnet 20 | sudo rm -rf "$AGENT_TOOLSDIRECTORY" 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.9" 26 | - name: Cache Python 27 | id: cache-python 28 | uses: actions/cache/restore@v3 29 | with: 30 | path: ${{ env.pythonLocation }} 31 | key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install -r requirements.txt 36 | - name: Cache Python 37 | if: steps.cache-python.outputs.cache-hit != 'true' 38 | uses: actions/cache/save@v3 39 | with: 40 | path: ${{ env.pythonLocation }} 41 | key: ${{ steps.cache-python.outputs.cache-primary-key }} 42 | 43 | - name: Test with unittest 44 | run: | 45 | python -m unittest discover tests -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode 3 | .ei-block-config 4 | 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) 2025 EdgeImpulse Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted (subject to the limitations in the disclaimer 8 | below) provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright 14 | notice, this list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 22 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 23 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 25 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 26 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 27 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 28 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 29 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER 30 | IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 31 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Edge Impulse processing blocks 2 | 3 | These are officially supported processing blocks in Edge Impulse. These blocks can be selected from any project, and can be used as inspiration for [Building custom processing blocks](https://docs.edgeimpulse.com/docs/custom-blocks). These blocks are meant to be ran on a server. The corresponding C++ blocks are in the [C++ inferencing SDK](https://github.com/edgeimpulse/inferencing-sdk-cpp), or can be found in the edge-impulse-sdk folder of the zip exported from Studio. 4 | 5 | ## Contributing to this repository 6 | 7 | We welcome contributions to this repository. Both improvements to our own processing blocks, as well as new and well-tested processing blocks for other sensor data. To contribute just open a pull request against this repository. Note that blocks require a corresponding implementation in the inferencing SDK. If you add options, change behavior, or add a new block please either open a corresponding pull request in the inferencing SDK, or ask for some help from the Edge Impulse team. 8 | 9 | ### Testing your contributions 10 | 11 | The blocks in this repository are compatible with custom processing blocks. Follow the [Building custom processing blocks](https://docs.edgeimpulse.com/docs/custom-blocks) tutorial to learn how you can load modified processing blocks into Edge Impulse. 12 | -------------------------------------------------------------------------------- /flatten/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental@sha256:3c244c0c6fc9d6aa3ddb73af4264b3a23597523ac553294218c13735a2c6cf79 2 | FROM ubuntu:20.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | WORKDIR /app 7 | 8 | # python3 and all dependencies for scipy 9 | RUN apt update && apt install -y python3 python3-pip libatlas-base-dev gfortran-9 libfreetype6-dev wget && \ 10 | ln -s $(which gfortran-9) /usr/bin/gfortran 11 | 12 | # Update pip 13 | RUN pip3 install -U pip==22.0.3 14 | 15 | # Cython and scikit-learn - it needs to be done in this order for some reason 16 | RUN pip3 --no-cache-dir install Cython==0.29.24 17 | 18 | # Rest of the dependencies 19 | COPY requirements-blocks.txt ./ 20 | RUN pip3 --no-cache-dir install -r requirements-blocks.txt 21 | 22 | COPY third_party /third_party 23 | COPY . ./ 24 | 25 | EXPOSE 4446 26 | 27 | ENTRYPOINT ["python3", "-u", "dsp-server.py"] 28 | -------------------------------------------------------------------------------- /flatten/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsp import generate_features 2 | -------------------------------------------------------------------------------- /flatten/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('/') 4 | from common.sampling import calc_resampled_size, calculate_freq, Resampler 5 | 6 | 7 | class Dataset: 8 | '''Create an iterable dataset when x data is flattened, handling reshaping and resampling''' 9 | 10 | def __init__(self, X_all, metadata, axis, returns_interval=True, resample_interval_ms=None): 11 | self.ix = 0 12 | self.returns_interval = returns_interval 13 | self.max_len = 0 14 | 15 | X_all_shaped = [] 16 | y_all = [] 17 | self.y_label_set = set() 18 | intervals_all = [] 19 | current_offset = 0 20 | 21 | if resample_interval_ms: 22 | self.fs = calculate_freq(resample_interval_ms) 23 | else: 24 | self.fs = None 25 | 26 | # Prepare for resampling data 27 | if resample_interval_ms is not None: 28 | resample_utility = Resampler(len(metadata)) 29 | target_freq = calculate_freq(resample_interval_ms) 30 | intervals_all.append(resample_interval_ms) 31 | 32 | # Reshape all samples 33 | for ix in range(len(metadata)): 34 | # Get x data using offset 35 | cur_len = metadata[ix] 36 | X_full = X_all[current_offset: current_offset + cur_len] 37 | current_offset = current_offset + cur_len 38 | 39 | # Split the interval, label from the features 40 | interval_ms = X_full[0] 41 | y = X_full[1] 42 | X = X_full[2:] 43 | 44 | if not self.fs: 45 | # if we didn't get a resampling rate from the caller, use the first sample's rate 46 | self.fs = calculate_freq(interval_ms) 47 | 48 | if not np.isnan(X).any(): 49 | # Reshape 50 | len_adjusted = cur_len - 2 51 | rows = int(len_adjusted / axis) 52 | # Data length is unexpected 53 | if not ((len_adjusted % axis) == 0): 54 | raise ValueError('Sample length is invalid, check the axis count.') 55 | 56 | X = np.reshape(X, (rows, axis)) 57 | 58 | # Resample data 59 | if resample_interval_ms is not None: 60 | # Work out the up and down factors using sample lengths 61 | original_length = X.shape[0] 62 | original_freq = calculate_freq(interval_ms) 63 | new_length = calc_resampled_size(original_freq, target_freq, original_length) 64 | 65 | # Resample 66 | X = resample_utility.resample(X, new_length, original_length) 67 | else: 68 | intervals_all.append(interval_ms) 69 | 70 | # Store the longest sample length 71 | self.max_len = max(self.max_len, X.shape[0]) 72 | X_all_shaped.append(X) 73 | y_all.append(y) 74 | self.y_label_set.add(y) 75 | 76 | self.X_all = X_all_shaped 77 | self.y_all = y_all 78 | self.intervals = intervals_all 79 | 80 | def reset(self): 81 | self.ix = 0 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | if self.ix >= len(self.y_all): 88 | self.reset() 89 | raise StopIteration 90 | 91 | X = self.X_all[self.ix] 92 | y = self.y_all[self.ix] 93 | if (len(self.intervals) == 1): 94 | # Resampled data has the same interval so we only store it once 95 | interval_ms = self.intervals[0] 96 | else: 97 | interval_ms = self.intervals[self.ix] 98 | 99 | self.ix += 1 100 | 101 | if (self.returns_interval): 102 | return X, y, interval_ms 103 | else: 104 | return X, y 105 | -------------------------------------------------------------------------------- /flatten/common/errors.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import traceback 4 | 5 | 6 | class ConfigurationError(Exception): 7 | pass 8 | 9 | 10 | def log(*msg, level='warn'): 11 | msg_clean = ' '.join([str(i) for i in msg]) 12 | print(json.dumps( 13 | {'msg': msg_clean, 14 | 'level': level, 15 | 'time': datetime.datetime.now().replace(microsecond=0).isoformat() + 'Z'})) 16 | 17 | 18 | def log_exception(msg): 19 | log(msg + ': ' + traceback.format_exc(), level='error') 20 | -------------------------------------------------------------------------------- /flatten/common/graphing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import io 5 | import base64 6 | import math 7 | 8 | 9 | def set_x_axis_times(frame_stride, frame_length, width): 10 | plt.xlabel('Time [sec]') 11 | time_len = (width * frame_stride) + frame_length 12 | times = np.linspace(0, time_len, 10) 13 | plt.xticks(np.linspace(0, width, len(times)), [round(x, 2) for x in times]) 14 | 15 | 16 | def create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs=None): 17 | matplotlib.use('Svg') 18 | _, ax = plt.subplots() 19 | if not freqs: 20 | freqs = np.linspace(0, sampling_freq / 2, 15) 21 | plt.ylabel('Frequency [Hz]') 22 | ax.imshow(power_spectrum, interpolation='nearest', 23 | cmap=matplotlib.cm.coolwarm, origin='lower') 24 | plt.yticks(np.linspace(0, height, len(freqs)), [math.ceil(x) for x in freqs]) 25 | set_x_axis_times(frame_stride, frame_length, width) 26 | 27 | buf = io.BytesIO() 28 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 29 | buf.seek(0) 30 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 31 | buf.close() 32 | return image 33 | 34 | 35 | def create_mfe_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs): 36 | # Trim down the frequency list for a y axis labels 37 | freqs = [freqs[0], *freqs[1:-1:4], freqs[-1]] 38 | return create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs) 39 | -------------------------------------------------------------------------------- /flatten/common/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import sys 4 | from scipy import signal 5 | 6 | 7 | def calc_resampled_size(input_sample_rate, output_sample_rate, input_length): 8 | """Calculate the output size after resampling. 9 | :returns: integer output size, >= 1 10 | """ 11 | target_size = int( 12 | math.ceil((output_sample_rate / input_sample_rate) * (input_length))) 13 | return max(target_size, 1) 14 | 15 | 16 | def calculate_freq(interval): 17 | """ Convert interval (ms) to frequency (Hz) 18 | """ 19 | freq = 1000 / interval 20 | if abs(freq - round(freq)) < 0.01: 21 | freq = round(freq) 22 | return freq 23 | 24 | 25 | def calc_decimation_ratios(filter_type, filter_cutoff, fs): 26 | if filter_type != 'low': 27 | return 1 28 | 29 | # we support base ratios of 3 and 10 in SDK 30 | ratios = [3, 10, 30, 100, 1000] 31 | ratios.reverse() 32 | for r in ratios: 33 | if fs / 2 / r * 0.9 > filter_cutoff: 34 | return r 35 | 36 | return 1 37 | 38 | 39 | def get_ratio_combo(r): 40 | if r == 1: 41 | return [1] 42 | elif r == 3 or r == 10: 43 | return [r] 44 | elif r == 30: 45 | return [3, 10] 46 | elif r == 100: 47 | return [10, 10] 48 | elif r == 1000: 49 | return [10, 10, 10] 50 | else: 51 | raise ValueError("Invalid decimation ratio: {}".format(r)) 52 | 53 | 54 | def create_decimate_filter(ratio): 55 | sos = signal.cheby1(8, 0.05, 0.8 / ratio, output='sos') 56 | zi = signal.sosfilt_zi(sos) 57 | return sos, zi 58 | 59 | 60 | def decimate_simple(x, ratio, export=False): 61 | if x.ndim != 1: 62 | raise ValueError(f'x must be 1D {x.shape}') 63 | x = x.reshape(x.shape[0]) 64 | if (ratio == 1): 65 | return x 66 | sos, zi = create_decimate_filter(ratio) 67 | y, zo = signal.sosfilt(sos, x, zi=zi * x[0]) 68 | sl = slice(None, None, ratio) 69 | y = y[sl] 70 | if export: 71 | return y, sos, zi 72 | return y 73 | 74 | 75 | class Resampler: 76 | """ Utility class to handle resampling and logging 77 | """ 78 | 79 | def __init__(self, total_samples): 80 | self.total_samples = total_samples 81 | self.ix = 0 82 | self.last_message = 0 83 | 84 | def resample(self, sample, new_length, original_length): 85 | # Work out the correct axis 86 | ds_axis = 0 87 | if (sample.shape[0] == 1): 88 | ds_axis = 1 89 | 90 | # Resample 91 | if (original_length != new_length): 92 | sample = signal.resample_poly( 93 | sample, new_length, original_length, axis=ds_axis) 94 | 95 | # Logging 96 | self.ix += 1 97 | if (int(round(time.time() * 1000)) - self.last_message >= 3000) or (self.ix == self.total_samples): 98 | print('[%s/%d] Resampling windows...' % 99 | (str(self.ix).rjust(len(str(self.total_samples)), ' '), self.total_samples)) 100 | 101 | if (self.ix == self.total_samples): 102 | print('Resampled %d windows\n' % self.total_samples) 103 | 104 | sys.stdout.flush() 105 | self.last_message = int(round(time.time() * 1000)) 106 | 107 | return sample 108 | -------------------------------------------------------------------------------- /flatten/common/spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | sys.path.append('/') 5 | from .errors import ConfigurationError 6 | 7 | 8 | def next_power_of_2(x): 9 | return 1 if x == 0 else 2**(x - 1).bit_length() 10 | 11 | 12 | def welch_max_hold(fx, sampling_freq, nfft, n_overlap): 13 | n_overlap = int(n_overlap) 14 | spec_powers = [0 for _ in range(nfft//2+1)] 15 | ix = 0 16 | while ix <= len(fx): 17 | # Slicing truncates if end_idx > len, and rfft will auto zero pad 18 | fft_out = np.abs(np.fft.rfft(fx[ix:ix+nfft], nfft)) 19 | spec_powers = np.maximum(spec_powers, fft_out**2/nfft) 20 | ix = ix + (nfft-n_overlap) 21 | return np.fft.rfftfreq(nfft, 1/sampling_freq), spec_powers 22 | 23 | 24 | def zero_handling(x): 25 | """ 26 | This function handle the issue with zero values if the are exposed 27 | to become an argument for any log function. 28 | :param x: The vector. 29 | :return: The vector with zeros substituted with epsilon values. 30 | """ 31 | return np.where(x == 0, 1e-10, x) 32 | 33 | 34 | def cap_frame_stride(window_size_ms, frame_stride): 35 | """Returns the frame stride passed in, 36 | or a stride that creates 500 frames if the window size is too large. 37 | 38 | Args: 39 | window_size_ms (int): The users window size (in ms). 40 | If none or 0, no capping is done. 41 | frame_stride (float): The desired frame stride 42 | 43 | Returns: 44 | float: Either the passed in frame_stride, or longer frame stride 45 | """ 46 | if window_size_ms: 47 | num_frames = (window_size_ms / 1000) / frame_stride 48 | if num_frames > 500: 49 | print('WARNING: Your window size is too large for the ideal frame stride. ' 50 | f'Set window size to {500 * frame_stride * 1000} ms, or smaller. ' 51 | 'Adjusting ideal frame stride to set number of frames to 500') 52 | frame_stride = (window_size_ms / 1000) / 500 53 | return frame_stride 54 | 55 | 56 | def audio_set_params(frame_length, fs): 57 | """Suggest parameters for audio processing (MFE/MFCC) 58 | 59 | Args: 60 | frame_length (float): The desired frame length (in seconds) 61 | fs (int): The sampling frequency (in Hz) 62 | 63 | Returns: 64 | fft_length: Recomended FFT length 65 | num_filters: Recomended number of filters 66 | """ 67 | DEFAULT_NUM_FILTERS = 40 68 | DEFAULT_NFFT = 256 # for 8kHz sampling rate 69 | 70 | fft_length = next_power_of_2(int(frame_length * fs)) 71 | num_filters = int(DEFAULT_NUM_FILTERS + np.log2(fft_length / DEFAULT_NFFT)) 72 | return fft_length, num_filters 73 | -------------------------------------------------------------------------------- /flatten/common/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import numpy as np 3 | from scipy.stats import skew, entropy, kurtosis 4 | 5 | 6 | def calculate_entropy(x): 7 | # todo: try approximate entropy 8 | # todo: try Kozachenko and Leonenko 9 | probabilities = np.histogram(x, bins=100, density=True)[0] 10 | return {'entropy': entropy(probabilities)} 11 | 12 | 13 | def get_percentile_from_sorted(array, percentile): 14 | # adding 0.5 is a trick to get rounding out of C flooring behavior during cast 15 | index = int(((len(array)-1) * percentile/100) + 0.5) 16 | return array[index] 17 | 18 | 19 | def calculate_statistics(x): 20 | output = {} 21 | x.sort() 22 | output['n5'] = get_percentile_from_sorted(x, 5) 23 | output['n25'] = get_percentile_from_sorted(x, 25) 24 | output['n75'] = get_percentile_from_sorted(x, 75) 25 | output['n95'] = get_percentile_from_sorted(x, 95) 26 | output['median'] = get_percentile_from_sorted(x, 50) 27 | output['mean'] = np.mean(x) 28 | output['std'] = np.std(x) 29 | output['var'] = np.var(x, ddof=1) 30 | output['rms'] = np.sqrt(np.mean(x**2)) 31 | output['skew'] = 0 if output['rms'] == 0 else skew(x) 32 | output['kurtosis'] = 0 if output['rms'] == 0 else kurtosis(x) 33 | return output 34 | 35 | 36 | def calculate_crossings(x): 37 | lx = len(x) 38 | zero_crossing_indices = np.nonzero(np.diff(np.array(x) > 0))[0] 39 | no_zero_crossings = len(zero_crossing_indices) / lx 40 | m = np.nanmean(x) 41 | mean_crossing_indices = np.nonzero(np.diff(np.array(x) > m))[0] 42 | no_mean_crossings = len(mean_crossing_indices) / lx 43 | return {'zcross': no_zero_crossings, 'mcross': no_mean_crossings} 44 | 45 | 46 | def get_features(x): 47 | features = calculate_entropy(x) 48 | features.update(calculate_crossings(x)) 49 | features.update(calculate_statistics(x)) 50 | return features 51 | 52 | 53 | def get_max_level(signal_length): 54 | return int(np.log2(signal_length / 32)) 55 | 56 | 57 | def get_min_length(level): 58 | return 32 * np.power(2, level) 59 | 60 | 61 | def dwt_features(x, wav='db4', level=4, mode='stats'): 62 | y = pywt.wavedec(x, wav, level=level) 63 | 64 | if mode == 'raw': 65 | XW = [item for sublist in y for item in sublist] 66 | else: 67 | features = [] 68 | labels = [] 69 | for i in range(len(y)): 70 | d = get_features(y[i]) 71 | for k, v in d.items(): 72 | features.append(v) 73 | labels.append('L' + str(i) + '-' + k) 74 | 75 | return features, labels, y[0] 76 | 77 | 78 | def get_wavefunc(wav, level): 79 | 80 | wavelet = pywt.Wavelet(wav) 81 | try: 82 | phi, psi, x = wavelet.wavefun(level) 83 | except: 84 | phi, psi, _, _, x = wavelet.wavefun(level) 85 | return phi, psi, x 86 | -------------------------------------------------------------------------------- /flatten/dsp-server.py: -------------------------------------------------------------------------------- 1 | # This is a generic Edge Impulse DSP server in Python 2 | # You probably don't need to change this file. 3 | 4 | import sys, importlib, os, socket, json, math, traceback 5 | from http.server import HTTPServer, BaseHTTPRequestHandler 6 | from socketserver import ThreadingMixIn 7 | import threading 8 | from urllib.parse import urlparse, parse_qs 9 | import traceback 10 | import logging 11 | import numpy as np 12 | from dsp import generate_features 13 | 14 | def get_params(self): 15 | with open('parameters.json', 'r') as f: 16 | return json.loads(f.read()) 17 | 18 | def single_req(self, fn, body): 19 | if (not body['features'] or len(body['features']) == 0): 20 | raise ValueError('Missing "features" in body') 21 | if (not 'params' in body): 22 | raise ValueError('Missing "params" in body') 23 | if (not 'sampling_freq' in body): 24 | raise ValueError('Missing "sampling_freq" in body') 25 | if (not 'draw_graphs' in body): 26 | raise ValueError('Missing "draw_graphs" in body') 27 | 28 | args = { 29 | 'draw_graphs': body['draw_graphs'], 30 | 'raw_data': np.array(body['features']), 31 | 'axes': np.array(body['axes']), 32 | 'sampling_freq': body['sampling_freq'], 33 | 'implementation_version': body['implementation_version'] 34 | } 35 | 36 | for param_key in body['params'].keys(): 37 | args[param_key] = body['params'][param_key] 38 | 39 | processed = fn(**args) 40 | if (isinstance(processed['features'], np.ndarray)): 41 | processed['features'] = processed['features'].flatten().tolist() 42 | 43 | body = json.dumps(processed) 44 | 45 | self.send_response(200) 46 | self.send_header('Content-Type', 'application/json') 47 | self.end_headers() 48 | self.wfile.write(body.encode()) 49 | 50 | def batch_req(self, fn, body): 51 | if (not body['features'] or len(body['features']) == 0): 52 | raise ValueError('Missing "features" in body') 53 | if (not 'params' in body): 54 | raise ValueError('Missing "params" in body') 55 | if (not 'sampling_freq' in body): 56 | raise ValueError('Missing "sampling_freq" in body') 57 | 58 | base_args = { 59 | 'draw_graphs': False, 60 | 'axes': np.array(body['axes']), 61 | 'sampling_freq': body['sampling_freq'], 62 | 'implementation_version': body['implementation_version'] 63 | } 64 | 65 | for param_key in body['params'].keys(): 66 | base_args[param_key] = body['params'][param_key] 67 | 68 | total = 0 69 | features = [] 70 | labels = [] 71 | output_config = None 72 | 73 | for example in body['features']: 74 | args = dict(base_args) 75 | args['raw_data'] = np.array(example) 76 | f = fn(**args) 77 | if (isinstance(f['features'], np.ndarray)): 78 | features.append(f['features'].flatten().tolist()) 79 | else: 80 | features.append(f['features']) 81 | 82 | if total == 0: 83 | if ('labels' in f): 84 | labels = f['labels'] 85 | if ('output_config' in f): 86 | output_config = f['output_config'] 87 | 88 | total += 1 89 | 90 | body = json.dumps({ 91 | 'success': True, 92 | 'features': features, 93 | 'labels': labels, 94 | 'output_config': output_config 95 | }) 96 | 97 | self.send_response(200) 98 | self.send_header('Content-Type', 'application/json') 99 | self.end_headers() 100 | self.wfile.write(body.encode()) 101 | 102 | def tflite_req(self, fn, body): 103 | if (not 'params' in body): 104 | raise ValueError('Missing "params" in body') 105 | if (not 'sampling_freq' in body): 106 | raise ValueError('Missing "sampling_freq" in body') 107 | 108 | args = { 109 | 'axes': np.array(body['axes']), 110 | 'sampling_freq': body['sampling_freq'], 111 | 'implementation_version': body['implementation_version'], 112 | 'input_shape': body['input_shape'] 113 | } 114 | 115 | for param_key in body['params'].keys(): 116 | args[param_key] = body['params'][param_key] 117 | 118 | tflite_byte_arr = fn(**args) 119 | 120 | self.send_response(200) 121 | self.send_header('Content-type', 'application/octet-stream') 122 | self.send_header('Content-Disposition', 'attachment; filename="dsp.tflite"') 123 | self.end_headers() 124 | self.wfile.write(tflite_byte_arr) 125 | 126 | class Handler(BaseHTTPRequestHandler): 127 | def do_GET(self): 128 | url = urlparse(self.path) 129 | params = get_params(self) 130 | 131 | if (url.path == '/'): 132 | self.send_response(200) 133 | self.send_header('Content-Type', 'text/plain') 134 | self.end_headers() 135 | self.wfile.write(('Edge Impulse DSP block: ' + params['info']['title'] + ' by ' + 136 | params['info']['author']).encode()) 137 | 138 | elif (url.path == '/parameters'): 139 | self.send_response(200) 140 | self.send_header('Content-Type', 'application/json') 141 | self.end_headers() 142 | params['version'] = 1 143 | self.wfile.write(json.dumps(params).encode()) 144 | 145 | else: 146 | self.send_response(404) 147 | self.send_header('Content-Type', 'text/plain') 148 | self.end_headers() 149 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 150 | 151 | def do_POST(self): 152 | url = urlparse(self.path) 153 | try: 154 | if (url.path == '/run'): 155 | content_len = int(self.headers.get('Content-Length')) 156 | post_body = self.rfile.read(content_len) 157 | body = json.loads(post_body.decode('utf-8')) 158 | single_req(self, generate_features, body) 159 | 160 | elif (url.path == '/batch'): 161 | content_len = int(self.headers.get('Content-Length')) 162 | post_body = self.rfile.read(content_len) 163 | body = json.loads(post_body.decode('utf-8')) 164 | batch_req(self, generate_features, body) 165 | 166 | else: 167 | self.send_response(404) 168 | self.send_header('Content-Type', 'text/plain') 169 | self.end_headers() 170 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 171 | 172 | 173 | except Exception as e: 174 | print('Failed to handle request', e, traceback.format_exc()) 175 | self.send_response(200) 176 | self.send_header('Content-Type', 'application/json') 177 | self.end_headers() 178 | self.wfile.write(json.dumps({ 'success': False, 'error': str(e) }).encode()) 179 | 180 | def log_message(self, format, *args): 181 | return 182 | 183 | class ThreadingSimpleServer(ThreadingMixIn, HTTPServer): 184 | pass 185 | 186 | def run(): 187 | host = '0.0.0.0' if not 'HOST' in os.environ else os.environ['HOST'] 188 | port = 4446 if not 'PORT' in os.environ else int(os.environ['PORT']) 189 | 190 | server = ThreadingSimpleServer((host, port), Handler) 191 | print('Listening on host', host, 'port', port) 192 | server.serve_forever() 193 | 194 | if __name__ == '__main__': 195 | run() 196 | -------------------------------------------------------------------------------- /flatten/dsp.py: -------------------------------------------------------------------------------- 1 | import argparse, sys 2 | import json 3 | import numpy as np 4 | from scipy.stats import skew 5 | from scipy.stats import kurtosis as calculateKurtosis 6 | 7 | def generate_features(implementation_version, draw_graphs, raw_data, axes, sampling_freq, scale_axes, 8 | average, minimum, maximum, rms, stdev, skewness, kurtosis): 9 | if (implementation_version != 1): 10 | raise Exception('implementation_version should be 1') 11 | 12 | raw_data = raw_data * scale_axes 13 | raw_data = raw_data.reshape(int(len(raw_data) / len(axes)), len(axes)) 14 | 15 | features = [] 16 | labels = [] 17 | 18 | for ax in range(0, len(axes)): 19 | X = raw_data[:,ax] 20 | 21 | if (average): 22 | features.append(float(np.average(X))) 23 | 24 | if (minimum): 25 | features.append(float(np.min(X))) 26 | 27 | if (maximum): 28 | features.append(float(np.max(X))) 29 | 30 | if (rms): 31 | features.append(float(np.sqrt(np.mean(np.square(X))))) 32 | 33 | if (stdev): 34 | features.append(float(np.std(X))) 35 | 36 | if (skewness): 37 | features.append(float(skew(X))) 38 | 39 | if (kurtosis): 40 | features.append(float(calculateKurtosis(X))) 41 | 42 | if (average): labels.append('Average') 43 | if (minimum): labels.append('Minimum') 44 | if (maximum): labels.append('Maximum') 45 | if (rms): labels.append('RMS') 46 | if (stdev): labels.append('Stdev') 47 | if (skewness): labels.append('Skewness') 48 | if (kurtosis): labels.append('Kurtosis') 49 | 50 | return { 51 | 'features': features, 52 | 'graphs': [], 53 | 'labels': labels, 54 | 'fft_used': [], 55 | 'output_config': { 'type': 'flat', 'shape': { 'width': len(features) } } 56 | } 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser(description='Flatten script for raw data') 60 | parser.add_argument('--features', type=str, required=True, 61 | help='Axis data as a flattened array of x,y,z (pass as comma separated values)') 62 | parser.add_argument('--axes', type=str, required=True, 63 | help='Names of the axis (pass as comma separated values)') 64 | parser.add_argument('--frequency', type=float, required=True, 65 | help='Frequency in hz') 66 | parser.add_argument('--scale-axes', type=float, default=1, 67 | help='scale axes (multiplies by this number, default: 1)') 68 | parser.add_argument('--average', type=lambda x: (str(x).lower() in ['true','1', 'yes']), default=True, 69 | help='calculate average (default: true)') 70 | parser.add_argument('--minimum', type=lambda x: (str(x).lower() in ['true','1', 'yes']), default=True, 71 | help='calculate minimum (default: true)') 72 | parser.add_argument('--maximum', type=lambda x: (str(x).lower() in ['true','1', 'yes']), default=True, 73 | help='calculate maximum (default: true)') 74 | parser.add_argument('--rms', type=lambda x: (str(x).lower() in ['true','1', 'yes']), default=True, 75 | help='calculate rms (default: true)') 76 | parser.add_argument('--stdev', type=lambda x: (str(x).lower() in ['true','1', 'yes']), default=True, 77 | help='calculate stdev (default: true)') 78 | parser.add_argument('--skewness', type=lambda x: (str(x).lower() in ['true','1', 'yes']), default=True, 79 | help='calculate skewness (default: true)') 80 | parser.add_argument('--kurtosis', type=lambda x: (str(x).lower() in ['true','1', 'yes']), default=True, 81 | help='calculate kurtosis (default: true)') 82 | parser.add_argument('--draw-graphs', type=lambda x: (str(x).lower() in ['true','1', 'yes']), required=True, 83 | help='Whether to draw graphs') 84 | 85 | args = parser.parse_args() 86 | 87 | raw_features = np.array([float(item.strip()) for item in args.features.split(',')]) 88 | raw_axes = args.axes.split(',') 89 | 90 | try: 91 | processed = generate_features(1, args.draw_graphs, raw_features, raw_axes, args.frequency, args.scale_axes, 92 | args.average, args.minimum, args.maximum, args.rms, args.stdev, args.skewness, args.kurtosis) 93 | 94 | print('Begin output') 95 | print(json.dumps(processed)) 96 | print('End output') 97 | except Exception as e: 98 | print(e, file=sys.stderr) 99 | exit(1) 100 | -------------------------------------------------------------------------------- /flatten/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": { 3 | "title": "Flatten", 4 | "author": "Edge Impulse", 5 | "description": "Flatten an axis into a single value, useful for slow-moving averages like temperature data, in combination with other blocks.", 6 | "name": "Flatten", 7 | "preferConvolution": false, 8 | "cppType": "flatten_custom", 9 | "experimental": false, 10 | "latestImplementationVersion": 1, 11 | "hasFeatureImportance": true 12 | }, 13 | "parameters": [ 14 | { 15 | "group": "Scaling", 16 | "items": [ 17 | { 18 | "name": "Scale axes", 19 | "value": 1, 20 | "type": "float", 21 | "help": "Multiply axes by this number", 22 | "param": "scale-axes" 23 | } 24 | ] 25 | }, 26 | { 27 | "group": "Method", 28 | "items": [ 29 | { 30 | "name": "Average", 31 | "value": true, 32 | "type": "boolean", 33 | "help": "Calculates the average value for the window", 34 | "param": "average" 35 | }, 36 | { 37 | "name": "Minimum", 38 | "value": true, 39 | "type": "boolean", 40 | "help": "Calculates the minimum value in the window", 41 | "param": "minimum" 42 | }, 43 | { 44 | "name": "Maximum", 45 | "value": true, 46 | "type": "boolean", 47 | "help": "Calculates the maximum value in the window", 48 | "param": "maximum" 49 | }, 50 | { 51 | "name": "Root-mean square", 52 | "value": true, 53 | "type": "boolean", 54 | "help": "Calculates the RMS value of the window", 55 | "param": "rms" 56 | }, 57 | { 58 | "name": "Standard deviation", 59 | "value": true, 60 | "type": "boolean", 61 | "help": "Calculates the standard deviation of the window", 62 | "param": "stdev" 63 | }, 64 | { 65 | "name": "Skewness", 66 | "value": true, 67 | "type": "boolean", 68 | "help": "Calculates the skewness of the window", 69 | "param": "skewness" 70 | }, 71 | { 72 | "name": "Kurtosis", 73 | "value": true, 74 | "type": "boolean", 75 | "help": "Calculates the kurtosis of the window", 76 | "param": "kurtosis" 77 | } 78 | ] 79 | } 80 | ] 81 | } 82 | -------------------------------------------------------------------------------- /flatten/requirements-blocks.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /flatten/third_party/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgeimpulse/processing-blocks/ba8108d8427ede9d8098808f283d95e0f7d610b8/flatten/third_party/placeholder -------------------------------------------------------------------------------- /image/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental@sha256:3c244c0c6fc9d6aa3ddb73af4264b3a23597523ac553294218c13735a2c6cf79 2 | FROM ubuntu:20.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | WORKDIR /app 7 | 8 | # python3 and all dependencies for scipy 9 | RUN apt update && apt install -y python3 python3-pip libatlas-base-dev gfortran-9 libfreetype6-dev wget && \ 10 | ln -s $(which gfortran-9) /usr/bin/gfortran 11 | 12 | # Update pip 13 | RUN pip3 install -U pip==22.0.3 14 | 15 | # Cython and scikit-learn - it needs to be done in this order for some reason 16 | RUN pip3 --no-cache-dir install Cython==0.29.24 17 | 18 | # Rest of the dependencies 19 | COPY requirements-blocks.txt ./ 20 | RUN pip3 --no-cache-dir install -r requirements-blocks.txt 21 | 22 | COPY third_party /third_party 23 | COPY . ./ 24 | 25 | EXPOSE 4446 26 | 27 | ENTRYPOINT ["python3", "-u", "dsp-server.py"] 28 | -------------------------------------------------------------------------------- /image/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsp import generate_features 2 | -------------------------------------------------------------------------------- /image/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('/') 4 | from common.sampling import calc_resampled_size, calculate_freq, Resampler 5 | 6 | 7 | class Dataset: 8 | '''Create an iterable dataset when x data is flattened, handling reshaping and resampling''' 9 | 10 | def __init__(self, X_all, metadata, axis, returns_interval=True, resample_interval_ms=None): 11 | self.ix = 0 12 | self.returns_interval = returns_interval 13 | self.max_len = 0 14 | 15 | X_all_shaped = [] 16 | y_all = [] 17 | self.y_label_set = set() 18 | intervals_all = [] 19 | current_offset = 0 20 | 21 | if resample_interval_ms: 22 | self.fs = calculate_freq(resample_interval_ms) 23 | else: 24 | self.fs = None 25 | 26 | # Prepare for resampling data 27 | if resample_interval_ms is not None: 28 | resample_utility = Resampler(len(metadata)) 29 | target_freq = calculate_freq(resample_interval_ms) 30 | intervals_all.append(resample_interval_ms) 31 | 32 | # Reshape all samples 33 | for ix in range(len(metadata)): 34 | # Get x data using offset 35 | cur_len = metadata[ix] 36 | X_full = X_all[current_offset: current_offset + cur_len] 37 | current_offset = current_offset + cur_len 38 | 39 | # Split the interval, label from the features 40 | interval_ms = X_full[0] 41 | y = X_full[1] 42 | X = X_full[2:] 43 | 44 | if not self.fs: 45 | # if we didn't get a resampling rate from the caller, use the first sample's rate 46 | self.fs = calculate_freq(interval_ms) 47 | 48 | if not np.isnan(X).any(): 49 | # Reshape 50 | len_adjusted = cur_len - 2 51 | rows = int(len_adjusted / axis) 52 | # Data length is unexpected 53 | if not ((len_adjusted % axis) == 0): 54 | raise ValueError('Sample length is invalid, check the axis count.') 55 | 56 | X = np.reshape(X, (rows, axis)) 57 | 58 | # Resample data 59 | if resample_interval_ms is not None: 60 | # Work out the up and down factors using sample lengths 61 | original_length = X.shape[0] 62 | original_freq = calculate_freq(interval_ms) 63 | new_length = calc_resampled_size(original_freq, target_freq, original_length) 64 | 65 | # Resample 66 | X = resample_utility.resample(X, new_length, original_length) 67 | else: 68 | intervals_all.append(interval_ms) 69 | 70 | # Store the longest sample length 71 | self.max_len = max(self.max_len, X.shape[0]) 72 | X_all_shaped.append(X) 73 | y_all.append(y) 74 | self.y_label_set.add(y) 75 | 76 | self.X_all = X_all_shaped 77 | self.y_all = y_all 78 | self.intervals = intervals_all 79 | 80 | def reset(self): 81 | self.ix = 0 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | if self.ix >= len(self.y_all): 88 | self.reset() 89 | raise StopIteration 90 | 91 | X = self.X_all[self.ix] 92 | y = self.y_all[self.ix] 93 | if (len(self.intervals) == 1): 94 | # Resampled data has the same interval so we only store it once 95 | interval_ms = self.intervals[0] 96 | else: 97 | interval_ms = self.intervals[self.ix] 98 | 99 | self.ix += 1 100 | 101 | if (self.returns_interval): 102 | return X, y, interval_ms 103 | else: 104 | return X, y 105 | -------------------------------------------------------------------------------- /image/common/errors.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import traceback 4 | 5 | 6 | class ConfigurationError(Exception): 7 | pass 8 | 9 | 10 | def log(*msg, level='warn'): 11 | msg_clean = ' '.join([str(i) for i in msg]) 12 | print(json.dumps( 13 | {'msg': msg_clean, 14 | 'level': level, 15 | 'time': datetime.datetime.now().replace(microsecond=0).isoformat() + 'Z'})) 16 | 17 | 18 | def log_exception(msg): 19 | log(msg + ': ' + traceback.format_exc(), level='error') 20 | -------------------------------------------------------------------------------- /image/common/graphing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import io 5 | import base64 6 | import math 7 | 8 | 9 | def set_x_axis_times(frame_stride, frame_length, width): 10 | plt.xlabel('Time [sec]') 11 | time_len = (width * frame_stride) + frame_length 12 | times = np.linspace(0, time_len, 10) 13 | plt.xticks(np.linspace(0, width, len(times)), [round(x, 2) for x in times]) 14 | 15 | 16 | def create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs=None): 17 | matplotlib.use('Svg') 18 | _, ax = plt.subplots() 19 | if not freqs: 20 | freqs = np.linspace(0, sampling_freq / 2, 15) 21 | plt.ylabel('Frequency [Hz]') 22 | ax.imshow(power_spectrum, interpolation='nearest', 23 | cmap=matplotlib.cm.coolwarm, origin='lower') 24 | plt.yticks(np.linspace(0, height, len(freqs)), [math.ceil(x) for x in freqs]) 25 | set_x_axis_times(frame_stride, frame_length, width) 26 | 27 | buf = io.BytesIO() 28 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 29 | buf.seek(0) 30 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 31 | buf.close() 32 | return image 33 | 34 | 35 | def create_mfe_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs): 36 | # Trim down the frequency list for a y axis labels 37 | freqs = [freqs[0], *freqs[1:-1:4], freqs[-1]] 38 | return create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs) 39 | -------------------------------------------------------------------------------- /image/common/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import sys 4 | from scipy import signal 5 | 6 | 7 | def calc_resampled_size(input_sample_rate, output_sample_rate, input_length): 8 | """Calculate the output size after resampling. 9 | :returns: integer output size, >= 1 10 | """ 11 | target_size = int( 12 | math.ceil((output_sample_rate / input_sample_rate) * (input_length))) 13 | return max(target_size, 1) 14 | 15 | 16 | def calculate_freq(interval): 17 | """ Convert interval (ms) to frequency (Hz) 18 | """ 19 | freq = 1000 / interval 20 | if abs(freq - round(freq)) < 0.01: 21 | freq = round(freq) 22 | return freq 23 | 24 | 25 | def calc_decimation_ratios(filter_type, filter_cutoff, fs): 26 | if filter_type != 'low': 27 | return 1 28 | 29 | # we support base ratios of 3 and 10 in SDK 30 | ratios = [3, 10, 30, 100, 1000] 31 | ratios.reverse() 32 | for r in ratios: 33 | if fs / 2 / r * 0.9 > filter_cutoff: 34 | return r 35 | 36 | return 1 37 | 38 | 39 | def get_ratio_combo(r): 40 | if r == 1: 41 | return [1] 42 | elif r == 3 or r == 10: 43 | return [r] 44 | elif r == 30: 45 | return [3, 10] 46 | elif r == 100: 47 | return [10, 10] 48 | elif r == 1000: 49 | return [10, 10, 10] 50 | else: 51 | raise ValueError("Invalid decimation ratio: {}".format(r)) 52 | 53 | 54 | def create_decimate_filter(ratio): 55 | sos = signal.cheby1(8, 0.05, 0.8 / ratio, output='sos') 56 | zi = signal.sosfilt_zi(sos) 57 | return sos, zi 58 | 59 | 60 | def decimate_simple(x, ratio, export=False): 61 | if x.ndim != 1: 62 | raise ValueError(f'x must be 1D {x.shape}') 63 | x = x.reshape(x.shape[0]) 64 | if (ratio == 1): 65 | return x 66 | sos, zi = create_decimate_filter(ratio) 67 | y, zo = signal.sosfilt(sos, x, zi=zi * x[0]) 68 | sl = slice(None, None, ratio) 69 | y = y[sl] 70 | if export: 71 | return y, sos, zi 72 | return y 73 | 74 | 75 | class Resampler: 76 | """ Utility class to handle resampling and logging 77 | """ 78 | 79 | def __init__(self, total_samples): 80 | self.total_samples = total_samples 81 | self.ix = 0 82 | self.last_message = 0 83 | 84 | def resample(self, sample, new_length, original_length): 85 | # Work out the correct axis 86 | ds_axis = 0 87 | if (sample.shape[0] == 1): 88 | ds_axis = 1 89 | 90 | # Resample 91 | if (original_length != new_length): 92 | sample = signal.resample_poly( 93 | sample, new_length, original_length, axis=ds_axis) 94 | 95 | # Logging 96 | self.ix += 1 97 | if (int(round(time.time() * 1000)) - self.last_message >= 3000) or (self.ix == self.total_samples): 98 | print('[%s/%d] Resampling windows...' % 99 | (str(self.ix).rjust(len(str(self.total_samples)), ' '), self.total_samples)) 100 | 101 | if (self.ix == self.total_samples): 102 | print('Resampled %d windows\n' % self.total_samples) 103 | 104 | sys.stdout.flush() 105 | self.last_message = int(round(time.time() * 1000)) 106 | 107 | return sample 108 | -------------------------------------------------------------------------------- /image/common/spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | sys.path.append('/') 5 | from .errors import ConfigurationError 6 | 7 | 8 | def next_power_of_2(x): 9 | return 1 if x == 0 else 2**(x - 1).bit_length() 10 | 11 | 12 | def welch_max_hold(fx, sampling_freq, nfft, n_overlap): 13 | n_overlap = int(n_overlap) 14 | spec_powers = [0 for _ in range(nfft//2+1)] 15 | ix = 0 16 | while ix <= len(fx): 17 | # Slicing truncates if end_idx > len, and rfft will auto zero pad 18 | fft_out = np.abs(np.fft.rfft(fx[ix:ix+nfft], nfft)) 19 | spec_powers = np.maximum(spec_powers, fft_out**2/nfft) 20 | ix = ix + (nfft-n_overlap) 21 | return np.fft.rfftfreq(nfft, 1/sampling_freq), spec_powers 22 | 23 | 24 | def zero_handling(x): 25 | """ 26 | This function handle the issue with zero values if the are exposed 27 | to become an argument for any log function. 28 | :param x: The vector. 29 | :return: The vector with zeros substituted with epsilon values. 30 | """ 31 | return np.where(x == 0, 1e-10, x) 32 | 33 | 34 | def cap_frame_stride(window_size_ms, frame_stride): 35 | """Returns the frame stride passed in, 36 | or a stride that creates 500 frames if the window size is too large. 37 | 38 | Args: 39 | window_size_ms (int): The users window size (in ms). 40 | If none or 0, no capping is done. 41 | frame_stride (float): The desired frame stride 42 | 43 | Returns: 44 | float: Either the passed in frame_stride, or longer frame stride 45 | """ 46 | if window_size_ms: 47 | num_frames = (window_size_ms / 1000) / frame_stride 48 | if num_frames > 500: 49 | print('WARNING: Your window size is too large for the ideal frame stride. ' 50 | f'Set window size to {500 * frame_stride * 1000} ms, or smaller. ' 51 | 'Adjusting ideal frame stride to set number of frames to 500') 52 | frame_stride = (window_size_ms / 1000) / 500 53 | return frame_stride 54 | 55 | 56 | def audio_set_params(frame_length, fs): 57 | """Suggest parameters for audio processing (MFE/MFCC) 58 | 59 | Args: 60 | frame_length (float): The desired frame length (in seconds) 61 | fs (int): The sampling frequency (in Hz) 62 | 63 | Returns: 64 | fft_length: Recomended FFT length 65 | num_filters: Recomended number of filters 66 | """ 67 | DEFAULT_NUM_FILTERS = 40 68 | DEFAULT_NFFT = 256 # for 8kHz sampling rate 69 | 70 | fft_length = next_power_of_2(int(frame_length * fs)) 71 | num_filters = int(DEFAULT_NUM_FILTERS + np.log2(fft_length / DEFAULT_NFFT)) 72 | return fft_length, num_filters 73 | -------------------------------------------------------------------------------- /image/common/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import numpy as np 3 | from scipy.stats import skew, entropy, kurtosis 4 | 5 | 6 | def calculate_entropy(x): 7 | # todo: try approximate entropy 8 | # todo: try Kozachenko and Leonenko 9 | probabilities = np.histogram(x, bins=100, density=True)[0] 10 | return {'entropy': entropy(probabilities)} 11 | 12 | 13 | def get_percentile_from_sorted(array, percentile): 14 | # adding 0.5 is a trick to get rounding out of C flooring behavior during cast 15 | index = int(((len(array)-1) * percentile/100) + 0.5) 16 | return array[index] 17 | 18 | 19 | def calculate_statistics(x): 20 | output = {} 21 | x.sort() 22 | output['n5'] = get_percentile_from_sorted(x, 5) 23 | output['n25'] = get_percentile_from_sorted(x, 25) 24 | output['n75'] = get_percentile_from_sorted(x, 75) 25 | output['n95'] = get_percentile_from_sorted(x, 95) 26 | output['median'] = get_percentile_from_sorted(x, 50) 27 | output['mean'] = np.mean(x) 28 | output['std'] = np.std(x) 29 | output['var'] = np.var(x, ddof=1) 30 | output['rms'] = np.sqrt(np.mean(x**2)) 31 | output['skew'] = 0 if output['rms'] == 0 else skew(x) 32 | output['kurtosis'] = 0 if output['rms'] == 0 else kurtosis(x) 33 | return output 34 | 35 | 36 | def calculate_crossings(x): 37 | lx = len(x) 38 | zero_crossing_indices = np.nonzero(np.diff(np.array(x) > 0))[0] 39 | no_zero_crossings = len(zero_crossing_indices) / lx 40 | m = np.nanmean(x) 41 | mean_crossing_indices = np.nonzero(np.diff(np.array(x) > m))[0] 42 | no_mean_crossings = len(mean_crossing_indices) / lx 43 | return {'zcross': no_zero_crossings, 'mcross': no_mean_crossings} 44 | 45 | 46 | def get_features(x): 47 | features = calculate_entropy(x) 48 | features.update(calculate_crossings(x)) 49 | features.update(calculate_statistics(x)) 50 | return features 51 | 52 | 53 | def get_max_level(signal_length): 54 | return int(np.log2(signal_length / 32)) 55 | 56 | 57 | def get_min_length(level): 58 | return 32 * np.power(2, level) 59 | 60 | 61 | def dwt_features(x, wav='db4', level=4, mode='stats'): 62 | y = pywt.wavedec(x, wav, level=level) 63 | 64 | if mode == 'raw': 65 | XW = [item for sublist in y for item in sublist] 66 | else: 67 | features = [] 68 | labels = [] 69 | for i in range(len(y)): 70 | d = get_features(y[i]) 71 | for k, v in d.items(): 72 | features.append(v) 73 | labels.append('L' + str(i) + '-' + k) 74 | 75 | return features, labels, y[0] 76 | 77 | 78 | def get_wavefunc(wav, level): 79 | 80 | wavelet = pywt.Wavelet(wav) 81 | try: 82 | phi, psi, x = wavelet.wavefun(level) 83 | except: 84 | phi, psi, _, _, x = wavelet.wavefun(level) 85 | return phi, psi, x 86 | -------------------------------------------------------------------------------- /image/dsp-server.py: -------------------------------------------------------------------------------- 1 | # This is a generic Edge Impulse DSP server in Python 2 | # You probably don't need to change this file. 3 | 4 | import sys, importlib, os, socket, json, math, traceback 5 | from http.server import HTTPServer, BaseHTTPRequestHandler 6 | from socketserver import ThreadingMixIn 7 | import threading 8 | from urllib.parse import urlparse, parse_qs 9 | import traceback 10 | import logging 11 | import numpy as np 12 | from dsp import generate_features 13 | 14 | def get_params(self): 15 | with open('parameters.json', 'r') as f: 16 | return json.loads(f.read()) 17 | 18 | def single_req(self, fn, body): 19 | if (not body['features'] or len(body['features']) == 0): 20 | raise ValueError('Missing "features" in body') 21 | if (not 'params' in body): 22 | raise ValueError('Missing "params" in body') 23 | if (not 'sampling_freq' in body): 24 | raise ValueError('Missing "sampling_freq" in body') 25 | if (not 'draw_graphs' in body): 26 | raise ValueError('Missing "draw_graphs" in body') 27 | 28 | args = { 29 | 'draw_graphs': body['draw_graphs'], 30 | 'raw_data': np.array(body['features']), 31 | 'axes': np.array(body['axes']), 32 | 'sampling_freq': body['sampling_freq'], 33 | 'implementation_version': body['implementation_version'] 34 | } 35 | 36 | for param_key in body['params'].keys(): 37 | args[param_key] = body['params'][param_key] 38 | 39 | processed = fn(**args) 40 | if (isinstance(processed['features'], np.ndarray)): 41 | processed['features'] = processed['features'].flatten().tolist() 42 | 43 | body = json.dumps(processed) 44 | 45 | self.send_response(200) 46 | self.send_header('Content-Type', 'application/json') 47 | self.end_headers() 48 | self.wfile.write(body.encode()) 49 | 50 | def batch_req(self, fn, body): 51 | if (not body['features'] or len(body['features']) == 0): 52 | raise ValueError('Missing "features" in body') 53 | if (not 'params' in body): 54 | raise ValueError('Missing "params" in body') 55 | if (not 'sampling_freq' in body): 56 | raise ValueError('Missing "sampling_freq" in body') 57 | 58 | base_args = { 59 | 'draw_graphs': False, 60 | 'axes': np.array(body['axes']), 61 | 'sampling_freq': body['sampling_freq'], 62 | 'implementation_version': body['implementation_version'] 63 | } 64 | 65 | for param_key in body['params'].keys(): 66 | base_args[param_key] = body['params'][param_key] 67 | 68 | total = 0 69 | features = [] 70 | labels = [] 71 | output_config = None 72 | 73 | for example in body['features']: 74 | args = dict(base_args) 75 | args['raw_data'] = np.array(example) 76 | f = fn(**args) 77 | if (isinstance(f['features'], np.ndarray)): 78 | features.append(f['features'].flatten().tolist()) 79 | else: 80 | features.append(f['features']) 81 | 82 | if total == 0: 83 | if ('labels' in f): 84 | labels = f['labels'] 85 | if ('output_config' in f): 86 | output_config = f['output_config'] 87 | 88 | total += 1 89 | 90 | body = json.dumps({ 91 | 'success': True, 92 | 'features': features, 93 | 'labels': labels, 94 | 'output_config': output_config 95 | }) 96 | 97 | self.send_response(200) 98 | self.send_header('Content-Type', 'application/json') 99 | self.end_headers() 100 | self.wfile.write(body.encode()) 101 | 102 | def tflite_req(self, fn, body): 103 | if (not 'params' in body): 104 | raise ValueError('Missing "params" in body') 105 | if (not 'sampling_freq' in body): 106 | raise ValueError('Missing "sampling_freq" in body') 107 | 108 | args = { 109 | 'axes': np.array(body['axes']), 110 | 'sampling_freq': body['sampling_freq'], 111 | 'implementation_version': body['implementation_version'], 112 | 'input_shape': body['input_shape'] 113 | } 114 | 115 | for param_key in body['params'].keys(): 116 | args[param_key] = body['params'][param_key] 117 | 118 | tflite_byte_arr = fn(**args) 119 | 120 | self.send_response(200) 121 | self.send_header('Content-type', 'application/octet-stream') 122 | self.send_header('Content-Disposition', 'attachment; filename="dsp.tflite"') 123 | self.end_headers() 124 | self.wfile.write(tflite_byte_arr) 125 | 126 | class Handler(BaseHTTPRequestHandler): 127 | def do_GET(self): 128 | url = urlparse(self.path) 129 | params = get_params(self) 130 | 131 | if (url.path == '/'): 132 | self.send_response(200) 133 | self.send_header('Content-Type', 'text/plain') 134 | self.end_headers() 135 | self.wfile.write(('Edge Impulse DSP block: ' + params['info']['title'] + ' by ' + 136 | params['info']['author']).encode()) 137 | 138 | elif (url.path == '/parameters'): 139 | self.send_response(200) 140 | self.send_header('Content-Type', 'application/json') 141 | self.end_headers() 142 | params['version'] = 1 143 | self.wfile.write(json.dumps(params).encode()) 144 | 145 | else: 146 | self.send_response(404) 147 | self.send_header('Content-Type', 'text/plain') 148 | self.end_headers() 149 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 150 | 151 | def do_POST(self): 152 | url = urlparse(self.path) 153 | try: 154 | if (url.path == '/run'): 155 | content_len = int(self.headers.get('Content-Length')) 156 | post_body = self.rfile.read(content_len) 157 | body = json.loads(post_body.decode('utf-8')) 158 | single_req(self, generate_features, body) 159 | 160 | elif (url.path == '/batch'): 161 | content_len = int(self.headers.get('Content-Length')) 162 | post_body = self.rfile.read(content_len) 163 | body = json.loads(post_body.decode('utf-8')) 164 | batch_req(self, generate_features, body) 165 | 166 | else: 167 | self.send_response(404) 168 | self.send_header('Content-Type', 'text/plain') 169 | self.end_headers() 170 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 171 | 172 | 173 | except Exception as e: 174 | print('Failed to handle request', e, traceback.format_exc()) 175 | self.send_response(200) 176 | self.send_header('Content-Type', 'application/json') 177 | self.end_headers() 178 | self.wfile.write(json.dumps({ 'success': False, 'error': str(e) }).encode()) 179 | 180 | def log_message(self, format, *args): 181 | return 182 | 183 | class ThreadingSimpleServer(ThreadingMixIn, HTTPServer): 184 | pass 185 | 186 | def run(): 187 | host = '0.0.0.0' if not 'HOST' in os.environ else os.environ['HOST'] 188 | port = 4446 if not 'PORT' in os.environ else int(os.environ['PORT']) 189 | 190 | server = ThreadingSimpleServer((host, port), Handler) 191 | print('Listening on host', host, 'port', port) 192 | server.serve_forever() 193 | 194 | if __name__ == '__main__': 195 | run() 196 | -------------------------------------------------------------------------------- /image/dsp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import numpy as np 5 | import sys 6 | import io, base64 7 | from PIL import Image 8 | 9 | def generate_features(implementation_version, draw_graphs, raw_data, axes, sampling_freq, channels): 10 | if (implementation_version != 1): 11 | raise Exception('implementation_version should be 1') 12 | 13 | graphs = [] 14 | all_features = [] 15 | 16 | width = raw_data[0] 17 | height = raw_data[1] 18 | raw_data = raw_data[2:].astype(dtype=np.uint32).view(dtype=np.uint8) 19 | 20 | pixels_per_frame = height * width * 4 21 | frame_count = 0 22 | expected_frame_count = len(raw_data) / pixels_per_frame 23 | 24 | for i in np.arange(0, len(raw_data), pixels_per_frame): 25 | frame = raw_data[i:i+pixels_per_frame] 26 | bs = frame.tobytes() 27 | pixels = [] 28 | frame_count = frame_count + 1 29 | ix = 0 30 | 31 | if channels == 'Grayscale': 32 | while ix < frame.shape[0]: 33 | # ITU-R 601-2 luma transform 34 | # see: https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.convert 35 | pixels.append((0.299 / 255.0) * float(bs[ix + 2]) + (0.587 / 255.0) * float(bs[ix + 1]) + (0.114 / 255.0) * float(bs[ix])) 36 | ix = ix + 4 37 | else: 38 | while ix < frame.shape[0]: 39 | pixels.append(float(bs[ix + 2]) / 255.0) 40 | pixels.append(float(bs[ix + 1]) / 255.0) 41 | pixels.append(float(bs[ix]) / 255.0) 42 | ix = ix + 4 43 | 44 | all_features = all_features + pixels 45 | 46 | if draw_graphs: 47 | im = None 48 | if channels == 'Grayscale': 49 | im = Image.fromarray(np.uint8((np.array(pixels) * 255.0).reshape(height, width)), mode='L') 50 | else: 51 | im = Image.fromarray(np.uint8((np.array(pixels) * 255.0).reshape(height, width, 3)), mode='RGB') 52 | im = im.convert(mode='RGBA') 53 | buf = io.BytesIO() 54 | im.save(buf, format='PNG') 55 | 56 | buf.seek(0) 57 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 58 | 59 | buf.close() 60 | 61 | name = 'Image' 62 | if expected_frame_count > 1: 63 | name = 'Frame ' + str(frame_count) 64 | 65 | graphs.append({ 66 | 'name': name, 67 | 'image': image, 68 | 'imageMimeType': 'image/png', 69 | 'type': 'image' 70 | }) 71 | 72 | num_channels = 1 73 | if channels == 'RGB': 74 | num_channels = 3 75 | 76 | image_config = { 'width': int(width), 'height': int(height), 'channels': num_channels, 'frames': frame_count } 77 | output_config = { 'type': 'image', 'shape': image_config } 78 | 79 | return { 'features': all_features, 'graphs': graphs, 'output_config': output_config, 'fft_used': [] } 80 | 81 | if __name__ == "__main__": 82 | parser = argparse.ArgumentParser(description='Returns raw data') 83 | parser.add_argument('--features', type=str, required=True, 84 | help='Axis data as a flattened WAV file (pass as comma separated values)') 85 | parser.add_argument('--axes', type=str, required=True, 86 | help='Names of the axis (pass as comma separated values)') 87 | parser.add_argument('--frequency', type=float, required=True, 88 | help='Frequency in hz') 89 | parser.add_argument('--channels', type=str, required=True, 90 | help='Image channels to use.') 91 | parser.add_argument('--draw-graphs', type=bool, required=True, 92 | help='Whether to draw graphs') 93 | 94 | args = parser.parse_args() 95 | 96 | raw_features = np.array([float(item.strip()) for item in args.features.split(',')]) 97 | raw_axes = args.axes.split(',') 98 | 99 | try: 100 | processed = generate_features(1, False, raw_features, args.axes, args.frequency, args.channels) 101 | 102 | print('Begin output') 103 | # print(json.dumps(processed)) 104 | print('End output') 105 | except Exception as e: 106 | print(e, file=sys.stderr) 107 | exit(1) 108 | -------------------------------------------------------------------------------- /image/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": { 3 | "title": "Image", 4 | "author": "Edge Impulse", 5 | "description": "Preprocess and normalize image data, and optionally reduce the color depth.", 6 | "name": "Image", 7 | "cppType": "image_custom", 8 | "visualization": "dimensionalityReduction", 9 | "experimental": false, 10 | "preferConvolution": false, 11 | "latestImplementationVersion": 1 12 | }, 13 | "parameters": [ 14 | { 15 | "group": "Image", 16 | "items": [ 17 | { 18 | "name": "Color depth", 19 | "value": "RGB", 20 | "help": "Color depth to use", 21 | "type": "select", 22 | "valid": [ "RGB", "Grayscale" ], 23 | "param": "channels" 24 | } 25 | ] 26 | } 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /image/requirements-blocks.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /image/third_party/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgeimpulse/processing-blocks/ba8108d8427ede9d8098808f283d95e0f7d610b8/image/third_party/placeholder -------------------------------------------------------------------------------- /mfcc/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental@sha256:3c244c0c6fc9d6aa3ddb73af4264b3a23597523ac553294218c13735a2c6cf79 2 | FROM ubuntu:20.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | WORKDIR /app 7 | 8 | # python3 and all dependencies for scipy 9 | RUN apt update && apt install -y python3 python3-pip libatlas-base-dev gfortran-9 libfreetype6-dev wget && \ 10 | ln -s $(which gfortran-9) /usr/bin/gfortran 11 | 12 | # Update pip 13 | RUN pip3 install -U pip==22.0.3 14 | 15 | # Cython and scikit-learn - it needs to be done in this order for some reason 16 | RUN pip3 --no-cache-dir install Cython==0.29.24 17 | 18 | # Rest of the dependencies 19 | COPY requirements-blocks.txt ./ 20 | RUN pip3 --no-cache-dir install -r requirements-blocks.txt 21 | 22 | COPY third_party /third_party 23 | COPY . ./ 24 | 25 | EXPOSE 4446 26 | 27 | ENTRYPOINT ["python3", "-u", "dsp-server.py"] 28 | -------------------------------------------------------------------------------- /mfcc/README.md: -------------------------------------------------------------------------------- 1 | # Notes on parameters 2 | 3 | Summary: Number of coefficients, filter number (really should say “Number of mel filters”) and Window size are worth playing with 4 | 5 | Number of coefficients 6 | 6 ( “Six coefficients succeed in capturing most of the relevant information. The importance of the higher cepstrum coefficients appears to depend on the speaker” (1) ) 7 | 13, I see this default a lot ( also, “A set of 10 mel-frequency cepstrum coefficients computed…” (1) ) 8 | Cannot be larger than Filter number! 9 | 10 | Frame length 11 | 0.024 (24 mS) 12 | Related to the max length of a syllable in speech 13 | Not much wiggle room here. (3) mentions that if segments are too long, the sound will change too much from start to finish. If segments are too short, there will not be enough of the signal to get useful information. 14 | Was used in (1). 15 | Other sources mention 10-20 mS as ideal 16 | Set it and forget it 17 | 18 | Frame stride 19 | 0.006 (6 mS) (6.4 mS was used in (1) ) 20 | Let’s set it and forget it. 21 | 22 | Filter number 23 | 20, 24 | 26, 25 | 40 26 | (see (4)) 27 | 28 | FFT size 29 | Let frame_size = Frame_length * sample_rate 30 | Choose FFT size to be the next power of 2 >frame_size 31 | 32 | Window size ( need to rename “Normalization window size ( # of cepstral frames ) ) 33 | 300 34 | 100 35 | 36 | Low frequency: 37 | 0, set it and forget it. (3) mentions that fricative and plosive sounds occupy the low end of the cepstrum, so you would not really want to change this 38 | 39 | High frequency: 40 | Should just set to ½ the sample rate (Nyquist limit), or 8KHz, whichever is lower 41 | 42 | Sources: 43 | Davis and Mermelstein, 1980 44 | Viikki Laurila, 1998 45 | Oppenheim and Schafer, Discrete Time Signal Processing 46 | http://practicalcryptography.com/miscellaneous/machine-learning/guide-mel-frequency-cepstral-coefficients-mfccs/ 47 | -------------------------------------------------------------------------------- /mfcc/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsp import generate_features 2 | -------------------------------------------------------------------------------- /mfcc/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('/') 4 | from common.sampling import calc_resampled_size, calculate_freq, Resampler 5 | 6 | 7 | class Dataset: 8 | '''Create an iterable dataset when x data is flattened, handling reshaping and resampling''' 9 | 10 | def __init__(self, X_all, metadata, axis, returns_interval=True, resample_interval_ms=None): 11 | self.ix = 0 12 | self.returns_interval = returns_interval 13 | self.max_len = 0 14 | 15 | X_all_shaped = [] 16 | y_all = [] 17 | self.y_label_set = set() 18 | intervals_all = [] 19 | current_offset = 0 20 | 21 | if resample_interval_ms: 22 | self.fs = calculate_freq(resample_interval_ms) 23 | else: 24 | self.fs = None 25 | 26 | # Prepare for resampling data 27 | if resample_interval_ms is not None: 28 | resample_utility = Resampler(len(metadata)) 29 | target_freq = calculate_freq(resample_interval_ms) 30 | intervals_all.append(resample_interval_ms) 31 | 32 | # Reshape all samples 33 | for ix in range(len(metadata)): 34 | # Get x data using offset 35 | cur_len = metadata[ix] 36 | X_full = X_all[current_offset: current_offset + cur_len] 37 | current_offset = current_offset + cur_len 38 | 39 | # Split the interval, label from the features 40 | interval_ms = X_full[0] 41 | y = X_full[1] 42 | X = X_full[2:] 43 | 44 | if not self.fs: 45 | # if we didn't get a resampling rate from the caller, use the first sample's rate 46 | self.fs = calculate_freq(interval_ms) 47 | 48 | if not np.isnan(X).any(): 49 | # Reshape 50 | len_adjusted = cur_len - 2 51 | rows = int(len_adjusted / axis) 52 | # Data length is unexpected 53 | if not ((len_adjusted % axis) == 0): 54 | raise ValueError('Sample length is invalid, check the axis count.') 55 | 56 | X = np.reshape(X, (rows, axis)) 57 | 58 | # Resample data 59 | if resample_interval_ms is not None: 60 | # Work out the up and down factors using sample lengths 61 | original_length = X.shape[0] 62 | original_freq = calculate_freq(interval_ms) 63 | new_length = calc_resampled_size(original_freq, target_freq, original_length) 64 | 65 | # Resample 66 | X = resample_utility.resample(X, new_length, original_length) 67 | else: 68 | intervals_all.append(interval_ms) 69 | 70 | # Store the longest sample length 71 | self.max_len = max(self.max_len, X.shape[0]) 72 | X_all_shaped.append(X) 73 | y_all.append(y) 74 | self.y_label_set.add(y) 75 | 76 | self.X_all = X_all_shaped 77 | self.y_all = y_all 78 | self.intervals = intervals_all 79 | 80 | def reset(self): 81 | self.ix = 0 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | if self.ix >= len(self.y_all): 88 | self.reset() 89 | raise StopIteration 90 | 91 | X = self.X_all[self.ix] 92 | y = self.y_all[self.ix] 93 | if (len(self.intervals) == 1): 94 | # Resampled data has the same interval so we only store it once 95 | interval_ms = self.intervals[0] 96 | else: 97 | interval_ms = self.intervals[self.ix] 98 | 99 | self.ix += 1 100 | 101 | if (self.returns_interval): 102 | return X, y, interval_ms 103 | else: 104 | return X, y 105 | -------------------------------------------------------------------------------- /mfcc/common/errors.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import traceback 4 | 5 | 6 | class ConfigurationError(Exception): 7 | pass 8 | 9 | 10 | def log(*msg, level='warn'): 11 | msg_clean = ' '.join([str(i) for i in msg]) 12 | print(json.dumps( 13 | {'msg': msg_clean, 14 | 'level': level, 15 | 'time': datetime.datetime.now().replace(microsecond=0).isoformat() + 'Z'})) 16 | 17 | 18 | def log_exception(msg): 19 | log(msg + ': ' + traceback.format_exc(), level='error') 20 | -------------------------------------------------------------------------------- /mfcc/common/graphing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import io 5 | import base64 6 | import math 7 | 8 | 9 | def set_x_axis_times(frame_stride, frame_length, width): 10 | plt.xlabel('Time [sec]') 11 | time_len = (width * frame_stride) + frame_length 12 | times = np.linspace(0, time_len, 10) 13 | plt.xticks(np.linspace(0, width, len(times)), [round(x, 2) for x in times]) 14 | 15 | 16 | def create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs=None): 17 | matplotlib.use('Svg') 18 | _, ax = plt.subplots() 19 | if not freqs: 20 | freqs = np.linspace(0, sampling_freq / 2, 15) 21 | plt.ylabel('Frequency [Hz]') 22 | ax.imshow(power_spectrum, interpolation='nearest', 23 | cmap=matplotlib.cm.coolwarm, origin='lower') 24 | plt.yticks(np.linspace(0, height, len(freqs)), [math.ceil(x) for x in freqs]) 25 | set_x_axis_times(frame_stride, frame_length, width) 26 | 27 | buf = io.BytesIO() 28 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 29 | buf.seek(0) 30 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 31 | buf.close() 32 | return image 33 | 34 | 35 | def create_mfe_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs): 36 | # Trim down the frequency list for a y axis labels 37 | freqs = [freqs[0], *freqs[1:-1:4], freqs[-1]] 38 | return create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs) 39 | -------------------------------------------------------------------------------- /mfcc/common/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import sys 4 | from scipy import signal 5 | 6 | 7 | def calc_resampled_size(input_sample_rate, output_sample_rate, input_length): 8 | """Calculate the output size after resampling. 9 | :returns: integer output size, >= 1 10 | """ 11 | target_size = int( 12 | math.ceil((output_sample_rate / input_sample_rate) * (input_length))) 13 | return max(target_size, 1) 14 | 15 | 16 | def calculate_freq(interval): 17 | """ Convert interval (ms) to frequency (Hz) 18 | """ 19 | freq = 1000 / interval 20 | if abs(freq - round(freq)) < 0.01: 21 | freq = round(freq) 22 | return freq 23 | 24 | 25 | def calc_decimation_ratios(filter_type, filter_cutoff, fs): 26 | if filter_type != 'low': 27 | return 1 28 | 29 | # we support base ratios of 3 and 10 in SDK 30 | ratios = [3, 10, 30, 100, 1000] 31 | ratios.reverse() 32 | for r in ratios: 33 | if fs / 2 / r * 0.9 > filter_cutoff: 34 | return r 35 | 36 | return 1 37 | 38 | 39 | def get_ratio_combo(r): 40 | if r == 1: 41 | return [1] 42 | elif r == 3 or r == 10: 43 | return [r] 44 | elif r == 30: 45 | return [3, 10] 46 | elif r == 100: 47 | return [10, 10] 48 | elif r == 1000: 49 | return [10, 10, 10] 50 | else: 51 | raise ValueError("Invalid decimation ratio: {}".format(r)) 52 | 53 | 54 | def create_decimate_filter(ratio): 55 | sos = signal.cheby1(8, 0.05, 0.8 / ratio, output='sos') 56 | zi = signal.sosfilt_zi(sos) 57 | return sos, zi 58 | 59 | 60 | def decimate_simple(x, ratio, export=False): 61 | if x.ndim != 1: 62 | raise ValueError(f'x must be 1D {x.shape}') 63 | x = x.reshape(x.shape[0]) 64 | if (ratio == 1): 65 | return x 66 | sos, zi = create_decimate_filter(ratio) 67 | y, zo = signal.sosfilt(sos, x, zi=zi * x[0]) 68 | sl = slice(None, None, ratio) 69 | y = y[sl] 70 | if export: 71 | return y, sos, zi 72 | return y 73 | 74 | 75 | class Resampler: 76 | """ Utility class to handle resampling and logging 77 | """ 78 | 79 | def __init__(self, total_samples): 80 | self.total_samples = total_samples 81 | self.ix = 0 82 | self.last_message = 0 83 | 84 | def resample(self, sample, new_length, original_length): 85 | # Work out the correct axis 86 | ds_axis = 0 87 | if (sample.shape[0] == 1): 88 | ds_axis = 1 89 | 90 | # Resample 91 | if (original_length != new_length): 92 | sample = signal.resample_poly( 93 | sample, new_length, original_length, axis=ds_axis) 94 | 95 | # Logging 96 | self.ix += 1 97 | if (int(round(time.time() * 1000)) - self.last_message >= 3000) or (self.ix == self.total_samples): 98 | print('[%s/%d] Resampling windows...' % 99 | (str(self.ix).rjust(len(str(self.total_samples)), ' '), self.total_samples)) 100 | 101 | if (self.ix == self.total_samples): 102 | print('Resampled %d windows\n' % self.total_samples) 103 | 104 | sys.stdout.flush() 105 | self.last_message = int(round(time.time() * 1000)) 106 | 107 | return sample 108 | -------------------------------------------------------------------------------- /mfcc/common/spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | sys.path.append('/') 5 | from .errors import ConfigurationError 6 | 7 | 8 | def next_power_of_2(x): 9 | return 1 if x == 0 else 2**(x - 1).bit_length() 10 | 11 | 12 | def welch_max_hold(fx, sampling_freq, nfft, n_overlap): 13 | n_overlap = int(n_overlap) 14 | spec_powers = [0 for _ in range(nfft//2+1)] 15 | ix = 0 16 | while ix <= len(fx): 17 | # Slicing truncates if end_idx > len, and rfft will auto zero pad 18 | fft_out = np.abs(np.fft.rfft(fx[ix:ix+nfft], nfft)) 19 | spec_powers = np.maximum(spec_powers, fft_out**2/nfft) 20 | ix = ix + (nfft-n_overlap) 21 | return np.fft.rfftfreq(nfft, 1/sampling_freq), spec_powers 22 | 23 | 24 | def zero_handling(x): 25 | """ 26 | This function handle the issue with zero values if the are exposed 27 | to become an argument for any log function. 28 | :param x: The vector. 29 | :return: The vector with zeros substituted with epsilon values. 30 | """ 31 | return np.where(x == 0, 1e-10, x) 32 | 33 | 34 | def cap_frame_stride(window_size_ms, frame_stride): 35 | """Returns the frame stride passed in, 36 | or a stride that creates 500 frames if the window size is too large. 37 | 38 | Args: 39 | window_size_ms (int): The users window size (in ms). 40 | If none or 0, no capping is done. 41 | frame_stride (float): The desired frame stride 42 | 43 | Returns: 44 | float: Either the passed in frame_stride, or longer frame stride 45 | """ 46 | if window_size_ms: 47 | num_frames = (window_size_ms / 1000) / frame_stride 48 | if num_frames > 500: 49 | print('WARNING: Your window size is too large for the ideal frame stride. ' 50 | f'Set window size to {500 * frame_stride * 1000} ms, or smaller. ' 51 | 'Adjusting ideal frame stride to set number of frames to 500') 52 | frame_stride = (window_size_ms / 1000) / 500 53 | return frame_stride 54 | 55 | 56 | def audio_set_params(frame_length, fs): 57 | """Suggest parameters for audio processing (MFE/MFCC) 58 | 59 | Args: 60 | frame_length (float): The desired frame length (in seconds) 61 | fs (int): The sampling frequency (in Hz) 62 | 63 | Returns: 64 | fft_length: Recomended FFT length 65 | num_filters: Recomended number of filters 66 | """ 67 | DEFAULT_NUM_FILTERS = 40 68 | DEFAULT_NFFT = 256 # for 8kHz sampling rate 69 | 70 | fft_length = next_power_of_2(int(frame_length * fs)) 71 | num_filters = int(DEFAULT_NUM_FILTERS + np.log2(fft_length / DEFAULT_NFFT)) 72 | return fft_length, num_filters 73 | -------------------------------------------------------------------------------- /mfcc/common/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import numpy as np 3 | from scipy.stats import skew, entropy, kurtosis 4 | 5 | 6 | def calculate_entropy(x): 7 | # todo: try approximate entropy 8 | # todo: try Kozachenko and Leonenko 9 | probabilities = np.histogram(x, bins=100, density=True)[0] 10 | return {'entropy': entropy(probabilities)} 11 | 12 | 13 | def get_percentile_from_sorted(array, percentile): 14 | # adding 0.5 is a trick to get rounding out of C flooring behavior during cast 15 | index = int(((len(array)-1) * percentile/100) + 0.5) 16 | return array[index] 17 | 18 | 19 | def calculate_statistics(x): 20 | output = {} 21 | x.sort() 22 | output['n5'] = get_percentile_from_sorted(x, 5) 23 | output['n25'] = get_percentile_from_sorted(x, 25) 24 | output['n75'] = get_percentile_from_sorted(x, 75) 25 | output['n95'] = get_percentile_from_sorted(x, 95) 26 | output['median'] = get_percentile_from_sorted(x, 50) 27 | output['mean'] = np.mean(x) 28 | output['std'] = np.std(x) 29 | output['var'] = np.var(x, ddof=1) 30 | output['rms'] = np.sqrt(np.mean(x**2)) 31 | output['skew'] = 0 if output['rms'] == 0 else skew(x) 32 | output['kurtosis'] = 0 if output['rms'] == 0 else kurtosis(x) 33 | return output 34 | 35 | 36 | def calculate_crossings(x): 37 | lx = len(x) 38 | zero_crossing_indices = np.nonzero(np.diff(np.array(x) > 0))[0] 39 | no_zero_crossings = len(zero_crossing_indices) / lx 40 | m = np.nanmean(x) 41 | mean_crossing_indices = np.nonzero(np.diff(np.array(x) > m))[0] 42 | no_mean_crossings = len(mean_crossing_indices) / lx 43 | return {'zcross': no_zero_crossings, 'mcross': no_mean_crossings} 44 | 45 | 46 | def get_features(x): 47 | features = calculate_entropy(x) 48 | features.update(calculate_crossings(x)) 49 | features.update(calculate_statistics(x)) 50 | return features 51 | 52 | 53 | def get_max_level(signal_length): 54 | return int(np.log2(signal_length / 32)) 55 | 56 | 57 | def get_min_length(level): 58 | return 32 * np.power(2, level) 59 | 60 | 61 | def dwt_features(x, wav='db4', level=4, mode='stats'): 62 | y = pywt.wavedec(x, wav, level=level) 63 | 64 | if mode == 'raw': 65 | XW = [item for sublist in y for item in sublist] 66 | else: 67 | features = [] 68 | labels = [] 69 | for i in range(len(y)): 70 | d = get_features(y[i]) 71 | for k, v in d.items(): 72 | features.append(v) 73 | labels.append('L' + str(i) + '-' + k) 74 | 75 | return features, labels, y[0] 76 | 77 | 78 | def get_wavefunc(wav, level): 79 | 80 | wavelet = pywt.Wavelet(wav) 81 | try: 82 | phi, psi, x = wavelet.wavefun(level) 83 | except: 84 | phi, psi, _, _, x = wavelet.wavefun(level) 85 | return phi, psi, x 86 | -------------------------------------------------------------------------------- /mfcc/dsp-server.py: -------------------------------------------------------------------------------- 1 | # This is a generic Edge Impulse DSP server in Python 2 | # You probably don't need to change this file. 3 | 4 | import sys, importlib, os, socket, json, math, traceback 5 | from http.server import HTTPServer, BaseHTTPRequestHandler 6 | from socketserver import ThreadingMixIn 7 | import threading 8 | from urllib.parse import urlparse, parse_qs 9 | import traceback 10 | import logging 11 | import numpy as np 12 | from dsp import generate_features 13 | 14 | def get_params(self): 15 | with open('parameters.json', 'r') as f: 16 | return json.loads(f.read()) 17 | 18 | def single_req(self, fn, body): 19 | if (not body['features'] or len(body['features']) == 0): 20 | raise ValueError('Missing "features" in body') 21 | if (not 'params' in body): 22 | raise ValueError('Missing "params" in body') 23 | if (not 'sampling_freq' in body): 24 | raise ValueError('Missing "sampling_freq" in body') 25 | if (not 'draw_graphs' in body): 26 | raise ValueError('Missing "draw_graphs" in body') 27 | 28 | args = { 29 | 'draw_graphs': body['draw_graphs'], 30 | 'raw_data': np.array(body['features']), 31 | 'axes': np.array(body['axes']), 32 | 'sampling_freq': body['sampling_freq'], 33 | 'implementation_version': body['implementation_version'] 34 | } 35 | 36 | for param_key in body['params'].keys(): 37 | args[param_key] = body['params'][param_key] 38 | 39 | processed = fn(**args) 40 | if (isinstance(processed['features'], np.ndarray)): 41 | processed['features'] = processed['features'].flatten().tolist() 42 | 43 | body = json.dumps(processed) 44 | 45 | self.send_response(200) 46 | self.send_header('Content-Type', 'application/json') 47 | self.end_headers() 48 | self.wfile.write(body.encode()) 49 | 50 | def batch_req(self, fn, body): 51 | if (not body['features'] or len(body['features']) == 0): 52 | raise ValueError('Missing "features" in body') 53 | if (not 'params' in body): 54 | raise ValueError('Missing "params" in body') 55 | if (not 'sampling_freq' in body): 56 | raise ValueError('Missing "sampling_freq" in body') 57 | 58 | base_args = { 59 | 'draw_graphs': False, 60 | 'axes': np.array(body['axes']), 61 | 'sampling_freq': body['sampling_freq'], 62 | 'implementation_version': body['implementation_version'] 63 | } 64 | 65 | for param_key in body['params'].keys(): 66 | base_args[param_key] = body['params'][param_key] 67 | 68 | total = 0 69 | features = [] 70 | labels = [] 71 | output_config = None 72 | 73 | for example in body['features']: 74 | args = dict(base_args) 75 | args['raw_data'] = np.array(example) 76 | f = fn(**args) 77 | if (isinstance(f['features'], np.ndarray)): 78 | features.append(f['features'].flatten().tolist()) 79 | else: 80 | features.append(f['features']) 81 | 82 | if total == 0: 83 | if ('labels' in f): 84 | labels = f['labels'] 85 | if ('output_config' in f): 86 | output_config = f['output_config'] 87 | 88 | total += 1 89 | 90 | body = json.dumps({ 91 | 'success': True, 92 | 'features': features, 93 | 'labels': labels, 94 | 'output_config': output_config 95 | }) 96 | 97 | self.send_response(200) 98 | self.send_header('Content-Type', 'application/json') 99 | self.end_headers() 100 | self.wfile.write(body.encode()) 101 | 102 | def tflite_req(self, fn, body): 103 | if (not 'params' in body): 104 | raise ValueError('Missing "params" in body') 105 | if (not 'sampling_freq' in body): 106 | raise ValueError('Missing "sampling_freq" in body') 107 | 108 | args = { 109 | 'axes': np.array(body['axes']), 110 | 'sampling_freq': body['sampling_freq'], 111 | 'implementation_version': body['implementation_version'], 112 | 'input_shape': body['input_shape'] 113 | } 114 | 115 | for param_key in body['params'].keys(): 116 | args[param_key] = body['params'][param_key] 117 | 118 | tflite_byte_arr = fn(**args) 119 | 120 | self.send_response(200) 121 | self.send_header('Content-type', 'application/octet-stream') 122 | self.send_header('Content-Disposition', 'attachment; filename="dsp.tflite"') 123 | self.end_headers() 124 | self.wfile.write(tflite_byte_arr) 125 | 126 | class Handler(BaseHTTPRequestHandler): 127 | def do_GET(self): 128 | url = urlparse(self.path) 129 | params = get_params(self) 130 | 131 | if (url.path == '/'): 132 | self.send_response(200) 133 | self.send_header('Content-Type', 'text/plain') 134 | self.end_headers() 135 | self.wfile.write(('Edge Impulse DSP block: ' + params['info']['title'] + ' by ' + 136 | params['info']['author']).encode()) 137 | 138 | elif (url.path == '/parameters'): 139 | self.send_response(200) 140 | self.send_header('Content-Type', 'application/json') 141 | self.end_headers() 142 | params['version'] = 1 143 | self.wfile.write(json.dumps(params).encode()) 144 | 145 | else: 146 | self.send_response(404) 147 | self.send_header('Content-Type', 'text/plain') 148 | self.end_headers() 149 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 150 | 151 | def do_POST(self): 152 | url = urlparse(self.path) 153 | try: 154 | if (url.path == '/run'): 155 | content_len = int(self.headers.get('Content-Length')) 156 | post_body = self.rfile.read(content_len) 157 | body = json.loads(post_body.decode('utf-8')) 158 | single_req(self, generate_features, body) 159 | 160 | elif (url.path == '/batch'): 161 | content_len = int(self.headers.get('Content-Length')) 162 | post_body = self.rfile.read(content_len) 163 | body = json.loads(post_body.decode('utf-8')) 164 | batch_req(self, generate_features, body) 165 | 166 | else: 167 | self.send_response(404) 168 | self.send_header('Content-Type', 'text/plain') 169 | self.end_headers() 170 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 171 | 172 | 173 | except Exception as e: 174 | print('Failed to handle request', e, traceback.format_exc()) 175 | self.send_response(200) 176 | self.send_header('Content-Type', 'application/json') 177 | self.end_headers() 178 | self.wfile.write(json.dumps({ 'success': False, 'error': str(e) }).encode()) 179 | 180 | def log_message(self, format, *args): 181 | return 182 | 183 | class ThreadingSimpleServer(ThreadingMixIn, HTTPServer): 184 | pass 185 | 186 | def run(): 187 | host = '0.0.0.0' if not 'HOST' in os.environ else os.environ['HOST'] 188 | port = 4446 if not 'PORT' in os.environ else int(os.environ['PORT']) 189 | 190 | server = ThreadingSimpleServer((host, port), Handler) 191 | print('Listening on host', host, 'port', port) 192 | server.serve_forever() 193 | 194 | if __name__ == '__main__': 195 | run() 196 | -------------------------------------------------------------------------------- /mfcc/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": { 3 | "title": "Audio (MFCC)", 4 | "author": "Edge Impulse", 5 | "description": "Extracts features from audio signals using Mel Frequency Cepstral Coefficients, great for human voice.", 6 | "name": "MFCC", 7 | "preferConvolution": true, 8 | "convolutionColumns": "num_cepstral", 9 | "convolutionKernelSize": 5, 10 | "cppType": "mfcc_custom", 11 | "visualization": "dimensionalityReduction", 12 | "experimental": false, 13 | "hasAutoTune": true, 14 | "latestImplementationVersion": 4 15 | }, 16 | "parameters": [ 17 | { 18 | "group": "Mel Frequency Cepstral Coefficients", 19 | "items": [ 20 | { 21 | "name": "Number of coefficients", 22 | "value": 13, 23 | "type": "int", 24 | "help": "Number of cepstral coefficients", 25 | "param": "num_cepstral" 26 | }, 27 | { 28 | "name": "Frame length", 29 | "value": 0.02, 30 | "type": "float", 31 | "help": "The length of each frame in seconds", 32 | "param": "frame_length" 33 | }, 34 | { 35 | "name": "Frame stride", 36 | "value": 0.02, 37 | "type": "float", 38 | "help": "The step between successive frames in seconds", 39 | "param": "frame_stride" 40 | }, 41 | { 42 | "name": "Filter number", 43 | "value": 32, 44 | "type": "int", 45 | "help": "The number of filters in the filterbank", 46 | "param": "num_filters" 47 | }, 48 | { 49 | "name": "FFT length", 50 | "value": 256, 51 | "type": "int", 52 | "help": "Number of FFT points", 53 | "param": "fft_length" 54 | }, 55 | { 56 | "name": "Normalization window size", 57 | "value": 101, 58 | "type": "int", 59 | "help": "The size of sliding window for local normalization. Set this to 0 to disable normalization.", 60 | "param": "win_size" 61 | }, 62 | { 63 | "name": "Low frequency", 64 | "value": 0, 65 | "type": "int", 66 | "help": "Lowest band edge of mel filters (in Hz)", 67 | "param": "low_frequency" 68 | }, 69 | { 70 | "name": "High frequency", 71 | "value": 0, 72 | "type": "int", 73 | "help": "Highest band edge of mel filters (in Hz). If not set (or set to 0) this is samplerate / 2", 74 | "param": "high_frequency", 75 | "optional": true 76 | } 77 | ] 78 | }, 79 | { 80 | "group": "Pre-emphasis", 81 | "items": [ 82 | { 83 | "name": "Coefficient", 84 | "value": 0.98, 85 | "type": "float", 86 | "help": "The pre-emphasizing coefficient to apply to the input signal (0 equals to no filtering)", 87 | "param": "pre_cof" 88 | }, 89 | { 90 | "name": "Shift", 91 | "value": 1, 92 | "type": "int", 93 | "help": "The pre-emphasis shift value to roll over the input signal", 94 | "param": "pre_shift", 95 | "showForImplementationVersion": [ 1, 2 ] 96 | } 97 | ] 98 | } 99 | ] 100 | } 101 | -------------------------------------------------------------------------------- /mfcc/requirements-blocks.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /mfcc/third_party/speechpy/__init__.py: -------------------------------------------------------------------------------- 1 | from . import feature 2 | from . import processing 3 | -------------------------------------------------------------------------------- /mfcc/third_party/speechpy/functions.py: -------------------------------------------------------------------------------- 1 | """function module. 2 | 3 | This module contains necessary functions for calculating the features 4 | in the `features` module. 5 | 6 | 7 | Attributes: 8 | 9 | frequency_to_mel: Converting the frequency to Mel scale. 10 | This is necessary for filterbank energy calculation. 11 | mel_to_frequency: Converting the Mel to frequency scale. 12 | This is necessary for filterbank energy calculation. 13 | triangle: Creating a triangle for filterbanks. 14 | This is necessary for filterbank energy calculation. 15 | zero_handling: Handling zero values due to the possible 16 | issues regarding the log functions. 17 | """ 18 | 19 | from __future__ import division 20 | import numpy as np 21 | from . import processing 22 | from scipy.fftpack import dct 23 | import math 24 | 25 | 26 | def frequency_to_mel(f): 27 | """converting from frequency to Mel scale. 28 | 29 | :param f: The frequency values(or a single frequency) in Hz. 30 | :returns: The mel scale values(or a single mel). 31 | """ 32 | return 1127 * np.log(1 + f / 700.) 33 | 34 | 35 | def mel_to_frequency(mel): 36 | """converting from Mel scale to frequency. 37 | 38 | :param mel: The mel scale values(or a single mel). 39 | :returns: The frequency values(or a single frequency) in Hz. 40 | """ 41 | return 700 * (np.exp(mel / 1127.0) - 1) 42 | 43 | 44 | def triangle(x, left, middle, right): 45 | out = np.zeros(x.shape) 46 | out[x <= left] = 0 47 | out[x >= right] = 0 48 | first_half = np.logical_and(left < x, x <= middle) 49 | out[first_half] = (x[first_half] - left) / (middle - left) 50 | second_half = np.logical_and(middle <= x, x < right) 51 | out[second_half] = (right - x[second_half]) / (right - middle) 52 | return out 53 | 54 | 55 | def zero_handling(x): 56 | """ 57 | This function handle the issue with zero values if the are exposed 58 | to become an argument for any log function. 59 | :param x: The vector. 60 | :return: The vector with zeros substituted with epsilon values. 61 | """ 62 | return np.where(x == 0, 1e-10, x) 63 | -------------------------------------------------------------------------------- /mfe/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental@sha256:3c244c0c6fc9d6aa3ddb73af4264b3a23597523ac553294218c13735a2c6cf79 2 | FROM ubuntu:20.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | WORKDIR /app 7 | 8 | # python3 and all dependencies for scipy 9 | RUN apt update && apt install -y python3 python3-pip libatlas-base-dev gfortran-9 libfreetype6-dev wget && \ 10 | ln -s $(which gfortran-9) /usr/bin/gfortran 11 | 12 | # Update pip 13 | RUN pip3 install -U pip==22.0.3 14 | 15 | # Cython and scikit-learn - it needs to be done in this order for some reason 16 | RUN pip3 --no-cache-dir install Cython==0.29.24 17 | 18 | # Rest of the dependencies 19 | COPY requirements-blocks.txt ./ 20 | RUN pip3 --no-cache-dir install -r requirements-blocks.txt 21 | 22 | COPY third_party /third_party 23 | COPY . ./ 24 | 25 | EXPOSE 4446 26 | 27 | ENTRYPOINT ["python3", "-u", "dsp-server.py"] 28 | -------------------------------------------------------------------------------- /mfe/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsp import generate_features 2 | -------------------------------------------------------------------------------- /mfe/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('/') 4 | from common.sampling import calc_resampled_size, calculate_freq, Resampler 5 | 6 | 7 | class Dataset: 8 | '''Create an iterable dataset when x data is flattened, handling reshaping and resampling''' 9 | 10 | def __init__(self, X_all, metadata, axis, returns_interval=True, resample_interval_ms=None): 11 | self.ix = 0 12 | self.returns_interval = returns_interval 13 | self.max_len = 0 14 | 15 | X_all_shaped = [] 16 | y_all = [] 17 | self.y_label_set = set() 18 | intervals_all = [] 19 | current_offset = 0 20 | 21 | if resample_interval_ms: 22 | self.fs = calculate_freq(resample_interval_ms) 23 | else: 24 | self.fs = None 25 | 26 | # Prepare for resampling data 27 | if resample_interval_ms is not None: 28 | resample_utility = Resampler(len(metadata)) 29 | target_freq = calculate_freq(resample_interval_ms) 30 | intervals_all.append(resample_interval_ms) 31 | 32 | # Reshape all samples 33 | for ix in range(len(metadata)): 34 | # Get x data using offset 35 | cur_len = metadata[ix] 36 | X_full = X_all[current_offset: current_offset + cur_len] 37 | current_offset = current_offset + cur_len 38 | 39 | # Split the interval, label from the features 40 | interval_ms = X_full[0] 41 | y = X_full[1] 42 | X = X_full[2:] 43 | 44 | if not self.fs: 45 | # if we didn't get a resampling rate from the caller, use the first sample's rate 46 | self.fs = calculate_freq(interval_ms) 47 | 48 | if not np.isnan(X).any(): 49 | # Reshape 50 | len_adjusted = cur_len - 2 51 | rows = int(len_adjusted / axis) 52 | # Data length is unexpected 53 | if not ((len_adjusted % axis) == 0): 54 | raise ValueError('Sample length is invalid, check the axis count.') 55 | 56 | X = np.reshape(X, (rows, axis)) 57 | 58 | # Resample data 59 | if resample_interval_ms is not None: 60 | # Work out the up and down factors using sample lengths 61 | original_length = X.shape[0] 62 | original_freq = calculate_freq(interval_ms) 63 | new_length = calc_resampled_size(original_freq, target_freq, original_length) 64 | 65 | # Resample 66 | X = resample_utility.resample(X, new_length, original_length) 67 | else: 68 | intervals_all.append(interval_ms) 69 | 70 | # Store the longest sample length 71 | self.max_len = max(self.max_len, X.shape[0]) 72 | X_all_shaped.append(X) 73 | y_all.append(y) 74 | self.y_label_set.add(y) 75 | 76 | self.X_all = X_all_shaped 77 | self.y_all = y_all 78 | self.intervals = intervals_all 79 | 80 | def reset(self): 81 | self.ix = 0 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | if self.ix >= len(self.y_all): 88 | self.reset() 89 | raise StopIteration 90 | 91 | X = self.X_all[self.ix] 92 | y = self.y_all[self.ix] 93 | if (len(self.intervals) == 1): 94 | # Resampled data has the same interval so we only store it once 95 | interval_ms = self.intervals[0] 96 | else: 97 | interval_ms = self.intervals[self.ix] 98 | 99 | self.ix += 1 100 | 101 | if (self.returns_interval): 102 | return X, y, interval_ms 103 | else: 104 | return X, y 105 | -------------------------------------------------------------------------------- /mfe/common/errors.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import traceback 4 | 5 | 6 | class ConfigurationError(Exception): 7 | pass 8 | 9 | 10 | def log(*msg, level='warn'): 11 | msg_clean = ' '.join([str(i) for i in msg]) 12 | print(json.dumps( 13 | {'msg': msg_clean, 14 | 'level': level, 15 | 'time': datetime.datetime.now().replace(microsecond=0).isoformat() + 'Z'})) 16 | 17 | 18 | def log_exception(msg): 19 | log(msg + ': ' + traceback.format_exc(), level='error') 20 | -------------------------------------------------------------------------------- /mfe/common/graphing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import io 5 | import base64 6 | import math 7 | 8 | 9 | def set_x_axis_times(frame_stride, frame_length, width): 10 | plt.xlabel('Time [sec]') 11 | time_len = (width * frame_stride) + frame_length 12 | times = np.linspace(0, time_len, 10) 13 | plt.xticks(np.linspace(0, width, len(times)), [round(x, 2) for x in times]) 14 | 15 | 16 | def create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs=None): 17 | matplotlib.use('Svg') 18 | _, ax = plt.subplots() 19 | if not freqs: 20 | freqs = np.linspace(0, sampling_freq / 2, 15) 21 | plt.ylabel('Frequency [Hz]') 22 | ax.imshow(power_spectrum, interpolation='nearest', 23 | cmap=matplotlib.cm.coolwarm, origin='lower') 24 | plt.yticks(np.linspace(0, height, len(freqs)), [math.ceil(x) for x in freqs]) 25 | set_x_axis_times(frame_stride, frame_length, width) 26 | 27 | buf = io.BytesIO() 28 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 29 | buf.seek(0) 30 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 31 | buf.close() 32 | return image 33 | 34 | 35 | def create_mfe_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs): 36 | # Trim down the frequency list for a y axis labels 37 | freqs = [freqs[0], *freqs[1:-1:4], freqs[-1]] 38 | return create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs) 39 | 40 | 41 | def create_graph(matrix, y_label, x_label): 42 | matplotlib.use('Svg') 43 | _, ax = plt.subplots() 44 | plt.ylabel(y_label) 45 | plt.xlabel(x_label) 46 | im = ax.imshow(matrix, interpolation='nearest', 47 | cmap=matplotlib.cm.coolwarm, origin='lower') 48 | cbar = plt.colorbar(im, orientation='horizontal', pad=0.2) 49 | buf = io.BytesIO() 50 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 51 | buf.seek(0) 52 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 53 | buf.close() 54 | return image 55 | -------------------------------------------------------------------------------- /mfe/common/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import sys 4 | from scipy import signal 5 | 6 | 7 | def calc_resampled_size(input_sample_rate, output_sample_rate, input_length): 8 | """Calculate the output size after resampling. 9 | :returns: integer output size, >= 1 10 | """ 11 | target_size = int( 12 | math.ceil((output_sample_rate / input_sample_rate) * (input_length))) 13 | return max(target_size, 1) 14 | 15 | 16 | def calculate_freq(interval): 17 | """ Convert interval (ms) to frequency (Hz) 18 | """ 19 | freq = 1000 / interval 20 | if abs(freq - round(freq)) < 0.01: 21 | freq = round(freq) 22 | return freq 23 | 24 | 25 | def calc_decimation_ratios(filter_type, filter_cutoff, fs): 26 | if filter_type != 'low': 27 | return 1 28 | 29 | # we support base ratios of 3 and 10 in SDK 30 | ratios = [3, 10, 30, 100, 1000] 31 | ratios.reverse() 32 | for r in ratios: 33 | if fs / 2 / r * 0.9 > filter_cutoff: 34 | return r 35 | 36 | return 1 37 | 38 | 39 | def get_ratio_combo(r): 40 | if r == 1: 41 | return [1] 42 | elif r == 3 or r == 10: 43 | return [r] 44 | elif r == 30: 45 | return [3, 10] 46 | elif r == 100: 47 | return [10, 10] 48 | elif r == 1000: 49 | return [10, 10, 10] 50 | else: 51 | raise ValueError("Invalid decimation ratio: {}".format(r)) 52 | 53 | 54 | def create_decimate_filter(ratio): 55 | sos = signal.cheby1(8, 0.05, 0.8 / ratio, output='sos') 56 | zi = signal.sosfilt_zi(sos) 57 | return sos, zi 58 | 59 | 60 | def decimate_simple(x, ratio, export=False): 61 | if x.ndim != 1: 62 | raise ValueError(f'x must be 1D {x.shape}') 63 | x = x.reshape(x.shape[0]) 64 | if (ratio == 1): 65 | return x 66 | sos, zi = create_decimate_filter(ratio) 67 | y, zo = signal.sosfilt(sos, x, zi=zi * x[0]) 68 | sl = slice(None, None, ratio) 69 | y = y[sl] 70 | if export: 71 | return y, sos, zi 72 | return y 73 | 74 | 75 | class Resampler: 76 | """ Utility class to handle resampling and logging 77 | """ 78 | 79 | def __init__(self, total_samples): 80 | self.total_samples = total_samples 81 | self.ix = 0 82 | self.last_message = 0 83 | 84 | def resample(self, sample, new_length, original_length): 85 | # Work out the correct axis 86 | ds_axis = 0 87 | if (sample.shape[0] == 1): 88 | ds_axis = 1 89 | 90 | # Resample 91 | if (original_length != new_length): 92 | sample = signal.resample_poly( 93 | sample, new_length, original_length, axis=ds_axis) 94 | 95 | # Logging 96 | self.ix += 1 97 | if (int(round(time.time() * 1000)) - self.last_message >= 3000) or (self.ix == self.total_samples): 98 | print('[%s/%d] Resampling windows...' % 99 | (str(self.ix).rjust(len(str(self.total_samples)), ' '), self.total_samples)) 100 | 101 | if (self.ix == self.total_samples): 102 | print('Resampled %d windows\n' % self.total_samples) 103 | 104 | sys.stdout.flush() 105 | self.last_message = int(round(time.time() * 1000)) 106 | 107 | return sample 108 | -------------------------------------------------------------------------------- /mfe/common/spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | sys.path.append('/') 5 | from .errors import ConfigurationError 6 | 7 | 8 | def next_power_of_2(x): 9 | return 1 if x == 0 else 2**(x - 1).bit_length() 10 | 11 | 12 | def welch_max_hold(fx, sampling_freq, nfft, n_overlap): 13 | n_overlap = int(n_overlap) 14 | spec_powers = [0 for _ in range(nfft//2+1)] 15 | ix = 0 16 | while ix <= len(fx): 17 | # Slicing truncates if end_idx > len, and rfft will auto zero pad 18 | fft_out = np.abs(np.fft.rfft(fx[ix:ix+nfft], nfft)) 19 | spec_powers = np.maximum(spec_powers, fft_out**2/nfft) 20 | ix = ix + (nfft-n_overlap) 21 | return np.fft.rfftfreq(nfft, 1/sampling_freq), spec_powers 22 | 23 | 24 | def zero_handling(x): 25 | """ 26 | This function handle the issue with zero values if the are exposed 27 | to become an argument for any log function. 28 | :param x: The vector. 29 | :return: The vector with zeros substituted with epsilon values. 30 | """ 31 | return np.where(x == 0, 1e-10, x) 32 | 33 | 34 | def cap_frame_stride(window_size_ms, frame_stride): 35 | """Returns the frame stride passed in, 36 | or a stride that creates 500 frames if the window size is too large. 37 | 38 | Args: 39 | window_size_ms (int): The users window size (in ms). 40 | If none or 0, no capping is done. 41 | frame_stride (float): The desired frame stride 42 | 43 | Returns: 44 | float: Either the passed in frame_stride, or longer frame stride 45 | """ 46 | if window_size_ms: 47 | num_frames = (window_size_ms / 1000) / frame_stride 48 | if num_frames > 500: 49 | print('WARNING: Your window size is too large for the ideal frame stride. ' 50 | f'Set window size to {500 * frame_stride * 1000} ms, or smaller. ' 51 | 'Adjusting ideal frame stride to set number of frames to 500') 52 | frame_stride = (window_size_ms / 1000) / 500 53 | return frame_stride 54 | 55 | 56 | def audio_set_params(frame_length, fs): 57 | """Suggest parameters for audio processing (MFE/MFCC) 58 | 59 | Args: 60 | frame_length (float): The desired frame length (in seconds) 61 | fs (int): The sampling frequency (in Hz) 62 | 63 | Returns: 64 | fft_length: Recomended FFT length 65 | num_filters: Recomended number of filters 66 | """ 67 | DEFAULT_NUM_FILTERS = 40 68 | DEFAULT_NFFT = 256 # for 8kHz sampling rate 69 | 70 | fft_length = next_power_of_2(int(frame_length * fs)) 71 | num_filters = int(DEFAULT_NUM_FILTERS + np.log2(fft_length / DEFAULT_NFFT)) 72 | return fft_length, num_filters 73 | -------------------------------------------------------------------------------- /mfe/common/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import numpy as np 3 | from scipy.stats import skew, entropy, kurtosis 4 | 5 | 6 | def calculate_entropy(x): 7 | # todo: try approximate entropy 8 | # todo: try Kozachenko and Leonenko 9 | probabilities = np.histogram(x, bins=100, density=True)[0] 10 | return {'entropy': entropy(probabilities)} 11 | 12 | 13 | def get_percentile_from_sorted(array, percentile): 14 | # adding 0.5 is a trick to get rounding out of C flooring behavior during cast 15 | index = int(((len(array)-1) * percentile/100) + 0.5) 16 | return array[index] 17 | 18 | 19 | def calculate_statistics(x): 20 | output = {} 21 | x.sort() 22 | output['n5'] = get_percentile_from_sorted(x, 5) 23 | output['n25'] = get_percentile_from_sorted(x, 25) 24 | output['n75'] = get_percentile_from_sorted(x, 75) 25 | output['n95'] = get_percentile_from_sorted(x, 95) 26 | output['median'] = get_percentile_from_sorted(x, 50) 27 | output['mean'] = np.mean(x) 28 | output['std'] = np.std(x) 29 | output['var'] = np.var(x, ddof=1) 30 | output['rms'] = np.sqrt(np.mean(x**2)) 31 | output['skew'] = 0 if output['rms'] == 0 else skew(x) 32 | output['kurtosis'] = 0 if output['rms'] == 0 else kurtosis(x) 33 | return output 34 | 35 | 36 | def calculate_crossings(x): 37 | lx = len(x) 38 | zero_crossing_indices = np.nonzero(np.diff(np.array(x) > 0))[0] 39 | no_zero_crossings = len(zero_crossing_indices) / lx 40 | m = np.nanmean(x) 41 | mean_crossing_indices = np.nonzero(np.diff(np.array(x) > m))[0] 42 | no_mean_crossings = len(mean_crossing_indices) / lx 43 | return {'zcross': no_zero_crossings, 'mcross': no_mean_crossings} 44 | 45 | 46 | def get_features(x): 47 | features = calculate_entropy(x) 48 | features.update(calculate_crossings(x)) 49 | features.update(calculate_statistics(x)) 50 | return features 51 | 52 | 53 | def get_max_level(signal_length): 54 | return int(np.log2(signal_length / 32)) 55 | 56 | 57 | def get_min_length(level): 58 | return 32 * np.power(2, level) 59 | 60 | 61 | def dwt_features(x, wav='db4', level=4, mode='stats'): 62 | y = pywt.wavedec(x, wav, level=level) 63 | 64 | if mode == 'raw': 65 | XW = [item for sublist in y for item in sublist] 66 | else: 67 | features = [] 68 | labels = [] 69 | for i in range(len(y)): 70 | d = get_features(y[i]) 71 | for k, v in d.items(): 72 | features.append(v) 73 | labels.append('L' + str(i) + '-' + k) 74 | 75 | return features, labels, y[0] 76 | 77 | 78 | def get_wavefunc(wav, level): 79 | 80 | wavelet = pywt.Wavelet(wav) 81 | try: 82 | phi, psi, x = wavelet.wavefun(level) 83 | except: 84 | phi, psi, _, _, x = wavelet.wavefun(level) 85 | return phi, psi, x 86 | -------------------------------------------------------------------------------- /mfe/dsp-server.py: -------------------------------------------------------------------------------- 1 | # This is a generic Edge Impulse DSP server in Python 2 | # You probably don't need to change this file. 3 | 4 | import sys, importlib, os, socket, json, math, traceback 5 | from http.server import HTTPServer, BaseHTTPRequestHandler 6 | from socketserver import ThreadingMixIn 7 | import threading 8 | from urllib.parse import urlparse, parse_qs 9 | import traceback 10 | import logging 11 | import numpy as np 12 | from dsp import generate_features 13 | 14 | def get_params(self): 15 | with open('parameters.json', 'r') as f: 16 | return json.loads(f.read()) 17 | 18 | def single_req(self, fn, body): 19 | if (not body['features'] or len(body['features']) == 0): 20 | raise ValueError('Missing "features" in body') 21 | if (not 'params' in body): 22 | raise ValueError('Missing "params" in body') 23 | if (not 'sampling_freq' in body): 24 | raise ValueError('Missing "sampling_freq" in body') 25 | if (not 'draw_graphs' in body): 26 | raise ValueError('Missing "draw_graphs" in body') 27 | 28 | args = { 29 | 'draw_graphs': body['draw_graphs'], 30 | 'raw_data': np.array(body['features']), 31 | 'axes': np.array(body['axes']), 32 | 'sampling_freq': body['sampling_freq'], 33 | 'implementation_version': body['implementation_version'] 34 | } 35 | 36 | for param_key in body['params'].keys(): 37 | args[param_key] = body['params'][param_key] 38 | 39 | processed = fn(**args) 40 | if (isinstance(processed['features'], np.ndarray)): 41 | processed['features'] = processed['features'].flatten().tolist() 42 | 43 | body = json.dumps(processed) 44 | 45 | self.send_response(200) 46 | self.send_header('Content-Type', 'application/json') 47 | self.end_headers() 48 | self.wfile.write(body.encode()) 49 | 50 | def batch_req(self, fn, body): 51 | if (not body['features'] or len(body['features']) == 0): 52 | raise ValueError('Missing "features" in body') 53 | if (not 'params' in body): 54 | raise ValueError('Missing "params" in body') 55 | if (not 'sampling_freq' in body): 56 | raise ValueError('Missing "sampling_freq" in body') 57 | 58 | base_args = { 59 | 'draw_graphs': False, 60 | 'axes': np.array(body['axes']), 61 | 'sampling_freq': body['sampling_freq'], 62 | 'implementation_version': body['implementation_version'] 63 | } 64 | 65 | for param_key in body['params'].keys(): 66 | base_args[param_key] = body['params'][param_key] 67 | 68 | total = 0 69 | features = [] 70 | labels = [] 71 | output_config = None 72 | 73 | for example in body['features']: 74 | args = dict(base_args) 75 | args['raw_data'] = np.array(example) 76 | f = fn(**args) 77 | if (isinstance(f['features'], np.ndarray)): 78 | features.append(f['features'].flatten().tolist()) 79 | else: 80 | features.append(f['features']) 81 | 82 | if total == 0: 83 | if ('labels' in f): 84 | labels = f['labels'] 85 | if ('output_config' in f): 86 | output_config = f['output_config'] 87 | 88 | total += 1 89 | 90 | body = json.dumps({ 91 | 'success': True, 92 | 'features': features, 93 | 'labels': labels, 94 | 'output_config': output_config 95 | }) 96 | 97 | self.send_response(200) 98 | self.send_header('Content-Type', 'application/json') 99 | self.end_headers() 100 | self.wfile.write(body.encode()) 101 | 102 | def tflite_req(self, fn, body): 103 | if (not 'params' in body): 104 | raise ValueError('Missing "params" in body') 105 | if (not 'sampling_freq' in body): 106 | raise ValueError('Missing "sampling_freq" in body') 107 | 108 | args = { 109 | 'axes': np.array(body['axes']), 110 | 'sampling_freq': body['sampling_freq'], 111 | 'implementation_version': body['implementation_version'], 112 | 'input_shape': body['input_shape'] 113 | } 114 | 115 | for param_key in body['params'].keys(): 116 | args[param_key] = body['params'][param_key] 117 | 118 | tflite_byte_arr = fn(**args) 119 | 120 | self.send_response(200) 121 | self.send_header('Content-type', 'application/octet-stream') 122 | self.send_header('Content-Disposition', 'attachment; filename="dsp.tflite"') 123 | self.end_headers() 124 | self.wfile.write(tflite_byte_arr) 125 | 126 | class Handler(BaseHTTPRequestHandler): 127 | def do_GET(self): 128 | url = urlparse(self.path) 129 | params = get_params(self) 130 | 131 | if (url.path == '/'): 132 | self.send_response(200) 133 | self.send_header('Content-Type', 'text/plain') 134 | self.end_headers() 135 | self.wfile.write(('Edge Impulse DSP block: ' + params['info']['title'] + ' by ' + 136 | params['info']['author']).encode()) 137 | 138 | elif (url.path == '/parameters'): 139 | self.send_response(200) 140 | self.send_header('Content-Type', 'application/json') 141 | self.end_headers() 142 | params['version'] = 1 143 | self.wfile.write(json.dumps(params).encode()) 144 | 145 | else: 146 | self.send_response(404) 147 | self.send_header('Content-Type', 'text/plain') 148 | self.end_headers() 149 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 150 | 151 | def do_POST(self): 152 | url = urlparse(self.path) 153 | try: 154 | if (url.path == '/run'): 155 | content_len = int(self.headers.get('Content-Length')) 156 | post_body = self.rfile.read(content_len) 157 | body = json.loads(post_body.decode('utf-8')) 158 | single_req(self, generate_features, body) 159 | 160 | elif (url.path == '/batch'): 161 | content_len = int(self.headers.get('Content-Length')) 162 | post_body = self.rfile.read(content_len) 163 | body = json.loads(post_body.decode('utf-8')) 164 | batch_req(self, generate_features, body) 165 | 166 | else: 167 | self.send_response(404) 168 | self.send_header('Content-Type', 'text/plain') 169 | self.end_headers() 170 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 171 | 172 | 173 | except Exception as e: 174 | print('Failed to handle request', e, traceback.format_exc()) 175 | self.send_response(200) 176 | self.send_header('Content-Type', 'application/json') 177 | self.end_headers() 178 | self.wfile.write(json.dumps({ 'success': False, 'error': str(e) }).encode()) 179 | 180 | def log_message(self, format, *args): 181 | return 182 | 183 | class ThreadingSimpleServer(ThreadingMixIn, HTTPServer): 184 | pass 185 | 186 | def run(): 187 | host = '0.0.0.0' if not 'HOST' in os.environ else os.environ['HOST'] 188 | port = 4446 if not 'PORT' in os.environ else int(os.environ['PORT']) 189 | 190 | server = ThreadingSimpleServer((host, port), Handler) 191 | print('Listening on host', host, 'port', port) 192 | server.serve_forever() 193 | 194 | if __name__ == '__main__': 195 | run() 196 | -------------------------------------------------------------------------------- /mfe/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "info": { 4 | "title": "Audio (MFE)", 5 | "author": "Edge Impulse", 6 | "description": "Extracts a spectrogram from audio signals using Mel-filterbank energy features, great for non-voice audio.", 7 | "name": "MFE", 8 | "preferConvolution": true, 9 | "convolutionColumns": "num_filters", 10 | "convolutionKernelSize": 5, 11 | "cppType": "mfe_custom", 12 | "visualization": "dimensionalityReduction", 13 | "experimental": false, 14 | "latestImplementationVersion": 4, 15 | "hasAutoTune": true 16 | }, 17 | "parameters": [ 18 | { 19 | "group": "Mel-filterbank energy features", 20 | "items": [ 21 | { 22 | "name": "Frame length", 23 | "value": 0.02, 24 | "type": "float", 25 | "help": "The length of each frame in seconds", 26 | "param": "frame_length" 27 | }, 28 | { 29 | "name": "Frame stride", 30 | "value": 0.01, 31 | "type": "float", 32 | "help": "The step between successive frames in seconds", 33 | "param": "frame_stride" 34 | }, 35 | { 36 | "name": "Filter number", 37 | "value": 40, 38 | "type": "int", 39 | "help": "The number of filters in the filterbank", 40 | "param": "num_filters" 41 | }, 42 | { 43 | "name": "FFT length", 44 | "value": 256, 45 | "type": "int", 46 | "help": "Number of FFT points", 47 | "param": "fft_length" 48 | }, 49 | { 50 | "name": "Low frequency", 51 | "value": 0, 52 | "type": "int", 53 | "help": "Lowest band edge of mel filters (in Hz)", 54 | "param": "low_frequency" 55 | }, 56 | { 57 | "name": "High frequency", 58 | "value": 0, 59 | "type": "int", 60 | "help": "Highest band edge of mel filters (in Hz). If set to 0 this is samplerate / 2", 61 | "param": "high_frequency", 62 | "optional": true 63 | } 64 | ] 65 | }, 66 | { 67 | "group": "Normalization", 68 | "items": [ 69 | { 70 | "name": "Window size", 71 | "value": 101, 72 | "type": "int", 73 | "help": "The size of sliding window for local normalization", 74 | "param": "win_size", 75 | "showForImplementationVersion": [ 1, 2 ] 76 | }, 77 | { 78 | "name": "Noise floor (dB)", 79 | "value": -52, 80 | "type": "int", 81 | "help": "Everything less loud than the noise floor will be dropped", 82 | "param": "noise_floor_db", 83 | "showForImplementationVersion": [ 3, 4 ] 84 | } 85 | ] 86 | } 87 | ] 88 | } 89 | -------------------------------------------------------------------------------- /mfe/requirements-blocks.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /mfe/third_party/speechpy/__init__.py: -------------------------------------------------------------------------------- 1 | from . import feature 2 | from . import processing 3 | -------------------------------------------------------------------------------- /mfe/third_party/speechpy/functions.py: -------------------------------------------------------------------------------- 1 | """function module. 2 | 3 | This module contains necessary functions for calculating the features 4 | in the `features` module. 5 | 6 | 7 | Attributes: 8 | 9 | frequency_to_mel: Converting the frequency to Mel scale. 10 | This is necessary for filterbank energy calculation. 11 | mel_to_frequency: Converting the Mel to frequency scale. 12 | This is necessary for filterbank energy calculation. 13 | triangle: Creating a triangle for filterbanks. 14 | This is necessary for filterbank energy calculation. 15 | zero_handling: Handling zero values due to the possible 16 | issues regarding the log functions. 17 | """ 18 | 19 | from __future__ import division 20 | import numpy as np 21 | from . import processing 22 | from scipy.fftpack import dct 23 | import math 24 | 25 | 26 | def frequency_to_mel(f): 27 | """converting from frequency to Mel scale. 28 | 29 | :param f: The frequency values(or a single frequency) in Hz. 30 | :returns: The mel scale values(or a single mel). 31 | """ 32 | return 1127 * np.log(1 + f / 700.) 33 | 34 | 35 | def mel_to_frequency(mel): 36 | """converting from Mel scale to frequency. 37 | 38 | :param mel: The mel scale values(or a single mel). 39 | :returns: The frequency values(or a single frequency) in Hz. 40 | """ 41 | return 700 * (np.exp(mel / 1127.0) - 1) 42 | 43 | 44 | def triangle(x, left, middle, right): 45 | out = np.zeros(x.shape) 46 | out[x <= left] = 0 47 | out[x >= right] = 0 48 | first_half = np.logical_and(left < x, x <= middle) 49 | out[first_half] = (x[first_half] - left) / (middle - left) 50 | second_half = np.logical_and(middle <= x, x < right) 51 | out[second_half] = (right - x[second_half]) / (right - middle) 52 | return out 53 | 54 | 55 | def zero_handling(x): 56 | """ 57 | This function handle the issue with zero values if the are exposed 58 | to become an argument for any log function. 59 | :param x: The vector. 60 | :return: The vector with zeros substituted with epsilon values. 61 | """ 62 | return np.where(x == 0, 1e-10, x) 63 | -------------------------------------------------------------------------------- /raw/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental@sha256:3c244c0c6fc9d6aa3ddb73af4264b3a23597523ac553294218c13735a2c6cf79 2 | FROM ubuntu:20.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | WORKDIR /app 7 | 8 | # python3 and all dependencies for scipy 9 | RUN apt update && apt install -y python3 python3-pip libatlas-base-dev gfortran-9 libfreetype6-dev wget && \ 10 | ln -s $(which gfortran-9) /usr/bin/gfortran 11 | 12 | # Update pip 13 | RUN pip3 install -U pip==22.0.3 14 | 15 | # Cython and scikit-learn - it needs to be done in this order for some reason 16 | RUN pip3 --no-cache-dir install Cython==0.29.24 17 | 18 | # Rest of the dependencies 19 | COPY requirements-blocks.txt ./ 20 | RUN pip3 --no-cache-dir install -r requirements-blocks.txt 21 | 22 | COPY third_party /third_party 23 | COPY . ./ 24 | 25 | EXPOSE 4446 26 | 27 | ENTRYPOINT ["python3", "-u", "dsp-server.py"] 28 | -------------------------------------------------------------------------------- /raw/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsp import generate_features 2 | -------------------------------------------------------------------------------- /raw/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('/') 4 | from common.sampling import calc_resampled_size, calculate_freq, Resampler 5 | 6 | 7 | class Dataset: 8 | '''Create an iterable dataset when x data is flattened, handling reshaping and resampling''' 9 | 10 | def __init__(self, X_all, metadata, axis, returns_interval=True, resample_interval_ms=None): 11 | self.ix = 0 12 | self.returns_interval = returns_interval 13 | self.max_len = 0 14 | 15 | X_all_shaped = [] 16 | y_all = [] 17 | self.y_label_set = set() 18 | intervals_all = [] 19 | current_offset = 0 20 | 21 | if resample_interval_ms: 22 | self.fs = calculate_freq(resample_interval_ms) 23 | else: 24 | self.fs = None 25 | 26 | # Prepare for resampling data 27 | if resample_interval_ms is not None: 28 | resample_utility = Resampler(len(metadata)) 29 | target_freq = calculate_freq(resample_interval_ms) 30 | intervals_all.append(resample_interval_ms) 31 | 32 | # Reshape all samples 33 | for ix in range(len(metadata)): 34 | # Get x data using offset 35 | cur_len = metadata[ix] 36 | X_full = X_all[current_offset: current_offset + cur_len] 37 | current_offset = current_offset + cur_len 38 | 39 | # Split the interval, label from the features 40 | interval_ms = X_full[0] 41 | y = X_full[1] 42 | X = X_full[2:] 43 | 44 | if not self.fs: 45 | # if we didn't get a resampling rate from the caller, use the first sample's rate 46 | self.fs = calculate_freq(interval_ms) 47 | 48 | if not np.isnan(X).any(): 49 | # Reshape 50 | len_adjusted = cur_len - 2 51 | rows = int(len_adjusted / axis) 52 | # Data length is unexpected 53 | if not ((len_adjusted % axis) == 0): 54 | raise ValueError('Sample length is invalid, check the axis count.') 55 | 56 | X = np.reshape(X, (rows, axis)) 57 | 58 | # Resample data 59 | if resample_interval_ms is not None: 60 | # Work out the up and down factors using sample lengths 61 | original_length = X.shape[0] 62 | original_freq = calculate_freq(interval_ms) 63 | new_length = calc_resampled_size(original_freq, target_freq, original_length) 64 | 65 | # Resample 66 | X = resample_utility.resample(X, new_length, original_length) 67 | else: 68 | intervals_all.append(interval_ms) 69 | 70 | # Store the longest sample length 71 | self.max_len = max(self.max_len, X.shape[0]) 72 | X_all_shaped.append(X) 73 | y_all.append(y) 74 | self.y_label_set.add(y) 75 | 76 | self.X_all = X_all_shaped 77 | self.y_all = y_all 78 | self.intervals = intervals_all 79 | 80 | def reset(self): 81 | self.ix = 0 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | if self.ix >= len(self.y_all): 88 | self.reset() 89 | raise StopIteration 90 | 91 | X = self.X_all[self.ix] 92 | y = self.y_all[self.ix] 93 | if (len(self.intervals) == 1): 94 | # Resampled data has the same interval so we only store it once 95 | interval_ms = self.intervals[0] 96 | else: 97 | interval_ms = self.intervals[self.ix] 98 | 99 | self.ix += 1 100 | 101 | if (self.returns_interval): 102 | return X, y, interval_ms 103 | else: 104 | return X, y 105 | -------------------------------------------------------------------------------- /raw/common/errors.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import traceback 4 | 5 | 6 | class ConfigurationError(Exception): 7 | pass 8 | 9 | 10 | def log(*msg, level='warn'): 11 | msg_clean = ' '.join([str(i) for i in msg]) 12 | print(json.dumps( 13 | {'msg': msg_clean, 14 | 'level': level, 15 | 'time': datetime.datetime.now().replace(microsecond=0).isoformat() + 'Z'})) 16 | 17 | 18 | def log_exception(msg): 19 | log(msg + ': ' + traceback.format_exc(), level='error') 20 | -------------------------------------------------------------------------------- /raw/common/graphing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import io 5 | import base64 6 | import math 7 | 8 | 9 | def set_x_axis_times(frame_stride, frame_length, width): 10 | plt.xlabel('Time [sec]') 11 | time_len = (width * frame_stride) + frame_length 12 | times = np.linspace(0, time_len, 10) 13 | plt.xticks(np.linspace(0, width, len(times)), [round(x, 2) for x in times]) 14 | 15 | 16 | def create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs=None): 17 | matplotlib.use('Svg') 18 | _, ax = plt.subplots() 19 | if not freqs: 20 | freqs = np.linspace(0, sampling_freq / 2, 15) 21 | plt.ylabel('Frequency [Hz]') 22 | ax.imshow(power_spectrum, interpolation='nearest', 23 | cmap=matplotlib.cm.coolwarm, origin='lower') 24 | plt.yticks(np.linspace(0, height, len(freqs)), [math.ceil(x) for x in freqs]) 25 | set_x_axis_times(frame_stride, frame_length, width) 26 | 27 | buf = io.BytesIO() 28 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 29 | buf.seek(0) 30 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 31 | buf.close() 32 | return image 33 | 34 | 35 | def create_mfe_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs): 36 | # Trim down the frequency list for a y axis labels 37 | freqs = [freqs[0], *freqs[1:-1:4], freqs[-1]] 38 | return create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs) 39 | -------------------------------------------------------------------------------- /raw/common/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import sys 4 | from scipy import signal 5 | 6 | 7 | def calc_resampled_size(input_sample_rate, output_sample_rate, input_length): 8 | """Calculate the output size after resampling. 9 | :returns: integer output size, >= 1 10 | """ 11 | target_size = int( 12 | math.ceil((output_sample_rate / input_sample_rate) * (input_length))) 13 | return max(target_size, 1) 14 | 15 | 16 | def calculate_freq(interval): 17 | """ Convert interval (ms) to frequency (Hz) 18 | """ 19 | freq = 1000 / interval 20 | if abs(freq - round(freq)) < 0.01: 21 | freq = round(freq) 22 | return freq 23 | 24 | 25 | def calc_decimation_ratios(filter_type, filter_cutoff, fs): 26 | if filter_type != 'low': 27 | return 1 28 | 29 | # we support base ratios of 3 and 10 in SDK 30 | ratios = [3, 10, 30, 100, 1000] 31 | ratios.reverse() 32 | for r in ratios: 33 | if fs / 2 / r * 0.9 > filter_cutoff: 34 | return r 35 | 36 | return 1 37 | 38 | 39 | def get_ratio_combo(r): 40 | if r == 1: 41 | return [1] 42 | elif r == 3 or r == 10: 43 | return [r] 44 | elif r == 30: 45 | return [3, 10] 46 | elif r == 100: 47 | return [10, 10] 48 | elif r == 1000: 49 | return [10, 10, 10] 50 | else: 51 | raise ValueError("Invalid decimation ratio: {}".format(r)) 52 | 53 | 54 | def create_decimate_filter(ratio): 55 | sos = signal.cheby1(8, 0.05, 0.8 / ratio, output='sos') 56 | zi = signal.sosfilt_zi(sos) 57 | return sos, zi 58 | 59 | 60 | def decimate_simple(x, ratio, export=False): 61 | if x.ndim != 1: 62 | raise ValueError(f'x must be 1D {x.shape}') 63 | x = x.reshape(x.shape[0]) 64 | if (ratio == 1): 65 | return x 66 | sos, zi = create_decimate_filter(ratio) 67 | y, zo = signal.sosfilt(sos, x, zi=zi * x[0]) 68 | sl = slice(None, None, ratio) 69 | y = y[sl] 70 | if export: 71 | return y, sos, zi 72 | return y 73 | 74 | 75 | class Resampler: 76 | """ Utility class to handle resampling and logging 77 | """ 78 | 79 | def __init__(self, total_samples): 80 | self.total_samples = total_samples 81 | self.ix = 0 82 | self.last_message = 0 83 | 84 | def resample(self, sample, new_length, original_length): 85 | # Work out the correct axis 86 | ds_axis = 0 87 | if (sample.shape[0] == 1): 88 | ds_axis = 1 89 | 90 | # Resample 91 | if (original_length != new_length): 92 | sample = signal.resample_poly( 93 | sample, new_length, original_length, axis=ds_axis) 94 | 95 | # Logging 96 | self.ix += 1 97 | if (int(round(time.time() * 1000)) - self.last_message >= 3000) or (self.ix == self.total_samples): 98 | print('[%s/%d] Resampling windows...' % 99 | (str(self.ix).rjust(len(str(self.total_samples)), ' '), self.total_samples)) 100 | 101 | if (self.ix == self.total_samples): 102 | print('Resampled %d windows\n' % self.total_samples) 103 | 104 | sys.stdout.flush() 105 | self.last_message = int(round(time.time() * 1000)) 106 | 107 | return sample 108 | -------------------------------------------------------------------------------- /raw/common/spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | sys.path.append('/') 5 | from .errors import ConfigurationError 6 | 7 | 8 | def next_power_of_2(x): 9 | return 1 if x == 0 else 2**(x - 1).bit_length() 10 | 11 | 12 | def welch_max_hold(fx, sampling_freq, nfft, n_overlap): 13 | n_overlap = int(n_overlap) 14 | spec_powers = [0 for _ in range(nfft//2+1)] 15 | ix = 0 16 | while ix <= len(fx): 17 | # Slicing truncates if end_idx > len, and rfft will auto zero pad 18 | fft_out = np.abs(np.fft.rfft(fx[ix:ix+nfft], nfft)) 19 | spec_powers = np.maximum(spec_powers, fft_out**2/nfft) 20 | ix = ix + (nfft-n_overlap) 21 | return np.fft.rfftfreq(nfft, 1/sampling_freq), spec_powers 22 | 23 | 24 | def zero_handling(x): 25 | """ 26 | This function handle the issue with zero values if the are exposed 27 | to become an argument for any log function. 28 | :param x: The vector. 29 | :return: The vector with zeros substituted with epsilon values. 30 | """ 31 | return np.where(x == 0, 1e-10, x) 32 | 33 | 34 | def cap_frame_stride(window_size_ms, frame_stride): 35 | """Returns the frame stride passed in, 36 | or a stride that creates 500 frames if the window size is too large. 37 | 38 | Args: 39 | window_size_ms (int): The users window size (in ms). 40 | If none or 0, no capping is done. 41 | frame_stride (float): The desired frame stride 42 | 43 | Returns: 44 | float: Either the passed in frame_stride, or longer frame stride 45 | """ 46 | if window_size_ms: 47 | num_frames = (window_size_ms / 1000) / frame_stride 48 | if num_frames > 500: 49 | print('WARNING: Your window size is too large for the ideal frame stride. ' 50 | f'Set window size to {500 * frame_stride * 1000} ms, or smaller. ' 51 | 'Adjusting ideal frame stride to set number of frames to 500') 52 | frame_stride = (window_size_ms / 1000) / 500 53 | return frame_stride 54 | 55 | 56 | def audio_set_params(frame_length, fs): 57 | """Suggest parameters for audio processing (MFE/MFCC) 58 | 59 | Args: 60 | frame_length (float): The desired frame length (in seconds) 61 | fs (int): The sampling frequency (in Hz) 62 | 63 | Returns: 64 | fft_length: Recomended FFT length 65 | num_filters: Recomended number of filters 66 | """ 67 | DEFAULT_NUM_FILTERS = 40 68 | DEFAULT_NFFT = 256 # for 8kHz sampling rate 69 | 70 | fft_length = next_power_of_2(int(frame_length * fs)) 71 | num_filters = int(DEFAULT_NUM_FILTERS + np.log2(fft_length / DEFAULT_NFFT)) 72 | return fft_length, num_filters 73 | -------------------------------------------------------------------------------- /raw/common/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import numpy as np 3 | from scipy.stats import skew, entropy, kurtosis 4 | 5 | 6 | def calculate_entropy(x): 7 | # todo: try approximate entropy 8 | # todo: try Kozachenko and Leonenko 9 | probabilities = np.histogram(x, bins=100, density=True)[0] 10 | return {'entropy': entropy(probabilities)} 11 | 12 | 13 | def get_percentile_from_sorted(array, percentile): 14 | # adding 0.5 is a trick to get rounding out of C flooring behavior during cast 15 | index = int(((len(array)-1) * percentile/100) + 0.5) 16 | return array[index] 17 | 18 | 19 | def calculate_statistics(x): 20 | output = {} 21 | x.sort() 22 | output['n5'] = get_percentile_from_sorted(x, 5) 23 | output['n25'] = get_percentile_from_sorted(x, 25) 24 | output['n75'] = get_percentile_from_sorted(x, 75) 25 | output['n95'] = get_percentile_from_sorted(x, 95) 26 | output['median'] = get_percentile_from_sorted(x, 50) 27 | output['mean'] = np.mean(x) 28 | output['std'] = np.std(x) 29 | output['var'] = np.var(x, ddof=1) 30 | output['rms'] = np.sqrt(np.mean(x**2)) 31 | output['skew'] = 0 if output['rms'] == 0 else skew(x) 32 | output['kurtosis'] = 0 if output['rms'] == 0 else kurtosis(x) 33 | return output 34 | 35 | 36 | def calculate_crossings(x): 37 | lx = len(x) 38 | zero_crossing_indices = np.nonzero(np.diff(np.array(x) > 0))[0] 39 | no_zero_crossings = len(zero_crossing_indices) / lx 40 | m = np.nanmean(x) 41 | mean_crossing_indices = np.nonzero(np.diff(np.array(x) > m))[0] 42 | no_mean_crossings = len(mean_crossing_indices) / lx 43 | return {'zcross': no_zero_crossings, 'mcross': no_mean_crossings} 44 | 45 | 46 | def get_features(x): 47 | features = calculate_entropy(x) 48 | features.update(calculate_crossings(x)) 49 | features.update(calculate_statistics(x)) 50 | return features 51 | 52 | 53 | def get_max_level(signal_length): 54 | return int(np.log2(signal_length / 32)) 55 | 56 | 57 | def get_min_length(level): 58 | return 32 * np.power(2, level) 59 | 60 | 61 | def dwt_features(x, wav='db4', level=4, mode='stats'): 62 | y = pywt.wavedec(x, wav, level=level) 63 | 64 | if mode == 'raw': 65 | XW = [item for sublist in y for item in sublist] 66 | else: 67 | features = [] 68 | labels = [] 69 | for i in range(len(y)): 70 | d = get_features(y[i]) 71 | for k, v in d.items(): 72 | features.append(v) 73 | labels.append('L' + str(i) + '-' + k) 74 | 75 | return features, labels, y[0] 76 | 77 | 78 | def get_wavefunc(wav, level): 79 | 80 | wavelet = pywt.Wavelet(wav) 81 | try: 82 | phi, psi, x = wavelet.wavefun(level) 83 | except: 84 | phi, psi, _, _, x = wavelet.wavefun(level) 85 | return phi, psi, x 86 | -------------------------------------------------------------------------------- /raw/dsp-server.py: -------------------------------------------------------------------------------- 1 | # This is a generic Edge Impulse DSP server in Python 2 | # You probably don't need to change this file. 3 | 4 | import sys, importlib, os, socket, json, math, traceback 5 | from http.server import HTTPServer, BaseHTTPRequestHandler 6 | from socketserver import ThreadingMixIn 7 | import threading 8 | from urllib.parse import urlparse, parse_qs 9 | import traceback 10 | import logging 11 | import numpy as np 12 | from dsp import generate_features 13 | 14 | def get_params(self): 15 | with open('parameters.json', 'r') as f: 16 | return json.loads(f.read()) 17 | 18 | def single_req(self, fn, body): 19 | if (not body['features'] or len(body['features']) == 0): 20 | raise ValueError('Missing "features" in body') 21 | if (not 'params' in body): 22 | raise ValueError('Missing "params" in body') 23 | if (not 'sampling_freq' in body): 24 | raise ValueError('Missing "sampling_freq" in body') 25 | if (not 'draw_graphs' in body): 26 | raise ValueError('Missing "draw_graphs" in body') 27 | 28 | args = { 29 | 'draw_graphs': body['draw_graphs'], 30 | 'raw_data': np.array(body['features']), 31 | 'axes': np.array(body['axes']), 32 | 'sampling_freq': body['sampling_freq'], 33 | 'implementation_version': body['implementation_version'] 34 | } 35 | 36 | for param_key in body['params'].keys(): 37 | args[param_key] = body['params'][param_key] 38 | 39 | processed = fn(**args) 40 | if (isinstance(processed['features'], np.ndarray)): 41 | processed['features'] = processed['features'].flatten().tolist() 42 | 43 | body = json.dumps(processed) 44 | 45 | self.send_response(200) 46 | self.send_header('Content-Type', 'application/json') 47 | self.end_headers() 48 | self.wfile.write(body.encode()) 49 | 50 | def batch_req(self, fn, body): 51 | if (not body['features'] or len(body['features']) == 0): 52 | raise ValueError('Missing "features" in body') 53 | if (not 'params' in body): 54 | raise ValueError('Missing "params" in body') 55 | if (not 'sampling_freq' in body): 56 | raise ValueError('Missing "sampling_freq" in body') 57 | 58 | base_args = { 59 | 'draw_graphs': False, 60 | 'axes': np.array(body['axes']), 61 | 'sampling_freq': body['sampling_freq'], 62 | 'implementation_version': body['implementation_version'] 63 | } 64 | 65 | for param_key in body['params'].keys(): 66 | base_args[param_key] = body['params'][param_key] 67 | 68 | total = 0 69 | features = [] 70 | labels = [] 71 | output_config = None 72 | 73 | for example in body['features']: 74 | args = dict(base_args) 75 | args['raw_data'] = np.array(example) 76 | f = fn(**args) 77 | if (isinstance(f['features'], np.ndarray)): 78 | features.append(f['features'].flatten().tolist()) 79 | else: 80 | features.append(f['features']) 81 | 82 | if total == 0: 83 | if ('labels' in f): 84 | labels = f['labels'] 85 | if ('output_config' in f): 86 | output_config = f['output_config'] 87 | 88 | total += 1 89 | 90 | body = json.dumps({ 91 | 'success': True, 92 | 'features': features, 93 | 'labels': labels, 94 | 'output_config': output_config 95 | }) 96 | 97 | self.send_response(200) 98 | self.send_header('Content-Type', 'application/json') 99 | self.end_headers() 100 | self.wfile.write(body.encode()) 101 | 102 | def tflite_req(self, fn, body): 103 | if (not 'params' in body): 104 | raise ValueError('Missing "params" in body') 105 | if (not 'sampling_freq' in body): 106 | raise ValueError('Missing "sampling_freq" in body') 107 | 108 | args = { 109 | 'axes': np.array(body['axes']), 110 | 'sampling_freq': body['sampling_freq'], 111 | 'implementation_version': body['implementation_version'], 112 | 'input_shape': body['input_shape'] 113 | } 114 | 115 | for param_key in body['params'].keys(): 116 | args[param_key] = body['params'][param_key] 117 | 118 | tflite_byte_arr = fn(**args) 119 | 120 | self.send_response(200) 121 | self.send_header('Content-type', 'application/octet-stream') 122 | self.send_header('Content-Disposition', 'attachment; filename="dsp.tflite"') 123 | self.end_headers() 124 | self.wfile.write(tflite_byte_arr) 125 | 126 | class Handler(BaseHTTPRequestHandler): 127 | def do_GET(self): 128 | url = urlparse(self.path) 129 | params = get_params(self) 130 | 131 | if (url.path == '/'): 132 | self.send_response(200) 133 | self.send_header('Content-Type', 'text/plain') 134 | self.end_headers() 135 | self.wfile.write(('Edge Impulse DSP block: ' + params['info']['title'] + ' by ' + 136 | params['info']['author']).encode()) 137 | 138 | elif (url.path == '/parameters'): 139 | self.send_response(200) 140 | self.send_header('Content-Type', 'application/json') 141 | self.end_headers() 142 | params['version'] = 1 143 | self.wfile.write(json.dumps(params).encode()) 144 | 145 | else: 146 | self.send_response(404) 147 | self.send_header('Content-Type', 'text/plain') 148 | self.end_headers() 149 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 150 | 151 | def do_POST(self): 152 | url = urlparse(self.path) 153 | try: 154 | if (url.path == '/run'): 155 | content_len = int(self.headers.get('Content-Length')) 156 | post_body = self.rfile.read(content_len) 157 | body = json.loads(post_body.decode('utf-8')) 158 | single_req(self, generate_features, body) 159 | 160 | elif (url.path == '/batch'): 161 | content_len = int(self.headers.get('Content-Length')) 162 | post_body = self.rfile.read(content_len) 163 | body = json.loads(post_body.decode('utf-8')) 164 | batch_req(self, generate_features, body) 165 | 166 | else: 167 | self.send_response(404) 168 | self.send_header('Content-Type', 'text/plain') 169 | self.end_headers() 170 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 171 | 172 | 173 | except Exception as e: 174 | print('Failed to handle request', e, traceback.format_exc()) 175 | self.send_response(200) 176 | self.send_header('Content-Type', 'application/json') 177 | self.end_headers() 178 | self.wfile.write(json.dumps({ 'success': False, 'error': str(e) }).encode()) 179 | 180 | def log_message(self, format, *args): 181 | return 182 | 183 | class ThreadingSimpleServer(ThreadingMixIn, HTTPServer): 184 | pass 185 | 186 | def run(): 187 | host = '0.0.0.0' if not 'HOST' in os.environ else os.environ['HOST'] 188 | port = 4446 if not 'PORT' in os.environ else int(os.environ['PORT']) 189 | 190 | server = ThreadingSimpleServer((host, port), Handler) 191 | print('Listening on host', host, 'port', port) 192 | server.serve_forever() 193 | 194 | if __name__ == '__main__': 195 | run() 196 | -------------------------------------------------------------------------------- /raw/dsp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import sys 5 | 6 | def generate_features(implementation_version, draw_graphs, raw_data, axes, sampling_freq, scale_axes): 7 | if (implementation_version != 1): 8 | raise Exception('implementation_version should be 1') 9 | 10 | features = raw_data 11 | if (scale_axes != 1): 12 | features = raw_data * scale_axes 13 | 14 | return { 15 | 'features': features, 16 | 'graphs': [], 17 | 'fft_used': [], 18 | 'output_config': { 'type': 'flat', 'shape': { 'width': len(raw_data) } } 19 | } 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser(description='Returns raw data') 23 | parser.add_argument('--features', type=str, required=True, 24 | help='Axis data as a flattened array of x,y,z (pass as comma separated values)') 25 | parser.add_argument('--axes', type=str, required=True, 26 | help='Names of the axis (pass as comma separated values)') 27 | parser.add_argument('--frequency', type=float, required=True, 28 | help='Frequency in hz') 29 | parser.add_argument('--scale-axes', type=float, default=1, 30 | help='scale axes (multiplies by this number, default: 1)') 31 | parser.add_argument('--draw-graphs', type=bool, required=True, 32 | help='Whether to draw graphs') 33 | 34 | args = parser.parse_args() 35 | 36 | raw_features = np.array([float(item.strip()) for item in args.features.split(',')]) 37 | raw_axes = args.axes.split(',') 38 | 39 | try: 40 | processed = generate_features(1, args.draw_graphs, raw_features, raw_axes, args.frequency, args.scale_axes) 41 | 42 | print('Begin output') 43 | print(json.dumps(processed)) 44 | print('End output') 45 | except Exception as e: 46 | print(e, file=sys.stderr) 47 | exit(1) 48 | -------------------------------------------------------------------------------- /raw/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": { 3 | "title": "Raw Data", 4 | "author": "Edge Impulse", 5 | "description": "Use data without pre-processing. Useful if you want to use deep learning to learn features.", 6 | "name": "Raw data", 7 | "preferConvolution": true, 8 | "convolutionColumns": "axes", 9 | "convolutionKernelSize": 7, 10 | "cppType": "raw_custom", 11 | "visualization": "dimensionalityReduction", 12 | "experimental": false, 13 | "hasFeatureImportance": true, 14 | "latestImplementationVersion": 1 15 | }, 16 | "parameters": [ 17 | { 18 | "group": "Scaling", 19 | "items": [ 20 | { 21 | "name": "Scale axes", 22 | "value": 1, 23 | "type": "float", 24 | "help": "Multiplies axes by this number", 25 | "param": "scale-axes" 26 | } 27 | ] 28 | } 29 | ] 30 | } 31 | -------------------------------------------------------------------------------- /raw/requirements-blocks.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /raw/third_party/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgeimpulse/processing-blocks/ba8108d8427ede9d8098808f283d95e0f7d610b8/raw/third_party/placeholder -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /run_spectral_analysis_via_python.py: -------------------------------------------------------------------------------- 1 | from spectral_analysis import generate_features 2 | import numpy as np 3 | 4 | # Parameters for generate_features 5 | 6 | ## The first section are parameters that apply for any DSP block 7 | # Version of the implementation. If you want the latest, look into parameters.json, and use the value of latestImplementationVersion 8 | implementation_version = 4 # 4 is latest versions 9 | draw_graphs = False # For testing from script, disable graphing to improve speed 10 | 11 | # This is where you want to paste in your test sample. This can be taken from studio 12 | # For example, the below sample is from https://studio.edgeimpulse.com/public/223682/latest/classification#load-sample-269491318 13 | # It was copied by clicking on the copy icon next to "Raw features" 14 | # It is 3 axis accelerometer data, with 62.5Hz sampling frequency 15 | # Data should be formatted as a single flat list, regardless of the number of axes/channels 16 | raw_data = np.array([ 2.2600, -1.2700, -1.5300, 1.9500, -1.7500, -1.1900, 1.7900, -2.8500, 0.6500, 1.9100, -2.9100, 2.3500, 1.9100, -2.9100, 2.3500, 1.9900, -2.4100, 3.5900, 1.2700, -0.3800, 2.5200, 1.5300, 0.9900, 2.7300, 2.0200, 0.3000, 4.1400, 1.2400, -0.9500, 5.8300, 0.7400, -1.2500, 6.8400, 0.7400, -1.2500, 6.8400, 0.3100, -0.4200, 6.1200, 1.1300, 0.2500, 8.0100, 1.9600, 0.7400, 10.2900, 1.4600, 0.8700, 9.3800, 0.1200, -0.3400, 9.1400, 0.5500, -1.2900, 12.6500, 0.5500, -1.2900, 12.6500, 1.1000, -1.0500, 14.3300, 0.8300, -0.4800, 12.2600, 0.1900, -0.7900, 11.1900, -0.0500, -0.8100, 12.9700, -0.4500, -0.3500, 17.0500, -0.1000, 0.8300, 17.1800, -0.1000, 0.8300, 17.1800, -0.6000, 0.4200, 14.1100, -0.9000, 0.9200, 12.1500, -0.6000, 1.4200, 14.1400, -0.6200, 1.3100, 15.6600, -0.4800, 1.8700, 14.5600, -0.4300, 1.5300, 13.2100, -0.4300, 1.5300, 13.2100, -0.7800, 1.1400, 12.7500, -0.9100, 1.2100, 13.0900, -0.1600, 1.2200, 14.3900, -0.2900, 1.4000, 13.6000, -1.0200, 1.3800, 13.1400, -1.3400, 0.3400, 14.9300, -1.3400, 0.3400, 14.9300, -1.0000, -1.2300, 19.1200, -1.1200, -3.4000, 19.9800, -1.2800, -2.9000, 19.9800, -1.2400, -1.8300, 19.9800, -1.6200, -2.3900, 19.9800, -1.3900, -1.9000, 19.9800, -1.3900, -1.9000, 19.9800, -1.5300, -1.6200, 19.9800, -1.4700, -0.6200, 19.3400, -1.0400, 1.1900, 17.5700, -1.0400, 1.6700, 15.1700, -1.0600, 1.9200, 12.6000, -0.5100, 3.3900, 12.1800, -0.5100, 3.3900, 12.1800, 0.1600, 3.3500, 14.2200, -0.4200, 1.8600, 13.8700, -0.7000, 1.4700, 10.9600, 0.0000, 2.0200, 8.0200, 0.8800, 2.7000, 6.6000, 1.1700, 2.2400, 7.5300, 1.1700, 2.2400, 7.5300, 0.8600, 0.7700, 9.3000, 1.1300, 0.6500, 9.8600, 0.7000, 0.7600, 6.6400, 0.0000, 0.6500, 2.3700, 0.3300, 0.5600, 0.9900, 0.3300, 0.5600, 0.9900, 1.0200, 0.3000, 1.5100, 0.6200, -0.5900, 1.2600, 0.6800, -1.4500, 1.1200, 0.6900, -2.6900, -0.0900, 0.9100, -2.7700, -1.8600, 1.5000, -2.4300, -2.2100, 1.5000, -2.4300, -2.2100, 2.0400, -3.0300, 0.7100, 2.3600, -2.7900, 3.4400, 2.6200, -2.0500, 3.0800, 2.6000, -1.9000, 0.8300, 2.2700, -2.4700, -0.8600, 2.1600, -3.1600, -1.0300, 2.1600, -3.1600, -1.0300, 2.6200, -2.4300, 0.8500, 3.0000, -1.5500, 2.2800, 2.5900, -0.5700, 2.5000, 1.4900, -0.4300, 1.3700, 0.8600, -0.3800, 2.5600, 1.4900, 0.7300, 4.6100, 1.4900, 0.7300, 4.6100, 1.5300, 1.4200, 4.7200, 1.0800, 1.6100, 4.6600, -0.0700, 1.2600, 5.4800, 0.6200, 1.2700, 8.5500, 1.3700, 0.7600, 11.1200, 1.1400, 0.4600, 10.8000, 1.1400, 0.4600, 10.8000, 0.6400, 1.5400, 10.0600, 0.3400, 0.5300, 10.9900, 0.7600, 1.0000, 14.3200, 0.6400, 2.5200, 15.9600, 0.4300, 2.6400, 17.3400, -0.5500, 1.2500, 13.2400, -0.5500, 1.2500, 13.2400, 0.2300, 2.8200, 14.0200, 0.1300, 3.6500, 15.3500, 0.2300, 3.7600, 16.5700, -0.2900, 2.9500, 15.0900, -0.6000, 3.7800, 15.0900, -0.1100, 4.4300, 15.5600, -0.1100, 4.4300, 15.5600, -0.6600, 3.3100, 16.1400, -0.8000, 2.6700, 16.0700, 0.0900, 3.9500, 16.8200, -0.5600, 4.1900, 16.9800, -0.3700, 3.6100, 19.5200, -0.0300, 1.5200, 19.9800, -0.0300, 1.5200, 19.9800, -0.4200, -0.2900, 19.9800, -0.5100, -0.1200, 19.3700, -0.2500, 0.4800, 19.3000, -0.2500, 0.4300, 19.5900, -0.2700, 0.5000, 19.3400, -0.2700, 0.5000, 17.6800, -0.3400, 1.2800, 17.6800, -0.3200, 2.7600, 15.9600, -0.2200, 3.3800, 14.8100 ]) 17 | 18 | axes = ['x', 'y', 'z'] # Axes names. Can be any labels, but the length naturally must match the number of channels in raw data 19 | sampling_freq = 62.5 # Sampling frequency of the data. Ignored for images 20 | 21 | # Below here are parameters that are specific to the spectral analysis DSP block. These are set to the defaults 22 | scale_axes = 1 # Scale your data if desired 23 | input_decimation_ratio = 1 # Decimation ratio. See /spectral_analysis/paramters.json:31 for valid ratios 24 | filter_type = 'none' # Filter type. String : low, high, or none 25 | filter_cutoff = 0 # Cutoff frequency if filtering is chosen. Ignored if filter_type is 'none' 26 | filter_order = 0 # Filter order. Ignored if filter_type is 'none'. 2, 4, 6, or 8 is valid otherwise 27 | analysis_type = 'FFT' # Analysis type. String : FFT, wavelet 28 | 29 | # The following parameters only apply to FFT analysis type. Even if you choose wavelet analysis, these parameters still need dummy values 30 | fft_length = 16 # Size of FFT to perform. Should be power of 2 >-= 16 and <= 4096 31 | 32 | # Deprecated parameters. Only applies to version 1, maintained for backwards compatibility 33 | spectral_peaks_count = 0 # Deprecated parameter. Only applies to version 1, maintained for backwards compatibility 34 | spectral_peaks_threshold = 0 # Deprecated parameter. Only applies to version 1, maintained for backwards compatibility 35 | spectral_power_edges = "0" # Deprecated parameter. Only applies to version 1, maintained for backwards compatibility 36 | 37 | # Current FFT parameters 38 | do_log = True # Take the log of the spectral powers from the FFT frames 39 | do_fft_overlap = True # Overlap FFT frames by 50%. If false, no overlap 40 | extra_low_freq = False # This will decimate the input window by 10 and perform another FFT on the decimated window. 41 | # This is useful to extract low frequency data. The features will be appended to the normal FFT features 42 | 43 | # These parameters only apply to Wavelet analysis type. Even if you choose FFT analysis, these parameters still need dummy values 44 | wavelet_level = 1 # Level of wavelet decomposition 45 | wavelet = "" # Wavelet kernel to use 46 | 47 | output = generate_features(implementation_version, draw_graphs, raw_data, axes, sampling_freq, scale_axes, input_decimation_ratio, 48 | filter_type, filter_cutoff, filter_order, analysis_type, fft_length, spectral_peaks_count, 49 | spectral_peaks_threshold, spectral_power_edges, do_log, do_fft_overlap, 50 | wavelet_level, wavelet, extra_low_freq) 51 | 52 | # Return dictionary, as defined in code 53 | # return { 54 | # 'features': List of output features 55 | # 'graphs': Dictionary of graphs 56 | # 'labels': Names of the features 57 | # 'fft_used': Array showing which FFT sizes were used. Helpful for optimzing embedded DSP code 58 | # 'output_config': information useful for correctly configuring the learn block in Studio 59 | # } 60 | 61 | print(f'Processed features are: ') 62 | print('Feature name, value') 63 | idx = 0 64 | for axis in axes: 65 | print(f'\nFeatures for axis: {axis}') 66 | for label in output['labels']: 67 | print(f'{label: <40}: {output["features"][idx]}') 68 | idx += 1 69 | -------------------------------------------------------------------------------- /spectral_analysis/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental@sha256:3c244c0c6fc9d6aa3ddb73af4264b3a23597523ac553294218c13735a2c6cf79 2 | FROM ubuntu:20.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | WORKDIR /app 7 | 8 | # python3 and all dependencies for scipy 9 | RUN apt update && apt install -y python3 python3-pip libatlas-base-dev gfortran-9 libfreetype6-dev wget && \ 10 | ln -s $(which gfortran-9) /usr/bin/gfortran 11 | 12 | # Update pip 13 | RUN pip3 install -U pip==22.0.3 14 | 15 | # Cython and scikit-learn - it needs to be done in this order for some reason 16 | RUN pip3 --no-cache-dir install Cython==0.29.24 17 | 18 | # Rest of the dependencies 19 | COPY requirements-blocks.txt ./ 20 | RUN pip3 --no-cache-dir install -r requirements-blocks.txt 21 | 22 | COPY third_party /third_party 23 | COPY . ./ 24 | 25 | EXPOSE 4446 26 | 27 | ENTRYPOINT ["python3", "-u", "dsp-server.py"] 28 | -------------------------------------------------------------------------------- /spectral_analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsp import generate_features 2 | from .autotune import autotune_params, Iter -------------------------------------------------------------------------------- /spectral_analysis/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('/') 4 | from common.sampling import calc_resampled_size, calculate_freq, Resampler 5 | 6 | 7 | class Dataset: 8 | '''Create an iterable dataset when x data is flattened, handling reshaping and resampling''' 9 | 10 | def __init__(self, X_all, metadata, axis, returns_interval=True, resample_interval_ms=None): 11 | self.ix = 0 12 | self.returns_interval = returns_interval 13 | self.max_len = 0 14 | 15 | X_all_shaped = [] 16 | y_all = [] 17 | self.y_label_set = set() 18 | intervals_all = [] 19 | current_offset = 0 20 | 21 | if resample_interval_ms: 22 | self.fs = calculate_freq(resample_interval_ms) 23 | else: 24 | self.fs = None 25 | 26 | # Prepare for resampling data 27 | if resample_interval_ms is not None: 28 | resample_utility = Resampler(len(metadata)) 29 | target_freq = calculate_freq(resample_interval_ms) 30 | intervals_all.append(resample_interval_ms) 31 | 32 | # Reshape all samples 33 | for ix in range(len(metadata)): 34 | # Get x data using offset 35 | cur_len = metadata[ix] 36 | X_full = X_all[current_offset: current_offset + cur_len] 37 | current_offset = current_offset + cur_len 38 | 39 | # Split the interval, label from the features 40 | interval_ms = X_full[0] 41 | y = X_full[1] 42 | X = X_full[2:] 43 | 44 | if not self.fs: 45 | # if we didn't get a resampling rate from the caller, use the first sample's rate 46 | self.fs = calculate_freq(interval_ms) 47 | 48 | if not np.isnan(X).any(): 49 | # Reshape 50 | len_adjusted = cur_len - 2 51 | rows = int(len_adjusted / axis) 52 | # Data length is unexpected 53 | if not ((len_adjusted % axis) == 0): 54 | raise ValueError('Sample length is invalid, check the axis count.') 55 | 56 | X = np.reshape(X, (rows, axis)) 57 | 58 | # Resample data 59 | if resample_interval_ms is not None: 60 | # Work out the up and down factors using sample lengths 61 | original_length = X.shape[0] 62 | original_freq = calculate_freq(interval_ms) 63 | new_length = calc_resampled_size(original_freq, target_freq, original_length) 64 | 65 | # Resample 66 | X = resample_utility.resample(X, new_length, original_length) 67 | else: 68 | intervals_all.append(interval_ms) 69 | 70 | # Store the longest sample length 71 | self.max_len = max(self.max_len, X.shape[0]) 72 | X_all_shaped.append(X) 73 | y_all.append(y) 74 | self.y_label_set.add(y) 75 | 76 | self.X_all = X_all_shaped 77 | self.y_all = y_all 78 | self.intervals = intervals_all 79 | 80 | def reset(self): 81 | self.ix = 0 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | if self.ix >= len(self.y_all): 88 | self.reset() 89 | raise StopIteration 90 | 91 | X = self.X_all[self.ix] 92 | y = self.y_all[self.ix] 93 | if (len(self.intervals) == 1): 94 | # Resampled data has the same interval so we only store it once 95 | interval_ms = self.intervals[0] 96 | else: 97 | interval_ms = self.intervals[self.ix] 98 | 99 | self.ix += 1 100 | 101 | if (self.returns_interval): 102 | return X, y, interval_ms 103 | else: 104 | return X, y 105 | -------------------------------------------------------------------------------- /spectral_analysis/common/errors.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import traceback 4 | 5 | 6 | class ConfigurationError(Exception): 7 | pass 8 | 9 | 10 | def log(*msg, level='warn'): 11 | msg_clean = ' '.join([str(i) for i in msg]) 12 | print(json.dumps( 13 | {'msg': msg_clean, 14 | 'level': level, 15 | 'time': datetime.datetime.now().replace(microsecond=0).isoformat() + 'Z'})) 16 | 17 | 18 | def log_exception(msg): 19 | log(msg + ': ' + traceback.format_exc(), level='error') 20 | -------------------------------------------------------------------------------- /spectral_analysis/common/graphing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import io 5 | import base64 6 | import math 7 | 8 | 9 | def set_x_axis_times(frame_stride, frame_length, width): 10 | plt.xlabel('Time [sec]') 11 | time_len = (width * frame_stride) + frame_length 12 | times = np.linspace(0, time_len, 10) 13 | plt.xticks(np.linspace(0, width, len(times)), [round(x, 2) for x in times]) 14 | 15 | 16 | def create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs=None): 17 | matplotlib.use('Svg') 18 | _, ax = plt.subplots() 19 | if not freqs: 20 | freqs = np.linspace(0, sampling_freq / 2, 15) 21 | plt.ylabel('Frequency [Hz]') 22 | ax.imshow(power_spectrum, interpolation='nearest', 23 | cmap=matplotlib.cm.coolwarm, origin='lower') 24 | plt.yticks(np.linspace(0, height, len(freqs)), [math.ceil(x) for x in freqs]) 25 | set_x_axis_times(frame_stride, frame_length, width) 26 | 27 | buf = io.BytesIO() 28 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 29 | buf.seek(0) 30 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 31 | buf.close() 32 | return image 33 | 34 | 35 | def create_mfe_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs): 36 | # Trim down the frequency list for a y axis labels 37 | freqs = [freqs[0], *freqs[1:-1:4], freqs[-1]] 38 | return create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs) 39 | -------------------------------------------------------------------------------- /spectral_analysis/common/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import sys 4 | from scipy import signal 5 | 6 | 7 | def calc_resampled_size(input_sample_rate, output_sample_rate, input_length): 8 | """Calculate the output size after resampling. 9 | :returns: integer output size, >= 1 10 | """ 11 | target_size = int( 12 | math.ceil((output_sample_rate / input_sample_rate) * (input_length))) 13 | return max(target_size, 1) 14 | 15 | 16 | def calculate_freq(interval): 17 | """ Convert interval (ms) to frequency (Hz) 18 | """ 19 | freq = 1000 / interval 20 | if abs(freq - round(freq)) < 0.01: 21 | freq = round(freq) 22 | return freq 23 | 24 | 25 | def calc_decimation_ratios(filter_type, filter_cutoff, fs): 26 | if filter_type != 'low': 27 | return 1 28 | 29 | # we support base ratios of 3 and 10 in SDK 30 | ratios = [3, 10, 30, 100, 1000] 31 | ratios.reverse() 32 | for r in ratios: 33 | if fs / 2 / r * 0.9 > filter_cutoff: 34 | return r 35 | 36 | return 1 37 | 38 | 39 | def get_ratio_combo(r): 40 | if r == 1: 41 | return [1] 42 | elif r == 3 or r == 10: 43 | return [r] 44 | elif r == 30: 45 | return [3, 10] 46 | elif r == 100: 47 | return [10, 10] 48 | elif r == 1000: 49 | return [10, 10, 10] 50 | else: 51 | raise ValueError("Invalid decimation ratio: {}".format(r)) 52 | 53 | 54 | def create_decimate_filter(ratio): 55 | sos = signal.cheby1(8, 0.05, 0.8 / ratio, output='sos') 56 | zi = signal.sosfilt_zi(sos) 57 | return sos, zi 58 | 59 | 60 | def decimate_simple(x, ratio, export=False): 61 | if x.ndim != 1: 62 | raise ValueError(f'x must be 1D {x.shape}') 63 | x = x.reshape(x.shape[0]) 64 | if (ratio == 1): 65 | return x 66 | sos, zi = create_decimate_filter(ratio) 67 | y, zo = signal.sosfilt(sos, x, zi=zi * x[0]) 68 | sl = slice(None, None, ratio) 69 | y = y[sl] 70 | if export: 71 | return y, sos, zi 72 | return y 73 | 74 | 75 | class Resampler: 76 | """ Utility class to handle resampling and logging 77 | """ 78 | 79 | def __init__(self, total_samples): 80 | self.total_samples = total_samples 81 | self.ix = 0 82 | self.last_message = 0 83 | 84 | def resample(self, sample, new_length, original_length): 85 | # Work out the correct axis 86 | ds_axis = 0 87 | if (sample.shape[0] == 1): 88 | ds_axis = 1 89 | 90 | # Resample 91 | if (original_length != new_length): 92 | sample = signal.resample_poly( 93 | sample, new_length, original_length, axis=ds_axis) 94 | 95 | # Logging 96 | self.ix += 1 97 | if (int(round(time.time() * 1000)) - self.last_message >= 3000) or (self.ix == self.total_samples): 98 | print('[%s/%d] Resampling windows...' % 99 | (str(self.ix).rjust(len(str(self.total_samples)), ' '), self.total_samples)) 100 | 101 | if (self.ix == self.total_samples): 102 | print('Resampled %d windows\n' % self.total_samples) 103 | 104 | sys.stdout.flush() 105 | self.last_message = int(round(time.time() * 1000)) 106 | 107 | return sample 108 | -------------------------------------------------------------------------------- /spectral_analysis/common/spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | sys.path.append('/') 5 | from .errors import ConfigurationError 6 | 7 | 8 | def next_power_of_2(x): 9 | return 1 if x == 0 else 2**(x - 1).bit_length() 10 | 11 | 12 | def welch_max_hold(fx, sampling_freq, nfft, n_overlap): 13 | n_overlap = int(n_overlap) 14 | spec_powers = [0 for _ in range(nfft//2+1)] 15 | ix = 0 16 | while ix <= len(fx): 17 | # Slicing truncates if end_idx > len, and rfft will auto zero pad 18 | fft_out = np.abs(np.fft.rfft(fx[ix:ix+nfft], nfft)) 19 | spec_powers = np.maximum(spec_powers, fft_out**2/nfft) 20 | ix = ix + (nfft-n_overlap) 21 | return np.fft.rfftfreq(nfft, 1/sampling_freq), spec_powers 22 | 23 | 24 | def zero_handling(x): 25 | """ 26 | This function handle the issue with zero values if the are exposed 27 | to become an argument for any log function. 28 | :param x: The vector. 29 | :return: The vector with zeros substituted with epsilon values. 30 | """ 31 | return np.where(x == 0, 1e-10, x) 32 | 33 | 34 | def cap_frame_stride(window_size_ms, frame_stride): 35 | """Returns the frame stride passed in, 36 | or a stride that creates 500 frames if the window size is too large. 37 | 38 | Args: 39 | window_size_ms (int): The users window size (in ms). 40 | If none or 0, no capping is done. 41 | frame_stride (float): The desired frame stride 42 | 43 | Returns: 44 | float: Either the passed in frame_stride, or longer frame stride 45 | """ 46 | if window_size_ms: 47 | num_frames = (window_size_ms / 1000) / frame_stride 48 | if num_frames > 500: 49 | print('WARNING: Your window size is too large for the ideal frame stride. ' 50 | f'Set window size to {500 * frame_stride * 1000} ms, or smaller. ' 51 | 'Adjusting ideal frame stride to set number of frames to 500') 52 | frame_stride = (window_size_ms / 1000) / 500 53 | return frame_stride 54 | 55 | 56 | def audio_set_params(frame_length, fs): 57 | """Suggest parameters for audio processing (MFE/MFCC) 58 | 59 | Args: 60 | frame_length (float): The desired frame length (in seconds) 61 | fs (int): The sampling frequency (in Hz) 62 | 63 | Returns: 64 | fft_length: Recomended FFT length 65 | num_filters: Recomended number of filters 66 | """ 67 | DEFAULT_NUM_FILTERS = 40 68 | DEFAULT_NFFT = 256 # for 8kHz sampling rate 69 | 70 | fft_length = next_power_of_2(int(frame_length * fs)) 71 | num_filters = int(DEFAULT_NUM_FILTERS + np.log2(fft_length / DEFAULT_NFFT)) 72 | return fft_length, num_filters 73 | -------------------------------------------------------------------------------- /spectral_analysis/common/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import numpy as np 3 | from scipy.stats import skew, entropy, kurtosis 4 | 5 | 6 | def calculate_entropy(x): 7 | # todo: try approximate entropy 8 | # todo: try Kozachenko and Leonenko 9 | probabilities = np.histogram(x, bins=100, density=True)[0] 10 | return {'entropy': entropy(probabilities)} 11 | 12 | 13 | def get_percentile_from_sorted(array, percentile): 14 | # adding 0.5 is a trick to get rounding out of C flooring behavior during cast 15 | index = int(((len(array)-1) * percentile/100) + 0.5) 16 | return array[index] 17 | 18 | 19 | def calculate_statistics(x): 20 | output = {} 21 | x.sort() 22 | output['n5'] = get_percentile_from_sorted(x, 5) 23 | output['n25'] = get_percentile_from_sorted(x, 25) 24 | output['n75'] = get_percentile_from_sorted(x, 75) 25 | output['n95'] = get_percentile_from_sorted(x, 95) 26 | output['median'] = get_percentile_from_sorted(x, 50) 27 | output['mean'] = np.mean(x) 28 | output['std'] = np.std(x) 29 | output['var'] = np.var(x, ddof=1) 30 | output['rms'] = np.sqrt(np.mean(x**2)) 31 | output['skew'] = 0 if output['rms'] == 0 else skew(x) 32 | output['kurtosis'] = 0 if output['rms'] == 0 else kurtosis(x) 33 | return output 34 | 35 | 36 | def calculate_crossings(x): 37 | lx = len(x) 38 | zero_crossing_indices = np.nonzero(np.diff(np.array(x) > 0))[0] 39 | no_zero_crossings = len(zero_crossing_indices) / lx 40 | m = np.nanmean(x) 41 | mean_crossing_indices = np.nonzero(np.diff(np.array(x) > m))[0] 42 | no_mean_crossings = len(mean_crossing_indices) / lx 43 | return {'zcross': no_zero_crossings, 'mcross': no_mean_crossings} 44 | 45 | 46 | def get_features(x): 47 | features = calculate_entropy(x) 48 | features.update(calculate_crossings(x)) 49 | features.update(calculate_statistics(x)) 50 | return features 51 | 52 | 53 | def get_max_level(signal_length): 54 | return int(np.log2(signal_length / 32)) 55 | 56 | 57 | def get_min_length(level): 58 | return 32 * np.power(2, level) 59 | 60 | 61 | def dwt_features(x, wav='db4', level=4, mode='stats'): 62 | y = pywt.wavedec(x, wav, level=level) 63 | 64 | if mode == 'raw': 65 | XW = [item for sublist in y for item in sublist] 66 | else: 67 | features = [] 68 | labels = [] 69 | for i in range(len(y)): 70 | d = get_features(y[i]) 71 | for k, v in d.items(): 72 | features.append(v) 73 | labels.append('L' + str(i) + '-' + k) 74 | 75 | return features, labels, y[0] 76 | 77 | 78 | def get_wavefunc(wav, level): 79 | 80 | wavelet = pywt.Wavelet(wav) 81 | try: 82 | phi, psi, x = wavelet.wavefun(level) 83 | except: 84 | phi, psi, _, _, x = wavelet.wavefun(level) 85 | return phi, psi, x 86 | -------------------------------------------------------------------------------- /spectral_analysis/requirements-blocks.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /spectral_analysis/third_party/placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edgeimpulse/processing-blocks/ba8108d8427ede9d8098808f283d95e0f7d610b8/spectral_analysis/third_party/placeholder -------------------------------------------------------------------------------- /spectrogram/Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax = docker/dockerfile:experimental@sha256:3c244c0c6fc9d6aa3ddb73af4264b3a23597523ac553294218c13735a2c6cf79 2 | FROM ubuntu:20.04 3 | 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | WORKDIR /app 7 | 8 | # python3 and all dependencies for scipy 9 | RUN apt update && apt install -y python3 python3-pip libatlas-base-dev gfortran-9 libfreetype6-dev wget && \ 10 | ln -s $(which gfortran-9) /usr/bin/gfortran 11 | 12 | # Update pip 13 | RUN pip3 install -U pip==22.0.3 14 | 15 | # Cython and scikit-learn - it needs to be done in this order for some reason 16 | RUN pip3 --no-cache-dir install Cython==0.29.24 17 | 18 | # Rest of the dependencies 19 | COPY requirements-blocks.txt ./ 20 | RUN pip3 --no-cache-dir install -r requirements-blocks.txt 21 | 22 | COPY third_party /third_party 23 | COPY . ./ 24 | 25 | EXPOSE 4446 26 | 27 | ENTRYPOINT ["python3", "-u", "dsp-server.py"] 28 | -------------------------------------------------------------------------------- /spectrogram/__init__.py: -------------------------------------------------------------------------------- 1 | from .dsp import generate_features 2 | -------------------------------------------------------------------------------- /spectrogram/common/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('/') 4 | from common.sampling import calc_resampled_size, calculate_freq, Resampler 5 | 6 | 7 | class Dataset: 8 | '''Create an iterable dataset when x data is flattened, handling reshaping and resampling''' 9 | 10 | def __init__(self, X_all, metadata, axis, returns_interval=True, resample_interval_ms=None): 11 | self.ix = 0 12 | self.returns_interval = returns_interval 13 | self.max_len = 0 14 | 15 | X_all_shaped = [] 16 | y_all = [] 17 | self.y_label_set = set() 18 | intervals_all = [] 19 | current_offset = 0 20 | 21 | if resample_interval_ms: 22 | self.fs = calculate_freq(resample_interval_ms) 23 | else: 24 | self.fs = None 25 | 26 | # Prepare for resampling data 27 | if resample_interval_ms is not None: 28 | resample_utility = Resampler(len(metadata)) 29 | target_freq = calculate_freq(resample_interval_ms) 30 | intervals_all.append(resample_interval_ms) 31 | 32 | # Reshape all samples 33 | for ix in range(len(metadata)): 34 | # Get x data using offset 35 | cur_len = metadata[ix] 36 | X_full = X_all[current_offset: current_offset + cur_len] 37 | current_offset = current_offset + cur_len 38 | 39 | # Split the interval, label from the features 40 | interval_ms = X_full[0] 41 | y = X_full[1] 42 | X = X_full[2:] 43 | 44 | if not self.fs: 45 | # if we didn't get a resampling rate from the caller, use the first sample's rate 46 | self.fs = calculate_freq(interval_ms) 47 | 48 | if not np.isnan(X).any(): 49 | # Reshape 50 | len_adjusted = cur_len - 2 51 | rows = int(len_adjusted / axis) 52 | # Data length is unexpected 53 | if not ((len_adjusted % axis) == 0): 54 | raise ValueError('Sample length is invalid, check the axis count.') 55 | 56 | X = np.reshape(X, (rows, axis)) 57 | 58 | # Resample data 59 | if resample_interval_ms is not None: 60 | # Work out the up and down factors using sample lengths 61 | original_length = X.shape[0] 62 | original_freq = calculate_freq(interval_ms) 63 | new_length = calc_resampled_size(original_freq, target_freq, original_length) 64 | 65 | # Resample 66 | X = resample_utility.resample(X, new_length, original_length) 67 | else: 68 | intervals_all.append(interval_ms) 69 | 70 | # Store the longest sample length 71 | self.max_len = max(self.max_len, X.shape[0]) 72 | X_all_shaped.append(X) 73 | y_all.append(y) 74 | self.y_label_set.add(y) 75 | 76 | self.X_all = X_all_shaped 77 | self.y_all = y_all 78 | self.intervals = intervals_all 79 | 80 | def reset(self): 81 | self.ix = 0 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | if self.ix >= len(self.y_all): 88 | self.reset() 89 | raise StopIteration 90 | 91 | X = self.X_all[self.ix] 92 | y = self.y_all[self.ix] 93 | if (len(self.intervals) == 1): 94 | # Resampled data has the same interval so we only store it once 95 | interval_ms = self.intervals[0] 96 | else: 97 | interval_ms = self.intervals[self.ix] 98 | 99 | self.ix += 1 100 | 101 | if (self.returns_interval): 102 | return X, y, interval_ms 103 | else: 104 | return X, y 105 | -------------------------------------------------------------------------------- /spectrogram/common/errors.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import traceback 4 | 5 | 6 | class ConfigurationError(Exception): 7 | pass 8 | 9 | 10 | def log(*msg, level='warn'): 11 | msg_clean = ' '.join([str(i) for i in msg]) 12 | print(json.dumps( 13 | {'msg': msg_clean, 14 | 'level': level, 15 | 'time': datetime.datetime.now().replace(microsecond=0).isoformat() + 'Z'})) 16 | 17 | 18 | def log_exception(msg): 19 | log(msg + ': ' + traceback.format_exc(), level='error') 20 | -------------------------------------------------------------------------------- /spectrogram/common/graphing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib 4 | import io 5 | import base64 6 | import math 7 | 8 | 9 | def set_x_axis_times(frame_stride, frame_length, width): 10 | plt.xlabel('Time [sec]') 11 | time_len = (width * frame_stride) + frame_length 12 | times = np.linspace(0, time_len, 10) 13 | plt.xticks(np.linspace(0, width, len(times)), [round(x, 2) for x in times]) 14 | 15 | 16 | def create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs=None): 17 | matplotlib.use('Svg') 18 | _, ax = plt.subplots() 19 | if not freqs: 20 | freqs = np.linspace(0, sampling_freq / 2, 15) 21 | plt.ylabel('Frequency [Hz]') 22 | ax.imshow(power_spectrum, interpolation='nearest', 23 | cmap=matplotlib.cm.coolwarm, origin='lower') 24 | plt.yticks(np.linspace(0, height, len(freqs)), [math.ceil(x) for x in freqs]) 25 | set_x_axis_times(frame_stride, frame_length, width) 26 | 27 | buf = io.BytesIO() 28 | plt.savefig(buf, format='svg', bbox_inches='tight', pad_inches=0) 29 | buf.seek(0) 30 | image = (base64.b64encode(buf.getvalue()).decode('ascii')) 31 | buf.close() 32 | return image 33 | 34 | 35 | def create_mfe_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs): 36 | # Trim down the frequency list for a y axis labels 37 | freqs = [freqs[0], *freqs[1:-1:4], freqs[-1]] 38 | return create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum, freqs) 39 | -------------------------------------------------------------------------------- /spectrogram/common/sampling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import sys 4 | from scipy import signal 5 | 6 | 7 | def calc_resampled_size(input_sample_rate, output_sample_rate, input_length): 8 | """Calculate the output size after resampling. 9 | :returns: integer output size, >= 1 10 | """ 11 | target_size = int( 12 | math.ceil((output_sample_rate / input_sample_rate) * (input_length))) 13 | return max(target_size, 1) 14 | 15 | 16 | def calculate_freq(interval): 17 | """ Convert interval (ms) to frequency (Hz) 18 | """ 19 | freq = 1000 / interval 20 | if abs(freq - round(freq)) < 0.01: 21 | freq = round(freq) 22 | return freq 23 | 24 | 25 | def calc_decimation_ratios(filter_type, filter_cutoff, fs): 26 | if filter_type != 'low': 27 | return 1 28 | 29 | # we support base ratios of 3 and 10 in SDK 30 | ratios = [3, 10, 30, 100, 1000] 31 | ratios.reverse() 32 | for r in ratios: 33 | if fs / 2 / r * 0.9 > filter_cutoff: 34 | return r 35 | 36 | return 1 37 | 38 | 39 | def get_ratio_combo(r): 40 | if r == 1: 41 | return [1] 42 | elif r == 3 or r == 10: 43 | return [r] 44 | elif r == 30: 45 | return [3, 10] 46 | elif r == 100: 47 | return [10, 10] 48 | elif r == 1000: 49 | return [10, 10, 10] 50 | else: 51 | raise ValueError("Invalid decimation ratio: {}".format(r)) 52 | 53 | 54 | def create_decimate_filter(ratio): 55 | sos = signal.cheby1(8, 0.05, 0.8 / ratio, output='sos') 56 | zi = signal.sosfilt_zi(sos) 57 | return sos, zi 58 | 59 | 60 | def decimate_simple(x, ratio, export=False): 61 | if x.ndim != 1: 62 | raise ValueError(f'x must be 1D {x.shape}') 63 | x = x.reshape(x.shape[0]) 64 | if (ratio == 1): 65 | return x 66 | sos, zi = create_decimate_filter(ratio) 67 | y, zo = signal.sosfilt(sos, x, zi=zi * x[0]) 68 | sl = slice(None, None, ratio) 69 | y = y[sl] 70 | if export: 71 | return y, sos, zi 72 | return y 73 | 74 | 75 | class Resampler: 76 | """ Utility class to handle resampling and logging 77 | """ 78 | 79 | def __init__(self, total_samples): 80 | self.total_samples = total_samples 81 | self.ix = 0 82 | self.last_message = 0 83 | 84 | def resample(self, sample, new_length, original_length): 85 | # Work out the correct axis 86 | ds_axis = 0 87 | if (sample.shape[0] == 1): 88 | ds_axis = 1 89 | 90 | # Resample 91 | if (original_length != new_length): 92 | sample = signal.resample_poly( 93 | sample, new_length, original_length, axis=ds_axis) 94 | 95 | # Logging 96 | self.ix += 1 97 | if (int(round(time.time() * 1000)) - self.last_message >= 3000) or (self.ix == self.total_samples): 98 | print('[%s/%d] Resampling windows...' % 99 | (str(self.ix).rjust(len(str(self.total_samples)), ' '), self.total_samples)) 100 | 101 | if (self.ix == self.total_samples): 102 | print('Resampled %d windows\n' % self.total_samples) 103 | 104 | sys.stdout.flush() 105 | self.last_message = int(round(time.time() * 1000)) 106 | 107 | return sample 108 | -------------------------------------------------------------------------------- /spectrogram/common/spectrum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | 4 | sys.path.append('/') 5 | from .errors import ConfigurationError 6 | 7 | 8 | def next_power_of_2(x): 9 | return 1 if x == 0 else 2**(x - 1).bit_length() 10 | 11 | 12 | def welch_max_hold(fx, sampling_freq, nfft, n_overlap): 13 | n_overlap = int(n_overlap) 14 | spec_powers = [0 for _ in range(nfft//2+1)] 15 | ix = 0 16 | while ix <= len(fx): 17 | # Slicing truncates if end_idx > len, and rfft will auto zero pad 18 | fft_out = np.abs(np.fft.rfft(fx[ix:ix+nfft], nfft)) 19 | spec_powers = np.maximum(spec_powers, fft_out**2/nfft) 20 | ix = ix + (nfft-n_overlap) 21 | return np.fft.rfftfreq(nfft, 1/sampling_freq), spec_powers 22 | 23 | 24 | def zero_handling(x): 25 | """ 26 | This function handle the issue with zero values if the are exposed 27 | to become an argument for any log function. 28 | :param x: The vector. 29 | :return: The vector with zeros substituted with epsilon values. 30 | """ 31 | return np.where(x == 0, 1e-10, x) 32 | 33 | 34 | def cap_frame_stride(window_size_ms, frame_stride): 35 | """Returns the frame stride passed in, 36 | or a stride that creates 500 frames if the window size is too large. 37 | 38 | Args: 39 | window_size_ms (int): The users window size (in ms). 40 | If none or 0, no capping is done. 41 | frame_stride (float): The desired frame stride 42 | 43 | Returns: 44 | float: Either the passed in frame_stride, or longer frame stride 45 | """ 46 | if window_size_ms: 47 | num_frames = (window_size_ms / 1000) / frame_stride 48 | if num_frames > 500: 49 | print('WARNING: Your window size is too large for the ideal frame stride. ' 50 | f'Set window size to {500 * frame_stride * 1000} ms, or smaller. ' 51 | 'Adjusting ideal frame stride to set number of frames to 500') 52 | frame_stride = (window_size_ms / 1000) / 500 53 | return frame_stride 54 | 55 | 56 | def audio_set_params(frame_length, fs): 57 | """Suggest parameters for audio processing (MFE/MFCC) 58 | 59 | Args: 60 | frame_length (float): The desired frame length (in seconds) 61 | fs (int): The sampling frequency (in Hz) 62 | 63 | Returns: 64 | fft_length: Recomended FFT length 65 | num_filters: Recomended number of filters 66 | """ 67 | DEFAULT_NUM_FILTERS = 40 68 | DEFAULT_NFFT = 256 # for 8kHz sampling rate 69 | 70 | fft_length = next_power_of_2(int(frame_length * fs)) 71 | num_filters = int(DEFAULT_NUM_FILTERS + np.log2(fft_length / DEFAULT_NFFT)) 72 | return fft_length, num_filters 73 | -------------------------------------------------------------------------------- /spectrogram/common/wavelet.py: -------------------------------------------------------------------------------- 1 | import pywt 2 | import numpy as np 3 | from scipy.stats import skew, entropy, kurtosis 4 | 5 | 6 | def calculate_entropy(x): 7 | # todo: try approximate entropy 8 | # todo: try Kozachenko and Leonenko 9 | probabilities = np.histogram(x, bins=100, density=True)[0] 10 | return {'entropy': entropy(probabilities)} 11 | 12 | 13 | def get_percentile_from_sorted(array, percentile): 14 | # adding 0.5 is a trick to get rounding out of C flooring behavior during cast 15 | index = int(((len(array)-1) * percentile/100) + 0.5) 16 | return array[index] 17 | 18 | 19 | def calculate_statistics(x): 20 | output = {} 21 | x.sort() 22 | output['n5'] = get_percentile_from_sorted(x, 5) 23 | output['n25'] = get_percentile_from_sorted(x, 25) 24 | output['n75'] = get_percentile_from_sorted(x, 75) 25 | output['n95'] = get_percentile_from_sorted(x, 95) 26 | output['median'] = get_percentile_from_sorted(x, 50) 27 | output['mean'] = np.mean(x) 28 | output['std'] = np.std(x) 29 | output['var'] = np.var(x, ddof=1) 30 | output['rms'] = np.sqrt(np.mean(x**2)) 31 | output['skew'] = 0 if output['rms'] == 0 else skew(x) 32 | output['kurtosis'] = 0 if output['rms'] == 0 else kurtosis(x) 33 | return output 34 | 35 | 36 | def calculate_crossings(x): 37 | lx = len(x) 38 | zero_crossing_indices = np.nonzero(np.diff(np.array(x) > 0))[0] 39 | no_zero_crossings = len(zero_crossing_indices) / lx 40 | m = np.nanmean(x) 41 | mean_crossing_indices = np.nonzero(np.diff(np.array(x) > m))[0] 42 | no_mean_crossings = len(mean_crossing_indices) / lx 43 | return {'zcross': no_zero_crossings, 'mcross': no_mean_crossings} 44 | 45 | 46 | def get_features(x): 47 | features = calculate_entropy(x) 48 | features.update(calculate_crossings(x)) 49 | features.update(calculate_statistics(x)) 50 | return features 51 | 52 | 53 | def get_max_level(signal_length): 54 | return int(np.log2(signal_length / 32)) 55 | 56 | 57 | def get_min_length(level): 58 | return 32 * np.power(2, level) 59 | 60 | 61 | def dwt_features(x, wav='db4', level=4, mode='stats'): 62 | y = pywt.wavedec(x, wav, level=level) 63 | 64 | if mode == 'raw': 65 | XW = [item for sublist in y for item in sublist] 66 | else: 67 | features = [] 68 | labels = [] 69 | for i in range(len(y)): 70 | d = get_features(y[i]) 71 | for k, v in d.items(): 72 | features.append(v) 73 | labels.append('L' + str(i) + '-' + k) 74 | 75 | return features, labels, y[0] 76 | 77 | 78 | def get_wavefunc(wav, level): 79 | 80 | wavelet = pywt.Wavelet(wav) 81 | try: 82 | phi, psi, x = wavelet.wavefun(level) 83 | except: 84 | phi, psi, _, _, x = wavelet.wavefun(level) 85 | return phi, psi, x 86 | -------------------------------------------------------------------------------- /spectrogram/dsp-server.py: -------------------------------------------------------------------------------- 1 | # This is a generic Edge Impulse DSP server in Python 2 | # You probably don't need to change this file. 3 | 4 | import sys, importlib, os, socket, json, math, traceback 5 | from http.server import HTTPServer, BaseHTTPRequestHandler 6 | from socketserver import ThreadingMixIn 7 | import threading 8 | from urllib.parse import urlparse, parse_qs 9 | import traceback 10 | import logging 11 | import numpy as np 12 | from dsp import generate_features 13 | 14 | def get_params(self): 15 | with open('parameters.json', 'r') as f: 16 | return json.loads(f.read()) 17 | 18 | def single_req(self, fn, body): 19 | if (not body['features'] or len(body['features']) == 0): 20 | raise ValueError('Missing "features" in body') 21 | if (not 'params' in body): 22 | raise ValueError('Missing "params" in body') 23 | if (not 'sampling_freq' in body): 24 | raise ValueError('Missing "sampling_freq" in body') 25 | if (not 'draw_graphs' in body): 26 | raise ValueError('Missing "draw_graphs" in body') 27 | 28 | args = { 29 | 'draw_graphs': body['draw_graphs'], 30 | 'raw_data': np.array(body['features']), 31 | 'axes': np.array(body['axes']), 32 | 'sampling_freq': body['sampling_freq'], 33 | 'implementation_version': body['implementation_version'] 34 | } 35 | 36 | for param_key in body['params'].keys(): 37 | args[param_key] = body['params'][param_key] 38 | 39 | processed = fn(**args) 40 | if (isinstance(processed['features'], np.ndarray)): 41 | processed['features'] = processed['features'].flatten().tolist() 42 | 43 | body = json.dumps(processed) 44 | 45 | self.send_response(200) 46 | self.send_header('Content-Type', 'application/json') 47 | self.end_headers() 48 | self.wfile.write(body.encode()) 49 | 50 | def batch_req(self, fn, body): 51 | if (not body['features'] or len(body['features']) == 0): 52 | raise ValueError('Missing "features" in body') 53 | if (not 'params' in body): 54 | raise ValueError('Missing "params" in body') 55 | if (not 'sampling_freq' in body): 56 | raise ValueError('Missing "sampling_freq" in body') 57 | 58 | base_args = { 59 | 'draw_graphs': False, 60 | 'axes': np.array(body['axes']), 61 | 'sampling_freq': body['sampling_freq'], 62 | 'implementation_version': body['implementation_version'] 63 | } 64 | 65 | for param_key in body['params'].keys(): 66 | base_args[param_key] = body['params'][param_key] 67 | 68 | total = 0 69 | features = [] 70 | labels = [] 71 | output_config = None 72 | 73 | for example in body['features']: 74 | args = dict(base_args) 75 | args['raw_data'] = np.array(example) 76 | f = fn(**args) 77 | if (isinstance(f['features'], np.ndarray)): 78 | features.append(f['features'].flatten().tolist()) 79 | else: 80 | features.append(f['features']) 81 | 82 | if total == 0: 83 | if ('labels' in f): 84 | labels = f['labels'] 85 | if ('output_config' in f): 86 | output_config = f['output_config'] 87 | 88 | total += 1 89 | 90 | body = json.dumps({ 91 | 'success': True, 92 | 'features': features, 93 | 'labels': labels, 94 | 'output_config': output_config 95 | }) 96 | 97 | self.send_response(200) 98 | self.send_header('Content-Type', 'application/json') 99 | self.end_headers() 100 | self.wfile.write(body.encode()) 101 | 102 | def tflite_req(self, fn, body): 103 | if (not 'params' in body): 104 | raise ValueError('Missing "params" in body') 105 | if (not 'sampling_freq' in body): 106 | raise ValueError('Missing "sampling_freq" in body') 107 | 108 | args = { 109 | 'axes': np.array(body['axes']), 110 | 'sampling_freq': body['sampling_freq'], 111 | 'implementation_version': body['implementation_version'], 112 | 'input_shape': body['input_shape'] 113 | } 114 | 115 | for param_key in body['params'].keys(): 116 | args[param_key] = body['params'][param_key] 117 | 118 | tflite_byte_arr = fn(**args) 119 | 120 | self.send_response(200) 121 | self.send_header('Content-type', 'application/octet-stream') 122 | self.send_header('Content-Disposition', 'attachment; filename="dsp.tflite"') 123 | self.end_headers() 124 | self.wfile.write(tflite_byte_arr) 125 | 126 | class Handler(BaseHTTPRequestHandler): 127 | def do_GET(self): 128 | url = urlparse(self.path) 129 | params = get_params(self) 130 | 131 | if (url.path == '/'): 132 | self.send_response(200) 133 | self.send_header('Content-Type', 'text/plain') 134 | self.end_headers() 135 | self.wfile.write(('Edge Impulse DSP block: ' + params['info']['title'] + ' by ' + 136 | params['info']['author']).encode()) 137 | 138 | elif (url.path == '/parameters'): 139 | self.send_response(200) 140 | self.send_header('Content-Type', 'application/json') 141 | self.end_headers() 142 | params['version'] = 1 143 | self.wfile.write(json.dumps(params).encode()) 144 | 145 | else: 146 | self.send_response(404) 147 | self.send_header('Content-Type', 'text/plain') 148 | self.end_headers() 149 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 150 | 151 | def do_POST(self): 152 | url = urlparse(self.path) 153 | try: 154 | if (url.path == '/run'): 155 | content_len = int(self.headers.get('Content-Length')) 156 | post_body = self.rfile.read(content_len) 157 | body = json.loads(post_body.decode('utf-8')) 158 | single_req(self, generate_features, body) 159 | 160 | elif (url.path == '/batch'): 161 | content_len = int(self.headers.get('Content-Length')) 162 | post_body = self.rfile.read(content_len) 163 | body = json.loads(post_body.decode('utf-8')) 164 | batch_req(self, generate_features, body) 165 | 166 | else: 167 | self.send_response(404) 168 | self.send_header('Content-Type', 'text/plain') 169 | self.end_headers() 170 | self.wfile.write(b'Invalid path ' + self.path.encode() + b'\n') 171 | 172 | 173 | except Exception as e: 174 | print('Failed to handle request', e, traceback.format_exc()) 175 | self.send_response(200) 176 | self.send_header('Content-Type', 'application/json') 177 | self.end_headers() 178 | self.wfile.write(json.dumps({ 'success': False, 'error': str(e) }).encode()) 179 | 180 | def log_message(self, format, *args): 181 | return 182 | 183 | class ThreadingSimpleServer(ThreadingMixIn, HTTPServer): 184 | pass 185 | 186 | def run(): 187 | host = '0.0.0.0' if not 'HOST' in os.environ else os.environ['HOST'] 188 | port = 4446 if not 'PORT' in os.environ else int(os.environ['PORT']) 189 | 190 | server = ThreadingSimpleServer((host, port), Handler) 191 | print('Listening on host', host, 'port', port) 192 | server.serve_forever() 193 | 194 | if __name__ == '__main__': 195 | run() 196 | -------------------------------------------------------------------------------- /spectrogram/dsp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import os, sys 5 | import math 6 | 7 | import pathlib 8 | ROOT = pathlib.Path(__file__).parent 9 | sys.path.append(str(ROOT / '..')) 10 | sys.path.append(str(object=ROOT )) 11 | from common import graphing 12 | from common.errors import ConfigurationError 13 | from third_party import speechpy 14 | 15 | 16 | def generate_features(implementation_version, draw_graphs, raw_data, axes, sampling_freq, 17 | frame_length, frame_stride, fft_length, 18 | show_axes, noise_floor_db): 19 | if (implementation_version >4): 20 | raise ConfigurationError('implementation_version should be <= 4') 21 | if (not math.log2(fft_length).is_integer()): 22 | raise ConfigurationError('FFT length must be a power of 2') 23 | if (len(axes) != 1): 24 | raise ConfigurationError('Spectrogram blocks only support a single axis, ' + 25 | 'create one spectrogram block per axis under **Create impulse**') 26 | if (len(raw_data) < 1): 27 | raise ConfigurationError('Input data must not be empty') 28 | if (frame_length < 4/sampling_freq): 29 | raise ConfigurationError('Frame length should be at least 4 samples') 30 | if (frame_stride < 4/sampling_freq): 31 | raise ConfigurationError('Frame stride should be at least 4 samples') 32 | 33 | fs = sampling_freq 34 | 35 | # reshape first 36 | raw_data = raw_data.reshape(int(len(raw_data) / len(axes)), len(axes)) 37 | 38 | features = [] 39 | graphs = [] 40 | 41 | width = 0 42 | height = 0 43 | 44 | for ax in range(0, len(axes)): 45 | signal = raw_data[:,ax] 46 | 47 | if implementation_version == 3: 48 | # Rescale to [-1, 1] 49 | if np.any((signal < -1) | (signal > 1)): 50 | signal = (signal / 2**15).astype(np.float32) 51 | 52 | sampling_frequency = fs 53 | 54 | s = np.array(signal).astype(float) 55 | 56 | numframes, _, __ = speechpy.processing.calculate_number_of_frames( 57 | s, 58 | implementation_version=implementation_version, 59 | sampling_frequency=sampling_frequency, 60 | frame_length=frame_length, 61 | frame_stride=frame_stride, 62 | zero_padding=False) 63 | 64 | if (numframes < 1): 65 | raise ConfigurationError('Frame length is larger than your window size') 66 | 67 | if (numframes > 500): 68 | raise ConfigurationError('Number of frames is larger than 500 (' + str(numframes) + '), ' + 69 | 'increase your frame stride or decrease your window size.') 70 | 71 | frames = speechpy.processing.stack_frames( 72 | s, 73 | implementation_version=implementation_version, 74 | sampling_frequency=sampling_frequency, 75 | frame_length=frame_length, 76 | frame_stride=frame_stride, 77 | filter=lambda x: np.ones((x,)), 78 | zero_padding=False) 79 | 80 | power_spectrum = speechpy.processing.power_spectrum(frames, fft_length) 81 | 82 | if implementation_version < 3: 83 | power_spectrum = (power_spectrum - np.min(power_spectrum)) / (np.max(power_spectrum) - np.min(power_spectrum)) 84 | power_spectrum[np.isnan(power_spectrum)] = 0 85 | else: 86 | # Clip to avoid zero values 87 | power_spectrum = np.clip(power_spectrum, 1e-30, None) 88 | # Convert to dB scale 89 | # log_mel_spec = 10 * log10(mel_spectrograms) 90 | power_spectrum = 10 * np.log10(power_spectrum) 91 | 92 | power_spectrum = (power_spectrum - noise_floor_db) / ((-1 * noise_floor_db) + 12) 93 | max_clip = 1 if implementation_version == 3 else None 94 | power_spectrum = np.clip(power_spectrum, 0, max_clip) 95 | 96 | flattened = power_spectrum.flatten() 97 | features = np.concatenate((features, flattened)) 98 | 99 | width = np.shape(power_spectrum)[0] 100 | height = np.shape(power_spectrum)[1] 101 | 102 | if draw_graphs: 103 | # make visualization too 104 | power_spectrum = np.swapaxes(power_spectrum, 0, 1) 105 | image = graphing.create_sgram_graph(sampling_freq, frame_length, frame_stride, width, height, power_spectrum) 106 | 107 | graphs.append({ 108 | 'name': 'Spectrogram', 109 | 'image': image, 110 | 'imageMimeType': 'image/svg+xml', 111 | 'type': 'image' 112 | }) 113 | 114 | return { 115 | 'features': features.tolist(), 116 | 'graphs': graphs, 117 | 'fft_used': [ fft_length ], 118 | 'output_config': { 119 | 'type': 'spectrogram', 120 | 'shape': { 121 | 'width': width, 122 | 'height': height 123 | } 124 | } 125 | } 126 | 127 | 128 | 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser(description='Spectrogram from sensor data') 133 | parser.add_argument('--features', type=str, required=True, 134 | help='Axis data as a flattened WAV file (pass as comma separated values)') 135 | parser.add_argument('--axes', type=str, required=True, 136 | help='Names of the axis (pass as comma separated values)') 137 | parser.add_argument('--frequency', type=float, required=True, 138 | help='Frequency in hz') 139 | parser.add_argument('--draw-graphs', type=lambda x: (str(x).lower() in ['true','1', 'yes']), required=True, 140 | help='Whether to draw graphs') 141 | parser.add_argument('--frame_length', type=float, default=0.02, 142 | help='The length of each frame in seconds') 143 | parser.add_argument('--frame_stride', type=float, default=0.02, 144 | help='The step between successive frames in seconds') 145 | parser.add_argument('--fft_length', type=int, default=256, 146 | help='Number of FFT points') 147 | parser.add_argument('--noise_floor_db', type=int, default=-52, 148 | help='Everything below this loudness will be dropped') 149 | parser.add_argument('--show-axes', type=lambda x: (str(x).lower() in ['true','1', 'yes']), required=True, 150 | help='Whether to show axes on the graph') 151 | 152 | args = parser.parse_args() 153 | 154 | raw_features = np.array([float(item.strip()) for item in args.features.split(',')]) 155 | raw_axes = args.axes.split(',') 156 | 157 | try: 158 | processed = generate_features(3, args.draw_graphs, raw_features, raw_axes, args.frequency, 159 | args.frame_length, args.frame_stride, args.fft_length, args.show_axes, args.noise_floor_db) 160 | 161 | print('Begin output') 162 | print(json.dumps(processed)) 163 | print('End output') 164 | except Exception as e: 165 | print(e, file=sys.stderr) 166 | exit(1) 167 | -------------------------------------------------------------------------------- /spectrogram/parameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 1, 3 | "info": { 4 | "title": "Spectrogram", 5 | "author": "Edge Impulse", 6 | "description": "Extracts a spectrogram from audio or sensor data, great for non-voice audio or data with continuous frequencies.", 7 | "name": "Spectrogram", 8 | "preferConvolution": true, 9 | "convolutionColumns": "fft_length", 10 | "convolutionKernelSize": 5, 11 | "cppType": "spectrogram_custom", 12 | "visualization": "dimensionalityReduction", 13 | "experimental": false, 14 | "latestImplementationVersion": 4, 15 | "hasAutoTune": true 16 | }, 17 | "parameters": [ 18 | { 19 | "group": "Spectrogram", 20 | "items": [ 21 | { 22 | "name": "Frame length", 23 | "value": 0.02, 24 | "type": "float", 25 | "help": "The length of each frame in seconds", 26 | "param": "frame_length" 27 | }, 28 | { 29 | "name": "Frame stride", 30 | "value": 0.01, 31 | "type": "float", 32 | "help": "The step between successive frames in seconds", 33 | "param": "frame_stride" 34 | }, 35 | { 36 | "name": "FFT length", 37 | "value": 128, 38 | "type": "int", 39 | "help": "Number of frequency bands. Choose a power of two (e.g. 64, 128, 256) for optimal performance on device.", 40 | "param": "fft_length" 41 | } 42 | ] 43 | }, 44 | { 45 | "group": "Normalization", 46 | "items": [ 47 | { 48 | "name": "Noise floor (dB)", 49 | "value": -52, 50 | "type": "int", 51 | "help": "When the spectrogram is rescaled, all frequency bins below this noise floor will be clipped to zero", 52 | "param": "noise_floor_db", 53 | "showForImplementationVersion": [ 3, 4 ] 54 | }, 55 | { 56 | "name": "Show axes", 57 | "value": true, 58 | "type": "boolean", 59 | "help": "Show frequency / time axes in the graphs", 60 | "param": "show_axes", 61 | "showForImplementationVersion": [] 62 | } 63 | ] 64 | } 65 | ] 66 | } 67 | -------------------------------------------------------------------------------- /spectrogram/requirements-blocks.txt: -------------------------------------------------------------------------------- 1 | audioread==2.1.9 2 | librosa==0.8.0 3 | matplotlib==3.5.1 4 | numpy==1.21.5 5 | PeakUtils==1.3.2 6 | Pillow==9.0.1 7 | requests==2.22.0 8 | requests-oauthlib==1.3.0 9 | requests-unixsocket==0.2.0 10 | scikit-learn==1.3.0 11 | scipy==1.7.3 12 | sklearn==0.0 13 | urllib3==1.24.2 14 | PyWavelets==1.3.0 -------------------------------------------------------------------------------- /spectrogram/third_party/speechpy/__init__.py: -------------------------------------------------------------------------------- 1 | from . import feature 2 | from . import processing 3 | -------------------------------------------------------------------------------- /spectrogram/third_party/speechpy/functions.py: -------------------------------------------------------------------------------- 1 | """function module. 2 | 3 | This module contains necessary functions for calculating the features 4 | in the `features` module. 5 | 6 | 7 | Attributes: 8 | 9 | frequency_to_mel: Converting the frequency to Mel scale. 10 | This is necessary for filterbank energy calculation. 11 | mel_to_frequency: Converting the Mel to frequency scale. 12 | This is necessary for filterbank energy calculation. 13 | triangle: Creating a triangle for filterbanks. 14 | This is necessary for filterbank energy calculation. 15 | zero_handling: Handling zero values due to the possible 16 | issues regarding the log functions. 17 | """ 18 | 19 | from __future__ import division 20 | import numpy as np 21 | from . import processing 22 | from scipy.fftpack import dct 23 | import math 24 | 25 | 26 | def frequency_to_mel(f): 27 | """converting from frequency to Mel scale. 28 | 29 | :param f: The frequency values(or a single frequency) in Hz. 30 | :returns: The mel scale values(or a single mel). 31 | """ 32 | return 1127 * np.log(1 + f / 700.) 33 | 34 | 35 | def mel_to_frequency(mel): 36 | """converting from Mel scale to frequency. 37 | 38 | :param mel: The mel scale values(or a single mel). 39 | :returns: The frequency values(or a single frequency) in Hz. 40 | """ 41 | return 700 * (np.exp(mel / 1127.0) - 1) 42 | 43 | 44 | def triangle(x, left, middle, right): 45 | out = np.zeros(x.shape) 46 | out[x <= left] = 0 47 | out[x >= right] = 0 48 | first_half = np.logical_and(left < x, x <= middle) 49 | out[first_half] = (x[first_half] - left) / (middle - left) 50 | second_half = np.logical_and(middle <= x, x < right) 51 | out[second_half] = (right - x[second_half]) / (right - middle) 52 | return out 53 | 54 | 55 | def zero_handling(x): 56 | """ 57 | This function handle the issue with zero values if the are exposed 58 | to become an argument for any log function. 59 | :param x: The vector. 60 | :return: The vector with zeros substituted with epsilon values. 61 | """ 62 | return np.where(x == 0, 1e-10, x) 63 | -------------------------------------------------------------------------------- /tests/test_spectrogram.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import numpy as np 3 | 4 | import pathlib 5 | import sys 6 | 7 | ROOT = pathlib.Path(__file__).parent 8 | sys.path.append(str(ROOT / "../spectrogram")) 9 | 10 | from spectrogram.dsp import generate_features 11 | 12 | 13 | class TestSpectrogram(TestCase): 14 | def test_generate_features(self): 15 | # Test input is an 11 Hz sine wave sampled at 100 Hz, 2s in length 16 | # fmt: off 17 | raw_data = np.array([0.2470, -0.4304, -0.9923, -1.0102, -0.5687, 0.1191, 0.7667, 1.0614, 0.7940, 0.0991, -0.4816, -0.9524, -0.9578, -0.6449, 0.2309, 0.7219, 1.0711, 0.8023, 0.0542, -0.4745, -1.0562, -0.8612, -0.2869, 0.1635, 0.7713, 0.9633, 0.6966, 0.1055, -0.6854, -0.9501, -0.7795, -0.4141, 0.2724, 0.8313, 0.9308, 0.6764, -0.0987, -0.5782, -0.9305, -0.8633, -0.3703, 0.4105, 0.8242, 0.9785, 0.4415, -0.0183, -0.6195, -1.0339, -0.8468, -0.2340, 0.3867, 0.9366, 0.9824, 0.6682, -0.1983, -0.8198, -0.9776, -0.7510, -0.1863, 0.5749, 0.8727, 0.9547, 0.4211, -0.2702, -0.7578, -0.9896, -0.6869, -0.0612, 0.6043, 0.9324, 0.9054, 0.3475, -0.1959, -0.8891, -1.0128, -0.6421, -0.0245, 0.5605, 1.0065, 0.8225, 0.3661, -0.3759, -0.8418, -1.0468, -0.6020, 0.0375, 0.7299, 1.0559, 0.8837, 0.2379, -0.3551, -0.9804, -0.9239, -0.4617, -0.0571, 0.6450, 0.9617, 0.7370, 0.4257, -0.3916, -0.8378, -1.0119, -0.4930, 0.2098, 0.7428, 0.9908, 0.7492, 0.2576, -0.5474, -0.9514, -0.9384, -0.5748, 0.1881, 0.7462, 0.8847, 0.8915, 0.3235, -0.4997, -0.9000, -0.7715, -0.3475, 0.1775, 0.8012, 0.9421, 0.7048, 0.0897, -0.6061, -0.9405, -0.8668, -0.4114, 0.2920, 0.8920, 1.0544, 0.6019, 0.0306, -0.5786, -0.9819, -0.8275, -0.4641, 0.3465, 0.8204, 0.9212, 0.6782, -0.1161, -0.7370, -1.0199, -0.7616, -0.1773, 0.3927, 0.9146, 0.9817, 0.4189, 0.0252, -0.7683, -1.0024, -0.7143, -0.2131, 0.5674, 0.9643, 0.9882, 0.4813, -0.2866, -0.9017, -1.0339, -0.7639, -0.1421, 0.6644, 0.9378, 0.8162, 0.4840, -0.2752, -0.8332, -1.0306, -0.6984, -0.1740, 0.5063, 1.0416, 0.8551, 0.4422, -0.2983, -0.8524, -1.0125, -0.6084, -0.1480, 0.5300, 0.9781, 0.8452, 0.4319, -0.4447, -0.8232, -0.9779, -0.5672, 0.0490, 0.6888, 0.9783, 0.7655, 0.3662, -0.4595, -0.9232, -0.9419]) 18 | # fmt: on 19 | 20 | # ---Here is an example. It's helpful to open up parameters.json for this block while debugging or experimenting 21 | # Always set to max, can look in parameters.json for 'latestImplementationVersion' for this value 22 | implementation_version = 4 23 | # Set to false, the graphs are for Studio 24 | draw_graphs = False 25 | # The actual axes names don't matter, just needs to be the correct dimension 26 | axes = ["x"] 27 | # Sampling frequency in Hz 28 | sampling_freq = 100 29 | # Frame length in seconds 30 | frame_length = 0.5 31 | # Frame stride in seconds 32 | frame_stride = 0.25 33 | # FFT length in samples 34 | fft_length = 64 35 | # The next param is deprecated and ignored 36 | show_axes = False 37 | # Noise floor in dB 38 | noise_floor_db = -100 39 | 40 | # Call the function 41 | output = generate_features( 42 | implementation_version, 43 | draw_graphs, 44 | raw_data, 45 | axes, 46 | sampling_freq, 47 | frame_length, 48 | frame_stride, 49 | fft_length, 50 | show_axes, 51 | noise_floor_db, 52 | ) 53 | 54 | # Return signature: 55 | # return { 56 | # 'features': features.tolist(), 57 | # 'graphs': graphs, 58 | # 'fft_used': [ fft_length ], 59 | # 'output_config': { 60 | # 'type': 'spectrogram', 61 | # 'shape': { 62 | # 'width': width, 63 | # 'height': height 64 | # } 65 | # } 66 | # } 67 | 68 | # fmt: off 69 | expected = [0.8195, 0.7850, 0.5811, 0.7853, 0.8354, 0.8515, 0.7983, 0.9777, 0.9165, 0.8571, 0.7664, 0.6709, 0.7731, 0.7659, 0.7097, 0.6908, 0.7016, 0.7639, 0.7129, 0.5968, 0.6873, 0.7006, 0.6397, 0.6777, 0.7492, 0.6527, 0.6825, 0.7065, 0.6699, 0.6136, 0.6198, 0.7193, 0.4839, 0.6780, 0.6322, 0.7368, 0.7727, 0.8103, 0.8045, 0.7375, 0.9778, 0.9237, 0.8647, 0.8136, 0.6613, 0.7581, 0.7919, 0.7856, 0.7430, 0.6846, 0.7778, 0.7436, 0.7450, 0.6468, 0.6677, 0.7298, 0.7332, 0.7494, 0.6736, 0.6998, 0.7409, 0.7344, 0.4676, 0.6534, 0.6759, 0.7638, 0.8044, 0.7824, 0.6644, 0.7783, 0.8355, 0.8488, 0.8044, 0.9782, 0.9170, 0.8574, 0.7677, 0.6678, 0.7431, 0.7740, 0.7315, 0.6559, 0.7158, 0.7139, 0.7017, 0.5952, 0.6497, 0.5169, 0.6258, 0.7423, 0.7071, 0.7234, 0.7226, 0.6883, 0.7006, 0.6238, 0.6974, 0.7119, 0.7230, 0.8028, 0.7622, 0.7143, 0.7863, 0.7969, 0.8432, 0.8097, 0.9779, 0.9147, 0.8645, 0.8004, 0.6934, 0.7342, 0.7801, 0.7360, 0.7471, 0.6276, 0.7499, 0.7471, 0.6926, 0.7103, 0.6706, 0.6509, 0.7475, 0.7017, 0.6659, 0.6599, 0.7369, 0.6544, 0.7136, 0.5444, 0.7089, 0.7517, 0.7059, 0.7377, 0.7371, 0.7651, 0.8041, 0.8328, 0.7641, 0.9777, 0.9199, 0.8692, 0.8105, 0.6651, 0.7809, 0.7708, 0.7691, 0.7490, 0.7249, 0.7014, 0.7588, 0.7547, 0.6342, 0.7113, 0.7279, 0.7338, 0.6907, 0.5643, 0.7388, 0.7379, 0.7126, 0.6640, 0.6026, 0.7076, 0.7251, 0.8158, 0.7901, 0.6783, 0.7729, 0.8418, 0.8571, 0.7927, 0.9780, 0.9202, 0.8503, 0.7343, 0.6806, 0.7719, 0.7489, 0.7017, 0.5085, 0.6762, 0.7051, 0.7086, 0.7010, 0.6236, 0.6669, 0.7057, 0.6899, 0.7028, 0.6533, 0.7391, 0.6289, 0.6906, 0.6240, 0.6509, 0.6542, 0.6031, 0.7381, 0.6577, 0.7311, 0.7704, 0.8040, 0.8146, 0.7908, 0.9798, 0.9189, 0.8765, 0.8038, 0.5722, 0.7309, 0.7969, 0.7829, 0.6880, 0.5730, 0.7496, 0.7476, 0.7322, 0.6958, 0.7110, 0.6818, 0.7145, 0.7455, 0.6496, 0.6827, 0.7271, 0.7744, 0.7276, 0.6951, 0.7265, 0.7406] 70 | self.assertTrue(np.allclose(output['features'], expected, atol=0.001)) 71 | # fmt: on 72 | --------------------------------------------------------------------------------