├── LICENSE ├── LLaSE-G1.png ├── README.md ├── ckpt ├── codec_ckpt │ └── hub │ │ ├── models--facebook--w2v-bert-2.0 │ │ ├── config.json │ │ └── preprocessor_config.json │ │ └── version.txt ├── download.sh └── download_ckpt.py ├── config └── test.yml ├── inference.py ├── inference.sh ├── loader ├── __pycache__ │ └── datareader_fe.cpython-310.pyc ├── datareader.py ├── datareader_aec.py └── datareader_tse.py ├── nnet ├── WavLM.py ├── __pycache__ │ ├── WavLM.cpython-310.pyc │ ├── embedding.cpython-310.pyc │ ├── llama.cpython-310.pyc │ └── modules.cpython-310.pyc ├── llase.py └── modules.py ├── requirements.txt └── vq ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── __init__.cpython-311.pyc ├── __init__.cpython-312.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── __init__.cpython-39.pyc ├── activations.cpython-310.pyc ├── activations.cpython-311.pyc ├── activations.cpython-312.pyc ├── activations.cpython-37.pyc ├── activations.cpython-38.pyc ├── activations.cpython-39.pyc ├── blocks.cpython-310.pyc ├── blocks.cpython-39.pyc ├── bs_roformer5.cpython-310.pyc ├── bs_roformer5.cpython-37.pyc ├── bs_roformer5.cpython-38.pyc ├── bs_roformer5.cpython-39.pyc ├── codec_decoder.cpython-310.pyc ├── codec_decoder.cpython-311.pyc ├── codec_decoder.cpython-312.pyc ├── codec_decoder.cpython-39.pyc ├── codec_decoder_vocos.cpython-310.pyc ├── codec_decoder_vocos.cpython-311.pyc ├── codec_decoder_vocos.cpython-312.pyc ├── codec_decoder_vocos.cpython-39.pyc ├── codec_encoder.cpython-310.pyc ├── codec_encoder.cpython-311.pyc ├── codec_encoder.cpython-312.pyc ├── codec_encoder.cpython-37.pyc ├── codec_encoder.cpython-38.pyc ├── codec_encoder.cpython-39.pyc ├── factorized_vector_quantize.cpython-310.pyc ├── factorized_vector_quantize.cpython-311.pyc ├── factorized_vector_quantize.cpython-312.pyc ├── factorized_vector_quantize.cpython-39.pyc ├── module.cpython-310.pyc ├── module.cpython-311.pyc ├── module.cpython-312.pyc ├── module.cpython-37.pyc ├── module.cpython-38.pyc ├── module.cpython-39.pyc ├── residual_vq.cpython-310.pyc ├── residual_vq.cpython-311.pyc ├── residual_vq.cpython-312.pyc ├── residual_vq.cpython-39.pyc ├── unet.cpython-312.pyc └── unet.cpython-39.pyc ├── activations.py ├── alias_free_torch ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-312.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── act.cpython-310.pyc │ ├── act.cpython-311.pyc │ ├── act.cpython-312.pyc │ ├── act.cpython-37.pyc │ ├── act.cpython-38.pyc │ ├── act.cpython-39.pyc │ ├── filter.cpython-310.pyc │ ├── filter.cpython-311.pyc │ ├── filter.cpython-312.pyc │ ├── filter.cpython-37.pyc │ ├── filter.cpython-38.pyc │ ├── filter.cpython-39.pyc │ ├── resample.cpython-310.pyc │ ├── resample.cpython-311.pyc │ ├── resample.cpython-312.pyc │ ├── resample.cpython-37.pyc │ ├── resample.cpython-38.pyc │ └── resample.cpython-39.pyc ├── act.py ├── filter.py └── resample.py ├── blocks.py ├── bs_roformer5.py ├── codec_decoder.py ├── codec_decoder_vocos.py ├── codec_encoder.py ├── factorized_vector_quantize.py ├── module.py ├── residual_vq.py └── unet.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LLaSE-G1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/LLaSE-G1.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLaSE-G1: Incentivizing Generalization Capability for LLaMA-based Speech Enhancement 2 | 3 |

4 | 5 | Paper 6 | 7 | 8 | Demo 9 | 10 | 11 | Hugging Face 12 | 13 |

14 | 15 | ![LLaSE-G1](LLaSE-G1.png) 16 | 17 | 18 | ## Introduction 19 | 20 | LLaSE-G1 is a unified speech enhancement model capable of handling multiple tasks without extra task prompts, including: 21 | 22 | - **Noise Suppression (SE)** 23 | - **Target Speaker Extraction (TSE)** 24 | - **Packet Loss Concealment (PLC)** 25 | - **Acoustic Echo Cancellation (AEC)** 26 | - **Speech Separation (SS)** 27 | 28 | To mitigate acoustic inconsistency, LLaSE-G1 employs continuous representations from **WavLM** as input and predicts speech tokens using **X-Codec2**, maximizing acoustic preservation. The model surpasses prior task-specific discriminative and generative speech enhancement models, demonstrating scaling effects at test time and emerging capabilities for unseen speech enhancement tasks. 29 | 30 | For more details, refer to our paper: [LLaSE-G1 Paper](https://arxiv.org/abs/2503.00493) 31 | 32 | ## Demo 33 | 34 | You can listen to the enhancement results on our [Demo Page](https://submission-papers.github.io/LLaSE-G1-demo-page/). 35 | 36 | ## Installation 37 | 38 | Checkpoints are at [huggingface](https://huggingface.co/ASLP-lab/LLaSE-G1). 39 | 40 | ### 1. Clone the repository 41 | 42 | ```bash 43 | git clone https://github.com/your-repo/LLaSE-G1.git 44 | cd LLaSE-G1 45 | ``` 46 | 47 | ### 2. Create a Conda environment and install dependencies 48 | 49 | ```bash 50 | conda create -n llase python=3.10 51 | conda activate llase 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ### 3. Download Pretrained Models 56 | 57 | LLaSE-G1 requires three additional pre-trained models and checkpoint of the middle LM on Huggingface to function properly. You can download three of them using the provided shell script: 58 | 59 | ```bash 60 | cd ckpt 61 | bash download.sh 62 | ``` 63 | Additionally, download WavLM-Large.pt from this [URL](https://drive.google.com/file/d/12-cB34qCTvByWT-QtOcZaqwwO21FLSqU/view) and put it at `./ckpt/WavLM-Large.pt` . 64 | 65 | Alternatively, you can download them manually and place them in the `./ckpt/` directory. 66 | 67 | After Downloading, the tree should be like this: 68 | 69 | ```bash 70 | ├── ckpt 71 | │ ├── codec_ckpt 72 | │ │ ├── epoch=4-step=1400000.ckpt 73 | │ │ └── hub 74 | │ │ ├── models--facebook--w2v-bert-2.0 75 | │ │ │ ├── config.json 76 | │ │ │ ├── model.safetensors 77 | │ │ │ └── preprocessor_config.json 78 | │ │ └── version.txt 79 | │ ├── download_ckpt.py 80 | │ ├── download.sh 81 | │ ├── model.pt.tar 82 | │ └── WavLM-Large.pt 83 | ``` 84 | 85 | ## Inference 86 | 87 | The main inference script is **`inference.py`**. The inference process consists of two stages: 88 | 89 | 1. Extract the 6th-layer features from WavLM. 90 | 2. Use the language model (LM) to predict speech tokens, and then decode them into audio using **X-Codec2**. 91 | 92 | ### Running Inference 93 | 94 | To run inference, configure the parameters in `./config/test.yml`: 95 | 96 | | Parameter | Description | 97 | | ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 98 | | `infer_feat_too` | Whether to extract WavLM features during inference. | 99 | | `inference_time` | Number of inference iterations. | 100 | | `feat_dir` | Directory containing extracted features. | 101 | | `wav_dir` | Directory of processed audio files. | 102 | | `task` | Task type: `SE` (Noise Suppression), `TSE` (Target Speaker Extraction), `PLC` (Packet Loss Concealment), `AEC` (Acoustic Echo Cancellation), `SS` (Speech Separation). | 103 | | `filename` | It should be the path of a text file, which contains the paths of the audio files you want to process. For example: /home/0.wav | 104 | 105 | Command to run inference: 106 | 107 | ```bash 108 | bash inference.sh 109 | ``` 110 | 111 | ## Results 112 | 113 | Samples processed by LLaSE-G1 can be found on our [Demo Page](https://submission-papers.github.io/LLaSE-G1-demo-page/). 114 | 115 | ## Model Checkpoints 116 | 117 | Our pretrained model is available on [Hugging Face](https://huggingface.co/ASLP-lab/LLaSE-G1). 118 | 119 | ## Hints 120 | 121 | Our approach focuses on leveraging the LLM's comprehension capabilities to enable autonomous determination of task types, though this may exhibit instability in certain scenarios. A more stable and robust iteration will be released in the upcoming version. 122 | 123 | ## Citation 124 | 125 | ``` 126 | @misc{kang2025llaseg1incentivizinggeneralizationcapability, 127 | title={LLaSE-G1: Incentivizing Generalization Capability for LLaMA-based Speech Enhancement}, 128 | author={Boyi Kang and Xinfa Zhu and Zihan Zhang and Zhen Ye and Mingshuai Liu and Ziqian Wang and Yike Zhu and Guobin Ma and Jun Chen and Longshuai Xiao and Chao Weng and Wei Xue and Lei Xie}, 129 | year={2025}, 130 | eprint={2503.00493}, 131 | archivePrefix={arXiv}, 132 | primaryClass={eess.AS}, 133 | url={https://arxiv.org/abs/2503.00493}, 134 | } 135 | ``` 136 | 137 | 138 | ## Contact 139 | 140 | For any questions, please contact: `beaukang02@gmail.com` 141 | ![image](https://github.com/user-attachments/assets/5c1b4f78-9906-4020-a686-b61d27f1716e) 142 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_dropout": 0.0, 3 | "adapter_act": "relu", 4 | "adapter_kernel_size": 3, 5 | "adapter_stride": 2, 6 | "add_adapter": false, 7 | "apply_spec_augment": false, 8 | "architectures": [ 9 | "Wav2Vec2BertModel" 10 | ], 11 | "attention_dropout": 0.0, 12 | "bos_token_id": 1, 13 | "classifier_proj_size": 768, 14 | "codevector_dim": 768, 15 | "conformer_conv_dropout": 0.1, 16 | "contrastive_logits_temperature": 0.1, 17 | "conv_depthwise_kernel_size": 31, 18 | "ctc_loss_reduction": "sum", 19 | "ctc_zero_infinity": false, 20 | "diversity_loss_weight": 0.1, 21 | "eos_token_id": 2, 22 | "feat_proj_dropout": 0.0, 23 | "feat_quantizer_dropout": 0.0, 24 | "feature_projection_input_dim": 160, 25 | "final_dropout": 0.1, 26 | "hidden_act": "swish", 27 | "hidden_dropout": 0.0, 28 | "hidden_size": 1024, 29 | "initializer_range": 0.02, 30 | "intermediate_size": 4096, 31 | "layer_norm_eps": 1e-05, 32 | "layerdrop": 0.1, 33 | "left_max_position_embeddings": 64, 34 | "mask_feature_length": 10, 35 | "mask_feature_min_masks": 0, 36 | "mask_feature_prob": 0.0, 37 | "mask_time_length": 10, 38 | "mask_time_min_masks": 2, 39 | "mask_time_prob": 0.05, 40 | "max_source_positions": 5000, 41 | "model_type": "wav2vec2-bert", 42 | "num_adapter_layers": 1, 43 | "num_attention_heads": 16, 44 | "num_codevector_groups": 2, 45 | "num_codevectors_per_group": 320, 46 | "num_hidden_layers": 24, 47 | "num_negatives": 100, 48 | "output_hidden_size": 1024, 49 | "pad_token_id": 0, 50 | "position_embeddings_type": "relative_key", 51 | "proj_codevector_dim": 768, 52 | "right_max_position_embeddings": 8, 53 | "rotary_embedding_base": 10000, 54 | "tdnn_dilation": [ 55 | 1, 56 | 2, 57 | 3, 58 | 1, 59 | 1 60 | ], 61 | "tdnn_dim": [ 62 | 512, 63 | 512, 64 | 512, 65 | 512, 66 | 1500 67 | ], 68 | "tdnn_kernel": [ 69 | 5, 70 | 3, 71 | 3, 72 | 1, 73 | 1 74 | ], 75 | "torch_dtype": "float32", 76 | "transformers_version": "4.37.0.dev0", 77 | "use_intermediate_ffn_before_adapter": false, 78 | "use_weighted_layer_sum": false, 79 | "vocab_size": null, 80 | "xvector_output_dim": 512 81 | } 82 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "feature_extractor_type": "SeamlessM4TFeatureExtractor", 3 | "feature_size": 80, 4 | "num_mel_bins": 80, 5 | "padding_side": "right", 6 | "padding_value": 1, 7 | "processor_class": "Wav2Vec2BertProcessor", 8 | "return_attention_mask": true, 9 | "sampling_rate": 16000, 10 | "stride": 2 11 | } 12 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/version.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /ckpt/download.sh: -------------------------------------------------------------------------------- 1 | python download_script.py \ 2 | --source hf \ 3 | --repo_id facebook/w2v-bert-2.0 \ 4 | --filename model.safetensors \ 5 | --save_path \ 6 | ./codec_ckpt/hub/models--facebook--w2v-bert-2.0/model.safetensors 7 | 8 | python download_script.py \ 9 | --source hf \ 10 | --repo_id HKUSTAudio/xcodec2 \ 11 | --filename ckpt/epoch=4-step=1400000.ckpt \ 12 | --save_path ./codec_ckpt/epoch=4-step=1400000.ckpt 13 | 14 | python download_script.py \ 15 | --source hf \ 16 | --repo_id ASLP-lab/LLaSE-G1 \ 17 | --filename ckpt/model.pt.tar \ 18 | --save_path ./model.pt.tar 19 | -------------------------------------------------------------------------------- /ckpt/download_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import argparse 4 | from huggingface_hub import hf_hub_download 5 | from tqdm import tqdm 6 | 7 | def download_from_url(url, save_path): 8 | """Download a file from a given URL and save it locally.""" 9 | response = requests.get(url, stream=True) 10 | total_size = int(response.headers.get("content-length", 0)) 11 | block_size = 1024 # 1 KB 12 | progress_bar = tqdm(total=total_size, unit="B", unit_scale=True) 13 | 14 | with open(save_path, "wb") as file: 15 | for data in response.iter_content(block_size): 16 | progress_bar.update(len(data)) 17 | file.write(data) 18 | progress_bar.close() 19 | 20 | if total_size != 0 and progress_bar.n != total_size: 21 | print("Download failed!") 22 | else: 23 | print(f"File downloaded to: {save_path}") 24 | 25 | def download_from_hf(repo_id, filename, save_path): 26 | """Download a file from Hugging Face Hub.""" 27 | print(f"Downloading from Hugging Face Hub: {repo_id}/{filename}") 28 | try: 29 | hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.path.dirname(save_path), local_dir_use_symlinks=False) 30 | print(f"File downloaded to: {save_path}") 31 | except Exception as e: 32 | print(f"Download failed: {e}") 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser(description="Automatically download model checkpoints") 36 | parser.add_argument("--source", type=str, required=True, choices=["hf", "url"], help="Download source: hf (Hugging Face Hub) or url (custom URL)") 37 | parser.add_argument("--repo_id", type=str, help="Hugging Face model repository ID (e.g., google/bert-base-uncased)") 38 | parser.add_argument("--filename", type=str, help="Filename in the Hugging Face repository") 39 | parser.add_argument("--url", type=str, help="Custom download URL") 40 | parser.add_argument("--save_path", type=str, required=True, help="Path to save the file (including filename)") 41 | args = parser.parse_args() 42 | 43 | # Ensure the save directory exists 44 | os.makedirs(os.path.dirname(args.save_path), exist_ok=True) 45 | 46 | if args.source == "hf": 47 | if not args.repo_id or not args.filename: 48 | print("Please provide a Hugging Face repository ID and filename!") 49 | return 50 | download_from_hf(args.repo_id, args.filename, args.save_path) 51 | elif args.source == "url": 52 | if not args.url: 53 | print("Please provide a download URL!") 54 | return 55 | download_from_url(args.url, args.save_path) 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /config/test.yml: -------------------------------------------------------------------------------- 1 | test: 2 | checkpoint: ./ckpt/model.pt.tar 3 | use_cuda: True 4 | infer_feat_too: True 5 | inference_time: 1 6 | 7 | save: 8 | feat_dir: ./decode/feat/se 9 | wav_dir: ./decode/wav/se 10 | 11 | task: SE #PLC,AEC,SS,TSE 12 | 13 | # LLaSE config 14 | nnet_conf: 15 | d_model: 1024 16 | nhead: 16 17 | num_layers: 16 18 | 19 | datareader: 20 | sample_rate: 16000 21 | filename: /path/to/your/filelist -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import librosa 5 | import yaml 6 | import joblib 7 | import argparse 8 | 9 | import soundfile as sf 10 | import numpy as np 11 | 12 | from pathlib import Path 13 | from collections import defaultdict 14 | from typing import Optional 15 | from tqdm import tqdm 16 | 17 | sys.path.append(os.path.dirname(__file__)) 18 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 19 | 20 | # Torch 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | import torch.distributed as dist 25 | from torch.nn.parallel import DistributedDataParallel 26 | 27 | # WavLM 28 | from nnet.WavLM import WavLM, WavLMConfig 29 | 30 | # Xcodec2 31 | from vq.codec_encoder import CodecEncoder_Transformer 32 | from vq.codec_decoder_vocos import CodecDecoderVocos 33 | from vq.module import SemanticEncoder 34 | from transformers import AutoFeatureExtractor, Wav2Vec2BertModel 35 | from collections import OrderedDict 36 | 37 | # Dataloader 38 | from loader.datareader import DataReader 39 | from loader.datareader_aec import DataReaderAEC 40 | from loader.datareader_tse import DataReaderTSE 41 | 42 | # LLaSE 43 | from nnet.llase import LLM_AR as model 44 | 45 | class Encodec(): 46 | ''' 47 | Load Xcodec2 48 | ''' 49 | def __init__(self,device="cpu") -> None: 50 | self.device=device 51 | ckpt = "./ckpt/codec_ckpt/epoch=4-step=1400000.ckpt" 52 | # ckpt = '/home/bykang/codec_ckpt/epoch=4-step=1400000.ckpt' 53 | ckpt = torch.load(ckpt, map_location='cpu') 54 | state_dict = ckpt['state_dict'] 55 | filtered_state_dict_codec = OrderedDict() 56 | filtered_state_dict_semantic_encoder = OrderedDict() 57 | filtered_state_dict_gen = OrderedDict() 58 | filtered_state_dict_fc_post_a = OrderedDict() 59 | filtered_state_dict_fc_prior = OrderedDict() 60 | for key, value in state_dict.items(): 61 | if key.startswith('CodecEnc.'): 62 | new_key = key[len('CodecEnc.'):] 63 | filtered_state_dict_codec[new_key] = value 64 | elif key.startswith('generator.'): 65 | new_key = key[len('generator.'):] 66 | filtered_state_dict_gen[new_key] = value 67 | elif key.startswith('fc_post_a.'): 68 | new_key = key[len('fc_post_a.'):] 69 | filtered_state_dict_fc_post_a[new_key] = value 70 | elif key.startswith('SemanticEncoder_module.'): 71 | new_key = key[len('SemanticEncoder_module.'):] 72 | filtered_state_dict_semantic_encoder[new_key] = value 73 | elif key.startswith('fc_prior.'): 74 | new_key = key[len('fc_prior.'):] 75 | filtered_state_dict_fc_prior[new_key] = value 76 | 77 | self.semantic_model = Wav2Vec2BertModel.from_pretrained( 78 | "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0", 79 | # "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b", 80 | output_hidden_states=True) 81 | self.semantic_model=self.semantic_model.eval().to(self.device) 82 | 83 | self.SemanticEncoder_module = SemanticEncoder(1024,1024,1024) 84 | self.SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder) 85 | self.SemanticEncoder_module = self.SemanticEncoder_module.eval().to(self.device) 86 | 87 | self.encoder = CodecEncoder_Transformer() 88 | self.encoder.load_state_dict(filtered_state_dict_codec) 89 | self.encoder = self.encoder.eval().to(self.device) 90 | 91 | self.decoder = CodecDecoderVocos() 92 | self.decoder.load_state_dict(filtered_state_dict_gen) 93 | self.decoder = self.decoder.eval().to(self.device) 94 | 95 | self.fc_post_a = nn.Linear( 2048, 1024 ) 96 | self.fc_post_a.load_state_dict(filtered_state_dict_fc_post_a) 97 | self.fc_post_a = self.fc_post_a.eval().to(self.device) 98 | 99 | self.fc_prior = nn.Linear( 2048, 2048 ) 100 | self.fc_prior.load_state_dict(filtered_state_dict_fc_prior) 101 | self.fc_prior = self.fc_prior.eval().to(self.device) 102 | 103 | self.feature_extractor = AutoFeatureExtractor.from_pretrained( 104 | "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0") 105 | # "/home/bykang/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b") 106 | 107 | def get_feat(self, wav_batch, pad=None): 108 | 109 | if len(wav_batch.shape) != 2: 110 | return self.feature_extractor(F.pad(wav_batch, pad), sampling_rate=16000, return_tensors="pt") .data['input_features'] 111 | 112 | padded_wavs = torch.stack([F.pad(wav, pad) for wav in wav_batch]) 113 | batch_feats = [] 114 | 115 | for wav in padded_wavs: 116 | feat = self.feature_extractor( 117 | wav, 118 | sampling_rate=16000, 119 | return_tensors="pt" 120 | ).data['input_features'] 121 | 122 | batch_feats.append(feat) 123 | feat_batch = torch.concat(batch_feats, dim=0).to(self.device) 124 | return feat_batch 125 | 126 | def get_embedding(self, wav_cpu): 127 | wav_cpu = wav_cpu.cpu() 128 | feat = self.get_feat(wav_cpu,pad=(160,160)) 129 | feat = feat.to(self.device) 130 | 131 | if(len(wav_cpu.shape)==1): 132 | wav = wav_cpu.unsqueeze(0).to(self.device) 133 | else: 134 | wav = wav_cpu.to(self.device) 135 | 136 | wav = torch.nn.functional.pad(wav, (0, (200 - (wav.shape[1] % 200)))) 137 | with torch.no_grad(): 138 | vq_emb = self.encoder(wav.unsqueeze(1)) 139 | vq_emb = vq_emb.transpose(1, 2) 140 | 141 | if vq_emb.shape[2]!=feat.shape[1]: 142 | feat = self.get_feat(wav_cpu) 143 | feat = feat.to(self.device) 144 | 145 | semantic_target = self.semantic_model(feat[:, :,:]) 146 | semantic_target = semantic_target.hidden_states[16] 147 | semantic_target = semantic_target.transpose(1, 2) 148 | semantic_target = self.SemanticEncoder_module(semantic_target) 149 | 150 | vq_emb = torch.cat([semantic_target, vq_emb], dim=1) 151 | 152 | return vq_emb 153 | 154 | def emb2token(self, emb): 155 | emb.to(self.device) 156 | emb = self.fc_prior(emb.transpose(1, 2)).transpose(1, 2) 157 | _, vq_code, _ = self.decoder(emb, vq=True) 158 | return vq_code 159 | 160 | def token2wav(self, vq_code): 161 | vq_code.to(self.device) 162 | vq_post_emb = self.decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) 163 | vq_post_emb = vq_post_emb.transpose(1, 2) 164 | vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2) 165 | recon = self.decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze() 166 | # if write the wav, add .squeeze().detach().cpu().numpy() 167 | # if need gradient use the config right now 168 | return recon 169 | 170 | class WavLM_feat(object): 171 | ''' 172 | Load WavLM 173 | ''' 174 | def __init__(self, device): 175 | self.wavlm = self._reload_wavLM_large(device=device) 176 | 177 | def __call__(self, wav): 178 | T = wav.shape[-1] 179 | wav = wav.reshape(-1, T) 180 | with torch.no_grad(): 181 | feat = self.wavlm.extract_features(wav, output_layer=6, ret_layer_results=False)[0] 182 | B, T, D = feat.shape 183 | feat = torch.reshape(feat, (-1, D)) 184 | 185 | return feat 186 | 187 | def _reload_wavLM_large(self, path="./ckpt/WavLM-Large.pt", device: Optional[torch.device] = None): 188 | cpt = torch.load(path, map_location="cpu") 189 | cfg = WavLMConfig(cpt['cfg']) 190 | wavLM = WavLM(cfg) 191 | wavLM.load_state_dict(cpt['model']) 192 | wavLM.eval() 193 | if device != None: 194 | wavLM = wavLM.to(device) 195 | for p in wavLM.parameters(): 196 | p.requires_grad = False 197 | print('successful to reload wavLM', path) 198 | return wavLM 199 | 200 | def get_firstchannel_read(path, fs=16000): 201 | ''' 202 | Get first channel of the wav 203 | ''' 204 | wave_data, sr = sf.read(path) 205 | if sr != fs: 206 | if len(wave_data.shape) != 1: 207 | wave_data = wave_data.transpose((1, 0)) 208 | wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) 209 | if len(wave_data.shape) != 1: 210 | wave_data = wave_data.transpose((1, 0)) 211 | if len(wave_data.shape) > 1: 212 | wave_data = wave_data[:, 0] 213 | return wave_data 214 | 215 | def load_obj(obj, device): 216 | ''' 217 | Offload tensor object in obj to cuda device 218 | ''' 219 | def cuda(obj): 220 | return obj.to(device) if isinstance(obj, torch.Tensor) else obj 221 | 222 | if isinstance(obj, dict): 223 | return {key: load_obj(obj[key], device) for key in obj} 224 | elif isinstance(obj, list): 225 | return [load_obj(val, device) for val in obj] 226 | else: 227 | return cuda(obj) 228 | 229 | def run(args): 230 | LOCAL_RANK = int(os.environ['LOCAL_RANK']) 231 | WORLD_SIZE = int(os.environ['WORLD_SIZE']) 232 | WORLD_RANK = int(os.environ['RANK']) 233 | dist.init_process_group(args.backend, rank=WORLD_RANK, world_size=WORLD_SIZE) 234 | torch.cuda.set_device(LOCAL_RANK) 235 | device = torch.device('cuda', LOCAL_RANK) 236 | print(f"[{os.getpid()}] using device: {device}", torch.cuda.current_device(), "local rank", LOCAL_RANK) 237 | 238 | with open(args.conf, "r") as f: 239 | conf = yaml.load(f, Loader=yaml.FullLoader) 240 | 241 | # Dataloader 242 | if conf["task"]=="AEC": 243 | data_reader = DataReaderAEC(**conf["datareader"]) 244 | elif conf["task"]=="TSE": 245 | data_reader = DataReaderTSE(**conf["datareader"]) 246 | else: 247 | data_reader = DataReader(**conf["datareader"]) 248 | 249 | # Load WavLM and XCodec2 250 | codec = Encodec(device) 251 | wavlm_feat = WavLM_feat(device) 252 | 253 | # Load LLaSE 254 | nnet = model(**conf["nnet_conf"]) 255 | cpt_fname = Path(conf["test"]["checkpoint"]) 256 | cpt = torch.load(cpt_fname, map_location="cpu") 257 | 258 | nnet = nnet.to(device) 259 | nnet = DistributedDataParallel(nnet, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, find_unused_parameters=True) 260 | nnet.load_state_dict(cpt["model_state_dict"]) 261 | nnet.eval() 262 | 263 | # Make sure the dir exists 264 | if conf["task"]=="AEC": 265 | if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): 266 | os.makedirs(conf["save"]["feat_dir"]+"/mic") 267 | if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): 268 | os.makedirs(conf["save"]["feat_dir"]+"/ref") 269 | elif conf["task"]=="TSE": 270 | if not os.path.exists(conf["save"]["feat_dir"]+"/mic"): 271 | os.makedirs(conf["save"]["feat_dir"]+"/mic") 272 | if not os.path.exists(conf["save"]["feat_dir"]+"/ref"): 273 | os.makedirs(conf["save"]["feat_dir"]+"/ref") 274 | else: 275 | if not os.path.exists(conf["save"]["feat_dir"]): 276 | os.makedirs(conf["save"]["feat_dir"]) 277 | 278 | if not os.path.exists(conf["save"]["wav_dir"]): 279 | os.makedirs(conf["save"]["wav_dir"]) 280 | 281 | # Main of inference 282 | if_feat_too = conf["test"]["infer_feat_too"] 283 | 284 | origin_feat_dir = conf["save"]["feat_dir"] 285 | origin_wav_dir = conf["save"]["wav_dir"] 286 | 287 | last_feat_dir = origin_feat_dir 288 | last_wav_dir = origin_wav_dir 289 | 290 | for inference_time in range(conf["test"]["inference_time"]): 291 | # For multi-inference 292 | if inference_time > 0: 293 | feat_dir = origin_feat_dir + "inference" + str(inference_time) 294 | wav_dir = origin_wav_dir + "inference" + str(inference_time) 295 | else: 296 | feat_dir = origin_feat_dir 297 | wav_dir = origin_wav_dir 298 | 299 | if not os.path.exists(feat_dir): 300 | os.makedirs(feat_dir) 301 | if not os.path.exists(wav_dir): 302 | os.makedirs(wav_dir) 303 | 304 | with torch.no_grad(): 305 | # Extract WavLM features 306 | if if_feat_too ==True or inference_time>0: 307 | for egs in tqdm(data_reader): 308 | egs = load_obj(egs, device) 309 | 310 | if conf["task"]=="AEC" or conf["task"]=="TSE": 311 | if inference_time > 0: 312 | mic_path = last_wav_dir + '/' + egs["mic_name"] + ".wav" 313 | egs["mic"] = torch.from_numpy(get_firstchannel_read(mic_path).astype(np.float32)).unsqueeze(0).to(device) 314 | else: 315 | egs["mic"]=egs["mic"].contiguous() 316 | egs["ref"]=egs["ref"].contiguous() 317 | 318 | feat_mic = wavlm_feat(egs["mic"]) 319 | out_mic = feat_mic.detach().squeeze(0).cpu().numpy() 320 | 321 | if not os.path.exists(os.path.join(feat_dir, "mic")): 322 | os.makedirs(os.path.join(feat_dir, "mic")) 323 | np.save(os.path.join(feat_dir, "mic", egs["mic_name"]), out_mic) 324 | 325 | # For AEC and TSE, reference audio only need to extract feats at first time 326 | if inference_time == 0: 327 | feat_ref = wavlm_feat(egs["ref"]) 328 | out_ref = feat_ref.detach().squeeze(0).cpu().numpy() 329 | np.save(os.path.join(origin_feat_dir, "ref", egs["ref_name"]), out_ref) 330 | 331 | torch.cuda.empty_cache() 332 | 333 | else: 334 | if inference_time > 0: 335 | mix_path = last_wav_dir + '/' + egs["name"] + ".wav" 336 | egs["mix"] = torch.from_numpy(get_firstchannel_read(mix_path).astype(np.float32)).unsqueeze(0).to(device) 337 | else: 338 | egs["mix"]=egs["mix"].contiguous() 339 | 340 | feat = wavlm_feat(egs["mix"]) 341 | out = feat.detach().squeeze(0).cpu().numpy() 342 | np.save(os.path.join(feat_dir, egs["name"]), out) 343 | 344 | # Predict the clean tokens and token2wav 345 | for egs in tqdm(data_reader): 346 | egs = load_obj(egs, device) 347 | sr = 16000 348 | 349 | if conf["task"] == "AEC": 350 | # Get feat 351 | feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" 352 | feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" 353 | 354 | feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) 355 | feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) 356 | 357 | # For multi-inference 358 | if inference_time > 0: 359 | est = nnet(feat_mic) 360 | else: 361 | est = nnet(feat_mic, feat_ref) 362 | 363 | # Get tokens and token2wav 364 | max, max_indices_1 = torch.max(est[1], dim=1) 365 | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() 366 | 367 | # Save the wav 368 | target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") 369 | print(target_path) 370 | sf.write(target_path , recon_1, sr) 371 | 372 | elif conf["task"] == "TSE" : 373 | # Get feat 374 | feat_path_mic = os.path.join(feat_dir, "mic", egs["mic_name"]) + ".npy" 375 | feat_path_ref = os.path.join(origin_feat_dir, "ref", egs["ref_name"]) + ".npy" 376 | 377 | feat_mic = torch.from_numpy(np.load(feat_path_mic)).unsqueeze(0) 378 | feat_ref = torch.from_numpy(np.load(feat_path_ref)).unsqueeze(0) 379 | 380 | # Choose if keep the enroallment audio while multi-inference 381 | if_keep_ref = True 382 | 383 | if inference_time>0 and if_keep_ref== False: 384 | est = nnet(feat_mic) 385 | else: 386 | est = nnet(feat_mic, feat_ref) 387 | 388 | # Get tokens and token2wav 389 | max, max_indices_1 = torch.max(est[0], dim=1) 390 | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() 391 | 392 | # Save the wav 393 | target_path = os.path.join(wav_dir, egs["mic_name"] + ".wav") 394 | print(target_path) 395 | sf.write(target_path , recon_1, sr) 396 | 397 | elif conf["task"] == "PLC": 398 | # Get feat 399 | feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" 400 | feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) 401 | 402 | # Get tokens and token2wav 403 | est = nnet(feat) 404 | max, max_indices_1 = torch.max(est[1], dim=1) 405 | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() 406 | 407 | # Save the wav 408 | target_path = os.path.join(wav_dir, egs["name"] + ".wav") 409 | print(target_path) 410 | sf.write(target_path , recon_1, sr) 411 | 412 | elif conf["task"] == "SS": 413 | # Get feat 414 | feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" 415 | feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) 416 | 417 | # Separate the first speaker 418 | est = nnet(feat) 419 | max, max_indices_1 = torch.max(est[1], dim=1) 420 | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() 421 | 422 | target_path_1 = os.path.join(wav_dir, egs["name"] + ".wav") 423 | sf.write(target_path_1 , recon_1, sr) 424 | 425 | # Separate the second speaker, SS need at least 2 inference time in config 426 | if inference_time > 0: 427 | origin_feat_path = os.path.join(origin_feat_dir, egs["name"]) + ".npy" 428 | origin_feat = torch.from_numpy(np.load(origin_feat_path)).unsqueeze(0) 429 | 430 | est2 = nnet(origin_feat, feat) 431 | max, max_indices_2 = torch.max(est2[1], dim=1) 432 | recon_2 = codec.token2wav(max_indices_2.unsqueeze(0)).squeeze().detach().cpu().numpy() 433 | 434 | if not os.path.exists(last_wav_dir + "s2"): 435 | os.makedirs(last_wav_dir + "s2") 436 | 437 | target_path_2 = os.path.join(last_wav_dir + "s2", egs["name"] + ".wav") 438 | sf.write(target_path_2 , recon_2, sr) 439 | 440 | else: 441 | # Get feat 442 | feat_path = os.path.join(feat_dir, egs["name"]) + ".npy" 443 | feat = torch.from_numpy(np.load(feat_path)).unsqueeze(0) 444 | 445 | # Get tokens and token2wav 446 | est = nnet(feat) 447 | max, max_indices_1 = torch.max(est[1], dim=1) 448 | recon_1 = codec.token2wav(max_indices_1.unsqueeze(0)).squeeze().detach().cpu().numpy() 449 | 450 | # Save the wav 451 | target_path = os.path.join(wav_dir, egs["name"] + ".wav") 452 | print(target_path) 453 | sf.write(target_path , recon_1, sr) 454 | 455 | # For next inference 456 | last_feat_dir = feat_dir 457 | last_wav_dir = wav_dir 458 | 459 | if __name__ == "__main__": 460 | parser = argparse.ArgumentParser( 461 | description = "Command to test separation model in Pytorch", 462 | formatter_class = argparse.ArgumentDefaultsHelpFormatter) 463 | parser.add_argument("-conf", 464 | type=str, 465 | required=True, 466 | help="Yaml configuration file for training") 467 | parser.add_argument("--backend", 468 | type=str, 469 | default="nccl", 470 | choices=["nccl", "gloo"]) 471 | args = parser.parse_args() 472 | # for nccl debug 473 | os.environ["NCCL_DEBUG"] = "INFO" 474 | run(args) 475 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 torchrun \ 2 | --nnodes=1 \ 3 | --nproc_per_node=1 \ 4 | --master_port=21547 \ 5 | inference.py \ 6 | -conf ./config/test.yml -------------------------------------------------------------------------------- /loader/__pycache__/datareader_fe.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/loader/__pycache__/datareader_fe.cpython-310.pyc -------------------------------------------------------------------------------- /loader/datareader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchaudio 3 | import torch 4 | 5 | def get_firstchannel_read(path, fs=16000): 6 | wave_data, sr = torchaudio.load(path) 7 | if sr != fs: 8 | wave_data = torchaudio.functional.resample(wave_data, sr, fs) 9 | if len(wave_data.shape) > 1: 10 | wave_data = wave_data[0,...] 11 | wave_data = wave_data.cpu().numpy() 12 | return wave_data 13 | 14 | def parse_scp(scp, path_list): 15 | with open(scp) as fid: 16 | for line in fid: 17 | tmp = line.strip().split() 18 | if len(tmp) > 1: 19 | path_list.append({"inputs": tmp[0], "duration": tmp[1]}) 20 | else: 21 | path_list.append({"inputs": tmp[0]}) 22 | 23 | class DataReader(object): 24 | def __init__(self, filename, sample_rate): 25 | self.file_list = [] 26 | self.sample_rate = sample_rate 27 | parse_scp(filename, self.file_list) 28 | 29 | def extract_feature(self, path): 30 | path = path["inputs"] 31 | name = path.split("/")[-1].split(".")[0] 32 | data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32) 33 | max_norm = np.max(np.abs(data)) 34 | if max_norm == 0: 35 | max_norm = 1 36 | data = data / max_norm 37 | inputs = np.reshape(data, [1, data.shape[0]]) 38 | inputs = torch.from_numpy(inputs) 39 | 40 | egs = { 41 | "mix": inputs, 42 | "max_norm": max_norm, 43 | "name": name 44 | } 45 | return egs 46 | 47 | def __len__(self): 48 | return len(self.file_list) 49 | 50 | def __getitem__(self, index): 51 | return self.extract_feature(self.file_list[index]) 52 | 53 | def get_utt2spk(self, path): 54 | lines = open(path, "r").readlines() 55 | for line in lines: 56 | line = line.strip().split() 57 | utt_path, spk_id = line[0], line[1] 58 | self.utt2spk[utt_path] = spk_id 59 | 60 | def get_spk2utt(self, path): 61 | lines = open(path, "r").readlines() 62 | for line in lines: 63 | line = line.strip().split() 64 | utt_path, spk_id = line[0], line[1] 65 | self.spk2aux[spk_id] = utt_path 66 | -------------------------------------------------------------------------------- /loader/datareader_aec.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch as th 3 | import numpy as np 4 | import soundfile as sf 5 | 6 | import sys, os 7 | sys.path.append(os.path.dirname(__file__)) 8 | # from speex_linear.lp_or_tde import LP_or_TDE 9 | 10 | 11 | def audio(path, fs=16000): 12 | wave_data, sr = sf.read(path) 13 | if sr != fs: 14 | wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) 15 | return wave_data 16 | 17 | def get_firstchannel_read(path, fs=16000): 18 | wave_data, sr = sf.read(path) 19 | if sr != fs: 20 | if len(wave_data.shape) != 1: 21 | wave_data = wave_data.transpose((1, 0)) 22 | wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) 23 | if len(wave_data.shape) != 1: 24 | wave_data = wave_data.transpose((1, 0)) 25 | if len(wave_data.shape) > 1: 26 | wave_data = wave_data[:, 0] 27 | return wave_data 28 | 29 | def parse_scp(scp, path_list): 30 | with open(scp) as fid: 31 | for line in fid: 32 | tmp = line.strip().split() 33 | if len(tmp) > 1: 34 | path_list.append({"inputs": tmp[0], "duration": tmp[1]}) 35 | else: 36 | path_list.append({"inputs": tmp[0]}) 37 | 38 | class DataReaderAEC(object): 39 | def __init__(self, filename, sample_rate): #, aux_segment): # filename是不带id的待解码音频,noisy_id是带id的带解码音频,clean是带id的注册音频 40 | self.file_list = [] 41 | parse_scp(filename, self.file_list) 42 | self.sample_rate = sample_rate 43 | 44 | # self.aux_segment_length = aux_segment * sample_rate 45 | 46 | def extract_feature(self, path): 47 | mic_path = path["inputs"] 48 | utt_id = mic_path.split("/")[-1] 49 | mic_name = mic_path.split("/")[-1].split(".")[0] 50 | 51 | ref_path = mic_path.replace("mic.wav", "lpb.wav") 52 | ref_name = ref_path.split("/")[-1].split(".")[0] 53 | 54 | mic = get_firstchannel_read(mic_path, self.sample_rate).astype(np.float32) 55 | ref = get_firstchannel_read(ref_path, self.sample_rate).astype(np.float32) 56 | 57 | min_len = min(mic.shape[0], ref.shape[0]) 58 | mic = mic[:min_len] 59 | ref = ref[:min_len] 60 | 61 | inputs_mic = np.reshape(mic, [1, mic.shape[0]]) 62 | inputs_ref = np.reshape(ref, [1, ref.shape[0]]).astype(np.float32) 63 | 64 | 65 | inputs_mic = th.from_numpy(inputs_mic) 66 | inputs_ref = th.from_numpy(inputs_ref) 67 | 68 | # print(f'e: {inputs_e.shape}') 69 | # print(f'mic: {inputs_mic.shape}') 70 | # print(f'ref: {inputs_ref.shape}') 71 | 72 | egs = { 73 | "mic": inputs_mic, 74 | "ref": inputs_ref, 75 | "utt_id": utt_id, 76 | "mic_name": mic_name, 77 | "ref_name": ref_name 78 | # "max_norm": max_norm 79 | } 80 | return egs 81 | 82 | def __len__(self): 83 | return len(self.file_list) 84 | 85 | def __getitem__(self, index): 86 | return self.extract_feature(self.file_list[index]) 87 | -------------------------------------------------------------------------------- /loader/datareader_tse.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch as th 3 | import numpy as np 4 | import soundfile as sf 5 | 6 | import sys, os 7 | sys.path.append(os.path.dirname(__file__)) 8 | # from speex_linear.lp_or_tde import LP_or_TDE 9 | 10 | 11 | def audio(path, fs=16000): 12 | wave_data, sr = sf.read(path) 13 | if sr != fs: 14 | wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) 15 | return wave_data 16 | 17 | def get_firstchannel_read(path, fs=16000): 18 | wave_data, sr = sf.read(path) 19 | if sr != fs: 20 | if len(wave_data.shape) != 1: 21 | wave_data = wave_data.transpose((1, 0)) 22 | wave_data = librosa.resample(wave_data, orig_sr=sr, target_sr=fs) 23 | if len(wave_data.shape) != 1: 24 | wave_data = wave_data.transpose((1, 0)) 25 | if len(wave_data.shape) > 1: 26 | wave_data = wave_data[:, 0] 27 | return wave_data 28 | 29 | def parse_scp(scp, path_list): 30 | with open(scp) as fid: 31 | for line in fid: 32 | tmp = line.strip().split() 33 | if len(tmp) > 1: 34 | path_list.append({"inputs": tmp[0], "duration": tmp[1]}) 35 | else: 36 | path_list.append({"inputs": tmp[0]}) 37 | 38 | class DataReaderTSE(object): 39 | def __init__(self, filename, sample_rate): 40 | self.file_list = [] 41 | parse_scp(filename, self.file_list) 42 | self.sample_rate = sample_rate 43 | 44 | 45 | def extract_feature(self, path): 46 | mic_path = path["inputs"] 47 | utt_id = mic_path.split("/")[-1] 48 | mic_name = mic_path.split("/")[-1].split(".")[0] 49 | 50 | ref_path = mic_path.replace("noisy/", "enrol/") 51 | ref_name = ref_path.split("/")[-1].split(".")[0] 52 | 53 | mic = get_firstchannel_read(mic_path, self.sample_rate).astype(np.float32) 54 | ref = get_firstchannel_read(ref_path, self.sample_rate).astype(np.float32) 55 | 56 | if ref.shape[0] > mic.shape[0]: 57 | min_len = mic.shape[0] 58 | ref = ref[:min_len] 59 | 60 | inputs_mic = np.reshape(mic, [1, mic.shape[0]]).astype(np.float32) 61 | inputs_ref = np.reshape(ref, [1, ref.shape[0]]).astype(np.float32) 62 | 63 | inputs_mic = th.from_numpy(inputs_mic) 64 | inputs_ref = th.from_numpy(inputs_ref) 65 | 66 | # print(f'e: {inputs_e.shape}') 67 | # print(f'mic: {inputs_mic.shape}') 68 | # print(f'ref: {inputs_ref.shape}') 69 | 70 | egs = { 71 | "mic": inputs_mic, 72 | "ref": inputs_ref, 73 | "utt_id": utt_id, 74 | "mic_name": mic_name, 75 | "ref_name": ref_name 76 | # "max_norm": max_norm 77 | } 78 | return egs 79 | 80 | def __len__(self): 81 | return len(self.file_list) 82 | 83 | def __getitem__(self, index): 84 | return self.extract_feature(self.file_list[index]) 85 | -------------------------------------------------------------------------------- /nnet/WavLM.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import logging 12 | from typing import List, Optional, Tuple 13 | 14 | import sys,os 15 | sys.path.append(os.path.dirname(sys.path[0])) 16 | import numpy as np 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn import LayerNorm 22 | from nnet.modules import ( 23 | Fp32GroupNorm, 24 | Fp32LayerNorm, 25 | GradMultiply, 26 | MultiheadAttention, 27 | SamePad, 28 | init_bert_params, 29 | get_activation_fn, 30 | TransposeLast, 31 | GLU_Linear, 32 | ) 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def compute_mask_indices( 38 | shape: Tuple[int, int], 39 | padding_mask: Optional[torch.Tensor], 40 | mask_prob: float, 41 | mask_length: int, 42 | mask_type: str = "static", 43 | mask_other: float = 0.0, 44 | min_masks: int = 0, 45 | no_overlap: bool = False, 46 | min_space: int = 0, 47 | ) -> np.ndarray: 48 | """ 49 | Computes random mask spans for a given shape 50 | 51 | Args: 52 | shape: the the shape for which to compute masks. 53 | should be of size 2 where first element is batch size and 2nd is timesteps 54 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements 55 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by 56 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. 57 | however due to overlaps, the actual number will be smaller (unless no_overlap is True) 58 | mask_type: how to compute mask lengths 59 | static = fixed size 60 | uniform = sample from uniform distribution [mask_other, mask_length*2] 61 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element 62 | poisson = sample from possion distribution with lambda = mask length 63 | min_masks: minimum number of masked spans 64 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping 65 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans 66 | """ 67 | 68 | bsz, all_sz = shape 69 | mask = np.full((bsz, all_sz), False) 70 | 71 | all_num_mask = int( 72 | # add a random number for probabilistic rounding 73 | mask_prob * all_sz / float(mask_length) 74 | + np.random.rand() 75 | ) 76 | 77 | all_num_mask = max(min_masks, all_num_mask) 78 | 79 | mask_idcs = [] 80 | for i in range(bsz): 81 | if padding_mask is not None: 82 | sz = all_sz - padding_mask[i].long().sum().item() 83 | num_mask = int( 84 | # add a random number for probabilistic rounding 85 | mask_prob * sz / float(mask_length) 86 | + np.random.rand() 87 | ) 88 | num_mask = max(min_masks, num_mask) 89 | else: 90 | sz = all_sz 91 | num_mask = all_num_mask 92 | 93 | if mask_type == "static": 94 | lengths = np.full(num_mask, mask_length) 95 | elif mask_type == "uniform": 96 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) 97 | elif mask_type == "normal": 98 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 99 | lengths = [max(1, int(round(x))) for x in lengths] 100 | elif mask_type == "poisson": 101 | lengths = np.random.poisson(mask_length, size=num_mask) 102 | lengths = [int(round(x)) for x in lengths] 103 | else: 104 | raise Exception("unknown mask selection " + mask_type) 105 | 106 | if sum(lengths) == 0: 107 | lengths[0] = min(mask_length, sz - 1) 108 | 109 | if no_overlap: 110 | mask_idc = [] 111 | 112 | def arrange(s, e, length, keep_length): 113 | span_start = np.random.randint(s, e - length) 114 | mask_idc.extend(span_start + i for i in range(length)) 115 | 116 | new_parts = [] 117 | if span_start - s - min_space >= keep_length: 118 | new_parts.append((s, span_start - min_space + 1)) 119 | if e - span_start - keep_length - min_space > keep_length: 120 | new_parts.append((span_start + length + min_space, e)) 121 | return new_parts 122 | 123 | parts = [(0, sz)] 124 | min_length = min(lengths) 125 | for length in sorted(lengths, reverse=True): 126 | lens = np.fromiter( 127 | (e - s if e - s >= length + min_space else 0 for s, e in parts), 128 | np.int, 129 | ) 130 | l_sum = np.sum(lens) 131 | if l_sum == 0: 132 | break 133 | probs = lens / np.sum(lens) 134 | c = np.random.choice(len(parts), p=probs) 135 | s, e = parts.pop(c) 136 | parts.extend(arrange(s, e, length, min_length)) 137 | mask_idc = np.asarray(mask_idc) 138 | else: 139 | min_len = min(lengths) 140 | if sz - min_len <= num_mask: 141 | min_len = sz - num_mask - 1 142 | 143 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 144 | 145 | mask_idc = np.asarray( 146 | [ 147 | mask_idc[j] + offset 148 | for j in range(len(mask_idc)) 149 | for offset in range(lengths[j]) 150 | ] 151 | ) 152 | 153 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 154 | 155 | min_len = min([len(m) for m in mask_idcs]) 156 | for i, mask_idc in enumerate(mask_idcs): 157 | if len(mask_idc) > min_len: 158 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 159 | mask[i, mask_idc] = True 160 | 161 | return mask 162 | 163 | 164 | class WavLMConfig: 165 | def __init__(self, cfg=None): 166 | self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) 167 | self.encoder_layers: int = 12 # num encoder layers in the transformer 168 | 169 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 170 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 171 | self.encoder_attention_heads: int = 12 # num encoder attention heads 172 | self.activation_fn: str = "gelu" # activation function to use 173 | 174 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 175 | self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] 176 | self.conv_bias: bool = False # include bias in conv encoder 177 | self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this 178 | 179 | self.normalize: bool = False # normalize input to have 0 mean and unit variance during training 180 | 181 | # dropouts 182 | self.dropout: float = 0.1 # dropout probability for the transformer 183 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 184 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 185 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 186 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 187 | self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) 188 | 189 | # masking 190 | self.mask_length: int = 10 # mask length 191 | self.mask_prob: float = 0.65 # probability of replacing a token with mask 192 | self.mask_selection: str = "static" # how to choose mask length 193 | self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh 194 | self.no_mask_overlap: bool = False # whether to allow masks to overlap 195 | self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) 196 | 197 | # channel masking 198 | self.mask_channel_length: int = 10 # length of the mask for features (channels) 199 | self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 200 | self.mask_channel_selection: str = "static" # how to choose mask length for channel masking 201 | self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices 202 | self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap 203 | self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) 204 | 205 | # positional embeddings 206 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 207 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 208 | 209 | # relative position embedding 210 | self.relative_position_embedding: bool = False # apply relative position embedding 211 | self.num_buckets: int = 320 # number of buckets for relative position embedding 212 | self.max_distance: int = 1280 # maximum distance for relative position embedding 213 | self.gru_rel_pos: bool = False # apply gated relative position embedding 214 | 215 | if cfg is not None: 216 | self.update(cfg) 217 | 218 | def update(self, cfg: dict): 219 | self.__dict__.update(cfg) 220 | 221 | 222 | class WavLM(nn.Module): 223 | def __init__( 224 | self, 225 | cfg: WavLMConfig, 226 | ) -> None: 227 | super().__init__() 228 | logger.info(f"WavLM Config: {cfg.__dict__}") 229 | 230 | self.cfg = cfg 231 | feature_enc_layers = eval(cfg.conv_feature_layers) 232 | self.embed = feature_enc_layers[-1][0] 233 | 234 | self.feature_extractor = ConvFeatureExtractionModel( 235 | conv_layers=feature_enc_layers, 236 | dropout=0.0, 237 | mode=cfg.extractor_mode, 238 | conv_bias=cfg.conv_bias, 239 | ) 240 | 241 | self.post_extract_proj = ( 242 | nn.Linear(self.embed, cfg.encoder_embed_dim) 243 | if self.embed != cfg.encoder_embed_dim 244 | else None 245 | ) 246 | 247 | self.mask_prob = cfg.mask_prob 248 | self.mask_selection = cfg.mask_selection 249 | self.mask_other = cfg.mask_other 250 | self.mask_length = cfg.mask_length 251 | self.no_mask_overlap = cfg.no_mask_overlap 252 | self.mask_min_space = cfg.mask_min_space 253 | 254 | self.mask_channel_prob = cfg.mask_channel_prob 255 | self.mask_channel_selection = cfg.mask_channel_selection 256 | self.mask_channel_other = cfg.mask_channel_other 257 | self.mask_channel_length = cfg.mask_channel_length 258 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap 259 | self.mask_channel_min_space = cfg.mask_channel_min_space 260 | 261 | self.dropout_input = nn.Dropout(cfg.dropout_input) 262 | self.dropout_features = nn.Dropout(cfg.dropout_features) 263 | 264 | self.feature_grad_mult = cfg.feature_grad_mult 265 | 266 | self.mask_emb = nn.Parameter( 267 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_() 268 | ) 269 | 270 | self.encoder = TransformerEncoder(cfg) 271 | self.layer_norm = LayerNorm(self.embed) 272 | 273 | def apply_mask(self, x, padding_mask): 274 | B, T, C = x.shape 275 | if self.mask_prob > 0: 276 | mask_indices = compute_mask_indices( 277 | (B, T), 278 | padding_mask, 279 | self.mask_prob, 280 | self.mask_length, 281 | self.mask_selection, 282 | self.mask_other, 283 | min_masks=2, 284 | no_overlap=self.no_mask_overlap, 285 | min_space=self.mask_min_space, 286 | ) 287 | mask_indices = torch.from_numpy(mask_indices).to(x.device) 288 | x[mask_indices] = self.mask_emb 289 | else: 290 | mask_indices = None 291 | 292 | if self.mask_channel_prob > 0: 293 | mask_channel_indices = compute_mask_indices( 294 | (B, C), 295 | None, 296 | self.mask_channel_prob, 297 | self.mask_channel_length, 298 | self.mask_channel_selection, 299 | self.mask_channel_other, 300 | no_overlap=self.no_mask_channel_overlap, 301 | min_space=self.mask_channel_min_space, 302 | ) 303 | mask_channel_indices = ( 304 | torch.from_numpy(mask_channel_indices) 305 | .to(x.device) 306 | .unsqueeze(1) 307 | .expand(-1, T, -1) 308 | ) 309 | x[mask_channel_indices] = 0 310 | 311 | return x, mask_indices 312 | 313 | def forward_padding_mask( 314 | self, features: torch.Tensor, padding_mask: torch.Tensor, 315 | ) -> torch.Tensor: 316 | extra = padding_mask.size(1) % features.size(1) 317 | if extra > 0: 318 | padding_mask = padding_mask[:, :-extra] 319 | padding_mask = padding_mask.view( 320 | padding_mask.size(0), features.size(1), -1 321 | ) 322 | padding_mask = padding_mask.all(-1) 323 | return padding_mask 324 | 325 | def extract_features( 326 | self, 327 | source: torch.Tensor, 328 | padding_mask: Optional[torch.Tensor] = None, 329 | mask: bool = False, 330 | ret_conv: bool = False, 331 | output_layer: Optional[int] = None, 332 | ret_layer_results: bool = False, 333 | ): 334 | 335 | if self.feature_grad_mult > 0: 336 | features = self.feature_extractor(source) 337 | if self.feature_grad_mult != 1.0: 338 | features = GradMultiply.apply(features, self.feature_grad_mult) 339 | else: 340 | with torch.no_grad(): 341 | features = self.feature_extractor(source) 342 | 343 | features = features.transpose(1, 2) 344 | features = self.layer_norm(features) 345 | 346 | if padding_mask is not None: 347 | padding_mask = self.forward_padding_mask(features, padding_mask) 348 | 349 | if self.post_extract_proj is not None: 350 | features = self.post_extract_proj(features) 351 | 352 | features = self.dropout_input(features) 353 | 354 | if mask: 355 | x, mask_indices = self.apply_mask( 356 | features, padding_mask 357 | ) 358 | else: 359 | x = features 360 | 361 | # feature: (B, T, D), float 362 | # target: (B, T), long 363 | # x: (B, T, D), float 364 | # padding_mask: (B, T), bool 365 | # mask_indices: (B, T), bool 366 | x, layer_results = self.encoder( 367 | x, 368 | padding_mask=padding_mask, 369 | layer=None if output_layer is None else output_layer - 1 370 | ) 371 | 372 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} 373 | 374 | feature = res["features"] if ret_conv else res["x"] 375 | if ret_layer_results: 376 | feature = (feature, res["layer_results"]) 377 | return feature, res["padding_mask"] 378 | 379 | 380 | def long_term_modeling( 381 | self, 382 | source: torch.Tensor, 383 | padding_mask: Optional[torch.Tensor] = None, 384 | mask: bool = False, 385 | ret_conv: bool = False, 386 | output_layer: Optional[int] = None, 387 | ret_layer_results: bool = False, 388 | ): 389 | 390 | features = source.transpose(1, 2) 391 | features = self.layer_norm(features) 392 | 393 | if padding_mask is not None: 394 | padding_mask = self.forward_padding_mask(features, padding_mask) 395 | 396 | if self.post_extract_proj is not None: 397 | features = self.post_extract_proj(features) 398 | 399 | features = self.dropout_input(features) 400 | 401 | if mask: 402 | x, mask_indices = self.apply_mask( 403 | features, padding_mask 404 | ) 405 | else: 406 | x = features 407 | 408 | # feature: (B, T, D), float 409 | # target: (B, T), long 410 | # x: (B, T, D), float 411 | # padding_mask: (B, T), bool 412 | # mask_indices: (B, T), bool 413 | x, layer_results = self.encoder( 414 | x, 415 | padding_mask=padding_mask, 416 | layer=None if output_layer is None else output_layer - 1 417 | ) 418 | 419 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} 420 | 421 | feature = res["features"] if ret_conv else res["x"] 422 | if ret_layer_results: 423 | feature = (feature, res["layer_results"]) 424 | return feature, res["padding_mask"] 425 | 426 | 427 | 428 | class ConvFeatureExtractionModel(nn.Module): 429 | def __init__( 430 | self, 431 | conv_layers: List[Tuple[int, int, int]], 432 | dropout: float = 0.0, 433 | mode: str = "default", 434 | conv_bias: bool = False, 435 | conv_type: str = "default" 436 | ): 437 | super().__init__() 438 | 439 | assert mode in {"default", "layer_norm"} 440 | 441 | def block( 442 | n_in, 443 | n_out, 444 | k, 445 | stride, 446 | is_layer_norm=False, 447 | is_group_norm=False, 448 | conv_bias=False, 449 | ): 450 | def make_conv(): 451 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) 452 | nn.init.kaiming_normal_(conv.weight) 453 | return conv 454 | 455 | assert ( 456 | is_layer_norm and is_group_norm 457 | ) == False, "layer norm and group norm are exclusive" 458 | 459 | if is_layer_norm: 460 | return nn.Sequential( 461 | make_conv(), 462 | nn.Dropout(p=dropout), 463 | nn.Sequential( 464 | TransposeLast(), 465 | Fp32LayerNorm(dim, elementwise_affine=True), 466 | TransposeLast(), 467 | ), 468 | nn.GELU(), 469 | ) 470 | elif is_group_norm: 471 | return nn.Sequential( 472 | make_conv(), 473 | nn.Dropout(p=dropout), 474 | Fp32GroupNorm(dim, dim, affine=True), 475 | nn.GELU(), 476 | ) 477 | else: 478 | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) 479 | 480 | self.conv_type = conv_type 481 | if self.conv_type == "default": 482 | in_d = 1 483 | self.conv_layers = nn.ModuleList() 484 | for i, cl in enumerate(conv_layers): 485 | assert len(cl) == 3, "invalid conv definition: " + str(cl) 486 | (dim, k, stride) = cl 487 | 488 | self.conv_layers.append( 489 | block( 490 | in_d, 491 | dim, 492 | k, 493 | stride, 494 | is_layer_norm=mode == "layer_norm", 495 | is_group_norm=mode == "default" and i == 0, 496 | conv_bias=conv_bias, 497 | ) 498 | ) 499 | in_d = dim 500 | elif self.conv_type == "conv2d": 501 | in_d = 1 502 | self.conv_layers = nn.ModuleList() 503 | for i, cl in enumerate(conv_layers): 504 | assert len(cl) == 3 505 | (dim, k, stride) = cl 506 | 507 | self.conv_layers.append( 508 | torch.nn.Conv2d(in_d, dim, k, stride) 509 | ) 510 | self.conv_layers.append(torch.nn.ReLU()) 511 | in_d = dim 512 | elif self.conv_type == "custom": 513 | in_d = 1 514 | idim = 80 515 | self.conv_layers = nn.ModuleList() 516 | for i, cl in enumerate(conv_layers): 517 | assert len(cl) == 3 518 | (dim, k, stride) = cl 519 | self.conv_layers.append( 520 | torch.nn.Conv2d(in_d, dim, k, stride, padding=1) 521 | ) 522 | self.conv_layers.append( 523 | torch.nn.LayerNorm([dim, idim]) 524 | ) 525 | self.conv_layers.append(torch.nn.ReLU()) 526 | in_d = dim 527 | if (i + 1) % 2 == 0: 528 | self.conv_layers.append( 529 | torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) 530 | ) 531 | idim = int(math.ceil(idim / 2)) 532 | else: 533 | pass 534 | 535 | def forward(self, x, mask=None): 536 | 537 | # BxT -> BxCxT 538 | x = x.unsqueeze(1) 539 | if self.conv_type == "custom": 540 | for conv in self.conv_layers: 541 | if isinstance(conv, nn.LayerNorm): 542 | x = x.transpose(1, 2) 543 | x = conv(x).transpose(1, 2) 544 | else: 545 | x = conv(x) 546 | x = x.transpose(2, 3).contiguous() 547 | x = x.view(x.size(0), -1, x.size(-1)) 548 | else: 549 | for conv in self.conv_layers: 550 | x = conv(x) 551 | if self.conv_type == "conv2d": 552 | b, c, t, f = x.size() 553 | x = x.transpose(2, 3).contiguous().view(b, c * f, t) 554 | return x 555 | 556 | 557 | class TransformerEncoder(nn.Module): 558 | def __init__(self, args): 559 | super().__init__() 560 | 561 | self.dropout = args.dropout 562 | self.embedding_dim = args.encoder_embed_dim 563 | 564 | self.pos_conv = nn.Conv1d( 565 | self.embedding_dim, 566 | self.embedding_dim, 567 | kernel_size=args.conv_pos, 568 | padding=args.conv_pos // 2, 569 | groups=args.conv_pos_groups, 570 | ) 571 | dropout = 0 572 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) 573 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std) 574 | nn.init.constant_(self.pos_conv.bias, 0) 575 | 576 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) 577 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) 578 | 579 | if hasattr(args, "relative_position_embedding"): 580 | self.relative_position_embedding = args.relative_position_embedding 581 | self.num_buckets = args.num_buckets 582 | self.max_distance = args.max_distance 583 | else: 584 | self.relative_position_embedding = False 585 | self.num_buckets = 0 586 | self.max_distance = 0 587 | 588 | self.layers = nn.ModuleList( 589 | [ 590 | TransformerSentenceEncoderLayer( 591 | embedding_dim=self.embedding_dim, 592 | ffn_embedding_dim=args.encoder_ffn_embed_dim, 593 | num_attention_heads=args.encoder_attention_heads, 594 | dropout=self.dropout, 595 | attention_dropout=args.attention_dropout, 596 | activation_dropout=args.activation_dropout, 597 | activation_fn=args.activation_fn, 598 | layer_norm_first=args.layer_norm_first, 599 | has_relative_attention_bias=(self.relative_position_embedding and i == 0), 600 | num_buckets=self.num_buckets, 601 | max_distance=self.max_distance, 602 | gru_rel_pos=args.gru_rel_pos, 603 | ) 604 | for i in range(args.encoder_layers) 605 | ] 606 | ) 607 | 608 | self.layer_norm_first = args.layer_norm_first 609 | self.layer_norm = LayerNorm(self.embedding_dim) 610 | self.layerdrop = args.encoder_layerdrop 611 | 612 | self.apply(init_bert_params) 613 | 614 | def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): 615 | x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) 616 | 617 | if self.layer_norm_first and layer is None: 618 | x = self.layer_norm(x) 619 | 620 | return x, layer_results 621 | 622 | def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): 623 | 624 | if padding_mask is not None: 625 | x[padding_mask] = 0 626 | 627 | y = x.transpose(1, 2).clone() 628 | x_conv = self.pos_conv(y) 629 | x_conv = x_conv.transpose(1, 2) 630 | x += x_conv 631 | 632 | if not self.layer_norm_first: 633 | x = self.layer_norm(x) 634 | 635 | x = F.dropout(x, p=self.dropout, training=self.training) 636 | 637 | # B x T x C -> T x B x C 638 | x = x.transpose(0, 1) 639 | 640 | layer_results = [] 641 | z = None 642 | if tgt_layer is not None: 643 | layer_results.append((x, z)) 644 | r = None 645 | pos_bias = None 646 | for i, layer in enumerate(self.layers): 647 | dropout_probability = np.random.random() 648 | if not self.training or (dropout_probability > self.layerdrop): 649 | x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, 650 | self_attn_mask=streaming_mask, pos_bias=pos_bias) 651 | if tgt_layer is not None: 652 | layer_results.append((x, z)) 653 | if i == tgt_layer: 654 | r = x 655 | break 656 | 657 | if r is not None: 658 | x = r 659 | 660 | # T x B x C -> B x T x C 661 | x = x.transpose(0, 1) 662 | 663 | return x, layer_results 664 | 665 | 666 | class TransformerSentenceEncoderLayer(nn.Module): 667 | """ 668 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 669 | models. 670 | """ 671 | 672 | def __init__( 673 | self, 674 | embedding_dim: float = 768, 675 | ffn_embedding_dim: float = 3072, 676 | num_attention_heads: float = 8, 677 | dropout: float = 0.1, 678 | attention_dropout: float = 0.1, 679 | activation_dropout: float = 0.1, 680 | activation_fn: str = "relu", 681 | layer_norm_first: bool = False, 682 | has_relative_attention_bias: bool = False, 683 | num_buckets: int = 0, 684 | max_distance: int = 0, 685 | rescale_init: bool = False, 686 | gru_rel_pos: bool = False, 687 | ) -> None: 688 | 689 | super().__init__() 690 | # Initialize parameters 691 | self.embedding_dim = embedding_dim 692 | self.dropout = dropout 693 | self.activation_dropout = activation_dropout 694 | 695 | # Initialize blocks 696 | self.activation_name = activation_fn 697 | self.activation_fn = get_activation_fn(activation_fn) 698 | self.self_attn = MultiheadAttention( 699 | self.embedding_dim, 700 | num_attention_heads, 701 | dropout=attention_dropout, 702 | self_attention=True, 703 | has_relative_attention_bias=has_relative_attention_bias, 704 | num_buckets=num_buckets, 705 | max_distance=max_distance, 706 | rescale_init=rescale_init, 707 | gru_rel_pos=gru_rel_pos, 708 | ) 709 | 710 | self.dropout1 = nn.Dropout(dropout) 711 | self.dropout2 = nn.Dropout(self.activation_dropout) 712 | self.dropout3 = nn.Dropout(dropout) 713 | 714 | self.layer_norm_first = layer_norm_first 715 | 716 | # layer norm associated with the self attention layer 717 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim) 718 | 719 | if self.activation_name == "glu": 720 | self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") 721 | else: 722 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 723 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 724 | 725 | # layer norm associated with the position wise feed-forward NN 726 | self.final_layer_norm = LayerNorm(self.embedding_dim) 727 | 728 | def forward( 729 | self, 730 | x: torch.Tensor, 731 | self_attn_mask: torch.Tensor = None, 732 | self_attn_padding_mask: torch.Tensor = None, 733 | need_weights: bool = False, 734 | pos_bias=None 735 | ): 736 | """ 737 | LayerNorm is applied either before or after the self-attention/ffn 738 | modules similar to the original Transformer imlementation. 739 | """ 740 | residual = x 741 | 742 | if self.layer_norm_first: 743 | x = self.self_attn_layer_norm(x) 744 | x, attn, pos_bias = self.self_attn( 745 | query=x, 746 | key=x, 747 | value=x, 748 | key_padding_mask=self_attn_padding_mask, 749 | need_weights=False, 750 | attn_mask=self_attn_mask, 751 | position_bias=pos_bias 752 | ) 753 | x = self.dropout1(x) 754 | x = residual + x 755 | 756 | residual = x 757 | x = self.final_layer_norm(x) 758 | if self.activation_name == "glu": 759 | x = self.fc1(x) 760 | else: 761 | x = self.activation_fn(self.fc1(x)) 762 | x = self.dropout2(x) 763 | x = self.fc2(x) 764 | x = self.dropout3(x) 765 | x = residual + x 766 | else: 767 | x, attn, pos_bias = self.self_attn( 768 | query=x, 769 | key=x, 770 | value=x, 771 | key_padding_mask=self_attn_padding_mask, 772 | need_weights=need_weights, 773 | attn_mask=self_attn_mask, 774 | position_bias=pos_bias 775 | ) 776 | 777 | x = self.dropout1(x) 778 | x = residual + x 779 | 780 | x = self.self_attn_layer_norm(x) 781 | 782 | residual = x 783 | if self.activation_name == "glu": 784 | x = self.fc1(x) 785 | else: 786 | x = self.activation_fn(self.fc1(x)) 787 | x = self.dropout2(x) 788 | x = self.fc2(x) 789 | x = self.dropout3(x) 790 | x = residual + x 791 | x = self.final_layer_norm(x) 792 | 793 | return x, attn, pos_bias 794 | -------------------------------------------------------------------------------- /nnet/__pycache__/WavLM.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/nnet/__pycache__/WavLM.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/embedding.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/nnet/__pycache__/embedding.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/llama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/nnet/__pycache__/llama.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/nnet/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/llase.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import sys,os 7 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | from typing import Union, Optional 10 | from transformers import LlamaConfig, LlamaForCausalLM 11 | 12 | NUM_AUDIO_TOKENS = 65536 # Codebook size 13 | 14 | class LLM_AR(nn.Module): 15 | def __init__( 16 | self, 17 | d_model: int, 18 | nhead: int, 19 | num_layers: int 20 | ): 21 | super().__init__() 22 | self.d_model = d_model 23 | 24 | self.audio_linear_y = nn.Linear(1024, d_model) 25 | self.audio_linear_x = nn.Linear(1024, d_model) 26 | 27 | self.Llama_config = LlamaConfig( 28 | hidden_size=d_model*2, 29 | intermediate_size=d_model * 4, 30 | num_attention_heads=nhead, 31 | num_hidden_layers=num_layers, 32 | dropout_rate=0.1, 33 | attention_dropout=0.1, 34 | is_decoder=True, 35 | use_cache=True 36 | ) 37 | 38 | self.llama= LlamaForCausalLM(config=self.Llama_config) 39 | self.predict_layer_x = nn.Linear(2*d_model, NUM_AUDIO_TOKENS) 40 | self.predict_layer_y = nn.Linear(2*d_model, NUM_AUDIO_TOKENS) 41 | 42 | def forward( 43 | self, 44 | y: torch.Tensor, 45 | x: Union[torch.Tensor, None] = None, 46 | ) -> torch.Tensor: 47 | # y = y.transpose(1,2) # if codec input use this transpose 48 | 49 | if x is None: 50 | x = torch.zeros_like(y) 51 | elif x.dim() == 2: 52 | x = x.unsqueeze(-1) 53 | x = x.expand_as(y) 54 | 55 | 56 | y_emb = self.audio_linear_y(y) # [B, T, D] 57 | x_emb = self.audio_linear_x(x) # [B, T, D] 58 | 59 | if x_emb.shape[1] < y_emb.shape[1]: 60 | pad_length = y_emb.shape[1] - x_emb.shape[1] 61 | x_emb= F.pad(x_emb, (0, 0, 0, pad_length), mode='constant', value=0) 62 | 63 | if y_emb.shape[1] < x_emb.shape[1]: 64 | pad_length = x_emb.shape[1] - y_emb.shape[1] 65 | y_emb= F.pad(y_emb, (0, 0, 0, pad_length), mode='constant', value=0) 66 | 67 | y_emb = torch.concat([x_emb, y_emb], dim = -1) # [B, T_y, D*2] 68 | 69 | outputs = self.llama(inputs_embeds = y_emb, output_hidden_states=True) 70 | 71 | dec = outputs.hidden_states[-1] # [B, T_y, D*2] 72 | 73 | logits_y = self.predict_layer_y(dec) # [B, T, NUM_AUDIO_TOKENS] 74 | logits_x = self.predict_layer_x(dec) 75 | 76 | logits_y = logits_y.transpose(-1, -2) # [B, NUM_AUDIO_TOKENS, T] 77 | logits_x = logits_x.transpose(-1, -2) 78 | 79 | return logits_y, logits_x 80 | 81 | if __name__=="__main__": 82 | # Simple test 83 | model = LLM_AR(d_model=1024, nhead=8, num_layers=16) 84 | ce_loss = nn.CrossEntropyLoss() 85 | 86 | y = torch.randn([1,199,1024]) 87 | x = torch.randn([1,99,1024]) 88 | label = torch.from_numpy(np.random.randint(0, 300, size=[2,1,199])) 89 | 90 | total_params = sum(p.numel() for p in model.parameters()) 91 | 92 | print(f"Total Params: {total_params}") 93 | 94 | logits = model(y) 95 | print(logits[0].shape) 96 | print(logits[1].shape) 97 | 98 | logits = model(y,x) 99 | print(logits[0].shape) 100 | print(logits[1].shape) 101 | 102 | logits = model(y,y) 103 | print(logits[0].shape) 104 | print(logits[1].shape) 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.1.0 2 | aiohttp==3.10.5 3 | datasets==3.0.1 4 | deepspeed==0.15.1 5 | einops==0.8.0 6 | huggingface-hub==0.25.1 7 | numpy==1.23.5 8 | pandas==2.2.3 9 | pillow==10.4.0 10 | scikit-learn==1.5.2 11 | scipy==1.13.1 12 | torch==2.4.1 13 | torchaudio==2.4.1 14 | torchmetrics==1.4.2 15 | transformers==4.47.1 16 | tqdm==4.66.5 17 | torchtune==0.3.1 18 | triton==3.0.0 19 | torchao==0.5.0 -------------------------------------------------------------------------------- /vq/__init__.py: -------------------------------------------------------------------------------- 1 | from vq.codec_encoder import CodecEncoder 2 | from vq.codec_decoder import CodecDecoder 3 | from vq.codec_decoder_vocos import CodecDecoderVocos 4 | from vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/activations.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/activations.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/activations.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/activations.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/activations.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/activations.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/blocks.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/bs_roformer5.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/bs_roformer5.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/bs_roformer5.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/bs_roformer5.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/codec_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/module.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/module.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/module.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/module.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/module.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/module.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/residual_vq.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/residual_vq.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/residual_vq.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/residual_vq.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/unet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/unet.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /vq/activations.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | from torch import nn, sin, pow 6 | from torch.nn import Parameter 7 | 8 | 9 | class Snake(nn.Module): 10 | ''' 11 | Implementation of a sine-based periodic activation function 12 | Shape: 13 | - Input: (B, C, T) 14 | - Output: (B, C, T), same shape as the input 15 | Parameters: 16 | - alpha - trainable parameter 17 | References: 18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 19 | https://arxiv.org/abs/2006.08195 20 | Examples: 21 | >>> a1 = snake(256) 22 | >>> x = torch.randn(256) 23 | >>> x = a1(x) 24 | ''' 25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 26 | ''' 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | ''' 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | ''' 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | ''' 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | ''' 79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 80 | ''' 81 | Initialization. 82 | INPUT: 83 | - in_features: shape of the input 84 | - alpha - trainable parameter that controls frequency 85 | - beta - trainable parameter that controls magnitude 86 | alpha is initialized to 1 by default, higher values = higher-frequency. 87 | beta is initialized to 1 by default, higher values = higher-magnitude. 88 | alpha will be trained along with the rest of your model. 89 | ''' 90 | super(SnakeBeta, self).__init__() 91 | self.in_features = in_features 92 | 93 | # initialize alpha 94 | self.alpha_logscale = alpha_logscale 95 | if self.alpha_logscale: # log scale alphas initialized to zeros 96 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 97 | self.beta = Parameter(torch.zeros(in_features) * alpha) 98 | else: # linear scale alphas initialized to ones 99 | self.alpha = Parameter(torch.ones(in_features) * alpha) 100 | self.beta = Parameter(torch.ones(in_features) * alpha) 101 | 102 | self.alpha.requires_grad = alpha_trainable 103 | self.beta.requires_grad = alpha_trainable 104 | 105 | self.no_div_by_zero = 0.000000001 106 | 107 | def forward(self, x): 108 | ''' 109 | Forward pass of the function. 110 | Applies the function to the input elementwise. 111 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 112 | ''' 113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 114 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 115 | if self.alpha_logscale: 116 | alpha = torch.exp(alpha) 117 | beta = torch.exp(beta) 118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 119 | 120 | return x -------------------------------------------------------------------------------- /vq/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/act.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/act.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/act.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/act.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/act.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/act.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/filter.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/resample.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE-G1/67918c52c9269921b7ba7434a57a8c043371841e/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__(self, 10 | activation, 11 | up_ratio: int = 2, 12 | down_ratio: int = 2, 13 | up_kernel_size: int = 12, 14 | down_kernel_size: int = 12): 15 | super().__init__() 16 | self.up_ratio = up_ratio 17 | self.down_ratio = down_ratio 18 | self.act = activation 19 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 20 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 21 | 22 | # x: [B,C,T] 23 | def forward(self, x): 24 | x = self.upsample(x) 25 | x = self.act(x) 26 | x = self.downsample(x) 27 | 28 | return x -------------------------------------------------------------------------------- /vq/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /vq/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /vq/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Sequence, Type, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] 8 | 9 | 10 | class FeedForwardModule(nn.Module): 11 | 12 | def __init__(self) -> None: 13 | super().__init__() 14 | self.net = None 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | return self.net(x) 18 | 19 | 20 | class Residual(nn.Module): 21 | 22 | def __init__(self, module: nn.Module) -> None: 23 | super().__init__() 24 | self.module = module 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | return self.module(x) + x 28 | 29 | 30 | class DilatedConvolutionalUnit(FeedForwardModule): 31 | 32 | def __init__( 33 | self, 34 | hidden_dim: int, 35 | dilation: int, 36 | kernel_size: int, 37 | activation: ModuleFactory, 38 | normalization: Callable[[nn.Module], 39 | nn.Module] = lambda x: x) -> None: 40 | super().__init__() 41 | self.net = nn.Sequential( 42 | activation(), 43 | normalization( 44 | nn.Conv1d( 45 | in_channels=hidden_dim, 46 | out_channels=hidden_dim, 47 | kernel_size=kernel_size, 48 | dilation=dilation, 49 | padding=((kernel_size - 1) * dilation) // 2, 50 | )), 51 | activation(), 52 | nn.Conv1d(in_channels=hidden_dim, 53 | out_channels=hidden_dim, 54 | kernel_size=1), 55 | ) 56 | 57 | 58 | class UpsamplingUnit(FeedForwardModule): 59 | 60 | def __init__( 61 | self, 62 | input_dim: int, 63 | output_dim: int, 64 | stride: int, 65 | activation: ModuleFactory, 66 | normalization: Callable[[nn.Module], 67 | nn.Module] = lambda x: x) -> None: 68 | super().__init__() 69 | self.net = nn.Sequential( 70 | activation(), 71 | normalization( 72 | nn.ConvTranspose1d( 73 | in_channels=input_dim, 74 | out_channels=output_dim, 75 | kernel_size=2 * stride, 76 | stride=stride, 77 | padding=stride // 2+ stride % 2, 78 | output_padding=1 if stride % 2 != 0 else 0 79 | ))) 80 | 81 | 82 | class DownsamplingUnit(FeedForwardModule): 83 | 84 | def __init__( 85 | self, 86 | input_dim: int, 87 | output_dim: int, 88 | stride: int, 89 | activation: ModuleFactory, 90 | normalization: Callable[[nn.Module], 91 | nn.Module] = lambda x: x) -> None: 92 | super().__init__() 93 | self.net = nn.Sequential( 94 | activation(), 95 | normalization( 96 | nn.Conv1d( 97 | in_channels=input_dim, 98 | out_channels=output_dim, 99 | kernel_size=2 * stride, 100 | stride=stride, 101 | padding= stride // 2+ stride % 2, 102 | 103 | ))) 104 | 105 | 106 | class DilatedResidualEncoder(FeedForwardModule): 107 | 108 | def __init__( 109 | self, 110 | capacity: int, 111 | dilated_unit: Type[DilatedConvolutionalUnit], 112 | downsampling_unit: Type[DownsamplingUnit], 113 | ratios: Sequence[int], 114 | dilations: Union[Sequence[int], Sequence[Sequence[int]]], 115 | pre_network_conv: Type[nn.Conv1d], 116 | post_network_conv: Type[nn.Conv1d], 117 | normalization: Callable[[nn.Module], 118 | nn.Module] = lambda x: x) -> None: 119 | super().__init__() 120 | channels = capacity * 2**np.arange(len(ratios) + 1) 121 | 122 | dilations_list = self.normalize_dilations(dilations, ratios) 123 | 124 | net = [normalization(pre_network_conv(out_channels=channels[0]))] 125 | 126 | for ratio, dilations, input_dim, output_dim in zip( 127 | ratios, dilations_list, channels[:-1], channels[1:]): 128 | for dilation in dilations: 129 | net.append(Residual(dilated_unit(input_dim, dilation))) 130 | net.append(downsampling_unit(input_dim, output_dim, ratio)) 131 | 132 | net.append(post_network_conv(in_channels=output_dim)) 133 | 134 | self.net = nn.Sequential(*net) 135 | 136 | @staticmethod 137 | def normalize_dilations(dilations: Union[Sequence[int], 138 | Sequence[Sequence[int]]], 139 | ratios: Sequence[int]): 140 | if isinstance(dilations[0], int): 141 | dilations = [dilations for _ in ratios] 142 | return dilations 143 | 144 | 145 | class DilatedResidualDecoder(FeedForwardModule): 146 | 147 | def __init__( 148 | self, 149 | capacity: int, 150 | dilated_unit: Type[DilatedConvolutionalUnit], 151 | upsampling_unit: Type[UpsamplingUnit], 152 | ratios: Sequence[int], 153 | dilations: Union[Sequence[int], Sequence[Sequence[int]]], 154 | pre_network_conv: Type[nn.Conv1d], 155 | post_network_conv: Type[nn.Conv1d], 156 | normalization: Callable[[nn.Module], 157 | nn.Module] = lambda x: x) -> None: 158 | super().__init__() 159 | channels = capacity * 2**np.arange(len(ratios) + 1) 160 | channels = channels[::-1] 161 | 162 | dilations_list = self.normalize_dilations(dilations, ratios) 163 | dilations_list = dilations_list[::-1] 164 | 165 | net = [pre_network_conv(out_channels=channels[0])] 166 | 167 | for ratio, dilations, input_dim, output_dim in zip( 168 | ratios, dilations_list, channels[:-1], channels[1:]): 169 | net.append(upsampling_unit(input_dim, output_dim, ratio)) 170 | for dilation in dilations: 171 | net.append(Residual(dilated_unit(output_dim, dilation))) 172 | 173 | net.append(normalization(post_network_conv(in_channels=output_dim))) 174 | 175 | self.net = nn.Sequential(*net) 176 | 177 | @staticmethod 178 | def normalize_dilations(dilations: Union[Sequence[int], 179 | Sequence[Sequence[int]]], 180 | ratios: Sequence[int]): 181 | if isinstance(dilations[0], int): 182 | dilations = [dilations for _ in ratios] 183 | return dilations -------------------------------------------------------------------------------- /vq/bs_roformer5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Module, ModuleList 5 | import torchaudio 6 | from einops import rearrange 7 | import numpy as np 8 | # from rotary_embedding_torch import RotaryEmbedding 9 | 10 | from torchtune.modules import RotaryPositionalEmbeddings 11 | 12 | 13 | 14 | class RMSNorm(torch.nn.Module): 15 | def __init__(self, dim: int, eps: float = 1e-6): 16 | r"""https://github.com/meta-llama/llama/blob/main/llama/model.py""" 17 | super().__init__() 18 | self.eps = eps 19 | self.weight = nn.Parameter(torch.ones(dim)) 20 | 21 | def forward(self, x): 22 | norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) 23 | output = x * torch.rsqrt(norm_x + self.eps) * self.weight 24 | return output 25 | 26 | 27 | 28 | class MLP(nn.Module): 29 | def __init__(self, dim: int) -> None: 30 | super().__init__() 31 | 32 | self.fc1 = nn.Linear(dim, 4 * dim, bias=False) 33 | self.silu = nn.SiLU() 34 | self.fc2 = nn.Linear(4 * dim, dim, bias=False) 35 | 36 | def forward(self, x): 37 | x = self.fc1(x) 38 | x = self.silu(x) 39 | x = self.fc2(x) 40 | return x 41 | 42 | 43 | class Attention(nn.Module): 44 | 45 | def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): 46 | super().__init__() 47 | 48 | assert dim % n_heads == 0 49 | 50 | self.n_heads = n_heads 51 | self.dim = dim 52 | self.rotary_embed = rotary_embed 53 | 54 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 55 | assert self.flash, "Must have flash attention." 56 | 57 | self.c_attn = nn.Linear(dim, 3 * dim, bias=False) 58 | self.c_proj = nn.Linear(dim, dim, bias=False) 59 | 60 | def forward(self, x): 61 | r""" 62 | Args: 63 | x: (b, t, h*d) 64 | 65 | Constants: 66 | b: batch_size 67 | t: time steps 68 | r: 3 69 | h: heads_num 70 | d: heads_dim 71 | """ 72 | B, T, C = x.size() 73 | 74 | q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads) 75 | # q, k, v: (b, h, t, d) 76 | 77 | q = self.rotary_embed(q) 78 | k = self.rotary_embed(k) 79 | 80 | if self.flash: 81 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) 82 | 83 | y = rearrange(y, 'b h t d -> b t (h d)') 84 | 85 | y = self.c_proj(y) 86 | # shape: (b, t, h*d) 87 | 88 | return y 89 | 90 | 91 | class TransformerBlock(nn.Module): 92 | def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): 93 | 94 | super().__init__() 95 | self.dim = dim 96 | self.n_heads = n_heads 97 | 98 | self.att_norm = RMSNorm(dim) 99 | self.ffn_norm = RMSNorm(dim) 100 | self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed) 101 | self.mlp = MLP(dim=dim) 102 | 103 | 104 | def forward( 105 | self, 106 | x: torch.Tensor, 107 | ): 108 | x = x + self.att(self.att_norm(x)) 109 | x = x + self.mlp(self.ffn_norm(x)) 110 | return x 111 | 112 | 113 | if __name__ == '__main__': 114 | rotary_embed_128 = RotaryPositionalEmbeddings(dim=128) 115 | transformer_block = TransformerBlock( 116 | dim=1024, 117 | n_heads=8, 118 | rotary_embed=rotary_embed_128 119 | ) 120 | x = torch.randn(2, 128, 1024) 121 | y = transformer_block(x) 122 | print(y.shape) 123 | c=1 -------------------------------------------------------------------------------- /vq/codec_decoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv1d_transformer') 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from vq.residual_vq import ResidualVQ 7 | from vq.module import WNConv1d, DecoderBlock, ResLSTM 8 | from vq.alias_free_torch import * 9 | from vq import activations 10 | import vq.blocks as blocks 11 | from torch.nn import utils 12 | 13 | from vq.bs_roformer5 import TransformerBlock 14 | 15 | from torchtune.modules import RotaryPositionalEmbeddings 16 | 17 | def init_weights(m): 18 | if isinstance(m, nn.Conv1d): 19 | nn.init.trunc_normal_(m.weight, std=0.02) 20 | nn.init.constant_(m.bias, 0) 21 | 22 | class CodecDecoder(nn.Module): 23 | def __init__(self, 24 | in_channels=1024, 25 | upsample_initial_channel=1536, 26 | ngf=48, 27 | use_rnn=True, 28 | rnn_bidirectional=False, 29 | rnn_num_layers=2, 30 | up_ratios=(5, 4, 4, 4, 2), 31 | dilations=(1, 3, 9), 32 | vq_num_quantizers=1, 33 | vq_dim=2048, 34 | vq_commit_weight=0.25, 35 | vq_weight_init=False, 36 | vq_full_commit_loss=False, 37 | codebook_size=16384, 38 | codebook_dim=32, 39 | ): 40 | super().__init__() 41 | self.hop_length = np.prod(up_ratios) 42 | self.ngf = ngf 43 | self.up_ratios = up_ratios 44 | 45 | self.quantizer = ResidualVQ( 46 | num_quantizers=vq_num_quantizers, 47 | dim=vq_dim, # double the dim for acousitc and semantic 48 | codebook_size=codebook_size, 49 | codebook_dim=codebook_dim, 50 | threshold_ema_dead_code=2, 51 | commitment=vq_commit_weight, 52 | weight_init=vq_weight_init, 53 | full_commit_loss=vq_full_commit_loss, 54 | ) 55 | channels = upsample_initial_channel 56 | layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] 57 | 58 | if use_rnn: 59 | layers += [ 60 | ResLSTM(channels, 61 | num_layers=rnn_num_layers, 62 | bidirectional=rnn_bidirectional 63 | ) 64 | ] 65 | 66 | for i, stride in enumerate(up_ratios): 67 | input_dim = channels // 2**i 68 | output_dim = channels // 2 ** (i + 1) 69 | layers += [DecoderBlock(input_dim, output_dim, stride, dilations)] 70 | 71 | layers += [ 72 | Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)), 73 | WNConv1d(output_dim, 1, kernel_size=7, padding=3), 74 | nn.Tanh(), 75 | ] 76 | 77 | self.model = nn.Sequential(*layers) 78 | 79 | self.reset_parameters() 80 | 81 | def forward(self, x, vq=True): 82 | if vq is True: 83 | x, q, commit_loss = self.quantizer(x) 84 | return x, q, commit_loss 85 | x = self.model(x) 86 | return x 87 | 88 | def vq2emb(self, vq): 89 | self.quantizer = self.quantizer.eval() 90 | x = self.quantizer.vq2emb(vq) 91 | return x 92 | 93 | def get_emb(self): 94 | self.quantizer = self.quantizer.eval() 95 | embs = self.quantizer.get_emb() 96 | return embs 97 | 98 | def inference_vq(self, vq): 99 | x = vq[None,:,:] 100 | x = self.model(x) 101 | return x 102 | 103 | def inference_0(self, x): 104 | x, q, loss, perp = self.quantizer(x) 105 | x = self.model(x) 106 | return x, None 107 | 108 | def inference(self, x): 109 | x = self.model(x) 110 | return x, None 111 | 112 | 113 | def remove_weight_norm(self): 114 | """Remove weight normalization module from all of the layers.""" 115 | 116 | def _remove_weight_norm(m): 117 | try: 118 | torch.nn.utils.remove_weight_norm(m) 119 | except ValueError: # this module didn't have weight norm 120 | return 121 | 122 | self.apply(_remove_weight_norm) 123 | 124 | def apply_weight_norm(self): 125 | """Apply weight normalization module from all of the layers.""" 126 | 127 | def _apply_weight_norm(m): 128 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 129 | torch.nn.utils.weight_norm(m) 130 | 131 | self.apply(_apply_weight_norm) 132 | 133 | def reset_parameters(self): 134 | self.apply(init_weights) 135 | 136 | 137 | class CodecDecoder_oobleck_Transformer(nn.Module): 138 | def __init__(self, 139 | ngf=32, 140 | up_ratios=(5, 4, 4, 4, 2), 141 | dilations=(1, 3, 9), 142 | vq_num_quantizers=1, 143 | vq_dim=1024, 144 | vq_commit_weight=0.25, 145 | vq_weight_init=False, 146 | vq_full_commit_loss=False, 147 | codebook_size=16384, 148 | codebook_dim=16, 149 | hidden_dim=1024, 150 | depth=12, 151 | heads=16, 152 | pos_meb_dim=64, 153 | ): 154 | super().__init__() 155 | self.hop_length = np.prod(up_ratios) 156 | self.capacity = ngf 157 | self.up_ratios = up_ratios 158 | self.hidden_dim = hidden_dim 159 | self.quantizer = ResidualVQ( 160 | num_quantizers=vq_num_quantizers, 161 | dim=vq_dim, # double the dim for acousitc and semantic 162 | codebook_size=codebook_size, 163 | codebook_dim=codebook_dim, 164 | threshold_ema_dead_code=2, 165 | commitment=vq_commit_weight, 166 | weight_init=vq_weight_init, 167 | full_commit_loss=vq_full_commit_loss, 168 | ) 169 | 170 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 171 | 172 | transformer_blocks = [ 173 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 174 | for _ in range(depth) 175 | ] 176 | 177 | self.transformers = nn.Sequential(*transformer_blocks) 178 | 179 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 180 | 181 | self.conv_blocks = blocks.DilatedResidualDecoder( 182 | capacity=self.capacity, 183 | dilated_unit=self.dilated_unit, 184 | upsampling_unit=self.upsampling_unit, 185 | ratios=up_ratios, # 逆转编码器的下采样比率 186 | dilations=dilations, 187 | pre_network_conv=self.pre_conv, 188 | post_network_conv=self.post_conv, 189 | ) 190 | 191 | 192 | 193 | self.reset_parameters() 194 | 195 | def forward(self, x, vq=True): 196 | if vq is True: 197 | x, q, commit_loss = self.quantizer(x) 198 | return x, q, commit_loss 199 | x= self.transformers(x) 200 | x = self.final_layer_norm(x) 201 | x = x.permute(0, 2, 1) 202 | x = self.conv_blocks(x) 203 | return x 204 | 205 | def vq2emb(self, vq): 206 | self.quantizer = self.quantizer.eval() 207 | x = self.quantizer.vq2emb(vq) 208 | return x 209 | 210 | def get_emb(self): 211 | self.quantizer = self.quantizer.eval() 212 | embs = self.quantizer.get_emb() 213 | return embs 214 | 215 | def inference_vq(self, vq): 216 | x = vq[None,:,:] 217 | x = self.model(x) 218 | return x 219 | 220 | def inference_0(self, x): 221 | x, q, loss, perp = self.quantizer(x) 222 | x = self.model(x) 223 | return x, None 224 | 225 | def inference(self, x): 226 | x = self.model(x) 227 | return x, None 228 | 229 | 230 | def remove_weight_norm(self): 231 | """Remove weight normalization module from all of the layers.""" 232 | 233 | def _remove_weight_norm(m): 234 | try: 235 | torch.nn.utils.remove_weight_norm(m) 236 | except ValueError: # this module didn't have weight norm 237 | return 238 | 239 | self.apply(_remove_weight_norm) 240 | 241 | def apply_weight_norm(self): 242 | """Apply weight normalization module from all of the layers.""" 243 | 244 | def _apply_weight_norm(m): 245 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 246 | torch.nn.utils.weight_norm(m) 247 | 248 | self.apply(_apply_weight_norm) 249 | 250 | def reset_parameters(self): 251 | self.apply(init_weights) 252 | 253 | def pre_conv(self, out_channels): 254 | return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1) 255 | 256 | # 定义后处理卷积层,将模型的输出映射到最终的输出通道数 257 | def post_conv(self,in_channels): 258 | return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1) 259 | 260 | def dilated_unit(self, hidden_dim, dilation): 261 | return blocks.DilatedConvolutionalUnit( 262 | hidden_dim=hidden_dim, 263 | dilation=dilation, 264 | kernel_size=3, 265 | activation=nn.ReLU , 266 | normalization=utils.weight_norm 267 | ) 268 | 269 | # 定义上采样单元 270 | def upsampling_unit(self,input_dim, output_dim, stride): 271 | return blocks.UpsamplingUnit( 272 | input_dim=input_dim, 273 | output_dim=output_dim, 274 | stride=stride, 275 | activation=nn.ReLU , 276 | normalization=utils.weight_norm 277 | ) 278 | 279 | def main(): 280 | # 设置设备 281 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 282 | print(f"Using device: {device}") 283 | 284 | # 初始化模型 285 | model = CodecDecoder_oobleck_Transformer().to(device) 286 | print("Model initialized.") 287 | 288 | # 创建测试输入: batch_size x in_channels x sequence_length 289 | batch_size = 2 290 | in_channels = 1024 291 | sequence_length = 100 # 示例长度,可以根据需要调整 292 | dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device) 293 | print(f"Dummy input shape: {dummy_input.shape}") 294 | 295 | # 将模型设为评估模式 296 | model.eval() 297 | 298 | 299 | 300 | output_no_vq = model(dummy_input, vq=False) 301 | c=1 302 | 303 | if __name__ == "__main__": 304 | main() -------------------------------------------------------------------------------- /vq/codec_decoder_vocos.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos') 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from vq.residual_vq import ResidualVQ 7 | from vq.module import WNConv1d, DecoderBlock, ResLSTM 8 | from vq.alias_free_torch import * 9 | from vq import activations 10 | from typing import Optional 11 | from vq.module import ConvNeXtBlock, AdaLayerNorm 12 | from vq.bs_roformer5 import TransformerBlock 13 | # from rotary_embedding_torch import RotaryEmbedding 14 | from torchtune.modules import RotaryPositionalEmbeddings 15 | from vector_quantize_pytorch import ResidualFSQ 16 | from torch.nn import Module, ModuleList 17 | class ISTFT(nn.Module): 18 | """ 19 | Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with 20 | windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. 21 | See issue: https://github.com/pytorch/pytorch/issues/62323 22 | Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. 23 | The NOLA constraint is met as we trim padded samples anyway. 24 | 25 | Args: 26 | n_fft (int): Size of Fourier transform. 27 | hop_length (int): The distance between neighboring sliding window frames. 28 | win_length (int): The size of window frame and STFT filter. 29 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 30 | """ 31 | 32 | def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): 33 | super().__init__() 34 | if padding not in ["center", "same"]: 35 | raise ValueError("Padding must be 'center' or 'same'.") 36 | self.padding = padding 37 | self.n_fft = n_fft 38 | self.hop_length = hop_length 39 | self.win_length = win_length 40 | window = torch.hann_window(win_length) 41 | self.register_buffer("window", window) 42 | 43 | def forward(self, spec: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. 46 | 47 | Args: 48 | spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, 49 | N is the number of frequency bins, and T is the number of time frames. 50 | 51 | Returns: 52 | Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. 53 | """ 54 | if self.padding == "center": 55 | # Fallback to pytorch native implementation 56 | return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) 57 | elif self.padding == "same": 58 | pad = (self.win_length - self.hop_length) // 2 59 | else: 60 | raise ValueError("Padding must be 'center' or 'same'.") 61 | 62 | assert spec.dim() == 3, "Expected a 3D tensor as input" 63 | B, N, T = spec.shape 64 | 65 | # Inverse FFT 66 | ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") 67 | ifft = ifft * self.window[None, :, None] 68 | 69 | # Overlap and Add 70 | output_size = (T - 1) * self.hop_length + self.win_length 71 | y = torch.nn.functional.fold( 72 | ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 73 | )[:, 0, 0, pad:-pad] 74 | 75 | # Window envelope 76 | window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) 77 | window_envelope = torch.nn.functional.fold( 78 | window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 79 | ).squeeze()[pad:-pad] 80 | 81 | # Normalize 82 | assert (window_envelope > 1e-11).all() 83 | y = y / window_envelope 84 | 85 | return y 86 | 87 | 88 | 89 | class FourierHead(nn.Module): 90 | """Base class for inverse fourier modules.""" 91 | 92 | def forward(self, x: torch.Tensor) -> torch.Tensor: 93 | """ 94 | Args: 95 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 96 | L is the sequence length, and H denotes the model dimension. 97 | 98 | Returns: 99 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 100 | """ 101 | raise NotImplementedError("Subclasses must implement the forward method.") 102 | 103 | 104 | class ISTFTHead(FourierHead): 105 | """ 106 | ISTFT Head module for predicting STFT complex coefficients. 107 | 108 | Args: 109 | dim (int): Hidden dimension of the model. 110 | n_fft (int): Size of Fourier transform. 111 | hop_length (int): The distance between neighboring sliding window frames, which should align with 112 | the resolution of the input features. 113 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 114 | """ 115 | 116 | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): 117 | super().__init__() 118 | out_dim = n_fft + 2 119 | self.out = torch.nn.Linear(dim, out_dim) 120 | self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) 121 | 122 | def forward(self, x: torch.Tensor) -> torch.Tensor: 123 | """ 124 | Forward pass of the ISTFTHead module. 125 | 126 | Args: 127 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 128 | L is the sequence length, and H denotes the model dimension. 129 | 130 | Returns: 131 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 132 | """ 133 | x_pred = self.out(x ) 134 | # x_pred = x 135 | x_pred = x_pred.transpose(1, 2) 136 | mag, p = x_pred.chunk(2, dim=1) 137 | mag = torch.exp(mag) 138 | mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes 139 | # wrapping happens here. These two lines produce real and imaginary value 140 | x = torch.cos(p) 141 | y = torch.sin(p) 142 | # recalculating phase here does not produce anything new 143 | # only costs time 144 | # phase = torch.atan2(y, x) 145 | # S = mag * torch.exp(phase * 1j) 146 | # better directly produce the complex value 147 | S = mag * (x + 1j * y) 148 | audio = self.istft(S) 149 | return audio.unsqueeze(1),x_pred 150 | 151 | 152 | def nonlinearity(x): 153 | # swish 154 | return x * torch.sigmoid(x) 155 | 156 | 157 | def Normalize(in_channels, num_groups=32): 158 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) 159 | 160 | 161 | class ResnetBlock(nn.Module): 162 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 163 | dropout, temb_channels=512): 164 | super().__init__() 165 | self.in_channels = in_channels 166 | out_channels = in_channels if out_channels is None else out_channels 167 | self.out_channels = out_channels 168 | self.use_conv_shortcut = conv_shortcut 169 | 170 | self.norm1 = Normalize(in_channels) 171 | self.conv1 = torch.nn.Conv1d(in_channels, 172 | out_channels, 173 | kernel_size=3, 174 | stride=1, 175 | padding=1) 176 | if temb_channels > 0: 177 | self.temb_proj = torch.nn.Linear(temb_channels, 178 | out_channels) 179 | self.norm2 = Normalize(out_channels) 180 | self.dropout = torch.nn.Dropout(dropout) 181 | self.conv2 = torch.nn.Conv1d(out_channels, 182 | out_channels, 183 | kernel_size=3, 184 | stride=1, 185 | padding=1) 186 | if self.in_channels != self.out_channels: 187 | if self.use_conv_shortcut: 188 | self.conv_shortcut = torch.nn.Conv1d(in_channels, 189 | out_channels, 190 | kernel_size=3, 191 | stride=1, 192 | padding=1) 193 | else: 194 | self.nin_shortcut = torch.nn.Conv1d(in_channels, 195 | out_channels, 196 | kernel_size=1, 197 | stride=1, 198 | padding=0) 199 | 200 | def forward(self, x, temb=None): 201 | h = x 202 | h = self.norm1(h) 203 | h = nonlinearity(h) 204 | h = self.conv1(h) 205 | 206 | if temb is not None: 207 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 208 | 209 | h = self.norm2(h) 210 | h = nonlinearity(h) 211 | h = self.dropout(h) 212 | h = self.conv2(h) 213 | 214 | if self.in_channels != self.out_channels: 215 | if self.use_conv_shortcut: 216 | x = self.conv_shortcut(x) 217 | else: 218 | x = self.nin_shortcut(x) 219 | 220 | return x + h 221 | 222 | class AttnBlock(nn.Module): 223 | def __init__(self, in_channels): 224 | super().__init__() 225 | self.in_channels = in_channels 226 | 227 | self.norm = Normalize(in_channels) 228 | self.q = torch.nn.Conv1d(in_channels, 229 | in_channels, 230 | kernel_size=1, 231 | stride=1, 232 | padding=0) 233 | self.k = torch.nn.Conv1d(in_channels, 234 | in_channels, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | self.v = torch.nn.Conv1d(in_channels, 239 | in_channels, 240 | kernel_size=1, 241 | stride=1, 242 | padding=0) 243 | self.proj_out = torch.nn.Conv1d(in_channels, 244 | in_channels, 245 | kernel_size=1, 246 | stride=1, 247 | padding=0) 248 | 249 | def forward(self, x): 250 | h_ = x 251 | h_ = self.norm(h_) 252 | q = self.q(h_) 253 | k = self.k(h_) 254 | v = self.v(h_) 255 | 256 | # compute attention 257 | b, c, h = q.shape 258 | q = q.permute(0, 2, 1) # b,hw,c 259 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 260 | w_ = w_ * (int(c) ** (-0.5)) 261 | w_ = torch.nn.functional.softmax(w_, dim=2) 262 | 263 | # attend to values 264 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 265 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 266 | 267 | h_ = self.proj_out(h_) 268 | 269 | return x + h_ 270 | 271 | def make_attn(in_channels, attn_type="vanilla"): 272 | assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' 273 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 274 | if attn_type == "vanilla": 275 | return AttnBlock(in_channels) 276 | 277 | 278 | class Backbone(nn.Module): 279 | """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" 280 | 281 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 282 | """ 283 | Args: 284 | x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, 285 | C denotes output features, and L is the sequence length. 286 | 287 | Returns: 288 | Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, 289 | and H denotes the model dimension. 290 | """ 291 | raise NotImplementedError("Subclasses must implement the forward method.") 292 | 293 | 294 | class VocosBackbone(Backbone): 295 | """ 296 | Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization 297 | 298 | Args: 299 | input_channels (int): Number of input features channels. 300 | dim (int): Hidden dimension of the model. 301 | intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. 302 | num_layers (int): Number of ConvNeXtBlock layers. 303 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. 304 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 305 | None means non-conditional model. Defaults to None. 306 | """ 307 | 308 | def __init__( 309 | self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): 310 | super().__init__() 311 | 312 | self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3) 313 | 314 | 315 | 316 | self.temb_ch = 0 317 | block_in = hidden_dim 318 | dropout = 0.1 319 | 320 | prior_net : tp.List[nn.Module] = [ 321 | ResnetBlock(in_channels=block_in,out_channels=block_in, 322 | temb_channels=self.temb_ch,dropout=dropout), 323 | ResnetBlock(in_channels=block_in,out_channels=block_in, 324 | temb_channels=self.temb_ch,dropout=dropout), 325 | ] 326 | self.prior_net = nn.Sequential(*prior_net) 327 | 328 | depth = depth 329 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 330 | 331 | 332 | transformer_blocks = [ 333 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 334 | for _ in range(depth) 335 | ] 336 | 337 | 338 | self.transformers = nn.Sequential(*transformer_blocks) 339 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 340 | post_net : tp.List[nn.Module] = [ 341 | ResnetBlock(in_channels=block_in,out_channels=block_in, 342 | temb_channels=self.temb_ch,dropout=dropout), 343 | ResnetBlock(in_channels=block_in,out_channels=block_in, 344 | temb_channels=self.temb_ch,dropout=dropout), 345 | ] 346 | self.post_net = nn.Sequential(*post_net) 347 | 348 | def forward(self, x: torch.Tensor ) -> torch.Tensor: 349 | x = x.transpose(1, 2) 350 | x = self.embed(x) 351 | x = self.prior_net(x) 352 | x = x.transpose(1, 2) 353 | x= self.transformers(x) 354 | x = x.transpose(1, 2) 355 | x = self.post_net(x) 356 | x = x.transpose(1, 2) 357 | x = self.final_layer_norm(x) 358 | return x 359 | 360 | def init_weights(m): 361 | if isinstance(m, nn.Conv1d): 362 | nn.init.trunc_normal_(m.weight, std=0.02) 363 | nn.init.constant_(m.bias, 0) 364 | 365 | class CodecDecoderVocos(nn.Module): 366 | def __init__(self, 367 | hidden_dim=1024, 368 | depth=12, 369 | heads=16, 370 | pos_meb_dim=64, 371 | hop_length=320, 372 | vq_num_quantizers=1, 373 | vq_dim=2048, #1024 2048 374 | vq_commit_weight=0.25, 375 | vq_weight_init=False, 376 | vq_full_commit_loss=False, 377 | codebook_size=16384, 378 | codebook_dim=16, 379 | ): 380 | super().__init__() 381 | self.hop_length = hop_length 382 | 383 | self.quantizer = ResidualFSQ( 384 | dim = vq_dim, 385 | levels = [4, 4, 4, 4, 4,4,4,4], 386 | num_quantizers = 1 387 | ) 388 | 389 | # self.quantizer = ResidualVQ( 390 | # num_quantizers=vq_num_quantizers, 391 | # dim=vq_dim, 392 | # codebook_size=codebook_size, 393 | # codebook_dim=codebook_dim, 394 | # threshold_ema_dead_code=2, 395 | # commitment=vq_commit_weight, 396 | # weight_init=vq_weight_init, 397 | # full_commit_loss=vq_full_commit_loss, 398 | # ) 399 | 400 | 401 | self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) 402 | 403 | self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") 404 | 405 | self.reset_parameters() 406 | 407 | def forward(self, x, vq=True): 408 | if vq is True: 409 | # x, q, commit_loss = self.quantizer(x) 410 | x = x.permute(0, 2, 1) 411 | x, q = self.quantizer(x) 412 | x = x.permute(0, 2, 1) 413 | q = q.permute(0, 2, 1) 414 | return x, q, None 415 | x = self.backbone(x) 416 | x,_ = self.head(x) 417 | 418 | return x ,_ 419 | 420 | def vq2emb(self, vq): 421 | self.quantizer = self.quantizer.eval() 422 | x = self.quantizer.vq2emb(vq) 423 | return x 424 | 425 | def get_emb(self): 426 | self.quantizer = self.quantizer.eval() 427 | embs = self.quantizer.get_emb() 428 | return embs 429 | 430 | def inference_vq(self, vq): 431 | x = vq[None,:,:] 432 | x = self.model(x) 433 | return x 434 | 435 | def inference_0(self, x): 436 | x, q, loss, perp = self.quantizer(x) 437 | x = self.model(x) 438 | return x, None 439 | 440 | def inference(self, x): 441 | x = self.model(x) 442 | return x, None 443 | 444 | 445 | def remove_weight_norm(self): 446 | """Remove weight normalization module from all of the layers.""" 447 | 448 | def _remove_weight_norm(m): 449 | try: 450 | torch.nn.utils.remove_weight_norm(m) 451 | except ValueError: # this module didn't have weight norm 452 | return 453 | 454 | self.apply(_remove_weight_norm) 455 | 456 | def apply_weight_norm(self): 457 | """Apply weight normalization module from all of the layers.""" 458 | 459 | def _apply_weight_norm(m): 460 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 461 | torch.nn.utils.weight_norm(m) 462 | 463 | self.apply(_apply_weight_norm) 464 | 465 | def reset_parameters(self): 466 | self.apply(init_weights) 467 | 468 | 469 | 470 | class CodecDecoderVocos_transpose(nn.Module): 471 | def __init__(self, 472 | hidden_dim=1024, 473 | depth=12, 474 | heads=16, 475 | pos_meb_dim=64, 476 | hop_length=320, 477 | vq_num_quantizers=1, 478 | vq_dim=1024, #1024 2048 479 | vq_commit_weight=0.25, 480 | vq_weight_init=False, 481 | vq_full_commit_loss=False, 482 | codebook_size=16384, 483 | codebook_dim=16, 484 | ): 485 | super().__init__() 486 | self.hop_length = hop_length 487 | 488 | 489 | self.quantizer = ResidualVQ( 490 | num_quantizers=vq_num_quantizers, 491 | dim=vq_dim, 492 | codebook_size=codebook_size, 493 | codebook_dim=codebook_dim, 494 | threshold_ema_dead_code=2, 495 | commitment=vq_commit_weight, 496 | weight_init=vq_weight_init, 497 | full_commit_loss=vq_full_commit_loss, 498 | ) 499 | 500 | 501 | self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) 502 | 503 | self.inverse_mel_conv = nn.Sequential( 504 | nn.GELU(), 505 | nn.ConvTranspose1d( 506 | in_channels=hidden_dim, 507 | out_channels=hidden_dim, 508 | kernel_size=3, 509 | stride=2, 510 | padding=1, 511 | output_padding=1 # 确保输出长度与编码前匹配 512 | ), 513 | nn.GELU(), 514 | nn.ConvTranspose1d( 515 | in_channels=hidden_dim, 516 | out_channels=hidden_dim, 517 | kernel_size=3, 518 | padding=1 519 | ) 520 | ) 521 | 522 | self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") 523 | 524 | self.reset_parameters() 525 | 526 | def forward(self, x, vq=True): 527 | if vq is True: 528 | x, q, commit_loss = self.quantizer(x) 529 | return x, q, commit_loss 530 | x = self.backbone(x) 531 | x,_ = self.head(x) 532 | 533 | return x ,_ 534 | 535 | def vq2emb(self, vq): 536 | self.quantizer = self.quantizer.eval() 537 | x = self.quantizer.vq2emb(vq) 538 | return x 539 | 540 | def get_emb(self): 541 | self.quantizer = self.quantizer.eval() 542 | embs = self.quantizer.get_emb() 543 | return embs 544 | 545 | def inference_vq(self, vq): 546 | x = vq[None,:,:] 547 | x = self.model(x) 548 | return x 549 | 550 | def inference_0(self, x): 551 | x, q, loss, perp = self.quantizer(x) 552 | x = self.model(x) 553 | return x, None 554 | 555 | def inference(self, x): 556 | x = self.model(x) 557 | return x, None 558 | 559 | 560 | def remove_weight_norm(self): 561 | """Remove weight normalization module from all of the layers.""" 562 | 563 | def _remove_weight_norm(m): 564 | try: 565 | torch.nn.utils.remove_weight_norm(m) 566 | except ValueError: # this module didn't have weight norm 567 | return 568 | 569 | self.apply(_remove_weight_norm) 570 | 571 | def apply_weight_norm(self): 572 | """Apply weight normalization module from all of the layers.""" 573 | 574 | def _apply_weight_norm(m): 575 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 576 | torch.nn.utils.weight_norm(m) 577 | 578 | self.apply(_apply_weight_norm) 579 | 580 | def reset_parameters(self): 581 | self.apply(init_weights) 582 | 583 | 584 | 585 | 586 | def main(): 587 | # 设置设备 588 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 589 | print(f"Using device: {device}") 590 | 591 | # 初始化模型 592 | model = CodecDecoderVocos_transpose().to(device) 593 | print("Model initialized.") 594 | 595 | # 创建测试输入: batch_size x in_channels x sequence_length 596 | batch_size = 2 597 | in_channels = 1024 598 | sequence_length = 50 # 示例长度,可以根据需要调整 599 | dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device) 600 | print(f"Dummy input shape: {dummy_input.shape}") 601 | 602 | # 将模型设为评估模式 603 | model.eval() 604 | 605 | # 前向传播(使用 VQ) 606 | # with torch.no_grad(): 607 | # try: 608 | # output, q, commit_loss = model(dummy_input, vq=True) 609 | # print("Forward pass with VQ:") 610 | # print(f"Output shape: {output.shape}") 611 | # print(f"Quantized codes shape: {q.shape}") 612 | # print(f"Commitment loss: {commit_loss}") 613 | # except Exception as e: 614 | # print(f"Error during forward pass with VQ: {e}") 615 | 616 | # 前向传播(不使用 VQ) 617 | with torch.no_grad(): 618 | # try: 619 | output_no_vq = model(dummy_input, vq=False) 620 | print("\nForward pass without VQ:") 621 | print(f"Output shape: {output_no_vq.shape}") 622 | c=1 623 | # except Exception as e: 624 | # print(f"Error during forward pass without VQ: {e}") 625 | 626 | 627 | # model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) 628 | # model_size_mb = model_size_bytes / (1024 ** 2) 629 | # print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)") 630 | 631 | if __name__ == "__main__": 632 | main() -------------------------------------------------------------------------------- /vq/codec_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv1d_transformer') 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | from vq.module import WNConv1d, EncoderBlock, ResLSTM 7 | from vq.alias_free_torch import * 8 | from vq import activations 9 | from vq.bs_roformer5 import TransformerBlock 10 | # from rotary_embedding_torch import RotaryEmbedding 11 | from torchtune.modules import RotaryPositionalEmbeddings 12 | import vq.blocks as blocks 13 | from torch.nn import utils 14 | def init_weights(m): 15 | if isinstance(m, nn.Conv1d): 16 | nn.init.trunc_normal_(m.weight, std=0.02) 17 | nn.init.constant_(m.bias, 0) 18 | 19 | class CodecEncoder(nn.Module): 20 | def __init__(self, 21 | ngf=48, 22 | use_rnn=True, 23 | rnn_bidirectional=False, 24 | rnn_num_layers=2, 25 | up_ratios=(2, 2, 4, 4, 5), 26 | dilations=(1, 3, 9), 27 | out_channels=1024): 28 | super().__init__() 29 | self.hop_length = np.prod(up_ratios) 30 | self.ngf = ngf 31 | self.up_ratios = up_ratios 32 | 33 | # Create first convolution 34 | d_model = ngf 35 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 36 | 37 | # Create EncoderBlocks that double channels as they downsample by `stride` 38 | for i, stride in enumerate(up_ratios): 39 | d_model *= 2 40 | self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)] 41 | # RNN 42 | if use_rnn: 43 | self.block += [ 44 | ResLSTM(d_model, 45 | num_layers=rnn_num_layers, 46 | bidirectional=rnn_bidirectional 47 | ) 48 | ] 49 | # Create last convolution 50 | self.block += [ 51 | Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), 52 | WNConv1d(d_model, out_channels, kernel_size=3, padding=1), 53 | ] 54 | 55 | # Wrap black into nn.Sequential 56 | self.block = nn.Sequential(*self.block) 57 | self.enc_dim = d_model 58 | 59 | self.reset_parameters() 60 | 61 | def forward(self, x): 62 | out = self.block(x) 63 | return out 64 | 65 | def inference(self, x): 66 | return self.block(x) 67 | 68 | def remove_weight_norm(self): 69 | """Remove weight normalization module from all of the layers.""" 70 | 71 | def _remove_weight_norm(m): 72 | try: 73 | torch.nn.utils.remove_weight_norm(m) 74 | except ValueError: # this module didn't have weight norm 75 | return 76 | 77 | self.apply(_remove_weight_norm) 78 | 79 | def apply_weight_norm(self): 80 | """Apply weight normalization module from all of the layers.""" 81 | 82 | def _apply_weight_norm(m): 83 | if isinstance(m, nn.Conv1d): 84 | torch.nn.utils.weight_norm(m) 85 | 86 | self.apply(_apply_weight_norm) 87 | 88 | def reset_parameters(self): 89 | self.apply(init_weights) 90 | 91 | 92 | class Transpose(nn.Module): 93 | def __init__(self, dim1, dim2): 94 | super(Transpose, self).__init__() 95 | self.dim1 = dim1 96 | self.dim2 = dim2 97 | 98 | def forward(self, x): 99 | return x.transpose(self.dim1, self.dim2) 100 | 101 | class CodecEncoder_Transformer(nn.Module): 102 | def __init__(self, 103 | ngf=48, 104 | up_ratios=[2, 2, 4, 4, 5], 105 | dilations=(1, 3, 9), 106 | hidden_dim=1024, 107 | depth=12, 108 | heads=12, 109 | pos_meb_dim=64, 110 | ): 111 | super().__init__() 112 | self.hop_length = np.prod(up_ratios) 113 | self.ngf =ngf 114 | self.up_ratios = up_ratios 115 | 116 | d_model = ngf 117 | self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 118 | 119 | 120 | for i, stride in enumerate(up_ratios): 121 | d_model *= 2 122 | self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)] 123 | 124 | self.conv_blocks = nn.Sequential(*self.conv_blocks) 125 | 126 | 127 | # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 128 | 129 | 130 | # transformer_blocks = [ 131 | # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 132 | # for _ in range(depth) 133 | # ] 134 | 135 | 136 | # self.transformers = nn.Sequential(*transformer_blocks) 137 | 138 | # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 139 | 140 | self.conv_final_block = [ 141 | Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), 142 | WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1), 143 | ] 144 | self.conv_final_block = nn.Sequential(*self.conv_final_block) 145 | 146 | self.reset_parameters() 147 | 148 | def forward(self, x): 149 | x = self.conv_blocks(x) 150 | # x = x.permute(0, 2, 1) 151 | # x= self.transformers(x) 152 | # x = self.final_layer_norm(x) 153 | # x = x.permute(0, 2, 1) 154 | x = self.conv_final_block (x) 155 | x = x.permute(0, 2, 1) 156 | return x 157 | 158 | def inference(self, x): 159 | return self.block(x) 160 | 161 | def remove_weight_norm(self): 162 | """Remove weight normalization module from all of the layers.""" 163 | 164 | def _remove_weight_norm(m): 165 | try: 166 | torch.nn.utils.remove_weight_norm(m) 167 | except ValueError: # this module didn't have weight norm 168 | return 169 | 170 | self.apply(_remove_weight_norm) 171 | 172 | def apply_weight_norm(self): 173 | """Apply weight normalization module from all of the layers.""" 174 | 175 | def _apply_weight_norm(m): 176 | if isinstance(m, nn.Conv1d): 177 | torch.nn.utils.weight_norm(m) 178 | 179 | self.apply(_apply_weight_norm) 180 | 181 | def reset_parameters(self): 182 | self.apply(init_weights) 183 | 184 | 185 | 186 | class Codec_oobleck_Transformer(nn.Module): 187 | def __init__(self, 188 | ngf=32, 189 | up_ratios=(2, 2,4,4, 5), 190 | dilations=(1, 3, 9), 191 | hidden_dim=1024, 192 | depth=12, 193 | heads=16, 194 | pos_meb_dim=64, 195 | ): 196 | super().__init__() 197 | self.hop_length = np.prod(up_ratios) 198 | self.ngf =ngf 199 | self.up_ratios = up_ratios 200 | self.hidden_dim = hidden_dim 201 | 202 | 203 | self.conv_blocks = blocks.DilatedResidualEncoder( 204 | capacity=ngf, 205 | dilated_unit=self.dilated_unit, 206 | downsampling_unit=self.downsampling_unit, 207 | ratios=up_ratios, 208 | dilations=dilations, 209 | pre_network_conv=self.pre_conv, 210 | post_network_conv=self.post_conv, 211 | ) 212 | 213 | 214 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 215 | 216 | transformer_blocks = [ 217 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 218 | for _ in range(depth) 219 | ] 220 | 221 | self.transformers = nn.Sequential(*transformer_blocks) 222 | 223 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 224 | 225 | 226 | self.reset_parameters() 227 | 228 | def forward(self, x): 229 | x = self.conv_blocks(x) 230 | x = x.permute(0, 2, 1) 231 | x= self.transformers(x) 232 | x = self.final_layer_norm(x) 233 | return x 234 | 235 | def inference(self, x): 236 | return self.block(x) 237 | 238 | def remove_weight_norm(self): 239 | """Remove weight normalization module from all of the layers.""" 240 | 241 | def _remove_weight_norm(m): 242 | try: 243 | torch.nn.utils.remove_weight_norm(m) 244 | except ValueError: # this module didn't have weight norm 245 | return 246 | 247 | self.apply(_remove_weight_norm) 248 | 249 | def apply_weight_norm(self): 250 | """Apply weight normalization module from all of the layers.""" 251 | 252 | def _apply_weight_norm(m): 253 | if isinstance(m, nn.Conv1d): 254 | torch.nn.utils.weight_norm(m) 255 | 256 | self.apply(_apply_weight_norm) 257 | 258 | def reset_parameters(self): 259 | self.apply(init_weights) 260 | 261 | def dilated_unit(self,hidden_dim, dilation): 262 | return blocks.DilatedConvolutionalUnit(hidden_dim, 263 | dilation, 264 | kernel_size=3, 265 | activation=nn.ReLU, 266 | normalization=utils.weight_norm) 267 | 268 | def downsampling_unit(self, input_dim: int, output_dim: int, stride: int): 269 | return blocks.DownsamplingUnit(input_dim, 270 | output_dim, 271 | stride, 272 | nn.ReLU, 273 | normalization=utils.weight_norm) 274 | 275 | def pre_conv(self,out_channels): 276 | return nn.Conv1d(1, out_channels, 1) 277 | 278 | def post_conv(self,in_channels): 279 | return nn.Conv1d(in_channels, self.hidden_dim, 1) 280 | 281 | 282 | 283 | 284 | 285 | class CodecEncoder_only_Transformer(nn.Module): 286 | def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): 287 | super().__init__() 288 | # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300, 289 | 290 | depth = depth 291 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 292 | 293 | 294 | transformer_blocks = [ 295 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 296 | for _ in range(depth) 297 | ] 298 | 299 | 300 | self.transformers = nn.Sequential(*transformer_blocks) 301 | 302 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 303 | 304 | def forward(self, x: torch.Tensor ) -> torch.Tensor: 305 | # x = self.embed(x) 306 | 307 | 308 | x= self.transformers(x) 309 | x = self.final_layer_norm(x) 310 | 311 | return x 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | def get_model_size(model): 320 | # 计算总参数数 321 | total_params = sum(p.numel() for p in model.parameters()) 322 | 323 | # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位) 324 | model_size_bytes = total_params # 每个参数4字节 325 | 326 | # 转换为更易读的单位(例如,MB) 327 | model_size_mb = model_size_bytes / (1024 ** 2) 328 | 329 | return total_params, model_size_mb 330 | 331 | if __name__ == '__main__': 332 | model = Codec_oobleck_Transformer() 333 | x = torch.randn(1, 1, 16000) # example input tensor 334 | output = model(x) 335 | print("Output shape:", output.shape) 336 | -------------------------------------------------------------------------------- /vq/factorized_vector_quantize.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | class FactorizedVectorQuantize(nn.Module): 11 | def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): 12 | super().__init__() 13 | self.codebook_size = codebook_size 14 | self.codebook_dim = codebook_dim 15 | self.commitment = commitment 16 | 17 | if dim != self.codebook_dim: 18 | self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) 19 | self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) 20 | else: 21 | self.in_proj = nn.Identity() 22 | self.out_proj = nn.Identity() 23 | self._codebook = nn.Embedding(codebook_size, self.codebook_dim) 24 | 25 | @property 26 | def codebook(self): 27 | return self._codebook 28 | 29 | def forward(self, z): 30 | """Quantized the input tensor using a fixed codebook and returns 31 | the corresponding codebook vectors 32 | 33 | Parameters 34 | ---------- 35 | z : Tensor[B x D x T] 36 | 37 | Returns 38 | ------- 39 | Tensor[B x D x T] 40 | Quantized continuous representation of input 41 | Tensor[1] 42 | Commitment loss to train encoder to predict vectors closer to codebook 43 | entries 44 | Tensor[1] 45 | Codebook loss to update the codebook 46 | Tensor[B x T] 47 | Codebook indices (quantized discrete representation of input) 48 | Tensor[B x D x T] 49 | Projected latents (continuous representation of input before quantization) 50 | """ 51 | # transpose since we use linear 52 | 53 | z = rearrange(z, "b d t -> b t d") 54 | 55 | # Factorized codes project input into low-dimensional space 56 | z_e = self.in_proj(z) # z_e : (B x T x D) 57 | z_e = rearrange(z_e, "b t d -> b d t") 58 | z_q, indices = self.decode_latents(z_e) 59 | 60 | 61 | if self.training: 62 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment 63 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2]) 64 | commit_loss = commitment_loss + codebook_loss 65 | else: 66 | commit_loss = torch.zeros(z.shape[0], device = z.device) 67 | 68 | z_q = ( 69 | z_e + (z_q - z_e).detach() 70 | ) # noop in forward pass, straight-through gradient estimator in backward pass 71 | 72 | z_q = rearrange(z_q, "b d t -> b t d") 73 | z_q = self.out_proj(z_q) 74 | z_q = rearrange(z_q, "b t d -> b d t") 75 | 76 | return z_q, indices, commit_loss 77 | 78 | def vq2emb(self, vq, proj=True): 79 | emb = self.embed_code(vq) 80 | if proj: 81 | emb = self.out_proj(emb) 82 | return emb 83 | 84 | def get_emb(self): 85 | return self.codebook.weight 86 | 87 | def embed_code(self, embed_id): 88 | return F.embedding(embed_id, self.codebook.weight) 89 | 90 | def decode_code(self, embed_id): 91 | return self.embed_code(embed_id).transpose(1, 2) 92 | 93 | def decode_latents(self, latents): 94 | encodings = rearrange(latents, "b d t -> (b t) d") 95 | codebook = self.codebook.weight # codebook: (N x D) 96 | 97 | # L2 normalize encodings and codebook 98 | encodings = F.normalize(encodings) 99 | codebook = F.normalize(codebook) 100 | 101 | # Compute euclidean distance with codebook 102 | dist = ( 103 | encodings.pow(2).sum(1, keepdim=True) 104 | - 2 * encodings @ codebook.t() 105 | + codebook.pow(2).sum(1, keepdim=True).t() 106 | ) 107 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 108 | z_q = self.decode_code(indices) 109 | return z_q, indices -------------------------------------------------------------------------------- /vq/module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from einops import rearrange 3 | from . import activations 4 | from .alias_free_torch import * 5 | from torch.nn.utils import weight_norm 6 | 7 | from typing import Optional, Tuple 8 | 9 | from torch.nn.utils import weight_norm, remove_weight_norm 10 | 11 | 12 | def WNConv1d(*args, **kwargs): 13 | return weight_norm(nn.Conv1d(*args, **kwargs)) 14 | 15 | 16 | def WNConvTranspose1d(*args, **kwargs): 17 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 18 | 19 | class ResidualUnit(nn.Module): 20 | def __init__(self, dim: int = 16, dilation: int = 1): 21 | super().__init__() 22 | pad = ((7 - 1) * dilation) // 2 23 | self.block = nn.Sequential( 24 | Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), 25 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), 26 | Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), 27 | WNConv1d(dim, dim, kernel_size=1), 28 | ) 29 | 30 | def forward(self, x): 31 | return x + self.block(x) 32 | 33 | class EncoderBlock(nn.Module): 34 | def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)): 35 | super().__init__() 36 | runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations] 37 | self.block = nn.Sequential( 38 | *runits, 39 | Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)), 40 | WNConv1d( 41 | dim // 2, 42 | dim, 43 | kernel_size=2 * stride, 44 | stride=stride, 45 | padding=stride // 2 + stride % 2, 46 | ), 47 | ) 48 | 49 | def forward(self, x): 50 | return self.block(x) 51 | 52 | class DecoderBlock(nn.Module): 53 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)): 54 | super().__init__() 55 | self.block = nn.Sequential( 56 | Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)), 57 | WNConvTranspose1d( 58 | input_dim, 59 | output_dim, 60 | kernel_size=2 * stride, 61 | stride=stride, 62 | padding=stride // 2 + stride % 2, 63 | output_padding= stride % 2, 64 | ) 65 | ) 66 | self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations]) 67 | 68 | def forward(self, x): 69 | return self.block(x) 70 | 71 | class ResLSTM(nn.Module): 72 | def __init__(self, dimension: int, 73 | num_layers: int = 2, 74 | bidirectional: bool = False, 75 | skip: bool = True): 76 | super().__init__() 77 | self.skip = skip 78 | self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2, 79 | num_layers, batch_first=True, 80 | bidirectional=bidirectional) 81 | 82 | def forward(self, x): 83 | """ 84 | Args: 85 | x: [B, F, T] 86 | 87 | Returns: 88 | y: [B, F, T] 89 | """ 90 | x = rearrange(x, "b f t -> b t f") 91 | y, _ = self.lstm(x) 92 | if self.skip: 93 | y = y + x 94 | y = rearrange(y, "b t f -> b f t") 95 | return y 96 | 97 | 98 | 99 | class ConvNeXtBlock(nn.Module): 100 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 101 | 102 | Args: 103 | dim (int): Number of input channels. 104 | intermediate_dim (int): Dimensionality of the intermediate layer. 105 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 106 | Defaults to None. 107 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 108 | None means non-conditional LayerNorm. Defaults to None. 109 | """ 110 | 111 | def __init__( 112 | self, 113 | dim: int, 114 | intermediate_dim: int, 115 | layer_scale_init_value: float, 116 | adanorm_num_embeddings: Optional[int] = None, 117 | ): 118 | super().__init__() 119 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 120 | self.adanorm = adanorm_num_embeddings is not None 121 | if adanorm_num_embeddings: 122 | self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) 123 | else: 124 | self.norm = nn.LayerNorm(dim, eps=1e-6) 125 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 126 | self.act = nn.GELU() 127 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 128 | self.gamma = ( 129 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 130 | if layer_scale_init_value > 0 131 | else None 132 | ) 133 | 134 | def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: 135 | residual = x 136 | x = self.dwconv(x) 137 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 138 | if self.adanorm: 139 | assert cond_embedding_id is not None 140 | x = self.norm(x, cond_embedding_id) 141 | else: 142 | x = self.norm(x) 143 | x = self.pwconv1(x) 144 | x = self.act(x) 145 | x = self.pwconv2(x) 146 | if self.gamma is not None: 147 | x = self.gamma * x 148 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 149 | 150 | x = residual + x 151 | return x 152 | 153 | 154 | class AdaLayerNorm(nn.Module): 155 | """ 156 | Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes 157 | 158 | Args: 159 | num_embeddings (int): Number of embeddings. 160 | embedding_dim (int): Dimension of the embeddings. 161 | """ 162 | 163 | def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): 164 | super().__init__() 165 | self.eps = eps 166 | self.dim = embedding_dim 167 | self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 168 | self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 169 | torch.nn.init.ones_(self.scale.weight) 170 | torch.nn.init.zeros_(self.shift.weight) 171 | 172 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: 173 | scale = self.scale(cond_embedding_id) 174 | shift = self.shift(cond_embedding_id) 175 | x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) 176 | x = x * scale + shift 177 | return x 178 | 179 | 180 | class ResBlock1(nn.Module): 181 | """ 182 | ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, 183 | but without upsampling layers. 184 | 185 | Args: 186 | dim (int): Number of input channels. 187 | kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. 188 | dilation (tuple[int], optional): Dilation factors for the dilated convolutions. 189 | Defaults to (1, 3, 5). 190 | lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. 191 | Defaults to 0.1. 192 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 193 | Defaults to None. 194 | """ 195 | 196 | def __init__( 197 | self, 198 | dim: int, 199 | kernel_size: int = 3, 200 | dilation: Tuple[int, int, int] = (1, 3, 5), 201 | lrelu_slope: float = 0.1, 202 | layer_scale_init_value: Optional[float] = None, 203 | ): 204 | super().__init__() 205 | self.lrelu_slope = lrelu_slope 206 | self.convs1 = nn.ModuleList( 207 | [ 208 | weight_norm( 209 | nn.Conv1d( 210 | dim, 211 | dim, 212 | kernel_size, 213 | 1, 214 | dilation=dilation[0], 215 | padding=self.get_padding(kernel_size, dilation[0]), 216 | ) 217 | ), 218 | weight_norm( 219 | nn.Conv1d( 220 | dim, 221 | dim, 222 | kernel_size, 223 | 1, 224 | dilation=dilation[1], 225 | padding=self.get_padding(kernel_size, dilation[1]), 226 | ) 227 | ), 228 | weight_norm( 229 | nn.Conv1d( 230 | dim, 231 | dim, 232 | kernel_size, 233 | 1, 234 | dilation=dilation[2], 235 | padding=self.get_padding(kernel_size, dilation[2]), 236 | ) 237 | ), 238 | ] 239 | ) 240 | 241 | self.convs2 = nn.ModuleList( 242 | [ 243 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 244 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 245 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 246 | ] 247 | ) 248 | 249 | self.gamma = nn.ParameterList( 250 | [ 251 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 252 | if layer_scale_init_value is not None 253 | else None, 254 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 255 | if layer_scale_init_value is not None 256 | else None, 257 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 258 | if layer_scale_init_value is not None 259 | else None, 260 | ] 261 | ) 262 | 263 | def forward(self, x: torch.Tensor) -> torch.Tensor: 264 | for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): 265 | xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) 266 | xt = c1(xt) 267 | xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) 268 | xt = c2(xt) 269 | if gamma is not None: 270 | xt = gamma * xt 271 | x = xt + x 272 | return x 273 | 274 | def remove_weight_norm(self): 275 | for l in self.convs1: 276 | remove_weight_norm(l) 277 | for l in self.convs2: 278 | remove_weight_norm(l) 279 | 280 | @staticmethod 281 | def get_padding(kernel_size: int, dilation: int = 1) -> int: 282 | return int((kernel_size * dilation - dilation) / 2) 283 | 284 | 285 | def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: 286 | """ 287 | Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. 288 | 289 | Args: 290 | x (Tensor): Input tensor. 291 | clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. 292 | 293 | Returns: 294 | Tensor: Element-wise logarithm of the input tensor with clipping applied. 295 | """ 296 | return torch.log(torch.clip(x, min=clip_val)) 297 | 298 | 299 | def symlog(x: torch.Tensor) -> torch.Tensor: 300 | return torch.sign(x) * torch.log1p(x.abs()) 301 | 302 | 303 | def symexp(x: torch.Tensor) -> torch.Tensor: 304 | return torch.sign(x) * (torch.exp(x.abs()) - 1) 305 | 306 | 307 | 308 | class SemanticEncoder(nn.Module): 309 | def __init__( 310 | self, 311 | input_channels: int, 312 | code_dim: int, 313 | encode_channels: int, 314 | kernel_size: int = 3, 315 | bias: bool = True, 316 | ): 317 | super(SemanticEncoder, self).__init__() 318 | 319 | # 初始卷积,将 input_channels 映射到 encode_channels 320 | self.initial_conv = nn.Conv1d( 321 | in_channels=input_channels, 322 | out_channels=encode_channels, 323 | kernel_size=kernel_size, 324 | stride=1, 325 | padding=(kernel_size - 1) // 2, 326 | bias=False 327 | ) 328 | 329 | # 残差块 330 | self.residual_blocks = nn.Sequential( 331 | nn.ReLU(inplace=True), 332 | nn.Conv1d( 333 | encode_channels, 334 | encode_channels, 335 | kernel_size=kernel_size, 336 | stride=1, 337 | padding=(kernel_size - 1) // 2, 338 | bias=bias 339 | ), 340 | nn.ReLU(inplace=True), 341 | nn.Conv1d( 342 | encode_channels, 343 | encode_channels, 344 | kernel_size=kernel_size, 345 | stride=1, 346 | padding=(kernel_size - 1) // 2, 347 | bias=bias 348 | ) 349 | ) 350 | 351 | # 最终卷积,将 encode_channels 映射到 code_dim 352 | self.final_conv = nn.Conv1d( 353 | in_channels=encode_channels, 354 | out_channels=code_dim, 355 | kernel_size=kernel_size, 356 | stride=1, 357 | padding=(kernel_size - 1) // 2, 358 | bias=False 359 | ) 360 | 361 | def forward(self, x): 362 | """ 363 | 前向传播方法。 364 | 365 | Args: 366 | x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length) 367 | 368 | Returns: 369 | Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length) 370 | """ 371 | x = self.initial_conv(x) # (Batch, Encode_channels, Length) 372 | x = self.residual_blocks(x) + x # 残差连接 373 | x = self.final_conv(x) # (Batch, Code_dim, Length) 374 | return x 375 | 376 | class SemanticDecoder(nn.Module): 377 | def __init__( 378 | self, 379 | code_dim: int, 380 | output_channels: int, 381 | decode_channels: int, 382 | kernel_size: int = 3, 383 | bias: bool = True, 384 | ): 385 | super(SemanticDecoder, self).__init__() 386 | 387 | # Initial convolution to map code_dim to decode_channels 388 | self.initial_conv = nn.Conv1d( 389 | in_channels=code_dim, 390 | out_channels=decode_channels, 391 | kernel_size=kernel_size, 392 | stride=1, 393 | padding=(kernel_size - 1) // 2, 394 | bias=False 395 | ) 396 | 397 | # Residual Blocks 398 | self.residual_blocks = nn.Sequential( 399 | nn.ReLU(inplace=True), 400 | nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias), 401 | nn.ReLU(inplace=True), 402 | nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias) 403 | ) 404 | 405 | # Final convolution to map decode_channels to output_channels 406 | self.final_conv = nn.Conv1d( 407 | in_channels=decode_channels, 408 | out_channels=output_channels, 409 | kernel_size=kernel_size, 410 | stride=1, 411 | padding=(kernel_size - 1) // 2, 412 | bias=False 413 | ) 414 | 415 | def forward(self, z): 416 | # z: (Batch, Code_dim, Length) 417 | x = self.initial_conv(z) # (Batch, Decode_channels, Length) 418 | x = self.residual_blocks(x) + x # Residual connection 419 | x = self.final_conv(x) # (Batch, Output_channels, Length) 420 | return x -------------------------------------------------------------------------------- /vq/residual_vq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from .factorized_vector_quantize import FactorizedVectorQuantize 5 | 6 | class ResidualVQ(nn.Module): 7 | def __init__( 8 | self, 9 | *, 10 | num_quantizers, 11 | codebook_size, 12 | **kwargs 13 | ): 14 | super().__init__() 15 | VQ = FactorizedVectorQuantize 16 | if type(codebook_size) == int: 17 | codebook_size = [codebook_size] * num_quantizers 18 | self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size]) 19 | self.num_quantizers = num_quantizers 20 | 21 | def forward(self, x): 22 | quantized_out = 0. 23 | residual = x 24 | 25 | all_losses = [] 26 | all_indices = [] 27 | 28 | for idx, layer in enumerate(self.layers): 29 | quantized, indices, loss = layer(residual) 30 | 31 | residual = residual - quantized 32 | 33 | quantized_out = quantized_out + quantized 34 | 35 | loss = loss.mean() 36 | 37 | all_indices.append(indices) 38 | all_losses.append(loss) 39 | all_losses, all_indices = map(torch.stack, (all_losses, all_indices)) 40 | return quantized_out, all_indices, all_losses 41 | 42 | def vq2emb(self, vq, proj=True): 43 | # [B, T, num_quantizers] 44 | quantized_out = 0. 45 | for idx, layer in enumerate(self.layers): 46 | quantized = layer.vq2emb(vq[:, :, idx], proj=proj) 47 | quantized_out = quantized_out + quantized 48 | return quantized_out 49 | def get_emb(self): 50 | embs = [] 51 | for idx, layer in enumerate(self.layers): 52 | embs.append(layer.get_emb()) 53 | return embs 54 | -------------------------------------------------------------------------------- /vq/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | import numpy as np 6 | 7 | 8 | class EncoderBlock(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): 10 | super(EncoderBlock, self).__init__() 11 | 12 | self.pool_size = 2 13 | 14 | self.conv_block = ConvBlock(in_channels, out_channels, kernel_size) 15 | 16 | def forward(self, x): 17 | latent = self.conv_block(x) 18 | output = F.avg_pool2d(latent, kernel_size=self.pool_size) 19 | return output, latent 20 | 21 | class DecoderBlock(nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): 23 | super(DecoderBlock, self).__init__() 24 | 25 | stride = 2 26 | 27 | self.upsample = nn.ConvTranspose2d( 28 | in_channels=in_channels, 29 | out_channels=in_channels, 30 | kernel_size=stride, 31 | stride=stride, 32 | padding=(0, 0), 33 | bias=False, 34 | ) 35 | 36 | self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size) 37 | 38 | def forward(self, x, latent): 39 | x = self.upsample(x) 40 | x = torch.cat((x, latent), dim=1) 41 | output = self.conv_block(x) 42 | return output 43 | 44 | 45 | class UNet(nn.Module): 46 | def __init__(self,freq_dim=1281,out_channel=1024): 47 | super(UNet, self).__init__() 48 | 49 | self.downsample_ratio = 16 50 | 51 | 52 | in_channels = 1 #self.audio_channels * self.cmplx_num 53 | 54 | self.encoder_block1 = EncoderBlock(in_channels, 16) 55 | self.encoder_block2 = EncoderBlock(16, 64) 56 | self.encoder_block3 = EncoderBlock(64, 256) 57 | self.encoder_block4 = EncoderBlock(256, 1024) 58 | self.middle = EncoderBlock(1024, 1024) 59 | self.decoder_block1 = DecoderBlock(1024, 256) 60 | self.decoder_block2 = DecoderBlock(256, 64) 61 | self.decoder_block3 = DecoderBlock(64, 16) 62 | self.decoder_block4 = DecoderBlock(16, 16) 63 | 64 | self.fc = nn.Linear(freq_dim*16, out_channel) 65 | 66 | def forward(self, x_ori): 67 | """ 68 | Args: 69 | complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量 70 | 71 | Returns: 72 | output: (batch_size, channels_num, time_steps, freq_bins),复数张量 73 | """ 74 | 75 | 76 | x= self.process_image(x_ori) 77 | x1, latent1 = self.encoder_block1(x) 78 | x2, latent2 = self.encoder_block2(x1) 79 | x3, latent3 = self.encoder_block3(x2) 80 | x4, latent4 = self.encoder_block4(x3) 81 | _, h = self.middle(x4) 82 | x5 = self.decoder_block1(h, latent4) 83 | x6 = self.decoder_block2(x5, latent3) 84 | x7 = self.decoder_block3(x6, latent2) 85 | x8 = self.decoder_block4(x7, latent1) 86 | x= self.unprocess_image(x8,x_ori.shape[2]) 87 | x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024] 88 | x = x.view(x.size(0), x.size(1), -1) 89 | x= self.fc(x) 90 | 91 | return x 92 | 93 | def process_image(self, x): 94 | """ 95 | 处理频谱以便可以被 downsample_ratio 整除。 96 | 97 | Args: 98 | x: (B, C, T, F) 99 | 100 | Returns: 101 | output: (B, C, T_padded, F_reduced) 102 | """ 103 | 104 | B, C, T, Freq = x.shape 105 | 106 | pad_len = ( 107 | int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio 108 | - T 109 | ) 110 | x = F.pad(x, pad=(0, 0, 0, pad_len)) 111 | 112 | output = x[:, :, :, 0 : Freq - 1] 113 | 114 | return output 115 | 116 | def unprocess_image(self, x,time_steps): 117 | """ 118 | 恢复频谱到原始形状。 119 | 120 | Args: 121 | x: (B, C, T_padded, F_reduced) 122 | 123 | Returns: 124 | output: (B, C, T_original, F_original) 125 | """ 126 | x = F.pad(x, pad=(0, 1)) 127 | 128 | output = x[:, :,0:time_steps, :] 129 | 130 | return output 131 | 132 | class ConvBlock(nn.Module): 133 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): 134 | super(ConvBlock, self).__init__() 135 | 136 | padding = [kernel_size[0] // 2, kernel_size[1] // 2] 137 | 138 | self.bn1 = nn.BatchNorm2d(in_channels) 139 | self.bn2 = nn.BatchNorm2d(out_channels) 140 | 141 | self.conv1 = nn.Conv2d( 142 | in_channels=in_channels, 143 | out_channels=out_channels, 144 | kernel_size=kernel_size, 145 | padding=padding, 146 | bias=False, 147 | ) 148 | 149 | self.conv2 = nn.Conv2d( 150 | in_channels=out_channels, 151 | out_channels=out_channels, 152 | kernel_size=kernel_size, 153 | padding=padding, 154 | bias=False, 155 | ) 156 | 157 | if in_channels != out_channels: 158 | self.shortcut = nn.Conv2d( 159 | in_channels=in_channels, 160 | out_channels=out_channels, 161 | kernel_size=(1, 1), 162 | padding=(0, 0), 163 | ) 164 | self.is_shortcut = True 165 | else: 166 | self.is_shortcut = False 167 | 168 | def forward(self, x): 169 | h = self.conv1(F.leaky_relu_(self.bn1(x))) 170 | h = self.conv2(F.leaky_relu_(self.bn2(h))) 171 | 172 | if self.is_shortcut: 173 | return self.shortcut(x) + h 174 | else: 175 | return x + h 176 | 177 | 178 | def test_unet(): 179 | # 定义输入参数 180 | batch_size = 6 181 | channels = 1 # 音频通道数 182 | time_steps = 256 # 时间步数 183 | freq_bins = 1024 # 频率 bins 数 184 | 185 | # 创建一个随机的复数张量作为输入 186 | real_part = torch.randn(batch_size, channels, time_steps, freq_bins) 187 | imag_part = torch.randn(batch_size, channels, time_steps, freq_bins) 188 | complex_sp = real_part #torch.complex(real_part, imag_part) 189 | 190 | # 实例化 UNet 模型 191 | model = UNet() 192 | 193 | # 前向传播 194 | output = model(complex_sp) 195 | 196 | # 输出输入和输出的形状 197 | print("输入形状:", complex_sp.shape) 198 | print("输出形状:", output.shape) 199 | 200 | # 检查输出是否为复数张量 201 | assert torch.is_complex(output), "输出不是复数张量" 202 | 203 | # 检查输出形状是否与输入形状一致 204 | assert output.shape == complex_sp.shape, "输出形状与输入形状不一致" 205 | 206 | print("测试通过,模型正常工作。") 207 | 208 | # 运行测试函数 209 | if __name__ == "__main__": 210 | test_unet() --------------------------------------------------------------------------------