├── .gitignore ├── README.md ├── audioset ├── mel_features.py ├── vggish_input.py ├── vggish_params.py ├── vggish_postprocess.py └── vggish_slim.py ├── convert_to_pytorch.py ├── example_usage.py └── vggish.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *class java.lang.String 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # PyCharm 126 | .idea/ 127 | 128 | # Sandbox 129 | sandbox/ 130 | 131 | # Media 132 | *.jpg 133 | *.png 134 | *.jpeg 135 | *.avi 136 | *.mp4 137 | *.mp3 138 | *.pkl 139 | *.pt 140 | *.log 141 | *.ckpt 142 | *.pth 143 | *.zip 144 | *.npz 145 | 146 | # MNIST 147 | *ubyte -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AudioSet VGGish in PyTorch 2 | 3 | 4 | ## Introduction 5 | This repository includes: 6 | - A script which converts the pretrained VGGish model provided in the AudioSet repository from TensorFlow to PyTorch 7 | (along with a basic smoke test). 8 | **Sourced from:** https://github.com/tensorflow/models/tree/master/research/audioset 9 | - The VGGish architecture defined in PyTorch. 10 | **Adapted from:** https://github.com/harritaylor/torchvggish 11 | - The converted weights found in the [Releases](https://github.com/tcvrick/audioset-vggish-tensorflow-to-pytorch/releases) section. 12 | 13 | Please note that converted model does not produce exactly the same results as the original model, but should be 14 | close in most cases. 15 | 16 | ## Usage 17 | 1. Download the pretrained weights and PCA parameters from the [AudioSet](https://github.com/tensorflow/models/tree/master/research/audioset) repository and place them in the working directory. 18 | 2. Install any dependencies required by [AudioSet](https://github.com/tensorflow/models/tree/master/research/audioset) (e.g., resampy, numpy, TensorFlow, etc.). 19 | 3. Run **"convert_to_pytorch.py"** to generate the PyTorch formatted weights for the VGGish model or download 20 | the weights from the [Releases](https://github.com/tcvrick/audioset-vggish-tensorflow-to-pytorch/releases) section. 21 | 22 | ## Example Usage 23 | Please refer to the **"example_usage.py"** script. The output of the script should be as follows. 24 | 25 | ``` 26 | Input Shape: (3, 1, 96, 64) 27 | Output Shape: (3, 128) 28 | Computed Embedding Mean and Standard Deviation: 0.13079901 0.23851949 29 | Expected Embedding Mean and Standard Deviation: 0.131 0.238 30 | Computed Post-processed Embedding Mean and Standard Deviation: 123.01041666666667 75.51479501722199 31 | Expected Post-processed Embedding Mean and Standard Deviation: 123.0 75.0 32 | ``` -------------------------------------------------------------------------------- /audioset/mel_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines routines to compute mel spectrogram features from audio waveform.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def frame(data, window_length, hop_length): 22 | """Convert array into a sequence of successive possibly overlapping frames. 23 | 24 | An n-dimensional array of shape (num_samples, ...) is converted into an 25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame 26 | starts hop_length points after the preceding one. 27 | 28 | This is accomplished using stride_tricks, so the original data is not 29 | copied. However, there is no zero-padding, so any incomplete frames at the 30 | end are not included. 31 | 32 | Args: 33 | data: np.array of dimension N >= 1. 34 | window_length: Number of samples in each frame. 35 | hop_length: Advance (in samples) between each window. 36 | 37 | Returns: 38 | (N+1)-D np.array with as many rows as there are complete frames that can be 39 | extracted. 40 | """ 41 | num_samples = data.shape[0] 42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) 43 | shape = (num_frames, window_length) + data.shape[1:] 44 | strides = (data.strides[0] * hop_length,) + data.strides 45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) 46 | 47 | 48 | def periodic_hann(window_length): 49 | """Calculate a "periodic" Hann window. 50 | 51 | The classic Hann window is defined as a raised cosine that starts and 52 | ends on zero, and where every value appears twice, except the middle 53 | point for an odd-length window. Matlab calls this a "symmetric" window 54 | and np.hanning() returns it. However, for Fourier analysis, this 55 | actually represents just over one cycle of a period N-1 cosine, and 56 | thus is not compactly expressed on a length-N Fourier basis. Instead, 57 | it's better to use a raised cosine that ends just before the final 58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab 59 | calls this a "periodic" window. This routine calculates it. 60 | 61 | Args: 62 | window_length: The number of points in the returned window. 63 | 64 | Returns: 65 | A 1D np.array containing the periodic hann window. 66 | """ 67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * 68 | np.arange(window_length))) 69 | 70 | 71 | def stft_magnitude(signal, fft_length, 72 | hop_length=None, 73 | window_length=None): 74 | """Calculate the short-time Fourier transform magnitude. 75 | 76 | Args: 77 | signal: 1D np.array of the input time-domain signal. 78 | fft_length: Size of the FFT to apply. 79 | hop_length: Advance (in samples) between each frame passed to FFT. 80 | window_length: Length of each block of samples to pass to FFT. 81 | 82 | Returns: 83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1 84 | unique values of the FFT for the corresponding frame of input samples. 85 | """ 86 | frames = frame(signal, window_length, hop_length) 87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period 88 | # window_length) instead of the symmetric Hann of np.hanning (period 89 | # window_length-1). 90 | window = periodic_hann(window_length) 91 | windowed_frames = frames * window 92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) 93 | 94 | 95 | # Mel spectrum constants and functions. 96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 97 | _MEL_HIGH_FREQUENCY_Q = 1127.0 98 | 99 | 100 | def hertz_to_mel(frequencies_hertz): 101 | """Convert frequencies to mel scale using HTK formula. 102 | 103 | Args: 104 | frequencies_hertz: Scalar or np.array of frequencies in hertz. 105 | 106 | Returns: 107 | Object of same size as frequencies_hertz containing corresponding values 108 | on the mel scale. 109 | """ 110 | return _MEL_HIGH_FREQUENCY_Q * np.log( 111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 112 | 113 | 114 | def spectrogram_to_mel_matrix(num_mel_bins=20, 115 | num_spectrogram_bins=129, 116 | audio_sample_rate=8000, 117 | lower_edge_hertz=125.0, 118 | upper_edge_hertz=3800.0): 119 | """Return a matrix that can post-multiply spectrogram rows to make mel. 120 | 121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of 122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a 123 | "mel spectrogram" M of frames x num_mel_bins. M = S A. 124 | 125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands 126 | to multiply each FFT bin by only one mel weight, then add it, with positive 127 | and negative signs, to the two adjacent mel bands to which that bin 128 | contributes. Here, by expressing this operation as a matrix multiply, we go 129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around 130 | num_fft^2 multiplies and adds. However, because these are all presumably 131 | accomplished in a single call to np.dot(), it's not clear which approach is 132 | faster in Python. The matrix multiplication has the attraction of being more 133 | general and flexible, and much easier to read. 134 | 135 | Args: 136 | num_mel_bins: How many bands in the resulting mel spectrum. This is 137 | the number of columns in the output matrix. 138 | num_spectrogram_bins: How many bins there are in the source spectrogram 139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram 140 | only contains the nonredundant FFT bins. 141 | audio_sample_rate: Samples per second of the audio at the input to the 142 | spectrogram. We need this to figure out the actual frequencies for 143 | each spectrogram bin, which dictates how they are mapped into mel. 144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel 145 | spectrum. This corresponds to the lower edge of the lowest triangular 146 | band. 147 | upper_edge_hertz: The desired top edge of the highest frequency band. 148 | 149 | Returns: 150 | An np.array with shape (num_spectrogram_bins, num_mel_bins). 151 | 152 | Raises: 153 | ValueError: if frequency edges are incorrectly ordered or out of range. 154 | """ 155 | nyquist_hertz = audio_sample_rate / 2. 156 | if lower_edge_hertz < 0.0: 157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) 158 | if lower_edge_hertz >= upper_edge_hertz: 159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % 160 | (lower_edge_hertz, upper_edge_hertz)) 161 | if upper_edge_hertz > nyquist_hertz: 162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % 163 | (upper_edge_hertz, nyquist_hertz)) 164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) 165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) 166 | # The i'th mel band (starting from i=1) has center frequency 167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge 168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in 169 | # the band_edges_mel arrays. 170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), 171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) 172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins 173 | # of spectrogram values. 174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) 175 | for i in range(num_mel_bins): 176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] 177 | # Calculate lower and upper slopes for every spectrogram bin. 178 | # Line segments are linear in the *mel* domain, not hertz. 179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / 180 | (center_mel - lower_edge_mel)) 181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / 182 | (upper_edge_mel - center_mel)) 183 | # .. then intersect them with each other and zero. 184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, 185 | upper_slope)) 186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero 187 | # coefficient. 188 | mel_weights_matrix[0, :] = 0.0 189 | return mel_weights_matrix 190 | 191 | 192 | def log_mel_spectrogram(data, 193 | audio_sample_rate=8000, 194 | log_offset=0.0, 195 | window_length_secs=0.025, 196 | hop_length_secs=0.010, 197 | **kwargs): 198 | """Convert waveform to a log magnitude mel-frequency spectrogram. 199 | 200 | Args: 201 | data: 1D np.array of waveform data. 202 | audio_sample_rate: The sampling rate of data. 203 | log_offset: Add this to values when taking log to avoid -Infs. 204 | window_length_secs: Duration of each window to analyze. 205 | hop_length_secs: Advance between successive analysis windows. 206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. 207 | 208 | Returns: 209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank 210 | magnitudes for successive frames. 211 | """ 212 | window_length_samples = int(round(audio_sample_rate * window_length_secs)) 213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) 214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) 215 | spectrogram = stft_magnitude( 216 | data, 217 | fft_length=fft_length, 218 | hop_length=hop_length_samples, 219 | window_length=window_length_samples) 220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( 221 | num_spectrogram_bins=spectrogram.shape[1], 222 | audio_sample_rate=audio_sample_rate, **kwargs)) 223 | return np.log(mel_spectrogram + log_offset) 224 | -------------------------------------------------------------------------------- /audioset/vggish_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute input examples for VGGish from audio waveform.""" 17 | 18 | import numpy as np 19 | import resampy 20 | 21 | from audioset import mel_features 22 | from audioset import vggish_params 23 | 24 | import soundfile as sf 25 | 26 | 27 | def waveform_to_examples(data, sample_rate): 28 | """Converts audio waveform into an array of examples for VGGish. 29 | 30 | Args: 31 | data: np.array of either one dimension (mono) or two dimensions 32 | (multi-channel, with the outer dimension representing channels). 33 | Each sample is generally expected to lie in the range [-1.0, +1.0], 34 | although this is not required. 35 | sample_rate: Sample rate of data. 36 | 37 | Returns: 38 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents 39 | a sequence of examples, each of which contains a patch of log mel 40 | spectrogram, covering num_frames frames of audio and num_bands mel frequency 41 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. 42 | """ 43 | # Convert to mono. 44 | if len(data.shape) > 1: 45 | data = np.mean(data, axis=1) 46 | # Resample to the rate assumed by VGGish. 47 | if sample_rate != vggish_params.SAMPLE_RATE: 48 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) 49 | 50 | # Compute log mel spectrogram features. 51 | log_mel = mel_features.log_mel_spectrogram( 52 | data, 53 | audio_sample_rate=vggish_params.SAMPLE_RATE, 54 | log_offset=vggish_params.LOG_OFFSET, 55 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, 56 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, 57 | num_mel_bins=vggish_params.NUM_MEL_BINS, 58 | lower_edge_hertz=vggish_params.MEL_MIN_HZ, 59 | upper_edge_hertz=vggish_params.MEL_MAX_HZ) 60 | 61 | # Frame features into examples. 62 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS 63 | example_window_length = int(round( 64 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) 65 | example_hop_length = int(round( 66 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) 67 | log_mel_examples = mel_features.frame( 68 | log_mel, 69 | window_length=example_window_length, 70 | hop_length=example_hop_length) 71 | return log_mel_examples 72 | 73 | 74 | def wavfile_to_examples(wav_file): 75 | """Convenience wrapper around waveform_to_examples() for a common WAV format. 76 | 77 | Args: 78 | wav_file: String path to a file, or a file-like object. The file 79 | is assumed to contain WAV audio data with signed 16-bit PCM samples. 80 | 81 | Returns: 82 | See waveform_to_examples. 83 | """ 84 | wav_data, sr = sf.read(wav_file, dtype='int16') 85 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype 86 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] 87 | return waveform_to_examples(samples, sr) 88 | -------------------------------------------------------------------------------- /audioset/vggish_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Global parameters for the VGGish model. 17 | 18 | See vggish_slim.py for more information. 19 | """ 20 | 21 | # Architectural constants. 22 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. 23 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. 24 | EMBEDDING_SIZE = 128 # Size of embedding layer. 25 | 26 | # Hyperparameters used in feature and example generation. 27 | SAMPLE_RATE = 16000 28 | STFT_WINDOW_LENGTH_SECONDS = 0.025 29 | STFT_HOP_LENGTH_SECONDS = 0.010 30 | NUM_MEL_BINS = NUM_BANDS 31 | MEL_MIN_HZ = 125 32 | MEL_MAX_HZ = 7500 33 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. 34 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames 35 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. 36 | 37 | # Parameters used for embedding postprocessing. 38 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' 39 | PCA_MEANS_NAME = 'pca_means' 40 | QUANTIZE_MIN_VAL = -2.0 41 | QUANTIZE_MAX_VAL = +2.0 42 | 43 | # Hyperparameters used in training. 44 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 45 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. 46 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 47 | 48 | # Names of ops, tensors, and features. 49 | INPUT_OP_NAME = 'vggish/input_features' 50 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' 51 | OUTPUT_OP_NAME = 'vggish/embedding' 52 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' 53 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' 54 | -------------------------------------------------------------------------------- /audioset/vggish_postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Post-process embeddings from VGGish.""" 17 | 18 | import numpy as np 19 | 20 | from audioset import vggish_params 21 | 22 | 23 | class Postprocessor(object): 24 | """Post-processes VGGish embeddings. 25 | 26 | The initial release of AudioSet included 128-D VGGish embeddings for each 27 | segment of AudioSet. These released embeddings were produced by applying 28 | a PCA transformation (technically, a whitening transform is included as well) 29 | and 8-bit quantization to the raw embedding output from VGGish, in order to 30 | stay compatible with the YouTube-8M project which provides visual embeddings 31 | in the same format for a large set of YouTube videos. This class implements 32 | the same PCA (with whitening) and quantization transformations. 33 | """ 34 | 35 | def __init__(self, pca_params_npz_path): 36 | """Constructs a postprocessor. 37 | 38 | Args: 39 | pca_params_npz_path: Path to a NumPy-format .npz file that 40 | contains the PCA parameters used in postprocessing. 41 | """ 42 | params = np.load(pca_params_npz_path) 43 | self._pca_matrix = params[vggish_params.PCA_EIGEN_VECTORS_NAME] 44 | # Load means into a column vector for easier broadcasting later. 45 | self._pca_means = params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1) 46 | assert self._pca_matrix.shape == ( 47 | vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE), ( 48 | 'Bad PCA matrix shape: %r' % (self._pca_matrix.shape,)) 49 | assert self._pca_means.shape == (vggish_params.EMBEDDING_SIZE, 1), ( 50 | 'Bad PCA means shape: %r' % (self._pca_means.shape,)) 51 | 52 | def postprocess(self, embeddings_batch): 53 | """Applies postprocessing to a batch of embeddings. 54 | 55 | Args: 56 | embeddings_batch: An nparray of shape [batch_size, embedding_size] 57 | containing output from the embedding layer of VGGish. 58 | 59 | Returns: 60 | An nparray of the same shape as the input but of type uint8, 61 | containing the PCA-transformed and quantized version of the input. 62 | """ 63 | assert len(embeddings_batch.shape) == 2, ( 64 | 'Expected 2-d batch, got %r' % (embeddings_batch.shape,)) 65 | assert embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE, ( 66 | 'Bad batch shape: %r' % (embeddings_batch.shape,)) 67 | 68 | # Apply PCA. 69 | # - Embeddings come in as [batch_size, embedding_size]. 70 | # - Transpose to [embedding_size, batch_size]. 71 | # - Subtract pca_means column vector from each column. 72 | # - Premultiply by PCA matrix of shape [output_dims, input_dims] 73 | # where both are are equal to embedding_size in our case. 74 | # - Transpose result back to [batch_size, embedding_size]. 75 | pca_applied = np.dot(self._pca_matrix, 76 | (embeddings_batch.T - self._pca_means)).T 77 | 78 | # Quantize by: 79 | # - clipping to [min, max] range 80 | clipped_embeddings = np.clip( 81 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, 82 | vggish_params.QUANTIZE_MAX_VAL) 83 | # - convert to 8-bit in range [0.0, 255.0] 84 | quantized_embeddings = ( 85 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) * 86 | (255.0 / 87 | (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL))) 88 | # - cast 8-bit float to uint8 89 | quantized_embeddings = quantized_embeddings.astype(np.uint8) 90 | 91 | return quantized_embeddings 92 | -------------------------------------------------------------------------------- /audioset/vggish_slim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines the 'VGGish' model used to generate AudioSet embedding features. 17 | 18 | The public AudioSet release (https://research.google.com/audioset/download.html) 19 | includes 128-D features extracted from the embedding layer of a VGG-like model 20 | that was trained on a large Google-internal YouTube dataset. Here we provide 21 | a TF-Slim definition of the same model, without any dependences on libraries 22 | internal to Google. We call it 'VGGish'. 23 | 24 | Note that we only define the model up to the embedding layer, which is the 25 | penultimate layer before the final classifier layer. We also provide various 26 | hyperparameter values (in vggish_params.py) that were used to train this model 27 | internally. 28 | 29 | For comparison, here is TF-Slim's VGG definition: 30 | https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py 31 | """ 32 | 33 | import tensorflow as tf 34 | from audioset import vggish_params as params 35 | 36 | slim = tf.contrib.slim 37 | 38 | 39 | def define_vggish_slim(training=False): 40 | """Defines the VGGish TensorFlow model. 41 | 42 | All ops are created in the current default graph, under the scope 'vggish/'. 43 | 44 | The input is a placeholder named 'vggish/input_features' of type float32 and 45 | shape [batch_size, num_frames, num_bands] where batch_size is variable and 46 | num_frames and num_bands are constants, and [num_frames, num_bands] represents 47 | a log-mel-scale spectrogram patch covering num_bands frequency bands and 48 | num_frames time frames (where each frame step is usually 10ms). This is 49 | produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET). 50 | The output is an op named 'vggish/embedding' which produces the activations of 51 | a 128-D embedding layer, which is usually the penultimate layer when used as 52 | part of a full model with a final classifier layer. 53 | 54 | Args: 55 | training: If true, all parameters are marked trainable. 56 | 57 | Returns: 58 | The op 'vggish/embeddings'. 59 | """ 60 | # Defaults: 61 | # - All weights are initialized to N(0, INIT_STDDEV). 62 | # - All biases are initialized to 0. 63 | # - All activations are ReLU. 64 | # - All convolutions are 3x3 with stride 1 and SAME padding. 65 | # - All max-pools are 2x2 with stride 2 and SAME padding. 66 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 67 | weights_initializer=tf.truncated_normal_initializer( 68 | stddev=params.INIT_STDDEV), 69 | biases_initializer=tf.zeros_initializer(), 70 | activation_fn=tf.nn.relu, 71 | trainable=training), \ 72 | slim.arg_scope([slim.conv2d], 73 | kernel_size=[3, 3], stride=1, padding='SAME'), \ 74 | slim.arg_scope([slim.max_pool2d], 75 | kernel_size=[2, 2], stride=2, padding='SAME'), \ 76 | tf.variable_scope('vggish'): 77 | # Input: a batch of 2-D log-mel-spectrogram patches. 78 | features = tf.placeholder( 79 | tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS), 80 | name='input_features') 81 | # Reshape to 4-D so that we can convolve a batch with conv2d(). 82 | net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1]) 83 | 84 | # The VGG stack of alternating convolutions and max-pools. 85 | net = slim.conv2d(net, 64, scope='conv1') 86 | net = slim.max_pool2d(net, scope='pool1') 87 | net = slim.conv2d(net, 128, scope='conv2') 88 | net = slim.max_pool2d(net, scope='pool2') 89 | net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3') 90 | net = slim.max_pool2d(net, scope='pool3') 91 | net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4') 92 | net = slim.max_pool2d(net, scope='pool4') 93 | 94 | # Flatten before entering fully-connected layers 95 | net = slim.flatten(net) 96 | net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1') 97 | # The embedding layer. 98 | net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2') 99 | return tf.identity(net, name='embedding') 100 | 101 | 102 | def load_vggish_slim_checkpoint(session, checkpoint_path): 103 | """Loads a pre-trained VGGish-compatible checkpoint. 104 | 105 | This function can be used as an initialization function (referred to as 106 | init_fn in TensorFlow documentation) which is called in a Session after 107 | initializating all variables. When used as an init_fn, this will load 108 | a pre-trained checkpoint that is compatible with the VGGish model 109 | definition. Only variables defined by VGGish will be loaded. 110 | 111 | Args: 112 | session: an active TensorFlow session. 113 | checkpoint_path: path to a file containing a checkpoint that is 114 | compatible with the VGGish model definition. 115 | """ 116 | # Get the list of names of all VGGish variables that exist in 117 | # the checkpoint (i.e., all inference-mode VGGish variables). 118 | with tf.Graph().as_default(): 119 | define_vggish_slim(training=False) 120 | vggish_var_names = [v.name for v in tf.global_variables()] 121 | 122 | # Get the list of all currently existing variables that match 123 | # the list of variable names we just computed. 124 | vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names] 125 | 126 | # Use a Saver to restore just the variables selected above. 127 | saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained', 128 | write_version=1) 129 | saver.restore(session, checkpoint_path) 130 | -------------------------------------------------------------------------------- /convert_to_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from vggish import VGGish 6 | from audioset import vggish_params, vggish_slim, vggish_input 7 | 8 | 9 | """ 10 | Script which converts the pretrained TensorFlow implementation of VGGish to a PyTorch equivalent, along with 11 | a basic smoke test to verify accuracy. 12 | """ 13 | 14 | 15 | def main(): 16 | with tf.Graph().as_default(), tf.Session() as sess: 17 | # ------------------- 18 | # Step 1 19 | # ------------------- 20 | # Load the model. 21 | vggish_slim.define_vggish_slim(training=False) 22 | vggish_slim.load_vggish_slim_checkpoint(sess, 'vggish_model.ckpt') 23 | 24 | # Get all of the variables, and use this to construct a dictionary which maps 25 | # the name of the variables to their values. 26 | variables = tf.all_variables() 27 | variables = [x.name for x in variables] 28 | variable_values = sess.run(variables) 29 | variable_dict = dict(zip(variables, variable_values)) 30 | 31 | # Create a new state dictionary which maps the TensorFlow version of the weights 32 | # to those in in the new PyTorch model. 33 | pytorch_model = VGGish() 34 | pytorch_feature_dict = pytorch_model.features.state_dict() 35 | pytorch_fc_dict = pytorch_model.fc.state_dict() 36 | 37 | # ------------------- 38 | # Step 2 39 | # ------------------- 40 | # There is a bias and weight vector for each convolution layer. The weights are not necessarily stored 41 | # in the same format and order between the two frameworks; for the TensorFlow model, the 12 vectors for the 42 | # convolution layers are first, followed by the 6 FC layers. 43 | tf_feature_names = list(variable_dict.keys())[:-6] 44 | tf_fc_names = list(variable_dict.keys())[-6:] 45 | 46 | def to_pytorch_tensor(weights): 47 | if len(weights.shape) == 4: 48 | tensor = torch.from_numpy(weights.transpose(3, 2, 0, 1)).float() 49 | else: 50 | tensor = torch.from_numpy(weights.T).float() 51 | return tensor 52 | 53 | # Convert the weights for the convolution layers. 54 | for tf_name, pytorch_name in zip(tf_feature_names, pytorch_feature_dict.keys()): 55 | print(f'Converting [{tf_name}] ----------> [feature.{pytorch_name}]') 56 | pytorch_feature_dict[pytorch_name] = to_pytorch_tensor(variable_dict[tf_name]) 57 | 58 | # Convert the weights for the FC layers. 59 | for tf_name, pytorch_name in zip(tf_fc_names, pytorch_fc_dict.keys()): 60 | print(f'Converting [{tf_name}] ----------> [fc.{pytorch_name}]') 61 | pytorch_fc_dict[pytorch_name] = to_pytorch_tensor(variable_dict[tf_name]) 62 | 63 | # ------------------- 64 | # Step 3 65 | # ------------------- 66 | # Load the new state dictionaries into the PyTorch model. 67 | pytorch_model.features.load_state_dict(pytorch_feature_dict) 68 | pytorch_model.fc.load_state_dict(pytorch_fc_dict) 69 | 70 | # ------------------- 71 | # Step 4 72 | # ------------------- 73 | # Generate a sample input (as in the AudioSet repo smoke test). 74 | num_secs = 3 75 | freq = 1000 76 | sr = 44100 77 | t = np.linspace(0, num_secs, int(num_secs * sr)) 78 | x = np.sin(2 * np.pi * freq * t) 79 | 80 | # Produce a batch of log mel spectrogram examples. 81 | input_batch = vggish_input.waveform_to_examples(x, sr) 82 | 83 | # Run inference on the TensorFlow model. 84 | features_tensor = sess.graph.get_tensor_by_name( 85 | vggish_params.INPUT_TENSOR_NAME) 86 | embedding_tensor = sess.graph.get_tensor_by_name( 87 | vggish_params.OUTPUT_TENSOR_NAME) 88 | [tf_output] = sess.run([embedding_tensor], 89 | feed_dict={features_tensor: input_batch}) 90 | 91 | # Run on the PyTorch model. 92 | pytorch_model = pytorch_model.to('cpu') 93 | pytorch_output = pytorch_model(torch.from_numpy(input_batch).unsqueeze(dim=1).float()) 94 | pytorch_output = pytorch_output.detach().numpy() 95 | 96 | # ------------------- 97 | # Step 5 98 | # ------------------- 99 | # Compare the difference between the outputs. 100 | diff = np.linalg.norm(pytorch_output - tf_output) ** 2 101 | print(f'Distance between TensorFlow and PyTorch outputs: [{diff}]') 102 | assert diff < 1e-6 103 | 104 | # Run a smoke test. 105 | expected_embedding_mean = 0.131 106 | expected_embedding_std = 0.238 107 | 108 | # Verify the TF output. 109 | np.testing.assert_allclose( 110 | [np.mean(tf_output), np.std(tf_output)], 111 | [expected_embedding_mean, expected_embedding_std], 112 | rtol=0.001) 113 | 114 | # Verify the PyTorch output. 115 | np.testing.assert_allclose( 116 | [np.mean(pytorch_output), np.std(pytorch_output)], 117 | [expected_embedding_mean, expected_embedding_std], 118 | rtol=0.001) 119 | 120 | # ------------------- 121 | # Step 6 122 | # ------------------- 123 | print('Smoke test passed! Saving PyTorch weights to "pytorch_vggish.pth".') 124 | torch.save(pytorch_model.state_dict(), 'pytorch_vggish.pth') 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /example_usage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from vggish import VGGish 5 | from audioset import vggish_input, vggish_postprocess 6 | 7 | 8 | def main(): 9 | # Initialize the PyTorch model. 10 | device = 'cuda:0' 11 | pytorch_model = VGGish() 12 | pytorch_model.load_state_dict(torch.load('pytorch_vggish.pth')) 13 | pytorch_model = pytorch_model.to(device) 14 | 15 | # Generate a sample input (as in the AudioSet repo smoke test). 16 | num_secs = 3 17 | freq = 1000 18 | sr = 44100 19 | t = np.linspace(0, num_secs, int(num_secs * sr)) 20 | x = np.sin(2 * np.pi * freq * t) 21 | 22 | # Produce a batch of log mel spectrogram examples. 23 | input_batch = vggish_input.waveform_to_examples(x, sr) 24 | input_batch = torch.from_numpy(input_batch).unsqueeze(dim=1) 25 | input_batch = input_batch.float().to(device) 26 | 27 | # Run the PyTorch model. 28 | pytorch_output = pytorch_model(input_batch) 29 | pytorch_output = pytorch_output.detach().cpu().numpy() 30 | print('Input Shape:', tuple(input_batch.shape)) 31 | print('Output Shape:', tuple(pytorch_output.shape)) 32 | 33 | expected_embedding_mean = 0.131 34 | expected_embedding_std = 0.238 35 | print('Computed Embedding Mean and Standard Deviation:', np.mean(pytorch_output), np.std(pytorch_output)) 36 | print('Expected Embedding Mean and Standard Deviation:', expected_embedding_mean, expected_embedding_std) 37 | 38 | # Post-processing. 39 | post_processor = vggish_postprocess.Postprocessor('vggish_pca_params.npz') 40 | postprocessed_output = post_processor.postprocess(pytorch_output) 41 | expected_postprocessed_mean = 123.0 42 | expected_postprocessed_std = 75.0 43 | print('Computed Post-processed Embedding Mean and Standard Deviation:', np.mean(postprocessed_output), 44 | np.std(postprocessed_output)) 45 | print('Expected Post-processed Embedding Mean and Standard Deviation:', expected_postprocessed_mean, 46 | expected_postprocessed_std) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /vggish.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class VGGish(nn.Module): 5 | """ 6 | PyTorch implementation of the VGGish model. 7 | 8 | Adapted from: https://github.com/harritaylor/torch-vggish 9 | The following modifications were made: (i) correction for the missing ReLU layers, (ii) correction for the 10 | improperly formatted data when transitioning from NHWC --> NCHW in the fully-connected layers, and (iii) 11 | correction for flattening in the fully-connected layers. 12 | """ 13 | 14 | def __init__(self): 15 | super(VGGish, self).__init__() 16 | self.features = nn.Sequential( 17 | nn.Conv2d(1, 64, 3, stride=1, padding=1), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(2, stride=2), 20 | 21 | nn.Conv2d(64, 128, 3, stride=1, padding=1), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(2, stride=2), 24 | 25 | nn.Conv2d(128, 256, 3, stride=1, padding=1), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(256, 256, 3, stride=1, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.MaxPool2d(2, stride=2), 30 | 31 | nn.Conv2d(256, 512, 3, stride=1, padding=1), 32 | nn.ReLU(inplace=True), 33 | nn.Conv2d(512, 512, 3, stride=1, padding=1), 34 | nn.ReLU(inplace=True), 35 | nn.MaxPool2d(2, stride=2) 36 | ) 37 | self.fc = nn.Sequential( 38 | nn.Linear(512 * 24, 4096), 39 | nn.ReLU(inplace=True), 40 | nn.Linear(4096, 4096), 41 | nn.ReLU(inplace=True), 42 | nn.Linear(4096, 128), 43 | nn.ReLU(inplace=True), 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.features(x).permute(0, 2, 3, 1).contiguous() 48 | x = x.view(x.size(0), -1) 49 | x = self.fc(x) 50 | return x 51 | 52 | 53 | def main(): 54 | pass 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | --------------------------------------------------------------------------------