├── .gitignore ├── README.md ├── README_en.md ├── configs ├── ar_model.json ├── ar_model.yaml ├── finetune_ar_model.yaml ├── finetune_nar_model.yaml ├── nar_model.json └── nar_model.yaml ├── requirements.txt ├── res ├── o1.wav ├── o2.wav ├── o3.wav ├── o4.wav ├── o5.wav ├── test-in.wav ├── vclm-ar.png └── vclm-nar.png ├── run.py ├── sh ├── run.sh ├── train_ar_model.sh ├── train_finetune_ar_model.sh ├── train_finetune_nar_model.sh └── train_nar_model.sh ├── tests ├── test_ar_datamodule.py ├── test_ar_dataset.py ├── test_nar_datamodule.py ├── test_nar_dataset.py └── test_whisper_encoder.py ├── tools ├── construct_dataset.py ├── construct_parallel_dataset.py ├── construct_wavs_file.py ├── extract_whisper_encoder_model.py ├── save_ar_model.py └── save_model.py └── vc_lm ├── __init__.py ├── callbacks └── __init__.py ├── datamodules ├── __init__.py ├── ar_datamodule.py ├── datasets │ ├── __init__.py │ ├── ar_dataset.py │ └── nar_dataset.py └── nar_datamodule.py ├── models ├── __init__.py ├── ar_model_pl.py ├── bart │ ├── __init__.py │ ├── configuration_bart.py │ └── modeling_bart.py ├── base.py ├── decoders │ ├── __init__.py │ ├── ar_decoder.py │ ├── layers.py │ └── nar_decoder.py ├── encoders │ ├── __init__.py │ └── whisper_encoder.py ├── misc.py ├── models │ ├── __init__.py │ ├── ar_model.py │ └── nar_model.py └── nar_model_pl.py ├── utils ├── __init__.py └── data_utils.py └── vc_engine.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | local/ 3 | __pycache__/ 4 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vc-lm 2 | [**中文**](./README.md) | [**English**](./README_en.md) 3 | 4 | vc-lm是一个可以将任意人的音色转换为成千上万种不同音色的音频的项目。 5 | 6 | ## 🔄 最近更新 7 | * [2023/06/09] 新增Any-to-One声音转换模型训练. 8 | 9 | ## 算法架构 10 | 该项目参考论文 [Vall-E](https://arxiv.org/abs/2301.02111) 11 | 12 | 使用[encodec](https://github.com/facebookresearch/encodec), 13 | 将音频离散化成tokens, 在tokens上构建transformer语言模型。 14 | 该项目包含两阶段模型 AR模型和NAR模型。 15 | 16 | 输入: 3s音色prompt音频 + 被转换音频 17 | 18 | 输出: 转换后音频 19 | 20 | 在训练阶段,采用了自监督的方式,其中源音频和目标音频是相同的。 21 | ### AR阶段 22 | 输入: prompt音频 + 源音频 23 | 24 | 输出: 目标音频 0 level tokens 25 | 26 | ![ar](res/vclm-ar.png) 27 | 28 | ### NAR阶段 29 | 输入: 目标音频(0~k)level tokens 30 | 31 | 输出: 目标音频k+1 level tokens 32 | 33 | ![nar](res/vclm-nar.png) 34 | 35 | ## 构造数据集 36 | 37 | ``` 38 | # 所有wav文件先处理成长度10~24s的文件, 参考文件[tools/construct_wavs_file.py] 39 | python tools/construct_dataset.py 40 | ``` 41 | ## 转换whisper encoder模型 42 | 43 | ``` 44 | python tools/extract_whisper_encoder_model.py --input_model=../whisper/medium.pt --output_model=../whisper-encoder/medium-encoder.pt 45 | ``` 46 | ## 训练 47 | ``` 48 | bash ./sh/train_ar_model.sh 49 | bash ./sh/train_nar_model.sh 50 | ``` 51 | ## 推理 52 | ``` 53 | from vc_lm.vc_engine import VCEngine 54 | engine = VCEngine('/root/autodl-tmp/vc-models/ar.ckpt', 55 | '/root/autodl-tmp/vc-models/nar.ckpt', 56 | '/root/project/vc-lm/configs/ar_model.json', 57 | '/root/project/vc-lm/configs/nar_model.json') 58 | output_wav = engine.process_audio(content_wav, 59 | style_wav, max_style_len=3, use_ar=True) 60 | ``` 61 | 62 | ## 样例展示 63 | [输入音频](res/test-in.wav) 64 | 65 | [输出音频1](res/o1.wav) 66 | 67 | [输出音频2](res/o2.wav) 68 | 69 | [输出音频3](res/o3.wav) 70 | 71 | [输出音频4](res/o4.wav) 72 | 73 | [输出音频5](res/o5.wav) 74 | 75 | --- 76 | ``` 77 | 本项目模型可以生成大量one-to-any的平行数据(也就是any-to-one)。这些平行数据可以被用来训练 Any-to-One 的变声模型。 78 | ``` 79 | --- 80 | ## 训练Any-to-One VC模型 81 | 目标人数据仅需10分钟,即可达到很好的效果。 82 | 83 | ### 构造数据集 84 | ``` 85 | # 所有wav文件先处理成长度10~24s的文件, 参考文件[tools/construct_wavs_file.py] 86 | python tools/construct_dataset.py 87 | ``` 88 | 89 | ### 构造Any-to-one平行数据 90 | ``` 91 | # 需要构造train, val, test数据 92 | python tools.construct_parallel_dataset.py 93 | ``` 94 | ### 训练模型 95 | 加载上面的预训练模型,在指定人数据上训练。 96 | ``` 97 | bash ./sh/train_finetune_ar_model.sh 98 | bash ./sh/train_finetune_nar_model.sh 99 | ``` 100 | 101 | ### 推理 102 | ``` 103 | from vc_lm.vc_engine import VCEngine 104 | engine = VCEngine('/root/autodl-tmp/vc-models/jr-ar.ckpt', 105 | '/root/autodl-tmp/vc-models/jr-nar.ckpt', 106 | '/root/project/vc-lm/configs/ar_model.json', 107 | '/root/project/vc-lm/configs/nar_model.json') 108 | output_wav = engine.process_audio(content_wav, 109 | style_wav, max_style_len=3, use_ar=True) 110 | ``` 111 | ### DEMO 112 | #### 输入音频: 113 | https://github.com/nilboy/vc-lm/assets/17962699/d9c7fb99-7d34-468b-a376-1c8c882d97e2 114 | #### 输出音频: 115 | https://github.com/nilboy/vc-lm/assets/17962699/7a7620d7-e71b-4655-8ad4-2fb543c92960 116 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | # vc-lm 2 | [**中文**](./README.md) | [**English**](./README_en.md) 3 | 4 | vc-lm is a project that can transform anyone's voice into thousands of different voices in audio. 5 | 6 | ## 🔄 What‘s new 7 | * [2023/06/09] Support Any-to-One voice conversion model. 8 | 9 | ## Algorithm Architecture 10 | This project references the paper [Vall-E](https://arxiv.org/abs/2301.02111) 11 | 12 | It uses [encodec](https://github.com/facebookresearch/encodec) 13 | to discretize audio into tokens and build a transformer language model on tokens. The project consists of two-stage models: AR model and NAR model. 14 | 15 | Input: 3-second voice prompt audio + voice to be transformed 16 | 17 | Output: Transformed audio 18 | 19 | During the training phase, a self-supervised approach is used where the source audio and target audio are the same. 20 | ### AR Stage 21 | Input: Prompt audio + source audio 22 | 23 | Output: Target audio with 0-level tokens 24 | 25 | ![ar](res/vclm-ar.png) 26 | 27 | ### NAR Stage 28 | Input: Target audio with 0 to k-level tokens 29 | 30 | Output: Target audio with k+1 level tokens 31 | 32 | ![nar](res/vclm-nar.png) 33 | 34 | ## Dataset Construction 35 | 36 | ``` 37 | # All WAV files are first processed into files with a length of 10 to 24 seconds. Reference to[[tools/construct_wavs_file.py] 38 | python tools/construct_dataset.py 39 | ``` 40 | ## Convert Whisper Encoder Model 41 | 42 | ``` 43 | python tools/extract_whisper_encoder_model.py --input_model=../whisper/medium.pt --output_model=../whisper-encoder/medium-encoder.pt 44 | ``` 45 | ## Training 46 | ``` 47 | bash ./sh/train_ar_model.sh 48 | bash ./sh/train_nar_model.sh 49 | ``` 50 | ## Inference 51 | ``` 52 | from vc_lm.vc_engine import VCEngine 53 | engine = VCEngine('/root/autodl-tmp/vc-models/ar.ckpt', 54 | '/root/autodl-tmp/vc-models/nar.ckpt', 55 | '/root/project/vc-lm/configs/ar_model.json', 56 | '/root/project/vc-lm/configs/nar_model.json') 57 | output_wav = engine.process_audio(content_wav, 58 | style_wav, max_style_len=3, use_ar=True) 59 | ``` 60 | 61 | ## Models 62 | The models were trained on the Wenetspeech dataset, which consists of thousands of hours of audio data, including the AR model and NAR model. 63 | 64 | Model download link: 65 | 66 | Link: https://pan.baidu.com/s/1bJUXrSH7tJ1QLPTv3tZzRQ 67 | Extract code: 4kao 68 | 69 | ## Examples 70 | [Input Audio](res/test-in.wav) 71 | 72 | [Output Audio 1](res/o1.wav) 73 | 74 | [Output Audio 2](res/o2.wav) 75 | 76 | [Output Audio 3](res/o3.wav) 77 | 78 | [Output Audio 4](res/o4.wav) 79 | 80 | [Output Audio 5](res/o5.wav) 81 | 82 | --- 83 | ``` 84 | This project's models can generate a large number of one-to-any parallel data (i.e., any-to-one). These parallel data can be used to train any-to-one voice conversion models. 85 | ``` 86 | ## Training Any-to-One VC Model 87 | The target speaker's data achieves excellent results in just 10 minutes. 88 | 89 | ### Dataset Construction 90 | ``` 91 | # All WAV files are first processed into files with a length of 10 to 24 seconds. Reference to[[tools/construct_wavs_file.py] 92 | python tools/construct_dataset.py 93 | ``` 94 | 95 | ### Constructing Any-to-One Parallel Data 96 | ``` 97 | # Construct train, val, test data 98 | python tools.construct_parallel_dataset.py 99 | ``` 100 | ### Training 101 | Load the pre-trained model mentioned above and train it on the specified speaker's data. 102 | ``` 103 | bash ./sh/train_finetune_ar_model.sh 104 | bash ./sh/train_finetune_nar_model.sh 105 | ``` 106 | 107 | ### Inference 108 | ``` 109 | from vc_lm.vc_engine import VCEngine 110 | engine = VCEngine('/root/autodl-tmp/vc-models/jr-ar.ckpt', 111 | '/root/autodl-tmp/vc-models/jr-nar.ckpt', 112 | '/root/project/vc-lm/configs/ar_model.json', 113 | '/root/project/vc-lm/configs/nar_model.json') 114 | output_wav = engine.process_audio(content_wav, 115 | style_wav, max_style_len=3, use_ar=True) 116 | ``` 117 | ### DEMO 118 | #### Input Audio: 119 | https://github.com/nilboy/vc-lm/assets/17962699/d9c7fb99-7d34-468b-a376-1c8c882d97e2 120 | #### Output Audio: 121 | https://github.com/nilboy/vc-lm/assets/17962699/7a7620d7-e71b-4655-8ad4-2fb543c92960 122 | -------------------------------------------------------------------------------- /configs/ar_model.json: -------------------------------------------------------------------------------- 1 | { 2 | "content_layer_num": -1, 3 | "n_q": 8, 4 | "q_size": 1024, 5 | "vocab_size": 8195, 6 | "max_position_embeddings": 2048, 7 | "style_length": 225, 8 | "d_model": 1024, 9 | "decoder_ffn_dim": 4096, 10 | "decoder_layers": 10, 11 | "decoder_attention_heads": 16, 12 | "dropout": 0.1, 13 | "attention_dropout": 0.0, 14 | "activation_dropout": 0.0, 15 | "activation_function": "gelu", 16 | "init_std": 0.02, 17 | "decoder_layerdrop": 0.0, 18 | "classifier_dropout": 0.0, 19 | "use_cache": true, 20 | "scale_embedding": false, 21 | "encoder_model_path": "/root/autodl-tmp/pretrained-models/whisper/medium-encoder.pt", 22 | "return_dict": false, 23 | "output_hidden_states": false, 24 | "output_attentions": false, 25 | "torchscript": false, 26 | "torch_dtype": null, 27 | "use_bfloat16": false, 28 | "tf_legacy_loss": false, 29 | "pruned_heads": {}, 30 | "tie_word_embeddings": true, 31 | "is_encoder_decoder": true, 32 | "is_decoder": false, 33 | "cross_attention_hidden_size": null, 34 | "add_cross_attention": false, 35 | "tie_encoder_decoder": false, 36 | "max_length": 20, 37 | "min_length": 0, 38 | "do_sample": false, 39 | "early_stopping": false, 40 | "num_beams": 1, 41 | "num_beam_groups": 1, 42 | "diversity_penalty": 0.0, 43 | "temperature": 1.0, 44 | "top_k": 50, 45 | "top_p": 1.0, 46 | "typical_p": 1.0, 47 | "repetition_penalty": 1.0, 48 | "length_penalty": 1.0, 49 | "no_repeat_ngram_size": 0, 50 | "encoder_no_repeat_ngram_size": 0, 51 | "bad_words_ids": null, 52 | "num_return_sequences": 1, 53 | "chunk_size_feed_forward": 0, 54 | "output_scores": false, 55 | "return_dict_in_generate": false, 56 | "forced_bos_token_id": null, 57 | "forced_eos_token_id": 8192, 58 | "remove_invalid_values": false, 59 | "exponential_decay_length_penalty": null, 60 | "architectures": null, 61 | "finetuning_task": null, 62 | "id2label": { 63 | "0": "LABEL_0", 64 | "1": "LABEL_1", 65 | "2": "LABEL_2" 66 | }, 67 | "label2id": { 68 | "LABEL_0": 0, 69 | "LABEL_1": 1, 70 | "LABEL_2": 2 71 | }, 72 | "tokenizer_class": null, 73 | "prefix": null, 74 | "bos_token_id": 8194, 75 | "pad_token_id": 8193, 76 | "eos_token_id": 8192, 77 | "sep_token_id": null, 78 | "decoder_start_token_id": 2, 79 | "task_specific_params": null, 80 | "problem_type": null, 81 | "_name_or_path": "", 82 | "transformers_version": "4.22.2", 83 | "model_type": "VCLM" 84 | } -------------------------------------------------------------------------------- /configs/ar_model.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | #resume_from_checkpoint: "/root/autodl-tmp/models/ar/last1.ckpt" 4 | logger: true 5 | default_root_dir: "/root/autodl-tmp/models/ar" 6 | accelerator: gpu 7 | devices: 1 8 | strategy: ddp_find_unused_parameters_false 9 | accumulate_grad_batches: 3 10 | precision: 16 11 | val_check_interval: 1000 12 | max_steps: 300000 13 | gradient_clip_val: 0.5 14 | callbacks: 15 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 16 | init_args: 17 | save_top_k: 1 # save k best models (determined by above metric) 18 | monitor: 'val/loss' 19 | mode: 'min' 20 | save_last: True # additionaly always save model from last epoch 21 | verbose: True 22 | dirpath: "/root/autodl-tmp/models/ar" 23 | filename: "epoch_{epoch}_{step}" 24 | auto_insert_metric_name: False 25 | every_n_train_steps: 1000 26 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 27 | init_args: 28 | logging_interval: step 29 | 30 | model: 31 | class_path: vc_lm.models.ar_model_pl.ARModelPL 32 | init_args: 33 | config_file: configs/ar_model.json 34 | lr: 0.00002 35 | weight_decay: 0.01 36 | warmup_step: 150 37 | max_iters: 60000 38 | 39 | data: 40 | class_path: vc_lm.datamodules.ar_datamodule.ARDataModule 41 | init_args: 42 | data_dir: "/root/autodl-tmp/data/wds" 43 | batch_size: 10 44 | max_audio_time: 24 45 | num_workers: 2 46 | # 2262692 47 | train_dataset_size: 1262692 48 | train_pattern: "shard-{000100..000290}.tar" 49 | val_dataset_size: 800 50 | val_pattern: "shard-{000000..000009}.tar" 51 | 52 | -------------------------------------------------------------------------------- /configs/finetune_ar_model.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | #resume_from_checkpoint: "/root/autodl-tmp/models/jr-ar/last.ckpt" 4 | logger: true 5 | default_root_dir: "/root/autodl-tmp/models/jr-ar" 6 | accelerator: gpu 7 | devices: 1 8 | strategy: ddp_find_unused_parameters_false 9 | accumulate_grad_batches: 2 10 | precision: 16 11 | val_check_interval: 100 12 | max_steps: 300000 13 | gradient_clip_val: 0.5 14 | callbacks: 15 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 16 | init_args: 17 | save_top_k: 1 # save k best models (determined by above metric) 18 | monitor: 'val/loss' 19 | mode: 'min' 20 | save_last: True # additionaly always save model from last epoch 21 | verbose: True 22 | dirpath: "/root/autodl-tmp/models/jr-ar" 23 | filename: "epoch_{epoch}_{step}" 24 | auto_insert_metric_name: False 25 | every_n_train_steps: 100 26 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 27 | init_args: 28 | logging_interval: step 29 | 30 | model: 31 | class_path: vc_lm.models.ar_model_pl.ARModelPL 32 | init_args: 33 | config_file: configs/ar_model.json 34 | lr: 0.00002 35 | weight_decay: 0.01 36 | warmup_step: 100 37 | max_iters: 300000 38 | load_pretrain: True 39 | pretrain_model_path: /root/autodl-tmp/vc-models/ar-1024 40 | 41 | 42 | data: 43 | class_path: vc_lm.datamodules.ar_datamodule.ARDataModule 44 | init_args: 45 | data_dir: "/root/autodl-tmp/data/jr-wds-pair" 46 | batch_size: 4 47 | max_audio_time: 24 48 | num_workers: 1 49 | # 2262692 50 | train_dataset_size: 7714 51 | train_pattern: "shard-{000000..000001}.tar" 52 | val_dataset_size: 144 53 | val_pattern: "shard-{000000..000000}.tar" 54 | 55 | -------------------------------------------------------------------------------- /configs/finetune_nar_model.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | #resume_from_checkpoint: "/root/autodl-tmp/models/lyh-nar/last.ckpt" 4 | logger: true 5 | default_root_dir: "/root/autodl-tmp/models/jr-nar" 6 | accelerator: gpu 7 | devices: 1 8 | strategy: ddp_find_unused_parameters_false 9 | accumulate_grad_batches: 1 10 | precision: 16 11 | val_check_interval: 100 12 | max_steps: 300000 13 | gradient_clip_val: 0.5 14 | callbacks: 15 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 16 | init_args: 17 | save_top_k: 1 # save k best models (determined by above metric) 18 | monitor: 'val/loss' 19 | mode: 'min' 20 | save_last: True # additionaly always save model from last epoch 21 | verbose: True 22 | dirpath: "/root/autodl-tmp/models/jr-nar" 23 | filename: "epoch_{epoch}_{step}" 24 | auto_insert_metric_name: False 25 | every_n_train_steps: 100 26 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 27 | init_args: 28 | logging_interval: step 29 | 30 | model: 31 | class_path: vc_lm.models.nar_model_pl.NARModelPL 32 | init_args: 33 | config_file: configs/nar_model.json 34 | lr: 0.00002 35 | weight_decay: 0.01 36 | warmup_step: 100 37 | max_iters: 300000 38 | load_pretrain: True 39 | pretrain_model_path: /root/autodl-tmp/vc-models/nar-1024 40 | 41 | data: 42 | class_path: vc_lm.datamodules.nar_datamodule.NARDataModule 43 | init_args: 44 | data_dir: "/root/autodl-tmp/data/jr-wds-pair" 45 | batch_size: 4 46 | max_audio_time: 24 47 | num_workers: 1 48 | # 2262692 49 | train_dataset_size: 7714 50 | train_pattern: "shard-{000000..000001}.tar" 51 | val_dataset_size: 144 52 | val_pattern: "shard-{000000..000000}.tar" 53 | -------------------------------------------------------------------------------- /configs/nar_model.json: -------------------------------------------------------------------------------- 1 | { 2 | "content_layer_num": -1, 3 | "n_q": 8, 4 | "q_size": 1024, 5 | "vocab_size": 8195, 6 | "max_position_embeddings": 2048, 7 | "style_length": 225, 8 | "d_model": 1024, 9 | "decoder_ffn_dim": 4096, 10 | "decoder_layers": 10, 11 | "decoder_attention_heads": 16, 12 | "dropout": 0.1, 13 | "attention_dropout": 0.0, 14 | "activation_dropout": 0.0, 15 | "activation_function": "gelu", 16 | "init_std": 0.02, 17 | "decoder_layerdrop": 0.0, 18 | "classifier_dropout": 0.0, 19 | "use_cache": true, 20 | "scale_embedding": false, 21 | "encoder_model_path": "/root/autodl-tmp/pretrained-models/whisper/medium-encoder.pt", 22 | "return_dict": false, 23 | "output_hidden_states": false, 24 | "output_attentions": false, 25 | "torchscript": false, 26 | "torch_dtype": null, 27 | "use_bfloat16": false, 28 | "tf_legacy_loss": false, 29 | "pruned_heads": {}, 30 | "tie_word_embeddings": true, 31 | "is_encoder_decoder": true, 32 | "is_decoder": false, 33 | "cross_attention_hidden_size": null, 34 | "add_cross_attention": false, 35 | "tie_encoder_decoder": false, 36 | "max_length": 20, 37 | "min_length": 0, 38 | "do_sample": false, 39 | "early_stopping": false, 40 | "num_beams": 1, 41 | "num_beam_groups": 1, 42 | "diversity_penalty": 0.0, 43 | "temperature": 1.0, 44 | "top_k": 50, 45 | "top_p": 1.0, 46 | "typical_p": 1.0, 47 | "repetition_penalty": 1.0, 48 | "length_penalty": 1.0, 49 | "no_repeat_ngram_size": 0, 50 | "encoder_no_repeat_ngram_size": 0, 51 | "bad_words_ids": null, 52 | "num_return_sequences": 1, 53 | "chunk_size_feed_forward": 0, 54 | "output_scores": false, 55 | "return_dict_in_generate": false, 56 | "forced_bos_token_id": null, 57 | "forced_eos_token_id": 8192, 58 | "remove_invalid_values": false, 59 | "exponential_decay_length_penalty": null, 60 | "architectures": null, 61 | "finetuning_task": null, 62 | "id2label": { 63 | "0": "LABEL_0", 64 | "1": "LABEL_1", 65 | "2": "LABEL_2" 66 | }, 67 | "label2id": { 68 | "LABEL_0": 0, 69 | "LABEL_1": 1, 70 | "LABEL_2": 2 71 | }, 72 | "tokenizer_class": null, 73 | "prefix": null, 74 | "bos_token_id": 8194, 75 | "pad_token_id": 8193, 76 | "eos_token_id": 8192, 77 | "sep_token_id": null, 78 | "decoder_start_token_id": 2, 79 | "task_specific_params": null, 80 | "problem_type": null, 81 | "_name_or_path": "", 82 | "transformers_version": "4.22.2", 83 | "model_type": "VCLM" 84 | } -------------------------------------------------------------------------------- /configs/nar_model.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | #resume_from_checkpoint: "/root/autodl-tmp/models/nar/last1.ckpt" 4 | logger: true 5 | default_root_dir: "/root/autodl-tmp/models/nar" 6 | accelerator: gpu 7 | devices: 1 8 | strategy: ddp_find_unused_parameters_false 9 | accumulate_grad_batches: 3 10 | precision: 16 11 | val_check_interval: 1000 12 | max_steps: 300000 13 | gradient_clip_val: 0.5 14 | callbacks: 15 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 16 | init_args: 17 | save_top_k: 3 # save k best models (determined by above metric) 18 | monitor: 'val/loss' 19 | mode: 'min' 20 | save_last: True # additionaly always save model from last epoch 21 | verbose: True 22 | dirpath: "/root/autodl-tmp/models/nar" 23 | filename: "epoch_{epoch}_{step}" 24 | auto_insert_metric_name: False 25 | every_n_train_steps: 1000 26 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 27 | init_args: 28 | logging_interval: step 29 | 30 | model: 31 | class_path: vc_lm.models.nar_model_pl.NARModelPL 32 | init_args: 33 | config_file: configs/nar_model.json 34 | lr: 0.00002 35 | weight_decay: 0.01 36 | warmup_step: 150 37 | max_iters: 60000 38 | 39 | data: 40 | class_path: vc_lm.datamodules.nar_datamodule.NARDataModule 41 | init_args: 42 | data_dir: "/root/autodl-tmp/data/wds" 43 | batch_size: 8 44 | max_audio_time: 24 45 | num_workers: 2 46 | # 2262692 47 | train_dataset_size: 1262692 48 | train_pattern: "shard-{000100..000290}.tar" 49 | val_dataset_size: 800 50 | val_pattern: "shard-{000000..000009}.tar" 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning=1.6.5 2 | jsonargparse[signatures] 3 | transformers==4.26.0 4 | git+https://github.com/openai/whisper.git 5 | einops 6 | torchaudio -------------------------------------------------------------------------------- /res/o1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/o1.wav -------------------------------------------------------------------------------- /res/o2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/o2.wav -------------------------------------------------------------------------------- /res/o3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/o3.wav -------------------------------------------------------------------------------- /res/o4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/o4.wav -------------------------------------------------------------------------------- /res/o5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/o5.wav -------------------------------------------------------------------------------- /res/test-in.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/test-in.wav -------------------------------------------------------------------------------- /res/vclm-ar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/vclm-ar.png -------------------------------------------------------------------------------- /res/vclm-nar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/res/vclm-nar.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pytorch_lightning.utilities.cli import LightningCLI 3 | from pytorch_lightning import LightningDataModule, LightningModule 4 | sys.path.insert(0, '.') 5 | 6 | def cli_main(): 7 | cli = LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True) 8 | 9 | 10 | if __name__ == "__main__": 11 | cli_main() 12 | -------------------------------------------------------------------------------- /sh/run.sh: -------------------------------------------------------------------------------- 1 | python tools/construct_parallel_dataset.py 2 | 3 | nohup bash ./sh/train_finetune_ar_model.sh & 4 | nohup bash ./sh/train_finetune_nar_model.sh & -------------------------------------------------------------------------------- /sh/train_ar_model.sh: -------------------------------------------------------------------------------- 1 | python run.py fit --config configs/ar_model.yaml -------------------------------------------------------------------------------- /sh/train_finetune_ar_model.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python run.py fit --config configs/finetune_ar_model.yaml -------------------------------------------------------------------------------- /sh/train_finetune_nar_model.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=1 2 | python run.py fit --config configs/finetune_nar_model.yaml -------------------------------------------------------------------------------- /sh/train_nar_model.sh: -------------------------------------------------------------------------------- 1 | python run.py fit --config configs/nar_model.yaml -------------------------------------------------------------------------------- /tests/test_ar_datamodule.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from vc_lm.datamodules.ar_datamodule import ARDataModule 4 | 5 | from tqdm.auto import tqdm 6 | 7 | class TestArDataModule(unittest.TestCase): 8 | def setUp(self) -> None: 9 | self.data_module = ARDataModule('/home/jiangxinghua/data/vc-lm/audios-dataset-small', 10 | batch_size=2, 11 | max_audio_time=24, num_workers=2, train_dataset_size=1000, train_pattern="shard-{000000..000012}.tar", 12 | val_dataset_size=300, val_pattern="shard-{000000..000009}.tar") 13 | 14 | def test_ar_datamodule(self): 15 | self.data_module.prepare_data() 16 | self.data_module.setup() 17 | assert self.data_module.train_dataloader() is not None and self.data_module.val_dataloader() is not None and self.data_module.test_dataloader() is not None 18 | item = next(iter(self.data_module.train_dataloader())) 19 | self.assertTrue(True) 20 | 21 | if __name__ == '__main__': 22 | unittest.main() -------------------------------------------------------------------------------- /tests/test_ar_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from vc_lm.datamodules.datasets.ar_dataset import ARDataset 4 | 5 | class TestARDataset(unittest.TestCase): 6 | def setUp(self) -> None: 7 | self.dataset = ARDataset('/home/jiangxinghua/data/vc-lm/wds/train') 8 | 9 | def test_ar_dataset(self): 10 | item = next(iter(self.dataset.get_dataset())) 11 | self.assertTrue(True) 12 | 13 | if __name__ == '__main__': 14 | unittest.main() -------------------------------------------------------------------------------- /tests/test_nar_datamodule.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from vc_lm.datamodules.nar_datamodule import NARDataModule 4 | 5 | 6 | class TestArDataModule(unittest.TestCase): 7 | def setUp(self) -> None: 8 | self.data_module = NARDataModule('/root/autodl-tmp/data/vc-lm-sample', 9 | batch_size=4) 10 | 11 | def test_ar_datamodule(self): 12 | self.data_module.prepare_data() 13 | self.data_module.setup() 14 | assert self.data_module.train_dataloader() and self.data_module.val_dataloader() and self.data_module.test_dataloader() 15 | item = next(iter(self.data_module.train_dataloader())) 16 | self.assertTrue(True) 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /tests/test_nar_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from vc_lm.datamodules.datasets.nar_dataset import NARDataset 4 | 5 | class TestARDataset(unittest.TestCase): 6 | def setUp(self) -> None: 7 | self.dataset = NARDataset('/root/autodl-tmp/data/vc-lm-sample/train') 8 | 9 | def test_nar_dataset(self): 10 | item = next(iter(self.dataset)) 11 | self.assertTrue(True) 12 | 13 | if __name__ == '__main__': 14 | unittest.main() 15 | -------------------------------------------------------------------------------- /tests/test_whisper_encoder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | import torch 4 | 5 | from vc_lm.models.encoders.whisper_encoder import WhisperEncoder 6 | from vc_lm.models.base import VCLMConfig 7 | 8 | 9 | class TestWhisperEncoder(unittest.TestCase): 10 | def setUp(self) -> None: 11 | config = VCLMConfig(**json.load(open("configs/ar_model.json"))) 12 | self.model = WhisperEncoder(config) 13 | self.model.cuda() 14 | 15 | def test_whisper_encoder(self): 16 | mels = torch.rand([2, 80, 3000]).cuda() 17 | with torch.inference_mode(): 18 | content_feats = self.model(mels) 19 | self.assertEqual(list(content_feats.shape), [2, 1500, 1024]) 20 | 21 | 22 | if __name__ == '__main__': 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /tools/construct_dataset.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import glob 3 | import math 4 | import numpy as np 5 | import os 6 | from typing import List, Any 7 | from tqdm.auto import tqdm 8 | from joblib.parallel import Parallel, delayed 9 | 10 | from encodec import EncodecModel 11 | from encodec.utils import convert_audio 12 | 13 | import torchaudio 14 | import torch 15 | 16 | from whisper.audio import log_mel_spectrogram 17 | 18 | import webdataset as wds 19 | 20 | 21 | def get_code_list(audio_list: List[str], 22 | gpu_id: int = 0): 23 | device = f'cuda:{gpu_id}' 24 | # Instantiate a pretrained EnCodec model 25 | model = EncodecModel.encodec_model_24khz() 26 | model.set_target_bandwidth(6.0) 27 | model = model.cuda(device) 28 | code_list = [] 29 | for audio in tqdm(audio_list, desc='calculate codes...'): 30 | code_list.append(get_code(audio, device, model)) 31 | return code_list 32 | 33 | def get_code(audio: str, 34 | device: str, 35 | model: Any): 36 | try: 37 | wav, sr = torchaudio.load(audio) 38 | wav = convert_audio(wav, sr, model.sample_rate, model.channels) 39 | wav = wav.unsqueeze(0) 40 | wav = wav.cuda(device) 41 | # Extract discrete codes from EnCodec 42 | with torch.no_grad(): 43 | encoded_frames = model.encode(wav) 44 | code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] 45 | return code.cpu().numpy().astype(np.int16)[0] 46 | except: 47 | print(f'{audio} code error...') 48 | return None 49 | 50 | def get_mel_spectrogram(audio): 51 | try: 52 | mel = log_mel_spectrogram(audio) 53 | return mel.numpy() 54 | except: 55 | print(f'{audio} mel error...') 56 | return None 57 | 58 | def process_audios(audios: List[str], 59 | num_workers: int): 60 | """ 61 | 处理audios文件 62 | Args: 63 | audios: List[str] 64 | Returns: 65 | records: List[Dict] 66 | """ 67 | records = [] 68 | # 计算mel 69 | mels = Parallel(n_jobs=num_workers)(delayed(get_mel_spectrogram)(audio) for audio in tqdm(audios, desc='calculate mels...')) 70 | # 计算code 71 | num_gpus = torch.cuda.device_count() 72 | per_gpu_samples = math.ceil(len(audios)/num_gpus) 73 | codes_list = Parallel(n_jobs=num_gpus)(delayed(get_code_list)(audios[i*per_gpu_samples:(i+1)*per_gpu_samples], i) \ 74 | for i in range(0, num_gpus)) 75 | codes = [] 76 | for codes_item in codes_list: 77 | codes.extend(codes_item) 78 | for mel, code in zip(mels, codes): 79 | records.append({ 80 | 'mel': mel, 81 | 'code': code 82 | }) 83 | return records 84 | 85 | 86 | def construct_dataset(input_dir, 87 | output_dir, 88 | partition_size=1000, 89 | num_workers=10): 90 | os.makedirs(output_dir, exist_ok=True) 91 | input_files = glob.glob(f"{input_dir}/**/*.wav", recursive=True) 92 | 93 | with wds.ShardWriter(os.path.join(output_dir, 'shard-%06d.tar'), 94 | maxcount=10000000, maxsize=1<<32) as sink: 95 | index = 0 96 | for partition_start in tqdm(range(0, len(input_files), partition_size)): 97 | audios = input_files[partition_start:partition_start+partition_size] 98 | records = process_audios(audios, num_workers=num_workers) 99 | for record in records: 100 | if record['mel'] is not None and record['code'] is not None: 101 | sink.write({ 102 | '__key__': "%011d" % index, 103 | 'data.pyd': record}) 104 | index += 1 105 | print(f'records number: {index}') 106 | 107 | 108 | if __name__ == '__main__': 109 | fire.Fire(construct_dataset) 110 | -------------------------------------------------------------------------------- /tools/construct_parallel_dataset.py: -------------------------------------------------------------------------------- 1 | import webdataset as wds 2 | import fire 3 | import os 4 | from vc_lm.vc_engine import VCEngineDataFactory 5 | from tqdm.auto import tqdm 6 | import numpy as np 7 | 8 | def process_records(record_list, 9 | device_id=0, 10 | ar_model_path='/root/autodl-tmp/vc-models/ar-1024.ckpt', 11 | nar_model_path='/root/autodl-tmp/vc-models/nar-1024.ckpt', 12 | ar_config_file='configs/ar_model.json', 13 | nar_config_file='configs/nar_model.json'): 14 | import os 15 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device_id) 16 | device_id = 0 17 | engine = VCEngineDataFactory(ar_model_path, 18 | nar_model_path, 19 | ar_config_file, 20 | nar_config_file, 21 | device=f'cuda:{device_id}') 22 | output_records = [] 23 | for record in tqdm(record_list): 24 | mel1, code1, mel2, code2 = record['mel1'], record['code1'], record['mel2'], record['code2'] 25 | outputs_list = engine.process_multistep_audio(mel1, code1, mel2, code2) 26 | for outputs in outputs_list: 27 | output_mel = outputs['mel_alpha'] 28 | output_code = outputs['code'] 29 | mel_len = output_mel.shape[1] / 100 30 | code_len = output_code.shape[1] / 75 31 | if code_len > mel_len: 32 | output_code = output_code[:, 0:int(mel_len * 75)] 33 | else: 34 | output_mel = output_mel[:, 0:int(code_len * 100)] 35 | sub_record_idx = len(output_records) 36 | output_records.append( 37 | { 38 | '__key__': "%011d" % sub_record_idx, 39 | 'data.pyd': { 40 | 'mel': output_mel, 41 | 'code': output_code.detach().cpu().numpy().astype(np.int16) 42 | } 43 | } 44 | ) 45 | return output_records 46 | 47 | def construct_parallel_dataset(input_data_path: str ="/root/autodl-tmp/jr_dataset/shard-000000.tar", 48 | ref_data_path: str ="/root/autodl-tmp/shard-000000.tar", 49 | repeat_num: int = 3, 50 | num_devices: int = 1, 51 | output_dir: str = "/root/autodl-tmp/data/jr-wds-pair/train", 52 | ar_model_path = '/root/autodl-tmp/vc-models/ar-1024.ckpt', 53 | nar_model_path = '/root/autodl-tmp/vc-models/nar-1024.ckpt', 54 | ar_config_file = 'configs/ar_model.json', 55 | nar_config_file = 'configs/nar_model.json'): 56 | """ 57 | Args: 58 | input_data_path: str. The target person's voice audio file. 59 | ref_data_path: str. files consisting of a large number of different voices, used for prompts. 60 | repeat_num: int. The number of repetitions in constructing the dataset. 61 | output_dir: str. 62 | """ 63 | os.makedirs(output_dir, exist_ok=True) 64 | dataset1 = wds.WebDataset(input_data_path) 65 | dataset1 = dataset1.decode() 66 | 67 | dataset2 = wds.WebDataset(ref_data_path) 68 | dataset2 = dataset2.decode() 69 | 70 | dataset2 = iter(dataset2) 71 | 72 | 73 | index = 0 74 | 75 | records = [] 76 | 77 | for i in range(repeat_num): 78 | for record_idx, item1 in tqdm(enumerate(dataset1)): 79 | # if record_idx >= 40/2: 80 | # continue 81 | item2 = next(dataset2) 82 | obj1, obj2 = item1['data.pyd'], item2['data.pyd'] 83 | mel1, code1, mel2, code2 = obj1['mel'], obj1['code'], obj2['mel'], obj2['code'] 84 | records.append({ 85 | 'index': index, 86 | 'mel1': mel1, 87 | 'code1': code1, 88 | 'mel2': mel2, 89 | 'code2': code2 90 | }) 91 | index += 1 92 | print(index) 93 | 94 | from joblib.parallel import Parallel, delayed 95 | 96 | n_jobs = num_devices 97 | segment_num = int(len(records) / n_jobs) 98 | 99 | result_list = Parallel(n_jobs=n_jobs)(delayed(process_records)(records[i * segment_num:(i + 1) * segment_num], 100 | i % num_devices, 101 | ar_model_path, 102 | nar_model_path, 103 | ar_config_file, 104 | nar_config_file) \ 105 | for i in range(n_jobs)) 106 | outputs = [] 107 | for item in result_list: 108 | outputs.extend(item) 109 | 110 | with wds.ShardWriter(os.path.join(output_dir, 'shard-%06d.tar'), 111 | maxcount=10000000, maxsize=1 << 32) as sink: 112 | for record in outputs: 113 | sink.write(record) 114 | 115 | if __name__ == '__main__': 116 | fire.Fire(construct_parallel_dataset) 117 | -------------------------------------------------------------------------------- /tools/construct_wavs_file.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import librosa 3 | import os 4 | import random 5 | import soundfile as sf 6 | 7 | from joblib.parallel import Parallel, delayed 8 | from tqdm.auto import tqdm 9 | 10 | min_time, max_time = 10, 24 11 | 12 | def process_audio(input_file, 13 | audio_id, output_dir): 14 | x, sr = librosa.load(input_file, 15 | sr=16000) 16 | min_segment_size, max_segment_size = min_time * sr, max_time * sr 17 | x_len = x.shape[0] 18 | if x_len < min_segment_size: 19 | return 0 20 | segments = [] 21 | pos = 0 22 | while pos < x_len: 23 | cur_segment_size = random.randint(min_segment_size, max_segment_size) 24 | if pos + cur_segment_size > x_len: 25 | cur_pos = min(pos, x_len - min_segment_size) 26 | segment = x[cur_pos:] 27 | segments.append(segment) 28 | break 29 | segment = x[pos:pos+cur_segment_size] 30 | segments.append(segment) 31 | pos += cur_segment_size 32 | for sub_id, segment in enumerate(segments): 33 | sf.write(os.path.join(output_dir, 34 | f"{audio_id}_{sub_id}.wav"), 35 | segment, sr, subtype='PCM_16') 36 | return sum([item.shape[0]/sr for item in segments]) 37 | 38 | files = [] 39 | 40 | for line in open('/home1/jiangxinghua/data/files.txt'): 41 | files.append(line.strip()) 42 | 43 | output_dir = '/home/jiangxinghua/data/audios' 44 | 45 | r = Parallel(n_jobs=64)(delayed(process_audio)(filename, idx, output_dir) for idx, filename in tqdm(enumerate(files))) 46 | 47 | print(f"total time: {sum(r)/3600} h") 48 | 49 | with open('outputs.txt', 'w') as fout: 50 | fout.write(f"total time: {sum(r)/3600} h") 51 | 52 | 53 | -------------------------------------------------------------------------------- /tools/extract_whisper_encoder_model.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import torch 3 | import whisper 4 | 5 | def extract_whisper_encoder_model(input_model='/root/autodl-tmp/cache/whisper/medium.pt', 6 | output_model=None): 7 | checkpoint = torch.load(input_model) 8 | dims = checkpoint['dims'] 9 | model = whisper.load_model(input_model) 10 | model_state_dict = model.encoder.state_dict() 11 | torch.save({ 12 | 'dims': dims, 13 | 'model_state_dict': model_state_dict 14 | }, output_model) 15 | 16 | 17 | if __name__ == '__main__': 18 | fire.Fire(extract_whisper_encoder_model) -------------------------------------------------------------------------------- /tools/save_ar_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fire 3 | 4 | def save_model(input_model, 5 | output_model): 6 | m = torch.load(input_model, map_location=torch.device('cpu')) 7 | del m['optimizer_states'] 8 | for k in list(m['state_dict'].keys()): 9 | v = m['state_dict'][k] 10 | del m['state_dict'][k] 11 | m['state_dict'][k.replace('linear3', 'linear1').replace('linear4', 'linear2')] = v 12 | torch.save(m, output_model) 13 | 14 | 15 | if __name__ == '__main__': 16 | fire.Fire(save_model) -------------------------------------------------------------------------------- /tools/save_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fire 3 | 4 | def save_model(input_model, 5 | output_model): 6 | m = torch.load(input_model, map_location=torch.device('cpu')) 7 | del m['optimizer_states'] 8 | torch.save(m, output_model) 9 | 10 | 11 | if __name__ == '__main__': 12 | fire.Fire(save_model) -------------------------------------------------------------------------------- /vc_lm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/__init__.py -------------------------------------------------------------------------------- /vc_lm/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/callbacks/__init__.py -------------------------------------------------------------------------------- /vc_lm/datamodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/datamodules/__init__.py -------------------------------------------------------------------------------- /vc_lm/datamodules/ar_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | import pytorch_lightning as pl 5 | import webdataset as wds 6 | from torch.utils.data import Dataset, DataLoader, default_collate 7 | from vc_lm.datamodules.datasets.ar_dataset import ARDataset 8 | 9 | from webdataset.utils import pytorch_worker_info 10 | 11 | def ar_collect_fn(x): 12 | y = default_collate(x) 13 | if '__key__' in y: 14 | del y['__key__'] 15 | return y 16 | 17 | class ARDataModule(pl.LightningDataModule): 18 | def __init__(self, 19 | data_dir: str, 20 | batch_size: int = 64, 21 | max_audio_time: float = 24, 22 | num_workers: int = 0, 23 | train_dataset_size: int = -1, 24 | val_dataset_size: int = 200, 25 | train_pattern: str = None, 26 | val_pattern: str = None, 27 | pin_memory: bool = False): 28 | super().__init__() 29 | self.data_dir = data_dir 30 | self.batch_size = batch_size 31 | self.max_audio_time = max_audio_time 32 | self.num_workers = num_workers 33 | self.pin_memory = pin_memory 34 | self.data_train: Optional[Dataset] = None 35 | self.data_val: Optional[Dataset] = None 36 | self.data_test: Optional[Dataset] = None 37 | self.train_dataset_size = train_dataset_size 38 | self.val_dataset_size = val_dataset_size 39 | self.train_pattern = train_pattern 40 | self.val_pattern = val_pattern 41 | 42 | def prepare_data(self) -> None: 43 | pass 44 | 45 | def setup(self, stage: Optional[str] = None): 46 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 47 | This method is called by lightning separately when using `trainer.fit()` and `trainer.test()`! 48 | The `stage` can be used to differentiate whether the `setup()` is called before trainer.fit()` or `trainer.test()`.""" 49 | if self.data_train is None or self.data_val is None or self.data_test is None: 50 | self.data_train = ARDataset(os.path.join(self.data_dir, 'train'), 51 | pattern=self.train_pattern, 52 | max_audio_time=self.max_audio_time, 53 | shuffle=True).get_dataset() 54 | self.data_val = ARDataset(os.path.join(self.data_dir, 'val'), 55 | pattern=self.val_pattern, 56 | max_audio_time=self.max_audio_time).get_dataset() 57 | self.data_test = ARDataset(os.path.join(self.data_dir, 'test'), 58 | pattern=self.val_pattern, 59 | max_audio_time=self.max_audio_time).get_dataset() 60 | 61 | def train_dataloader(self): 62 | return self.get_dataloader(self.data_train, self.train_dataset_size) 63 | 64 | def val_dataloader(self): 65 | return self.get_dataloader(self.data_val, self.val_dataset_size) 66 | 67 | def test_dataloader(self): 68 | return self.get_dataloader(self.data_test, self.val_dataset_size) 69 | 70 | def get_dataloader(self, dataset, dataset_size): 71 | # batch 72 | dataset = dataset.batched(self.batch_size, collation_fn=ar_collect_fn, partial=False) 73 | _, world_size, _, _ = pytorch_worker_info() 74 | number_of_batches = int(dataset_size // (world_size * self.batch_size)) 75 | loader = wds.WebLoader(dataset, 76 | batch_size=None, 77 | shuffle=False, num_workers=self.num_workers).with_length(number_of_batches).with_epoch(number_of_batches) 78 | loader = loader.repeat(2).slice(number_of_batches) 79 | return loader 80 | -------------------------------------------------------------------------------- /vc_lm/datamodules/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/datamodules/datasets/__init__.py -------------------------------------------------------------------------------- /vc_lm/datamodules/datasets/ar_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | import webdataset as wds 6 | 7 | from vc_lm.utils.data_utils import pad_or_trim 8 | 9 | class ARDataset(object): 10 | _MEL_SAMPLE_RATE = 100 11 | _CODE_SAMPLE_RATE = 75 12 | _NUM_Q = 8 13 | _EOS_ID = 1024 * 8 14 | _PAD_ID = 1024 * 8 + 1 15 | _MAX_MEL_AUDIO_TIME = 30 16 | 17 | def __init__(self, 18 | local, 19 | remote=None, 20 | pattern=None, 21 | max_audio_time=24, 22 | shuffle=False, 23 | shuffle_buffer=2000): 24 | self.local_path = local 25 | self.max_audio_time = max_audio_time 26 | self.shuffle = shuffle 27 | self.shuffle_buffer = shuffle_buffer 28 | self.max_mel_len = int(self._MEL_SAMPLE_RATE * self._MAX_MEL_AUDIO_TIME) 29 | self.max_code_len = int(self._CODE_SAMPLE_RATE * self.max_audio_time) 30 | self.max_content_len = math.ceil(self.max_mel_len/2) 31 | self.pattern = pattern 32 | 33 | def process_record(self, record): 34 | obj = record['data.pyd'] 35 | mel_len = obj['mel'].shape[1] 36 | # (max_content_len,) 37 | content_mask = torch.lt(torch.arange(0, self.max_content_len), math.ceil(mel_len//2)).type(torch.long) 38 | # (80, max_mel_len) 39 | mel = pad_or_trim(obj['mel'], self.max_mel_len) 40 | # (_NUM_Q, code_len) 41 | input_code = obj['code'].astype(np.int64) 42 | output_code = np.concatenate([input_code, np.ones([self._NUM_Q, 1], dtype=np.int64) * self._EOS_ID], axis=1) 43 | code_len = input_code.shape[1] 44 | # (max_code_len,) 45 | code_mask = torch.lt(torch.arange(0, self.max_code_len), 46 | code_len).type(torch.long) 47 | # pad input_code, output_code 48 | input_code = pad_or_trim(input_code[0], self.max_code_len, pad_value=self._PAD_ID) 49 | output_code = pad_or_trim(output_code[0][1:], self.max_code_len, 50 | pad_value=self._PAD_ID) 51 | return { 52 | 'mel': torch.tensor(mel), 53 | 'content_mask': content_mask, 54 | 'input_code': torch.tensor(input_code), 55 | 'output_code': torch.tensor(output_code), 56 | 'code_mask': code_mask 57 | } 58 | 59 | def get_dataset(self): 60 | dataset = wds.WebDataset(self.local_path + '/' + self.pattern, 61 | nodesplitter=wds.split_by_node) 62 | if self.shuffle: 63 | dataset = dataset.shuffle(self.shuffle_buffer) 64 | dataset = dataset.decode().map(self.process_record) 65 | return dataset -------------------------------------------------------------------------------- /vc_lm/datamodules/datasets/nar_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | import webdataset as wds 6 | 7 | from vc_lm.utils.data_utils import pad_or_trim 8 | 9 | class NARDataset(object): 10 | _MEL_SAMPLE_RATE = 100 11 | _CODE_SAMPLE_RATE = 75 12 | _NUM_Q = 8 13 | _EOS_ID = 1024 14 | _PAD_ID = 1024 * 8 + 1 15 | _MAX_MEL_AUDIO_TIME = 30 16 | def __init__(self, 17 | local, 18 | remote=None, 19 | pattern=None, 20 | max_audio_time=24, 21 | style_audio_time=3, 22 | shuffle=False, 23 | shuffle_buffer=2000): 24 | self.local_path = local 25 | self.max_audio_time = max_audio_time 26 | self.style_audio_time = 3 27 | self.max_mel_len = int(self._MEL_SAMPLE_RATE * self._MAX_MEL_AUDIO_TIME) 28 | self.max_code_len = int(self._CODE_SAMPLE_RATE * self.max_audio_time) 29 | self.max_content_len = math.ceil(self.max_mel_len/2) 30 | self.style_code_len = int(self._CODE_SAMPLE_RATE * style_audio_time) 31 | self.pattern = pattern 32 | self.shuffle = shuffle 33 | self.shuffle_buffer = shuffle_buffer 34 | 35 | def process_record(self, record): 36 | """ 37 | return: 38 | { 39 | "mel": (80, max_mel_len), 40 | "content_mask": (max_content_len,), 41 | "input_code": (8, code_len), 42 | "code_mask": (max_code_len,) 43 | } 44 | """ 45 | obj = record['data.pyd'] 46 | mel_len = obj['mel'].shape[1] 47 | # (max_content_len,) 48 | content_mask = torch.lt(torch.arange(0, self.max_content_len), math.ceil(mel_len//2)).type(torch.long) 49 | # (80, max_mel_len) 50 | mel = pad_or_trim(obj['mel'], self.max_mel_len) 51 | # (_NUM_Q, code_len) 52 | input_code = obj['code'].astype(np.int64) 53 | code_len = input_code.shape[1] 54 | # (max_code_len,) 55 | code_mask = torch.lt(torch.arange(0, self.max_code_len), 56 | code_len).type(torch.long) 57 | # style_code (8, style_code_len) 58 | style_start_pos = torch.randint(0, input_code.shape[1] - self.style_code_len, ()) 59 | style_code = input_code[:, style_start_pos:style_start_pos+self.style_code_len] 60 | 61 | # pad input_code, output_code 62 | input_code = pad_or_trim(input_code, self.max_code_len, 63 | pad_value=self._PAD_ID) 64 | return { 65 | 'mel': torch.tensor(mel), 66 | 'content_mask': content_mask, 67 | 'input_code': torch.tensor(input_code), 68 | 'code_mask': code_mask, 69 | 'style_code': torch.tensor(style_code) 70 | } 71 | 72 | def get_dataset(self): 73 | dataset = wds.WebDataset(self.local_path + '/' + self.pattern, 74 | nodesplitter=wds.split_by_node) 75 | if self.shuffle: 76 | dataset = dataset.shuffle(self.shuffle_buffer) 77 | dataset = dataset.decode().map(self.process_record) 78 | return dataset 79 | -------------------------------------------------------------------------------- /vc_lm/datamodules/nar_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import Optional 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader, default_collate 8 | from vc_lm.datamodules.datasets.nar_dataset import NARDataset 9 | 10 | import webdataset as wds 11 | from webdataset.utils import pytorch_worker_info 12 | 13 | def nar_collate_fn(x): 14 | nar_stage = random.randint(0, NARDataset._NUM_Q - 2) 15 | for idx, item in enumerate(x): 16 | item['nar_stage'] = torch.tensor(nar_stage) 17 | item['output_code'] = item['input_code'][item['nar_stage'] + 1] 18 | item['input_code'] = item['input_code'][0:item['nar_stage'] + 1] 19 | y = default_collate(x) 20 | if '__key__' in y: 21 | del y['__key__'] 22 | return y 23 | 24 | 25 | class NARDataModule(pl.LightningDataModule): 26 | def __init__(self, 27 | data_dir: str, 28 | batch_size: int = 64, 29 | max_audio_time: float = 24, 30 | style_audio_time: float = 3, 31 | num_workers: int = 0, 32 | train_dataset_size: int = -1, 33 | val_dataset_size: int = 200, 34 | train_pattern: str = None, 35 | val_pattern: str = None, 36 | pin_memory: bool = False): 37 | super().__init__() 38 | self.data_dir = data_dir 39 | self.batch_size = batch_size 40 | self.max_audio_time = max_audio_time 41 | self.style_audio_time = style_audio_time 42 | self.num_workers = num_workers 43 | self.pin_memory = pin_memory 44 | self.data_train: Optional[Dataset] = None 45 | self.data_val: Optional[Dataset] = None 46 | self.data_test: Optional[Dataset] = None 47 | self.train_dataset_size = train_dataset_size 48 | self.val_dataset_size = val_dataset_size 49 | self.train_pattern = train_pattern 50 | self.val_pattern = val_pattern 51 | 52 | def prepare_data(self) -> None: 53 | pass 54 | 55 | def setup(self, stage: Optional[str] = None): 56 | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. 57 | This method is called by lightning separately when using `trainer.fit()` and `trainer.test()`! 58 | The `stage` can be used to differentiate whether the `setup()` is called before trainer.fit()` or `trainer.test()`.""" 59 | if not self.data_train or not self.data_val or not self.data_test: 60 | self.data_train = NARDataset(os.path.join(self.data_dir, 'train'), 61 | pattern=self.train_pattern, 62 | max_audio_time=self.max_audio_time, 63 | style_audio_time=self.style_audio_time, 64 | shuffle=True).get_dataset() 65 | self.data_val = NARDataset(os.path.join(self.data_dir, 'val'), 66 | pattern=self.val_pattern, 67 | max_audio_time=self.max_audio_time, 68 | style_audio_time=self.style_audio_time).get_dataset() 69 | self.data_test = NARDataset(os.path.join(self.data_dir, 'test'), 70 | pattern=self.val_pattern, 71 | max_audio_time=self.max_audio_time, 72 | style_audio_time=self.style_audio_time).get_dataset() 73 | 74 | def train_dataloader(self): 75 | return self.get_dataloader(self.data_train, self.train_dataset_size) 76 | 77 | def val_dataloader(self): 78 | return self.get_dataloader(self.data_val, self.val_dataset_size) 79 | 80 | def test_dataloader(self): 81 | return self.get_dataloader(self.data_test, self.val_dataset_size) 82 | 83 | def get_dataloader(self, dataset, dataset_size): 84 | # batch 85 | dataset = dataset.batched(self.batch_size, 86 | collation_fn=nar_collate_fn, partial=False) 87 | _, world_size, _, _ = pytorch_worker_info() 88 | number_of_batches = int(dataset_size // (world_size * self.batch_size)) 89 | loader = wds.WebLoader(dataset, 90 | batch_size=None, 91 | shuffle=False, num_workers=self.num_workers).with_length(number_of_batches).with_epoch(number_of_batches) 92 | loader = loader.repeat(2).slice(number_of_batches) 93 | return loader 94 | -------------------------------------------------------------------------------- /vc_lm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/models/__init__.py -------------------------------------------------------------------------------- /vc_lm/models/ar_model_pl.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import json 4 | 5 | from typing import Any, List 6 | from torch import nn 7 | from torchmetrics.classification.accuracy import Accuracy 8 | from torch.optim import AdamW 9 | 10 | from vc_lm.models.base import VCLMConfig 11 | from vc_lm.models.models.ar_model import ARModel, ARModelForConditionalGeneration 12 | from vc_lm.models.misc import CosineWarmupScheduler 13 | 14 | from transformers.optimization import get_polynomial_decay_schedule_with_warmup 15 | 16 | 17 | class ARModelPL(pl.LightningModule): 18 | def __init__(self, config_file: str, 19 | lr: float = 0.001, 20 | weight_decay: float = 0.0005, 21 | warmup_step: int = 10000, 22 | max_iters: int = 800000, 23 | load_pretrain: bool = False, 24 | pretrain_model_path: str = None): 25 | super().__init__() 26 | self.save_hyperparameters() 27 | with open(config_file) as f: 28 | config = json.load(f) 29 | config = VCLMConfig(**config) 30 | self.model = ARModelForConditionalGeneration(config) 31 | # 加载whisper模型参数. 32 | self.model.model.encoder.load_pretrained_whisper_params() 33 | 34 | self.loss_fct = nn.CrossEntropyLoss() 35 | 36 | if load_pretrain: 37 | loaded_state = torch.load(pretrain_model_path)['state_dict'] 38 | self.load_state_dict(loaded_state, 39 | strict=False) 40 | 41 | self.train_accuracy = Accuracy(task="multiclass", 42 | num_classes=self.model.model.shared.num_embeddings, 43 | average='micro', 44 | ignore_index=-100) 45 | self.val_accuracy = Accuracy(task="multiclass", 46 | num_classes=self.model.model.shared.num_embeddings, 47 | average='micro', 48 | ignore_index=-100) 49 | self.test_accuracy = Accuracy(task="multiclass", 50 | num_classes=self.model.model.shared.num_embeddings, 51 | average='micro', 52 | ignore_index=-100) 53 | 54 | def load_bart_decoder_params(self): 55 | decoder = self.model.model.decoder 56 | bart_state_dict = torch.load('/root/autodl-tmp/pretrained-models/bart-large/pytorch_model.bin') 57 | filtered_state_dict = {} 58 | for k, v in bart_state_dict.items(): 59 | if 'decoder.layers' in k: 60 | filtered_state_dict[".".join(k.split('.')[1:])] = v 61 | decoder.load_state_dict(filtered_state_dict, strict=False) 62 | 63 | def forward(self, 64 | input_mels=None, 65 | attention_mask=None, 66 | decoder_input_ids=None, 67 | decoder_attention_mask=None): 68 | outputs = self.model(input_ids=input_mels, 69 | attention_mask=attention_mask, 70 | decoder_input_ids=decoder_input_ids, 71 | decoder_attention_mask=decoder_attention_mask) 72 | return outputs[0] 73 | 74 | def step(self, batch: Any): 75 | mel = batch['mel'] 76 | content_mask = batch['content_mask'] 77 | input_code = batch['input_code'] 78 | output_code = batch['output_code'] 79 | output_code[output_code == self.model.config.pad_token_id] = -100 80 | code_mask = batch['code_mask'] 81 | lm_logits = self.forward(input_mels=mel, 82 | attention_mask=content_mask, 83 | decoder_input_ids=input_code, 84 | decoder_attention_mask=code_mask) 85 | lm_loss = self.loss_fct(lm_logits.view(-1, self.model.config.vocab_size), output_code.view(-1)) 86 | preds = torch.argmax(lm_logits, dim=-1) 87 | return lm_loss, preds, output_code 88 | 89 | def training_step(self, 90 | batch: Any, 91 | batch_idx: int): 92 | loss, preds, targets = self.step(batch) 93 | acc = self.train_accuracy(preds, targets) 94 | self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=False) 95 | self.log('train/acc', acc, on_step=True, on_epoch=True, prog_bar=True) 96 | # return { 97 | # 'loss': loss, 98 | # 'preds': preds, 99 | # 'targets': targets 100 | # } 101 | return {'loss': loss} 102 | 103 | def training_epoch_end(self, outputs: List[Any]): 104 | pass 105 | 106 | def validation_step(self, batch: Any, batch_idx: int): 107 | loss, preds, targets = self.step(batch) 108 | acc = self.val_accuracy(preds, targets) 109 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 110 | self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True) 111 | # return {"loss": loss, "preds": preds, "targets": targets} 112 | return {'loss': loss} 113 | 114 | def validation_epoch_end(self, outputs: List[Any]): 115 | pass 116 | 117 | def test_step(self, batch: Any, batch_idx: int): 118 | loss, preds, targets = self.step(batch) 119 | # log test metrics 120 | acc = self.test_accuracy(preds, targets) 121 | self.log("test/loss", loss, on_step=False, on_epoch=True) 122 | self.log("test/acc", acc, on_step=False, on_epoch=True) 123 | # return {"loss": loss, "preds": preds, "targets": targets} 124 | return {"loss": loss} 125 | 126 | def test_epoch_end(self, outputs: List[Any]): 127 | pass 128 | 129 | def configure_optimizers(self) -> Any: 130 | no_decay = ["bias", "LayerNorm.weight"] 131 | optimizer_grouped_parameters = [ 132 | { 133 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 134 | "weight_decay": self.hparams.weight_decay, 135 | }, 136 | { 137 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 138 | "weight_decay": 0.0, 139 | }, 140 | ] 141 | optimizer = AdamW(optimizer_grouped_parameters, 142 | lr=self.hparams.lr, 143 | weight_decay=self.hparams.weight_decay) 144 | # scheduler = CosineWarmupScheduler(optimizer=optimizer, 145 | # warmup=self.hparams.warmup_step, 146 | # max_iters=self.hparams.max_iters) 147 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer=optimizer, 148 | num_warmup_steps=self.hparams.warmup_step, 149 | num_training_steps=self.hparams.max_iters) 150 | 151 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 152 | -------------------------------------------------------------------------------- /vc_lm/models/bart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/models/bart/__init__.py -------------------------------------------------------------------------------- /vc_lm/models/bart/configuration_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ BART model configuration""" 16 | import warnings 17 | from collections import OrderedDict 18 | from typing import Any, Mapping, Optional 19 | 20 | from transformers import PreTrainedTokenizer 21 | from transformers.configuration_utils import PretrainedConfig 22 | from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast 23 | from transformers.onnx.utils import compute_effective_axis_dimension 24 | from transformers.utils import TensorType, is_torch_available, logging 25 | 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | BART_PRETRAINED_CONFIG_ARCHIVE_MAP = { 30 | "facebook/bart-large": "https://huggingface.co/facebook/bart-large/resolve/main/config.json", 31 | # See all BART models at https://huggingface.co/models?filter=bart 32 | } 33 | 34 | 35 | class BartConfig(PretrainedConfig): 36 | r""" 37 | This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART 38 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 39 | defaults will yield a similar configuration to that of the BART 40 | [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture. 41 | 42 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 43 | documentation from [`PretrainedConfig`] for more information. 44 | 45 | 46 | Args: 47 | vocab_size (`int`, *optional*, defaults to 50265): 48 | Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the 49 | `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`]. 50 | d_model (`int`, *optional*, defaults to 1024): 51 | Dimensionality of the layers and the pooler layer. 52 | encoder_layers (`int`, *optional*, defaults to 12): 53 | Number of encoder layers. 54 | decoder_layers (`int`, *optional*, defaults to 12): 55 | Number of decoder layers. 56 | encoder_attention_heads (`int`, *optional*, defaults to 16): 57 | Number of attention heads for each attention layer in the Transformer encoder. 58 | decoder_attention_heads (`int`, *optional*, defaults to 16): 59 | Number of attention heads for each attention layer in the Transformer decoder. 60 | decoder_ffn_dim (`int`, *optional*, defaults to 4096): 61 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 62 | encoder_ffn_dim (`int`, *optional*, defaults to 4096): 63 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 64 | activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): 65 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 66 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 67 | dropout (`float`, *optional*, defaults to 0.1): 68 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 69 | attention_dropout (`float`, *optional*, defaults to 0.0): 70 | The dropout ratio for the attention probabilities. 71 | activation_dropout (`float`, *optional*, defaults to 0.0): 72 | The dropout ratio for activations inside the fully connected layer. 73 | classifier_dropout (`float`, *optional*, defaults to 0.0): 74 | The dropout ratio for classifier. 75 | max_position_embeddings (`int`, *optional*, defaults to 1024): 76 | The maximum sequence length that this model might ever be used with. Typically set this to something large 77 | just in case (e.g., 512 or 1024 or 2048). 78 | init_std (`float`, *optional*, defaults to 0.02): 79 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 80 | encoder_layerdrop (`float`, *optional*, defaults to 0.0): 81 | The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 82 | for more details. 83 | decoder_layerdrop (`float`, *optional*, defaults to 0.0): 84 | The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 85 | for more details. 86 | scale_embedding (`bool`, *optional*, defaults to `False`): 87 | Scale embeddings by diving by sqrt(d_model). 88 | use_cache (`bool`, *optional*, defaults to `True`): 89 | Whether or not the model should return the last key/values attentions (not used by all models). 90 | num_labels (`int`, *optional*, defaults to 3): 91 | The number of labels to use in [`BartForSequenceClassification`]. 92 | forced_eos_token_id (`int`, *optional*, defaults to 2): 93 | The id of the token to force as the last generated token when `max_length` is reached. Usually set to 94 | `eos_token_id`. 95 | 96 | Example: 97 | 98 | ```python 99 | >>> from transformers import BartConfig, BartModel 100 | 101 | >>> # Initializing a BART facebook/bart-large style configuration 102 | >>> configuration = BartConfig() 103 | 104 | >>> # Initializing a model (with random weights) from the facebook/bart-large style configuration 105 | >>> model = BartModel(configuration) 106 | 107 | >>> # Accessing the model configuration 108 | >>> configuration = model.config 109 | ```""" 110 | model_type = "bart" 111 | keys_to_ignore_at_inference = ["past_key_values"] 112 | attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} 113 | 114 | def __init__( 115 | self, 116 | vocab_size=50265, 117 | max_position_embeddings=1024, 118 | encoder_layers=12, 119 | encoder_ffn_dim=4096, 120 | encoder_attention_heads=16, 121 | decoder_layers=12, 122 | decoder_ffn_dim=4096, 123 | decoder_attention_heads=16, 124 | encoder_layerdrop=0.0, 125 | decoder_layerdrop=0.0, 126 | activation_function="gelu", 127 | d_model=1024, 128 | dropout=0.1, 129 | attention_dropout=0.0, 130 | activation_dropout=0.0, 131 | init_std=0.02, 132 | classifier_dropout=0.0, 133 | scale_embedding=False, 134 | use_cache=True, 135 | num_labels=3, 136 | pad_token_id=1, 137 | bos_token_id=0, 138 | eos_token_id=2, 139 | is_encoder_decoder=True, 140 | decoder_start_token_id=2, 141 | forced_eos_token_id=2, 142 | **kwargs, 143 | ): 144 | self.vocab_size = vocab_size 145 | self.max_position_embeddings = max_position_embeddings 146 | self.d_model = d_model 147 | self.encoder_ffn_dim = encoder_ffn_dim 148 | self.encoder_layers = encoder_layers 149 | self.encoder_attention_heads = encoder_attention_heads 150 | self.decoder_ffn_dim = decoder_ffn_dim 151 | self.decoder_layers = decoder_layers 152 | self.decoder_attention_heads = decoder_attention_heads 153 | self.dropout = dropout 154 | self.attention_dropout = attention_dropout 155 | self.activation_dropout = activation_dropout 156 | self.activation_function = activation_function 157 | self.init_std = init_std 158 | self.encoder_layerdrop = encoder_layerdrop 159 | self.decoder_layerdrop = decoder_layerdrop 160 | self.classifier_dropout = classifier_dropout 161 | self.use_cache = use_cache 162 | self.num_hidden_layers = encoder_layers 163 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True 164 | 165 | super().__init__( 166 | num_labels=num_labels, 167 | pad_token_id=pad_token_id, 168 | bos_token_id=bos_token_id, 169 | eos_token_id=eos_token_id, 170 | is_encoder_decoder=is_encoder_decoder, 171 | decoder_start_token_id=decoder_start_token_id, 172 | forced_eos_token_id=forced_eos_token_id, 173 | **kwargs, 174 | ) 175 | 176 | # ensure backward compatibility for BART CNN models 177 | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): 178 | self.forced_bos_token_id = self.bos_token_id 179 | warnings.warn( 180 | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " 181 | "The config can simply be saved and uploaded again to be fixed." 182 | ) 183 | 184 | 185 | class BartOnnxConfig(OnnxSeq2SeqConfigWithPast): 186 | @property 187 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 188 | if self.task in ["default", "seq2seq-lm"]: 189 | common_inputs = OrderedDict( 190 | [ 191 | ("input_ids", {0: "batch", 1: "encoder_sequence"}), 192 | ("attention_mask", {0: "batch", 1: "encoder_sequence"}), 193 | ] 194 | ) 195 | 196 | if self.use_past: 197 | common_inputs["decoder_input_ids"] = {0: "batch"} 198 | common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} 199 | else: 200 | common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} 201 | common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} 202 | 203 | if self.use_past: 204 | self.fill_with_past_key_values_(common_inputs, direction="inputs") 205 | elif self.task == "causal-lm": 206 | # TODO: figure this case out. 207 | common_inputs = OrderedDict( 208 | [ 209 | ("input_ids", {0: "batch", 1: "encoder_sequence"}), 210 | ("attention_mask", {0: "batch", 1: "encoder_sequence"}), 211 | ] 212 | ) 213 | if self.use_past: 214 | num_encoder_layers, _ = self.num_layers 215 | for i in range(num_encoder_layers): 216 | common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} 217 | common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} 218 | else: 219 | common_inputs = OrderedDict( 220 | [ 221 | ("input_ids", {0: "batch", 1: "encoder_sequence"}), 222 | ("attention_mask", {0: "batch", 1: "encoder_sequence"}), 223 | ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), 224 | ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), 225 | ] 226 | ) 227 | 228 | return common_inputs 229 | 230 | @property 231 | def outputs(self) -> Mapping[str, Mapping[int, str]]: 232 | if self.task in ["default", "seq2seq-lm"]: 233 | common_outputs = super().outputs 234 | else: 235 | common_outputs = super(OnnxConfigWithPast, self).outputs 236 | if self.use_past: 237 | num_encoder_layers, _ = self.num_layers 238 | for i in range(num_encoder_layers): 239 | common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} 240 | common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} 241 | return common_outputs 242 | 243 | def _generate_dummy_inputs_for_default_and_seq2seq_lm( 244 | self, 245 | tokenizer: PreTrainedTokenizer, 246 | batch_size: int = -1, 247 | seq_length: int = -1, 248 | is_pair: bool = False, 249 | framework: Optional[TensorType] = None, 250 | ) -> Mapping[str, Any]: 251 | encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( 252 | tokenizer, batch_size, seq_length, is_pair, framework 253 | ) 254 | 255 | # Generate decoder inputs 256 | decoder_seq_length = seq_length if not self.use_past else 1 257 | decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( 258 | tokenizer, batch_size, decoder_seq_length, is_pair, framework 259 | ) 260 | decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} 261 | common_inputs = dict(**encoder_inputs, **decoder_inputs) 262 | 263 | if self.use_past: 264 | if not is_torch_available(): 265 | raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") 266 | else: 267 | import torch 268 | batch, encoder_seq_length = common_inputs["input_ids"].shape 269 | decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] 270 | num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads 271 | encoder_shape = ( 272 | batch, 273 | num_encoder_attention_heads, 274 | encoder_seq_length, 275 | self._config.hidden_size // num_encoder_attention_heads, 276 | ) 277 | decoder_past_length = decoder_seq_length + 3 278 | decoder_shape = ( 279 | batch, 280 | num_decoder_attention_heads, 281 | decoder_past_length, 282 | self._config.hidden_size // num_decoder_attention_heads, 283 | ) 284 | 285 | common_inputs["decoder_attention_mask"] = torch.cat( 286 | [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 287 | ) 288 | 289 | common_inputs["past_key_values"] = [] 290 | # If the number of encoder and decoder layers are present in the model configuration, both are considered 291 | num_encoder_layers, num_decoder_layers = self.num_layers 292 | min_num_layers = min(num_encoder_layers, num_decoder_layers) 293 | max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers 294 | remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" 295 | 296 | for _ in range(min_num_layers): 297 | common_inputs["past_key_values"].append( 298 | ( 299 | torch.zeros(decoder_shape), 300 | torch.zeros(decoder_shape), 301 | torch.zeros(encoder_shape), 302 | torch.zeros(encoder_shape), 303 | ) 304 | ) 305 | # TODO: test this. 306 | shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape 307 | for _ in range(min_num_layers, max_num_layers): 308 | common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) 309 | return common_inputs 310 | 311 | def _generate_dummy_inputs_for_causal_lm( 312 | self, 313 | tokenizer: PreTrainedTokenizer, 314 | batch_size: int = -1, 315 | seq_length: int = -1, 316 | is_pair: bool = False, 317 | framework: Optional[TensorType] = None, 318 | ) -> Mapping[str, Any]: 319 | common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( 320 | tokenizer, batch_size, seq_length, is_pair, framework 321 | ) 322 | 323 | if self.use_past: 324 | if not is_torch_available(): 325 | raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") 326 | else: 327 | import torch 328 | batch, seqlen = common_inputs["input_ids"].shape 329 | # Not using the same length for past_key_values 330 | past_key_values_length = seqlen + 2 331 | num_encoder_layers, _ = self.num_layers 332 | num_encoder_attention_heads, _ = self.num_attention_heads 333 | past_shape = ( 334 | batch, 335 | num_encoder_attention_heads, 336 | past_key_values_length, 337 | self._config.hidden_size // num_encoder_attention_heads, 338 | ) 339 | 340 | mask_dtype = common_inputs["attention_mask"].dtype 341 | common_inputs["attention_mask"] = torch.cat( 342 | [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 343 | ) 344 | common_inputs["past_key_values"] = [ 345 | (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) 346 | ] 347 | return common_inputs 348 | 349 | def _generate_dummy_inputs_for_sequence_classification_and_question_answering( 350 | self, 351 | tokenizer: PreTrainedTokenizer, 352 | batch_size: int = -1, 353 | seq_length: int = -1, 354 | is_pair: bool = False, 355 | framework: Optional[TensorType] = None, 356 | ) -> Mapping[str, Any]: 357 | # Copied from OnnxConfig.generate_dummy_inputs 358 | # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. 359 | # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX 360 | batch_size = compute_effective_axis_dimension( 361 | batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 362 | ) 363 | 364 | # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX 365 | token_to_add = tokenizer.num_special_tokens_to_add(is_pair) 366 | seq_length = compute_effective_axis_dimension( 367 | seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add 368 | ) 369 | 370 | # Generate dummy inputs according to compute batch and sequence 371 | dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size 372 | common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) 373 | return common_inputs 374 | 375 | def generate_dummy_inputs( 376 | self, 377 | tokenizer: PreTrainedTokenizer, 378 | batch_size: int = -1, 379 | seq_length: int = -1, 380 | is_pair: bool = False, 381 | framework: Optional[TensorType] = None, 382 | ) -> Mapping[str, Any]: 383 | if self.task in ["default", "seq2seq-lm"]: 384 | common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( 385 | tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework 386 | ) 387 | 388 | elif self.task == "causal-lm": 389 | common_inputs = self._generate_dummy_inputs_for_causal_lm( 390 | tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework 391 | ) 392 | else: 393 | common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( 394 | tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework 395 | ) 396 | 397 | return common_inputs 398 | 399 | def _flatten_past_key_values_(self, flattened_output, name, idx, t): 400 | if self.task in ["default", "seq2seq-lm"]: 401 | flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) 402 | else: 403 | flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( 404 | flattened_output, name, idx, t 405 | ) 406 | -------------------------------------------------------------------------------- /vc_lm/models/base.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | from torch import nn 4 | from transformers.configuration_utils import PretrainedConfig 5 | from transformers.modeling_utils import PreTrainedModel 6 | 7 | class VCLMConfig(PretrainedConfig): 8 | r""" 9 | This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART 10 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 11 | defaults will yield a similar configuration to that of the BART 12 | [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture. 13 | 14 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 15 | documentation from [`PretrainedConfig`] for more information. 16 | 17 | 18 | Args: 19 | vocab_size (`int`, *optional*, defaults to 50265): 20 | Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the 21 | `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`]. 22 | d_model (`int`, *optional*, defaults to 1024): 23 | Dimensionality of the layers and the pooler layer. 24 | encoder_layers (`int`, *optional*, defaults to 12): 25 | Number of encoder layers. 26 | decoder_layers (`int`, *optional*, defaults to 12): 27 | Number of decoder layers. 28 | encoder_attention_heads (`int`, *optional*, defaults to 16): 29 | Number of attention heads for each attention layer in the Transformer encoder. 30 | decoder_attention_heads (`int`, *optional*, defaults to 16): 31 | Number of attention heads for each attention layer in the Transformer decoder. 32 | decoder_ffn_dim (`int`, *optional*, defaults to 4096): 33 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 34 | encoder_ffn_dim (`int`, *optional*, defaults to 4096): 35 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 36 | activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): 37 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 38 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 39 | dropout (`float`, *optional*, defaults to 0.1): 40 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 41 | attention_dropout (`float`, *optional*, defaults to 0.0): 42 | The dropout ratio for the attention probabilities. 43 | activation_dropout (`float`, *optional*, defaults to 0.0): 44 | The dropout ratio for activations inside the fully connected layer. 45 | classifier_dropout (`float`, *optional*, defaults to 0.0): 46 | The dropout ratio for classifier. 47 | max_position_embeddings (`int`, *optional*, defaults to 1024): 48 | The maximum sequence length that this model might ever be used with. Typically set this to something large 49 | just in case (e.g., 512 or 1024 or 2048). 50 | init_std (`float`, *optional*, defaults to 0.02): 51 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 52 | encoder_layerdrop (`float`, *optional*, defaults to 0.0): 53 | The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 54 | for more details. 55 | decoder_layerdrop (`float`, *optional*, defaults to 0.0): 56 | The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 57 | for more details. 58 | scale_embedding (`bool`, *optional*, defaults to `False`): 59 | Scale embeddings by diving by sqrt(d_model). 60 | use_cache (`bool`, *optional*, defaults to `True`): 61 | Whether or not the model should return the last key/values attentions (not used by all models). 62 | num_labels (`int`, *optional*, defaults to 3): 63 | The number of labels to use in [`BartForSequenceClassification`]. 64 | forced_eos_token_id (`int`, *optional*, defaults to 2): 65 | The id of the token to force as the last generated token when `max_length` is reached. Usually set to 66 | `eos_token_id`. 67 | 68 | Example: 69 | 70 | ```python 71 | >>> from transformers import BartConfig, BartModel 72 | 73 | >>> # Initializing a BART facebook/bart-large style configuration 74 | >>> configuration = BartConfig() 75 | 76 | >>> # Initializing a model (with random weights) from the facebook/bart-large style configuration 77 | >>> model = BartModel(configuration) 78 | 79 | >>> # Accessing the model configuration 80 | >>> configuration = model.config 81 | ```""" 82 | model_type = "VCLM" 83 | keys_to_ignore_at_inference = ["past_key_values"] 84 | attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} 85 | 86 | def __init__( 87 | self, 88 | vocab_size=1024 * 8 + 3, 89 | max_position_embeddings=2048, 90 | style_length=225, 91 | encoder_ffn_dim=4096, 92 | encoder_attention_heads=16, 93 | encoder_layerdrop=0.0, 94 | decoder_layers=12, 95 | decoder_ffn_dim=4096, 96 | decoder_attention_heads=16, 97 | decoder_layerdrop=0.0, 98 | activation_function="gelu", 99 | d_model=1024, 100 | dropout=0.1, 101 | attention_dropout=0.0, 102 | activation_dropout=0.0, 103 | init_std=0.02, 104 | classifier_dropout=0.0, 105 | scale_embedding=False, 106 | use_cache=True, 107 | num_labels=3, 108 | pad_token_id=1024 * 8 + 1, 109 | eos_token_id=1024 * 8, 110 | bos_token_id=1024 * 8 + 2, 111 | is_encoder_decoder=True, 112 | decoder_start_token_id=2, 113 | forced_eos_token_id=1024 * 8, 114 | encoder_model_path="/whisper-medium-encoder.pt", 115 | n_q=8, 116 | q_size=1024, 117 | content_layer_num=4, 118 | **kwargs, 119 | ): 120 | self.vocab_size = vocab_size 121 | self.max_position_embeddings = max_position_embeddings 122 | self.style_length = style_length 123 | self.encoder_ffn_dim = encoder_ffn_dim 124 | self.encoder_attention_heads = encoder_attention_heads 125 | self.encoder_layerdrop = encoder_layerdrop 126 | self.d_model = d_model 127 | self.decoder_ffn_dim = decoder_ffn_dim 128 | self.decoder_layers = decoder_layers 129 | self.decoder_attention_heads = decoder_attention_heads 130 | self.dropout = dropout 131 | self.attention_dropout = attention_dropout 132 | self.activation_dropout = activation_dropout 133 | self.activation_function = activation_function 134 | self.init_std = init_std 135 | self.decoder_layerdrop = decoder_layerdrop 136 | self.classifier_dropout = classifier_dropout 137 | self.use_cache = use_cache 138 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True 139 | self.encoder_model_path = encoder_model_path 140 | self.n_q = n_q 141 | self.q_size = q_size 142 | self.content_layer_num = content_layer_num 143 | 144 | super().__init__( 145 | num_labels=num_labels, 146 | pad_token_id=pad_token_id, 147 | bos_token_id=bos_token_id, 148 | eos_token_id=eos_token_id, 149 | is_encoder_decoder=is_encoder_decoder, 150 | decoder_start_token_id=decoder_start_token_id, 151 | forced_eos_token_id=forced_eos_token_id, 152 | **kwargs, 153 | ) 154 | 155 | # ensure backward compatibility for BART CNN models 156 | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): 157 | self.forced_bos_token_id = self.bos_token_id 158 | warnings.warn( 159 | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " 160 | "The config can simply be saved and uploaded again to be fixed." 161 | ) 162 | 163 | 164 | class VCLMPretrainedModel(PreTrainedModel): 165 | config_class = VCLMConfig 166 | base_model_prefix = "model" 167 | supports_gradient_checkpointing = True 168 | _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"] 169 | _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] 170 | 171 | def _init_weights(self, module): 172 | std = self.config.init_std 173 | if isinstance(module, nn.Linear): 174 | module.weight.data.normal_(mean=0.0, std=std) 175 | if module.bias is not None: 176 | module.bias.data.zero_() 177 | elif isinstance(module, nn.Embedding): 178 | module.weight.data.normal_(mean=0.0, std=std) 179 | if module.padding_idx is not None: 180 | module.weight.data[module.padding_idx].zero_() 181 | 182 | def _set_gradient_checkpointing(self, module, value=False): 183 | # @todo 184 | pass 185 | # if isinstance(module, (ARDecoder, NARDecoder)): 186 | #if isinstance(module, (ARDecoder, )): 187 | # module.gradient_checkpointing = value 188 | 189 | @property 190 | def dummy_inputs(self): 191 | pad_token = self.config.pad_token_id 192 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 193 | dummy_inputs = { 194 | "attention_mask": input_ids.ne(pad_token), 195 | "input_ids": input_ids, 196 | } 197 | return dummy_inputs -------------------------------------------------------------------------------- /vc_lm/models/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/models/decoders/__init__.py -------------------------------------------------------------------------------- /vc_lm/models/decoders/ar_decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | 5 | from torch import nn 6 | 7 | from typing import Optional, Union, List, Tuple 8 | from vc_lm.models.bart.modeling_bart import BartLearnedPositionalEmbedding, BartDecoderLayer, _make_causal_mask, _expand_mask, BaseModelOutputWithPastAndCrossAttentions 9 | 10 | from transformers.utils import logging 11 | 12 | from vc_lm.models.base import VCLMConfig, VCLMPretrainedModel 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | class ARDecoder(VCLMPretrainedModel): 17 | """ 18 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] 19 | 20 | Args: 21 | config: BartConfig 22 | embed_tokens (nn.Embedding): output embedding 23 | """ 24 | 25 | def __init__(self, config: VCLMConfig, embed_tokens: Optional[nn.Embedding] = None): 26 | super().__init__(config) 27 | self.dropout = config.dropout 28 | self.layerdrop = config.decoder_layerdrop 29 | self.padding_idx = config.pad_token_id 30 | self.max_target_positions = config.max_position_embeddings 31 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 32 | 33 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 34 | 35 | if embed_tokens is not None: 36 | self.embed_tokens.weight = embed_tokens.weight 37 | 38 | self.embed_positions = BartLearnedPositionalEmbedding( 39 | config.max_position_embeddings, 40 | config.d_model, 41 | ) 42 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) 43 | self.layernorm_embedding = nn.LayerNorm(config.d_model) 44 | 45 | self.gradient_checkpointing = False 46 | # Initialize weights and apply final processing 47 | self.post_init() 48 | 49 | def get_input_embeddings(self): 50 | return self.embed_tokens 51 | 52 | def set_input_embeddings(self, value): 53 | self.embed_tokens = value 54 | 55 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 56 | # create causal mask 57 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 58 | combined_attention_mask = None 59 | if input_shape[-1] > 1: 60 | combined_attention_mask = _make_causal_mask( 61 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 62 | ).to(inputs_embeds.device) 63 | 64 | if attention_mask is not None: 65 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 66 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 67 | inputs_embeds.device 68 | ) 69 | combined_attention_mask = ( 70 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 71 | ) 72 | 73 | return combined_attention_mask 74 | 75 | def forward( 76 | self, 77 | input_ids: torch.LongTensor = None, 78 | attention_mask: Optional[torch.Tensor] = None, 79 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 80 | encoder_attention_mask: Optional[torch.LongTensor] = None, 81 | head_mask: Optional[torch.Tensor] = None, 82 | cross_attn_head_mask: Optional[torch.Tensor] = None, 83 | past_key_values: Optional[List[torch.FloatTensor]] = None, 84 | inputs_embeds: Optional[torch.FloatTensor] = None, 85 | use_cache: Optional[bool] = None, 86 | output_attentions: Optional[bool] = None, 87 | output_hidden_states: Optional[bool] = None, 88 | return_dict: Optional[bool] = None, 89 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 90 | r""" 91 | Args: 92 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 93 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 94 | provide it. 95 | 96 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 97 | [`PreTrainedTokenizer.__call__`] for details. 98 | 99 | [What are input IDs?](../glossary#input-ids) 100 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 101 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 102 | 103 | - 1 for tokens that are **not masked**, 104 | - 0 for tokens that are **masked**. 105 | 106 | [What are attention masks?](../glossary#attention-mask) 107 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): 108 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 109 | of the decoder. 110 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): 111 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 112 | selected in `[0, 1]`: 113 | 114 | - 1 for tokens that are **not masked**, 115 | - 0 for tokens that are **masked**. 116 | 117 | [What are attention masks?](../glossary#attention-mask) 118 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 119 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 120 | 121 | - 1 indicates the head is **not masked**, 122 | - 0 indicates the head is **masked**. 123 | 124 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 125 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing 126 | cross-attention on hidden heads. Mask values selected in `[0, 1]`: 127 | 128 | - 1 indicates the head is **not masked**, 129 | - 0 indicates the head is **masked**. 130 | 131 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 132 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 133 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 134 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 135 | 136 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 137 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 138 | 139 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 140 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 141 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of 142 | shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing 143 | `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more 144 | control over how to convert `input_ids` indices into associated vectors than the model's internal 145 | embedding lookup matrix. 146 | output_attentions (`bool`, *optional*): 147 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 148 | returned tensors for more detail. 149 | output_hidden_states (`bool`, *optional*): 150 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 151 | for more detail. 152 | return_dict (`bool`, *optional*): 153 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 154 | """ 155 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 156 | output_hidden_states = ( 157 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 158 | ) 159 | use_cache = use_cache if use_cache is not None else self.config.use_cache 160 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 161 | 162 | # retrieve input_ids and inputs_embeds 163 | if input_ids is not None and inputs_embeds is not None: 164 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 165 | elif input_ids is not None: 166 | input = input_ids 167 | input_shape = input.shape 168 | input_ids = input_ids.view(-1, input_shape[-1]) 169 | elif inputs_embeds is not None: 170 | input_shape = inputs_embeds.size()[:-1] 171 | input = inputs_embeds[:, :, -1] 172 | else: 173 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 174 | 175 | # past_key_values_length 176 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 177 | 178 | if inputs_embeds is None: 179 | inputs_embeds = self.embed_tokens(input) * self.embed_scale 180 | 181 | attention_mask = self._prepare_decoder_attention_mask( 182 | attention_mask, input_shape, inputs_embeds, past_key_values_length 183 | ) 184 | 185 | # expand encoder attention mask 186 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 187 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 188 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 189 | # embed positions 190 | positions = self.embed_positions(input, past_key_values_length) 191 | positions = positions.to(inputs_embeds.device) 192 | 193 | hidden_states = inputs_embeds + positions 194 | hidden_states = self.layernorm_embedding(hidden_states) 195 | 196 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 197 | 198 | # decoder layers 199 | all_hidden_states = () if output_hidden_states else None 200 | all_self_attns = () if output_attentions else None 201 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 202 | next_decoder_cache = () if use_cache else None 203 | 204 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired 205 | for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): 206 | if attn_mask is not None: 207 | if attn_mask.size()[0] != (len(self.layers)): 208 | raise ValueError( 209 | f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" 210 | f" {head_mask.size()[0]}." 211 | ) 212 | 213 | for idx, decoder_layer in enumerate(self.layers): 214 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 215 | if output_hidden_states: 216 | all_hidden_states += (hidden_states,) 217 | dropout_probability = random.uniform(0, 1) 218 | if self.training and (dropout_probability < self.layerdrop): 219 | continue 220 | 221 | past_key_value = past_key_values[idx] if past_key_values is not None else None 222 | 223 | if self.gradient_checkpointing and self.training: 224 | if use_cache: 225 | logger.warning( 226 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 227 | ) 228 | use_cache = False 229 | 230 | def create_custom_forward(module): 231 | def custom_forward(*inputs): 232 | # None for past_key_value 233 | return module(*inputs, output_attentions, use_cache) 234 | 235 | return custom_forward 236 | 237 | layer_outputs = torch.utils.checkpoint.checkpoint( 238 | create_custom_forward(decoder_layer), 239 | hidden_states, 240 | attention_mask, 241 | encoder_hidden_states, 242 | encoder_attention_mask, 243 | head_mask[idx] if head_mask is not None else None, 244 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, 245 | None, 246 | ) 247 | else: 248 | layer_outputs = decoder_layer( 249 | hidden_states, 250 | attention_mask=attention_mask, 251 | encoder_hidden_states=encoder_hidden_states, 252 | encoder_attention_mask=encoder_attention_mask, 253 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 254 | cross_attn_layer_head_mask=( 255 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None 256 | ), 257 | past_key_value=past_key_value, 258 | output_attentions=output_attentions, 259 | use_cache=use_cache, 260 | ) 261 | hidden_states = layer_outputs[0] 262 | 263 | if use_cache: 264 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 265 | 266 | if output_attentions: 267 | all_self_attns += (layer_outputs[1],) 268 | 269 | if encoder_hidden_states is not None: 270 | all_cross_attentions += (layer_outputs[2],) 271 | 272 | # add hidden states from the last decoder layer 273 | if output_hidden_states: 274 | all_hidden_states += (hidden_states,) 275 | 276 | next_cache = next_decoder_cache if use_cache else None 277 | if not return_dict: 278 | return tuple( 279 | v 280 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 281 | if v is not None 282 | ) 283 | return BaseModelOutputWithPastAndCrossAttentions( 284 | last_hidden_state=hidden_states, 285 | past_key_values=next_cache, 286 | hidden_states=all_hidden_states, 287 | attentions=all_self_attns, 288 | cross_attentions=all_cross_attentions, 289 | ) -------------------------------------------------------------------------------- /vc_lm/models/decoders/layers.py: -------------------------------------------------------------------------------- 1 | from vc_lm.models.bart.modeling_bart import * 2 | from vc_lm.models.base import VCLMConfig 3 | from vc_lm.models.misc import StageAdaLN 4 | 5 | class NARStageDecoderLayer(nn.Module): 6 | def __init__(self, config: VCLMConfig): 7 | super().__init__() 8 | self.embed_dim = config.d_model 9 | 10 | self.self_attn = BartAttention( 11 | embed_dim=self.embed_dim, 12 | num_heads=config.decoder_attention_heads, 13 | dropout=config.attention_dropout, 14 | is_decoder=True, 15 | ) 16 | self.dropout = config.dropout 17 | self.activation_fn = ACT2FN[config.activation_function] 18 | self.activation_dropout = config.activation_dropout 19 | 20 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 21 | self.self_attn_layer_norm = StageAdaLN(self.self_attn_layer_norm, config.n_q - 1) 22 | 23 | self.encoder_attn = BartAttention( 24 | self.embed_dim, 25 | config.decoder_attention_heads, 26 | dropout=config.attention_dropout, 27 | is_decoder=True, 28 | ) 29 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) 30 | self.encoder_attn_layer_norm = StageAdaLN(self.encoder_attn_layer_norm, config.n_q - 1) 31 | 32 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 33 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 34 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 35 | self.final_layer_norm = StageAdaLN(self.final_layer_norm, config.n_q - 1) 36 | 37 | 38 | def forward( 39 | self, 40 | hidden_states: torch.Tensor, 41 | attention_mask: Optional[torch.Tensor] = None, 42 | encoder_hidden_states: Optional[torch.Tensor] = None, 43 | encoder_attention_mask: Optional[torch.Tensor] = None, 44 | layer_head_mask: Optional[torch.Tensor] = None, 45 | cross_attn_layer_head_mask: Optional[torch.Tensor] = None, 46 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 47 | output_attentions: Optional[bool] = False, 48 | use_cache: Optional[bool] = True, 49 | nar_stage: torch.LongTensor = None, 50 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 51 | """ 52 | Args: 53 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 54 | attention_mask (`torch.FloatTensor`): attention mask of size 55 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 56 | encoder_hidden_states (`torch.FloatTensor`): 57 | cross attention input to the layer of shape `(batch, seq_len, embed_dim)` 58 | encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size 59 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 60 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 61 | `(encoder_attention_heads,)`. 62 | cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of 63 | size `(decoder_attention_heads,)`. 64 | past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states 65 | output_attentions (`bool`, *optional*): 66 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 67 | returned tensors for more detail. 68 | nar_stage: (`torch.LongTensor` of shape `(batch_size,)`) 69 | """ 70 | residual = hidden_states 71 | 72 | # Self Attention 73 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 74 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 75 | # add present self-attn cache to positions 1,2 of present_key_value tuple 76 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 77 | hidden_states=hidden_states, 78 | past_key_value=self_attn_past_key_value, 79 | attention_mask=attention_mask, 80 | layer_head_mask=layer_head_mask, 81 | output_attentions=output_attentions, 82 | ) 83 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 84 | hidden_states = residual + hidden_states 85 | hidden_states = self.self_attn_layer_norm(hidden_states, nar_stage) 86 | 87 | # Cross-Attention Block 88 | cross_attn_present_key_value = None 89 | cross_attn_weights = None 90 | if encoder_hidden_states is not None: 91 | residual = hidden_states 92 | 93 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 94 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 95 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 96 | hidden_states=hidden_states, 97 | key_value_states=encoder_hidden_states, 98 | attention_mask=encoder_attention_mask, 99 | layer_head_mask=cross_attn_layer_head_mask, 100 | past_key_value=cross_attn_past_key_value, 101 | output_attentions=output_attentions, 102 | ) 103 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 104 | hidden_states = residual + hidden_states 105 | hidden_states = self.encoder_attn_layer_norm(hidden_states, nar_stage) 106 | 107 | # add cross-attn to positions 3,4 of present_key_value tuple 108 | present_key_value = present_key_value + cross_attn_present_key_value 109 | 110 | # Fully Connected 111 | residual = hidden_states 112 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 113 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 114 | hidden_states = self.fc2(hidden_states) 115 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 116 | hidden_states = residual + hidden_states 117 | hidden_states = self.final_layer_norm(hidden_states, nar_stage) 118 | 119 | outputs = (hidden_states,) 120 | 121 | if output_attentions: 122 | outputs += (self_attn_weights, cross_attn_weights) 123 | 124 | if use_cache: 125 | outputs += (present_key_value,) 126 | 127 | return outputs -------------------------------------------------------------------------------- /vc_lm/models/decoders/nar_decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | 5 | from torch import nn 6 | 7 | from typing import Optional, Union, List, Tuple 8 | from vc_lm.models.bart.modeling_bart import BartLearnedPositionalEmbedding, BartDecoderLayer, _make_causal_mask, _expand_mask, BaseModelOutputWithPastAndCrossAttentions 9 | 10 | from transformers.utils import logging 11 | 12 | from vc_lm.models.base import VCLMConfig, VCLMPretrainedModel 13 | from vc_lm.models.misc import StageAdaLN 14 | from vc_lm.models.decoders.layers import NARStageDecoderLayer 15 | from vc_lm.datamodules.datasets.nar_dataset import NARDataset 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | class AccumulateMultiStageEmbedding(nn.Module): 21 | def __init__(self, embed_tokens: nn.Embedding, 22 | q_size: int = 1024): 23 | """AccumulateMultiStageEmbedding""" 24 | super().__init__() 25 | self.embed_tokens = embed_tokens 26 | self.q_size = q_size 27 | 28 | def forward(self, multistage_code: torch.LongTensor): 29 | """ 30 | Args: 31 | multistage_code: (batch_size, stage_num, seq_len) 32 | Return: 33 | multistage_code_emb: (batch_size, seq_len, dim) 34 | """ 35 | stage_id = torch.arange(0, multistage_code.shape[1], device=multistage_code.device)[None, ..., None] 36 | multistage_code = stage_id * self.q_size + multistage_code 37 | multistage_code[multistage_code >= NARDataset._PAD_ID] = NARDataset._PAD_ID 38 | # (batch_size, stage_num, seq_len, dim) 39 | multistage_code = self.embed_tokens(multistage_code) 40 | return torch.sum(multistage_code, dim=1) 41 | 42 | 43 | class NARDecoder(VCLMPretrainedModel): 44 | """ 45 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] 46 | 47 | Args: 48 | config: BartConfig 49 | embed_tokens (nn.Embedding): output embedding 50 | """ 51 | 52 | def __init__(self, config: VCLMConfig, embed_tokens: Optional[nn.Embedding] = None): 53 | super().__init__(config) 54 | self.dropout = config.dropout 55 | self.layerdrop = config.decoder_layerdrop 56 | self.padding_idx = config.pad_token_id 57 | self.max_target_positions = config.max_position_embeddings 58 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 59 | 60 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 61 | 62 | self.stage_embed = nn.Embedding(config.n_q, config.d_model) 63 | 64 | if embed_tokens is not None: 65 | self.embed_tokens.weight = embed_tokens.weight 66 | 67 | self.embed_positions = BartLearnedPositionalEmbedding( 68 | config.max_position_embeddings, 69 | config.d_model, 70 | ) 71 | self.style_positions = BartLearnedPositionalEmbedding(config.style_length, config.d_model) 72 | self.accumulate_multistage_embedding_layer = AccumulateMultiStageEmbedding(self.embed_tokens, 73 | q_size=config.q_size) 74 | self.register_buffer('style_mask', torch.ones((1, config.style_length), dtype=torch.int64)) 75 | 76 | self.layers = nn.ModuleList([NARStageDecoderLayer(config) for _ in range(config.decoder_layers)]) 77 | self.layernorm_embedding = StageAdaLN(nn.LayerNorm(config.d_model), 78 | num_stage=config.n_q - 1) 79 | 80 | self.gradient_checkpointing = False 81 | # Initialize weights and apply final processing 82 | self.post_init() 83 | 84 | def get_input_embeddings(self): 85 | return self.embed_tokens 86 | 87 | def set_input_embeddings(self, value): 88 | self.embed_tokens = value 89 | 90 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, 91 | past_key_values_length=None): 92 | # create causal mask 93 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 94 | combined_attention_mask = None 95 | if input_shape[-1] > 1: 96 | combined_attention_mask = _make_causal_mask( 97 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 98 | ).to(inputs_embeds.device) 99 | 100 | if attention_mask is not None: 101 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 102 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 103 | inputs_embeds.device 104 | ) 105 | combined_attention_mask = ( 106 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 107 | ) 108 | 109 | return combined_attention_mask 110 | 111 | def forward( 112 | self, 113 | input_code: torch.LongTensor = None, 114 | attention_mask: Optional[torch.Tensor] = None, 115 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 116 | encoder_attention_mask: Optional[torch.LongTensor] = None, 117 | style_code: torch.LongTensor = None, 118 | nar_stage: torch.LongTensor = None, 119 | ): 120 | r""" 121 | Args: 122 | input_code: (`torch.LongTensor` of shape `(batch_size, num_stage, sequence_length)`) 123 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 124 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 125 | 126 | - 1 for tokens that are **not masked**, 127 | - 0 for tokens that are **masked**. 128 | 129 | [What are attention masks?](../glossary#attention-mask) 130 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): 131 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 132 | of the decoder. 133 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): 134 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 135 | selected in `[0, 1]`: 136 | 137 | - 1 for tokens that are **not masked**, 138 | - 0 for tokens that are **masked**. 139 | 140 | [What are attention masks?](../glossary#attention-mask) 141 | style_code: (`torch.LongTensor` of shape `(batch_size, num_q, style_len)`) 142 | nar_stage: (`torch.LongTensor` of shape `(batch_size)`) 143 | """ 144 | batch_size = input_code.shape[0] 145 | # get input_code_embeds 146 | # (batch_size, seq_len, dim) 147 | input_code_embeds = self.accumulate_multistage_embedding_layer(input_code) 148 | 149 | # add stage embedding 150 | input_code_embeds = self.stage_embed(nar_stage).unsqueeze(1) + input_code_embeds 151 | 152 | # (batch_size, style_len, dim) 153 | style_code_embeds = self.accumulate_multistage_embedding_layer(style_code) 154 | # (batch_size, style_len + seq_len, dim) 155 | inputs_embeds = torch.cat([style_code_embeds, input_code_embeds], 1) * self.embed_scale 156 | # pad style_mask: attention_mask (batch_size, style_len + seq_len) 157 | attention_mask = torch.cat([self.style_mask.expand(batch_size, -1), attention_mask], 1) 158 | input_shape = attention_mask.shape 159 | 160 | attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) 161 | 162 | # expand encoder attention mask 163 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 164 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 165 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 166 | # embed positions 167 | style_positions = self.style_positions(style_code[:, 0, :], past_key_values_length=0) 168 | input_code_positions = self.embed_positions(input_code[:, 0, :], past_key_values_length=0) 169 | positions = torch.cat([style_positions, input_code_positions], 1) 170 | positions = positions.to(inputs_embeds.device) 171 | 172 | hidden_states = inputs_embeds + positions 173 | hidden_states = self.layernorm_embedding(hidden_states, nar_stage) 174 | 175 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 176 | 177 | # decoder layers 178 | for idx, decoder_layer in enumerate(self.layers): 179 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 180 | dropout_probability = random.uniform(0, 1) 181 | if self.training and (dropout_probability < self.layerdrop): 182 | continue 183 | 184 | past_key_value = None 185 | 186 | if self.gradient_checkpointing and self.training: 187 | if use_cache: 188 | logger.warning( 189 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 190 | ) 191 | use_cache = False 192 | 193 | def create_custom_forward(module): 194 | def custom_forward(*inputs): 195 | # None for past_key_value 196 | return module(*inputs, None, use_cache) 197 | 198 | return custom_forward 199 | 200 | layer_outputs = torch.utils.checkpoint.checkpoint( 201 | create_custom_forward(decoder_layer), 202 | hidden_states, 203 | attention_mask, 204 | encoder_hidden_states, 205 | encoder_attention_mask, 206 | None, 207 | None, 208 | None, 209 | ) 210 | else: 211 | layer_outputs = decoder_layer( 212 | hidden_states, 213 | attention_mask=attention_mask, 214 | encoder_hidden_states=encoder_hidden_states, 215 | encoder_attention_mask=encoder_attention_mask, 216 | layer_head_mask=None, 217 | cross_attn_layer_head_mask=None, 218 | past_key_value=past_key_value, 219 | output_attentions=None, 220 | use_cache=False, 221 | nar_stage=nar_stage 222 | ) 223 | hidden_states = layer_outputs[0] 224 | 225 | # extract code hidden_states 226 | return hidden_states[:, self.config.style_length:] 227 | -------------------------------------------------------------------------------- /vc_lm/models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/models/encoders/__init__.py -------------------------------------------------------------------------------- /vc_lm/models/encoders/whisper_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import random 4 | from torch import nn 5 | from typing import List, Optional, Union, Tuple 6 | 7 | import pytorch_lightning as pl 8 | 9 | from whisper.model import AudioEncoder 10 | 11 | from vc_lm.models.base import VCLMPretrainedModel, VCLMConfig 12 | 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutput, 15 | BaseModelOutputWithPastAndCrossAttentions, 16 | CausalLMOutputWithCrossAttentions, 17 | Seq2SeqLMOutput, 18 | Seq2SeqModelOutput, 19 | Seq2SeqQuestionAnsweringModelOutput, 20 | Seq2SeqSequenceClassifierOutput, 21 | ) 22 | 23 | from vc_lm.models.bart.modeling_bart import BartLearnedPositionalEmbedding, BartEncoderLayer, _expand_mask 24 | from vc_lm.datamodules.datasets.ar_dataset import ARDataset 25 | 26 | class ContentEncoder(VCLMPretrainedModel): 27 | def __init__(self, config: VCLMConfig): 28 | super().__init__(config) 29 | 30 | self.dropout = config.dropout 31 | self.layerdrop = config.encoder_layerdrop 32 | 33 | embed_dim = config.d_model 34 | self.padding_idx = config.pad_token_id 35 | self.max_source_positions = 1500 36 | 37 | self.embed_positions = BartLearnedPositionalEmbedding( 38 | self.max_source_positions, 39 | embed_dim, 40 | ) 41 | self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.content_layer_num)]) 42 | self.layernorm_embedding = nn.LayerNorm(embed_dim) 43 | 44 | self.gradient_checkpointing = False 45 | # Initialize weights and apply final processing 46 | self.post_init() 47 | 48 | def forward(self, 49 | inputs: torch.FloatTensor = None, 50 | attention_mask: Optional[torch.Tensor] = None): 51 | embed_pos = self.embed_positions(inputs) 52 | embed_pos = embed_pos.to(inputs.device) 53 | 54 | hidden_states = inputs + embed_pos 55 | hidden_states = self.layernorm_embedding(hidden_states) 56 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 57 | 58 | # expand attention_mask 59 | if attention_mask is not None: 60 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 61 | attention_mask = _expand_mask(attention_mask, inputs.dtype) 62 | 63 | for idx, encoder_layer in enumerate(self.layers): 64 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 65 | dropout_probability = random.uniform(0, 1) 66 | if self.training and (dropout_probability < self.layerdrop): # skip the layer 67 | layer_outputs = (None, None) 68 | else: 69 | if self.gradient_checkpointing and self.training: 70 | 71 | def create_custom_forward(module): 72 | def custom_forward(*inputs): 73 | return module(*inputs, output_attention=None) 74 | 75 | return custom_forward 76 | 77 | layer_outputs = torch.utils.checkpoint.checkpoint( 78 | create_custom_forward(encoder_layer), 79 | hidden_states, 80 | attention_mask, 81 | None, 82 | ) 83 | else: 84 | layer_outputs = encoder_layer( 85 | hidden_states, 86 | attention_mask, 87 | layer_head_mask=None, 88 | output_attentions=None, 89 | ) 90 | 91 | hidden_states = layer_outputs[0] 92 | return hidden_states 93 | 94 | 95 | class WhisperEncoder(VCLMPretrainedModel): 96 | """ 97 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a 98 | [`BartEncoderLayer`]. 99 | Args: 100 | config: BartConfig 101 | embed_tokens (nn.Embedding): output embedding 102 | """ 103 | 104 | def __init__(self, config: VCLMConfig, embed_tokens: Optional[nn.Embedding] = None): 105 | super().__init__(config) 106 | checkpoint = torch.load(config.encoder_model_path) 107 | self.audio_encoder = AudioEncoder(n_mels=checkpoint['dims']['n_mels'], 108 | n_ctx=checkpoint['dims']['n_audio_ctx'], 109 | n_state=checkpoint['dims']['n_audio_state'], 110 | n_head=checkpoint['dims']['n_audio_head'], 111 | n_layer=checkpoint['dims']['n_audio_layer']) 112 | if config.content_layer_num >= 0: 113 | self.content_encoder = ContentEncoder(config) 114 | else: 115 | self.content_encoder = None 116 | # 117 | self.linear1 = nn.Linear(checkpoint['dims']['n_audio_state'], 256) 118 | self.activate_fn = nn.GELU() 119 | self.linear2 = nn.Linear(256, config.d_model) 120 | 121 | self.load_pretrained_whisper_params() 122 | 123 | def get_input_embeddings(self): 124 | return self.embed_tokens 125 | 126 | def set_input_embeddings(self, value): 127 | self.embed_tokens = value 128 | 129 | def load_pretrained_whisper_params(self): 130 | checkpoint = torch.load(self.config.encoder_model_path) 131 | self.audio_encoder.load_state_dict(checkpoint['model_state_dict']) 132 | 133 | def forward( 134 | self, 135 | input_ids: torch.Tensor = None, 136 | attention_mask: Optional[torch.Tensor] = None, 137 | head_mask: Optional[torch.Tensor] = None, 138 | inputs_embeds: Optional[torch.FloatTensor] = None, 139 | output_attentions: Optional[bool] = None, 140 | output_hidden_states: Optional[bool] = None, 141 | return_dict: Optional[bool] = None, 142 | ) -> Union[Tuple, BaseModelOutput]: 143 | r""" 144 | Args: 145 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 146 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 147 | provide it. 148 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 149 | [`PreTrainedTokenizer.__call__`] for details. 150 | [What are input IDs?](../glossary#input-ids) 151 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 152 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 153 | - 1 for tokens that are **not masked**, 154 | - 0 for tokens that are **masked**. 155 | [What are attention masks?](../glossary#attention-mask) 156 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 157 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 158 | - 1 indicates the head is **not masked**, 159 | - 0 indicates the head is **masked**. 160 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 161 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 162 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 163 | than the model's internal embedding lookup matrix. 164 | output_attentions (`bool`, *optional*): 165 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 166 | returned tensors for more detail. 167 | output_hidden_states (`bool`, *optional*): 168 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 169 | for more detail. 170 | return_dict (`bool`, *optional*): 171 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 172 | """ 173 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 174 | output_hidden_states = ( 175 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 176 | ) 177 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 178 | 179 | hidden_states = self.audio_encoder.forward(input_ids) 180 | 181 | if self.content_encoder is not None: 182 | hidden_states = self.content_encoder(hidden_states, attention_mask=attention_mask) 183 | # project 184 | hidden_states = self.linear1(hidden_states) 185 | hidden_states = self.activate_fn(hidden_states) 186 | hidden_states = self.linear2(hidden_states) 187 | # # 188 | # hidden_states = torch.zeros_like(hidden_states, device=hidden_states.device) * hidden_states 189 | # # 190 | 191 | if not return_dict: 192 | return tuple(v for v in [hidden_states, hidden_states, attention_mask] if v is not None) 193 | return BaseModelOutput( 194 | last_hidden_state=hidden_states, hidden_states=hidden_states, attentions=attention_mask 195 | ) 196 | 197 | def freeze(self, only_whisper=False) -> None: 198 | r""" 199 | Freeze all params for inference. 200 | 201 | Example:: 202 | 203 | model = MyLightningModule(...) 204 | model.freeze() 205 | 206 | """ 207 | if only_whisper: 208 | for param in self.audio_encoder.parameters(): 209 | param.requires_grad = False 210 | self.audio_encoder.eval() 211 | else: 212 | for param in self.parameters(): 213 | param.requires_grad = False 214 | self.eval() 215 | 216 | def unfreeze(self) -> None: 217 | """Unfreeze all parameters for training. 218 | 219 | .. code-block:: python 220 | 221 | model = MyLightningModule(...) 222 | model.unfreeze() 223 | """ 224 | for param in self.parameters(): 225 | param.requires_grad = True 226 | 227 | self.train() 228 | -------------------------------------------------------------------------------- /vc_lm/models/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import optim 4 | import numpy as np 5 | 6 | class StageAdaLN(nn.Module): 7 | def __init__(self, layer_norm: nn.LayerNorm, 8 | num_stage: int): 9 | super().__init__() 10 | self.layer_norm = layer_norm 11 | self.num_stage = num_stage 12 | self.register_parameter('stage_w', nn.Parameter(torch.ones((num_stage,) + self.layer_norm.normalized_shape))) 13 | self.register_parameter('stage_b', nn.Parameter(torch.zeros((num_stage,) + self.layer_norm.normalized_shape))) 14 | 15 | def forward(self, x: torch.Tensor, stage_id: torch.Tensor): 16 | """ 17 | Args: 18 | x: torch.Tensor (batch_size, ..., dim) 19 | stage_id: torch.Tensor (batch_size,) 20 | Return: 21 | y: torch.Tensor (batch_size, ..., dim) 22 | """ 23 | y = self.layer_norm(x) 24 | expand_number = y.ndim - len(self.layer_norm.normalized_shape) - 1 25 | # (batch_size, *dim) 26 | c_w, c_b = self.stage_w[stage_id], self.stage_b[stage_id] 27 | for i in range(expand_number): 28 | c_w = c_w.unsqueeze(1) 29 | c_b = c_b.unsqueeze(1) 30 | return y * c_w + c_b 31 | 32 | class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler): 33 | def __init__(self, optimizer, warmup, max_iters, 34 | min_lr=1e-5): 35 | self.warmup = warmup 36 | self.max_num_iters = max_iters 37 | self.min_lr = min_lr 38 | super().__init__(optimizer) 39 | 40 | def get_lr(self): 41 | lr_factor = self.get_lr_factor(epoch=self.last_epoch) 42 | return [max(base_lr * lr_factor, self.min_lr) for base_lr in self.base_lrs] 43 | 44 | def get_lr_factor(self, epoch): 45 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) 46 | if epoch <= self.warmup: 47 | lr_factor *= epoch * 1.0 / self.warmup 48 | return lr_factor 49 | 50 | -------------------------------------------------------------------------------- /vc_lm/models/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/models/models/__init__.py -------------------------------------------------------------------------------- /vc_lm/models/models/ar_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import warnings 4 | from typing import List, Optional, Tuple, Union 5 | from transformers.modeling_outputs import ( 6 | BaseModelOutput, 7 | BaseModelOutputWithPastAndCrossAttentions, 8 | CausalLMOutputWithCrossAttentions, 9 | Seq2SeqLMOutput, 10 | Seq2SeqModelOutput, 11 | Seq2SeqQuestionAnsweringModelOutput, 12 | Seq2SeqSequenceClassifierOutput, 13 | ) 14 | 15 | from vc_lm.models.base import VCLMPretrainedModel, VCLMConfig 16 | from vc_lm.models.encoders.whisper_encoder import WhisperEncoder 17 | from vc_lm.models.decoders.ar_decoder import ARDecoder 18 | 19 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 20 | """ 21 | Shift input ids one token to the right. 22 | """ 23 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 24 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 25 | shifted_input_ids[:, 0] = decoder_start_token_id 26 | 27 | if pad_token_id is None: 28 | raise ValueError("self.model.config.pad_token_id has to be defined.") 29 | # replace possible -100 values in labels by `pad_token_id` 30 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 31 | 32 | return shifted_input_ids 33 | 34 | class ARModel(VCLMPretrainedModel): 35 | _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] 36 | 37 | def __init__(self, config: VCLMConfig): 38 | super().__init__(config) 39 | 40 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 41 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 42 | 43 | self.encoder = WhisperEncoder(config) 44 | self.encoder.freeze(only_whisper=True) 45 | self.decoder = ARDecoder(config, self.shared) 46 | # Initialize weights and apply final processing 47 | self.post_init() 48 | 49 | def get_input_embeddings(self): 50 | return self.shared 51 | 52 | def set_input_embeddings(self, value): 53 | self.shared = value 54 | self.decoder.embed_tokens = self.shared 55 | 56 | def get_encoder(self): 57 | return self.encoder 58 | 59 | def get_decoder(self): 60 | return self.decoder 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.FloatTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | decoder_input_ids: Optional[torch.LongTensor] = None, 67 | decoder_attention_mask: Optional[torch.LongTensor] = None, 68 | head_mask: Optional[torch.Tensor] = None, 69 | decoder_head_mask: Optional[torch.Tensor] = None, 70 | cross_attn_head_mask: Optional[torch.Tensor] = None, 71 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 72 | past_key_values: Optional[List[torch.FloatTensor]] = None, 73 | inputs_embeds: Optional[torch.FloatTensor] = None, 74 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | return_dict: Optional[bool] = None, 79 | ) -> Union[Tuple, Seq2SeqModelOutput]: 80 | 81 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 82 | output_hidden_states = ( 83 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 84 | ) 85 | use_cache = use_cache if use_cache is not None else self.config.use_cache 86 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 87 | 88 | if encoder_outputs is None: 89 | encoder_outputs = self.encoder(input_ids) 90 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 91 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 92 | encoder_outputs = BaseModelOutput( 93 | last_hidden_state=encoder_outputs[0], 94 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 95 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 96 | ) 97 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 98 | decoder_outputs = self.decoder( 99 | input_ids=decoder_input_ids, 100 | attention_mask=decoder_attention_mask, 101 | encoder_hidden_states=encoder_outputs[0], 102 | encoder_attention_mask=attention_mask, 103 | head_mask=decoder_head_mask, 104 | cross_attn_head_mask=cross_attn_head_mask, 105 | past_key_values=past_key_values, 106 | inputs_embeds=decoder_inputs_embeds, 107 | use_cache=use_cache, 108 | output_attentions=output_attentions, 109 | output_hidden_states=output_hidden_states, 110 | return_dict=return_dict, 111 | ) 112 | if not return_dict: 113 | return decoder_outputs + encoder_outputs 114 | 115 | return Seq2SeqModelOutput( 116 | last_hidden_state=decoder_outputs.last_hidden_state, 117 | past_key_values=decoder_outputs.past_key_values, 118 | decoder_hidden_states=decoder_outputs.hidden_states, 119 | decoder_attentions=decoder_outputs.attentions, 120 | cross_attentions=decoder_outputs.cross_attentions, 121 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 122 | encoder_hidden_states=encoder_outputs.hidden_states, 123 | encoder_attentions=encoder_outputs.attentions, 124 | ) 125 | 126 | class ARModelForConditionalGeneration(VCLMPretrainedModel): 127 | base_model_prefix = "model" 128 | _keys_to_ignore_on_load_missing = [ 129 | r"final_logits_bias", 130 | r"lm_head.weight", 131 | "encoder.embed_tokens.weight", 132 | "decoder.embed_tokens.weight", 133 | ] 134 | 135 | def __init__(self, config: VCLMConfig): 136 | super().__init__(config) 137 | self.model = ARModel(config) 138 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 139 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 140 | 141 | # Initialize weights and apply final processing 142 | self.post_init() 143 | 144 | def get_encoder(self): 145 | return self.model.get_encoder() 146 | 147 | def get_decoder(self): 148 | return self.model.get_decoder() 149 | 150 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 151 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 152 | self._resize_final_logits_bias(new_num_tokens) 153 | return new_embeddings 154 | 155 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 156 | old_num_tokens = self.final_logits_bias.shape[-1] 157 | if new_num_tokens <= old_num_tokens: 158 | new_bias = self.final_logits_bias[:, :new_num_tokens] 159 | else: 160 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 161 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 162 | self.register_buffer("final_logits_bias", new_bias) 163 | 164 | def get_output_embeddings(self): 165 | return self.lm_head 166 | 167 | def set_output_embeddings(self, new_embeddings): 168 | self.lm_head = new_embeddings 169 | 170 | def forward( 171 | self, 172 | input_ids: torch.LongTensor = None, 173 | attention_mask: Optional[torch.Tensor] = None, 174 | decoder_input_ids: Optional[torch.LongTensor] = None, 175 | decoder_attention_mask: Optional[torch.LongTensor] = None, 176 | head_mask: Optional[torch.Tensor] = None, 177 | decoder_head_mask: Optional[torch.Tensor] = None, 178 | cross_attn_head_mask: Optional[torch.Tensor] = None, 179 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 180 | past_key_values: Optional[List[torch.FloatTensor]] = None, 181 | inputs_embeds: Optional[torch.FloatTensor] = None, 182 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 183 | labels: Optional[torch.LongTensor] = None, 184 | use_cache: Optional[bool] = None, 185 | output_attentions: Optional[bool] = None, 186 | output_hidden_states: Optional[bool] = None, 187 | return_dict: Optional[bool] = None, 188 | ) -> Union[Tuple, Seq2SeqLMOutput]: 189 | r""" 190 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 191 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 192 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 193 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 194 | Returns: 195 | """ 196 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 197 | 198 | if labels is not None: 199 | if use_cache: 200 | logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") 201 | use_cache = False 202 | if decoder_input_ids is None and decoder_inputs_embeds is None: 203 | decoder_input_ids = shift_tokens_right( 204 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 205 | ) 206 | 207 | outputs = self.model( 208 | input_ids, 209 | attention_mask=attention_mask, 210 | decoder_input_ids=decoder_input_ids, 211 | encoder_outputs=encoder_outputs, 212 | decoder_attention_mask=decoder_attention_mask, 213 | head_mask=head_mask, 214 | decoder_head_mask=decoder_head_mask, 215 | cross_attn_head_mask=cross_attn_head_mask, 216 | past_key_values=past_key_values, 217 | inputs_embeds=inputs_embeds, 218 | decoder_inputs_embeds=decoder_inputs_embeds, 219 | use_cache=use_cache, 220 | output_attentions=output_attentions, 221 | output_hidden_states=output_hidden_states, 222 | return_dict=return_dict, 223 | ) 224 | 225 | lm_logits = self.lm_head(outputs[0]) 226 | lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) 227 | 228 | masked_lm_loss = 0.0 229 | 230 | if not return_dict: 231 | output = (lm_logits,) + outputs[1:] 232 | return output 233 | 234 | return Seq2SeqLMOutput( 235 | loss=masked_lm_loss, 236 | logits=lm_logits, 237 | past_key_values=outputs.past_key_values, 238 | decoder_hidden_states=outputs.decoder_hidden_states, 239 | decoder_attentions=outputs.decoder_attentions, 240 | cross_attentions=outputs.cross_attentions, 241 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 242 | encoder_hidden_states=outputs.encoder_hidden_states, 243 | encoder_attentions=outputs.encoder_attentions, 244 | ) 245 | 246 | def prepare_inputs_for_generation( 247 | self, 248 | decoder_input_ids, 249 | past_key_values=None, 250 | attention_mask=None, 251 | decoder_attention_mask=None, 252 | head_mask=None, 253 | decoder_head_mask=None, 254 | cross_attn_head_mask=None, 255 | use_cache=None, 256 | encoder_outputs=None, 257 | **kwargs, 258 | ): 259 | # cut decoder_input_ids if past_key_values is used 260 | if past_key_values is not None: 261 | decoder_input_ids = decoder_input_ids[:, -1:] 262 | 263 | return { 264 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 265 | "encoder_outputs": encoder_outputs, 266 | "past_key_values": past_key_values, 267 | "decoder_input_ids": decoder_input_ids, 268 | "attention_mask": attention_mask, 269 | "decoder_attention_mask": decoder_attention_mask, 270 | "head_mask": head_mask, 271 | "decoder_head_mask": decoder_head_mask, 272 | "cross_attn_head_mask": cross_attn_head_mask, 273 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 274 | } 275 | 276 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 277 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 278 | 279 | @staticmethod 280 | def _reorder_cache(past_key_values, beam_idx): 281 | reordered_past = () 282 | for layer_past in past_key_values: 283 | # cached cross_attention states don't have to be reordered -> they are always the same 284 | reordered_past += ( 285 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 286 | ) 287 | return reordered_past -------------------------------------------------------------------------------- /vc_lm/models/models/nar_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import warnings 4 | from typing import List, Optional 5 | 6 | from vc_lm.models.base import VCLMPretrainedModel, VCLMConfig 7 | from vc_lm.models.encoders.whisper_encoder import WhisperEncoder 8 | from vc_lm.models.decoders.nar_decoder import NARDecoder 9 | 10 | 11 | class NARModel(VCLMPretrainedModel): 12 | _keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] 13 | 14 | def __init__(self, config: VCLMConfig): 15 | super().__init__(config) 16 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 17 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 18 | self.encoder = WhisperEncoder(config) 19 | self.encoder.freeze(only_whisper=True) 20 | self.decoder = NARDecoder(config, self.shared) 21 | self.num_q = config.n_q 22 | self.q_size = config.q_size 23 | # Initialize weights and apply final processing 24 | self.post_init() 25 | 26 | def get_input_embeddings(self): 27 | return self.shared 28 | 29 | def set_input_embeddings(self, value): 30 | self.shared = value 31 | self.decoder.embed_tokens = self.shared 32 | 33 | def get_encoder(self): 34 | return self.encoder 35 | 36 | def get_decoder(self): 37 | return self.decoder 38 | 39 | def forward( 40 | self, 41 | input_ids: torch.FloatTensor = None, 42 | attention_mask: Optional[torch.Tensor] = None, 43 | decoder_input_ids: Optional[torch.LongTensor] = None, 44 | decoder_attention_mask: Optional[torch.LongTensor] = None, 45 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 46 | style_code: torch.LongTensor = None, 47 | nar_stage: torch.LongTensor = None): 48 | if encoder_outputs is None: 49 | encoder_outputs = self.encoder(input_ids, 50 | attention_mask=attention_mask) 51 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 52 | # (batch_size, target_len, dims) 53 | decoder_outputs = self.decoder( 54 | input_code=decoder_input_ids, 55 | attention_mask=decoder_attention_mask, 56 | encoder_hidden_states=encoder_outputs[0], 57 | encoder_attention_mask=attention_mask, 58 | style_code=style_code, 59 | nar_stage=nar_stage 60 | ) 61 | # (num_q, q_size, dim) 62 | reshaped_emb = torch.reshape(self.shared.weight[0:self.q_size * self.num_q], (self.num_q, self.q_size, -1)) 63 | # (batch_size, q_size, dim) 64 | predicted_emb = reshaped_emb[nar_stage + 1] 65 | # (batch_size, target_len, q_size) 66 | logits = torch.einsum('btd,bqd->btq', decoder_outputs, predicted_emb) 67 | 68 | return decoder_outputs, logits 69 | -------------------------------------------------------------------------------- /vc_lm/models/nar_model_pl.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import json 4 | 5 | from typing import Any, List 6 | from torch import nn 7 | from torchmetrics.classification.accuracy import Accuracy 8 | from torch.optim import AdamW 9 | 10 | from vc_lm.models.base import VCLMConfig 11 | from vc_lm.models.models.nar_model import NARModel 12 | 13 | from vc_lm.models.misc import CosineWarmupScheduler 14 | from transformers.optimization import get_polynomial_decay_schedule_with_warmup 15 | 16 | 17 | class NARModelPL(pl.LightningModule): 18 | def __init__(self, config_file: str, 19 | lr: float = 0.001, 20 | weight_decay: float = 0.0005, 21 | warmup_step: int = 10000, 22 | max_iters: int = 800000, 23 | load_pretrain: bool = False, 24 | pretrain_model_path: str = None): 25 | super().__init__() 26 | self.save_hyperparameters() 27 | with open(config_file) as f: 28 | config = json.load(f) 29 | config = VCLMConfig(**config) 30 | self.model = NARModel(config) 31 | # load whisper parameter 32 | self.model.encoder.load_pretrained_whisper_params() 33 | self.loss_fct = nn.CrossEntropyLoss() 34 | 35 | if load_pretrain: 36 | loaded_state = torch.load(pretrain_model_path)['state_dict'] 37 | self.load_state_dict(loaded_state, 38 | strict=False) 39 | 40 | self.train_accuracy = Accuracy(task="multiclass", 41 | num_classes=self.model.shared.num_embeddings, 42 | average='micro', 43 | ignore_index=-100) 44 | self.val_accuracy = Accuracy(task="multiclass", 45 | num_classes=self.model.shared.num_embeddings, 46 | average='micro', 47 | ignore_index=-100) 48 | self.test_accuracy = Accuracy(task="multiclass", 49 | num_classes=self.model.shared.num_embeddings, 50 | average='micro', 51 | ignore_index=-100) 52 | 53 | def forward(self, 54 | input_mels=None, 55 | attention_mask=None, 56 | decoder_input_ids=None, 57 | decoder_attention_mask=None, 58 | encoder_outputs=None, 59 | style_code=None, 60 | nar_stage=None): 61 | outputs = self.model(input_ids=input_mels, 62 | attention_mask=attention_mask, 63 | decoder_input_ids=decoder_input_ids, 64 | decoder_attention_mask=decoder_attention_mask, 65 | encoder_outputs=encoder_outputs, 66 | style_code=style_code, 67 | nar_stage=nar_stage) 68 | return outputs 69 | 70 | def step(self, batch: Any): 71 | _, lm_logits = self.forward(input_mels=batch['mel'], 72 | attention_mask=batch['content_mask'], 73 | decoder_input_ids=batch['input_code'], 74 | decoder_attention_mask=batch['code_mask'], 75 | encoder_outputs=batch.get('encoder_outputs', None), 76 | style_code=batch['style_code'], 77 | nar_stage=batch['nar_stage']) 78 | output_code = batch['output_code'] 79 | output_code[output_code == self.model.config.pad_token_id] = -100 80 | 81 | lm_loss = self.loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), output_code.view(-1)) 82 | preds = torch.argmax(lm_logits, dim=-1) 83 | return lm_loss, preds, output_code 84 | 85 | def training_step(self, 86 | batch: Any, 87 | batch_idx: int): 88 | loss, preds, targets = self.step(batch) 89 | acc = self.train_accuracy(preds, targets) 90 | self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=False) 91 | self.log('train/acc', acc, on_step=True, on_epoch=True, prog_bar=True) 92 | # return { 93 | # 'loss': loss, 94 | # 'preds': preds, 95 | # 'targets': targets 96 | # } 97 | return {'loss': loss} 98 | 99 | def training_epoch_end(self, outputs: List[Any]): 100 | pass 101 | 102 | def validation_step(self, batch: Any, batch_idx: int): 103 | loss, preds, targets = self.step(batch) 104 | acc = self.val_accuracy(preds, targets) 105 | self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) 106 | self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True) 107 | #return {"loss": loss, "preds": preds, "targets": targets} 108 | return {'loss': loss} 109 | 110 | def validation_epoch_end(self, outputs: List[Any]): 111 | pass 112 | 113 | def test_step(self, batch: Any, batch_idx: int): 114 | loss, preds, targets = self.step(batch) 115 | # log test metrics 116 | acc = self.test_accuracy(preds, targets) 117 | self.log("test/loss", loss, on_step=False, on_epoch=True) 118 | self.log("test/acc", acc, on_step=False, on_epoch=True) 119 | #return {"loss": loss, "preds": preds, "targets": targets} 120 | return {'loss': loss} 121 | 122 | def test_epoch_end(self, outputs: List[Any]): 123 | pass 124 | 125 | def configure_optimizers(self) -> Any: 126 | no_decay = ["bias", "LayerNorm.weight"] 127 | optimizer_grouped_parameters = [ 128 | { 129 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 130 | "weight_decay": self.hparams.weight_decay, 131 | }, 132 | { 133 | "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 134 | "weight_decay": 0.0, 135 | }, 136 | ] 137 | optimizer = AdamW(optimizer_grouped_parameters, 138 | lr=self.hparams.lr, 139 | weight_decay=self.hparams.weight_decay) 140 | # scheduler = CosineWarmupScheduler(optimizer=optimizer, 141 | # warmup=self.hparams.warmup_step, 142 | # max_iters=self.hparams.max_iters) 143 | scheduler = get_polynomial_decay_schedule_with_warmup(optimizer=optimizer, 144 | num_warmup_steps=self.hparams.warmup_step, 145 | num_training_steps=self.hparams.max_iters) 146 | 147 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 148 | -------------------------------------------------------------------------------- /vc_lm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nilboy/vc-lm/f60856525a1effc622deb06537b91e5295996440/vc_lm/utils/__init__.py -------------------------------------------------------------------------------- /vc_lm/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | import numpy as np 4 | from typing import Any 5 | 6 | from encodec.utils import convert_audio 7 | 8 | import torchaudio 9 | import torch 10 | 11 | from whisper.audio import log_mel_spectrogram 12 | 13 | 14 | def pad_or_trim(array, length: int, axis: int = -1, 15 | pad_value=0): 16 | """ 17 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 18 | """ 19 | if torch.is_tensor(array): 20 | if array.shape[axis] > length: 21 | array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) 22 | 23 | if array.shape[axis] < length: 24 | pad_widths = [(0, 0)] * array.ndim 25 | pad_widths[axis] = (0, length - array.shape[axis]) 26 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes], 27 | value=pad_value) 28 | else: 29 | if array.shape[axis] > length: 30 | array = array.take(indices=range(length), axis=axis) 31 | 32 | if array.shape[axis] < length: 33 | pad_widths = [(0, 0)] * array.ndim 34 | pad_widths[axis] = (0, length - array.shape[axis]) 35 | array = np.pad(array, pad_widths, 36 | constant_values=pad_value) 37 | 38 | return array 39 | 40 | 41 | def get_code(audio: str, 42 | device: str, 43 | model: Any): 44 | wav, sr = torchaudio.load(audio) 45 | wav = convert_audio(wav, sr, model.sample_rate, model.channels) 46 | wav = wav.unsqueeze(0) 47 | wav = wav.cuda(device) 48 | # Extract discrete codes from EnCodec 49 | with torch.no_grad(): 50 | encoded_frames = model.encode(wav) 51 | code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] 52 | return code.cpu().numpy().astype(np.int16)[0] 53 | 54 | 55 | def get_mel_spectrogram(audio): 56 | mel = log_mel_spectrogram(audio) 57 | return mel.numpy() 58 | -------------------------------------------------------------------------------- /vc_lm/vc_engine.py: -------------------------------------------------------------------------------- 1 | from encodec import EncodecModel 2 | import torch 3 | import math 4 | from vc_lm.models.ar_model_pl import ARModelPL 5 | from vc_lm.models.nar_model_pl import NARModelPL 6 | from vc_lm.utils.data_utils import get_code, get_mel_spectrogram, pad_or_trim 7 | import soundfile as sf 8 | import librosa 9 | 10 | class VCEngine(object): 11 | def __init__(self, 12 | ar_model_path: str, 13 | nar_model_path: str, 14 | ar_config_file: str, 15 | nar_config_file: str, 16 | device: str = 'cuda:0'): 17 | self.device = device 18 | # Load AR model. 19 | ar_model = ARModelPL.load_from_checkpoint(ar_model_path, 20 | config_file=ar_config_file) 21 | self.ar_model = ar_model.eval().to(self.device).model 22 | # Load NAR model. 23 | nar_model = NARModelPL.load_from_checkpoint(nar_model_path, 24 | config_file=nar_config_file) 25 | self.nar_model = nar_model.eval().to(self.device).model 26 | 27 | self.config = self.nar_model.config 28 | # load encodec model 29 | encodec_model = EncodecModel.encodec_model_24khz() 30 | encodec_model.set_target_bandwidth(6.0) 31 | self.encodec_model = encodec_model.to(self.device) 32 | self.max_mel_audio_time = 30 33 | self.max_mel_len = 100 * self.max_mel_audio_time 34 | self.max_content_len = math.ceil(self.max_mel_len/2) 35 | 36 | def process_ar(self, content_mel, content_code, style_mel, style_code): 37 | # Process ARModel 38 | style_code_len = style_code.shape[1] 39 | total_code_len = style_code.shape[1] + content_code.shape[1] 40 | # (80, len) 41 | content_mel = torch.cat([style_mel, content_mel], 1) 42 | mel_len = content_mel.shape[1] 43 | content_mask = torch.lt(torch.arange(0, self.max_content_len), math.ceil(mel_len//2)).type(torch.long).to(self.device) 44 | content_mel = pad_or_trim(content_mel, self.max_mel_len) 45 | # (style_code_len,) 46 | style_code = style_code[0] 47 | # batch input data 48 | content_mel = content_mel.unsqueeze(0) 49 | content_mask = content_mask.unsqueeze(0) 50 | style_code = style_code.unsqueeze(0) 51 | 52 | with torch.no_grad(): 53 | outputs = self.ar_model.generate(content_mel, 54 | attention_mask=content_mask, 55 | decoder_input_ids=style_code, 56 | min_length=total_code_len+1, 57 | max_length=total_code_len+1) 58 | return outputs[0, style_code_len:total_code_len] 59 | 60 | def process_nar(self, content_mel, style_code, codes_0): 61 | style_code = style_code[:, 0:75 * 3] 62 | # codes_0: (code_len,) 63 | mel_len = content_mel.shape[1] 64 | content_mask = torch.lt(torch.arange(0, self.max_content_len), math.ceil(mel_len//2)).type(torch.long).to(self.device) 65 | content_mel = pad_or_trim(content_mel, self.max_mel_len) 66 | # 67 | content_mel = content_mel.unsqueeze(0) 68 | content_mask = content_mask.unsqueeze(0) 69 | style_code = style_code.unsqueeze(0) 70 | target_len = codes_0.shape[0] 71 | target_mask = torch.ones((1, target_len), dtype=torch.int64).to(self.device) 72 | # 73 | encoder_outputs = None 74 | codes_list = [codes_0] 75 | 76 | for i in range(0, self.config.n_q - 1): 77 | # prepare data. 78 | decoder_input_ids = torch.stack(codes_list, 0) 79 | # (1, n_q, code_len) 80 | decoder_input_ids = decoder_input_ids[None] 81 | nar_stage = torch.LongTensor([i]).to(self.device) 82 | _, logits = self.nar_model(input_ids=content_mel, 83 | attention_mask=content_mask, 84 | decoder_input_ids=decoder_input_ids, 85 | decoder_attention_mask=target_mask, 86 | encoder_outputs=encoder_outputs, 87 | style_code=style_code, 88 | nar_stage=nar_stage) 89 | preds = torch.argmax(logits, dim=-1) 90 | codes_list.append(preds[0]) 91 | full_codes = torch.stack(codes_list, 0) 92 | return full_codes 93 | 94 | def process_audio(self, content_audio: str, style_audio: str, 95 | max_style_len=3, max_content_len=15, use_ar=True): 96 | dtype = torch.float32 97 | # (80, content_mel_len) 98 | content_mel = torch.tensor(get_mel_spectrogram(content_audio), dtype=dtype).to(self.device)[:, 0:100 * max_content_len] 99 | # (n_q, content_code_len) 100 | content_code = torch.tensor(get_code(content_audio, 'cuda:0', self.encodec_model), dtype=torch.int64).to(self.device)[:, 0:75 * max_content_len] 101 | # (80, style_mel_len) 102 | style_mel = torch.tensor(get_mel_spectrogram(style_audio), dtype=dtype).to(self.device)[:, 0:100 * max_style_len] 103 | # (n_q, style_code_len) 104 | style_code = torch.tensor(get_code(style_audio, 'cuda:0', self.encodec_model), dtype=torch.int64).to(self.device)[:, 0:75 * max_style_len] 105 | # Process ARModel 106 | if use_ar: 107 | codes_0 = self.process_ar(content_mel, content_code, style_mel, style_code) 108 | else: 109 | codes_0 = content_code[0] 110 | # Process NARModel 111 | full_codes = self.process_nar(content_mel, style_code, codes_0) 112 | # Decode encodec 113 | with torch.no_grad(): 114 | outputs = self.encodec_model.decode([(full_codes.unsqueeze(0), None)]) 115 | return outputs[0, 0].detach().cpu().numpy() 116 | 117 | class VCEngineDataFactory(object): 118 | def __init__(self, 119 | ar_model_path: str, 120 | nar_model_path: str, 121 | ar_config_file: str, 122 | nar_config_file: str, 123 | device: str = 'cuda:0'): 124 | self.device = device 125 | # Load AR model. 126 | ar_model = ARModelPL.load_from_checkpoint(ar_model_path, 127 | config_file=ar_config_file) 128 | self.ar_model = ar_model.eval().to(self.device).model 129 | # Load NAR model. 130 | nar_model = NARModelPL.load_from_checkpoint(nar_model_path, 131 | config_file=nar_config_file) 132 | self.nar_model = nar_model.eval().to(self.device).model 133 | 134 | self.config = self.nar_model.config 135 | # load encodec model 136 | encodec_model = EncodecModel.encodec_model_24khz() 137 | encodec_model.set_target_bandwidth(6.0) 138 | self.encodec_model = encodec_model.to(self.device) 139 | self.max_mel_audio_time = 30 140 | self.max_mel_len = 100 * self.max_mel_audio_time 141 | self.max_content_len = math.ceil(self.max_mel_len/2) 142 | 143 | def process_ar(self, content_mel, content_code, style_mel, style_code): 144 | # Process ARModel 145 | style_code_len = style_code.shape[1] 146 | total_code_len = style_code.shape[1] + content_code.shape[1] 147 | # (80, len) 148 | content_mel = torch.cat([style_mel, content_mel], 1) 149 | mel_len = content_mel.shape[1] 150 | content_mask = torch.lt(torch.arange(0, self.max_content_len), math.ceil(mel_len//2)).type(torch.long).to(self.device) 151 | content_mel = pad_or_trim(content_mel, self.max_mel_len) 152 | # (style_code_len,) 153 | style_code = style_code[0] 154 | # batch input data 155 | content_mel = content_mel.unsqueeze(0) 156 | content_mask = content_mask.unsqueeze(0) 157 | style_code = style_code.unsqueeze(0) 158 | 159 | with torch.no_grad(): 160 | outputs = self.ar_model.generate(content_mel, 161 | attention_mask=content_mask, 162 | decoder_input_ids=style_code, 163 | min_length=total_code_len+1, 164 | max_length=total_code_len+1) 165 | return outputs[0, style_code_len:total_code_len] 166 | 167 | def process_nar(self, content_mel, style_code, codes_0): 168 | style_code = style_code[:, 0:75 * 3] 169 | # codes_0: (code_len,) 170 | mel_len = content_mel.shape[1] 171 | content_mask = torch.lt(torch.arange(0, self.max_content_len), math.ceil(mel_len//2)).type(torch.long).to(self.device) 172 | content_mel = pad_or_trim(content_mel, self.max_mel_len) 173 | # 174 | content_mel = content_mel.unsqueeze(0) 175 | content_mask = content_mask.unsqueeze(0) 176 | style_code = style_code.unsqueeze(0) 177 | target_len = codes_0.shape[0] 178 | target_mask = torch.ones((1, target_len), dtype=torch.int64).to(self.device) 179 | # 180 | encoder_outputs = None 181 | codes_list = [codes_0] 182 | 183 | for i in range(0, self.config.n_q - 1): 184 | # prepare data. 185 | decoder_input_ids = torch.stack(codes_list, 0) 186 | # (1, n_q, code_len) 187 | decoder_input_ids = decoder_input_ids[None] 188 | nar_stage = torch.LongTensor([i]).to(self.device) 189 | _, logits = self.nar_model(input_ids=content_mel, 190 | attention_mask=content_mask, 191 | decoder_input_ids=decoder_input_ids, 192 | decoder_attention_mask=target_mask, 193 | encoder_outputs=encoder_outputs, 194 | style_code=style_code, 195 | nar_stage=nar_stage) 196 | preds = torch.argmax(logits, dim=-1) 197 | codes_list.append(preds[0]) 198 | full_codes = torch.stack(codes_list, 0) 199 | return full_codes 200 | 201 | def process_audio(self, mel1, code1, mel2, code2, 202 | max_style_len=3, max_content_len=21): 203 | dtype = torch.float32 204 | mel1 = torch.tensor(mel1, dtype=dtype).to(self.device) 205 | code1 = torch.tensor(code1, dtype=torch.int64).to(self.device) 206 | mel2 = torch.tensor(mel2, dtype=dtype).to(self.device) 207 | code2 = torch.tensor(code2, dtype=torch.int64).to(self.device) 208 | content_mel = mel1[:, 0:100 * max_content_len] 209 | content_code = code1[:, 0:75 * max_content_len] 210 | style_mel = mel2[:, 0:100 * max_style_len] 211 | style_code = code2[:, 0:75 * max_style_len] 212 | # Process ARModel 213 | codes_0 = self.process_ar(content_mel, content_code, style_mel, style_code) 214 | # Process NARModel 215 | full_codes = self.process_nar(content_mel, style_code, codes_0) 216 | code_alpha = full_codes.unsqueeze(0).detach() 217 | # Decode encodec 218 | with torch.no_grad(): 219 | outputs = self.encodec_model.decode([(full_codes.unsqueeze(0), None)]) 220 | wav0_gen = outputs[0, 0].detach().cpu().numpy() 221 | wav0_gen = librosa.resample(wav0_gen, orig_sr=24000, target_sr=16000) 222 | mel1_alpha = get_mel_spectrogram(wav0_gen) 223 | return { 224 | 'mel': mel1, 225 | 'code': code1, 226 | 'mel_alpha': mel1_alpha, 227 | 'wav_alpha': wav0_gen, 228 | 'code_alpha': code_alpha[0] 229 | } 230 | 231 | def process_multistep_audio(self, mel1, code1, mel2, code2, max_style_len=3, max_content_len=21, 232 | step_num=3): 233 | mel_alpha, code_alpha = mel1, code1 234 | outputs_list = [] 235 | for i in range(step_num): 236 | outputs = self.process_audio(mel_alpha, code_alpha, mel2, code2, max_style_len, max_content_len) 237 | mel_alpha, code_alpha = outputs['mel_alpha'], outputs['code_alpha'] 238 | outputs_list.append(outputs) 239 | return outputs_list 240 | --------------------------------------------------------------------------------