├── LLaSE.png ├── README.md ├── ckpt ├── codec_ckpt │ └── hub │ │ ├── .locks │ │ └── models--facebook--w2v-bert-2.0 │ │ │ ├── 5db61951cdf5edab6337fd84ee619500c27aaa3d.lock │ │ │ ├── a383a594dac18459628cd2837168cd276342a31a.lock │ │ │ └── eb890c9660ed6e3414b6812e27257b8ce5454365d5490d3ad581ea60b93be043.lock │ │ ├── models--facebook--w2v-bert-2.0 │ │ ├── blobs │ │ │ ├── 5db61951cdf5edab6337fd84ee619500c27aaa3d │ │ │ └── a383a594dac18459628cd2837168cd276342a31a │ │ ├── refs │ │ │ └── main │ │ └── snapshots │ │ │ └── da985ba0987f70aaeb84a80f2851cfac8c697a7b │ │ │ ├── config.json │ │ │ └── preprocessor_config.json │ │ └── version.txt ├── download.sh └── download_ckpt.py ├── config └── test.yml ├── inference.py ├── inference.sh ├── loader ├── __pycache__ │ └── datareader_fe.cpython-310.pyc └── datareader_fe.py ├── nnet ├── WavLM.py ├── __pycache__ │ ├── WavLM.cpython-310.pyc │ ├── embedding.cpython-310.pyc │ ├── llama.cpython-310.pyc │ └── modules.cpython-310.pyc ├── embedding.py ├── llama.py └── modules.py ├── requirements.txt └── vq ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── __init__.cpython-311.pyc ├── __init__.cpython-312.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── __init__.cpython-39.pyc ├── activations.cpython-310.pyc ├── activations.cpython-311.pyc ├── activations.cpython-312.pyc ├── activations.cpython-37.pyc ├── activations.cpython-38.pyc ├── activations.cpython-39.pyc ├── blocks.cpython-310.pyc ├── blocks.cpython-39.pyc ├── bs_roformer5.cpython-310.pyc ├── bs_roformer5.cpython-37.pyc ├── bs_roformer5.cpython-38.pyc ├── bs_roformer5.cpython-39.pyc ├── codec_decoder.cpython-310.pyc ├── codec_decoder.cpython-311.pyc ├── codec_decoder.cpython-312.pyc ├── codec_decoder.cpython-39.pyc ├── codec_decoder_vocos.cpython-310.pyc ├── codec_decoder_vocos.cpython-311.pyc ├── codec_decoder_vocos.cpython-312.pyc ├── codec_decoder_vocos.cpython-39.pyc ├── codec_encoder.cpython-310.pyc ├── codec_encoder.cpython-311.pyc ├── codec_encoder.cpython-312.pyc ├── codec_encoder.cpython-37.pyc ├── codec_encoder.cpython-38.pyc ├── codec_encoder.cpython-39.pyc ├── factorized_vector_quantize.cpython-310.pyc ├── factorized_vector_quantize.cpython-311.pyc ├── factorized_vector_quantize.cpython-312.pyc ├── factorized_vector_quantize.cpython-39.pyc ├── module.cpython-310.pyc ├── module.cpython-311.pyc ├── module.cpython-312.pyc ├── module.cpython-37.pyc ├── module.cpython-38.pyc ├── module.cpython-39.pyc ├── residual_vq.cpython-310.pyc ├── residual_vq.cpython-311.pyc ├── residual_vq.cpython-312.pyc ├── residual_vq.cpython-39.pyc ├── unet.cpython-312.pyc └── unet.cpython-39.pyc ├── activations.py ├── alias_free_torch ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── __init__.cpython-312.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── act.cpython-310.pyc │ ├── act.cpython-311.pyc │ ├── act.cpython-312.pyc │ ├── act.cpython-37.pyc │ ├── act.cpython-38.pyc │ ├── act.cpython-39.pyc │ ├── filter.cpython-310.pyc │ ├── filter.cpython-311.pyc │ ├── filter.cpython-312.pyc │ ├── filter.cpython-37.pyc │ ├── filter.cpython-38.pyc │ ├── filter.cpython-39.pyc │ ├── resample.cpython-310.pyc │ ├── resample.cpython-311.pyc │ ├── resample.cpython-312.pyc │ ├── resample.cpython-37.pyc │ ├── resample.cpython-38.pyc │ └── resample.cpython-39.pyc ├── act.py ├── filter.py └── resample.py ├── blocks.py ├── bs_roformer5.py ├── codec_decoder.py ├── codec_decoder_vocos.py ├── codec_encoder.py ├── factorized_vector_quantize.py ├── module.py ├── residual_vq.py └── unet.py /LLaSE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/LLaSE.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **LLaSE: Maximizing Acoustic Preservation for LLaMA-based Speech Enhancement** 2 | 3 | Boyi Kang\*¹, Xinfa Zhu\*¹, Zihan Zhang¹, Zhen Ye², Ziqian Wang¹, Wei Xue², Lei Xie¹ 4 | ¹ **Audio, Speech and Language Processing Group (ASLP@NPU)**, 5 | School of Computer Science, Northwestern Polytechnical University, Xi’an, China 6 | ² **The Hong Kong University of Science and Technology** 7 | 8 | --- 9 | ## News 10 | LLaSE-G1 have been released. 11 | 12 | Github: https://github.com/Kevin-naticl/LLaSE-G1 13 | 14 | Huggingface: https://huggingface.co/ASLP-lab/LLaSE-G1 15 | 16 | Paper: https://arxiv.org/abs/2503.00493 17 | 18 | ## Abstract 19 | Language Models (LMs) have shown strong semantic understanding and contextual modeling capabilities, which have recently flourished in generative speech enhancement. However, most LM-based speech enhancement approaches focus on semantic information while ignoring the key vital of acoustic information, which leads to acoustic inconsistency after enhancement, including speaker timbre varaitions and intonation. This paper proposes LLaSE, a LLaMA-based language model for Speech Enhancement. To address the challenge of acoustic inconsistency, LLaSE takes continuous representations from WavLM as input and predicts speech tokens from XCodec2, a recently released efficient Codec, maximizing acoustic preservation. Experimental results demonstrate that LLaSE achieves state-of-the-art performance on speech enhancement, offering a robust and scalable solution for speech denoising and quality improvement. 20 | 21 | ## Huggingface 22 | 23 | Our checkpoint is [here]https://huggingface.co/BeauKang01/LLaSE. 24 | 25 | ![Overall Architecture of LLaSE](LLaSE.png) 26 | 27 | ## DNSMOS results on DNS Challenge testset 28 | 29 | | Model | Type | Testset | SIG | BAK | OVRL | 30 | |-------------|---------------|------------------|---------|---------|---------| 31 | | Unprocessed | - | syn_with_reverb | 1.76 | 1.50 | 1.39 | 32 | | | | syn_no_reverb | 3.39 | 2.62 | 2.48 | 33 | | | | real_recording | 3.05 | 2.51 | 2.26 | 34 | | Conv-TasNet | Discriminative | syn_with_reverb | 2.42 | 2.71 | 2.01 | 35 | | | | syn_no_reverb | 3.09 | 3.34 | 3.00 | 36 | | | | real_recording | 3.10 | 2.98 | 2.41 | 37 | | DEMUCS | Discriminative | syn_with_reverb | 2.86 | 3.90 | 2.55 | 38 | | | | syn_no_reverb | 3.58 | 4.15 | 3.35 | 39 | | | | real_recording | 3.26 | 4.03 | 2.99 | 40 | | FRCRN | Discriminative | syn_with_reverb | 2.93 | 2.92 | 2.28 | 41 | | | | syn_no_reverb | 3.58 | 4.13 | 3.34 | 42 | | | | real_recording | 3.37 | 3.98 | 3.04 | 43 | | SELM | Generative | syn_with_reverb | 3.16 | 3.58 | 2.70 | 44 | | | | syn_no_reverb | 3.51 | 4.10 | 3.26 | 45 | | | | real_recording | **3.59**| 3.44 | 3.12 | 46 | | MaskSR | Generative | syn_with_reverb | 3.53 | 4.07 | 3.25 | 47 | | | | syn_no_reverb | 3.59 | 4.12 | 3.34 | 48 | | | | real_recording | 3.43 | 4.03 | 3.14 | 49 | | GENSE | Generative | syn_with_reverb | 3.49 | 3.73 | 3.19 | 50 | | | | syn_no_reverb | **3.65**| **4.18**| **3.43**| 51 | | | | real_recording | - | - | - | 52 | | LLaSE | Generative | syn_with_reverb | **3.59**| **4.10**| **3.33**| 53 | | | | syn_no_reverb | **3.65**| 4.17 | **3.43**| 54 | | | | real_recording | 3.50 | **4.10**| **3.24**| 55 | 56 | ## Usage 57 | 58 | ### 1. Clone the Repo 59 | ```bash 60 | git clone https://github.com/Kevin-naticl/LLaSE.git 61 | cd LLaSE 62 | ``` 63 | 64 | ### 2. Install Requirements 65 | ```bash 66 | conda create -n LLaSE python=3.10 67 | conda activate LLaSE 68 | pip install -r requirements.txt 69 | ``` 70 | 71 | ### 3. Download the Checkpoint from Hugging Face 72 | You can use the provided shell script to download the checkpoint or manually download it from [Hugging Face](https://huggingface.co/). 73 | 74 | ```bash 75 | cd ckpt 76 | bash download.sh 77 | ``` 78 | 79 | ### 4. Inference 80 | 1. Provide the file list in `./config/test.yml`. 81 | 2. Run the inference script: 82 | 83 | ```bash 84 | bash inference.sh 85 | ``` 86 | 87 | The processed `.wav` files will be saved in `./decode/wav` by default (16k sample rate). 88 | 89 | --- 90 | 91 | ### Future Updates 92 | - A Python module will be available in the future. 93 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/.locks/models--facebook--w2v-bert-2.0/5db61951cdf5edab6337fd84ee619500c27aaa3d.lock: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/ckpt/codec_ckpt/hub/.locks/models--facebook--w2v-bert-2.0/5db61951cdf5edab6337fd84ee619500c27aaa3d.lock -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/.locks/models--facebook--w2v-bert-2.0/a383a594dac18459628cd2837168cd276342a31a.lock: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/ckpt/codec_ckpt/hub/.locks/models--facebook--w2v-bert-2.0/a383a594dac18459628cd2837168cd276342a31a.lock -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/.locks/models--facebook--w2v-bert-2.0/eb890c9660ed6e3414b6812e27257b8ce5454365d5490d3ad581ea60b93be043.lock: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/ckpt/codec_ckpt/hub/.locks/models--facebook--w2v-bert-2.0/eb890c9660ed6e3414b6812e27257b8ce5454365d5490d3ad581ea60b93be043.lock -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/blobs/5db61951cdf5edab6337fd84ee619500c27aaa3d: -------------------------------------------------------------------------------- 1 | { 2 | "feature_extractor_type": "SeamlessM4TFeatureExtractor", 3 | "feature_size": 80, 4 | "num_mel_bins": 80, 5 | "padding_side": "right", 6 | "padding_value": 1, 7 | "processor_class": "Wav2Vec2BertProcessor", 8 | "return_attention_mask": true, 9 | "sampling_rate": 16000, 10 | "stride": 2 11 | } 12 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/blobs/a383a594dac18459628cd2837168cd276342a31a: -------------------------------------------------------------------------------- 1 | { 2 | "activation_dropout": 0.0, 3 | "adapter_act": "relu", 4 | "adapter_kernel_size": 3, 5 | "adapter_stride": 2, 6 | "add_adapter": false, 7 | "apply_spec_augment": false, 8 | "architectures": [ 9 | "Wav2Vec2BertModel" 10 | ], 11 | "attention_dropout": 0.0, 12 | "bos_token_id": 1, 13 | "classifier_proj_size": 768, 14 | "codevector_dim": 768, 15 | "conformer_conv_dropout": 0.1, 16 | "contrastive_logits_temperature": 0.1, 17 | "conv_depthwise_kernel_size": 31, 18 | "ctc_loss_reduction": "sum", 19 | "ctc_zero_infinity": false, 20 | "diversity_loss_weight": 0.1, 21 | "eos_token_id": 2, 22 | "feat_proj_dropout": 0.0, 23 | "feat_quantizer_dropout": 0.0, 24 | "feature_projection_input_dim": 160, 25 | "final_dropout": 0.1, 26 | "hidden_act": "swish", 27 | "hidden_dropout": 0.0, 28 | "hidden_size": 1024, 29 | "initializer_range": 0.02, 30 | "intermediate_size": 4096, 31 | "layer_norm_eps": 1e-05, 32 | "layerdrop": 0.1, 33 | "left_max_position_embeddings": 64, 34 | "mask_feature_length": 10, 35 | "mask_feature_min_masks": 0, 36 | "mask_feature_prob": 0.0, 37 | "mask_time_length": 10, 38 | "mask_time_min_masks": 2, 39 | "mask_time_prob": 0.05, 40 | "max_source_positions": 5000, 41 | "model_type": "wav2vec2-bert", 42 | "num_adapter_layers": 1, 43 | "num_attention_heads": 16, 44 | "num_codevector_groups": 2, 45 | "num_codevectors_per_group": 320, 46 | "num_hidden_layers": 24, 47 | "num_negatives": 100, 48 | "output_hidden_size": 1024, 49 | "pad_token_id": 0, 50 | "position_embeddings_type": "relative_key", 51 | "proj_codevector_dim": 768, 52 | "right_max_position_embeddings": 8, 53 | "rotary_embedding_base": 10000, 54 | "tdnn_dilation": [ 55 | 1, 56 | 2, 57 | 3, 58 | 1, 59 | 1 60 | ], 61 | "tdnn_dim": [ 62 | 512, 63 | 512, 64 | 512, 65 | 512, 66 | 1500 67 | ], 68 | "tdnn_kernel": [ 69 | 5, 70 | 3, 71 | 3, 72 | 1, 73 | 1 74 | ], 75 | "torch_dtype": "float32", 76 | "transformers_version": "4.37.0.dev0", 77 | "use_intermediate_ffn_before_adapter": false, 78 | "use_weighted_layer_sum": false, 79 | "vocab_size": null, 80 | "xvector_output_dim": 512 81 | } 82 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/refs/main: -------------------------------------------------------------------------------- 1 | da985ba0987f70aaeb84a80f2851cfac8c697a7b -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_dropout": 0.0, 3 | "adapter_act": "relu", 4 | "adapter_kernel_size": 3, 5 | "adapter_stride": 2, 6 | "add_adapter": false, 7 | "apply_spec_augment": false, 8 | "architectures": [ 9 | "Wav2Vec2BertModel" 10 | ], 11 | "attention_dropout": 0.0, 12 | "bos_token_id": 1, 13 | "classifier_proj_size": 768, 14 | "codevector_dim": 768, 15 | "conformer_conv_dropout": 0.1, 16 | "contrastive_logits_temperature": 0.1, 17 | "conv_depthwise_kernel_size": 31, 18 | "ctc_loss_reduction": "sum", 19 | "ctc_zero_infinity": false, 20 | "diversity_loss_weight": 0.1, 21 | "eos_token_id": 2, 22 | "feat_proj_dropout": 0.0, 23 | "feat_quantizer_dropout": 0.0, 24 | "feature_projection_input_dim": 160, 25 | "final_dropout": 0.1, 26 | "hidden_act": "swish", 27 | "hidden_dropout": 0.0, 28 | "hidden_size": 1024, 29 | "initializer_range": 0.02, 30 | "intermediate_size": 4096, 31 | "layer_norm_eps": 1e-05, 32 | "layerdrop": 0.1, 33 | "left_max_position_embeddings": 64, 34 | "mask_feature_length": 10, 35 | "mask_feature_min_masks": 0, 36 | "mask_feature_prob": 0.0, 37 | "mask_time_length": 10, 38 | "mask_time_min_masks": 2, 39 | "mask_time_prob": 0.05, 40 | "max_source_positions": 5000, 41 | "model_type": "wav2vec2-bert", 42 | "num_adapter_layers": 1, 43 | "num_attention_heads": 16, 44 | "num_codevector_groups": 2, 45 | "num_codevectors_per_group": 320, 46 | "num_hidden_layers": 24, 47 | "num_negatives": 100, 48 | "output_hidden_size": 1024, 49 | "pad_token_id": 0, 50 | "position_embeddings_type": "relative_key", 51 | "proj_codevector_dim": 768, 52 | "right_max_position_embeddings": 8, 53 | "rotary_embedding_base": 10000, 54 | "tdnn_dilation": [ 55 | 1, 56 | 2, 57 | 3, 58 | 1, 59 | 1 60 | ], 61 | "tdnn_dim": [ 62 | 512, 63 | 512, 64 | 512, 65 | 512, 66 | 1500 67 | ], 68 | "tdnn_kernel": [ 69 | 5, 70 | 3, 71 | 3, 72 | 1, 73 | 1 74 | ], 75 | "torch_dtype": "float32", 76 | "transformers_version": "4.37.0.dev0", 77 | "use_intermediate_ffn_before_adapter": false, 78 | "use_weighted_layer_sum": false, 79 | "vocab_size": null, 80 | "xvector_output_dim": 512 81 | } 82 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "feature_extractor_type": "SeamlessM4TFeatureExtractor", 3 | "feature_size": 80, 4 | "num_mel_bins": 80, 5 | "padding_side": "right", 6 | "padding_value": 1, 7 | "processor_class": "Wav2Vec2BertProcessor", 8 | "return_attention_mask": true, 9 | "sampling_rate": 16000, 10 | "stride": 2 11 | } 12 | -------------------------------------------------------------------------------- /ckpt/codec_ckpt/hub/version.txt: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /ckpt/download.sh: -------------------------------------------------------------------------------- 1 | python download_ckpt.py \ 2 | --source hf \ 3 | --repo_id microsoft/wavlm-large \ 4 | --filename pytorch_model.bin \ 5 | --save_path ./WavLM-Large.pt 6 | 7 | python download_ckpt.py \ 8 | --source hf \ 9 | --repo_id facebook/w2v-bert-2.0 \ 10 | --filename model.safetensors \ 11 | --save_path \ 12 | ./codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b/model.safetensors 13 | 14 | python download_ckpt.py \ 15 | --source hf \ 16 | --repo_id HKUSTAudio/xcodec2 \ 17 | --filename ckpt/epoch=4-step=1400000.ckpt \ 18 | --save_path ./codec_ckpt/epoch=4-step=1400000.ckpt 19 | 20 | python download_ckpt.py \ 21 | --source hf \ 22 | --repo_id BeauKang01/LLaSE \ 23 | --filename best.pt.tar \ 24 | --save_path ./best.pt.tar -------------------------------------------------------------------------------- /ckpt/download_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import argparse 4 | from huggingface_hub import hf_hub_download 5 | from tqdm import tqdm 6 | 7 | def download_from_url(url, save_path): 8 | """Download a file from a given URL and save it locally.""" 9 | response = requests.get(url, stream=True) 10 | total_size = int(response.headers.get("content-length", 0)) 11 | block_size = 1024 # 1 KB 12 | progress_bar = tqdm(total=total_size, unit="B", unit_scale=True) 13 | 14 | with open(save_path, "wb") as file: 15 | for data in response.iter_content(block_size): 16 | progress_bar.update(len(data)) 17 | file.write(data) 18 | progress_bar.close() 19 | 20 | if total_size != 0 and progress_bar.n != total_size: 21 | print("Download failed!") 22 | else: 23 | print(f"File downloaded to: {save_path}") 24 | 25 | def download_from_hf(repo_id, filename, save_path): 26 | """Download a file from Hugging Face Hub.""" 27 | print(f"Downloading from Hugging Face Hub: {repo_id}/{filename}") 28 | try: 29 | hf_hub_download(repo_id=repo_id, filename=filename, local_dir=os.path.dirname(save_path), local_dir_use_symlinks=False) 30 | print(f"File downloaded to: {save_path}") 31 | except Exception as e: 32 | print(f"Download failed: {e}") 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser(description="Automatically download model checkpoints") 36 | parser.add_argument("--source", type=str, required=True, choices=["hf", "url"], help="Download source: hf (Hugging Face Hub) or url (custom URL)") 37 | parser.add_argument("--repo_id", type=str, help="Hugging Face model repository ID (e.g., google/bert-base-uncased)") 38 | parser.add_argument("--filename", type=str, help="Filename in the Hugging Face repository") 39 | parser.add_argument("--url", type=str, help="Custom download URL") 40 | parser.add_argument("--save_path", type=str, required=True, help="Path to save the file (including filename)") 41 | args = parser.parse_args() 42 | 43 | # Ensure the save directory exists 44 | os.makedirs(os.path.dirname(args.save_path), exist_ok=True) 45 | 46 | if args.source == "hf": 47 | if not args.repo_id or not args.filename: 48 | print("Please provide a Hugging Face repository ID and filename!") 49 | return 50 | download_from_hf(args.repo_id, args.filename, args.save_path) 51 | elif args.source == "url": 52 | if not args.url: 53 | print("Please provide a download URL!") 54 | return 55 | download_from_url(args.url, args.save_path) 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /config/test.yml: -------------------------------------------------------------------------------- 1 | test: 2 | checkpoint: ./ckpt/best.pt.tar 3 | use_cuda: True 4 | if_chunk: True # if chunk or not 5 | chunk_seconds: 10 #chunk when inference 6 | overlap_seconds: 2 #overlap seconds 7 | 8 | save: 9 | feat_dir: ./decode/feat 10 | wav_dir: ./decode/wav 11 | 12 | # llama config 13 | nnet_conf: 14 | d_model: 1024 15 | nhead: 16 16 | num_layers: 12 17 | 18 | datareader: 19 | sample_rate: 16000 20 | filename: /path/to/your/filelist -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import torch 8 | import torch as th 9 | import torch.nn as nn 10 | import soundfile as sf 11 | import torch.distributed as dist 12 | import torch.nn.functional as F 13 | from torch.nn.parallel import DistributedDataParallel 14 | import yaml 15 | from tqdm import tqdm 16 | import argparse 17 | from transformers import AutoFeatureExtractor, Wav2Vec2BertModel 18 | from collections import OrderedDict 19 | import numpy as np 20 | import os 21 | 22 | sys.path.append(os.path.dirname(__file__)) 23 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 24 | 25 | # for WavLM 26 | from nnet.WavLM import WavLM, WavLMConfig 27 | 28 | # for encodec 29 | from vq.codec_encoder import CodecEncoder_Transformer 30 | from vq.codec_decoder_vocos import CodecDecoderVocos 31 | from vq.module import SemanticEncoder 32 | 33 | # Simple Datareader 34 | from loader.datareader_fe import DataReader 35 | 36 | # llama 37 | from nnet.llama import LLM_AR as model 38 | 39 | class Encodec(): 40 | ''' 41 | load Xcodec2 42 | ''' 43 | def __init__(self,device="cpu") -> None: 44 | self.device=device 45 | ckpt = './ckpt/codec_ckpt/epoch=4-step=1400000.ckpt' 46 | ckpt = torch.load(ckpt, map_location='cpu') 47 | state_dict = ckpt['state_dict'] 48 | filtered_state_dict_codec = OrderedDict() 49 | filtered_state_dict_semantic_encoder = OrderedDict() 50 | filtered_state_dict_gen = OrderedDict() 51 | filtered_state_dict_fc_post_a = OrderedDict() 52 | filtered_state_dict_fc_prior = OrderedDict() 53 | for key, value in state_dict.items(): 54 | if key.startswith('CodecEnc.'): 55 | new_key = key[len('CodecEnc.'):] 56 | filtered_state_dict_codec[new_key] = value 57 | elif key.startswith('generator.'): 58 | new_key = key[len('generator.'):] 59 | filtered_state_dict_gen[new_key] = value 60 | elif key.startswith('fc_post_a.'): 61 | new_key = key[len('fc_post_a.'):] 62 | filtered_state_dict_fc_post_a[new_key] = value 63 | elif key.startswith('SemanticEncoder_module.'): 64 | new_key = key[len('SemanticEncoder_module.'):] 65 | filtered_state_dict_semantic_encoder[new_key] = value 66 | elif key.startswith('fc_prior.'): 67 | new_key = key[len('fc_prior.'):] 68 | filtered_state_dict_fc_prior[new_key] = value 69 | 70 | self.semantic_model = Wav2Vec2BertModel.from_pretrained( 71 | "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b", 72 | output_hidden_states=True) 73 | self.semantic_model=self.semantic_model.eval().to(self.device) 74 | 75 | self.SemanticEncoder_module = SemanticEncoder(1024,1024,1024) 76 | self.SemanticEncoder_module.load_state_dict(filtered_state_dict_semantic_encoder) 77 | self.SemanticEncoder_module = self.SemanticEncoder_module.eval().to(self.device) 78 | 79 | self.encoder = CodecEncoder_Transformer() 80 | self.encoder.load_state_dict(filtered_state_dict_codec) 81 | self.encoder = self.encoder.eval().to(self.device) 82 | 83 | self.decoder = CodecDecoderVocos() 84 | self.decoder.load_state_dict(filtered_state_dict_gen) 85 | self.decoder = self.decoder.eval().to(self.device) 86 | 87 | self.fc_post_a = nn.Linear( 2048, 1024 ) 88 | self.fc_post_a.load_state_dict(filtered_state_dict_fc_post_a) 89 | self.fc_post_a = self.fc_post_a.eval().to(self.device) 90 | 91 | self.fc_prior = nn.Linear( 2048, 2048 ) 92 | self.fc_prior.load_state_dict(filtered_state_dict_fc_prior) 93 | self.fc_prior = self.fc_prior.eval().to(self.device) 94 | 95 | self.feature_extractor = AutoFeatureExtractor.from_pretrained( 96 | "./ckpt/codec_ckpt/hub/models--facebook--w2v-bert-2.0/snapshots/da985ba0987f70aaeb84a80f2851cfac8c697a7b") 97 | 98 | 99 | def get_feat(self, wav_batch, pad=None): 100 | 101 | if len(wav_batch.shape) != 2: 102 | return self.feature_extractor(F.pad(wav_batch, pad), sampling_rate=16000, return_tensors="pt") .data['input_features'] 103 | 104 | padded_wavs = torch.stack([F.pad(wav, pad) for wav in wav_batch]) 105 | batch_feats = [] 106 | 107 | for wav in padded_wavs: 108 | feat = self.feature_extractor( 109 | wav, 110 | sampling_rate=16000, 111 | return_tensors="pt" 112 | ).data['input_features'] 113 | 114 | batch_feats.append(feat) 115 | feat_batch = torch.concat(batch_feats, dim=0).to(self.device) 116 | return feat_batch 117 | 118 | def get_embedding(self, wav_cpu): 119 | wav_cpu = wav_cpu.cpu() 120 | feat = self.get_feat(wav_cpu,pad=(160,160)) 121 | feat = feat.to(self.device) 122 | 123 | if(len(wav_cpu.shape)==1): 124 | wav = wav_cpu.unsqueeze(0).to(self.device) 125 | else: 126 | wav = wav_cpu.to(self.device) 127 | 128 | wav = torch.nn.functional.pad(wav, (0, (200 - (wav.shape[1] % 200)))) 129 | with torch.no_grad(): 130 | vq_emb = self.encoder(wav.unsqueeze(1)) 131 | vq_emb = vq_emb.transpose(1, 2) 132 | 133 | if vq_emb.shape[2]!=feat.shape[1]: 134 | feat = self.get_feat(wav_cpu) 135 | feat = feat.to(self.device) 136 | 137 | semantic_target = self.semantic_model(feat[:, :,:]) 138 | semantic_target = semantic_target.hidden_states[16] 139 | semantic_target = semantic_target.transpose(1, 2) 140 | semantic_target = self.SemanticEncoder_module(semantic_target) 141 | 142 | vq_emb = torch.cat([semantic_target, vq_emb], dim=1) 143 | # vq_emb = self.fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) 144 | 145 | return vq_emb 146 | 147 | def emb2token(self, emb): 148 | emb.to(self.device) 149 | emb = self.fc_prior(emb.transpose(1, 2)).transpose(1, 2) 150 | _, vq_code, _ = self.decoder(emb, vq=True) 151 | return vq_code 152 | 153 | def token2wav(self, vq_code): 154 | vq_code.to(self.device) 155 | vq_post_emb = self.decoder.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) 156 | vq_post_emb = vq_post_emb.transpose(1, 2) 157 | vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1,2)).transpose(1,2) 158 | recon = self.decoder(vq_post_emb.transpose(1, 2), vq=False)[0].squeeze() 159 | # if write the wav, add .squeeze().detach().cpu().numpy() 160 | # if need gradient use the config right now 161 | return recon 162 | 163 | class WavLM_feat(object): 164 | ''' 165 | reload pretrained wavlm and extract audio feature 166 | ''' 167 | 168 | def __init__(self, device): 169 | self.wavlm = self._reload_wavLM_large(device=device) 170 | self.wavlm.eval() 171 | 172 | def __call__(self, wav): 173 | T = wav.shape[-1] 174 | wav = wav.reshape(-1, T) 175 | with torch.no_grad(): 176 | feat = self.wavlm.extract_features(wav, output_layer=6, ret_layer_results=False)[0] 177 | # B x T x 768(1024) -> B*T x 768(1024) 178 | B, T, D = feat.shape 179 | feat = torch.reshape(feat, (-1, D)) 180 | 181 | return feat 182 | 183 | def _reload_wavLM_large(self, path="./ckpt/WavLM-Large.pt", device: Optional[torch.device] = None): 184 | cpt = torch.load(path, map_location="cpu") 185 | cfg = WavLMConfig(cpt['cfg']) 186 | wavLM = WavLM(cfg) 187 | wavLM.load_state_dict(cpt['model']) 188 | wavLM.eval() 189 | if device != None: 190 | wavLM = wavLM.to(device) 191 | for p in wavLM.parameters(): 192 | p.requires_grad = False 193 | print('successful to reload wavLM', path) 194 | return wavLM 195 | 196 | def load_obj(obj, device): 197 | ''' 198 | Offload tensor object in obj to cuda device 199 | ''' 200 | def cuda(obj): 201 | return obj.to(device) if isinstance(obj, th.Tensor) else obj 202 | 203 | if isinstance(obj, dict): 204 | return {key: load_obj(obj[key], device) for key in obj} 205 | elif isinstance(obj, list): 206 | return [load_obj(val, device) for val in obj] 207 | else: 208 | return cuda(obj) 209 | 210 | def run(args): 211 | LOCAL_RANK = int(os.environ['LOCAL_RANK']) 212 | WORLD_SIZE = int(os.environ['WORLD_SIZE']) 213 | WORLD_RANK = int(os.environ['RANK']) 214 | dist.init_process_group(args.backend, rank=WORLD_RANK, world_size=WORLD_SIZE) 215 | torch.cuda.set_device(LOCAL_RANK) 216 | 217 | device = torch.device('cuda', LOCAL_RANK) 218 | print(f"[{os.getpid()}] using device: {device}", torch.cuda.current_device(), "local rank", LOCAL_RANK) 219 | 220 | with open(args.conf, "r") as f: 221 | conf = yaml.load(f, Loader=yaml.FullLoader) 222 | 223 | data_reader = DataReader(**conf["datareader"]) 224 | 225 | # Encodec and WavLM 226 | codec = Encodec(device) 227 | wavlm_feat = WavLM_feat(device) 228 | 229 | nnet = model(**conf["nnet_conf"]) 230 | cpt_fname = Path(conf["test"]["checkpoint"]) 231 | cpt = th.load(cpt_fname, map_location="cpu") 232 | 233 | nnet = nnet.to(device) 234 | nnet = DistributedDataParallel(nnet, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, find_unused_parameters=True) 235 | nnet.load_state_dict(cpt["model_state_dict"]) 236 | nnet.eval() 237 | 238 | if not os.path.exists(conf["save"]["wav_dir"]): 239 | os.makedirs(conf["save"]["wav_dir"]) 240 | if not os.path.exists(conf["save"]["feat_dir"]): 241 | os.makedirs(conf["save"]["feat_dir"]) 242 | 243 | # inference is by chunk 244 | chunk_seconds = conf["test"]["chunk_seconds"] 245 | overlap_seconds = conf["test"]["overlap_seconds"] 246 | 247 | # Feature Extraction 248 | if_chunk = conf["test"]["if_chunk"] 249 | 250 | with th.no_grad(): 251 | for egs in tqdm(data_reader, desc="Feature Extraction"): 252 | egs = load_obj(egs, device) 253 | audio = egs["mix"].contiguous() 254 | 255 | if if_chunk: 256 | total_samples = audio.shape[-1] 257 | feat_list = [] 258 | 259 | chunk_size=16000 * chunk_seconds 260 | overlap_size = 16000 * overlap_seconds 261 | 262 | for start in range(0, total_samples, chunk_size): 263 | left = max(0, start - overlap_size) 264 | right = min(start + chunk_size + overlap_size, total_samples) 265 | 266 | left_overlap = (start - left) 267 | right_overlap = (right - (start + chunk_size)) 268 | 269 | chunk = audio[:, left:right] 270 | 271 | # too short to process 272 | if total_samples - start < 400: 273 | break 274 | 275 | feat_chunk = wavlm_feat(chunk) # (1, seq_len, feat_dim) 276 | if len(feat_chunk.shape)!=2: 277 | continue 278 | zeros_row = torch.zeros((1, 1024)).to(device) 279 | feat_chunk = torch.concat((feat_chunk, zeros_row), dim = 0) 280 | 281 | if right_overlap <= 0: 282 | feat_chunk = feat_chunk[left_overlap//320:, :] 283 | else: 284 | feat_chunk = feat_chunk[left_overlap//320: -right_overlap//320, :] 285 | 286 | feat_chunk = feat_chunk.detach().squeeze(0).cpu().numpy() 287 | feat_list.append(feat_chunk) 288 | 289 | del chunk, feat_chunk, zeros_row 290 | th.cuda.empty_cache() 291 | 292 | full_feat = np.concatenate(feat_list, axis=0) 293 | 294 | del audio, feat_list 295 | 296 | else: 297 | full_feat = wavlm_feat(audio) 298 | zeros_row = torch.zeros((1, 1024)).to(device) 299 | full_feat = torch.concat((full_feat, zeros_row), dim = 0).detach().squeeze(0).cpu().numpy() 300 | 301 | np.save(os.path.join(conf["save"]["feat_dir"], egs["name"]), full_feat) 302 | 303 | del full_feat 304 | th.cuda.empty_cache() 305 | 306 | with th.no_grad(): 307 | for egs in tqdm(data_reader, desc="Audio Generation"): 308 | feat_path = os.path.join(conf["save"]["feat_dir"], egs["name"] + ".npy") 309 | full_feat = np.load(feat_path) 310 | total_frames = full_feat.shape[0] 311 | 312 | if if_chunk: 313 | 314 | recon_list = [] 315 | chunk_step = chunk_seconds * 50 316 | overlap_step = overlap_seconds * 50 317 | 318 | for start in range(0, total_frames, chunk_step): 319 | 320 | left = max(0, start - overlap_step) 321 | right = min(start + chunk_step + overlap_step, total_frames) 322 | 323 | left_overlap = (start - left) 324 | right_overlap = (right - (start + chunk_step)) 325 | 326 | feat_chunk = th.from_numpy(full_feat[left:right, :]).unsqueeze(0) 327 | feat_chunk = feat_chunk.to(device) 328 | 329 | est = nnet(feat_chunk) 330 | max_indices = th.argmax(est, dim=1) 331 | 332 | recon_chunk = codec.token2wav(max_indices.unsqueeze(0)) 333 | 334 | if right_overlap <= 0: 335 | recon_chunk = recon_chunk[left_overlap//50 * 16000 :] 336 | else: 337 | recon_chunk = recon_chunk[left_overlap//50 * 16000 : - right_overlap//50 * 16000] 338 | 339 | recon_chunk = recon_chunk.squeeze().detach().cpu().numpy() 340 | recon_list.append(recon_chunk) 341 | 342 | del feat_chunk, est, max_indices, recon_chunk 343 | th.cuda.empty_cache() 344 | 345 | full_recon = np.concatenate(recon_list) 346 | del recon_list 347 | 348 | else: 349 | est = nnet(th.from_numpy(full_feat).unsqueeze(0)) 350 | max_indices = th.argmax(est, dim=1) 351 | full_recon = codec.token2wav(max_indices.unsqueeze(0)).squeeze().detach().cpu().numpy() 352 | 353 | sf.write( 354 | os.path.join(conf["save"]["wav_dir"], egs["name"] + ".wav"), 355 | full_recon, 356 | 16000 357 | ) 358 | 359 | del full_feat,full_recon 360 | th.cuda.empty_cache() 361 | 362 | 363 | if __name__ == "__main__": 364 | parser = argparse.ArgumentParser( 365 | description = "Command to test separation model in Pytorch", 366 | formatter_class = argparse.ArgumentDefaultsHelpFormatter) 367 | parser.add_argument("-conf", 368 | type=str, 369 | required=True, 370 | help="Yaml configuration file for training") 371 | parser.add_argument("--backend", 372 | type=str, 373 | default="nccl", 374 | choices=["nccl", "gloo"]) 375 | args = parser.parse_args() 376 | 377 | os.environ["NCCL_DEBUG"] = "INFO" 378 | run(args) -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 torchrun \ 2 | --nnodes=1 \ 3 | --nproc_per_node=1 \ 4 | --master_port=21547 \ 5 | inference.py \ 6 | -conf ./config/test.yml -------------------------------------------------------------------------------- /loader/__pycache__/datareader_fe.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/loader/__pycache__/datareader_fe.cpython-310.pyc -------------------------------------------------------------------------------- /loader/datareader_fe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchaudio 3 | import torch 4 | 5 | def get_firstchannel_read(path, fs=16000): 6 | wave_data, sr = torchaudio.load(path) 7 | if sr != fs: 8 | wave_data = torchaudio.functional.resample(wave_data, sr, fs) 9 | if len(wave_data.shape) > 1: 10 | wave_data = wave_data[0,...] 11 | wave_data = wave_data.cpu().numpy() 12 | return wave_data 13 | 14 | def parse_scp(scp, path_list): 15 | with open(scp) as fid: 16 | for line in fid: 17 | tmp = line.strip().split() 18 | if len(tmp) > 1: 19 | path_list.append({"inputs": tmp[0], "duration": tmp[1]}) 20 | else: 21 | path_list.append({"inputs": tmp[0]}) 22 | 23 | class DataReader(object): 24 | def __init__(self, filename, sample_rate): 25 | self.file_list = [] 26 | self.sample_rate = sample_rate 27 | parse_scp(filename, self.file_list) 28 | 29 | def extract_feature(self, path): 30 | path = path["inputs"] 31 | name = path.split("/")[-1].split(".")[0] 32 | data = get_firstchannel_read(path, fs=self.sample_rate).astype(np.float32) 33 | max_norm = np.max(np.abs(data)) 34 | if max_norm == 0: 35 | max_norm = 1 36 | data = data / max_norm 37 | inputs = np.reshape(data, [1, data.shape[0]]) 38 | inputs = torch.from_numpy(inputs) 39 | 40 | egs = { 41 | "mix": inputs, 42 | "max_norm": max_norm, 43 | "name": name 44 | } 45 | return egs 46 | 47 | def __len__(self): 48 | return len(self.file_list) 49 | 50 | def __getitem__(self, index): 51 | return self.extract_feature(self.file_list[index]) 52 | 53 | def get_utt2spk(self, path): 54 | lines = open(path, "r").readlines() 55 | for line in lines: 56 | line = line.strip().split() 57 | utt_path, spk_id = line[0], line[1] 58 | self.utt2spk[utt_path] = spk_id 59 | 60 | def get_spk2utt(self, path): 61 | lines = open(path, "r").readlines() 62 | for line in lines: 63 | line = line.strip().split() 64 | utt_path, spk_id = line[0], line[1] 65 | self.spk2aux[spk_id] = utt_path 66 | -------------------------------------------------------------------------------- /nnet/WavLM.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import logging 12 | from typing import List, Optional, Tuple 13 | 14 | import sys,os 15 | sys.path.append(os.path.dirname(sys.path[0])) 16 | import numpy as np 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn import LayerNorm 22 | from nnet.modules import ( 23 | Fp32GroupNorm, 24 | Fp32LayerNorm, 25 | GradMultiply, 26 | MultiheadAttention, 27 | SamePad, 28 | init_bert_params, 29 | get_activation_fn, 30 | TransposeLast, 31 | GLU_Linear, 32 | ) 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def compute_mask_indices( 38 | shape: Tuple[int, int], 39 | padding_mask: Optional[torch.Tensor], 40 | mask_prob: float, 41 | mask_length: int, 42 | mask_type: str = "static", 43 | mask_other: float = 0.0, 44 | min_masks: int = 0, 45 | no_overlap: bool = False, 46 | min_space: int = 0, 47 | ) -> np.ndarray: 48 | """ 49 | Computes random mask spans for a given shape 50 | 51 | Args: 52 | shape: the the shape for which to compute masks. 53 | should be of size 2 where first element is batch size and 2nd is timesteps 54 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements 55 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by 56 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. 57 | however due to overlaps, the actual number will be smaller (unless no_overlap is True) 58 | mask_type: how to compute mask lengths 59 | static = fixed size 60 | uniform = sample from uniform distribution [mask_other, mask_length*2] 61 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element 62 | poisson = sample from possion distribution with lambda = mask length 63 | min_masks: minimum number of masked spans 64 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping 65 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans 66 | """ 67 | 68 | bsz, all_sz = shape 69 | mask = np.full((bsz, all_sz), False) 70 | 71 | all_num_mask = int( 72 | # add a random number for probabilistic rounding 73 | mask_prob * all_sz / float(mask_length) 74 | + np.random.rand() 75 | ) 76 | 77 | all_num_mask = max(min_masks, all_num_mask) 78 | 79 | mask_idcs = [] 80 | for i in range(bsz): 81 | if padding_mask is not None: 82 | sz = all_sz - padding_mask[i].long().sum().item() 83 | num_mask = int( 84 | # add a random number for probabilistic rounding 85 | mask_prob * sz / float(mask_length) 86 | + np.random.rand() 87 | ) 88 | num_mask = max(min_masks, num_mask) 89 | else: 90 | sz = all_sz 91 | num_mask = all_num_mask 92 | 93 | if mask_type == "static": 94 | lengths = np.full(num_mask, mask_length) 95 | elif mask_type == "uniform": 96 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) 97 | elif mask_type == "normal": 98 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 99 | lengths = [max(1, int(round(x))) for x in lengths] 100 | elif mask_type == "poisson": 101 | lengths = np.random.poisson(mask_length, size=num_mask) 102 | lengths = [int(round(x)) for x in lengths] 103 | else: 104 | raise Exception("unknown mask selection " + mask_type) 105 | 106 | if sum(lengths) == 0: 107 | lengths[0] = min(mask_length, sz - 1) 108 | 109 | if no_overlap: 110 | mask_idc = [] 111 | 112 | def arrange(s, e, length, keep_length): 113 | span_start = np.random.randint(s, e - length) 114 | mask_idc.extend(span_start + i for i in range(length)) 115 | 116 | new_parts = [] 117 | if span_start - s - min_space >= keep_length: 118 | new_parts.append((s, span_start - min_space + 1)) 119 | if e - span_start - keep_length - min_space > keep_length: 120 | new_parts.append((span_start + length + min_space, e)) 121 | return new_parts 122 | 123 | parts = [(0, sz)] 124 | min_length = min(lengths) 125 | for length in sorted(lengths, reverse=True): 126 | lens = np.fromiter( 127 | (e - s if e - s >= length + min_space else 0 for s, e in parts), 128 | np.int, 129 | ) 130 | l_sum = np.sum(lens) 131 | if l_sum == 0: 132 | break 133 | probs = lens / np.sum(lens) 134 | c = np.random.choice(len(parts), p=probs) 135 | s, e = parts.pop(c) 136 | parts.extend(arrange(s, e, length, min_length)) 137 | mask_idc = np.asarray(mask_idc) 138 | else: 139 | min_len = min(lengths) 140 | if sz - min_len <= num_mask: 141 | min_len = sz - num_mask - 1 142 | 143 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 144 | 145 | mask_idc = np.asarray( 146 | [ 147 | mask_idc[j] + offset 148 | for j in range(len(mask_idc)) 149 | for offset in range(lengths[j]) 150 | ] 151 | ) 152 | 153 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 154 | 155 | min_len = min([len(m) for m in mask_idcs]) 156 | for i, mask_idc in enumerate(mask_idcs): 157 | if len(mask_idc) > min_len: 158 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 159 | mask[i, mask_idc] = True 160 | 161 | return mask 162 | 163 | 164 | class WavLMConfig: 165 | def __init__(self, cfg=None): 166 | self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) 167 | self.encoder_layers: int = 12 # num encoder layers in the transformer 168 | 169 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 170 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 171 | self.encoder_attention_heads: int = 12 # num encoder attention heads 172 | self.activation_fn: str = "gelu" # activation function to use 173 | 174 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 175 | self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] 176 | self.conv_bias: bool = False # include bias in conv encoder 177 | self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this 178 | 179 | self.normalize: bool = False # normalize input to have 0 mean and unit variance during training 180 | 181 | # dropouts 182 | self.dropout: float = 0.1 # dropout probability for the transformer 183 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 184 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 185 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 186 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 187 | self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) 188 | 189 | # masking 190 | self.mask_length: int = 10 # mask length 191 | self.mask_prob: float = 0.65 # probability of replacing a token with mask 192 | self.mask_selection: str = "static" # how to choose mask length 193 | self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh 194 | self.no_mask_overlap: bool = False # whether to allow masks to overlap 195 | self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) 196 | 197 | # channel masking 198 | self.mask_channel_length: int = 10 # length of the mask for features (channels) 199 | self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 200 | self.mask_channel_selection: str = "static" # how to choose mask length for channel masking 201 | self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices 202 | self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap 203 | self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) 204 | 205 | # positional embeddings 206 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 207 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 208 | 209 | # relative position embedding 210 | self.relative_position_embedding: bool = False # apply relative position embedding 211 | self.num_buckets: int = 320 # number of buckets for relative position embedding 212 | self.max_distance: int = 1280 # maximum distance for relative position embedding 213 | self.gru_rel_pos: bool = False # apply gated relative position embedding 214 | 215 | if cfg is not None: 216 | self.update(cfg) 217 | 218 | def update(self, cfg: dict): 219 | self.__dict__.update(cfg) 220 | 221 | 222 | class WavLM(nn.Module): 223 | def __init__( 224 | self, 225 | cfg: WavLMConfig, 226 | ) -> None: 227 | super().__init__() 228 | logger.info(f"WavLM Config: {cfg.__dict__}") 229 | 230 | self.cfg = cfg 231 | feature_enc_layers = eval(cfg.conv_feature_layers) 232 | self.embed = feature_enc_layers[-1][0] 233 | 234 | self.feature_extractor = ConvFeatureExtractionModel( 235 | conv_layers=feature_enc_layers, 236 | dropout=0.0, 237 | mode=cfg.extractor_mode, 238 | conv_bias=cfg.conv_bias, 239 | ) 240 | 241 | self.post_extract_proj = ( 242 | nn.Linear(self.embed, cfg.encoder_embed_dim) 243 | if self.embed != cfg.encoder_embed_dim 244 | else None 245 | ) 246 | 247 | self.mask_prob = cfg.mask_prob 248 | self.mask_selection = cfg.mask_selection 249 | self.mask_other = cfg.mask_other 250 | self.mask_length = cfg.mask_length 251 | self.no_mask_overlap = cfg.no_mask_overlap 252 | self.mask_min_space = cfg.mask_min_space 253 | 254 | self.mask_channel_prob = cfg.mask_channel_prob 255 | self.mask_channel_selection = cfg.mask_channel_selection 256 | self.mask_channel_other = cfg.mask_channel_other 257 | self.mask_channel_length = cfg.mask_channel_length 258 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap 259 | self.mask_channel_min_space = cfg.mask_channel_min_space 260 | 261 | self.dropout_input = nn.Dropout(cfg.dropout_input) 262 | self.dropout_features = nn.Dropout(cfg.dropout_features) 263 | 264 | self.feature_grad_mult = cfg.feature_grad_mult 265 | 266 | self.mask_emb = nn.Parameter( 267 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_() 268 | ) 269 | 270 | self.encoder = TransformerEncoder(cfg) 271 | self.layer_norm = LayerNorm(self.embed) 272 | 273 | def apply_mask(self, x, padding_mask): 274 | B, T, C = x.shape 275 | if self.mask_prob > 0: 276 | mask_indices = compute_mask_indices( 277 | (B, T), 278 | padding_mask, 279 | self.mask_prob, 280 | self.mask_length, 281 | self.mask_selection, 282 | self.mask_other, 283 | min_masks=2, 284 | no_overlap=self.no_mask_overlap, 285 | min_space=self.mask_min_space, 286 | ) 287 | mask_indices = torch.from_numpy(mask_indices).to(x.device) 288 | x[mask_indices] = self.mask_emb 289 | else: 290 | mask_indices = None 291 | 292 | if self.mask_channel_prob > 0: 293 | mask_channel_indices = compute_mask_indices( 294 | (B, C), 295 | None, 296 | self.mask_channel_prob, 297 | self.mask_channel_length, 298 | self.mask_channel_selection, 299 | self.mask_channel_other, 300 | no_overlap=self.no_mask_channel_overlap, 301 | min_space=self.mask_channel_min_space, 302 | ) 303 | mask_channel_indices = ( 304 | torch.from_numpy(mask_channel_indices) 305 | .to(x.device) 306 | .unsqueeze(1) 307 | .expand(-1, T, -1) 308 | ) 309 | x[mask_channel_indices] = 0 310 | 311 | return x, mask_indices 312 | 313 | def forward_padding_mask( 314 | self, features: torch.Tensor, padding_mask: torch.Tensor, 315 | ) -> torch.Tensor: 316 | extra = padding_mask.size(1) % features.size(1) 317 | if extra > 0: 318 | padding_mask = padding_mask[:, :-extra] 319 | padding_mask = padding_mask.view( 320 | padding_mask.size(0), features.size(1), -1 321 | ) 322 | padding_mask = padding_mask.all(-1) 323 | return padding_mask 324 | 325 | def extract_features( 326 | self, 327 | source: torch.Tensor, 328 | padding_mask: Optional[torch.Tensor] = None, 329 | mask: bool = False, 330 | ret_conv: bool = False, 331 | output_layer: Optional[int] = None, 332 | ret_layer_results: bool = False, 333 | ): 334 | 335 | if self.feature_grad_mult > 0: 336 | features = self.feature_extractor(source) 337 | if self.feature_grad_mult != 1.0: 338 | features = GradMultiply.apply(features, self.feature_grad_mult) 339 | else: 340 | with torch.no_grad(): 341 | features = self.feature_extractor(source) 342 | 343 | features = features.transpose(1, 2) 344 | features = self.layer_norm(features) 345 | 346 | if padding_mask is not None: 347 | padding_mask = self.forward_padding_mask(features, padding_mask) 348 | 349 | if self.post_extract_proj is not None: 350 | features = self.post_extract_proj(features) 351 | 352 | features = self.dropout_input(features) 353 | 354 | if mask: 355 | x, mask_indices = self.apply_mask( 356 | features, padding_mask 357 | ) 358 | else: 359 | x = features 360 | 361 | # feature: (B, T, D), float 362 | # target: (B, T), long 363 | # x: (B, T, D), float 364 | # padding_mask: (B, T), bool 365 | # mask_indices: (B, T), bool 366 | x, layer_results = self.encoder( 367 | x, 368 | padding_mask=padding_mask, 369 | layer=None if output_layer is None else output_layer - 1 370 | ) 371 | 372 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} 373 | 374 | feature = res["features"] if ret_conv else res["x"] 375 | if ret_layer_results: 376 | feature = (feature, res["layer_results"]) 377 | return feature, res["padding_mask"] 378 | 379 | 380 | def long_term_modeling( 381 | self, 382 | source: torch.Tensor, 383 | padding_mask: Optional[torch.Tensor] = None, 384 | mask: bool = False, 385 | ret_conv: bool = False, 386 | output_layer: Optional[int] = None, 387 | ret_layer_results: bool = False, 388 | ): 389 | 390 | features = source.transpose(1, 2) 391 | features = self.layer_norm(features) 392 | 393 | if padding_mask is not None: 394 | padding_mask = self.forward_padding_mask(features, padding_mask) 395 | 396 | if self.post_extract_proj is not None: 397 | features = self.post_extract_proj(features) 398 | 399 | features = self.dropout_input(features) 400 | 401 | if mask: 402 | x, mask_indices = self.apply_mask( 403 | features, padding_mask 404 | ) 405 | else: 406 | x = features 407 | 408 | # feature: (B, T, D), float 409 | # target: (B, T), long 410 | # x: (B, T, D), float 411 | # padding_mask: (B, T), bool 412 | # mask_indices: (B, T), bool 413 | x, layer_results = self.encoder( 414 | x, 415 | padding_mask=padding_mask, 416 | layer=None if output_layer is None else output_layer - 1 417 | ) 418 | 419 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} 420 | 421 | feature = res["features"] if ret_conv else res["x"] 422 | if ret_layer_results: 423 | feature = (feature, res["layer_results"]) 424 | return feature, res["padding_mask"] 425 | 426 | 427 | 428 | class ConvFeatureExtractionModel(nn.Module): 429 | def __init__( 430 | self, 431 | conv_layers: List[Tuple[int, int, int]], 432 | dropout: float = 0.0, 433 | mode: str = "default", 434 | conv_bias: bool = False, 435 | conv_type: str = "default" 436 | ): 437 | super().__init__() 438 | 439 | assert mode in {"default", "layer_norm"} 440 | 441 | def block( 442 | n_in, 443 | n_out, 444 | k, 445 | stride, 446 | is_layer_norm=False, 447 | is_group_norm=False, 448 | conv_bias=False, 449 | ): 450 | def make_conv(): 451 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) 452 | nn.init.kaiming_normal_(conv.weight) 453 | return conv 454 | 455 | assert ( 456 | is_layer_norm and is_group_norm 457 | ) == False, "layer norm and group norm are exclusive" 458 | 459 | if is_layer_norm: 460 | return nn.Sequential( 461 | make_conv(), 462 | nn.Dropout(p=dropout), 463 | nn.Sequential( 464 | TransposeLast(), 465 | Fp32LayerNorm(dim, elementwise_affine=True), 466 | TransposeLast(), 467 | ), 468 | nn.GELU(), 469 | ) 470 | elif is_group_norm: 471 | return nn.Sequential( 472 | make_conv(), 473 | nn.Dropout(p=dropout), 474 | Fp32GroupNorm(dim, dim, affine=True), 475 | nn.GELU(), 476 | ) 477 | else: 478 | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) 479 | 480 | self.conv_type = conv_type 481 | if self.conv_type == "default": 482 | in_d = 1 483 | self.conv_layers = nn.ModuleList() 484 | for i, cl in enumerate(conv_layers): 485 | assert len(cl) == 3, "invalid conv definition: " + str(cl) 486 | (dim, k, stride) = cl 487 | 488 | self.conv_layers.append( 489 | block( 490 | in_d, 491 | dim, 492 | k, 493 | stride, 494 | is_layer_norm=mode == "layer_norm", 495 | is_group_norm=mode == "default" and i == 0, 496 | conv_bias=conv_bias, 497 | ) 498 | ) 499 | in_d = dim 500 | elif self.conv_type == "conv2d": 501 | in_d = 1 502 | self.conv_layers = nn.ModuleList() 503 | for i, cl in enumerate(conv_layers): 504 | assert len(cl) == 3 505 | (dim, k, stride) = cl 506 | 507 | self.conv_layers.append( 508 | torch.nn.Conv2d(in_d, dim, k, stride) 509 | ) 510 | self.conv_layers.append(torch.nn.ReLU()) 511 | in_d = dim 512 | elif self.conv_type == "custom": 513 | in_d = 1 514 | idim = 80 515 | self.conv_layers = nn.ModuleList() 516 | for i, cl in enumerate(conv_layers): 517 | assert len(cl) == 3 518 | (dim, k, stride) = cl 519 | self.conv_layers.append( 520 | torch.nn.Conv2d(in_d, dim, k, stride, padding=1) 521 | ) 522 | self.conv_layers.append( 523 | torch.nn.LayerNorm([dim, idim]) 524 | ) 525 | self.conv_layers.append(torch.nn.ReLU()) 526 | in_d = dim 527 | if (i + 1) % 2 == 0: 528 | self.conv_layers.append( 529 | torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) 530 | ) 531 | idim = int(math.ceil(idim / 2)) 532 | else: 533 | pass 534 | 535 | def forward(self, x, mask=None): 536 | 537 | # BxT -> BxCxT 538 | x = x.unsqueeze(1) 539 | if self.conv_type == "custom": 540 | for conv in self.conv_layers: 541 | if isinstance(conv, nn.LayerNorm): 542 | x = x.transpose(1, 2) 543 | x = conv(x).transpose(1, 2) 544 | else: 545 | x = conv(x) 546 | x = x.transpose(2, 3).contiguous() 547 | x = x.view(x.size(0), -1, x.size(-1)) 548 | else: 549 | for conv in self.conv_layers: 550 | x = conv(x) 551 | if self.conv_type == "conv2d": 552 | b, c, t, f = x.size() 553 | x = x.transpose(2, 3).contiguous().view(b, c * f, t) 554 | return x 555 | 556 | 557 | class TransformerEncoder(nn.Module): 558 | def __init__(self, args): 559 | super().__init__() 560 | 561 | self.dropout = args.dropout 562 | self.embedding_dim = args.encoder_embed_dim 563 | 564 | self.pos_conv = nn.Conv1d( 565 | self.embedding_dim, 566 | self.embedding_dim, 567 | kernel_size=args.conv_pos, 568 | padding=args.conv_pos // 2, 569 | groups=args.conv_pos_groups, 570 | ) 571 | dropout = 0 572 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) 573 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std) 574 | nn.init.constant_(self.pos_conv.bias, 0) 575 | 576 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) 577 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) 578 | 579 | if hasattr(args, "relative_position_embedding"): 580 | self.relative_position_embedding = args.relative_position_embedding 581 | self.num_buckets = args.num_buckets 582 | self.max_distance = args.max_distance 583 | else: 584 | self.relative_position_embedding = False 585 | self.num_buckets = 0 586 | self.max_distance = 0 587 | 588 | self.layers = nn.ModuleList( 589 | [ 590 | TransformerSentenceEncoderLayer( 591 | embedding_dim=self.embedding_dim, 592 | ffn_embedding_dim=args.encoder_ffn_embed_dim, 593 | num_attention_heads=args.encoder_attention_heads, 594 | dropout=self.dropout, 595 | attention_dropout=args.attention_dropout, 596 | activation_dropout=args.activation_dropout, 597 | activation_fn=args.activation_fn, 598 | layer_norm_first=args.layer_norm_first, 599 | has_relative_attention_bias=(self.relative_position_embedding and i == 0), 600 | num_buckets=self.num_buckets, 601 | max_distance=self.max_distance, 602 | gru_rel_pos=args.gru_rel_pos, 603 | ) 604 | for i in range(args.encoder_layers) 605 | ] 606 | ) 607 | 608 | self.layer_norm_first = args.layer_norm_first 609 | self.layer_norm = LayerNorm(self.embedding_dim) 610 | self.layerdrop = args.encoder_layerdrop 611 | 612 | self.apply(init_bert_params) 613 | 614 | def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): 615 | x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) 616 | 617 | if self.layer_norm_first and layer is None: 618 | x = self.layer_norm(x) 619 | 620 | return x, layer_results 621 | 622 | def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): 623 | 624 | if padding_mask is not None: 625 | x[padding_mask] = 0 626 | 627 | y = x.transpose(1, 2).clone() 628 | x_conv = self.pos_conv(y) 629 | x_conv = x_conv.transpose(1, 2) 630 | x += x_conv 631 | 632 | if not self.layer_norm_first: 633 | x = self.layer_norm(x) 634 | 635 | x = F.dropout(x, p=self.dropout, training=self.training) 636 | 637 | # B x T x C -> T x B x C 638 | x = x.transpose(0, 1) 639 | 640 | layer_results = [] 641 | z = None 642 | if tgt_layer is not None: 643 | layer_results.append((x, z)) 644 | r = None 645 | pos_bias = None 646 | for i, layer in enumerate(self.layers): 647 | dropout_probability = np.random.random() 648 | if not self.training or (dropout_probability > self.layerdrop): 649 | x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, 650 | self_attn_mask=streaming_mask, pos_bias=pos_bias) 651 | if tgt_layer is not None: 652 | layer_results.append((x, z)) 653 | if i == tgt_layer: 654 | r = x 655 | break 656 | 657 | if r is not None: 658 | x = r 659 | 660 | # T x B x C -> B x T x C 661 | x = x.transpose(0, 1) 662 | 663 | return x, layer_results 664 | 665 | 666 | class TransformerSentenceEncoderLayer(nn.Module): 667 | """ 668 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 669 | models. 670 | """ 671 | 672 | def __init__( 673 | self, 674 | embedding_dim: float = 768, 675 | ffn_embedding_dim: float = 3072, 676 | num_attention_heads: float = 8, 677 | dropout: float = 0.1, 678 | attention_dropout: float = 0.1, 679 | activation_dropout: float = 0.1, 680 | activation_fn: str = "relu", 681 | layer_norm_first: bool = False, 682 | has_relative_attention_bias: bool = False, 683 | num_buckets: int = 0, 684 | max_distance: int = 0, 685 | rescale_init: bool = False, 686 | gru_rel_pos: bool = False, 687 | ) -> None: 688 | 689 | super().__init__() 690 | # Initialize parameters 691 | self.embedding_dim = embedding_dim 692 | self.dropout = dropout 693 | self.activation_dropout = activation_dropout 694 | 695 | # Initialize blocks 696 | self.activation_name = activation_fn 697 | self.activation_fn = get_activation_fn(activation_fn) 698 | self.self_attn = MultiheadAttention( 699 | self.embedding_dim, 700 | num_attention_heads, 701 | dropout=attention_dropout, 702 | self_attention=True, 703 | has_relative_attention_bias=has_relative_attention_bias, 704 | num_buckets=num_buckets, 705 | max_distance=max_distance, 706 | rescale_init=rescale_init, 707 | gru_rel_pos=gru_rel_pos, 708 | ) 709 | 710 | self.dropout1 = nn.Dropout(dropout) 711 | self.dropout2 = nn.Dropout(self.activation_dropout) 712 | self.dropout3 = nn.Dropout(dropout) 713 | 714 | self.layer_norm_first = layer_norm_first 715 | 716 | # layer norm associated with the self attention layer 717 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim) 718 | 719 | if self.activation_name == "glu": 720 | self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") 721 | else: 722 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 723 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 724 | 725 | # layer norm associated with the position wise feed-forward NN 726 | self.final_layer_norm = LayerNorm(self.embedding_dim) 727 | 728 | def forward( 729 | self, 730 | x: torch.Tensor, 731 | self_attn_mask: torch.Tensor = None, 732 | self_attn_padding_mask: torch.Tensor = None, 733 | need_weights: bool = False, 734 | pos_bias=None 735 | ): 736 | """ 737 | LayerNorm is applied either before or after the self-attention/ffn 738 | modules similar to the original Transformer imlementation. 739 | """ 740 | residual = x 741 | 742 | if self.layer_norm_first: 743 | x = self.self_attn_layer_norm(x) 744 | x, attn, pos_bias = self.self_attn( 745 | query=x, 746 | key=x, 747 | value=x, 748 | key_padding_mask=self_attn_padding_mask, 749 | need_weights=False, 750 | attn_mask=self_attn_mask, 751 | position_bias=pos_bias 752 | ) 753 | x = self.dropout1(x) 754 | x = residual + x 755 | 756 | residual = x 757 | x = self.final_layer_norm(x) 758 | if self.activation_name == "glu": 759 | x = self.fc1(x) 760 | else: 761 | x = self.activation_fn(self.fc1(x)) 762 | x = self.dropout2(x) 763 | x = self.fc2(x) 764 | x = self.dropout3(x) 765 | x = residual + x 766 | else: 767 | x, attn, pos_bias = self.self_attn( 768 | query=x, 769 | key=x, 770 | value=x, 771 | key_padding_mask=self_attn_padding_mask, 772 | need_weights=need_weights, 773 | attn_mask=self_attn_mask, 774 | position_bias=pos_bias 775 | ) 776 | 777 | x = self.dropout1(x) 778 | x = residual + x 779 | 780 | x = self.self_attn_layer_norm(x) 781 | 782 | residual = x 783 | if self.activation_name == "glu": 784 | x = self.fc1(x) 785 | else: 786 | x = self.activation_fn(self.fc1(x)) 787 | x = self.dropout2(x) 788 | x = self.fc2(x) 789 | x = self.dropout3(x) 790 | x = residual + x 791 | x = self.final_layer_norm(x) 792 | 793 | return x, attn, pos_bias 794 | -------------------------------------------------------------------------------- /nnet/__pycache__/WavLM.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/nnet/__pycache__/WavLM.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/embedding.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/nnet/__pycache__/embedding.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/llama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/nnet/__pycache__/llama.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/nnet/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /nnet/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | class SimpleMLP(nn.Module): 21 | def __init__( 22 | self, 23 | dim_model: int, 24 | dropout: float = 0.0, 25 | ): 26 | super().__init__() 27 | 28 | self.dim_model = dim_model 29 | 30 | self.dropout = torch.nn.Dropout(p=dropout) 31 | self.mlp = nn.Sequential( 32 | nn.Linear(dim_model, dim_model), 33 | nn.ReLU(), 34 | nn.Linear(dim_model, dim_model), 35 | ) 36 | self.init_weights() 37 | 38 | def init_weights(self, gain: float = 1.0): 39 | torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain) 40 | 41 | @property 42 | def weight(self) -> torch.Tensor: 43 | return self.word_embeddings.weight 44 | 45 | def forward(self, x: torch.Tensor): 46 | X = self.mlp(x) 47 | X = self.dropout(X) 48 | 49 | return X 50 | 51 | class TokenEmbedding(nn.Module): 52 | def __init__( 53 | self, 54 | dim_model: int, 55 | vocab_size: int, 56 | dropout: float = 0.0, 57 | ): 58 | super().__init__() 59 | 60 | self.vocab_size = vocab_size 61 | self.dim_model = dim_model 62 | 63 | self.dropout = torch.nn.Dropout(p=dropout) 64 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) 65 | self.init_weights() 66 | 67 | def init_weights(self, gain: float = 1.0): 68 | torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain) 69 | 70 | @property 71 | def weight(self) -> torch.Tensor: 72 | return self.word_embeddings.weight 73 | 74 | def forward(self, x: torch.Tensor): 75 | X = self.word_embeddings(x) 76 | X = self.dropout(X) 77 | 78 | return X 79 | 80 | 81 | class SinePositionalEmbedding(nn.Module): 82 | def __init__(self, dim_model: int): 83 | super().__init__() 84 | self.dim_model = dim_model 85 | 86 | def forward(self, x: torch.Tensor) -> torch.Tensor: 87 | seq_len = x.shape[1] 88 | pos = ( 89 | torch.arange(0, seq_len, device=x.device, dtype=torch.float32) 90 | .unsqueeze(1) 91 | .repeat(1, self.dim_model) 92 | ) 93 | dim = ( 94 | torch.arange( 95 | 0, self.dim_model, device=x.device, dtype=torch.float32 96 | ) 97 | .unsqueeze(0) 98 | .repeat(seq_len, 1) 99 | ) 100 | # div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model)) 101 | div = torch.exp(-math.log(10000) * (2 * (torch.div(dim, 2)) / self.dim_model)) 102 | 103 | pos *= div 104 | pos[:, 0::2] = torch.sin(pos[:, 0::2]) 105 | pos[:, 1::2] = torch.cos(pos[:, 1::2]) 106 | 107 | output = x.unsqueeze(-1) if x.ndim == 2 else x 108 | 109 | return output + pos.unsqueeze(0) 110 | 111 | 112 | class SinePositionalEmbedding_V2(nn.Module): 113 | def __init__(self, feature_dim: int, max_seq_len: int = 1024, temperature=10000): 114 | super().__init__() 115 | self.feature_dim = feature_dim 116 | self.max_seq_len = max_seq_len 117 | self.temperature = temperature 118 | self.positional_embeddings = self._generate_positional_embeddings() 119 | 120 | def _generate_positional_embeddings(self): 121 | div_term = torch.exp( 122 | torch.arange(0, self.feature_dim, 2).float() 123 | * -(torch.log(torch.tensor(self.temperature)) / self.feature_dim) 124 | ) 125 | 126 | positions = torch.arange(0, self.max_seq_len).float().unsqueeze(1) 127 | pos_emb = torch.zeros(self.max_seq_len, self.feature_dim) 128 | pos_emb[:, 0::2] = torch.sin(positions * div_term) 129 | pos_emb[:, 1::2] = torch.cos(positions * div_term) 130 | 131 | return pos_emb.unsqueeze(0) 132 | 133 | def forward(self, x: torch.Tensor) -> torch.Tensor: 134 | batch_size, seq_len, _ = x.shape 135 | if seq_len > self.max_seq_len: 136 | raise ValueError("Input sequence length exceeds maximum sequence length.") 137 | 138 | pos_emb = self.positional_embeddings[:, :seq_len, :] 139 | pos_emb = pos_emb.to(x.device) 140 | 141 | output = x + pos_emb 142 | 143 | return output 144 | 145 | 146 | if __name__=="__main__": 147 | x = torch.rand([4, 199, 1024]) 148 | pos_emb = SinePositionalEmbedding_V2(1024) 149 | out = pos_emb(x) 150 | print(out.shape) -------------------------------------------------------------------------------- /nnet/llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import sys,os 5 | import numpy as np 6 | 7 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | from embedding import SinePositionalEmbedding, TokenEmbedding 10 | from transformers import LlamaConfig, LlamaForCausalLM 11 | 12 | NUM_AUDIO_TOKENS = 65536 #Number of Xcodec codebook 13 | 14 | class AdaptiveLayerNorm(nn.Module): 15 | r"""Adaptive Layer Normalization""" 16 | 17 | def __init__(self, d_model, norm) -> None: 18 | super(AdaptiveLayerNorm, self).__init__() 19 | self.project_layer = nn.Linear(d_model, 2 * d_model) 20 | self.norm = norm 21 | self.d_model = d_model 22 | self.eps = self.norm.eps 23 | 24 | def forward(self, input: torch.Tensor, embedding: torch.Tensor = None) -> torch.Tensor: 25 | if isinstance(input, tuple): 26 | input, embedding = input 27 | weight, bias = torch.split( 28 | self.project_layer(embedding), 29 | split_size_or_sections=self.d_model, 30 | dim=-1, 31 | ) 32 | return (weight * self.norm(input) + bias, embedding) 33 | 34 | weight, bias = torch.split( 35 | self.project_layer(embedding), 36 | split_size_or_sections=self.d_model, 37 | dim=-1, 38 | ) 39 | 40 | return weight * self.norm(input) + bias 41 | 42 | class LLM_AR(nn.Module): 43 | def __init__( 44 | self, 45 | d_model: int, 46 | nhead: int, 47 | num_layers: int 48 | ): 49 | super().__init__() 50 | 51 | self.audio_linear = nn.Linear(1024, d_model) 52 | self.audio_position = SinePositionalEmbedding(d_model) 53 | self.stage_embedding = TokenEmbedding(d_model, 1) 54 | self.adaLN = AdaptiveLayerNorm(d_model, norm=nn.LayerNorm(d_model)) 55 | 56 | self.Llama_config = LlamaConfig( 57 | hidden_size=d_model, 58 | intermediate_size=d_model * 4, 59 | num_attention_heads=nhead, 60 | num_hidden_layers=num_layers, 61 | dropout_rate=0.1, 62 | attention_dropout=0.1, 63 | is_decoder=True, 64 | use_cache=True 65 | ) 66 | 67 | self.lm = LlamaForCausalLM(config=self.Llama_config) 68 | self.predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS) 69 | 70 | def forward( 71 | self, 72 | y: torch.Tensor, 73 | ) -> torch.Tensor: 74 | 75 | y_emb = self.audio_linear(y) # [B, T, D] 76 | y_pos = self.audio_position(y_emb) # [B, T, D] 77 | 78 | stage_embedding = self.stage_embedding(torch.tensor(0, device=y_pos.device)) 79 | y_pos = self.adaLN(y_pos, stage_embedding) 80 | 81 | outputs = self.lm(inputs_embeds=y_pos, output_hidden_states=True) 82 | y_dec = outputs.hidden_states[-1] # [B, T, D] 83 | 84 | logits = self.predict_layer(y_dec) # [B, T, NUM_AUDIO_TOKENS] 85 | 86 | logits = logits.transpose(-1, -2) # [B, NUM_AUDIO_TOKENS, T] 87 | 88 | return logits 89 | 90 | if __name__=="__main__": 91 | # for test 92 | model = LLM_AR(d_model=1024, nhead=8, num_layers=12, task="SE") 93 | ce_loss = nn.CrossEntropyLoss() 94 | x = torch.randn([2,199,1024]) 95 | label = torch.from_numpy(np.random.randint(0, 300, size=[2,1,199])) 96 | logits = model(x) 97 | print(logits.shape) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.1.0 2 | aiohttp==3.10.5 3 | datasets==3.0.1 4 | deepspeed==0.15.1 5 | einops==0.8.0 6 | huggingface-hub==0.25.1 7 | numpy==1.23.5 8 | pandas==2.2.3 9 | pillow==10.4.0 10 | scikit-learn==1.5.2 11 | scipy==1.13.1 12 | torch==2.4.1 13 | torchaudio==2.4.1 14 | torchmetrics==1.4.2 15 | transformers==4.47.1 16 | tqdm==4.66.5 -------------------------------------------------------------------------------- /vq/__init__.py: -------------------------------------------------------------------------------- 1 | from vq.codec_encoder import CodecEncoder 2 | from vq.codec_decoder import CodecDecoder 3 | from vq.codec_decoder_vocos import CodecDecoderVocos 4 | from vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/activations.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/activations.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/activations.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/activations.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/activations.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/activations.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/activations.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/blocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/blocks.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/blocks.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/bs_roformer5.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/bs_roformer5.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/bs_roformer5.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/bs_roformer5.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/bs_roformer5.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_decoder_vocos.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_encoder.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_encoder.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/codec_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/codec_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/factorized_vector_quantize.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/module.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/module.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/module.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/module.cpython-37.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/module.cpython-38.pyc -------------------------------------------------------------------------------- /vq/__pycache__/module.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/module.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/residual_vq.cpython-310.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/residual_vq.cpython-311.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/residual_vq.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/residual_vq.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/residual_vq.cpython-39.pyc -------------------------------------------------------------------------------- /vq/__pycache__/unet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/unet.cpython-312.pyc -------------------------------------------------------------------------------- /vq/__pycache__/unet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/__pycache__/unet.cpython-39.pyc -------------------------------------------------------------------------------- /vq/activations.py: -------------------------------------------------------------------------------- 1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | from torch import nn, sin, pow 6 | from torch.nn import Parameter 7 | 8 | 9 | class Snake(nn.Module): 10 | ''' 11 | Implementation of a sine-based periodic activation function 12 | Shape: 13 | - Input: (B, C, T) 14 | - Output: (B, C, T), same shape as the input 15 | Parameters: 16 | - alpha - trainable parameter 17 | References: 18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 19 | https://arxiv.org/abs/2006.08195 20 | Examples: 21 | >>> a1 = snake(256) 22 | >>> x = torch.randn(256) 23 | >>> x = a1(x) 24 | ''' 25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 26 | ''' 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | ''' 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | ''' 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | ''' 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | ''' 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | ''' 79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 80 | ''' 81 | Initialization. 82 | INPUT: 83 | - in_features: shape of the input 84 | - alpha - trainable parameter that controls frequency 85 | - beta - trainable parameter that controls magnitude 86 | alpha is initialized to 1 by default, higher values = higher-frequency. 87 | beta is initialized to 1 by default, higher values = higher-magnitude. 88 | alpha will be trained along with the rest of your model. 89 | ''' 90 | super(SnakeBeta, self).__init__() 91 | self.in_features = in_features 92 | 93 | # initialize alpha 94 | self.alpha_logscale = alpha_logscale 95 | if self.alpha_logscale: # log scale alphas initialized to zeros 96 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 97 | self.beta = Parameter(torch.zeros(in_features) * alpha) 98 | else: # linear scale alphas initialized to ones 99 | self.alpha = Parameter(torch.ones(in_features) * alpha) 100 | self.beta = Parameter(torch.ones(in_features) * alpha) 101 | 102 | self.alpha.requires_grad = alpha_trainable 103 | self.beta.requires_grad = alpha_trainable 104 | 105 | self.no_div_by_zero = 0.000000001 106 | 107 | def forward(self, x): 108 | ''' 109 | Forward pass of the function. 110 | Applies the function to the input elementwise. 111 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 112 | ''' 113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 114 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 115 | if self.alpha_logscale: 116 | alpha = torch.exp(alpha) 117 | beta = torch.exp(beta) 118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 119 | 120 | return x -------------------------------------------------------------------------------- /vq/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/act.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/act.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/act.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/act.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/act.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/act.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/act.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/filter.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/filter.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/resample.cpython-37.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/__pycache__/resample.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevin-naticl/LLaSE/6b9221aecabdc8e64ae0d738b5278ba7f4011d06/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc -------------------------------------------------------------------------------- /vq/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__(self, 10 | activation, 11 | up_ratio: int = 2, 12 | down_ratio: int = 2, 13 | up_kernel_size: int = 12, 14 | down_kernel_size: int = 12): 15 | super().__init__() 16 | self.up_ratio = up_ratio 17 | self.down_ratio = down_ratio 18 | self.act = activation 19 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 20 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 21 | 22 | # x: [B,C,T] 23 | def forward(self, x): 24 | x = self.upsample(x) 25 | x = self.act(x) 26 | x = self.downsample(x) 27 | 28 | return x -------------------------------------------------------------------------------- /vq/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if 'sinc' in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where(x == 0, 21 | torch.tensor(1., device=x.device, dtype=x.dtype), 22 | torch.sin(math.pi * x) / math.pi / x) 23 | 24 | 25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 26 | # https://adefossez.github.io/julius/julius/lowpass.html 27 | # LICENSE is in incl_licenses directory. 28 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 29 | even = (kernel_size % 2 == 0) 30 | half_size = kernel_size // 2 31 | 32 | #For kaiser window 33 | delta_f = 4 * half_width 34 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 35 | if A > 50.: 36 | beta = 0.1102 * (A - 8.7) 37 | elif A >= 21.: 38 | beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) 39 | else: 40 | beta = 0. 41 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 42 | 43 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 44 | if even: 45 | time = (torch.arange(-half_size, half_size) + 0.5) 46 | else: 47 | time = torch.arange(kernel_size) - half_size 48 | if cutoff == 0: 49 | filter_ = torch.zeros_like(time) 50 | else: 51 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 52 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 53 | # of the constant component in the input signal. 54 | filter_ /= filter_.sum() 55 | filter = filter_.view(1, 1, kernel_size) 56 | 57 | return filter 58 | 59 | 60 | class LowPassFilter1d(nn.Module): 61 | def __init__(self, 62 | cutoff=0.5, 63 | half_width=0.6, 64 | stride: int = 1, 65 | padding: bool = True, 66 | padding_mode: str = 'replicate', 67 | kernel_size: int = 12): 68 | # kernel_size should be even number for stylegan3 setup, 69 | # in this implementation, odd number is also possible. 70 | super().__init__() 71 | if cutoff < -0.: 72 | raise ValueError("Minimum cutoff must be larger than zero.") 73 | if cutoff > 0.5: 74 | raise ValueError("A cutoff above 0.5 does not make sense.") 75 | self.kernel_size = kernel_size 76 | self.even = (kernel_size % 2 == 0) 77 | self.pad_left = kernel_size // 2 - int(self.even) 78 | self.pad_right = kernel_size // 2 79 | self.stride = stride 80 | self.padding = padding 81 | self.padding_mode = padding_mode 82 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 83 | self.register_buffer("filter", filter) 84 | 85 | #input [B, C, T] 86 | def forward(self, x): 87 | _, C, _ = x.shape 88 | 89 | if self.padding: 90 | x = F.pad(x, (self.pad_left, self.pad_right), 91 | mode=self.padding_mode) 92 | out = F.conv1d(x, self.filter.expand(C, -1, -1), 93 | stride=self.stride, groups=C) 94 | 95 | return out -------------------------------------------------------------------------------- /vq/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, 20 | half_width=0.6 / ratio, 21 | kernel_size=self.kernel_size) 22 | self.register_buffer("filter", filter) 23 | 24 | # x: [B, C, T] 25 | def forward(self, x): 26 | _, C, _ = x.shape 27 | 28 | x = F.pad(x, (self.pad, self.pad), mode='replicate') 29 | x = self.ratio * F.conv_transpose1d( 30 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 31 | x = x[..., self.pad_left:-self.pad_right] 32 | 33 | return x 34 | 35 | 36 | class DownSample1d(nn.Module): 37 | def __init__(self, ratio=2, kernel_size=None): 38 | super().__init__() 39 | self.ratio = ratio 40 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 41 | self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, 42 | half_width=0.6 / ratio, 43 | stride=ratio, 44 | kernel_size=self.kernel_size) 45 | 46 | def forward(self, x): 47 | xx = self.lowpass(x) 48 | 49 | return xx -------------------------------------------------------------------------------- /vq/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Sequence, Type, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] 8 | 9 | 10 | class FeedForwardModule(nn.Module): 11 | 12 | def __init__(self) -> None: 13 | super().__init__() 14 | self.net = None 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | return self.net(x) 18 | 19 | 20 | class Residual(nn.Module): 21 | 22 | def __init__(self, module: nn.Module) -> None: 23 | super().__init__() 24 | self.module = module 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | return self.module(x) + x 28 | 29 | 30 | class DilatedConvolutionalUnit(FeedForwardModule): 31 | 32 | def __init__( 33 | self, 34 | hidden_dim: int, 35 | dilation: int, 36 | kernel_size: int, 37 | activation: ModuleFactory, 38 | normalization: Callable[[nn.Module], 39 | nn.Module] = lambda x: x) -> None: 40 | super().__init__() 41 | self.net = nn.Sequential( 42 | activation(), 43 | normalization( 44 | nn.Conv1d( 45 | in_channels=hidden_dim, 46 | out_channels=hidden_dim, 47 | kernel_size=kernel_size, 48 | dilation=dilation, 49 | padding=((kernel_size - 1) * dilation) // 2, 50 | )), 51 | activation(), 52 | nn.Conv1d(in_channels=hidden_dim, 53 | out_channels=hidden_dim, 54 | kernel_size=1), 55 | ) 56 | 57 | 58 | class UpsamplingUnit(FeedForwardModule): 59 | 60 | def __init__( 61 | self, 62 | input_dim: int, 63 | output_dim: int, 64 | stride: int, 65 | activation: ModuleFactory, 66 | normalization: Callable[[nn.Module], 67 | nn.Module] = lambda x: x) -> None: 68 | super().__init__() 69 | self.net = nn.Sequential( 70 | activation(), 71 | normalization( 72 | nn.ConvTranspose1d( 73 | in_channels=input_dim, 74 | out_channels=output_dim, 75 | kernel_size=2 * stride, 76 | stride=stride, 77 | padding=stride // 2+ stride % 2, 78 | output_padding=1 if stride % 2 != 0 else 0 79 | ))) 80 | 81 | 82 | class DownsamplingUnit(FeedForwardModule): 83 | 84 | def __init__( 85 | self, 86 | input_dim: int, 87 | output_dim: int, 88 | stride: int, 89 | activation: ModuleFactory, 90 | normalization: Callable[[nn.Module], 91 | nn.Module] = lambda x: x) -> None: 92 | super().__init__() 93 | self.net = nn.Sequential( 94 | activation(), 95 | normalization( 96 | nn.Conv1d( 97 | in_channels=input_dim, 98 | out_channels=output_dim, 99 | kernel_size=2 * stride, 100 | stride=stride, 101 | padding= stride // 2+ stride % 2, 102 | 103 | ))) 104 | 105 | 106 | class DilatedResidualEncoder(FeedForwardModule): 107 | 108 | def __init__( 109 | self, 110 | capacity: int, 111 | dilated_unit: Type[DilatedConvolutionalUnit], 112 | downsampling_unit: Type[DownsamplingUnit], 113 | ratios: Sequence[int], 114 | dilations: Union[Sequence[int], Sequence[Sequence[int]]], 115 | pre_network_conv: Type[nn.Conv1d], 116 | post_network_conv: Type[nn.Conv1d], 117 | normalization: Callable[[nn.Module], 118 | nn.Module] = lambda x: x) -> None: 119 | super().__init__() 120 | channels = capacity * 2**np.arange(len(ratios) + 1) 121 | 122 | dilations_list = self.normalize_dilations(dilations, ratios) 123 | 124 | net = [normalization(pre_network_conv(out_channels=channels[0]))] 125 | 126 | for ratio, dilations, input_dim, output_dim in zip( 127 | ratios, dilations_list, channels[:-1], channels[1:]): 128 | for dilation in dilations: 129 | net.append(Residual(dilated_unit(input_dim, dilation))) 130 | net.append(downsampling_unit(input_dim, output_dim, ratio)) 131 | 132 | net.append(post_network_conv(in_channels=output_dim)) 133 | 134 | self.net = nn.Sequential(*net) 135 | 136 | @staticmethod 137 | def normalize_dilations(dilations: Union[Sequence[int], 138 | Sequence[Sequence[int]]], 139 | ratios: Sequence[int]): 140 | if isinstance(dilations[0], int): 141 | dilations = [dilations for _ in ratios] 142 | return dilations 143 | 144 | 145 | class DilatedResidualDecoder(FeedForwardModule): 146 | 147 | def __init__( 148 | self, 149 | capacity: int, 150 | dilated_unit: Type[DilatedConvolutionalUnit], 151 | upsampling_unit: Type[UpsamplingUnit], 152 | ratios: Sequence[int], 153 | dilations: Union[Sequence[int], Sequence[Sequence[int]]], 154 | pre_network_conv: Type[nn.Conv1d], 155 | post_network_conv: Type[nn.Conv1d], 156 | normalization: Callable[[nn.Module], 157 | nn.Module] = lambda x: x) -> None: 158 | super().__init__() 159 | channels = capacity * 2**np.arange(len(ratios) + 1) 160 | channels = channels[::-1] 161 | 162 | dilations_list = self.normalize_dilations(dilations, ratios) 163 | dilations_list = dilations_list[::-1] 164 | 165 | net = [pre_network_conv(out_channels=channels[0])] 166 | 167 | for ratio, dilations, input_dim, output_dim in zip( 168 | ratios, dilations_list, channels[:-1], channels[1:]): 169 | net.append(upsampling_unit(input_dim, output_dim, ratio)) 170 | for dilation in dilations: 171 | net.append(Residual(dilated_unit(output_dim, dilation))) 172 | 173 | net.append(normalization(post_network_conv(in_channels=output_dim))) 174 | 175 | self.net = nn.Sequential(*net) 176 | 177 | @staticmethod 178 | def normalize_dilations(dilations: Union[Sequence[int], 179 | Sequence[Sequence[int]]], 180 | ratios: Sequence[int]): 181 | if isinstance(dilations[0], int): 182 | dilations = [dilations for _ in ratios] 183 | return dilations -------------------------------------------------------------------------------- /vq/bs_roformer5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Module, ModuleList 5 | import torchaudio 6 | from einops import rearrange 7 | import numpy as np 8 | # from rotary_embedding_torch import RotaryEmbedding 9 | 10 | from torchtune.modules import RotaryPositionalEmbeddings 11 | 12 | 13 | 14 | class RMSNorm(torch.nn.Module): 15 | def __init__(self, dim: int, eps: float = 1e-6): 16 | r"""https://github.com/meta-llama/llama/blob/main/llama/model.py""" 17 | super().__init__() 18 | self.eps = eps 19 | self.weight = nn.Parameter(torch.ones(dim)) 20 | 21 | def forward(self, x): 22 | norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) 23 | output = x * torch.rsqrt(norm_x + self.eps) * self.weight 24 | return output 25 | 26 | 27 | 28 | class MLP(nn.Module): 29 | def __init__(self, dim: int) -> None: 30 | super().__init__() 31 | 32 | self.fc1 = nn.Linear(dim, 4 * dim, bias=False) 33 | self.silu = nn.SiLU() 34 | self.fc2 = nn.Linear(4 * dim, dim, bias=False) 35 | 36 | def forward(self, x): 37 | x = self.fc1(x) 38 | x = self.silu(x) 39 | x = self.fc2(x) 40 | return x 41 | 42 | 43 | class Attention(nn.Module): 44 | 45 | def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): 46 | super().__init__() 47 | 48 | assert dim % n_heads == 0 49 | 50 | self.n_heads = n_heads 51 | self.dim = dim 52 | self.rotary_embed = rotary_embed 53 | 54 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 55 | assert self.flash, "Must have flash attention." 56 | 57 | self.c_attn = nn.Linear(dim, 3 * dim, bias=False) 58 | self.c_proj = nn.Linear(dim, dim, bias=False) 59 | 60 | def forward(self, x): 61 | r""" 62 | Args: 63 | x: (b, t, h*d) 64 | 65 | Constants: 66 | b: batch_size 67 | t: time steps 68 | r: 3 69 | h: heads_num 70 | d: heads_dim 71 | """ 72 | B, T, C = x.size() 73 | 74 | q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads) 75 | # q, k, v: (b, h, t, d) 76 | 77 | q = self.rotary_embed(q) 78 | k = self.rotary_embed(k) 79 | 80 | if self.flash: 81 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) 82 | 83 | y = rearrange(y, 'b h t d -> b t (h d)') 84 | 85 | y = self.c_proj(y) 86 | # shape: (b, t, h*d) 87 | 88 | return y 89 | 90 | 91 | class TransformerBlock(nn.Module): 92 | def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): 93 | 94 | super().__init__() 95 | self.dim = dim 96 | self.n_heads = n_heads 97 | 98 | self.att_norm = RMSNorm(dim) 99 | self.ffn_norm = RMSNorm(dim) 100 | self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed) 101 | self.mlp = MLP(dim=dim) 102 | 103 | 104 | def forward( 105 | self, 106 | x: torch.Tensor, 107 | ): 108 | x = x + self.att(self.att_norm(x)) 109 | x = x + self.mlp(self.ffn_norm(x)) 110 | return x 111 | 112 | 113 | if __name__ == '__main__': 114 | rotary_embed_128 = RotaryPositionalEmbeddings(dim=128) 115 | transformer_block = TransformerBlock( 116 | dim=1024, 117 | n_heads=8, 118 | rotary_embed=rotary_embed_128 119 | ) 120 | x = torch.randn(2, 128, 1024) 121 | y = transformer_block(x) 122 | print(y.shape) 123 | c=1 -------------------------------------------------------------------------------- /vq/codec_decoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv1d_transformer') 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from vq.residual_vq import ResidualVQ 7 | from vq.module import WNConv1d, DecoderBlock, ResLSTM 8 | from vq.alias_free_torch import * 9 | from vq import activations 10 | import vq.blocks as blocks 11 | from torch.nn import utils 12 | 13 | from vq.bs_roformer5 import TransformerBlock 14 | 15 | from torchtune.modules import RotaryPositionalEmbeddings 16 | 17 | def init_weights(m): 18 | if isinstance(m, nn.Conv1d): 19 | nn.init.trunc_normal_(m.weight, std=0.02) 20 | nn.init.constant_(m.bias, 0) 21 | 22 | class CodecDecoder(nn.Module): 23 | def __init__(self, 24 | in_channels=1024, 25 | upsample_initial_channel=1536, 26 | ngf=48, 27 | use_rnn=True, 28 | rnn_bidirectional=False, 29 | rnn_num_layers=2, 30 | up_ratios=(5, 4, 4, 4, 2), 31 | dilations=(1, 3, 9), 32 | vq_num_quantizers=1, 33 | vq_dim=2048, 34 | vq_commit_weight=0.25, 35 | vq_weight_init=False, 36 | vq_full_commit_loss=False, 37 | codebook_size=16384, 38 | codebook_dim=32, 39 | ): 40 | super().__init__() 41 | self.hop_length = np.prod(up_ratios) 42 | self.ngf = ngf 43 | self.up_ratios = up_ratios 44 | 45 | self.quantizer = ResidualVQ( 46 | num_quantizers=vq_num_quantizers, 47 | dim=vq_dim, # double the dim for acousitc and semantic 48 | codebook_size=codebook_size, 49 | codebook_dim=codebook_dim, 50 | threshold_ema_dead_code=2, 51 | commitment=vq_commit_weight, 52 | weight_init=vq_weight_init, 53 | full_commit_loss=vq_full_commit_loss, 54 | ) 55 | channels = upsample_initial_channel 56 | layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] 57 | 58 | if use_rnn: 59 | layers += [ 60 | ResLSTM(channels, 61 | num_layers=rnn_num_layers, 62 | bidirectional=rnn_bidirectional 63 | ) 64 | ] 65 | 66 | for i, stride in enumerate(up_ratios): 67 | input_dim = channels // 2**i 68 | output_dim = channels // 2 ** (i + 1) 69 | layers += [DecoderBlock(input_dim, output_dim, stride, dilations)] 70 | 71 | layers += [ 72 | Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)), 73 | WNConv1d(output_dim, 1, kernel_size=7, padding=3), 74 | nn.Tanh(), 75 | ] 76 | 77 | self.model = nn.Sequential(*layers) 78 | 79 | self.reset_parameters() 80 | 81 | def forward(self, x, vq=True): 82 | if vq is True: 83 | x, q, commit_loss = self.quantizer(x) 84 | return x, q, commit_loss 85 | x = self.model(x) 86 | return x 87 | 88 | def vq2emb(self, vq): 89 | self.quantizer = self.quantizer.eval() 90 | x = self.quantizer.vq2emb(vq) 91 | return x 92 | 93 | def get_emb(self): 94 | self.quantizer = self.quantizer.eval() 95 | embs = self.quantizer.get_emb() 96 | return embs 97 | 98 | def inference_vq(self, vq): 99 | x = vq[None,:,:] 100 | x = self.model(x) 101 | return x 102 | 103 | def inference_0(self, x): 104 | x, q, loss, perp = self.quantizer(x) 105 | x = self.model(x) 106 | return x, None 107 | 108 | def inference(self, x): 109 | x = self.model(x) 110 | return x, None 111 | 112 | 113 | def remove_weight_norm(self): 114 | """Remove weight normalization module from all of the layers.""" 115 | 116 | def _remove_weight_norm(m): 117 | try: 118 | torch.nn.utils.remove_weight_norm(m) 119 | except ValueError: # this module didn't have weight norm 120 | return 121 | 122 | self.apply(_remove_weight_norm) 123 | 124 | def apply_weight_norm(self): 125 | """Apply weight normalization module from all of the layers.""" 126 | 127 | def _apply_weight_norm(m): 128 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 129 | torch.nn.utils.weight_norm(m) 130 | 131 | self.apply(_apply_weight_norm) 132 | 133 | def reset_parameters(self): 134 | self.apply(init_weights) 135 | 136 | 137 | class CodecDecoder_oobleck_Transformer(nn.Module): 138 | def __init__(self, 139 | ngf=32, 140 | up_ratios=(5, 4, 4, 4, 2), 141 | dilations=(1, 3, 9), 142 | vq_num_quantizers=1, 143 | vq_dim=1024, 144 | vq_commit_weight=0.25, 145 | vq_weight_init=False, 146 | vq_full_commit_loss=False, 147 | codebook_size=16384, 148 | codebook_dim=16, 149 | hidden_dim=1024, 150 | depth=12, 151 | heads=16, 152 | pos_meb_dim=64, 153 | ): 154 | super().__init__() 155 | self.hop_length = np.prod(up_ratios) 156 | self.capacity = ngf 157 | self.up_ratios = up_ratios 158 | self.hidden_dim = hidden_dim 159 | self.quantizer = ResidualVQ( 160 | num_quantizers=vq_num_quantizers, 161 | dim=vq_dim, # double the dim for acousitc and semantic 162 | codebook_size=codebook_size, 163 | codebook_dim=codebook_dim, 164 | threshold_ema_dead_code=2, 165 | commitment=vq_commit_weight, 166 | weight_init=vq_weight_init, 167 | full_commit_loss=vq_full_commit_loss, 168 | ) 169 | 170 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 171 | 172 | transformer_blocks = [ 173 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 174 | for _ in range(depth) 175 | ] 176 | 177 | self.transformers = nn.Sequential(*transformer_blocks) 178 | 179 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 180 | 181 | self.conv_blocks = blocks.DilatedResidualDecoder( 182 | capacity=self.capacity, 183 | dilated_unit=self.dilated_unit, 184 | upsampling_unit=self.upsampling_unit, 185 | ratios=up_ratios, # 逆转编码器的下采样比率 186 | dilations=dilations, 187 | pre_network_conv=self.pre_conv, 188 | post_network_conv=self.post_conv, 189 | ) 190 | 191 | 192 | 193 | self.reset_parameters() 194 | 195 | def forward(self, x, vq=True): 196 | if vq is True: 197 | x, q, commit_loss = self.quantizer(x) 198 | return x, q, commit_loss 199 | x= self.transformers(x) 200 | x = self.final_layer_norm(x) 201 | x = x.permute(0, 2, 1) 202 | x = self.conv_blocks(x) 203 | return x 204 | 205 | def vq2emb(self, vq): 206 | self.quantizer = self.quantizer.eval() 207 | x = self.quantizer.vq2emb(vq) 208 | return x 209 | 210 | def get_emb(self): 211 | self.quantizer = self.quantizer.eval() 212 | embs = self.quantizer.get_emb() 213 | return embs 214 | 215 | def inference_vq(self, vq): 216 | x = vq[None,:,:] 217 | x = self.model(x) 218 | return x 219 | 220 | def inference_0(self, x): 221 | x, q, loss, perp = self.quantizer(x) 222 | x = self.model(x) 223 | return x, None 224 | 225 | def inference(self, x): 226 | x = self.model(x) 227 | return x, None 228 | 229 | 230 | def remove_weight_norm(self): 231 | """Remove weight normalization module from all of the layers.""" 232 | 233 | def _remove_weight_norm(m): 234 | try: 235 | torch.nn.utils.remove_weight_norm(m) 236 | except ValueError: # this module didn't have weight norm 237 | return 238 | 239 | self.apply(_remove_weight_norm) 240 | 241 | def apply_weight_norm(self): 242 | """Apply weight normalization module from all of the layers.""" 243 | 244 | def _apply_weight_norm(m): 245 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 246 | torch.nn.utils.weight_norm(m) 247 | 248 | self.apply(_apply_weight_norm) 249 | 250 | def reset_parameters(self): 251 | self.apply(init_weights) 252 | 253 | def pre_conv(self, out_channels): 254 | return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1) 255 | 256 | # 定义后处理卷积层,将模型的输出映射到最终的输出通道数 257 | def post_conv(self,in_channels): 258 | return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1) 259 | 260 | def dilated_unit(self, hidden_dim, dilation): 261 | return blocks.DilatedConvolutionalUnit( 262 | hidden_dim=hidden_dim, 263 | dilation=dilation, 264 | kernel_size=3, 265 | activation=nn.ReLU , 266 | normalization=utils.weight_norm 267 | ) 268 | 269 | # 定义上采样单元 270 | def upsampling_unit(self,input_dim, output_dim, stride): 271 | return blocks.UpsamplingUnit( 272 | input_dim=input_dim, 273 | output_dim=output_dim, 274 | stride=stride, 275 | activation=nn.ReLU , 276 | normalization=utils.weight_norm 277 | ) 278 | 279 | def main(): 280 | # 设置设备 281 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 282 | print(f"Using device: {device}") 283 | 284 | # 初始化模型 285 | model = CodecDecoder_oobleck_Transformer().to(device) 286 | print("Model initialized.") 287 | 288 | # 创建测试输入: batch_size x in_channels x sequence_length 289 | batch_size = 2 290 | in_channels = 1024 291 | sequence_length = 100 # 示例长度,可以根据需要调整 292 | dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device) 293 | print(f"Dummy input shape: {dummy_input.shape}") 294 | 295 | # 将模型设为评估模式 296 | model.eval() 297 | 298 | 299 | 300 | output_no_vq = model(dummy_input, vq=False) 301 | c=1 302 | 303 | if __name__ == "__main__": 304 | main() -------------------------------------------------------------------------------- /vq/codec_decoder_vocos.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos') 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from vq.residual_vq import ResidualVQ 7 | from vq.module import WNConv1d, DecoderBlock, ResLSTM 8 | from vq.alias_free_torch import * 9 | from vq import activations 10 | from typing import Optional 11 | from vq.module import ConvNeXtBlock, AdaLayerNorm 12 | from vq.bs_roformer5 import TransformerBlock 13 | # from rotary_embedding_torch import RotaryEmbedding 14 | from torchtune.modules import RotaryPositionalEmbeddings 15 | from vector_quantize_pytorch import ResidualFSQ 16 | from torch.nn import Module, ModuleList 17 | class ISTFT(nn.Module): 18 | """ 19 | Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with 20 | windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. 21 | See issue: https://github.com/pytorch/pytorch/issues/62323 22 | Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. 23 | The NOLA constraint is met as we trim padded samples anyway. 24 | 25 | Args: 26 | n_fft (int): Size of Fourier transform. 27 | hop_length (int): The distance between neighboring sliding window frames. 28 | win_length (int): The size of window frame and STFT filter. 29 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 30 | """ 31 | 32 | def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): 33 | super().__init__() 34 | if padding not in ["center", "same"]: 35 | raise ValueError("Padding must be 'center' or 'same'.") 36 | self.padding = padding 37 | self.n_fft = n_fft 38 | self.hop_length = hop_length 39 | self.win_length = win_length 40 | window = torch.hann_window(win_length) 41 | self.register_buffer("window", window) 42 | 43 | def forward(self, spec: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. 46 | 47 | Args: 48 | spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, 49 | N is the number of frequency bins, and T is the number of time frames. 50 | 51 | Returns: 52 | Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. 53 | """ 54 | if self.padding == "center": 55 | # Fallback to pytorch native implementation 56 | return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) 57 | elif self.padding == "same": 58 | pad = (self.win_length - self.hop_length) // 2 59 | else: 60 | raise ValueError("Padding must be 'center' or 'same'.") 61 | 62 | assert spec.dim() == 3, "Expected a 3D tensor as input" 63 | B, N, T = spec.shape 64 | 65 | # Inverse FFT 66 | ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") 67 | ifft = ifft * self.window[None, :, None] 68 | 69 | # Overlap and Add 70 | output_size = (T - 1) * self.hop_length + self.win_length 71 | y = torch.nn.functional.fold( 72 | ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 73 | )[:, 0, 0, pad:-pad] 74 | 75 | # Window envelope 76 | window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) 77 | window_envelope = torch.nn.functional.fold( 78 | window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 79 | ).squeeze()[pad:-pad] 80 | 81 | # Normalize 82 | assert (window_envelope > 1e-11).all() 83 | y = y / window_envelope 84 | 85 | return y 86 | 87 | 88 | 89 | class FourierHead(nn.Module): 90 | """Base class for inverse fourier modules.""" 91 | 92 | def forward(self, x: torch.Tensor) -> torch.Tensor: 93 | """ 94 | Args: 95 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 96 | L is the sequence length, and H denotes the model dimension. 97 | 98 | Returns: 99 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 100 | """ 101 | raise NotImplementedError("Subclasses must implement the forward method.") 102 | 103 | 104 | class ISTFTHead(FourierHead): 105 | """ 106 | ISTFT Head module for predicting STFT complex coefficients. 107 | 108 | Args: 109 | dim (int): Hidden dimension of the model. 110 | n_fft (int): Size of Fourier transform. 111 | hop_length (int): The distance between neighboring sliding window frames, which should align with 112 | the resolution of the input features. 113 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 114 | """ 115 | 116 | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): 117 | super().__init__() 118 | out_dim = n_fft + 2 119 | self.out = torch.nn.Linear(dim, out_dim) 120 | self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) 121 | 122 | def forward(self, x: torch.Tensor) -> torch.Tensor: 123 | """ 124 | Forward pass of the ISTFTHead module. 125 | 126 | Args: 127 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 128 | L is the sequence length, and H denotes the model dimension. 129 | 130 | Returns: 131 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 132 | """ 133 | x_pred = self.out(x ) 134 | # x_pred = x 135 | x_pred = x_pred.transpose(1, 2) 136 | mag, p = x_pred.chunk(2, dim=1) 137 | mag = torch.exp(mag) 138 | mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes 139 | # wrapping happens here. These two lines produce real and imaginary value 140 | x = torch.cos(p) 141 | y = torch.sin(p) 142 | # recalculating phase here does not produce anything new 143 | # only costs time 144 | # phase = torch.atan2(y, x) 145 | # S = mag * torch.exp(phase * 1j) 146 | # better directly produce the complex value 147 | S = mag * (x + 1j * y) 148 | audio = self.istft(S) 149 | return audio.unsqueeze(1),x_pred 150 | 151 | 152 | def nonlinearity(x): 153 | # swish 154 | return x * torch.sigmoid(x) 155 | 156 | 157 | def Normalize(in_channels, num_groups=32): 158 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) 159 | 160 | 161 | class ResnetBlock(nn.Module): 162 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 163 | dropout, temb_channels=512): 164 | super().__init__() 165 | self.in_channels = in_channels 166 | out_channels = in_channels if out_channels is None else out_channels 167 | self.out_channels = out_channels 168 | self.use_conv_shortcut = conv_shortcut 169 | 170 | self.norm1 = Normalize(in_channels) 171 | self.conv1 = torch.nn.Conv1d(in_channels, 172 | out_channels, 173 | kernel_size=3, 174 | stride=1, 175 | padding=1) 176 | if temb_channels > 0: 177 | self.temb_proj = torch.nn.Linear(temb_channels, 178 | out_channels) 179 | self.norm2 = Normalize(out_channels) 180 | self.dropout = torch.nn.Dropout(dropout) 181 | self.conv2 = torch.nn.Conv1d(out_channels, 182 | out_channels, 183 | kernel_size=3, 184 | stride=1, 185 | padding=1) 186 | if self.in_channels != self.out_channels: 187 | if self.use_conv_shortcut: 188 | self.conv_shortcut = torch.nn.Conv1d(in_channels, 189 | out_channels, 190 | kernel_size=3, 191 | stride=1, 192 | padding=1) 193 | else: 194 | self.nin_shortcut = torch.nn.Conv1d(in_channels, 195 | out_channels, 196 | kernel_size=1, 197 | stride=1, 198 | padding=0) 199 | 200 | def forward(self, x, temb=None): 201 | h = x 202 | h = self.norm1(h) 203 | h = nonlinearity(h) 204 | h = self.conv1(h) 205 | 206 | if temb is not None: 207 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 208 | 209 | h = self.norm2(h) 210 | h = nonlinearity(h) 211 | h = self.dropout(h) 212 | h = self.conv2(h) 213 | 214 | if self.in_channels != self.out_channels: 215 | if self.use_conv_shortcut: 216 | x = self.conv_shortcut(x) 217 | else: 218 | x = self.nin_shortcut(x) 219 | 220 | return x + h 221 | 222 | class AttnBlock(nn.Module): 223 | def __init__(self, in_channels): 224 | super().__init__() 225 | self.in_channels = in_channels 226 | 227 | self.norm = Normalize(in_channels) 228 | self.q = torch.nn.Conv1d(in_channels, 229 | in_channels, 230 | kernel_size=1, 231 | stride=1, 232 | padding=0) 233 | self.k = torch.nn.Conv1d(in_channels, 234 | in_channels, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | self.v = torch.nn.Conv1d(in_channels, 239 | in_channels, 240 | kernel_size=1, 241 | stride=1, 242 | padding=0) 243 | self.proj_out = torch.nn.Conv1d(in_channels, 244 | in_channels, 245 | kernel_size=1, 246 | stride=1, 247 | padding=0) 248 | 249 | def forward(self, x): 250 | h_ = x 251 | h_ = self.norm(h_) 252 | q = self.q(h_) 253 | k = self.k(h_) 254 | v = self.v(h_) 255 | 256 | # compute attention 257 | b, c, h = q.shape 258 | q = q.permute(0, 2, 1) # b,hw,c 259 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 260 | w_ = w_ * (int(c) ** (-0.5)) 261 | w_ = torch.nn.functional.softmax(w_, dim=2) 262 | 263 | # attend to values 264 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 265 | h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 266 | 267 | h_ = self.proj_out(h_) 268 | 269 | return x + h_ 270 | 271 | def make_attn(in_channels, attn_type="vanilla"): 272 | assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' 273 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 274 | if attn_type == "vanilla": 275 | return AttnBlock(in_channels) 276 | 277 | 278 | class Backbone(nn.Module): 279 | """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" 280 | 281 | def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: 282 | """ 283 | Args: 284 | x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, 285 | C denotes output features, and L is the sequence length. 286 | 287 | Returns: 288 | Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, 289 | and H denotes the model dimension. 290 | """ 291 | raise NotImplementedError("Subclasses must implement the forward method.") 292 | 293 | 294 | class VocosBackbone(Backbone): 295 | """ 296 | Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization 297 | 298 | Args: 299 | input_channels (int): Number of input features channels. 300 | dim (int): Hidden dimension of the model. 301 | intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. 302 | num_layers (int): Number of ConvNeXtBlock layers. 303 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. 304 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 305 | None means non-conditional model. Defaults to None. 306 | """ 307 | 308 | def __init__( 309 | self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): 310 | super().__init__() 311 | 312 | self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3) 313 | 314 | 315 | 316 | self.temb_ch = 0 317 | block_in = hidden_dim 318 | dropout = 0.1 319 | 320 | prior_net : tp.List[nn.Module] = [ 321 | ResnetBlock(in_channels=block_in,out_channels=block_in, 322 | temb_channels=self.temb_ch,dropout=dropout), 323 | ResnetBlock(in_channels=block_in,out_channels=block_in, 324 | temb_channels=self.temb_ch,dropout=dropout), 325 | ] 326 | self.prior_net = nn.Sequential(*prior_net) 327 | 328 | depth = depth 329 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 330 | 331 | 332 | transformer_blocks = [ 333 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 334 | for _ in range(depth) 335 | ] 336 | 337 | 338 | self.transformers = nn.Sequential(*transformer_blocks) 339 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 340 | post_net : tp.List[nn.Module] = [ 341 | ResnetBlock(in_channels=block_in,out_channels=block_in, 342 | temb_channels=self.temb_ch,dropout=dropout), 343 | ResnetBlock(in_channels=block_in,out_channels=block_in, 344 | temb_channels=self.temb_ch,dropout=dropout), 345 | ] 346 | self.post_net = nn.Sequential(*post_net) 347 | 348 | def forward(self, x: torch.Tensor ) -> torch.Tensor: 349 | x = x.transpose(1, 2) 350 | x = self.embed(x) 351 | x = self.prior_net(x) 352 | x = x.transpose(1, 2) 353 | x= self.transformers(x) 354 | x = x.transpose(1, 2) 355 | x = self.post_net(x) 356 | x = x.transpose(1, 2) 357 | x = self.final_layer_norm(x) 358 | return x 359 | 360 | def init_weights(m): 361 | if isinstance(m, nn.Conv1d): 362 | nn.init.trunc_normal_(m.weight, std=0.02) 363 | nn.init.constant_(m.bias, 0) 364 | 365 | class CodecDecoderVocos(nn.Module): 366 | def __init__(self, 367 | hidden_dim=1024, 368 | depth=12, 369 | heads=16, 370 | pos_meb_dim=64, 371 | hop_length=320, 372 | vq_num_quantizers=1, 373 | vq_dim=2048, #1024 2048 374 | vq_commit_weight=0.25, 375 | vq_weight_init=False, 376 | vq_full_commit_loss=False, 377 | codebook_size=16384, 378 | codebook_dim=16, 379 | ): 380 | super().__init__() 381 | self.hop_length = hop_length 382 | 383 | self.quantizer = ResidualFSQ( 384 | dim = vq_dim, 385 | levels = [4, 4, 4, 4, 4,4,4,4], 386 | num_quantizers = 1 387 | ) 388 | 389 | # self.quantizer = ResidualVQ( 390 | # num_quantizers=vq_num_quantizers, 391 | # dim=vq_dim, 392 | # codebook_size=codebook_size, 393 | # codebook_dim=codebook_dim, 394 | # threshold_ema_dead_code=2, 395 | # commitment=vq_commit_weight, 396 | # weight_init=vq_weight_init, 397 | # full_commit_loss=vq_full_commit_loss, 398 | # ) 399 | 400 | 401 | self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) 402 | 403 | self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") 404 | 405 | self.reset_parameters() 406 | 407 | def forward(self, x, vq=True): 408 | if vq is True: 409 | # x, q, commit_loss = self.quantizer(x) 410 | x = x.permute(0, 2, 1) 411 | x, q = self.quantizer(x) 412 | x = x.permute(0, 2, 1) 413 | q = q.permute(0, 2, 1) 414 | return x, q, None 415 | x = self.backbone(x) 416 | x,_ = self.head(x) 417 | 418 | return x ,_ 419 | 420 | def vq2emb(self, vq): 421 | self.quantizer = self.quantizer.eval() 422 | x = self.quantizer.vq2emb(vq) 423 | return x 424 | 425 | def get_emb(self): 426 | self.quantizer = self.quantizer.eval() 427 | embs = self.quantizer.get_emb() 428 | return embs 429 | 430 | def inference_vq(self, vq): 431 | x = vq[None,:,:] 432 | x = self.model(x) 433 | return x 434 | 435 | def inference_0(self, x): 436 | x, q, loss, perp = self.quantizer(x) 437 | x = self.model(x) 438 | return x, None 439 | 440 | def inference(self, x): 441 | x = self.model(x) 442 | return x, None 443 | 444 | 445 | def remove_weight_norm(self): 446 | """Remove weight normalization module from all of the layers.""" 447 | 448 | def _remove_weight_norm(m): 449 | try: 450 | torch.nn.utils.remove_weight_norm(m) 451 | except ValueError: # this module didn't have weight norm 452 | return 453 | 454 | self.apply(_remove_weight_norm) 455 | 456 | def apply_weight_norm(self): 457 | """Apply weight normalization module from all of the layers.""" 458 | 459 | def _apply_weight_norm(m): 460 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 461 | torch.nn.utils.weight_norm(m) 462 | 463 | self.apply(_apply_weight_norm) 464 | 465 | def reset_parameters(self): 466 | self.apply(init_weights) 467 | 468 | 469 | 470 | class CodecDecoderVocos_transpose(nn.Module): 471 | def __init__(self, 472 | hidden_dim=1024, 473 | depth=12, 474 | heads=16, 475 | pos_meb_dim=64, 476 | hop_length=320, 477 | vq_num_quantizers=1, 478 | vq_dim=1024, #1024 2048 479 | vq_commit_weight=0.25, 480 | vq_weight_init=False, 481 | vq_full_commit_loss=False, 482 | codebook_size=16384, 483 | codebook_dim=16, 484 | ): 485 | super().__init__() 486 | self.hop_length = hop_length 487 | 488 | 489 | self.quantizer = ResidualVQ( 490 | num_quantizers=vq_num_quantizers, 491 | dim=vq_dim, 492 | codebook_size=codebook_size, 493 | codebook_dim=codebook_dim, 494 | threshold_ema_dead_code=2, 495 | commitment=vq_commit_weight, 496 | weight_init=vq_weight_init, 497 | full_commit_loss=vq_full_commit_loss, 498 | ) 499 | 500 | 501 | self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) 502 | 503 | self.inverse_mel_conv = nn.Sequential( 504 | nn.GELU(), 505 | nn.ConvTranspose1d( 506 | in_channels=hidden_dim, 507 | out_channels=hidden_dim, 508 | kernel_size=3, 509 | stride=2, 510 | padding=1, 511 | output_padding=1 # 确保输出长度与编码前匹配 512 | ), 513 | nn.GELU(), 514 | nn.ConvTranspose1d( 515 | in_channels=hidden_dim, 516 | out_channels=hidden_dim, 517 | kernel_size=3, 518 | padding=1 519 | ) 520 | ) 521 | 522 | self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") 523 | 524 | self.reset_parameters() 525 | 526 | def forward(self, x, vq=True): 527 | if vq is True: 528 | x, q, commit_loss = self.quantizer(x) 529 | return x, q, commit_loss 530 | x = self.backbone(x) 531 | x,_ = self.head(x) 532 | 533 | return x ,_ 534 | 535 | def vq2emb(self, vq): 536 | self.quantizer = self.quantizer.eval() 537 | x = self.quantizer.vq2emb(vq) 538 | return x 539 | 540 | def get_emb(self): 541 | self.quantizer = self.quantizer.eval() 542 | embs = self.quantizer.get_emb() 543 | return embs 544 | 545 | def inference_vq(self, vq): 546 | x = vq[None,:,:] 547 | x = self.model(x) 548 | return x 549 | 550 | def inference_0(self, x): 551 | x, q, loss, perp = self.quantizer(x) 552 | x = self.model(x) 553 | return x, None 554 | 555 | def inference(self, x): 556 | x = self.model(x) 557 | return x, None 558 | 559 | 560 | def remove_weight_norm(self): 561 | """Remove weight normalization module from all of the layers.""" 562 | 563 | def _remove_weight_norm(m): 564 | try: 565 | torch.nn.utils.remove_weight_norm(m) 566 | except ValueError: # this module didn't have weight norm 567 | return 568 | 569 | self.apply(_remove_weight_norm) 570 | 571 | def apply_weight_norm(self): 572 | """Apply weight normalization module from all of the layers.""" 573 | 574 | def _apply_weight_norm(m): 575 | if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): 576 | torch.nn.utils.weight_norm(m) 577 | 578 | self.apply(_apply_weight_norm) 579 | 580 | def reset_parameters(self): 581 | self.apply(init_weights) 582 | 583 | 584 | 585 | 586 | def main(): 587 | # 设置设备 588 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 589 | print(f"Using device: {device}") 590 | 591 | # 初始化模型 592 | model = CodecDecoderVocos_transpose().to(device) 593 | print("Model initialized.") 594 | 595 | # 创建测试输入: batch_size x in_channels x sequence_length 596 | batch_size = 2 597 | in_channels = 1024 598 | sequence_length = 50 # 示例长度,可以根据需要调整 599 | dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device) 600 | print(f"Dummy input shape: {dummy_input.shape}") 601 | 602 | # 将模型设为评估模式 603 | model.eval() 604 | 605 | # 前向传播(使用 VQ) 606 | # with torch.no_grad(): 607 | # try: 608 | # output, q, commit_loss = model(dummy_input, vq=True) 609 | # print("Forward pass with VQ:") 610 | # print(f"Output shape: {output.shape}") 611 | # print(f"Quantized codes shape: {q.shape}") 612 | # print(f"Commitment loss: {commit_loss}") 613 | # except Exception as e: 614 | # print(f"Error during forward pass with VQ: {e}") 615 | 616 | # 前向传播(不使用 VQ) 617 | with torch.no_grad(): 618 | # try: 619 | output_no_vq = model(dummy_input, vq=False) 620 | print("\nForward pass without VQ:") 621 | print(f"Output shape: {output_no_vq.shape}") 622 | c=1 623 | # except Exception as e: 624 | # print(f"Error during forward pass without VQ: {e}") 625 | 626 | 627 | # model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) 628 | # model_size_mb = model_size_bytes / (1024 ** 2) 629 | # print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)") 630 | 631 | if __name__ == "__main__": 632 | main() -------------------------------------------------------------------------------- /vq/codec_encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv1d_transformer') 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | from vq.module import WNConv1d, EncoderBlock, ResLSTM 7 | from vq.alias_free_torch import * 8 | from vq import activations 9 | from vq.bs_roformer5 import TransformerBlock 10 | # from rotary_embedding_torch import RotaryEmbedding 11 | from torchtune.modules import RotaryPositionalEmbeddings 12 | import vq.blocks as blocks 13 | from torch.nn import utils 14 | def init_weights(m): 15 | if isinstance(m, nn.Conv1d): 16 | nn.init.trunc_normal_(m.weight, std=0.02) 17 | nn.init.constant_(m.bias, 0) 18 | 19 | class CodecEncoder(nn.Module): 20 | def __init__(self, 21 | ngf=48, 22 | use_rnn=True, 23 | rnn_bidirectional=False, 24 | rnn_num_layers=2, 25 | up_ratios=(2, 2, 4, 4, 5), 26 | dilations=(1, 3, 9), 27 | out_channels=1024): 28 | super().__init__() 29 | self.hop_length = np.prod(up_ratios) 30 | self.ngf = ngf 31 | self.up_ratios = up_ratios 32 | 33 | # Create first convolution 34 | d_model = ngf 35 | self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 36 | 37 | # Create EncoderBlocks that double channels as they downsample by `stride` 38 | for i, stride in enumerate(up_ratios): 39 | d_model *= 2 40 | self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)] 41 | # RNN 42 | if use_rnn: 43 | self.block += [ 44 | ResLSTM(d_model, 45 | num_layers=rnn_num_layers, 46 | bidirectional=rnn_bidirectional 47 | ) 48 | ] 49 | # Create last convolution 50 | self.block += [ 51 | Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), 52 | WNConv1d(d_model, out_channels, kernel_size=3, padding=1), 53 | ] 54 | 55 | # Wrap black into nn.Sequential 56 | self.block = nn.Sequential(*self.block) 57 | self.enc_dim = d_model 58 | 59 | self.reset_parameters() 60 | 61 | def forward(self, x): 62 | out = self.block(x) 63 | return out 64 | 65 | def inference(self, x): 66 | return self.block(x) 67 | 68 | def remove_weight_norm(self): 69 | """Remove weight normalization module from all of the layers.""" 70 | 71 | def _remove_weight_norm(m): 72 | try: 73 | torch.nn.utils.remove_weight_norm(m) 74 | except ValueError: # this module didn't have weight norm 75 | return 76 | 77 | self.apply(_remove_weight_norm) 78 | 79 | def apply_weight_norm(self): 80 | """Apply weight normalization module from all of the layers.""" 81 | 82 | def _apply_weight_norm(m): 83 | if isinstance(m, nn.Conv1d): 84 | torch.nn.utils.weight_norm(m) 85 | 86 | self.apply(_apply_weight_norm) 87 | 88 | def reset_parameters(self): 89 | self.apply(init_weights) 90 | 91 | 92 | class Transpose(nn.Module): 93 | def __init__(self, dim1, dim2): 94 | super(Transpose, self).__init__() 95 | self.dim1 = dim1 96 | self.dim2 = dim2 97 | 98 | def forward(self, x): 99 | return x.transpose(self.dim1, self.dim2) 100 | 101 | class CodecEncoder_Transformer(nn.Module): 102 | def __init__(self, 103 | ngf=48, 104 | up_ratios=[2, 2, 4, 4, 5], 105 | dilations=(1, 3, 9), 106 | hidden_dim=1024, 107 | depth=12, 108 | heads=12, 109 | pos_meb_dim=64, 110 | ): 111 | super().__init__() 112 | self.hop_length = np.prod(up_ratios) 113 | self.ngf =ngf 114 | self.up_ratios = up_ratios 115 | 116 | d_model = ngf 117 | self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 118 | 119 | 120 | for i, stride in enumerate(up_ratios): 121 | d_model *= 2 122 | self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)] 123 | 124 | self.conv_blocks = nn.Sequential(*self.conv_blocks) 125 | 126 | 127 | # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 128 | 129 | 130 | # transformer_blocks = [ 131 | # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 132 | # for _ in range(depth) 133 | # ] 134 | 135 | 136 | # self.transformers = nn.Sequential(*transformer_blocks) 137 | 138 | # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 139 | 140 | self.conv_final_block = [ 141 | Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), 142 | WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1), 143 | ] 144 | self.conv_final_block = nn.Sequential(*self.conv_final_block) 145 | 146 | self.reset_parameters() 147 | 148 | def forward(self, x): 149 | x = self.conv_blocks(x) 150 | # x = x.permute(0, 2, 1) 151 | # x= self.transformers(x) 152 | # x = self.final_layer_norm(x) 153 | # x = x.permute(0, 2, 1) 154 | x = self.conv_final_block (x) 155 | x = x.permute(0, 2, 1) 156 | return x 157 | 158 | def inference(self, x): 159 | return self.block(x) 160 | 161 | def remove_weight_norm(self): 162 | """Remove weight normalization module from all of the layers.""" 163 | 164 | def _remove_weight_norm(m): 165 | try: 166 | torch.nn.utils.remove_weight_norm(m) 167 | except ValueError: # this module didn't have weight norm 168 | return 169 | 170 | self.apply(_remove_weight_norm) 171 | 172 | def apply_weight_norm(self): 173 | """Apply weight normalization module from all of the layers.""" 174 | 175 | def _apply_weight_norm(m): 176 | if isinstance(m, nn.Conv1d): 177 | torch.nn.utils.weight_norm(m) 178 | 179 | self.apply(_apply_weight_norm) 180 | 181 | def reset_parameters(self): 182 | self.apply(init_weights) 183 | 184 | 185 | 186 | class Codec_oobleck_Transformer(nn.Module): 187 | def __init__(self, 188 | ngf=32, 189 | up_ratios=(2, 2,4,4, 5), 190 | dilations=(1, 3, 9), 191 | hidden_dim=1024, 192 | depth=12, 193 | heads=16, 194 | pos_meb_dim=64, 195 | ): 196 | super().__init__() 197 | self.hop_length = np.prod(up_ratios) 198 | self.ngf =ngf 199 | self.up_ratios = up_ratios 200 | self.hidden_dim = hidden_dim 201 | 202 | 203 | self.conv_blocks = blocks.DilatedResidualEncoder( 204 | capacity=ngf, 205 | dilated_unit=self.dilated_unit, 206 | downsampling_unit=self.downsampling_unit, 207 | ratios=up_ratios, 208 | dilations=dilations, 209 | pre_network_conv=self.pre_conv, 210 | post_network_conv=self.post_conv, 211 | ) 212 | 213 | 214 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 215 | 216 | transformer_blocks = [ 217 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 218 | for _ in range(depth) 219 | ] 220 | 221 | self.transformers = nn.Sequential(*transformer_blocks) 222 | 223 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 224 | 225 | 226 | self.reset_parameters() 227 | 228 | def forward(self, x): 229 | x = self.conv_blocks(x) 230 | x = x.permute(0, 2, 1) 231 | x= self.transformers(x) 232 | x = self.final_layer_norm(x) 233 | return x 234 | 235 | def inference(self, x): 236 | return self.block(x) 237 | 238 | def remove_weight_norm(self): 239 | """Remove weight normalization module from all of the layers.""" 240 | 241 | def _remove_weight_norm(m): 242 | try: 243 | torch.nn.utils.remove_weight_norm(m) 244 | except ValueError: # this module didn't have weight norm 245 | return 246 | 247 | self.apply(_remove_weight_norm) 248 | 249 | def apply_weight_norm(self): 250 | """Apply weight normalization module from all of the layers.""" 251 | 252 | def _apply_weight_norm(m): 253 | if isinstance(m, nn.Conv1d): 254 | torch.nn.utils.weight_norm(m) 255 | 256 | self.apply(_apply_weight_norm) 257 | 258 | def reset_parameters(self): 259 | self.apply(init_weights) 260 | 261 | def dilated_unit(self,hidden_dim, dilation): 262 | return blocks.DilatedConvolutionalUnit(hidden_dim, 263 | dilation, 264 | kernel_size=3, 265 | activation=nn.ReLU, 266 | normalization=utils.weight_norm) 267 | 268 | def downsampling_unit(self, input_dim: int, output_dim: int, stride: int): 269 | return blocks.DownsamplingUnit(input_dim, 270 | output_dim, 271 | stride, 272 | nn.ReLU, 273 | normalization=utils.weight_norm) 274 | 275 | def pre_conv(self,out_channels): 276 | return nn.Conv1d(1, out_channels, 1) 277 | 278 | def post_conv(self,in_channels): 279 | return nn.Conv1d(in_channels, self.hidden_dim, 1) 280 | 281 | 282 | 283 | 284 | 285 | class CodecEncoder_only_Transformer(nn.Module): 286 | def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): 287 | super().__init__() 288 | # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300, 289 | 290 | depth = depth 291 | time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) 292 | 293 | 294 | transformer_blocks = [ 295 | TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) 296 | for _ in range(depth) 297 | ] 298 | 299 | 300 | self.transformers = nn.Sequential(*transformer_blocks) 301 | 302 | self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) 303 | 304 | def forward(self, x: torch.Tensor ) -> torch.Tensor: 305 | # x = self.embed(x) 306 | 307 | 308 | x= self.transformers(x) 309 | x = self.final_layer_norm(x) 310 | 311 | return x 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | def get_model_size(model): 320 | # 计算总参数数 321 | total_params = sum(p.numel() for p in model.parameters()) 322 | 323 | # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位) 324 | model_size_bytes = total_params # 每个参数4字节 325 | 326 | # 转换为更易读的单位(例如,MB) 327 | model_size_mb = model_size_bytes / (1024 ** 2) 328 | 329 | return total_params, model_size_mb 330 | 331 | if __name__ == '__main__': 332 | model = Codec_oobleck_Transformer() 333 | x = torch.randn(1, 1, 16000) # example input tensor 334 | output = model(x) 335 | print("Output shape:", output.shape) 336 | -------------------------------------------------------------------------------- /vq/factorized_vector_quantize.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch.nn.utils import weight_norm 9 | 10 | class FactorizedVectorQuantize(nn.Module): 11 | def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): 12 | super().__init__() 13 | self.codebook_size = codebook_size 14 | self.codebook_dim = codebook_dim 15 | self.commitment = commitment 16 | 17 | if dim != self.codebook_dim: 18 | self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) 19 | self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) 20 | else: 21 | self.in_proj = nn.Identity() 22 | self.out_proj = nn.Identity() 23 | self._codebook = nn.Embedding(codebook_size, self.codebook_dim) 24 | 25 | @property 26 | def codebook(self): 27 | return self._codebook 28 | 29 | def forward(self, z): 30 | """Quantized the input tensor using a fixed codebook and returns 31 | the corresponding codebook vectors 32 | 33 | Parameters 34 | ---------- 35 | z : Tensor[B x D x T] 36 | 37 | Returns 38 | ------- 39 | Tensor[B x D x T] 40 | Quantized continuous representation of input 41 | Tensor[1] 42 | Commitment loss to train encoder to predict vectors closer to codebook 43 | entries 44 | Tensor[1] 45 | Codebook loss to update the codebook 46 | Tensor[B x T] 47 | Codebook indices (quantized discrete representation of input) 48 | Tensor[B x D x T] 49 | Projected latents (continuous representation of input before quantization) 50 | """ 51 | # transpose since we use linear 52 | 53 | z = rearrange(z, "b d t -> b t d") 54 | 55 | # Factorized codes project input into low-dimensional space 56 | z_e = self.in_proj(z) # z_e : (B x T x D) 57 | z_e = rearrange(z_e, "b t d -> b d t") 58 | z_q, indices = self.decode_latents(z_e) 59 | 60 | 61 | if self.training: 62 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment 63 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2]) 64 | commit_loss = commitment_loss + codebook_loss 65 | else: 66 | commit_loss = torch.zeros(z.shape[0], device = z.device) 67 | 68 | z_q = ( 69 | z_e + (z_q - z_e).detach() 70 | ) # noop in forward pass, straight-through gradient estimator in backward pass 71 | 72 | z_q = rearrange(z_q, "b d t -> b t d") 73 | z_q = self.out_proj(z_q) 74 | z_q = rearrange(z_q, "b t d -> b d t") 75 | 76 | return z_q, indices, commit_loss 77 | 78 | def vq2emb(self, vq, proj=True): 79 | emb = self.embed_code(vq) 80 | if proj: 81 | emb = self.out_proj(emb) 82 | return emb 83 | 84 | def get_emb(self): 85 | return self.codebook.weight 86 | 87 | def embed_code(self, embed_id): 88 | return F.embedding(embed_id, self.codebook.weight) 89 | 90 | def decode_code(self, embed_id): 91 | return self.embed_code(embed_id).transpose(1, 2) 92 | 93 | def decode_latents(self, latents): 94 | encodings = rearrange(latents, "b d t -> (b t) d") 95 | codebook = self.codebook.weight # codebook: (N x D) 96 | 97 | # L2 normalize encodings and codebook 98 | encodings = F.normalize(encodings) 99 | codebook = F.normalize(codebook) 100 | 101 | # Compute euclidean distance with codebook 102 | dist = ( 103 | encodings.pow(2).sum(1, keepdim=True) 104 | - 2 * encodings @ codebook.t() 105 | + codebook.pow(2).sum(1, keepdim=True).t() 106 | ) 107 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 108 | z_q = self.decode_code(indices) 109 | return z_q, indices -------------------------------------------------------------------------------- /vq/module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from einops import rearrange 3 | from . import activations 4 | from .alias_free_torch import * 5 | from torch.nn.utils import weight_norm 6 | 7 | from typing import Optional, Tuple 8 | 9 | from torch.nn.utils import weight_norm, remove_weight_norm 10 | 11 | 12 | def WNConv1d(*args, **kwargs): 13 | return weight_norm(nn.Conv1d(*args, **kwargs)) 14 | 15 | 16 | def WNConvTranspose1d(*args, **kwargs): 17 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 18 | 19 | class ResidualUnit(nn.Module): 20 | def __init__(self, dim: int = 16, dilation: int = 1): 21 | super().__init__() 22 | pad = ((7 - 1) * dilation) // 2 23 | self.block = nn.Sequential( 24 | Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), 25 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), 26 | Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), 27 | WNConv1d(dim, dim, kernel_size=1), 28 | ) 29 | 30 | def forward(self, x): 31 | return x + self.block(x) 32 | 33 | class EncoderBlock(nn.Module): 34 | def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)): 35 | super().__init__() 36 | runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations] 37 | self.block = nn.Sequential( 38 | *runits, 39 | Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)), 40 | WNConv1d( 41 | dim // 2, 42 | dim, 43 | kernel_size=2 * stride, 44 | stride=stride, 45 | padding=stride // 2 + stride % 2, 46 | ), 47 | ) 48 | 49 | def forward(self, x): 50 | return self.block(x) 51 | 52 | class DecoderBlock(nn.Module): 53 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)): 54 | super().__init__() 55 | self.block = nn.Sequential( 56 | Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)), 57 | WNConvTranspose1d( 58 | input_dim, 59 | output_dim, 60 | kernel_size=2 * stride, 61 | stride=stride, 62 | padding=stride // 2 + stride % 2, 63 | output_padding= stride % 2, 64 | ) 65 | ) 66 | self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations]) 67 | 68 | def forward(self, x): 69 | return self.block(x) 70 | 71 | class ResLSTM(nn.Module): 72 | def __init__(self, dimension: int, 73 | num_layers: int = 2, 74 | bidirectional: bool = False, 75 | skip: bool = True): 76 | super().__init__() 77 | self.skip = skip 78 | self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2, 79 | num_layers, batch_first=True, 80 | bidirectional=bidirectional) 81 | 82 | def forward(self, x): 83 | """ 84 | Args: 85 | x: [B, F, T] 86 | 87 | Returns: 88 | y: [B, F, T] 89 | """ 90 | x = rearrange(x, "b f t -> b t f") 91 | y, _ = self.lstm(x) 92 | if self.skip: 93 | y = y + x 94 | y = rearrange(y, "b t f -> b f t") 95 | return y 96 | 97 | 98 | 99 | class ConvNeXtBlock(nn.Module): 100 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 101 | 102 | Args: 103 | dim (int): Number of input channels. 104 | intermediate_dim (int): Dimensionality of the intermediate layer. 105 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 106 | Defaults to None. 107 | adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. 108 | None means non-conditional LayerNorm. Defaults to None. 109 | """ 110 | 111 | def __init__( 112 | self, 113 | dim: int, 114 | intermediate_dim: int, 115 | layer_scale_init_value: float, 116 | adanorm_num_embeddings: Optional[int] = None, 117 | ): 118 | super().__init__() 119 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 120 | self.adanorm = adanorm_num_embeddings is not None 121 | if adanorm_num_embeddings: 122 | self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) 123 | else: 124 | self.norm = nn.LayerNorm(dim, eps=1e-6) 125 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 126 | self.act = nn.GELU() 127 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 128 | self.gamma = ( 129 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 130 | if layer_scale_init_value > 0 131 | else None 132 | ) 133 | 134 | def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: 135 | residual = x 136 | x = self.dwconv(x) 137 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 138 | if self.adanorm: 139 | assert cond_embedding_id is not None 140 | x = self.norm(x, cond_embedding_id) 141 | else: 142 | x = self.norm(x) 143 | x = self.pwconv1(x) 144 | x = self.act(x) 145 | x = self.pwconv2(x) 146 | if self.gamma is not None: 147 | x = self.gamma * x 148 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 149 | 150 | x = residual + x 151 | return x 152 | 153 | 154 | class AdaLayerNorm(nn.Module): 155 | """ 156 | Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes 157 | 158 | Args: 159 | num_embeddings (int): Number of embeddings. 160 | embedding_dim (int): Dimension of the embeddings. 161 | """ 162 | 163 | def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): 164 | super().__init__() 165 | self.eps = eps 166 | self.dim = embedding_dim 167 | self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 168 | self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) 169 | torch.nn.init.ones_(self.scale.weight) 170 | torch.nn.init.zeros_(self.shift.weight) 171 | 172 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: 173 | scale = self.scale(cond_embedding_id) 174 | shift = self.shift(cond_embedding_id) 175 | x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) 176 | x = x * scale + shift 177 | return x 178 | 179 | 180 | class ResBlock1(nn.Module): 181 | """ 182 | ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, 183 | but without upsampling layers. 184 | 185 | Args: 186 | dim (int): Number of input channels. 187 | kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. 188 | dilation (tuple[int], optional): Dilation factors for the dilated convolutions. 189 | Defaults to (1, 3, 5). 190 | lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. 191 | Defaults to 0.1. 192 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 193 | Defaults to None. 194 | """ 195 | 196 | def __init__( 197 | self, 198 | dim: int, 199 | kernel_size: int = 3, 200 | dilation: Tuple[int, int, int] = (1, 3, 5), 201 | lrelu_slope: float = 0.1, 202 | layer_scale_init_value: Optional[float] = None, 203 | ): 204 | super().__init__() 205 | self.lrelu_slope = lrelu_slope 206 | self.convs1 = nn.ModuleList( 207 | [ 208 | weight_norm( 209 | nn.Conv1d( 210 | dim, 211 | dim, 212 | kernel_size, 213 | 1, 214 | dilation=dilation[0], 215 | padding=self.get_padding(kernel_size, dilation[0]), 216 | ) 217 | ), 218 | weight_norm( 219 | nn.Conv1d( 220 | dim, 221 | dim, 222 | kernel_size, 223 | 1, 224 | dilation=dilation[1], 225 | padding=self.get_padding(kernel_size, dilation[1]), 226 | ) 227 | ), 228 | weight_norm( 229 | nn.Conv1d( 230 | dim, 231 | dim, 232 | kernel_size, 233 | 1, 234 | dilation=dilation[2], 235 | padding=self.get_padding(kernel_size, dilation[2]), 236 | ) 237 | ), 238 | ] 239 | ) 240 | 241 | self.convs2 = nn.ModuleList( 242 | [ 243 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 244 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 245 | weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), 246 | ] 247 | ) 248 | 249 | self.gamma = nn.ParameterList( 250 | [ 251 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 252 | if layer_scale_init_value is not None 253 | else None, 254 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 255 | if layer_scale_init_value is not None 256 | else None, 257 | nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) 258 | if layer_scale_init_value is not None 259 | else None, 260 | ] 261 | ) 262 | 263 | def forward(self, x: torch.Tensor) -> torch.Tensor: 264 | for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): 265 | xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) 266 | xt = c1(xt) 267 | xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) 268 | xt = c2(xt) 269 | if gamma is not None: 270 | xt = gamma * xt 271 | x = xt + x 272 | return x 273 | 274 | def remove_weight_norm(self): 275 | for l in self.convs1: 276 | remove_weight_norm(l) 277 | for l in self.convs2: 278 | remove_weight_norm(l) 279 | 280 | @staticmethod 281 | def get_padding(kernel_size: int, dilation: int = 1) -> int: 282 | return int((kernel_size * dilation - dilation) / 2) 283 | 284 | 285 | def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: 286 | """ 287 | Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. 288 | 289 | Args: 290 | x (Tensor): Input tensor. 291 | clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. 292 | 293 | Returns: 294 | Tensor: Element-wise logarithm of the input tensor with clipping applied. 295 | """ 296 | return torch.log(torch.clip(x, min=clip_val)) 297 | 298 | 299 | def symlog(x: torch.Tensor) -> torch.Tensor: 300 | return torch.sign(x) * torch.log1p(x.abs()) 301 | 302 | 303 | def symexp(x: torch.Tensor) -> torch.Tensor: 304 | return torch.sign(x) * (torch.exp(x.abs()) - 1) 305 | 306 | 307 | 308 | class SemanticEncoder(nn.Module): 309 | def __init__( 310 | self, 311 | input_channels: int, 312 | code_dim: int, 313 | encode_channels: int, 314 | kernel_size: int = 3, 315 | bias: bool = True, 316 | ): 317 | super(SemanticEncoder, self).__init__() 318 | 319 | # 初始卷积,将 input_channels 映射到 encode_channels 320 | self.initial_conv = nn.Conv1d( 321 | in_channels=input_channels, 322 | out_channels=encode_channels, 323 | kernel_size=kernel_size, 324 | stride=1, 325 | padding=(kernel_size - 1) // 2, 326 | bias=False 327 | ) 328 | 329 | # 残差块 330 | self.residual_blocks = nn.Sequential( 331 | nn.ReLU(inplace=True), 332 | nn.Conv1d( 333 | encode_channels, 334 | encode_channels, 335 | kernel_size=kernel_size, 336 | stride=1, 337 | padding=(kernel_size - 1) // 2, 338 | bias=bias 339 | ), 340 | nn.ReLU(inplace=True), 341 | nn.Conv1d( 342 | encode_channels, 343 | encode_channels, 344 | kernel_size=kernel_size, 345 | stride=1, 346 | padding=(kernel_size - 1) // 2, 347 | bias=bias 348 | ) 349 | ) 350 | 351 | # 最终卷积,将 encode_channels 映射到 code_dim 352 | self.final_conv = nn.Conv1d( 353 | in_channels=encode_channels, 354 | out_channels=code_dim, 355 | kernel_size=kernel_size, 356 | stride=1, 357 | padding=(kernel_size - 1) // 2, 358 | bias=False 359 | ) 360 | 361 | def forward(self, x): 362 | """ 363 | 前向传播方法。 364 | 365 | Args: 366 | x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length) 367 | 368 | Returns: 369 | Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length) 370 | """ 371 | x = self.initial_conv(x) # (Batch, Encode_channels, Length) 372 | x = self.residual_blocks(x) + x # 残差连接 373 | x = self.final_conv(x) # (Batch, Code_dim, Length) 374 | return x 375 | 376 | class SemanticDecoder(nn.Module): 377 | def __init__( 378 | self, 379 | code_dim: int, 380 | output_channels: int, 381 | decode_channels: int, 382 | kernel_size: int = 3, 383 | bias: bool = True, 384 | ): 385 | super(SemanticDecoder, self).__init__() 386 | 387 | # Initial convolution to map code_dim to decode_channels 388 | self.initial_conv = nn.Conv1d( 389 | in_channels=code_dim, 390 | out_channels=decode_channels, 391 | kernel_size=kernel_size, 392 | stride=1, 393 | padding=(kernel_size - 1) // 2, 394 | bias=False 395 | ) 396 | 397 | # Residual Blocks 398 | self.residual_blocks = nn.Sequential( 399 | nn.ReLU(inplace=True), 400 | nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias), 401 | nn.ReLU(inplace=True), 402 | nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias) 403 | ) 404 | 405 | # Final convolution to map decode_channels to output_channels 406 | self.final_conv = nn.Conv1d( 407 | in_channels=decode_channels, 408 | out_channels=output_channels, 409 | kernel_size=kernel_size, 410 | stride=1, 411 | padding=(kernel_size - 1) // 2, 412 | bias=False 413 | ) 414 | 415 | def forward(self, z): 416 | # z: (Batch, Code_dim, Length) 417 | x = self.initial_conv(z) # (Batch, Decode_channels, Length) 418 | x = self.residual_blocks(x) + x # Residual connection 419 | x = self.final_conv(x) # (Batch, Output_channels, Length) 420 | return x -------------------------------------------------------------------------------- /vq/residual_vq.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from .factorized_vector_quantize import FactorizedVectorQuantize 5 | 6 | class ResidualVQ(nn.Module): 7 | def __init__( 8 | self, 9 | *, 10 | num_quantizers, 11 | codebook_size, 12 | **kwargs 13 | ): 14 | super().__init__() 15 | VQ = FactorizedVectorQuantize 16 | if type(codebook_size) == int: 17 | codebook_size = [codebook_size] * num_quantizers 18 | self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size]) 19 | self.num_quantizers = num_quantizers 20 | 21 | def forward(self, x): 22 | quantized_out = 0. 23 | residual = x 24 | 25 | all_losses = [] 26 | all_indices = [] 27 | 28 | for idx, layer in enumerate(self.layers): 29 | quantized, indices, loss = layer(residual) 30 | 31 | residual = residual - quantized 32 | 33 | quantized_out = quantized_out + quantized 34 | 35 | loss = loss.mean() 36 | 37 | all_indices.append(indices) 38 | all_losses.append(loss) 39 | all_losses, all_indices = map(torch.stack, (all_losses, all_indices)) 40 | return quantized_out, all_indices, all_losses 41 | 42 | def vq2emb(self, vq, proj=True): 43 | # [B, T, num_quantizers] 44 | quantized_out = 0. 45 | for idx, layer in enumerate(self.layers): 46 | quantized = layer.vq2emb(vq[:, :, idx], proj=proj) 47 | quantized_out = quantized_out + quantized 48 | return quantized_out 49 | def get_emb(self): 50 | embs = [] 51 | for idx, layer in enumerate(self.layers): 52 | embs.append(layer.get_emb()) 53 | return embs 54 | -------------------------------------------------------------------------------- /vq/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | import numpy as np 6 | 7 | 8 | class EncoderBlock(nn.Module): 9 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): 10 | super(EncoderBlock, self).__init__() 11 | 12 | self.pool_size = 2 13 | 14 | self.conv_block = ConvBlock(in_channels, out_channels, kernel_size) 15 | 16 | def forward(self, x): 17 | latent = self.conv_block(x) 18 | output = F.avg_pool2d(latent, kernel_size=self.pool_size) 19 | return output, latent 20 | 21 | class DecoderBlock(nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): 23 | super(DecoderBlock, self).__init__() 24 | 25 | stride = 2 26 | 27 | self.upsample = nn.ConvTranspose2d( 28 | in_channels=in_channels, 29 | out_channels=in_channels, 30 | kernel_size=stride, 31 | stride=stride, 32 | padding=(0, 0), 33 | bias=False, 34 | ) 35 | 36 | self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size) 37 | 38 | def forward(self, x, latent): 39 | x = self.upsample(x) 40 | x = torch.cat((x, latent), dim=1) 41 | output = self.conv_block(x) 42 | return output 43 | 44 | 45 | class UNet(nn.Module): 46 | def __init__(self,freq_dim=1281,out_channel=1024): 47 | super(UNet, self).__init__() 48 | 49 | self.downsample_ratio = 16 50 | 51 | 52 | in_channels = 1 #self.audio_channels * self.cmplx_num 53 | 54 | self.encoder_block1 = EncoderBlock(in_channels, 16) 55 | self.encoder_block2 = EncoderBlock(16, 64) 56 | self.encoder_block3 = EncoderBlock(64, 256) 57 | self.encoder_block4 = EncoderBlock(256, 1024) 58 | self.middle = EncoderBlock(1024, 1024) 59 | self.decoder_block1 = DecoderBlock(1024, 256) 60 | self.decoder_block2 = DecoderBlock(256, 64) 61 | self.decoder_block3 = DecoderBlock(64, 16) 62 | self.decoder_block4 = DecoderBlock(16, 16) 63 | 64 | self.fc = nn.Linear(freq_dim*16, out_channel) 65 | 66 | def forward(self, x_ori): 67 | """ 68 | Args: 69 | complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量 70 | 71 | Returns: 72 | output: (batch_size, channels_num, time_steps, freq_bins),复数张量 73 | """ 74 | 75 | 76 | x= self.process_image(x_ori) 77 | x1, latent1 = self.encoder_block1(x) 78 | x2, latent2 = self.encoder_block2(x1) 79 | x3, latent3 = self.encoder_block3(x2) 80 | x4, latent4 = self.encoder_block4(x3) 81 | _, h = self.middle(x4) 82 | x5 = self.decoder_block1(h, latent4) 83 | x6 = self.decoder_block2(x5, latent3) 84 | x7 = self.decoder_block3(x6, latent2) 85 | x8 = self.decoder_block4(x7, latent1) 86 | x= self.unprocess_image(x8,x_ori.shape[2]) 87 | x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024] 88 | x = x.view(x.size(0), x.size(1), -1) 89 | x= self.fc(x) 90 | 91 | return x 92 | 93 | def process_image(self, x): 94 | """ 95 | 处理频谱以便可以被 downsample_ratio 整除。 96 | 97 | Args: 98 | x: (B, C, T, F) 99 | 100 | Returns: 101 | output: (B, C, T_padded, F_reduced) 102 | """ 103 | 104 | B, C, T, Freq = x.shape 105 | 106 | pad_len = ( 107 | int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio 108 | - T 109 | ) 110 | x = F.pad(x, pad=(0, 0, 0, pad_len)) 111 | 112 | output = x[:, :, :, 0 : Freq - 1] 113 | 114 | return output 115 | 116 | def unprocess_image(self, x,time_steps): 117 | """ 118 | 恢复频谱到原始形状。 119 | 120 | Args: 121 | x: (B, C, T_padded, F_reduced) 122 | 123 | Returns: 124 | output: (B, C, T_original, F_original) 125 | """ 126 | x = F.pad(x, pad=(0, 1)) 127 | 128 | output = x[:, :,0:time_steps, :] 129 | 130 | return output 131 | 132 | class ConvBlock(nn.Module): 133 | def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): 134 | super(ConvBlock, self).__init__() 135 | 136 | padding = [kernel_size[0] // 2, kernel_size[1] // 2] 137 | 138 | self.bn1 = nn.BatchNorm2d(in_channels) 139 | self.bn2 = nn.BatchNorm2d(out_channels) 140 | 141 | self.conv1 = nn.Conv2d( 142 | in_channels=in_channels, 143 | out_channels=out_channels, 144 | kernel_size=kernel_size, 145 | padding=padding, 146 | bias=False, 147 | ) 148 | 149 | self.conv2 = nn.Conv2d( 150 | in_channels=out_channels, 151 | out_channels=out_channels, 152 | kernel_size=kernel_size, 153 | padding=padding, 154 | bias=False, 155 | ) 156 | 157 | if in_channels != out_channels: 158 | self.shortcut = nn.Conv2d( 159 | in_channels=in_channels, 160 | out_channels=out_channels, 161 | kernel_size=(1, 1), 162 | padding=(0, 0), 163 | ) 164 | self.is_shortcut = True 165 | else: 166 | self.is_shortcut = False 167 | 168 | def forward(self, x): 169 | h = self.conv1(F.leaky_relu_(self.bn1(x))) 170 | h = self.conv2(F.leaky_relu_(self.bn2(h))) 171 | 172 | if self.is_shortcut: 173 | return self.shortcut(x) + h 174 | else: 175 | return x + h 176 | 177 | 178 | def test_unet(): 179 | # 定义输入参数 180 | batch_size = 6 181 | channels = 1 # 音频通道数 182 | time_steps = 256 # 时间步数 183 | freq_bins = 1024 # 频率 bins 数 184 | 185 | # 创建一个随机的复数张量作为输入 186 | real_part = torch.randn(batch_size, channels, time_steps, freq_bins) 187 | imag_part = torch.randn(batch_size, channels, time_steps, freq_bins) 188 | complex_sp = real_part #torch.complex(real_part, imag_part) 189 | 190 | # 实例化 UNet 模型 191 | model = UNet() 192 | 193 | # 前向传播 194 | output = model(complex_sp) 195 | 196 | # 输出输入和输出的形状 197 | print("输入形状:", complex_sp.shape) 198 | print("输出形状:", output.shape) 199 | 200 | # 检查输出是否为复数张量 201 | assert torch.is_complex(output), "输出不是复数张量" 202 | 203 | # 检查输出形状是否与输入形状一致 204 | assert output.shape == complex_sp.shape, "输出形状与输入形状不一致" 205 | 206 | print("测试通过,模型正常工作。") 207 | 208 | # 运行测试函数 209 | if __name__ == "__main__": 210 | test_unet() --------------------------------------------------------------------------------