├── README.md ├── LICENSE └── spectral_ops.py /README.md: -------------------------------------------------------------------------------- 1 | # GEDLoss_pytorch 2 | a pytorch implementation of Google GEDLoss 3 | 4 | Full-text paper available on [arXiv](https://arxiv.org/abs/2008.01160). 5 | 6 | Origin code of TensorFlow edition at [GED_TTS](https://github.com/google-research/google-research/tree/68c738421186ce85339bfee16bf3ca2ea3ec16e4/ged_tts) 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /spectral_ops.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Library of spectral operations.""" 17 | import librosa 18 | import numpy as np 19 | import scipy.signal.windows as W 20 | import scipy 21 | # import tensorflow.compat.v2 as tf 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | EPSILON = 1e-8 # Small constant to avoid division by zero. 27 | 28 | # Mel spectrum constants. 29 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 30 | _MEL_HIGH_FREQUENCY_Q = 1127.0 31 | 32 | ### This is a 'translation' from Google's tf implementation to the torch edition 33 | 34 | # TODO: test code 35 | # I have never test this pytorch edition code, 36 | # maybe there are some parameter type errors or some backward gradient errors in it. 37 | # You'd better run a test before you use it. 38 | 39 | def torch_aligned_random_crop(waves, frame_length): 40 | """Get aligned random crops from batches of input waves.""" 41 | n, t = waves[0].shape 42 | crop_t = frame_length * (t//frame_length - 1) 43 | # offsets = [tf.random.uniform(shape=(), minval=0, 44 | # maxval=t-crop_t, dtype=tf.int32) 45 | # for _ in range(n)] 46 | offsets = [np.random.randint(size=(),low=0,high=t-crop_t,dtype=torch.int32) 47 | for _ in range(n)] 48 | 49 | # waves_unbatched = [tf.split(w, n, axis=0) for w in waves] 50 | waves_unbatched = [torch.split(w, n, dim=0) for w in waves] 51 | 52 | # wave_crops = [[tf.slice(w, begin=[0, o], size=[1, crop_t]) 53 | # for w, o in zip(ws, offsets)] for ws in waves_unbatched] 54 | wave_crops = [[torch.narrow(torch.narrow(w,0,0,0+1),1,start=o,length=o+crop_t) 55 | for w, o in zip(ws, offsets)] for ws in waves_unbatched] 56 | 57 | #wave_crops = [tf.concat(wc, axis=0) for wc in wave_crops] 58 | wave_crops = [torch.cat(wc, dim=0) for wc in wave_crops] 59 | 60 | return wave_crops 61 | 62 | 63 | def torch_mel_to_hertz(frequencies_mel): 64 | """Converts frequencies in `frequencies_mel` from mel to Hertz scale.""" 65 | # return _MEL_BREAK_FREQUENCY_HERTZ * ( 66 | # tf.math.exp(frequencies_mel / _MEL_HIGH_FREQUENCY_Q) - 1.) 67 | return _MEL_BREAK_FREQUENCY_HERTZ * ( 68 | np.exp(frequencies_mel / _MEL_HIGH_FREQUENCY_Q) - 1.) 69 | 70 | 71 | def torch_hertz_to_mel(frequencies_hertz): 72 | """Converts frequencies in `frequencies_hertz` in Hertz to the mel scale.""" 73 | # return _MEL_HIGH_FREQUENCY_Q * tf.math.log( 74 | # 1. + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 75 | return _MEL_HIGH_FREQUENCY_Q * np.log( 76 | 1. + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 77 | 78 | 79 | def torch_get_spectral_matrix(n, num_spec_bins=256, use_mel_scale=True, 80 | sample_rate=24000): 81 | """DFT matrix in overcomplete basis returned as a TF tensor. 82 | 83 | Args: 84 | n: Int. Frame length for the spectral matrix. 85 | num_spec_bins: Int. Number of bins to use in the spectrogram 86 | use_mel_scale: Bool. Equally spaced on Mel-scale or Hertz-scale? 87 | sample_rate: Int. Sample rate of the waveform audio. 88 | 89 | Returns: 90 | Constructed spectral matrix. 91 | """ 92 | sample_rate = float(sample_rate) 93 | upper_edge_hertz = sample_rate / 2. 94 | lower_edge_hertz = sample_rate / n 95 | 96 | if use_mel_scale: 97 | # upper_edge_mel = hertz_to_mel(upper_edge_hertz) 98 | # lower_edge_mel = hertz_to_mel(lower_edge_hertz) 99 | # mel_frequencies = tf.linspace(lower_edge_mel, upper_edge_mel, num_spec_bins) 100 | # hertz_frequencies = mel_to_hertz(mel_frequencies) 101 | 102 | upper_edge_mel = torch_hertz_to_mel(upper_edge_hertz) 103 | lower_edge_mel = torch_hertz_to_mel(lower_edge_hertz) 104 | mel_frequencies = torch.linspace(lower_edge_mel, upper_edge_mel, (upper_edge_mel-lower_edge_mel)/num_spec_bins) 105 | hertz_frequencies = torch_mel_to_hertz(mel_frequencies) 106 | else: 107 | # hertz_frequencies = tf.linspace(lower_edge_hertz, upper_edge_hertz, 108 | # num_spec_bins) 109 | hertz_frequencies = torch.linspace(lower_edge_hertz, upper_edge_hertz, 110 | (upper_edge_hertz-lower_edge_hertz)/num_spec_bins) 111 | # time_col_vec = (tf.reshape(tf.range(n, dtype=tf.float32), [n, 1]) 112 | # * np.cast[np.float32](2. * np.pi / sample_rate)) 113 | time_col_vec = (torch.reshape(torch.range(0,n, dtype=torch.float32), [n, 1]) 114 | * np.cast[np.float32](2. * np.pi / sample_rate)) 115 | tmat = torch.reshape(hertz_frequencies, [1, num_spec_bins]) * time_col_vec 116 | dct_mat = torch.cos(tmat) 117 | dst_mat = torch.sin(tmat) 118 | # dft_mat = tf.complex(real=dct_mat, imag=-dst_mat) 119 | dft_mat = torch.view_as_complex([dct_mat,-dst_mat]) 120 | # TODO: update my pytoch to support the complex tensor 121 | # torch.view_as_complex() opreation need the last release edition of Pytorch 1.6.0 122 | 123 | return dft_mat 124 | 125 | 126 | def torch_matmul_real_with_complex(real_input, complex_matrix): 127 | real_part = torch.matmul(real_input, torch.view_as_real(complex_matrix)[:,0]) 128 | imag_part = torch.matmul(real_input, torch.view_as_real(complex_matrix)[:,1]) 129 | # return tf.complex(real_part, imag_part) 130 | return torch.view_as_complex([real_part, imag_part]) 131 | 132 | def torch_build_mel_basis( 133 | num_mel_bins, 134 | num_spectrogram_bins, 135 | sample_rate, 136 | lower_edge_hertz, 137 | upper_edge_hertz, 138 | dtype=torch.float32 139 | ): 140 | assert upper_edge_hertz <= sample_rate // 2 141 | return torch.tensor(librosa.filters.mel(sample_rate, num_spectrogram_bins, n_mels=num_mel_bins, 142 | fmin=lower_edge_hertz, fmax=upper_edge_hertz),dtype=dtype) 143 | 144 | 145 | def torch_calc_spectrograms(waves, window_lengths, spectral_diffs=(0, 1), 146 | window_name='hann', use_mel_scale=True, 147 | proj_method='matmul', num_spec_bins=256, 148 | random_crop=True): 149 | """Calculate spectrograms with multiple window sizes for list of input waves. 150 | 151 | Args: 152 | waves: List of float tensors of shape [batch, length] or [batch, length, 1]. 153 | window_lengths: List of Int. Window sizes (frame lengths) to use for 154 | computing the spectrograms. 155 | spectral_diffs: Int. order of finite diff. to take before computing specs. 156 | window_name: Str. Name of the window to use when computing the spectrograms. 157 | Supports 'hann' and None. 158 | use_mel_scale: Bool. Whether or not to project to mel-scale frequencies. 159 | proj_method: Str. Spectral projection method implementation to use. 160 | Supported are 'fft' and 'matmul'. 161 | num_spec_bins: Int. Number of bins in the spectrogram. 162 | random_crop: Bool. Take random crop or not. 163 | 164 | Returns: 165 | Tuple of lists of magnitude spectrograms, with output[i][j] being the 166 | spectrogram for input wave i, computed for window length j. 167 | """ 168 | # waves = [tf.squeeze(w, axis=-1) for w in waves] 169 | waves = [torch.squeeze(w, dim=-1) for w in waves] 170 | 171 | if window_name == 'hann': 172 | # windows = [tf.reshape(tf.signal.hann_window(wl, periodic=False), [1, 1, -1]) 173 | # for wl in window_lengths] 174 | windows = [torch.reshape(torch.from_numpy(W.hann(wl)), [1, 1, -1]) 175 | for wl in window_lengths] 176 | elif window_name is None: 177 | windows = [None] * len(window_lengths) 178 | else: 179 | raise ValueError('Unknown window function (%s).' % window_name) 180 | 181 | spec_len_wave = [] 182 | for d in spectral_diffs: 183 | for length, window in zip(window_lengths, windows): 184 | 185 | wave_crops = waves 186 | for _ in range(d): 187 | wave_crops = [w[:, 1:] - w[:, :-1] for w in wave_crops] 188 | 189 | if random_crop: 190 | # wave_crops = aligned_random_crop(wave_crops, length) 191 | wave_crops = torch_aligned_random_crop(wave_crops, length) 192 | 193 | # frames = [tf.signal.frame(wc, length, length // 2) for wc in wave_crops] 194 | frames = [torch.tensor(librosa.util.frame(wc.numpy(),length,length//2)) for wc in wave_crops] 195 | # TODO: Whether this method is feasible (in the gradient part) remains to be verified 196 | if window is not None: 197 | frames = [f * window for f in frames] 198 | 199 | if proj_method == 'fft': 200 | # ffts = [tf.signal.rfft(f)[:, :, 1:] for f in frames] 201 | ffts = [torch.rfft(f,signal_ndim=1)[:, :, 1:] for f in frames] 202 | elif proj_method == 'matmul': 203 | # mat = get_spectral_matrix(length, num_spec_bins=num_spec_bins, 204 | # use_mel_scale=use_mel_scale) 205 | # ffts = [matmul_real_with_complex(f, mat) for f in frames] 206 | mat = torch_get_spectral_matrix(length, num_spec_bins=num_spec_bins, 207 | use_mel_scale=use_mel_scale) 208 | ffts = [torch_matmul_real_with_complex(f, mat) for f in frames] 209 | 210 | #sq_mag = lambda x: tf.square(tf.math.real(x)) + tf.square(tf.math.imag(x)) 211 | sq_mag = lambda x: (torch.view_as_real(x)[:,0])**2 + (torch.view_as_real(x)[:,1])**2 212 | # torch.view_as_real() opreation need the last release edition of Pytorch 1.6.0 213 | specs_sq = [sq_mag(f) for f in ffts] 214 | 215 | if use_mel_scale and proj_method == 'fft': 216 | sample_rate = 24000 217 | upper_edge_hertz = sample_rate / 2. 218 | lower_edge_hertz = sample_rate / length 219 | # lin_to_mel = tf.signal.linear_to_mel_weight_matrix( 220 | # num_mel_bins=num_spec_bins, 221 | # num_spectrogram_bins=length // 2 + 1, 222 | # sample_rate=sample_rate, 223 | # lower_edge_hertz=lower_edge_hertz, 224 | # upper_edge_hertz=upper_edge_hertz, 225 | # dtype=tf.dtypes.float32)[1:] 226 | # specs_sq = [tf.matmul(s, lin_to_mel) for s in specs_sq] 227 | lin_to_mel = torch_build_mel_basis( 228 | num_mel_bins=num_spec_bins, 229 | num_spectrogram_bins=length, 230 | sample_rate=sample_rate, 231 | lower_edge_hertz=lower_edge_hertz, 232 | upper_edge_hertz=upper_edge_hertz, 233 | dtype=torch.float32) 234 | # TODO: I use librosa to build the mel filters here to instead, and i'm not sure whether this method works or not 235 | specs_sq = [torch.matmul(s, lin_to_mel) for s in specs_sq] 236 | 237 | # specs = [tf.sqrt(s+EPSILON) for s in specs_sq] 238 | specs = [torch.sqrt(s+EPSILON) for s in specs_sq] 239 | 240 | spec_len_wave.append(specs) 241 | 242 | spec_wave_len = zip(*spec_len_wave) 243 | return spec_wave_len 244 | 245 | 246 | def torch_sum_spectral_dist(specs1, specs2, add_log_l2=True): 247 | """Sum over distances in frequency space for different window sizes. 248 | 249 | Args: 250 | specs1: List of float tensors of shape [batch, frames, frequencies]. 251 | Spectrograms of the first wave to compute the distance for. 252 | specs2: List of float tensors of shape [batch, frames, frequencies]. 253 | Spectrograms of the second wave to compute the distance for. 254 | add_log_l2: Bool. Whether or not to add L2 in log space to L1 distances. 255 | 256 | Returns: 257 | Tensor of shape [batch] with sum of L1 distances over input spectrograms. 258 | """ 259 | 260 | # l1_distances = [tf.reduce_mean(abs(s1 - s2), axis=[1, 2]) 261 | # for s1, s2 in zip(specs1, specs2)] 262 | # sum_dist = tf.math.accumulate_n(l1_distances) 263 | l1_distances = [torch.mean(abs(s1 - s2), dim=[1, 2]) 264 | for s1, s2 in zip(specs1, specs2)] 265 | sum_dist = np.sum(l1_distances,dim=0) 266 | 267 | 268 | if add_log_l2: 269 | # log_deltas = [tf.math.squared_difference( 270 | # tf.math.log(s1 + EPSILON), tf.math.log(s2 + EPSILON)) # pylint: disable=bad-continuation 271 | # for s1, s2 in zip(specs1, specs2)] 272 | # log_l2_norms = [tf.reduce_mean( 273 | # tf.sqrt(tf.reduce_mean(ld, axis=-1) + EPSILON), axis=-1) 274 | # for ld in log_deltas] 275 | # sum_log_l2 = tf.math.accumulate_n(log_l2_norms) 276 | 277 | log_deltas = [( 278 | torch.log(s1 + EPSILON)-torch.log(s2 + EPSILON))**2 279 | for s1, s2 in zip(specs1, specs2)] 280 | log_l2_norms = [torch.mean(torch.sqrt(torch.mean(ld, dim=-1) + EPSILON), dim=-1) 281 | for ld in log_deltas] 282 | sum_log_l2 = np.sum(log_l2_norms,dim=0) 283 | 284 | sum_dist += sum_log_l2 285 | 286 | return sum_dist 287 | 288 | 289 | def torch_ged(wav_fake1, wav_fake2, wav_real): 290 | """Multi-scale spectrogram-based generalized energy distance. 291 | 292 | Args: 293 | wav_fake1: Float tensors of shape [batch, time, 1]. 294 | Generated audio samples conditional on a set of linguistic features. 295 | wav_fake2: Float tensors of shape [batch, time, 1]. 296 | Second set of samples conditional on same features, but using new noise. 297 | wav_real: Float tensors of shape [batch, time, 1]. 298 | Real (data) audio samples corresponding to the same features. 299 | 300 | Returns: 301 | Tensor of shape [batch] with the GED values. 302 | """ 303 | 304 | specs_fake1, specs_fake2, specs_real = torch_calc_spectrograms( 305 | waves=[wav_fake1, wav_fake2, wav_real], 306 | window_lengths=[2**i for i in range(6, 12)]) 307 | 308 | dist_real_fake1 = torch_sum_spectral_dist(specs_real, specs_fake1) 309 | dist_real_fake2 = torch_sum_spectral_dist(specs_real, specs_fake2) 310 | dist_fake_fake = torch_sum_spectral_dist(specs_fake1, specs_fake2) 311 | 312 | return dist_real_fake1 + dist_real_fake2 - dist_fake_fake 313 | 314 | # TODO: Run it! 315 | if __name__=='__main__': 316 | sample_rate= 22050 317 | wav_fake1 = librosa.core.load('wav_fake1_path', sr=sample_rate)[0] 318 | wav_fake2 = librosa.core.load('wav_fake2_path', sr=sample_rate)[0] 319 | wav_real = librosa.core.load('wav_real_path', sr=sample_rate)[0] 320 | 321 | GEDLoss = torch_ged(wav_fake1,wav_fake2,wav_real) 322 | 323 | print(GEDLoss) 324 | --------------------------------------------------------------------------------