├── .gitignore ├── LICENSE ├── README.md ├── ultra_dual_path_compression.ipynb └── ultra_dual_path_compression.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ultra Dual-Path Compression and Decompression 2 | 3 | This is the repository for a Pytorch-based implementation of the compression and decompression module in "Ultra Dual-Path Compression For Joint Echo Cancellation And Noise Suppression". The ultra dual-path compression module can compress the input multi-track spectra with large numbers of frames and frequency (T-F) bins into feature maps with small numbers of T-F bins, facilitating the fast processing for dual-path models (e.g., fullsubnet, 2D-convolution network). The decompression module transforms the compressed feature map back to the shapes of spectra for further processing. 4 | 5 | The latest codes are recommended to be found in `ultra_dual_path_compression.ipynb`, including dual-path compression, PostNet, an example of a dual-path GRU module, an example of a whole front-end network, and some examples of usage. Note that the codes of the dual-path GRU module and the whole front-end network are only for demonstration purposes and differ from what is in the article. Due to policy restrictions, the whole front-end network in the article will not be open-sourced at this time. `ultra_dual_path_compression.py` contains some legacy code, which will not be updated in the future. 6 | 7 | Demos can be found in [DemoPage](https://hangtingchen.github.io/ultra_dual_path_compression.github.io/). 8 | 9 | Please refer to our paper with the latest version on [Arxiv](https://arxiv.org/abs/2308.11053) for details. This paper is also accepted by [INTERSPEECH2023](https://www.isca-speech.org/archive/interspeech_2023/chen23t_interspeech.html). 10 | 11 | Please cite the paper if you found this module useful. 12 | ``` 13 | @article{DBLP:journals/corr/abs-2308-11053, 14 | author = {Hangting Chen and 15 | Jianwei Yu and 16 | Yi Luo and 17 | Rongzhi Gu and 18 | Weihua Li and 19 | Zhuocheng Lu and 20 | Chao Weng}, 21 | title = {Ultra Dual-Path Compression For Joint Echo Cancellation And Noise 22 | Suppression}, 23 | journal = {CoRR}, 24 | volume = {abs/2308.11053}, 25 | year = {2023}, 26 | url = {https://doi.org/10.48550/arXiv.2308.11053}, 27 | doi = {10.48550/arXiv.2308.11053}, 28 | eprinttype = {arXiv}, 29 | eprint = {2308.11053}, 30 | timestamp = {Fri, 25 Aug 2023 12:09:57 +0200}, 31 | biburl = {https://dblp.org/rec/journals/corr/abs-2308-11053.bib}, 32 | bibsource = {dblp computer science bibliography, https://dblp.org} 33 | } 34 | ``` 35 | 36 | ## Disclaimer 37 | This is not an officially supported Tencent product. 38 | -------------------------------------------------------------------------------- /ultra_dual_path_compression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Author: Hangting Chen\n", 10 | "# Copyright: Tencent AI Lab\n", 11 | "# Paper: Ultra Dual-Path Compression For Joint Echo Cancellation and Noise Suppression\n", 12 | "# This code give the source of the ultra dual-path compression module and an example of usage." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 4, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "import torch\n", 22 | "import torch.nn as nn\n", 23 | "import numpy as np\n", 24 | "import math" 25 | ] 26 | }, 27 | { 28 | "attachments": {}, 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "# Define neural network modules" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 5, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "class WAVSTFT(nn.Module):\n", 42 | " def __init__(self, win_size=320):\n", 43 | " super(WAVSTFT,self).__init__()\n", 44 | " window = torch.from_numpy(np.hanning(win_size).astype(np.float32))\n", 45 | " self.window_size = window.shape[-1]\n", 46 | " self.hop_length = self.window_size // 2\n", 47 | " window = window.unsqueeze(0).unsqueeze(-1)\n", 48 | " divisor = torch.ones(1,1,1,self.window_size*4)\n", 49 | " divisor = nn.functional.unfold(divisor,(1,self.window_size),stride=self.hop_length)\n", 50 | " divisor = divisor * window.pow(2.0)\n", 51 | " divisor = nn.functional.fold(divisor,(1,self.window_size*4),(1,self.window_size),stride=self.hop_length)[:,0,0,:]\n", 52 | " divisor = divisor[0,self.window_size:2*self.window_size].unsqueeze(0).unsqueeze(-1)\n", 53 | " self.register_buffer('window', window)\n", 54 | " self.register_buffer('divisor', divisor)\n", 55 | "\n", 56 | " def magphase(self, complex_tensor: torch.Tensor):\n", 57 | " mag = complex_tensor.pow(2.).sum(-1).pow(0.5 * 1.0)\n", 58 | " phase = torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])\n", 59 | " return mag, phase\n", 60 | "\n", 61 | " def add_window(self, x, divisor):\n", 62 | " out = x * self.window / divisor\n", 63 | " return out\n", 64 | "\n", 65 | " def frame(self,x):\n", 66 | " assert x.dim()==2, x.shape\n", 67 | " out = x.unsqueeze(1).unsqueeze(1)\n", 68 | " out = nn.functional.pad(out, (self.window_size, self.window_size), 'constant', 0)\n", 69 | " out = nn.functional.unfold(out,(1,self.window_size),\\\n", 70 | " stride=self.hop_length) # B N T\n", 71 | " return out\n", 72 | "\n", 73 | " def overlap_and_add(self,x,length):\n", 74 | " assert x.dim()==3, x.shape\n", 75 | " out = nn.functional.fold(x,(1,length+2*self.window_size),(1,self.window_size), \\\n", 76 | " stride=self.hop_length)[:,0,0,:]\n", 77 | " out = out[:,self.window_size:-self.window_size]\n", 78 | " return out\n", 79 | "\n", 80 | " def rfft(self, x):\n", 81 | " assert x.dim()==3, x.shape\n", 82 | " return torch.fft.rfft(x, dim=1)\n", 83 | "\n", 84 | " def irfft(self, x):\n", 85 | " assert x.dim()==3, x.shape\n", 86 | " return torch.fft.irfft(x, dim=1)\n", 87 | "\n", 88 | " def STFT(self, x, return_aux=False):\n", 89 | " assert x.dim()==2, x.shape\n", 90 | " out = self.frame(x)\n", 91 | " out = self.add_window(out, 1)\n", 92 | " out = self.rfft(out)\n", 93 | " if(return_aux):\n", 94 | " mag, phase = self.magphase(torch.view_as_real(out))\n", 95 | " lps = torch.log(mag**2 + 1e-8)\n", 96 | " return out, mag, phase, lps \n", 97 | " else:\n", 98 | " return out\n", 99 | "\n", 100 | " def iSTFT(self, x, length):\n", 101 | " assert x.dim()==3, x.shape\n", 102 | " out = self.irfft(x)\n", 103 | " out = self.add_window(out, self.divisor)\n", 104 | " out = self.overlap_and_add(out, length=length)\n", 105 | " return out" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "class TimeCompression(nn.Module):\n", 115 | " def __init__(self, dim1, dim2, dim3, dim4, steps):\n", 116 | " super().__init__()\n", 117 | " self.dim1 = dim1 # input dim\n", 118 | " self.dim2 = dim2 # hidden dim\n", 119 | " self.dim3 = dim3 # input dim for decompression\n", 120 | " self.dim4 = dim4 # out dim\n", 121 | " self.steps = steps\n", 122 | " self.trans1 = nn.Conv2d(self.dim1 * self.steps, self.dim2, 1, bias = False)\n", 123 | " self.trans2 = nn.Conv2d(self.dim3, self.dim4, 1, bias = False)\n", 124 | "\n", 125 | " def forward(self, x, inverse):\n", 126 | " # x B C T F\n", 127 | " if(inverse):\n", 128 | " B, C, T, F = x.shape\n", 129 | " x = self.trans2(x).reshape(B, -1, 1, T, F).permute(0,1,3,2,4).contiguous() # B C T S F\n", 130 | " x = x.repeat(1,1,1,self.steps,1).reshape(B, -1, T*self.steps, F) # B C T*S F\n", 131 | " x = torch.nn.functional.pad(x,(0,0,self.steps-1,0),'constant',0)\n", 132 | " x = x[:,:,:-self.steps+1,:]\n", 133 | " if(self.pad > 0): x = x[:,:,self.pad:,:]\n", 134 | " return x\n", 135 | " else:\n", 136 | " B, C, T, F = x.shape\n", 137 | " if(x.shape[-2]%self.steps==0):\n", 138 | " self.pad = 0\n", 139 | " else:\n", 140 | " self.pad = self.steps - x.shape[-2]%self.steps\n", 141 | " x = torch.nn.functional.pad(x,(0,0,self.pad,0),'constant',0)\n", 142 | " x = x.reshape(B, C, -1, self.steps, F).permute(0,1,3,2,4).contiguous() # B C S T F\n", 143 | " x = x.reshape(B, C*self.steps, -1, F) # B C*S T F\n", 144 | " return self.trans1(x)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 7, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "class FreqCompression(nn.Module):\n", 154 | " def __init__(self, nfreq, nfilters, in_dim, hidden_dim, \\\n", 155 | " out_dim, sample_rate=16000):\n", 156 | " super().__init__()\n", 157 | " self.nfreq = nfreq\n", 158 | " self.nfilters = nfilters\n", 159 | " self.sample_rate = sample_rate\n", 160 | " self.in_dim = in_dim\n", 161 | " self.hidden_dim = hidden_dim\n", 162 | " self.out_dim = out_dim\n", 163 | "\n", 164 | " mel_scale = 'htk'\n", 165 | " \n", 166 | " all_freqs = torch.linspace(0, sample_rate // 2, nfreq)\n", 167 | " # calculate mel freq bins\n", 168 | " m_min = self._hz_to_mel(0, mel_scale=mel_scale)\n", 169 | " m_max = self._hz_to_mel(sample_rate/2.0, mel_scale=mel_scale)\n", 170 | "\n", 171 | " m_pts = torch.linspace(m_min, m_max, self.nfilters + 2)\n", 172 | " f_pts = self._mel_to_hz(m_pts, mel_scale=mel_scale)\n", 173 | " self.bounds = [0,]\n", 174 | " for freq_inx in range(1, len(f_pts)-1):\n", 175 | " self.bounds.append((all_freqs > f_pts[freq_inx]).float().argmax().item())\n", 176 | " self.bounds.append(nfreq)\n", 177 | " self.trans1 = nn.ModuleList()\n", 178 | " self.trans2 = nn.ModuleList()\n", 179 | " for freq_inx in range(self.nfilters):\n", 180 | " self.trans1.append(nn.Linear((self.bounds[freq_inx+2]-self.bounds[freq_inx])*self.in_dim, self.hidden_dim, bias=False))\n", 181 | " self.trans2.append(nn.Conv1d(self.hidden_dim, (self.bounds[freq_inx+2]-self.bounds[freq_inx])*self.out_dim, 1))\n", 182 | " \n", 183 | " def _hz_to_mel(self, freq: float, mel_scale: str = \"htk\") -> float:\n", 184 | " r\"\"\"\n", 185 | " Source: https://pytorch.org/audio/stable/_modules/torchaudio/functional/functional.html\n", 186 | " Convert Hz to Mels.\n", 187 | "\n", 188 | " Args:\n", 189 | " freqs (float): Frequencies in Hz\n", 190 | " mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)\n", 191 | "\n", 192 | " Returns:\n", 193 | " mels (float): Frequency in Mels\n", 194 | " \"\"\"\n", 195 | "\n", 196 | " if mel_scale not in [\"slaney\", \"htk\"]:\n", 197 | " raise ValueError('mel_scale should be one of \"htk\" or \"slaney\".')\n", 198 | "\n", 199 | " if mel_scale == \"htk\":\n", 200 | " return 2595.0 * math.log10(1.0 + (freq / 700.0))\n", 201 | "\n", 202 | " # Fill in the linear part\n", 203 | " f_min = 0.0\n", 204 | " f_sp = 200.0 / 3\n", 205 | "\n", 206 | " mels = (freq - f_min) / f_sp\n", 207 | "\n", 208 | " # Fill in the log-scale part\n", 209 | " min_log_hz = 1000.0\n", 210 | " min_log_mel = (min_log_hz - f_min) / f_sp\n", 211 | " logstep = math.log(6.4) / 27.0\n", 212 | "\n", 213 | " if freq >= min_log_hz:\n", 214 | " mels = min_log_mel + math.log(freq / min_log_hz) / logstep\n", 215 | "\n", 216 | " return mels\n", 217 | " \n", 218 | " def _mel_to_hz(self, mels: torch.Tensor, mel_scale: str = \"htk\") -> torch.Tensor:\n", 219 | " \"\"\"\n", 220 | " Source: https://pytorch.org/audio/stable/_modules/torchaudio/functional/functional.html\n", 221 | " Convert mel bin numbers to frequencies.\n", 222 | "\n", 223 | " Args:\n", 224 | " mels (torch.Tensor): Mel frequencies\n", 225 | " mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)\n", 226 | "\n", 227 | " Returns:\n", 228 | " freqs (torch.Tensor): Mels converted in Hz\n", 229 | " \"\"\"\n", 230 | "\n", 231 | " if mel_scale not in [\"slaney\", \"htk\"]:\n", 232 | " raise ValueError('mel_scale should be one of \"htk\" or \"slaney\".')\n", 233 | "\n", 234 | " if mel_scale == \"htk\":\n", 235 | " return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)\n", 236 | "\n", 237 | " # Fill in the linear scale\n", 238 | " f_min = 0.0\n", 239 | " f_sp = 200.0 / 3\n", 240 | " freqs = f_min + f_sp * mels\n", 241 | "\n", 242 | " # And now the nonlinear scale\n", 243 | " min_log_hz = 1000.0\n", 244 | " min_log_mel = (min_log_hz - f_min) / f_sp\n", 245 | " logstep = math.log(6.4) / 27.0\n", 246 | "\n", 247 | " log_t = mels >= min_log_mel\n", 248 | " freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))\n", 249 | "\n", 250 | " return freqs\n", 251 | " \n", 252 | " def forward(self, x, inverse):\n", 253 | " if(inverse):\n", 254 | " # B C T F\n", 255 | " out = torch.zeros([x.shape[0],self.out_dim,self.nfreq,x.shape[2]], dtype=x.dtype, layout=x.layout, device=x.device)\n", 256 | " for freq_inx in range(self.nfilters):\n", 257 | " out[:,:,self.bounds[freq_inx]:self.bounds[freq_inx+2],:] = out[:,:,self.bounds[freq_inx]:self.bounds[freq_inx+2],:] + \\\n", 258 | " self.trans2[freq_inx](x[:,:,:,freq_inx]).reshape(x.shape[0],self.out_dim,-1,x.shape[-2])\n", 259 | " out[:,:,self.bounds[1]:self.bounds[-2],:] = out[:,:,self.bounds[1]:self.bounds[-2],:] / 2.0\n", 260 | " out = out.permute(0,1,3,2).contiguous().tanh()\n", 261 | " return out\n", 262 | " else:\n", 263 | " x = x.reshape(x.shape[0],self.in_dim, *x.shape[-2:]) # B C T F\n", 264 | " x = x.permute(0,2,1,3).contiguous() # B T C F\n", 265 | " x = torch.stack([self.trans1[freq_inx](x[:,:,:,self.bounds[freq_inx]:self.bounds[freq_inx+2]].flatten(start_dim=2)) \\\n", 266 | " for freq_inx in range(self.nfilters)],-1) # B T C F\n", 267 | " x = x.permute(0,2,1,3).contiguous()\n", 268 | " return x" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 8, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "class UltraDualPath(nn.Module):\n", 278 | " def __init__(self, nfreq, in_dim, hidden_dim, out_dim, \\\n", 279 | " freq_cprs_ratio, time_cprs_ratio):\n", 280 | " super(UltraDualPath, self).__init__()\n", 281 | " self.nfreq = nfreq\n", 282 | " self.in_dim = in_dim\n", 283 | " self.hidden_dim = hidden_dim\n", 284 | " self.out_dim = out_dim\n", 285 | " self.freq_cprs_ratio = freq_cprs_ratio\n", 286 | " self.time_cprs_ratio = time_cprs_ratio\n", 287 | " self.compress_modules = [\n", 288 | " TimeCompression(self.in_dim, self.in_dim * 2, \\\n", 289 | " self.out_dim, self.out_dim, time_cprs_ratio), \\\n", 290 | " FreqCompression(self.nfreq, self.nfreq // self.freq_cprs_ratio, \\\n", 291 | " self.in_dim * 2, self.hidden_dim, self.out_dim), \\\n", 292 | " ]\n", 293 | " def forward(self, x, inverse):\n", 294 | " out = x\n", 295 | " # print(out.shape)\n", 296 | " if(inverse):\n", 297 | " for m in self.compress_modules[::-1]:\n", 298 | " out = m(out, inverse)\n", 299 | " # print(out.shape)\n", 300 | " else:\n", 301 | " for m in self.compress_modules:\n", 302 | " out = m(out, inverse)\n", 303 | " # print(out.shape)\n", 304 | " return out" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 30, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "class DynamicLayerNorm(nn.Module):\n", 314 | " def __init__(self):\n", 315 | " super(DynamicLayerNorm, self).__init__() \n", 316 | "\n", 317 | " def forward(self, x):\n", 318 | " return nn.functional.layer_norm(x,[x.shape[-1],])\n", 319 | " \n", 320 | "class GRUBlock(nn.Module):\n", 321 | " def __init__(self, hidden_size, causal):\n", 322 | " super(GRUBlock, self).__init__()\n", 323 | " self.gru = nn.GRU(hidden_size, hidden_size, 1, bidirectional=False if causal else True, batch_first=True)\n", 324 | " self.linear = nn.Linear(hidden_size if causal else hidden_size*2, hidden_size)\n", 325 | " self.norm = DynamicLayerNorm()\n", 326 | " self.activation = nn.ReLU()\n", 327 | "\n", 328 | " def forward(self, input):\n", 329 | " out, _ = self.gru(input) # T B E\n", 330 | " out = self.linear(self.activation(out))\n", 331 | " out = input + out\n", 332 | " out = self.norm(out)\n", 333 | " return out\n", 334 | "\n", 335 | "class ExampleDualPathBackbone(nn.Module):\n", 336 | " # An example of dual path module\n", 337 | " def __init__(self, hidden_size, num_layers, causal=True):\n", 338 | " super(ExampleDualPathBackbone, self).__init__()\n", 339 | " self.hidden_size = hidden_size\n", 340 | "\n", 341 | " # dual-path RNN\n", 342 | " self.row_trans = nn.ModuleList([])\n", 343 | " self.col_trans = nn.ModuleList([])\n", 344 | " for i in range(num_layers):\n", 345 | " self.row_trans.append(\n", 346 | " GRUBlock(hidden_size, False)\n", 347 | " )\n", 348 | " self.col_trans.append(\n", 349 | " GRUBlock(hidden_size, causal)\n", 350 | " )\n", 351 | "\n", 352 | " def forward(self, input):\n", 353 | " # input --- [b, c, num_frames, frame_size] --- [b, c, dim2, dim1]\n", 354 | " b, _, dim2, dim1 = input.shape\n", 355 | " output = input\n", 356 | " for i in range(len(self.row_trans)):\n", 357 | " row_input = output.permute(0, 2, 3, 1).contiguous().view(b*dim2, dim1, -1) # [b*dim2, dim1, c]\n", 358 | " output = self.row_trans[i](row_input) # [b*dim2, dim1, c]\n", 359 | " output = output.view(b, dim2, dim1, -1).permute(0, 3, 1, 2).contiguous() # [b, c, dim2, dim1]\n", 360 | "\n", 361 | " col_input = output.permute(0, 3, 2, 1).contiguous().view(b*dim1, dim2, -1) # [b*dim1, dim2, c]\n", 362 | " output = self.col_trans[i](col_input) # [b*dim1, dim2, c]\n", 363 | " output = output.view(b, dim1, dim2, -1).permute(0, 3, 2, 1).contiguous() # [b, c, dim2, dim1]\n", 364 | "\n", 365 | " return output" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 89, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "class ExampleModel(nn.Module):\n", 375 | " def __init__(self):\n", 376 | " super(ExampleModel, self).__init__() \n", 377 | " win_size = 320 # 20ms for 16khz\n", 378 | " hidden_size = 48\n", 379 | " freq_cprs_ratio = 4\n", 380 | " time_cprs_ratio = 4\n", 381 | " num_in_channels = 3\n", 382 | " num_dual_path_blocks = 4\n", 383 | " hidden_size_post = 4\n", 384 | "\n", 385 | " self.wavSTFT = WAVSTFT(win_size) # 16khz, 20ms/10ms -> 320 samples\n", 386 | " self.ultraCompress = UltraDualPath(win_size//2 + 1, num_in_channels*2, hidden_size, hidden_size, freq_cprs_ratio, time_cprs_ratio) # 3ch*(real+imag)\n", 387 | " self.dualpath = ExampleDualPathBackbone(hidden_size, num_dual_path_blocks, True)\n", 388 | " self.conv2mask = nn.Conv2d(hidden_size, num_in_channels*2, 1, 1)\n", 389 | "\n", 390 | " self.postCompress = FreqCompression(win_size // 2 + 1, (win_size // 2 + 1) // 2, num_in_channels + 1, 1, hidden_size_post)\n", 391 | " self.postNet = nn.GRU((win_size // 2 + 1) // 2, (win_size // 2 + 1) // 2, 1, bidirectional=False, batch_first=True)\n", 392 | " self.postLinear2mask = nn.Sequential(\n", 393 | " nn.Linear(win_size // 2 + 1, win_size // 2 + 1),\n", 394 | " nn.Conv2d(hidden_size_post, hidden_size_post, 1),\n", 395 | " nn.Sigmoid(),\n", 396 | " nn.Linear(win_size // 2 + 1, win_size // 2 + 1),\n", 397 | " nn.Conv2d(hidden_size_post, 1, 1),\n", 398 | " nn.Sigmoid(),\n", 399 | " )\n", 400 | "\n", 401 | " def complex2logmag(self, real, imag):\n", 402 | " return (real.pow(2.0)+imag.pow(2.0)+1e-8).log()\n", 403 | "\n", 404 | " def forward(self, wav):\n", 405 | " # organize input\n", 406 | " B, C, N = wav.shape\n", 407 | " spec = self.wavSTFT.STFT(wav.reshape(B*C, N))\n", 408 | " spec = torch.view_as_real(spec.reshape(B, C, *spec.shape[-2:])) # B C F T 2\n", 409 | " spec = spec.permute(0,1,4,3,2) # B C 2 T F\n", 410 | " spec = spec.reshape(B, C*2, *spec.shape[-2:])\n", 411 | "\n", 412 | " # compress\n", 413 | " latent = self.ultraCompress(spec,0)\n", 414 | " # run dual-path network\n", 415 | " latent = self.dualpath(latent)\n", 416 | " # decompress\n", 417 | " latent = self.ultraCompress(latent,1)\n", 418 | " # conv2mask\n", 419 | " mask = self.conv2mask(latent).reshape(B, C, 2, *spec.shape[-2:])\n", 420 | " spec = spec.reshape(B, C, 2, *spec.shape[-2:])\n", 421 | " est_spec = torch.stack([\n", 422 | " spec[:,:,0] * mask[:,:,0] - spec[:,:,1] * mask[:,:,1], \\\n", 423 | " spec[:,:,0] * mask[:,:,1] + spec[:,:,1] * mask[:,:,0], \\\n", 424 | " ],2) # B C 2 T F\n", 425 | " est_spec = est_spec.sum(1, keepdim=True) # B 1 2 T F\n", 426 | "\n", 427 | " # organize input for postnet\n", 428 | " post_input = torch.cat([spec, est_spec], 1)\n", 429 | " post_input = self.complex2logmag(post_input[:,:,0], post_input[:,:,1]) # B C+1 T F\n", 430 | " # compress\n", 431 | " latent = self.postCompress(post_input, 0)\n", 432 | " # run postnet\n", 433 | " latent, _ = self.postNet(latent.squeeze(1))\n", 434 | " # decompress\n", 435 | " latent = self.postCompress(latent.unsqueeze(1), 1) # B 4 T F\n", 436 | " # conv2mask\n", 437 | " latent = self.postLinear2mask(latent) # B 1 T F\n", 438 | " est_spec = est_spec + spec[:,[0]] * latent.unsqueeze(2) \n", 439 | " ## we assume the first channel includes full information (e.g., the mic signal of AEC, the ref channel in multi-channel microphone)\n", 440 | " ## thus it is used to complement spec loss from the previous stage\n", 441 | "\n", 442 | " # recover to wav\n", 443 | " est_spec = torch.complex(est_spec.sum(1)[:,0], est_spec.sum(1)[:,1])\n", 444 | " ests_wav = self.wavSTFT.iSTFT(est_spec.permute(0,2,1), N)\n", 445 | " ests_wav = ests_wav.reshape(B, 1, N)\n", 446 | "\n", 447 | " return ests_wav" 448 | ] 449 | }, 450 | { 451 | "attachments": {}, 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "# An example of run" 456 | ] 457 | }, 458 | { 459 | "attachments": {}, 460 | "cell_type": "markdown", 461 | "metadata": {}, 462 | "source": [ 463 | "## Define model and input" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": 66, 469 | "metadata": {}, 470 | "outputs": [], 471 | "source": [ 472 | "## Network configuration\n", 473 | "win_size = 320\n", 474 | "hidden_size = 48\n", 475 | "freq_cprs_ratio = 4\n", 476 | "time_cprs_ratio = 4\n", 477 | "num_in_channels = 3\n", 478 | "\n", 479 | "## Define network modules\n", 480 | "wavSTFT = WAVSTFT(win_size) # 16khz, 20ms/10ms -> 320 samples\n", 481 | "ultraCompress = UltraDualPath(win_size//2 + 1, num_in_channels*2, hidden_size, hidden_size, freq_cprs_ratio, time_cprs_ratio) # 3ch*(real+imag)\n", 482 | "exampleDualPathModule = ExampleDualPathBackbone(hidden_size, 2, True)\n", 483 | "\n", 484 | "## Define input\n", 485 | "wav = torch.rand(2,num_in_channels,16000) # nbatch=2, nchannels=3, nsamples=16000" 486 | ] 487 | }, 488 | { 489 | "attachments": {}, 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "## Run dual path compression" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 67, 499 | "metadata": {}, 500 | "outputs": [ 501 | { 502 | "name": "stdout", 503 | "output_type": "stream", 504 | "text": [ 505 | "torch.Size([2, 6, 103, 161])\n" 506 | ] 507 | } 508 | ], 509 | "source": [ 510 | "# organize input\n", 511 | "B, C, N = wav.shape\n", 512 | "spec = wavSTFT.STFT(wav.reshape(B*C, N))\n", 513 | "spec = torch.view_as_real(spec.reshape(B, C, *spec.shape[-2:])) # B C F T 2\n", 514 | "spec = spec.permute(0,1,4,3,2) # B C 2 T F\n", 515 | "spec = spec.reshape(B, C*2, *spec.shape[-2:])\n", 516 | "print(spec.shape) # nbatch, nchannels*2, nframes, nfreqs" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": 68, 522 | "metadata": {}, 523 | "outputs": [ 524 | { 525 | "name": "stdout", 526 | "output_type": "stream", 527 | "text": [ 528 | "torch.Size([2, 48, 26, 40])\n" 529 | ] 530 | } 531 | ], 532 | "source": [ 533 | "# compress\n", 534 | "latent = ultraCompress(spec,0)\n", 535 | "print(latent.shape) # nbatch, hidden_dim, nframes, nfreqs" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 69, 541 | "metadata": {}, 542 | "outputs": [ 543 | { 544 | "name": "stdout", 545 | "output_type": "stream", 546 | "text": [ 547 | "torch.Size([2, 48, 26, 40])\n" 548 | ] 549 | } 550 | ], 551 | "source": [ 552 | "# run dual-path network\n", 553 | "latent = exampleDualPathModule(latent)\n", 554 | "print(latent.shape) # nbatch, hidden_dim, nframes, nfreqs" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 70, 560 | "metadata": {}, 561 | "outputs": [ 562 | { 563 | "name": "stdout", 564 | "output_type": "stream", 565 | "text": [ 566 | "torch.Size([2, 48, 103, 161])\n" 567 | ] 568 | } 569 | ], 570 | "source": [ 571 | "# decompress\n", 572 | "output = ultraCompress(latent,1)\n", 573 | "print(output.shape) # nbatch, out_dim, nframes, nfreqs" 574 | ] 575 | }, 576 | { 577 | "attachments": {}, 578 | "cell_type": "markdown", 579 | "metadata": {}, 580 | "source": [ 581 | "## Run freq compression only" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 71, 587 | "metadata": {}, 588 | "outputs": [ 589 | { 590 | "name": "stdout", 591 | "output_type": "stream", 592 | "text": [ 593 | "torch.Size([2, 48, 103, 40])\n", 594 | "torch.Size([2, 48, 103, 40])\n", 595 | "torch.Size([2, 48, 103, 161])\n" 596 | ] 597 | } 598 | ], 599 | "source": [ 600 | "# Frequcy only\n", 601 | "ultraCompress = FreqCompression(win_size // 2 + 1, (win_size // 2 + 1) // freq_cprs_ratio, num_in_channels * 2, hidden_size, hidden_size) # 4x compression ratio\n", 602 | "# compress\n", 603 | "latent = ultraCompress(spec, 0)\n", 604 | "print(latent.shape) # nbatch, hidden_dim, nframes, nfreqs\n", 605 | "# run dual path network\n", 606 | "latent = exampleDualPathModule(latent)\n", 607 | "print(latent.shape) # nbatch, hidden_dim, nframes, nfreqs\n", 608 | "# decompress\n", 609 | "output = ultraCompress(latent, 1)\n", 610 | "print(output.shape) # nbatch, out_dim, nframes, nfreqs" 611 | ] 612 | }, 613 | { 614 | "attachments": {}, 615 | "cell_type": "markdown", 616 | "metadata": {}, 617 | "source": [ 618 | "## Run time compression only" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": 72, 624 | "metadata": {}, 625 | "outputs": [ 626 | { 627 | "name": "stdout", 628 | "output_type": "stream", 629 | "text": [ 630 | "torch.Size([2, 48, 26, 161])\n", 631 | "torch.Size([2, 48, 26, 161])\n", 632 | "torch.Size([2, 48, 103, 161])\n" 633 | ] 634 | } 635 | ], 636 | "source": [ 637 | "# Frequcy only\n", 638 | "ultraCompress = TimeCompression(num_in_channels * 2, hidden_size, hidden_size, hidden_size, time_cprs_ratio) # 4x compression ratio\n", 639 | "# compress\n", 640 | "latent = ultraCompress(spec, 0)\n", 641 | "print(latent.shape) # nbatch, hidden_dim, nframes, nfreqs\n", 642 | "# run dual path network\n", 643 | "latent = exampleDualPathModule(latent)\n", 644 | "print(latent.shape) # nbatch, hidden_dim, nframes, nfreqs\n", 645 | "# decompress\n", 646 | "output = ultraCompress(latent, 1)\n", 647 | "print(output.shape) # nbatch, out_dim, nframes, nfreqs" 648 | ] 649 | }, 650 | { 651 | "attachments": {}, 652 | "cell_type": "markdown", 653 | "metadata": {}, 654 | "source": [ 655 | "# Run ExampleModel" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": 90, 661 | "metadata": {}, 662 | "outputs": [ 663 | { 664 | "name": "stdout", 665 | "output_type": "stream", 666 | "text": [ 667 | "torch.Size([2, 1, 16000])\n" 668 | ] 669 | } 670 | ], 671 | "source": [ 672 | "## Define input\n", 673 | "wav = torch.rand(2,num_in_channels,16000) # nbatch=2, nchannels=3, nsamples=16000\n", 674 | "\n", 675 | "## Define model\n", 676 | "model = ExampleModel()\n", 677 | "est_wav = model(wav)\n", 678 | "print(est_wav.shape)" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 92, 684 | "metadata": {}, 685 | "outputs": [ 686 | { 687 | "name": "stdout", 688 | "output_type": "stream", 689 | "text": [ 690 | "[INFO] Register count_gru() for .\n", 691 | "[INFO] Register count_linear() for .\n", 692 | "[INFO] Register zero_ops() for .\n", 693 | "[INFO] Register count_convNd() for .\n", 694 | "[INFO] Register count_convNd() for .\n", 695 | "[INFO] Register zero_ops() for .\n", 696 | "{'conv2mask': (7418880.0, 294.0, {}),\n", 697 | " 'dualpath': (321638400.0,\n", 698 | " 197376.0,\n", 699 | " {'col_trans': (107212800.0,\n", 700 | " 65856.0,\n", 701 | " {'0': (26803200.0,\n", 702 | " 16464.0,\n", 703 | " {'activation': (0.0, 0.0, {}),\n", 704 | " 'gru': (23116800.0, 14112.0, {}),\n", 705 | " 'linear': (3686400.0, 2352.0, {}),\n", 706 | " 'norm': (0.0, 0, {})}),\n", 707 | " '1': (26803200.0,\n", 708 | " 16464.0,\n", 709 | " {'activation': (0.0, 0.0, {}),\n", 710 | " 'gru': (23116800.0, 14112.0, {}),\n", 711 | " 'linear': (3686400.0, 2352.0, {}),\n", 712 | " 'norm': (0.0, 0, {})}),\n", 713 | " '2': (26803200.0,\n", 714 | " 16464.0,\n", 715 | " {'activation': (0.0, 0.0, {}),\n", 716 | " 'gru': (23116800.0, 14112.0, {}),\n", 717 | " 'linear': (3686400.0, 2352.0, {}),\n", 718 | " 'norm': (0.0, 0, {})}),\n", 719 | " '3': (26803200.0,\n", 720 | " 16464.0,\n", 721 | " {'activation': (0.0, 0.0, {}),\n", 722 | " 'gru': (23116800.0, 14112.0, {}),\n", 723 | " 'linear': (3686400.0, 2352.0, {}),\n", 724 | " 'norm': (0.0, 0, {})})}),\n", 725 | " 'row_trans': (214425600.0,\n", 726 | " 131520.0,\n", 727 | " {'0': (53606400.0,\n", 728 | " 32880.0,\n", 729 | " {'activation': (0.0, 0.0, {}),\n", 730 | " 'gru': (46233600.0, 28224.0, {}),\n", 731 | " 'linear': (7372800.0, 4656.0, {}),\n", 732 | " 'norm': (0.0, 0, {})}),\n", 733 | " '1': (53606400.0,\n", 734 | " 32880.0,\n", 735 | " {'activation': (0.0, 0.0, {}),\n", 736 | " 'gru': (46233600.0, 28224.0, {}),\n", 737 | " 'linear': (7372800.0, 4656.0, {}),\n", 738 | " 'norm': (0.0, 0, {})}),\n", 739 | " '2': (53606400.0,\n", 740 | " 32880.0,\n", 741 | " {'activation': (0.0, 0.0, {}),\n", 742 | " 'gru': (46233600.0, 28224.0, {}),\n", 743 | " 'linear': (7372800.0, 4656.0, {}),\n", 744 | " 'norm': (0.0, 0, {})}),\n", 745 | " '3': (53606400.0,\n", 746 | " 32880.0,\n", 747 | " {'activation': (0.0, 0.0, {}),\n", 748 | " 'gru': (46233600.0, 28224.0, {}),\n", 749 | " 'linear': (7372800.0, 4656.0, {}),\n", 750 | " 'norm': (0.0, 0, {})})})}),\n", 751 | " 'postCompress': (403200.0,\n", 752 | " 3780.0,\n", 753 | " {'trans1': (201600.0,\n", 754 | " 1260.0,\n", 755 | " {'0': (640.0, 4.0, {}),\n", 756 | " '1': (640.0, 4.0, {}),\n", 757 | " '10': (640.0, 4.0, {}),\n", 758 | " '11': (640.0, 4.0, {}),\n", 759 | " '12': (640.0, 4.0, {}),\n", 760 | " '13': (1280.0, 8.0, {}),\n", 761 | " '14': (1280.0, 8.0, {}),\n", 762 | " '15': (640.0, 4.0, {}),\n", 763 | " '16': (640.0, 4.0, {}),\n", 764 | " '17': (1280.0, 8.0, {}),\n", 765 | " '18': (1280.0, 8.0, {}),\n", 766 | " '19': (640.0, 4.0, {}),\n", 767 | " '2': (640.0, 4.0, {}),\n", 768 | " '20': (640.0, 4.0, {}),\n", 769 | " '21': (1280.0, 8.0, {}),\n", 770 | " '22': (1280.0, 8.0, {}),\n", 771 | " '23': (1280.0, 8.0, {}),\n", 772 | " '24': (1280.0, 8.0, {}),\n", 773 | " '25': (1280.0, 8.0, {}),\n", 774 | " '26': (1280.0, 8.0, {}),\n", 775 | " '27': (1280.0, 8.0, {}),\n", 776 | " '28': (1280.0, 8.0, {}),\n", 777 | " '29': (1280.0, 8.0, {}),\n", 778 | " '3': (640.0, 4.0, {}),\n", 779 | " '30': (1280.0, 8.0, {}),\n", 780 | " '31': (1920.0, 12.0, {}),\n", 781 | " '32': (1920.0, 12.0, {}),\n", 782 | " '33': (1280.0, 8.0, {}),\n", 783 | " '34': (1280.0, 8.0, {}),\n", 784 | " '35': (1920.0, 12.0, {}),\n", 785 | " '36': (1920.0, 12.0, {}),\n", 786 | " '37': (1920.0, 12.0, {}),\n", 787 | " '38': (1920.0, 12.0, {}),\n", 788 | " '39': (1920.0, 12.0, {}),\n", 789 | " '4': (640.0, 4.0, {}),\n", 790 | " '40': (1920.0, 12.0, {}),\n", 791 | " '41': (1920.0, 12.0, {}),\n", 792 | " '42': (2560.0, 16.0, {}),\n", 793 | " '43': (1920.0, 12.0, {}),\n", 794 | " '44': (1920.0, 12.0, {}),\n", 795 | " '45': (2560.0, 16.0, {}),\n", 796 | " '46': (2560.0, 16.0, {}),\n", 797 | " '47': (2560.0, 16.0, {}),\n", 798 | " '48': (2560.0, 16.0, {}),\n", 799 | " '49': (2560.0, 16.0, {}),\n", 800 | " '5': (640.0, 4.0, {}),\n", 801 | " '50': (2560.0, 16.0, {}),\n", 802 | " '51': (2560.0, 16.0, {}),\n", 803 | " '52': (3200.0, 20.0, {}),\n", 804 | " '53': (3200.0, 20.0, {}),\n", 805 | " '54': (2560.0, 16.0, {}),\n", 806 | " '55': (3200.0, 20.0, {}),\n", 807 | " '56': (3840.0, 24.0, {}),\n", 808 | " '57': (3200.0, 20.0, {}),\n", 809 | " '58': (3200.0, 20.0, {}),\n", 810 | " '59': (3840.0, 24.0, {}),\n", 811 | " '6': (640.0, 4.0, {}),\n", 812 | " '60': (3840.0, 24.0, {}),\n", 813 | " '61': (3840.0, 24.0, {}),\n", 814 | " '62': (3840.0, 24.0, {}),\n", 815 | " '63': (3840.0, 24.0, {}),\n", 816 | " '64': (4480.0, 28.0, {}),\n", 817 | " '65': (4480.0, 28.0, {}),\n", 818 | " '66': (4480.0, 28.0, {}),\n", 819 | " '67': (4480.0, 28.0, {}),\n", 820 | " '68': (4480.0, 28.0, {}),\n", 821 | " '69': (5120.0, 32.0, {}),\n", 822 | " '7': (640.0, 4.0, {}),\n", 823 | " '70': (5120.0, 32.0, {}),\n", 824 | " '71': (5120.0, 32.0, {}),\n", 825 | " '72': (5120.0, 32.0, {}),\n", 826 | " '73': (5760.0, 36.0, {}),\n", 827 | " '74': (5760.0, 36.0, {}),\n", 828 | " '75': (5760.0, 36.0, {}),\n", 829 | " '76': (6400.0, 40.0, {}),\n", 830 | " '77': (6400.0, 40.0, {}),\n", 831 | " '78': (6400.0, 40.0, {}),\n", 832 | " '79': (7040.0, 44.0, {}),\n", 833 | " '8': (1280.0, 8.0, {}),\n", 834 | " '9': (640.0, 4.0, {})}),\n", 835 | " 'trans2': (201600.0,\n", 836 | " 2520.0,\n", 837 | " {'0': (640.0, 8.0, {}),\n", 838 | " '1': (640.0, 8.0, {}),\n", 839 | " '10': (640.0, 8.0, {}),\n", 840 | " '11': (640.0, 8.0, {}),\n", 841 | " '12': (640.0, 8.0, {}),\n", 842 | " '13': (1280.0, 16.0, {}),\n", 843 | " '14': (1280.0, 16.0, {}),\n", 844 | " '15': (640.0, 8.0, {}),\n", 845 | " '16': (640.0, 8.0, {}),\n", 846 | " '17': (1280.0, 16.0, {}),\n", 847 | " '18': (1280.0, 16.0, {}),\n", 848 | " '19': (640.0, 8.0, {}),\n", 849 | " '2': (640.0, 8.0, {}),\n", 850 | " '20': (640.0, 8.0, {}),\n", 851 | " '21': (1280.0, 16.0, {}),\n", 852 | " '22': (1280.0, 16.0, {}),\n", 853 | " '23': (1280.0, 16.0, {}),\n", 854 | " '24': (1280.0, 16.0, {}),\n", 855 | " '25': (1280.0, 16.0, {}),\n", 856 | " '26': (1280.0, 16.0, {}),\n", 857 | " '27': (1280.0, 16.0, {}),\n", 858 | " '28': (1280.0, 16.0, {}),\n", 859 | " '29': (1280.0, 16.0, {}),\n", 860 | " '3': (640.0, 8.0, {}),\n", 861 | " '30': (1280.0, 16.0, {}),\n", 862 | " '31': (1920.0, 24.0, {}),\n", 863 | " '32': (1920.0, 24.0, {}),\n", 864 | " '33': (1280.0, 16.0, {}),\n", 865 | " '34': (1280.0, 16.0, {}),\n", 866 | " '35': (1920.0, 24.0, {}),\n", 867 | " '36': (1920.0, 24.0, {}),\n", 868 | " '37': (1920.0, 24.0, {}),\n", 869 | " '38': (1920.0, 24.0, {}),\n", 870 | " '39': (1920.0, 24.0, {}),\n", 871 | " '4': (640.0, 8.0, {}),\n", 872 | " '40': (1920.0, 24.0, {}),\n", 873 | " '41': (1920.0, 24.0, {}),\n", 874 | " '42': (2560.0, 32.0, {}),\n", 875 | " '43': (1920.0, 24.0, {}),\n", 876 | " '44': (1920.0, 24.0, {}),\n", 877 | " '45': (2560.0, 32.0, {}),\n", 878 | " '46': (2560.0, 32.0, {}),\n", 879 | " '47': (2560.0, 32.0, {}),\n", 880 | " '48': (2560.0, 32.0, {}),\n", 881 | " '49': (2560.0, 32.0, {}),\n", 882 | " '5': (640.0, 8.0, {}),\n", 883 | " '50': (2560.0, 32.0, {}),\n", 884 | " '51': (2560.0, 32.0, {}),\n", 885 | " '52': (3200.0, 40.0, {}),\n", 886 | " '53': (3200.0, 40.0, {}),\n", 887 | " '54': (2560.0, 32.0, {}),\n", 888 | " '55': (3200.0, 40.0, {}),\n", 889 | " '56': (3840.0, 48.0, {}),\n", 890 | " '57': (3200.0, 40.0, {}),\n", 891 | " '58': (3200.0, 40.0, {}),\n", 892 | " '59': (3840.0, 48.0, {}),\n", 893 | " '6': (640.0, 8.0, {}),\n", 894 | " '60': (3840.0, 48.0, {}),\n", 895 | " '61': (3840.0, 48.0, {}),\n", 896 | " '62': (3840.0, 48.0, {}),\n", 897 | " '63': (3840.0, 48.0, {}),\n", 898 | " '64': (4480.0, 56.0, {}),\n", 899 | " '65': (4480.0, 56.0, {}),\n", 900 | " '66': (4480.0, 56.0, {}),\n", 901 | " '67': (4480.0, 56.0, {}),\n", 902 | " '68': (4480.0, 56.0, {}),\n", 903 | " '69': (5120.0, 64.0, {}),\n", 904 | " '7': (640.0, 8.0, {}),\n", 905 | " '70': (5120.0, 64.0, {}),\n", 906 | " '71': (5120.0, 64.0, {}),\n", 907 | " '72': (5120.0, 64.0, {}),\n", 908 | " '73': (5760.0, 72.0, {}),\n", 909 | " '74': (5760.0, 72.0, {}),\n", 910 | " '75': (5760.0, 72.0, {}),\n", 911 | " '76': (6400.0, 80.0, {}),\n", 912 | " '77': (6400.0, 80.0, {}),\n", 913 | " '78': (6400.0, 80.0, {}),\n", 914 | " '79': (7040.0, 88.0, {}),\n", 915 | " '8': (1280.0, 16.0, {}),\n", 916 | " '9': (640.0, 8.0, {})})}),\n", 917 | " 'postLinear2mask': (33694080.0,\n", 918 | " 52189.0,\n", 919 | " {'0': (16589440.0, 26082.0, {}),\n", 920 | " '1': (412160.0, 20.0, {}),\n", 921 | " '2': (0.0, 0, {}),\n", 922 | " '3': (16589440.0, 26082.0, {}),\n", 923 | " '4': (103040.0, 5.0, {}),\n", 924 | " '5': (0.0, 0, {})}),\n", 925 | " 'postNet': (6310400.0, 38880.0, {}),\n", 926 | " 'ultraCompress': (0.0, 0, {}),\n", 927 | " 'wavSTFT': (0.0, 0, {})}\n", 928 | "Params : 292.519K, Macs : 230.916\n" 929 | ] 930 | } 931 | ], 932 | "source": [ 933 | "from thop import profile, clever_format\n", 934 | "net_causal = ExampleModel()\n", 935 | "macs, params, ret_dict = profile(net_causal.to(torch.device('cpu')), inputs=(torch.rand(1,3,25120), ), ret_layer_info=True,)\n", 936 | "macs, params = clever_format([macs, params], \"%.3f\")\n", 937 | "from pprint import pprint\n", 938 | "pprint(ret_dict)\n", 939 | "print('Params : {}, Macs : {:.03f}'.format(params, float(macs[:-1])/1.6))" 940 | ] 941 | }, 942 | { 943 | "cell_type": "code", 944 | "execution_count": null, 945 | "metadata": {}, 946 | "outputs": [], 947 | "source": [] 948 | } 949 | ], 950 | "metadata": { 951 | "kernelspec": { 952 | "display_name": "torchcpu", 953 | "language": "python", 954 | "name": "python3" 955 | }, 956 | "language_info": { 957 | "codemirror_mode": { 958 | "name": "ipython", 959 | "version": 3 960 | }, 961 | "file_extension": ".py", 962 | "mimetype": "text/x-python", 963 | "name": "python", 964 | "nbconvert_exporter": "python", 965 | "pygments_lexer": "ipython3", 966 | "version": "3.11.5" 967 | }, 968 | "orig_nbformat": 4 969 | }, 970 | "nbformat": 4, 971 | "nbformat_minor": 2 972 | } 973 | -------------------------------------------------------------------------------- /ultra_dual_path_compression.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # text_representation: 5 | # extension: .py 6 | # format_name: light 7 | # format_version: '1.5' 8 | # jupytext_version: 1.15.2 9 | # kernelspec: 10 | # display_name: torchcpu 11 | # language: python 12 | # name: python3 13 | # --- 14 | 15 | # + 16 | # Author: Hangting Chen 17 | # Copyright: Tencent AI Lab 18 | # Paper: Ultra Dual-Path Compression For 19 | # Joint Echo Cancellation and Noise Suppression 20 | # This code give the source of the ultra dual-path 21 | # compression module and an example of usage. 22 | # - 23 | 24 | import torch 25 | import torch.nn as nn 26 | import numpy as np 27 | import math 28 | 29 | 30 | # # Define neural network modules 31 | 32 | 33 | class WAVSTFT(nn.Module): 34 | def __init__(self, win_size=320): 35 | super(WAVSTFT, self).__init__() 36 | window = torch.from_numpy(np.hanning(win_size).astype(np.float32)) 37 | self.window_size = window.shape[-1] 38 | self.hop_length = self.window_size // 2 39 | window = window.unsqueeze(0).unsqueeze(-1) 40 | divisor = torch.ones(1, 1, 1, self.window_size * 4) 41 | divisor = nn.functional.unfold( 42 | divisor, (1, self.window_size), stride=self.hop_length 43 | ) 44 | divisor = divisor * window.pow(2.0) 45 | divisor = nn.functional.fold( 46 | divisor, 47 | (1, self.window_size * 4), 48 | (1, self.window_size), 49 | stride=self.hop_length, 50 | )[:, 0, 0, :] 51 | divisor = ( 52 | divisor[0, self.window_size: 2 * self.window_size] 53 | .unsqueeze(0) 54 | .unsqueeze(-1) 55 | ) 56 | self.register_buffer("window", window) 57 | self.register_buffer("divisor", divisor) 58 | 59 | def magphase(self, complex_tensor: torch.Tensor): 60 | mag = complex_tensor.pow(2.0).sum(-1).pow(0.5 * 1.0) 61 | phase = torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) 62 | return mag, phase 63 | 64 | def add_window(self, x, divisor): 65 | out = x * self.window / divisor 66 | return out 67 | 68 | def frame(self, x): 69 | assert x.dim() == 2, x.shape 70 | out = x.unsqueeze(1).unsqueeze(1) 71 | out = nn.functional.pad( 72 | out, (self.window_size, self.window_size), "constant", 0 73 | ) 74 | out = nn.functional.unfold( 75 | out, (1, self.window_size), stride=self.hop_length 76 | ) # B N T 77 | return out 78 | 79 | def overlap_and_add(self, x, length): 80 | assert x.dim() == 3, x.shape 81 | out = nn.functional.fold( 82 | x, 83 | (1, length + 2 * self.window_size), 84 | (1, self.window_size), 85 | stride=self.hop_length, 86 | )[:, 0, 0, :] 87 | out = out[:, self.window_size: -self.window_size] 88 | return out 89 | 90 | def rfft(self, x): 91 | assert x.dim() == 3, x.shape 92 | return torch.fft.rfft(x, dim=1) 93 | 94 | def irfft(self, x): 95 | assert x.dim() == 3, x.shape 96 | return torch.fft.irfft(x, dim=1) 97 | 98 | def STFT(self, x, return_aux=False): 99 | assert x.dim() == 2, x.shape 100 | out = self.frame(x) 101 | out = self.add_window(out, 1) 102 | out = self.rfft(out) 103 | if return_aux: 104 | mag, phase = self.magphase(torch.view_as_real(out)) 105 | lps = torch.log(mag**2 + 1e-8) 106 | return out, mag, phase, lps 107 | else: 108 | return out 109 | 110 | def iSTFT(self, x, length): 111 | assert x.dim() == 3, x.shape 112 | out = self.irfft(x) 113 | out = self.add_window(out, self.divisor) 114 | out = self.overlap_and_add(out, length=length) 115 | return out 116 | 117 | 118 | class TimeCompression(nn.Module): 119 | def __init__(self, dim1, dim2, dim3, dim4, steps): 120 | super().__init__() 121 | self.dim1 = dim1 122 | self.dim2 = dim2 123 | self.dim3 = dim3 124 | self.dim4 = dim4 125 | self.steps = steps 126 | self.trans1 = nn.Conv2d( 127 | self.dim1 * self.steps, self.dim2, 1, bias=False 128 | ) 129 | self.trans2 = nn.Conv2d(self.dim3, self.dim4, 1, bias=False) 130 | 131 | def forward(self, x, inverse): 132 | # x B C T F 133 | if inverse: 134 | B, C, T, F = x.shape 135 | x = ( 136 | self.trans2(x) 137 | .reshape(B, -1, 1, T, F) 138 | .permute(0, 1, 3, 2, 4) 139 | .contiguous() 140 | ) # B C T S F 141 | x = x.repeat(1, 1, 1, self.steps, 1).reshape( 142 | B, -1, T * self.steps, F 143 | ) # B C T*S F 144 | x = torch.nn.functional.pad( 145 | x, (0, 0, self.steps - 1, 0), "constant", 0 146 | ) 147 | x = x[:, :, : -self.steps + 1, :] 148 | if self.pad > 0: 149 | x = x[:, :, self.pad:, :] 150 | return x 151 | else: 152 | B, C, T, F = x.shape 153 | if x.shape[-2] % self.steps == 0: 154 | self.pad = 0 155 | else: 156 | self.pad = self.steps - x.shape[-2] % self.steps 157 | x = torch.nn.functional.pad( 158 | x, (0, 0, self.pad, 0), "constant", 0 159 | ) 160 | x = ( 161 | x.reshape(B, C, -1, self.steps, F) 162 | .permute(0, 1, 3, 2, 4) 163 | .contiguous() 164 | ) # B C S T F 165 | x = x.reshape(B, C * self.steps, -1, F) # B C*S T F 166 | return self.trans1(x) 167 | 168 | 169 | class FreqCompression(nn.Module): 170 | def __init__( 171 | self, nfreq, nfilters, in_dim, hidden_dim, out_dim, sample_rate=16000 172 | ): 173 | super().__init__() 174 | self.nfreq = nfreq 175 | self.nfilters = nfilters 176 | self.sample_rate = sample_rate 177 | self.in_dim = in_dim 178 | self.hidden_dim = hidden_dim 179 | self.out_dim = out_dim 180 | 181 | mel_scale = "htk" 182 | 183 | all_freqs = torch.linspace(0, sample_rate // 2, nfreq) 184 | # calculate mel freq bins 185 | m_min = self._hz_to_mel(0, mel_scale=mel_scale) 186 | m_max = self._hz_to_mel(sample_rate / 2.0, mel_scale=mel_scale) 187 | 188 | m_pts = torch.linspace(m_min, m_max, self.nfilters + 2) 189 | f_pts = self._mel_to_hz(m_pts, mel_scale=mel_scale) 190 | self.bounds = [ 191 | 0, 192 | ] 193 | for freq_inx in range(1, len(f_pts) - 1): 194 | self.bounds.append( 195 | (all_freqs > f_pts[freq_inx]).float().argmax().item() 196 | ) 197 | self.bounds.append(nfreq) 198 | self.trans1 = nn.ModuleList() 199 | self.trans2 = nn.ModuleList() 200 | for freq_inx in range(self.nfilters): 201 | self.trans1.append( 202 | nn.Linear( 203 | (self.bounds[freq_inx + 2] - self.bounds[freq_inx]) 204 | * self.in_dim, 205 | self.hidden_dim, 206 | bias=False, 207 | ) 208 | ) 209 | self.trans2.append( 210 | nn.Conv1d( 211 | self.hidden_dim, 212 | (self.bounds[freq_inx + 2] - self.bounds[freq_inx]) 213 | * self.out_dim, 214 | 1, 215 | ) 216 | ) 217 | 218 | def _hz_to_mel(self, freq: float, mel_scale: str = "htk") -> float: 219 | r""" 220 | Source: https://pytorch.org/audio/stable/ 221 | _modules/torchaudio/functional/functional.html 222 | Convert Hz to Mels. 223 | 224 | Args: 225 | freqs (float): Frequencies in Hz 226 | mel_scale (str, optional): Scale to use: 227 | ``htk`` or ``slaney``. (Default: ``htk``) 228 | 229 | Returns: 230 | mels (float): Frequency in Mels 231 | """ 232 | 233 | if mel_scale not in ["slaney", "htk"]: 234 | raise ValueError('mel_scale should be one of "htk" or "slaney".') 235 | 236 | if mel_scale == "htk": 237 | return 2595.0 * math.log10(1.0 + (freq / 700.0)) 238 | 239 | # Fill in the linear part 240 | f_min = 0.0 241 | f_sp = 200.0 / 3 242 | 243 | mels = (freq - f_min) / f_sp 244 | 245 | # Fill in the log-scale part 246 | min_log_hz = 1000.0 247 | min_log_mel = (min_log_hz - f_min) / f_sp 248 | logstep = math.log(6.4) / 27.0 249 | 250 | if freq >= min_log_hz: 251 | mels = min_log_mel + math.log(freq / min_log_hz) / logstep 252 | 253 | return mels 254 | 255 | def _mel_to_hz( 256 | self, mels: torch.Tensor, mel_scale: str = "htk" 257 | ) -> torch.Tensor: 258 | """ 259 | Source: https://pytorch.org/audio/stable/ 260 | _modules/torchaudio/functional/functional.html 261 | Convert mel bin numbers to frequencies. 262 | 263 | Args: 264 | mels (torch.Tensor): Mel frequencies 265 | mel_scale (str, optional): Scale to use: 266 | ``htk`` or ``slaney``. (Default: ``htk``) 267 | 268 | Returns: 269 | freqs (torch.Tensor): Mels converted in Hz 270 | """ 271 | 272 | if mel_scale not in ["slaney", "htk"]: 273 | raise ValueError('mel_scale should be one of "htk" or "slaney".') 274 | 275 | if mel_scale == "htk": 276 | return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) 277 | 278 | # Fill in the linear scale 279 | f_min = 0.0 280 | f_sp = 200.0 / 3 281 | freqs = f_min + f_sp * mels 282 | 283 | # And now the nonlinear scale 284 | min_log_hz = 1000.0 285 | min_log_mel = (min_log_hz - f_min) / f_sp 286 | logstep = math.log(6.4) / 27.0 287 | 288 | log_t = mels >= min_log_mel 289 | freqs[log_t] = min_log_hz * torch.exp( 290 | logstep * (mels[log_t] - min_log_mel) 291 | ) 292 | 293 | return freqs 294 | 295 | def forward(self, x, inverse): 296 | if inverse: 297 | # B C T F 298 | out = torch.zeros( 299 | [x.shape[0], self.out_dim, self.nfreq, x.shape[2]], 300 | dtype=x.dtype, 301 | layout=x.layout, 302 | device=x.device, 303 | ) 304 | for freq_inx in range(self.nfilters): 305 | out[ 306 | :, :, self.bounds[freq_inx]: self.bounds[freq_inx + 2], : 307 | ] = out[ 308 | :, :, self.bounds[freq_inx]: self.bounds[freq_inx + 2], : 309 | ] + self.trans2[ 310 | freq_inx 311 | ]( 312 | x[:, :, :, freq_inx] 313 | ).reshape( 314 | x.shape[0], self.out_dim, -1, x.shape[-2] 315 | ) 316 | out[:, :, self.bounds[1]: self.bounds[-2], :] = ( 317 | out[:, :, self.bounds[1]: self.bounds[-2], :] / 2.0 318 | ) 319 | out = out.permute(0, 1, 3, 2).contiguous().tanh() 320 | return out 321 | else: 322 | x = x.reshape(x.shape[0], self.in_dim, *x.shape[-2:]) # B C T F 323 | x = x.permute(0, 2, 1, 3).contiguous() # B T C F 324 | x = torch.stack( 325 | [ 326 | self.trans1[freq_inx]( 327 | x[ 328 | :, 329 | :, 330 | :, 331 | self.bounds[freq_inx]: self.bounds[freq_inx + 2], 332 | ].flatten(start_dim=2) 333 | ) 334 | for freq_inx in range(self.nfilters) 335 | ], 336 | -1, 337 | ) # B T C F 338 | x = x.permute(0, 2, 1, 3).contiguous() 339 | return x 340 | 341 | 342 | class UltraDualPath(nn.Module): 343 | def __init__( 344 | self, 345 | nfreq, 346 | in_dim, 347 | hidden_dim, 348 | out_dim, 349 | freq_cprs_ratio, 350 | time_cprs_ratio, 351 | ): 352 | super(UltraDualPath, self).__init__() 353 | self.nfreq = nfreq 354 | self.in_dim = in_dim 355 | self.hidden_dim = hidden_dim 356 | self.out_dim = out_dim 357 | self.freq_cprs_ratio = freq_cprs_ratio 358 | self.time_cprs_ratio = time_cprs_ratio 359 | self.compress_modules = [ 360 | TimeCompression( 361 | self.in_dim, 362 | self.in_dim * 2, 363 | self.out_dim, 364 | self.out_dim, 365 | time_cprs_ratio, 366 | ), 367 | FreqCompression( 368 | self.nfreq, 369 | self.nfreq // self.freq_cprs_ratio, 370 | self.in_dim * 2, 371 | self.hidden_dim, 372 | self.out_dim, 373 | ), 374 | ] 375 | 376 | def forward(self, x, inverse): 377 | out = x 378 | # print(out.shape) 379 | if inverse: 380 | for m in self.compress_modules[::-1]: 381 | out = m(out, inverse) 382 | # print(out.shape) 383 | else: 384 | for m in self.compress_modules: 385 | out = m(out, inverse) 386 | # print(out.shape) 387 | return out 388 | 389 | 390 | # # An example of run 391 | 392 | # ## Define model and input 393 | 394 | # ## Define network modules 395 | win_size = 320 396 | wavSTFT = WAVSTFT(win_size) # 16khz, 20ms/10ms -> 320 samples 397 | ultraCompress = UltraDualPath( 398 | win_size // 2 + 1, 3 * 2, 48, 48, 4, 4 399 | ) # 3ch*(real+imag) 400 | # ## Define input 401 | wav = torch.rand(2, 3, 16000) # nbatch=2, nchannels=3, nsamples=16000 402 | 403 | # ## Run 404 | 405 | # organize input 406 | B, C, N = wav.shape 407 | spec = wavSTFT.STFT(wav.reshape(B * C, N)) 408 | spec = torch.view_as_real(spec.reshape(B, C, *spec.shape[-2:])) # B C F T 2 409 | spec = spec.permute(0, 1, 4, 3, 2) # B C 2 T F 410 | spec = spec.reshape(B, C * 2, *spec.shape[-2:]) 411 | print(spec.shape) # nbatch, nchannels*2, nframes, nfreqs 412 | 413 | # compress 414 | latent = ultraCompress(spec, 0) 415 | print(latent.shape) # nbatch, hidden_dim, nframes, nfreqs 416 | 417 | # decompress 418 | output = ultraCompress(latent, 1) 419 | print(output.shape) # nbatch, out_dim, nframes, nfreqs 420 | --------------------------------------------------------------------------------