├── README.md ├── bins └── tta │ └── inference.py ├── config ├── audioldm.json ├── autoencoderkl.json └── base.json ├── egs └── tta │ ├── audiolfm │ ├── exp_config_base.json │ ├── exp_config_lfm.json │ └── run_inference.sh │ └── autoencoderkl │ ├── exp_config.json │ └── exp_config_base.json ├── imgs └── lafma.png ├── models ├── __init__.py └── tta │ ├── autoencoder │ ├── __init__.py │ └── autoencoder.py │ ├── hifigan │ ├── LICENSE │ ├── __init__.py │ ├── models.py │ └── models_hifires.py │ └── lfm │ ├── __init__.py │ ├── attention.py │ ├── audioldm.py │ ├── audiolfm_inference.py │ └── fm_scheduler.py ├── modules ├── __init__.py └── distributions │ ├── __init__.py │ └── distributions.py ├── requirements.txt └── utils ├── HyperParams ├── __init__.py └── hps.py ├── __init__.py ├── data_utils.py ├── hparam.py ├── io.py ├── tensor_utils.py ├── util.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # LAFMA 2 | Official implementation of the paper "LAFMA: A Latent Flow Matching Model for Text-to-Audio Generation" (INTERSPEECH 2024). [Paper Link](https://arxiv.org/pdf/2406.08203.pdf) and [Demo Page](https://lafma.github.io) . 3 | 4 | ![](imgs/lafma.png) 5 | 6 | ## Checkpoints 7 | 8 | [VAEGAN Model](https://drive.google.com/file/d/1FRTMxcKHafTcDvEK-c_zRYckF25-UbjQ/view?usp=drive_link): 9 | The VAEGAN model is the audio VAE that compresses the audio mel-spectrogram into an audio latent. 10 | 11 | [LAFMA Model](https://drive.google.com/file/d/1lpX8rN1GvDar4quoLfofI0UireVmuHay/view?usp=drive_link): 12 | The LAFAM model is the latent flow matching model for text guided audio generation model. 13 | 14 | We use the checkpoint of HiFi-GAN vocoder provided by [AudioLDM](https://zenodo.org/records/7884686) . 15 | 16 | ## Inference 17 | ``` 18 | # install dependicies 19 | pip install -r requirement.txt 20 | 21 | # infer 22 | (first download the huggingface flan-t5-large to the huggingface/flan-t5-large dir) 23 | (replace the checkpoint_path to yours in the .sh file) 24 | cd LAFMA 25 | sh egs/tta/audiolfm/run_inference.sh 26 | ``` 27 | ## Acknowledgements 28 | - [Amphion](https://github.com/open-mmlab/Amphion) 29 | - [Fabric](https://github.com/Lightning-AI/pytorch-lightning) 30 | - [AudioLDM](https://github.com/haoheliu/AudioLDM) 31 | - [Flow Matching](https://github.com/atong01/conditional-flow-matching/tree/main) 32 | 33 | 34 | ## Cites 35 | ``` 36 | @misc{guan2024lafma, 37 | title={LAFMA: A Latent Flow Matching Model for Text-to-Audio Generation}, 38 | author={Wenhao Guan and Kaidi Wang and Wangjin Zhou and Yang Wang and Feng Deng and Hui Wang and Lin Li and Qingyang Hong and Yong Qin}, 39 | year={2024}, 40 | eprint={2406.08203}, 41 | archivePrefix={arXiv}, 42 | primaryClass={eess.AS} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /bins/tta/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | 8 | from models.tta.lfm.audiolfm_inference import LAFMA_Inference 9 | from utils.util import load_config 10 | 11 | 12 | def build_inference(cfg): 13 | supported_inference = {"AudioLFM": LAFMA_Inference} 14 | 15 | inference = supported_inference[cfg.train.project]( 16 | cfg, precision=cfg.train.precision 17 | ) 18 | 19 | return inference 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument( 26 | "--config", 27 | default="config.json", 28 | help="json files for configurations.", 29 | required=True, 30 | ) 31 | 32 | parser.add_argument( 33 | "--devices", nargs="+", type=int, default=None, help="gpu devices." 34 | ) 35 | 36 | parser.add_argument("--batch_size", type=int, default=None, help="batch size.") 37 | 38 | parser.add_argument( 39 | "--infer", action="store_true", default=False, help="test mode." 40 | ) 41 | 42 | parser.add_argument( 43 | "--text", 44 | help="Text to be synthesized", 45 | type=str, 46 | default="", 47 | ) 48 | parser.add_argument( 49 | "--checkpoint_file", 50 | type=str, 51 | default="final_checkpoint.ckpt", 52 | help="Checkpoint for test.(only test)", 53 | ) 54 | parser.add_argument( 55 | "--num_steps", 56 | type=int, 57 | default=200, 58 | help="The total number of denosing steps", 59 | ) 60 | parser.add_argument( 61 | "--guidance_scale", 62 | type=float, 63 | default=3.0, 64 | help="The scale of classifer free guidance", 65 | ) 66 | args = parser.parse_args() 67 | config = load_config(args.config) 68 | if args.infer: 69 | config.checkpoint_file = args.checkpoint_file 70 | config.infer = args.infer 71 | if args.batch_size is not None: 72 | config.train.batch_size = args.batch_size 73 | if args.num_steps is not None: 74 | config.num_steps = args.num_steps 75 | if args.guidance_scale is not None: 76 | config.guidance_scale = args.guidance_scale 77 | if args.text is not None: 78 | config.infer_text = args.text 79 | if args.devices is not None: 80 | config.train.devices = args.devices 81 | 82 | if config.infer: 83 | inferencer = build_inference(config) 84 | inferencer.test() 85 | 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /config/audioldm.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_config": "config/base.json", 3 | "model_type": "AudioLDM", 4 | "task_type": "tta", 5 | "dataset": [ 6 | "AudioCaps" 7 | ], 8 | "preprocess": { 9 | // feature used for model training 10 | "use_spkid": false, 11 | "use_uv": false, 12 | "use_frame_pitch": false, 13 | "use_phone_pitch": false, 14 | "use_frame_energy": false, 15 | "use_phone_energy": false, 16 | "use_mel": false, 17 | "use_audio": false, 18 | "use_label": false, 19 | "use_one_hot": false, 20 | "cond_mask_prob": 0.1 21 | }, 22 | // model 23 | "model": { 24 | "audioldm": { 25 | "image_size": 32, 26 | "in_channels": 4, 27 | "out_channels": 4, 28 | "model_channels": 256, 29 | "attention_resolutions": [ 30 | 4, 31 | 2, 32 | 1 33 | ], 34 | "num_res_blocks": 2, 35 | "channel_mult": [ 36 | 1, 37 | 2, 38 | 4 39 | ], 40 | "num_heads": 8, 41 | "use_spatial_transformer": true, 42 | "transformer_depth": 1, 43 | "context_dim": 768, 44 | "use_checkpoint": true, 45 | "legacy": false 46 | }, 47 | "autoencoderkl": { 48 | "ch": 128, 49 | "ch_mult": [ 50 | 1, 51 | 1, 52 | 2, 53 | 2, 54 | 4 55 | ], 56 | "num_res_blocks": 2, 57 | "in_channels": 1, 58 | "z_channels": 4, 59 | "out_ch": 1, 60 | "double_z": true 61 | }, 62 | "noise_scheduler": { 63 | "num_train_timesteps": 1000, 64 | "beta_start": 0.00085, 65 | "beta_end": 0.012, 66 | "beta_schedule": "scaled_linear", 67 | "clip_sample": false, 68 | "steps_offset": 1, 69 | "set_alpha_to_one": false, 70 | "skip_prk_steps": true, 71 | "prediction_type": "epsilon" 72 | } 73 | }, 74 | // train 75 | "train": { 76 | "lronPlateau": { 77 | "factor": 0.9, 78 | "patience": 100, 79 | "min_lr": 4.0e-5, 80 | "verbose": true 81 | }, 82 | "adam": { 83 | "lr": 5.0e-5, 84 | "betas": [ 85 | 0.9, 86 | 0.999 87 | ], 88 | "weight_decay": 1.0e-2, 89 | "eps": 1.0e-8 90 | } 91 | } 92 | } -------------------------------------------------------------------------------- /config/autoencoderkl.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_config": "config/base.json", 3 | "model_type": "AutoencoderKL", 4 | "task_type": "tta", 5 | "dataset": [ 6 | "AudioCaps" 7 | ], 8 | "preprocess": { 9 | // feature used for model training 10 | "use_spkid": false, 11 | "use_uv": false, 12 | "use_frame_pitch": false, 13 | "use_phone_pitch": false, 14 | "use_frame_energy": false, 15 | "use_phone_energy": false, 16 | "use_mel": false, 17 | "use_audio": false, 18 | "use_label": false, 19 | "use_one_hot": false, 20 | // Settings for data preprocessing 21 | "n_mel": 64, 22 | "win_size": 1024, 23 | "hop_size": 160, 24 | "sample_rate": 16000, 25 | "n_fft": 1024, 26 | "fmin": 0, 27 | "fmax": 8000, 28 | // "bits": 8 29 | }, 30 | // model 31 | "model": { 32 | "autoencoderkl": { 33 | "ch": 128, 34 | "ch_mult": [ 35 | 1, 36 | 2, 37 | 4 38 | ], 39 | "num_res_blocks": 2, 40 | "in_channels": 1, 41 | "z_channels": 8, 42 | "out_ch": 1, 43 | "double_z": true 44 | }, 45 | "loss": { 46 | "kl_weight": 1e-8, 47 | "disc_weight": 0.5, 48 | "disc_factor": 1.0, 49 | "logvar_init": 0.0, 50 | "disc_start": 20001, 51 | "disc_in_channels": 1, 52 | "disc_num_layers": 3, 53 | "use_actnorm": false 54 | } 55 | }, 56 | // train 57 | "train": { 58 | "adam": { 59 | "lr": 4.0e-5 60 | } 61 | } 62 | } -------------------------------------------------------------------------------- /config/base.json: -------------------------------------------------------------------------------- 1 | { 2 | "task_type": "", 3 | "dataset": [], 4 | "use_custom_dataset": [], 5 | "preprocess": { 6 | "phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon" 7 | // trim audio silence 8 | "data_augment": false, 9 | "trim_silence": false, 10 | "num_silent_frames": 8, 11 | "trim_fft_size": 512, // fft size used in trimming 12 | "trim_hop_size": 128, // hop size used in trimming 13 | "trim_top_db": 30, // top db used in trimming sensitive to each dataset 14 | // acoustic features 15 | "extract_mel": false, 16 | "mel_extract_mode": "", 17 | "extract_linear_spec": false, 18 | "extract_mcep": false, 19 | "extract_pitch": false, 20 | "extract_acoustic_token": false, 21 | "pitch_remove_outlier": false, 22 | "extract_uv": false, 23 | "pitch_norm": false, 24 | "extract_audio": false, 25 | "extract_label": false, 26 | "pitch_extractor": "parselmouth", // pyin, dio, pyworld, pyreaper, parselmouth, CWT (Continuous Wavelet Transform) 27 | "extract_energy": false, 28 | "energy_remove_outlier": false, 29 | "energy_norm": false, 30 | "energy_extract_mode": "from_mel", 31 | "extract_duration": false, 32 | "extract_amplitude_phase": false, 33 | "mel_min_max_norm": false, 34 | // lingusitic features 35 | "extract_phone": false, 36 | "lexicon_path": "./text/lexicon/librispeech-lexicon.txt", 37 | // content features 38 | "extract_whisper_feature": false, 39 | "extract_contentvec_feature": false, 40 | "extract_mert_feature": false, 41 | "extract_wenet_feature": false, 42 | // Settings for data preprocessing 43 | "n_mel": 80, 44 | "win_size": 480, 45 | "hop_size": 120, 46 | "sample_rate": 24000, 47 | "n_fft": 1024, 48 | "fmin": 0, 49 | "fmax": 12000, 50 | // "min_level_db": -115, 51 | // "ref_level_db": 20, 52 | // "bits": 8, 53 | // Directory names of processed data or extracted features 54 | "processed_dir": "processed_data", 55 | "trimmed_wav_dir": "trimmed_wavs", // directory name of silence trimed wav 56 | "raw_data": "raw_data", 57 | "phone_dir": "phones", 58 | "wav_dir": "wavs", // directory name of processed wav (such as downsampled waveform) 59 | "audio_dir": "audios", 60 | "log_amplitude_dir": "log_amplitudes", 61 | "phase_dir": "phases", 62 | "real_dir": "reals", 63 | "imaginary_dir": "imaginarys", 64 | "label_dir": "labels", 65 | "linear_dir": "linears", 66 | "mel_dir": "mels", // directory name of extraced mel features 67 | "mcep_dir": "mcep", // directory name of extraced mcep features 68 | "dur_dir": "durs", 69 | "symbols_dict": "symbols.dict", 70 | "lab_dir": "labs", // directory name of extraced label features 71 | "wenet_dir": "wenet", // directory name of extraced wenet features 72 | "contentvec_dir": "contentvec", // directory name of extraced wenet features 73 | "pitch_dir": "pitches", // directory name of extraced pitch features 74 | "energy_dir": "energys", // directory name of extracted energy features 75 | "phone_pitch_dir": "phone_pitches", // directory name of extraced pitch features 76 | "phone_energy_dir": "phone_energys", // directory name of extracted energy features 77 | "uv_dir": "uvs", // directory name of extracted unvoiced features 78 | "duration_dir": "duration", // ground-truth duration file 79 | "phone_seq_file": "phone_seq_file", // phoneme sequence file 80 | "file_lst": "file.lst", 81 | "train_file": "train.json", // training set, the json file contains detailed information about the dataset, including dataset name, utterance id, duration of the utterance 82 | "valid_file": "valid.json", // validattion set 83 | "spk2id": "spk2id.json", // used for multi-speaker dataset 84 | "utt2spk": "utt2spk", // used for multi-speaker dataset 85 | "emo2id": "emo2id.json", // used for multi-emotion dataset 86 | "utt2emo": "utt2emo", // used for multi-emotion dataset 87 | // Features used for model training 88 | "use_text": false, 89 | "use_phone": false, 90 | "use_phn_seq": false, 91 | "use_lab": false, 92 | "use_linear": false, 93 | "use_mel": false, 94 | "use_min_max_norm_mel": false, 95 | "use_wav": false, 96 | "use_phone_pitch": false, 97 | "use_log_scale_pitch": false, 98 | "use_phone_energy": false, 99 | "use_phone_duration": false, 100 | "use_log_scale_energy": false, 101 | "use_wenet": false, 102 | "use_dur": false, 103 | "use_spkid": false, // True: use speaker id for multi-speaker dataset 104 | "use_emoid": false, // True: use emotion id for multi-emotion dataset 105 | "use_frame_pitch": false, 106 | "use_uv": false, 107 | "use_frame_energy": false, 108 | "use_frame_duration": false, 109 | "use_audio": false, 110 | "use_label": false, 111 | "use_one_hot": false, 112 | "use_amplitude_phase": false, 113 | "align_mel_duration": false 114 | }, 115 | "train": { 116 | "ddp": true, 117 | "batch_size": 16, 118 | "max_steps": 1000000, 119 | // Trackers 120 | "tracker": [ 121 | "tensorboard" 122 | // "wandb", 123 | // "cometml", 124 | // "mlflow", 125 | ], 126 | "max_epoch": -1, 127 | // -1 means no limit 128 | "save_checkpoint_stride": [ 129 | 5, 130 | 20 131 | ], 132 | // unit is epoch 133 | "keep_last": [ 134 | 3, 135 | -1 136 | ], 137 | // -1 means infinite, if one number will broadcast 138 | "run_eval": [ 139 | false, 140 | true 141 | ], 142 | // if one number will broadcast 143 | // Fix the random seed 144 | "random_seed": 10086, 145 | // Optimizer 146 | "optimizer": "AdamW", 147 | "adamw": { 148 | "lr": 4.0e-4 149 | // nn model lr 150 | }, 151 | // LR Scheduler 152 | "scheduler": "ReduceLROnPlateau", 153 | "reducelronplateau": { 154 | "factor": 0.8, 155 | "patience": 10, 156 | // unit is epoch 157 | "min_lr": 1.0e-4 158 | }, 159 | // Batchsampler 160 | "sampler": { 161 | "holistic_shuffle": true, 162 | "drop_last": true 163 | }, 164 | // Dataloader 165 | "dataloader": { 166 | "num_worker": 32, 167 | "pin_memory": true 168 | }, 169 | "gradient_accumulation_step": 1, 170 | "total_training_steps": 50000, 171 | "save_summary_steps": 500, 172 | "save_checkpoints_steps": 10000, 173 | "valid_interval": 10000, 174 | "keep_checkpoint_max": 5, 175 | "multi_speaker_training": false // True: train multi-speaker model; False: training single-speaker model; 176 | } 177 | } -------------------------------------------------------------------------------- /egs/tta/audiolfm/exp_config_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_config": "config/audioldm.json", 3 | "model_type": "AudioLDM", 4 | "dataset": [ 5 | "AudioCaps" 6 | ], 7 | "preprocess": { 8 | "train_file": "train.json", 9 | "valid_file": "test.json" 10 | } 11 | } -------------------------------------------------------------------------------- /egs/tta/audiolfm/exp_config_lfm.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_config": "egs/tta/audiolfm/exp_config_base.json", 3 | "dataset": [ 4 | "AudioCaps" 5 | ], 6 | "preprocess": { 7 | // Specify the output root path to save the processed data 8 | "processed_dir": "/work/gwh/Amphion/processed_data/tta", 9 | // feature 10 | "use_spkid": false, 11 | "use_uv": false, 12 | "use_frame_pitch": false, 13 | "use_phone_pitch": false, 14 | "use_frame_energy": false, 15 | "use_phone_energy": false, 16 | "use_mel": false, 17 | "use_audio": false, 18 | "use_label": false, 19 | "use_one_hot": false, 20 | // feature for text to audio 21 | "use_caption": true, 22 | "use_melspec": true, 23 | "use_wav": false, 24 | // feature dir 25 | "melspec_dir": "mels", 26 | "wav_dir": "wavs", 27 | // mel features 28 | "n_mel": 64, 29 | "win_size": 1024, 30 | "hop_size": 160, // 24000/120*10=2000frames 2000 31 | "sample_rate": 16000, 32 | "n_fft": 1024, 33 | "fmin": 0, 34 | "fmax": 8000, 35 | "filter_length": 1024, 36 | "target_length": 1024, //10.24*16000/160=1024 37 | "duration": 10.24 38 | }, 39 | // Specify the output root path to save model ckpts and logs 40 | "root_dir": "", 41 | "checkpoint_dir": "ckpts/tta/audiolfm", 42 | "checkpoint_file": "", 43 | "test_out_dir": "test_results", 44 | "num_workers": 16, 45 | "seed": 3369, 46 | "warmup_steps": 1000, 47 | "gamma": 0.95, 48 | //resume 49 | "resume": "", 50 | //infer 51 | "infer": true, 52 | "infer_text": "", 53 | "test_batch_size": 16, 54 | // diff 55 | "num_steps": 200, 56 | "guidance_scale": 3, 57 | // model 58 | "model": { 59 | "audioldm": { 60 | "image_size": 64, 61 | "in_channels": 8, 62 | "out_channels": 8, 63 | "model_channels": 128, 64 | "attention_resolutions": [ 65 | 8, 66 | 4, 67 | 2 68 | ], 69 | "num_res_blocks": 2, 70 | "channel_mult": [ 71 | 1, 72 | 2, 73 | 3, 74 | 5 75 | ], 76 | "num_heads": 32, 77 | "use_spatial_transformer": true, 78 | "transformer_depth": 1, 79 | "context_dim": 1024, 80 | "use_checkpoint": true, 81 | "legacy": true, 82 | "extra_sa_layer": true // ### 83 | }, 84 | "autoencoderkl": { 85 | "ch": 128, 86 | "ch_mult": [ 87 | 1, 88 | 2, 89 | 4 90 | ], 91 | "num_res_blocks": 2, 92 | "in_channels": 1, 93 | "z_channels": 8, 94 | "out_ch": 1, 95 | "double_z": true 96 | }, 97 | "autoencoder_path": "/work/gwh/Amphion/ckpts/tta/autoencoder/best_729021.ckpt" 98 | }, 99 | // train 100 | "train": { 101 | "adam": { 102 | "lr": 1.0e-4 103 | }, 104 | "max_steps": 1000000, 105 | "total_training_steps": 800000, 106 | "save_summary_steps": 1000, 107 | "save_checkpoints_steps": 5000, 108 | "valid_interval": 5000, 109 | "keep_checkpoint_max": 100, 110 | "accelerator": "cuda", 111 | "devices": [ 112 | 1, 113 | 2 114 | ], 115 | "strategy": "auto", 116 | "precision": "32-true", //16-mixed 117 | "out_dir": "ckpts/tta/audiolfm", 118 | "batch_size": 8, 119 | "epochs": 60, 120 | "steps": "", // 1000000 121 | "project": "AudioLFM", 122 | "task_type": "tta", 123 | "gradient_accumulation_steps": 1, 124 | "exponential_lr": { 125 | "lr_decay": 0.999 126 | }, 127 | }, 128 | "logger": { 129 | "logger_name": [ 130 | "csv", 131 | "tensorboard" 132 | ], 133 | "log_interval": 10, 134 | "log_per_epoch": "", 135 | "checkpoint_frequency": 2, 136 | "num_checkpoint_keep": 4 137 | } 138 | } -------------------------------------------------------------------------------- /egs/tta/audiolfm/run_inference.sh: -------------------------------------------------------------------------------- 1 | 2 | ######## Build Experiment Environment ########### 3 | exp_dir=$(cd `dirname $0`; pwd) 4 | work_dir=$(dirname $(dirname $(dirname $exp_dir))) 5 | 6 | export WORK_DIR=$work_dir 7 | export PYTHONPATH=$work_dir 8 | export PYTHONIOENCODING=UTF-8 9 | 10 | exp_config="$exp_dir/exp_config_lfm.json" 11 | 12 | 13 | ######## Run inference ########### 14 | python "${work_dir}"/bins/tta/inference.py \ 15 | --config $exp_config \ 16 | --devices 3 \ 17 | --checkpoint_file "checkpoint_path" \ 18 | --num_steps 10 \ 19 | --guidance_scale 3 \ 20 | --infer \ 21 | --text "Birds are chirping." 22 | -------------------------------------------------------------------------------- /egs/tta/autoencoderkl/exp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_config": "egs/tta/autoencoderkl/exp_config_base.json", 3 | "dataset": [ 4 | "audioset", 5 | "AudioCaps", 6 | "BBC", 7 | "FreeSound", 8 | "SoundBible" 9 | ], 10 | "dataset_path": "", 11 | "preprocess": { 12 | // Specify the output root path to save the processed data 13 | "processed_dir": "/work/gwh/Amphion/processed_data/tta", 14 | // feature 15 | "use_spk": false, 16 | "use_spkid": false, 17 | "use_uv": false, 18 | "use_frame_pitch": false, 19 | "use_phone_pitch": false, 20 | "use_frame_energy": false, 21 | "use_phone_energy": false, 22 | "use_mel": false, 23 | "use_audio": false, 24 | "use_label": false, 25 | "use_one_hot": false, 26 | // feature for text to audio 27 | "use_caption": true, 28 | "use_melspec": true, 29 | "use_wav": false, 30 | // feature dir 31 | "melspec_dir": "mels", 32 | "wav_dir": "wavs", 33 | // mel features 34 | "n_mel": 64, 35 | "win_size": 1024, 36 | "hop_size": 160, // 24000/120*10=2000frames 2000 37 | "sample_rate": 16000, 38 | "n_fft": 1024, 39 | "fmin": 0, 40 | "fmax": 8000, 41 | "filter_length": 1024, 42 | "target_length": 1024, //10.24*16000/160=1024 43 | "duration": 10.24 44 | }, 45 | // Specify the output root path to save model ckpts and logs 46 | "root_dir": "", 47 | "checkpoint_dir": "ckpts/tta/autoencoder", 48 | "checkpoint_file": "", 49 | "test_out_dir": "test_results", 50 | "num_workers": 16, 51 | "seed": 1234, 52 | "warmup_steps": 5000, 53 | "gamma": 0.95, 54 | //resume 55 | "resume": "", 56 | //infer 57 | "infer": false, 58 | // train 59 | "train": { 60 | "accelerator": "cuda", 61 | "devices": [ 62 | 1, 63 | 2 64 | ], 65 | "strategy": "auto", 66 | "precision": "16-mixed", 67 | "out_dir": "ckpts/tta/autoencoder", 68 | "batch_size": 16, 69 | "epochs": 8, 70 | "steps": "", // 1000000 71 | "project": "AutoencoderKL", 72 | "task_type": "tta", 73 | "gradient_accumulation_steps": 1, 74 | "exponential_lr": { 75 | "lr_decay": 0.999 76 | }, 77 | "adam": { 78 | "lr": 1.0e-5, 79 | "betas": [ 80 | 0.5, 81 | 0.9 82 | ], 83 | "weight_decay": 0, 84 | "eps": 1.0e-8 85 | }, 86 | }, 87 | "logger": { 88 | "logger_name": [ 89 | "csv", 90 | "tensorboard" 91 | ], 92 | "log_interval": 10, 93 | "log_per_epoch": "", 94 | "checkpoint_frequency": 1, 95 | "num_checkpoint_keep": 4 96 | } 97 | } -------------------------------------------------------------------------------- /egs/tta/autoencoderkl/exp_config_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_config": "config/autoencoderkl.json", 3 | "model_type": "AutoencoderKL", 4 | "dataset": [ 5 | "AudioCaps" 6 | ], 7 | "preprocess": { 8 | "train_file": "train.json", 9 | "valid_file": "test.json" 10 | } 11 | } -------------------------------------------------------------------------------- /imgs/lafma.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gwh22/LAFMA/0a11ac2d4d176018f0aa0531dd6aca84cc532138/imgs/lafma.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gwh22/LAFMA/0a11ac2d4d176018f0aa0531dd6aca84cc532138/models/__init__.py -------------------------------------------------------------------------------- /models/tta/autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gwh22/LAFMA/0a11ac2d4d176018f0aa0531dd6aca84cc532138/models/tta/autoencoder/__init__.py -------------------------------------------------------------------------------- /models/tta/autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from modules.distributions.distributions import DiagonalGaussianDistribution 11 | from taming.modules.losses.vqperceptual import * 12 | 13 | 14 | def nonlinearity(x): 15 | # swish 16 | return x * torch.sigmoid(x) 17 | 18 | 19 | def Normalize(in_channels): 20 | return torch.nn.GroupNorm( 21 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 22 | ) 23 | 24 | 25 | class Upsample2d(nn.Module): 26 | def __init__(self, in_channels, with_conv): 27 | super().__init__() 28 | self.with_conv = with_conv 29 | if self.with_conv: 30 | self.conv = torch.nn.Conv2d( 31 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 32 | ) 33 | 34 | def forward(self, x): 35 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 36 | if self.with_conv: 37 | x = self.conv(x) 38 | return x 39 | 40 | 41 | class Upsample1d(Upsample2d): 42 | def __init__(self, in_channels, with_conv): 43 | super().__init__(in_channels, with_conv) 44 | if self.with_conv: 45 | self.conv = torch.nn.Conv1d( 46 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 47 | ) 48 | 49 | 50 | class Downsample2d(nn.Module): 51 | def __init__(self, in_channels, with_conv): 52 | super().__init__() 53 | self.with_conv = with_conv 54 | if self.with_conv: 55 | # no asymmetric padding in torch conv, must do it ourselves 56 | self.conv = torch.nn.Conv2d( 57 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 58 | ) 59 | self.pad = (0, 1, 0, 1) 60 | else: 61 | self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) 62 | 63 | def forward(self, x): 64 | if self.with_conv: # bp: check self.avgpool and self.pad 65 | x = torch.nn.functional.pad(x, self.pad, mode="constant", value=0) 66 | x = self.conv(x) 67 | else: 68 | x = self.avg_pool(x) 69 | return x 70 | 71 | 72 | class Downsample1d(Downsample2d): 73 | def __init__(self, in_channels, with_conv): 74 | super().__init__(in_channels, with_conv) 75 | if self.with_conv: 76 | # no asymmetric padding in torch conv, must do it ourselves 77 | # TODO: can we replace it just with conv2d with padding 1? 78 | self.conv = torch.nn.Conv1d( 79 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 80 | ) 81 | self.pad = (1, 1) 82 | else: 83 | self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) 84 | 85 | 86 | class ResnetBlock(nn.Module): 87 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): 88 | super().__init__() 89 | self.in_channels = in_channels 90 | out_channels = in_channels if out_channels is None else out_channels 91 | self.out_channels = out_channels 92 | self.use_conv_shortcut = conv_shortcut 93 | 94 | self.norm1 = Normalize(in_channels) 95 | self.conv1 = torch.nn.Conv2d( 96 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 97 | ) 98 | 99 | self.norm2 = Normalize(out_channels) 100 | self.dropout = torch.nn.Dropout(dropout) 101 | self.conv2 = torch.nn.Conv2d( 102 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 103 | ) 104 | if self.in_channels != self.out_channels: 105 | if self.use_conv_shortcut: 106 | self.conv_shortcut = torch.nn.Conv2d( 107 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 108 | ) 109 | else: 110 | self.nin_shortcut = torch.nn.Conv2d( 111 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 112 | ) 113 | 114 | def forward(self, x): 115 | h = x 116 | h = self.norm1(h) 117 | h = nonlinearity(h) 118 | h = self.conv1(h) 119 | 120 | h = self.norm2(h) 121 | h = nonlinearity(h) 122 | h = self.dropout(h) 123 | h = self.conv2(h) 124 | 125 | if self.in_channels != self.out_channels: 126 | if self.use_conv_shortcut: 127 | x = self.conv_shortcut(x) 128 | else: 129 | x = self.nin_shortcut(x) 130 | 131 | return x + h 132 | 133 | 134 | class ResnetBlock1d(ResnetBlock): 135 | def __init__( 136 | self, 137 | *, 138 | in_channels, 139 | out_channels=None, 140 | conv_shortcut=False, 141 | dropout, 142 | temb_channels=512 143 | ): 144 | super().__init__( 145 | in_channels=in_channels, 146 | out_channels=out_channels, 147 | conv_shortcut=conv_shortcut, 148 | dropout=dropout, 149 | ) 150 | 151 | self.conv1 = torch.nn.Conv1d( 152 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 153 | ) 154 | self.conv2 = torch.nn.Conv1d( 155 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 156 | ) 157 | if self.in_channels != self.out_channels: 158 | if self.use_conv_shortcut: 159 | self.conv_shortcut = torch.nn.Conv1d( 160 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 161 | ) 162 | else: 163 | self.nin_shortcut = torch.nn.Conv1d( 164 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 165 | ) 166 | 167 | 168 | class Encoder2d(nn.Module): 169 | def __init__( 170 | self, 171 | *, 172 | ch, 173 | ch_mult=(1, 2, 4, 8), 174 | num_res_blocks, 175 | dropout=0.0, 176 | resamp_with_conv=True, 177 | in_channels, 178 | z_channels, 179 | double_z=True, 180 | **ignore_kwargs 181 | ): 182 | super().__init__() 183 | self.ch = ch 184 | self.num_resolutions = len(ch_mult) 185 | self.num_res_blocks = num_res_blocks 186 | self.in_channels = in_channels 187 | 188 | # downsampling 189 | self.conv_in = torch.nn.Conv2d( 190 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 191 | ) 192 | 193 | in_ch_mult = (1,) + tuple(ch_mult) 194 | self.down = nn.ModuleList() 195 | for i_level in range(self.num_resolutions): 196 | block = nn.ModuleList() 197 | block_in = ch * in_ch_mult[i_level] 198 | block_out = ch * ch_mult[i_level] 199 | for i_block in range(self.num_res_blocks): 200 | block.append( 201 | ResnetBlock( 202 | in_channels=block_in, out_channels=block_out, dropout=dropout 203 | ) 204 | ) 205 | block_in = block_out 206 | down = nn.Module() 207 | down.block = block 208 | if i_level != self.num_resolutions - 1: 209 | down.downsample = Downsample2d(block_in, resamp_with_conv) 210 | self.down.append(down) 211 | 212 | # middle 213 | self.mid = nn.Module() 214 | self.mid.block_1 = ResnetBlock( 215 | in_channels=block_in, out_channels=block_in, dropout=dropout 216 | ) 217 | self.mid.block_2 = ResnetBlock( 218 | in_channels=block_in, out_channels=block_in, dropout=dropout 219 | ) 220 | 221 | # end 222 | self.norm_out = Normalize(block_in) 223 | self.conv_out = torch.nn.Conv2d( 224 | block_in, 225 | 2 * z_channels if double_z else z_channels, 226 | kernel_size=3, 227 | stride=1, 228 | padding=1, 229 | ) 230 | 231 | def forward(self, x): 232 | # downsampling 233 | hs = [self.conv_in(x)] 234 | for i_level in range(self.num_resolutions): 235 | for i_block in range(self.num_res_blocks): 236 | h = self.down[i_level].block[i_block](hs[-1]) 237 | hs.append(h) 238 | if i_level != self.num_resolutions - 1: 239 | hs.append(self.down[i_level].downsample(hs[-1])) 240 | 241 | # middle 242 | h = hs[-1] 243 | h = self.mid.block_1(h) 244 | h = self.mid.block_2(h) 245 | 246 | # end 247 | h = self.norm_out(h) 248 | h = nonlinearity(h) 249 | h = self.conv_out(h) 250 | return h 251 | 252 | 253 | # TODO: Encoder1d 254 | # class Encoder1d(Encoder2d): ... 255 | 256 | 257 | class Decoder2d(nn.Module): 258 | def __init__( 259 | self, 260 | *, 261 | ch, 262 | out_ch, 263 | ch_mult=(1, 2, 4, 8), 264 | num_res_blocks, 265 | dropout=0.0, 266 | resamp_with_conv=True, 267 | in_channels, 268 | z_channels, 269 | give_pre_end=False, 270 | **ignorekwargs 271 | ): 272 | super().__init__() 273 | self.ch = ch 274 | self.num_resolutions = len(ch_mult) 275 | self.num_res_blocks = num_res_blocks 276 | self.in_channels = in_channels 277 | self.give_pre_end = give_pre_end 278 | 279 | # compute in_ch_mult, block_in and curr_res at lowest res 280 | in_ch_mult = (1,) + tuple(ch_mult) 281 | block_in = ch * ch_mult[self.num_resolutions - 1] 282 | # self.z_shape = (1,z_channels,curr_res,curr_res) 283 | # print("Working with z of shape {} = {} dimensions.".format( 284 | # self.z_shape, np.prod(self.z_shape))) 285 | 286 | # z to block_in 287 | self.conv_in = torch.nn.Conv2d( 288 | z_channels, block_in, kernel_size=3, stride=1, padding=1 289 | ) 290 | 291 | # middle 292 | self.mid = nn.Module() 293 | self.mid.block_1 = ResnetBlock( 294 | in_channels=block_in, out_channels=block_in, dropout=dropout 295 | ) 296 | self.mid.block_2 = ResnetBlock( 297 | in_channels=block_in, out_channels=block_in, dropout=dropout 298 | ) 299 | 300 | # upsampling 301 | self.up = nn.ModuleList() 302 | for i_level in reversed(range(self.num_resolutions)): 303 | block = nn.ModuleList() 304 | attn = nn.ModuleList() 305 | block_out = ch * ch_mult[i_level] 306 | for i_block in range(self.num_res_blocks + 1): 307 | block.append( 308 | ResnetBlock( 309 | in_channels=block_in, out_channels=block_out, dropout=dropout 310 | ) 311 | ) 312 | block_in = block_out 313 | up = nn.Module() 314 | up.block = block 315 | up.attn = attn 316 | if i_level != 0: 317 | up.upsample = Upsample2d(block_in, resamp_with_conv) 318 | self.up.insert(0, up) # prepend to get consistent order 319 | 320 | # end 321 | self.norm_out = Normalize(block_in) 322 | self.conv_out = torch.nn.Conv2d( 323 | block_in, out_ch, kernel_size=3, stride=1, padding=1 324 | ) 325 | 326 | def forward(self, z): 327 | self.last_z_shape = z.shape 328 | 329 | # z to block_in 330 | h = self.conv_in(z) 331 | 332 | # middle 333 | h = self.mid.block_1(h) 334 | h = self.mid.block_2(h) 335 | 336 | # upsampling 337 | for i_level in reversed(range(self.num_resolutions)): 338 | for i_block in range(self.num_res_blocks + 1): 339 | h = self.up[i_level].block[i_block](h) 340 | if i_level != 0: 341 | h = self.up[i_level].upsample(h) 342 | 343 | # end 344 | if self.give_pre_end: 345 | return h 346 | 347 | h = self.norm_out(h) 348 | h = nonlinearity(h) 349 | h = self.conv_out(h) 350 | return h 351 | 352 | 353 | class AutoencoderKL(nn.Module): 354 | def __init__(self, cfg): 355 | super().__init__() 356 | self.cfg = cfg 357 | self.encoder = Encoder2d( 358 | ch=cfg.ch, 359 | ch_mult=cfg.ch_mult, 360 | num_res_blocks=cfg.num_res_blocks, 361 | in_channels=cfg.in_channels, 362 | z_channels=cfg.z_channels, 363 | double_z=cfg.double_z, 364 | ) 365 | self.decoder = Decoder2d( 366 | ch=cfg.ch, 367 | ch_mult=cfg.ch_mult, 368 | num_res_blocks=cfg.num_res_blocks, 369 | out_ch=cfg.out_ch, 370 | z_channels=cfg.z_channels, 371 | in_channels=None, 372 | ) 373 | assert self.cfg.double_z 374 | 375 | self.quant_conv = torch.nn.Conv2d(2 * cfg.z_channels, 2 * cfg.z_channels, 1) 376 | self.post_quant_conv = torch.nn.Conv2d(cfg.z_channels, cfg.z_channels, 1) 377 | self.embed_dim = cfg.z_channels 378 | 379 | def encode(self, x): 380 | h = self.encoder(x) 381 | moments = self.quant_conv(h) 382 | posterior = DiagonalGaussianDistribution(moments) 383 | return posterior 384 | 385 | def decode(self, z): 386 | z = self.post_quant_conv(z) 387 | dec = self.decoder(z) 388 | return dec 389 | 390 | def forward(self, input, sample_posterior=True): 391 | posterior = self.encode(input) 392 | if sample_posterior: 393 | z = posterior.sample() 394 | else: 395 | z = posterior.mode() 396 | dec = self.decode(z) 397 | return dec, posterior 398 | 399 | def get_last_layer(self): 400 | return self.decoder.conv_out.weight 401 | -------------------------------------------------------------------------------- /models/tta/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /models/tta/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_hifires import Generator_HiFiRes 2 | from .models import Generator as Generator 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | -------------------------------------------------------------------------------- /models/tta/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2**i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) 175 | -------------------------------------------------------------------------------- /models/tta/hifigan/models_hifires.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock1(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock1, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class ResBlock2(torch.nn.Module): 113 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 114 | super(ResBlock2, self).__init__() 115 | self.h = h 116 | self.convs = nn.ModuleList( 117 | [ 118 | weight_norm( 119 | Conv1d( 120 | channels, 121 | channels, 122 | kernel_size, 123 | 1, 124 | dilation=dilation[0], 125 | padding=get_padding(kernel_size, dilation[0]), 126 | ) 127 | ), 128 | weight_norm( 129 | Conv1d( 130 | channels, 131 | channels, 132 | kernel_size, 133 | 1, 134 | dilation=dilation[1], 135 | padding=get_padding(kernel_size, dilation[1]), 136 | ) 137 | ), 138 | ] 139 | ) 140 | self.convs.apply(init_weights) 141 | 142 | def forward(self, x): 143 | for c in self.convs: 144 | xt = F.leaky_relu(x, LRELU_SLOPE) 145 | xt = c(xt) 146 | x = xt + x 147 | return x 148 | 149 | def remove_weight_norm(self): 150 | for l in self.convs: 151 | remove_weight_norm(l) 152 | 153 | 154 | class Generator_HiFiRes(torch.nn.Module): 155 | def __init__(self, h): 156 | super(Generator_HiFiRes, self).__init__() 157 | self.h = h 158 | self.num_kernels = len(h.resblock_kernel_sizes) 159 | self.num_upsamples = len(h.upsample_rates) 160 | self.conv_pre = weight_norm( 161 | Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3) 162 | ) 163 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 164 | 165 | self.ups = nn.ModuleList() 166 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 167 | self.ups.append( 168 | weight_norm( 169 | ConvTranspose1d( 170 | h.upsample_initial_channel // (2**i), 171 | h.upsample_initial_channel // (2 ** (i + 1)), 172 | u * 2, 173 | u, 174 | padding=u // 2 + u % 2, 175 | output_padding=u % 2, 176 | ) 177 | ) 178 | ) 179 | 180 | self.resblocks = nn.ModuleList() 181 | for i in range(len(self.ups)): 182 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 183 | for j, (k, d) in enumerate( 184 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 185 | ): 186 | self.resblocks.append(resblock(h, ch, k, d)) 187 | 188 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 189 | self.ups.apply(init_weights) 190 | self.conv_post.apply(init_weights) 191 | 192 | def forward(self, x): 193 | x = self.conv_pre(x) 194 | for i in range(self.num_upsamples): 195 | x = F.leaky_relu(x, LRELU_SLOPE) 196 | x = self.ups[i](x) 197 | xs = None 198 | for j in range(self.num_kernels): 199 | if xs is None: 200 | xs = self.resblocks[i * self.num_kernels + j](x) 201 | else: 202 | xs += self.resblocks[i * self.num_kernels + j](x) 203 | x = xs / self.num_kernels 204 | x = F.leaky_relu(x) 205 | x = self.conv_post(x) 206 | x = torch.tanh(x) 207 | 208 | return x 209 | 210 | def remove_weight_norm(self): 211 | print("Removing weight norm...") 212 | for l in self.ups: 213 | remove_weight_norm(l) 214 | for l in self.resblocks: 215 | l.remove_weight_norm() 216 | remove_weight_norm(self.conv_pre) 217 | remove_weight_norm(self.conv_post) 218 | -------------------------------------------------------------------------------- /models/tta/lfm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gwh22/LAFMA/0a11ac2d4d176018f0aa0531dd6aca84cc532138/models/tta/lfm/__init__.py -------------------------------------------------------------------------------- /models/tta/lfm/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from inspect import isfunction 7 | import math 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn, einsum 11 | from einops import rearrange, repeat 12 | 13 | 14 | class CheckpointFunction(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, run_function, length, *args): 17 | ctx.run_function = run_function 18 | ctx.input_tensors = list(args[:length]) 19 | ctx.input_params = list(args[length:]) 20 | 21 | with torch.no_grad(): 22 | output_tensors = ctx.run_function(*ctx.input_tensors) 23 | return output_tensors 24 | 25 | @staticmethod 26 | def backward(ctx, *output_grads): 27 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 28 | with torch.enable_grad(): 29 | # Fixes a bug where the first op in run_function modifies the 30 | # Tensor storage in place, which is not allowed for detach()'d 31 | # Tensors. 32 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 33 | output_tensors = ctx.run_function(*shallow_copies) 34 | input_grads = torch.autograd.grad( 35 | output_tensors, 36 | ctx.input_tensors + ctx.input_params, 37 | output_grads, 38 | allow_unused=True, 39 | ) 40 | del ctx.input_tensors 41 | del ctx.input_params 42 | del output_tensors 43 | return (None, None) + input_grads 44 | 45 | 46 | def checkpoint(func, inputs, params, flag): 47 | """ 48 | Evaluate a function without caching intermediate activations, allowing for 49 | reduced memory at the expense of extra compute in the backward pass. 50 | :param func: the function to evaluate. 51 | :param inputs: the argument sequence to pass to `func`. 52 | :param params: a sequence of parameters `func` depends on but does not 53 | explicitly take as arguments. 54 | :param flag: if False, disable gradient checkpointing. 55 | """ 56 | if flag: 57 | args = tuple(inputs) + tuple(params) 58 | return CheckpointFunction.apply(func, len(inputs), *args) 59 | else: 60 | return func(*inputs) 61 | 62 | 63 | def exists(val): 64 | return val is not None 65 | 66 | 67 | def uniq(arr): 68 | return {el: True for el in arr}.keys() 69 | 70 | 71 | def default(val, d): 72 | if exists(val): 73 | return val 74 | return d() if isfunction(d) else d 75 | 76 | 77 | def max_neg_value(t): 78 | return -torch.finfo(t.dtype).max 79 | 80 | 81 | def init_(tensor): 82 | dim = tensor.shape[-1] 83 | std = 1 / math.sqrt(dim) 84 | tensor.uniform_(-std, std) 85 | return tensor 86 | 87 | 88 | # feedforward 89 | class GEGLU(nn.Module): 90 | def __init__(self, dim_in, dim_out): 91 | super().__init__() 92 | self.proj = nn.Linear(dim_in, dim_out * 2) 93 | 94 | def forward(self, x): 95 | x, gate = self.proj(x).chunk(2, dim=-1) 96 | return x * F.gelu(gate) 97 | 98 | 99 | class FeedForward(nn.Module): 100 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 101 | super().__init__() 102 | inner_dim = int(dim * mult) 103 | dim_out = default(dim_out, dim) 104 | project_in = ( 105 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 106 | if not glu 107 | else GEGLU(dim, inner_dim) 108 | ) 109 | 110 | self.net = nn.Sequential( 111 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 112 | ) 113 | 114 | def forward(self, x): 115 | return self.net(x) 116 | 117 | 118 | def zero_module(module): 119 | """ 120 | Zero out the parameters of a module and return it. 121 | """ 122 | for p in module.parameters(): 123 | p.detach().zero_() 124 | return module 125 | 126 | 127 | def Normalize(in_channels): 128 | return torch.nn.GroupNorm( 129 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 130 | ) 131 | 132 | 133 | class LinearAttention(nn.Module): 134 | def __init__(self, dim, heads=4, dim_head=32): 135 | super().__init__() 136 | self.heads = heads 137 | hidden_dim = dim_head * heads 138 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 139 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 140 | 141 | def forward(self, x): 142 | b, c, h, w = x.shape 143 | qkv = self.to_qkv(x) 144 | q, k, v = rearrange( 145 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 146 | ) 147 | k = k.softmax(dim=-1) 148 | context = torch.einsum("bhdn,bhen->bhde", k, v) 149 | out = torch.einsum("bhde,bhdn->bhen", context, q) 150 | out = rearrange( 151 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 152 | ) 153 | return self.to_out(out) 154 | 155 | 156 | class SpatialSelfAttention(nn.Module): 157 | def __init__(self, in_channels): 158 | super().__init__() 159 | self.in_channels = in_channels 160 | 161 | self.norm = Normalize(in_channels) 162 | self.q = torch.nn.Conv2d( 163 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 164 | ) 165 | self.k = torch.nn.Conv2d( 166 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 167 | ) 168 | self.v = torch.nn.Conv2d( 169 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 170 | ) 171 | self.proj_out = torch.nn.Conv2d( 172 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 173 | ) 174 | 175 | def forward(self, x): 176 | h_ = x 177 | h_ = self.norm(h_) 178 | q = self.q(h_) 179 | k = self.k(h_) 180 | v = self.v(h_) 181 | 182 | # compute attention 183 | b, c, h, w = q.shape 184 | q = rearrange(q, "b c h w -> b (h w) c") 185 | k = rearrange(k, "b c h w -> b c (h w)") 186 | w_ = torch.einsum("bij,bjk->bik", q, k) 187 | 188 | w_ = w_ * (int(c) ** (-0.5)) 189 | w_ = torch.nn.functional.softmax(w_, dim=2) 190 | 191 | # attend to values 192 | v = rearrange(v, "b c h w -> b c (h w)") 193 | w_ = rearrange(w_, "b i j -> b j i") 194 | h_ = torch.einsum("bij,bjk->bik", v, w_) 195 | h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) 196 | h_ = self.proj_out(h_) 197 | 198 | return x + h_ 199 | 200 | 201 | class CrossAttention(nn.Module): 202 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 203 | super().__init__() 204 | inner_dim = dim_head * heads 205 | context_dim = default(context_dim, query_dim) 206 | 207 | self.scale = dim_head**-0.5 208 | self.heads = heads 209 | 210 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 211 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 212 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 213 | 214 | self.to_out = nn.Sequential( 215 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 216 | ) 217 | 218 | def forward(self, x, context=None, mask=None): 219 | h = self.heads 220 | 221 | q = self.to_q(x) 222 | context = default(context, x) 223 | # print("context", context.dtype) 224 | k = self.to_k(context) 225 | v = self.to_v(context) 226 | 227 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 228 | 229 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 230 | 231 | if exists(mask): 232 | mask = rearrange(mask, "b ... -> b (...)") 233 | max_neg_value = -torch.finfo(sim.dtype).max 234 | mask = repeat(mask, "b j -> (b h) () j", h=h) 235 | sim.masked_fill_(~mask, max_neg_value) 236 | 237 | # attention, what we cannot get enough of 238 | attn = sim.softmax(dim=-1) 239 | 240 | out = einsum("b i j, b j d -> b i d", attn, v) 241 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 242 | return self.to_out(out) 243 | 244 | 245 | class BasicTransformerBlock(nn.Module): 246 | def __init__( 247 | self, 248 | dim, 249 | n_heads, 250 | d_head, 251 | dropout=0.0, 252 | context_dim=None, 253 | gated_ff=True, 254 | checkpoint=True, 255 | ): 256 | super().__init__() 257 | self.attn1 = CrossAttention( 258 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 259 | ) # is a self-attention 260 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 261 | self.attn2 = CrossAttention( 262 | query_dim=dim, 263 | context_dim=context_dim, 264 | heads=n_heads, 265 | dim_head=d_head, 266 | dropout=dropout, 267 | ) # is self-attn if context is none 268 | self.norm1 = nn.LayerNorm(dim) 269 | self.norm2 = nn.LayerNorm(dim) 270 | self.norm3 = nn.LayerNorm(dim) 271 | self.checkpoint = checkpoint 272 | 273 | def forward(self, x, context=None): 274 | return checkpoint( 275 | self._forward, (x, context), self.parameters(), self.checkpoint 276 | ) 277 | 278 | def _forward(self, x, context=None): 279 | x = self.attn1(self.norm1(x)) + x 280 | x = self.attn2(self.norm2(x), context=context) + x 281 | x = self.ff(self.norm3(x)) + x 282 | return x 283 | 284 | 285 | class SpatialTransformer(nn.Module): 286 | """ 287 | Transformer block for image-like data. 288 | First, project the input (aka embedding) 289 | and reshape to b, t, d. 290 | Then apply standard transformer action. 291 | Finally, reshape to image 292 | """ 293 | 294 | def __init__( 295 | self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None 296 | ): 297 | super().__init__() 298 | self.in_channels = in_channels 299 | inner_dim = n_heads * d_head 300 | self.norm = Normalize(in_channels) 301 | 302 | self.proj_in = nn.Conv2d( 303 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 304 | ) 305 | 306 | self.transformer_blocks = nn.ModuleList( 307 | [ 308 | BasicTransformerBlock( 309 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 310 | ) 311 | for d in range(depth) 312 | ] 313 | ) 314 | 315 | self.proj_out = zero_module( 316 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 317 | ) 318 | 319 | def forward(self, x, context=None): 320 | # note: if no context is given, cross-attention defaults to self-attention 321 | b, c, h, w = x.shape 322 | x_in = x 323 | x = self.norm(x) 324 | x = self.proj_in(x) 325 | x = rearrange(x, "b c h w -> b (h w) c") 326 | for block in self.transformer_blocks: 327 | x = block(x, context=context) 328 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 329 | x = self.proj_out(x) 330 | return x + x_in 331 | -------------------------------------------------------------------------------- /models/tta/lfm/audioldm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from abc import abstractmethod 7 | from functools import partial 8 | import math 9 | from typing import Iterable 10 | 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from models.tta.lfm.attention import SpatialTransformer 19 | 20 | 21 | class CheckpointFunction(torch.autograd.Function): 22 | @staticmethod 23 | def forward(ctx, run_function, length, *args): 24 | ctx.run_function = run_function 25 | ctx.input_tensors = list(args[:length]) 26 | ctx.input_params = list(args[length:]) 27 | 28 | with torch.no_grad(): 29 | output_tensors = ctx.run_function(*ctx.input_tensors) 30 | return output_tensors 31 | 32 | @staticmethod 33 | def backward(ctx, *output_grads): 34 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 35 | with torch.enable_grad(): 36 | # Fixes a bug where the first op in run_function modifies the 37 | # Tensor storage in place, which is not allowed for detach()'d 38 | # Tensors. 39 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 40 | output_tensors = ctx.run_function(*shallow_copies) 41 | input_grads = torch.autograd.grad( 42 | output_tensors, 43 | ctx.input_tensors + ctx.input_params, 44 | output_grads, 45 | allow_unused=True, 46 | ) 47 | del ctx.input_tensors 48 | del ctx.input_params 49 | del output_tensors 50 | return (None, None) + input_grads 51 | 52 | 53 | def checkpoint(func, inputs, params, flag): 54 | """ 55 | Evaluate a function without caching intermediate activations, allowing for 56 | reduced memory at the expense of extra compute in the backward pass. 57 | :param func: the function to evaluate. 58 | :param inputs: the argument sequence to pass to `func`. 59 | :param params: a sequence of parameters `func` depends on but does not 60 | explicitly take as arguments. 61 | :param flag: if False, disable gradient checkpointing. 62 | """ 63 | if flag: 64 | args = tuple(inputs) + tuple(params) 65 | return CheckpointFunction.apply(func, len(inputs), *args) 66 | else: 67 | return func(*inputs) 68 | 69 | 70 | def zero_module(module): 71 | """ 72 | Zero out the parameters of a module and return it. 73 | """ 74 | for p in module.parameters(): 75 | p.detach().zero_() 76 | return module 77 | 78 | 79 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 80 | """ 81 | Create sinusoidal timestep embeddings. 82 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 83 | These may be fractional. 84 | :param dim: the dimension of the output. 85 | :param max_period: controls the minimum frequency of the embeddings. 86 | :return: an [N x dim] Tensor of positional embeddings. 87 | """ 88 | if not repeat_only: 89 | half = dim // 2 90 | freqs = torch.exp( 91 | -math.log(max_period) 92 | * torch.arange(start=0, end=half, dtype=torch.float32) 93 | / half 94 | ).to(device=timesteps.device) 95 | args = timesteps[:, None].float() * freqs[None] 96 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 97 | if dim % 2: 98 | embedding = torch.cat( 99 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 100 | ) 101 | else: 102 | embedding = repeat(timesteps, "b -> b d", d=dim) 103 | return embedding 104 | 105 | 106 | class GroupNorm32(nn.GroupNorm): 107 | def forward(self, x): 108 | return super().forward(x.float()).type(x.dtype) 109 | 110 | 111 | def normalization(channels): 112 | """ 113 | Make a standard normalization layer. 114 | :param channels: number of input channels. 115 | :return: an nn.Module for normalization. 116 | """ 117 | return GroupNorm32(32, channels) 118 | 119 | 120 | def count_flops_attn(model, _x, y): 121 | """ 122 | A counter for the `thop` package to count the operations in an 123 | attention operation. 124 | Meant to be used like: 125 | macs, params = thop.profile( 126 | model, 127 | inputs=(inputs, timestamps), 128 | custom_ops={QKVAttention: QKVAttention.count_flops}, 129 | ) 130 | """ 131 | b, c, *spatial = y[0].shape 132 | num_spatial = int(np.prod(spatial)) 133 | # We perform two matmuls with the same number of ops. 134 | # The first computes the weight matrix, the second computes 135 | # the combination of the value vectors. 136 | matmul_ops = 2 * b * (num_spatial**2) * c 137 | model.total_ops += torch.DoubleTensor([matmul_ops]) 138 | 139 | 140 | def conv_nd(dims, *args, **kwargs): 141 | """ 142 | Create a 1D, 2D, or 3D convolution module. 143 | """ 144 | if dims == 1: 145 | return nn.Conv1d(*args, **kwargs) 146 | elif dims == 2: 147 | return nn.Conv2d(*args, **kwargs) 148 | elif dims == 3: 149 | return nn.Conv3d(*args, **kwargs) 150 | raise ValueError(f"unsupported dimensions: {dims}") 151 | 152 | 153 | def avg_pool_nd(dims, *args, **kwargs): 154 | """ 155 | Create a 1D, 2D, or 3D average pooling module. 156 | """ 157 | if dims == 1: 158 | return nn.AvgPool1d(*args, **kwargs) 159 | elif dims == 2: 160 | return nn.AvgPool2d(*args, **kwargs) 161 | elif dims == 3: 162 | return nn.AvgPool3d(*args, **kwargs) 163 | raise ValueError(f"unsupported dimensions: {dims}") 164 | 165 | 166 | class QKVAttention(nn.Module): 167 | """ 168 | A module which performs QKV attention and splits in a different order. 169 | """ 170 | 171 | def __init__(self, n_heads): 172 | super().__init__() 173 | self.n_heads = n_heads 174 | 175 | def forward(self, qkv): 176 | """ 177 | Apply QKV attention. 178 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 179 | :return: an [N x (H * C) x T] tensor after attention. 180 | """ 181 | 182 | bs, width, length = qkv.shape 183 | assert width % (3 * self.n_heads) == 0 184 | ch = width // (3 * self.n_heads) 185 | q, k, v = qkv.chunk(3, dim=1) # [N x (H * C) x T] 186 | scale = 1 / math.sqrt(math.sqrt(ch)) 187 | weight = torch.einsum( 188 | "bct,bcs->bts", 189 | (q * scale).view(bs * self.n_heads, ch, length), 190 | (k * scale).view(bs * self.n_heads, ch, length), 191 | ) # More stable with f16 than dividing afterwards 192 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 193 | a = torch.einsum( 194 | "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) 195 | ) 196 | return a.reshape(bs, -1, length) 197 | 198 | @staticmethod 199 | def count_flops(model, _x, y): 200 | return count_flops_attn(model, _x, y) 201 | 202 | 203 | class QKVAttentionLegacy(nn.Module): 204 | """ 205 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 206 | """ 207 | 208 | def __init__(self, n_heads): 209 | super().__init__() 210 | self.n_heads = n_heads 211 | 212 | def forward(self, qkv): 213 | """ 214 | Apply QKV attention. 215 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 216 | :return: an [N x (H * C) x T] tensor after attention. 217 | """ 218 | bs, width, length = qkv.shape 219 | assert width % (3 * self.n_heads) == 0 220 | ch = width // (3 * self.n_heads) 221 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 222 | scale = 1 / math.sqrt(math.sqrt(ch)) 223 | weight = torch.einsum( 224 | "bct,bcs->bts", q * scale, k * scale 225 | ) # More stable with f16 than dividing afterwards 226 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 227 | a = torch.einsum("bts,bcs->bct", weight, v) 228 | return a.reshape(bs, -1, length) 229 | 230 | @staticmethod 231 | def count_flops(model, _x, y): 232 | return count_flops_attn(model, _x, y) 233 | 234 | 235 | class AttentionPool2d(nn.Module): 236 | """ 237 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 238 | """ 239 | 240 | def __init__( 241 | self, 242 | spacial_dim: int, 243 | embed_dim: int, 244 | num_heads_channels: int, 245 | output_dim: int = None, 246 | ): 247 | super().__init__() 248 | self.positional_embedding = nn.Parameter( 249 | torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 250 | ) 251 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 252 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 253 | self.num_heads = embed_dim // num_heads_channels 254 | self.attention = QKVAttention(self.num_heads) 255 | 256 | def forward(self, x): 257 | b, c, *_spatial = x.shape 258 | x = x.reshape(b, c, -1) # NC(HW) 259 | x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 260 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 261 | x = self.qkv_proj(x) 262 | x = self.attention(x) 263 | x = self.c_proj(x) 264 | return x[:, :, 0] 265 | 266 | 267 | class TimestepBlock(nn.Module): 268 | """ 269 | Any module where forward() takes timestep embeddings as a second argument. 270 | """ 271 | 272 | @abstractmethod 273 | def forward(self, x, emb): 274 | """ 275 | Apply the module to `x` given `emb` timestep embeddings. 276 | """ 277 | 278 | 279 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 280 | """ 281 | A sequential module that passes timestep embeddings to the children that 282 | support it as an extra input. 283 | """ 284 | 285 | def forward(self, x, emb, context=None): 286 | for layer in self: 287 | if isinstance(layer, TimestepBlock): 288 | x = layer(x, emb) 289 | elif isinstance(layer, SpatialTransformer): 290 | x = layer(x, context) 291 | else: 292 | x = layer(x) 293 | return x 294 | 295 | 296 | class Upsample(nn.Module): 297 | """ 298 | An upsampling layer with an optional convolution. 299 | :param channels: channels in the inputs and outputs. 300 | :param use_conv: a bool determining if a convolution is applied. 301 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 302 | upsampling occurs in the inner-two dimensions. 303 | """ 304 | 305 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 306 | super().__init__() 307 | self.channels = channels 308 | self.out_channels = out_channels or channels 309 | self.use_conv = use_conv 310 | self.dims = dims 311 | if use_conv: 312 | self.conv = conv_nd( 313 | dims, self.channels, self.out_channels, 3, padding=padding 314 | ) 315 | 316 | def forward(self, x): 317 | assert x.shape[1] == self.channels 318 | if self.dims == 3: 319 | x = F.interpolate( 320 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 321 | ) 322 | else: 323 | x = F.interpolate(x, scale_factor=2, mode="nearest") 324 | if self.use_conv: 325 | x = self.conv(x) 326 | return x 327 | 328 | 329 | class TransposedUpsample(nn.Module): 330 | "Learned 2x upsampling without padding" 331 | 332 | def __init__(self, channels, out_channels=None, ks=5): 333 | super().__init__() 334 | self.channels = channels 335 | self.out_channels = out_channels or channels 336 | 337 | self.up = nn.ConvTranspose2d( 338 | self.channels, self.out_channels, kernel_size=ks, stride=2 339 | ) 340 | 341 | def forward(self, x): 342 | return self.up(x) 343 | 344 | 345 | class Downsample(nn.Module): 346 | """ 347 | A downsampling layer with an optional convolution. 348 | :param channels: channels in the inputs and outputs. 349 | :param use_conv: a bool determining if a convolution is applied. 350 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 351 | downsampling occurs in the inner-two dimensions. 352 | """ 353 | 354 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 355 | super().__init__() 356 | self.channels = channels 357 | self.out_channels = out_channels or channels 358 | self.use_conv = use_conv 359 | self.dims = dims 360 | stride = 2 if dims != 3 else (1, 2, 2) 361 | if use_conv: 362 | self.op = conv_nd( 363 | dims, 364 | self.channels, 365 | self.out_channels, 366 | 3, 367 | stride=stride, 368 | padding=padding, 369 | ) 370 | else: 371 | assert self.channels == self.out_channels 372 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 373 | 374 | def forward(self, x): 375 | assert x.shape[1] == self.channels 376 | return self.op(x) 377 | 378 | 379 | class ResBlock(TimestepBlock): 380 | """ 381 | A residual block that can optionally change the number of channels. 382 | :param channels: the number of input channels. 383 | :param emb_channels: the number of timestep embedding channels. 384 | :param dropout: the rate of dropout. 385 | :param out_channels: if specified, the number of out channels. 386 | :param use_conv: if True and out_channels is specified, use a spatial 387 | convolution instead of a smaller 1x1 convolution to change the 388 | channels in the skip connection. 389 | :param dims: determines if the signal is 1D, 2D, or 3D. 390 | :param use_checkpoint: if True, use gradient checkpointing on this module. 391 | :param up: if True, use this block for upsampling. 392 | :param down: if True, use this block for downsampling. 393 | """ 394 | 395 | def __init__( 396 | self, 397 | channels, 398 | emb_channels, 399 | dropout, 400 | out_channels=None, 401 | use_conv=False, 402 | use_scale_shift_norm=False, 403 | dims=2, 404 | use_checkpoint=False, 405 | up=False, 406 | down=False, 407 | ): 408 | super().__init__() 409 | self.channels = channels 410 | self.emb_channels = emb_channels 411 | self.dropout = dropout 412 | self.out_channels = out_channels or channels 413 | self.use_conv = use_conv 414 | self.use_checkpoint = use_checkpoint 415 | self.use_scale_shift_norm = use_scale_shift_norm 416 | 417 | self.in_layers = nn.Sequential( 418 | normalization(channels), 419 | nn.SiLU(), 420 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 421 | ) 422 | 423 | self.updown = up or down 424 | 425 | if up: 426 | self.h_upd = Upsample(channels, False, dims) 427 | self.x_upd = Upsample(channels, False, dims) 428 | elif down: 429 | self.h_upd = Downsample(channels, False, dims) 430 | self.x_upd = Downsample(channels, False, dims) 431 | else: 432 | self.h_upd = self.x_upd = nn.Identity() 433 | 434 | self.emb_layers = nn.Sequential( 435 | nn.SiLU(), 436 | nn.Linear( 437 | emb_channels, 438 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 439 | ), 440 | ) 441 | self.out_layers = nn.Sequential( 442 | normalization(self.out_channels), 443 | nn.SiLU(), 444 | nn.Dropout(p=dropout), 445 | zero_module( 446 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 447 | ), 448 | ) 449 | 450 | if self.out_channels == channels: 451 | self.skip_connection = nn.Identity() 452 | elif use_conv: 453 | self.skip_connection = conv_nd( 454 | dims, channels, self.out_channels, 3, padding=1 455 | ) 456 | else: 457 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 458 | 459 | def forward(self, x, emb): 460 | """ 461 | Apply the block to a Tensor, conditioned on a timestep embedding. 462 | :param x: an [N x C x ...] Tensor of features. 463 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 464 | :return: an [N x C x ...] Tensor of outputs. 465 | """ 466 | return checkpoint( 467 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 468 | ) 469 | 470 | def _forward(self, x, emb): 471 | if self.updown: 472 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 473 | h = in_rest(x) 474 | h = self.h_upd(h) 475 | x = self.x_upd(x) 476 | h = in_conv(h) 477 | else: 478 | h = self.in_layers(x) 479 | emb_out = self.emb_layers(emb).type(h.dtype) 480 | while len(emb_out.shape) < len(h.shape): 481 | emb_out = emb_out[..., None] 482 | if self.use_scale_shift_norm: 483 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 484 | scale, shift = torch.chunk(emb_out, 2, dim=1) 485 | h = out_norm(h) * (1 + scale) + shift 486 | h = out_rest(h) 487 | else: 488 | h = h + emb_out 489 | h = self.out_layers(h) 490 | return self.skip_connection(x) + h 491 | 492 | 493 | class AttentionBlock(nn.Module): 494 | """ 495 | An attention block that allows spatial positions to attend to each other. 496 | Originally ported from here, but adapted to the N-d case. 497 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 498 | """ 499 | 500 | def __init__( 501 | self, 502 | channels, 503 | num_heads=1, 504 | num_head_channels=-1, 505 | use_checkpoint=False, 506 | use_new_attention_order=False, 507 | ): 508 | super().__init__() 509 | self.channels = channels 510 | if num_head_channels == -1: 511 | self.num_heads = num_heads 512 | else: 513 | assert ( 514 | channels % num_head_channels == 0 515 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 516 | self.num_heads = channels // num_head_channels 517 | self.use_checkpoint = use_checkpoint 518 | self.norm = normalization(channels) 519 | self.qkv = conv_nd(1, channels, channels * 3, 1) 520 | if use_new_attention_order: 521 | # split qkv before split heads 522 | self.attention = QKVAttention(self.num_heads) 523 | else: 524 | # split heads before split qkv 525 | self.attention = QKVAttentionLegacy(self.num_heads) 526 | 527 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 528 | 529 | def forward(self, x): 530 | return checkpoint( 531 | self._forward, (x,), self.parameters(), True 532 | ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! 533 | # return pt_checkpoint(self._forward, x) # pytorch 534 | 535 | def _forward(self, x): 536 | b, c, *spatial = x.shape 537 | x = x.reshape(b, c, -1) 538 | qkv = self.qkv(self.norm(x)) 539 | h = self.attention(qkv) 540 | h = self.proj_out(h) 541 | return (x + h).reshape(b, c, *spatial) 542 | 543 | 544 | class UNetModel(nn.Module): 545 | """ 546 | The full UNet model with attention and timestep embedding. 547 | :param in_channels: channels in the input Tensor. 548 | :param model_channels: base channel count for the model. 549 | :param out_channels: channels in the output Tensor. 550 | :param num_res_blocks: number of residual blocks per downsample. 551 | :param attention_resolutions: a collection of downsample rates at which 552 | attention will take place. May be a set, list, or tuple. 553 | For example, if this contains 4, then at 4x downsampling, attention 554 | will be used. 555 | :param dropout: the dropout probability. 556 | :param channel_mult: channel multiplier for each level of the UNet. 557 | :param conv_resample: if True, use learned convolutions for upsampling and 558 | downsampling. 559 | :param dims: determines if the signal is 1D, 2D, or 3D. 560 | :param num_classes: if specified (as an int), then this model will be 561 | class-conditional with `num_classes` classes. 562 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 563 | :param num_heads: the number of attention heads in each attention layer. 564 | :param num_heads_channels: if specified, ignore num_heads and instead use 565 | a fixed channel width per attention head. 566 | :param num_heads_upsample: works with num_heads to set a different number 567 | of heads for upsampling. Deprecated. 568 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 569 | :param resblock_updown: use residual blocks for up/downsampling. 570 | :param use_new_attention_order: use a different attention pattern for potentially 571 | increased efficiency. 572 | """ 573 | 574 | def __init__( 575 | self, 576 | image_size, 577 | in_channels, 578 | model_channels, 579 | out_channels, 580 | num_res_blocks, 581 | attention_resolutions, 582 | dropout=0, 583 | channel_mult=(1, 2, 4, 8), 584 | conv_resample=True, 585 | dims=2, 586 | extra_sa_layer=True, 587 | num_classes=None, 588 | use_checkpoint=False, 589 | use_fp16=False, # use fp16 590 | num_heads=-1, 591 | num_head_channels=-1, 592 | num_heads_upsample=-1, 593 | use_scale_shift_norm=False, 594 | resblock_updown=False, 595 | use_new_attention_order=False, 596 | use_spatial_transformer=False, # custom transformer support 597 | transformer_depth=1, # custom transformer support 598 | context_dim=None, # custom transformer support 599 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model 600 | legacy=True, 601 | ): 602 | super().__init__() 603 | if use_spatial_transformer: 604 | assert ( 605 | context_dim is not None 606 | ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." 607 | 608 | if context_dim is not None: 609 | assert ( 610 | use_spatial_transformer 611 | ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." 612 | from omegaconf.listconfig import ListConfig 613 | 614 | if type(context_dim) is ListConfig: 615 | context_dim = list(context_dim) 616 | 617 | if num_heads_upsample == -1: 618 | num_heads_upsample = num_heads 619 | 620 | if num_heads == -1: 621 | assert ( 622 | num_head_channels != -1 623 | ), "Either num_heads or num_head_channels has to be set" 624 | 625 | if num_head_channels == -1: 626 | assert ( 627 | num_heads != -1 628 | ), "Either num_heads or num_head_channels has to be set" 629 | 630 | self.image_size = image_size 631 | self.in_channels = in_channels 632 | self.model_channels = model_channels 633 | self.out_channels = out_channels 634 | self.num_res_blocks = num_res_blocks 635 | self.attention_resolutions = attention_resolutions 636 | self.dropout = dropout 637 | self.channel_mult = channel_mult 638 | self.conv_resample = conv_resample 639 | self.num_classes = num_classes 640 | self.use_checkpoint = use_checkpoint 641 | self.dtype = torch.float16 if use_fp16 else torch.float32 642 | self.num_heads = num_heads 643 | self.num_head_channels = num_head_channels 644 | self.num_heads_upsample = num_heads_upsample 645 | self.predict_codebook_ids = n_embed is not None 646 | 647 | time_embed_dim = model_channels * 4 648 | self.time_embed = nn.Sequential( 649 | nn.Linear(model_channels, time_embed_dim), 650 | nn.SiLU(), 651 | nn.Linear(time_embed_dim, time_embed_dim), 652 | ) 653 | 654 | if self.num_classes is not None: 655 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 656 | 657 | self.input_blocks = nn.ModuleList( 658 | [ 659 | TimestepEmbedSequential( 660 | conv_nd( 661 | dims, 662 | in_channels, 663 | model_channels, 664 | 3, 665 | padding=1, 666 | ) # float--->half 667 | ) 668 | ] 669 | ) 670 | self._feature_size = model_channels 671 | input_block_chans = [model_channels] 672 | ch = model_channels 673 | ds = 1 674 | for level, mult in enumerate(channel_mult): 675 | for _ in range(num_res_blocks): 676 | layers = [ 677 | ResBlock( 678 | ch, 679 | time_embed_dim, 680 | dropout, 681 | out_channels=mult * model_channels, 682 | dims=dims, 683 | use_checkpoint=use_checkpoint, 684 | use_scale_shift_norm=use_scale_shift_norm, 685 | ) 686 | ] 687 | ch = mult * model_channels 688 | if ds in attention_resolutions: 689 | if num_head_channels == -1: 690 | dim_head = ch // num_heads 691 | else: 692 | num_heads = ch // num_head_channels 693 | dim_head = num_head_channels 694 | if legacy: 695 | # num_heads = 1 696 | dim_head = ( 697 | ch // num_heads 698 | if use_spatial_transformer 699 | else num_head_channels 700 | ) 701 | if extra_sa_layer: 702 | layers.append( 703 | SpatialTransformer( 704 | ch, 705 | num_heads, 706 | dim_head, 707 | depth=transformer_depth, 708 | context_dim=context_dim, 709 | ) 710 | ) 711 | 712 | layers.append( 713 | AttentionBlock( 714 | ch, 715 | use_checkpoint=use_checkpoint, 716 | num_heads=num_heads, 717 | num_head_channels=dim_head, 718 | use_new_attention_order=use_new_attention_order, 719 | ) 720 | if not use_spatial_transformer 721 | else SpatialTransformer( 722 | ch, 723 | num_heads, 724 | dim_head, 725 | depth=transformer_depth, 726 | context_dim=context_dim, 727 | ) 728 | ) 729 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 730 | self._feature_size += ch 731 | input_block_chans.append(ch) 732 | if level != len(channel_mult) - 1: 733 | out_ch = ch 734 | self.input_blocks.append( 735 | TimestepEmbedSequential( 736 | ResBlock( 737 | ch, 738 | time_embed_dim, 739 | dropout, 740 | out_channels=out_ch, 741 | dims=dims, 742 | use_checkpoint=use_checkpoint, 743 | use_scale_shift_norm=use_scale_shift_norm, 744 | down=True, 745 | ) 746 | if resblock_updown 747 | else Downsample( 748 | ch, conv_resample, dims=dims, out_channels=out_ch 749 | ) 750 | ) 751 | ) 752 | ch = out_ch 753 | input_block_chans.append(ch) 754 | ds *= 2 755 | self._feature_size += ch 756 | 757 | if num_head_channels == -1: 758 | dim_head = ch // num_heads 759 | else: 760 | num_heads = ch // num_head_channels 761 | dim_head = num_head_channels 762 | if legacy: 763 | # num_heads = 1 764 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 765 | middle_layers = [ 766 | ResBlock( 767 | ch, 768 | (time_embed_dim), 769 | dropout, 770 | dims=dims, 771 | use_checkpoint=use_checkpoint, 772 | use_scale_shift_norm=use_scale_shift_norm, 773 | ) 774 | ] 775 | if extra_sa_layer: 776 | middle_layers.append( 777 | SpatialTransformer( 778 | ch, 779 | num_heads, 780 | dim_head, 781 | depth=transformer_depth, 782 | context_dim=context_dim, 783 | ) 784 | ) 785 | 786 | middle_layers.append( 787 | AttentionBlock( 788 | ch, 789 | use_checkpoint=use_checkpoint, 790 | num_heads=num_heads, 791 | num_head_channels=dim_head, 792 | use_new_attention_order=use_new_attention_order, 793 | ) 794 | if not use_spatial_transformer 795 | else SpatialTransformer( 796 | ch, 797 | num_heads, 798 | dim_head, 799 | depth=transformer_depth, 800 | context_dim=context_dim, 801 | ) 802 | ) 803 | middle_layers.append( 804 | ResBlock( 805 | ch, 806 | (time_embed_dim), 807 | dropout, 808 | dims=dims, 809 | use_checkpoint=use_checkpoint, 810 | use_scale_shift_norm=use_scale_shift_norm, 811 | ) 812 | ) 813 | self.middle_block = TimestepEmbedSequential(*middle_layers) 814 | self._feature_size += ch 815 | 816 | self.output_blocks = nn.ModuleList([]) 817 | for level, mult in list(enumerate(channel_mult))[::-1]: 818 | for i in range(num_res_blocks + 1): 819 | ich = input_block_chans.pop() 820 | layers = [ 821 | ResBlock( 822 | ch + ich, 823 | time_embed_dim, 824 | dropout, 825 | out_channels=model_channels * mult, 826 | dims=dims, 827 | use_checkpoint=use_checkpoint, 828 | use_scale_shift_norm=use_scale_shift_norm, 829 | ) 830 | ] 831 | ch = model_channels * mult 832 | if ds in attention_resolutions: 833 | if num_head_channels == -1: 834 | dim_head = ch // num_heads 835 | else: 836 | num_heads = ch // num_head_channels 837 | dim_head = num_head_channels 838 | if legacy: 839 | # num_heads = 1 840 | dim_head = ( 841 | ch // num_heads 842 | if use_spatial_transformer 843 | else num_head_channels 844 | ) 845 | if extra_sa_layer: 846 | layers.append( 847 | SpatialTransformer( 848 | ch, 849 | num_heads, 850 | dim_head, 851 | depth=transformer_depth, 852 | context_dim=context_dim, 853 | ) 854 | ) 855 | 856 | layers.append( 857 | AttentionBlock( 858 | ch, 859 | use_checkpoint=use_checkpoint, 860 | num_heads=num_heads_upsample, 861 | num_head_channels=dim_head, 862 | use_new_attention_order=use_new_attention_order, 863 | ) 864 | if not use_spatial_transformer 865 | else SpatialTransformer( 866 | ch, 867 | num_heads, 868 | dim_head, 869 | depth=transformer_depth, 870 | context_dim=context_dim, 871 | ) 872 | ) 873 | if level and i == num_res_blocks: 874 | out_ch = ch 875 | layers.append( 876 | ResBlock( 877 | ch, 878 | time_embed_dim, 879 | dropout, 880 | out_channels=out_ch, 881 | dims=dims, 882 | use_checkpoint=use_checkpoint, 883 | use_scale_shift_norm=use_scale_shift_norm, 884 | up=True, 885 | ) 886 | if resblock_updown 887 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 888 | ) 889 | ds //= 2 890 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 891 | self._feature_size += ch 892 | 893 | self.out = nn.Sequential( 894 | normalization(ch), 895 | nn.SiLU(), 896 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 897 | ) 898 | if self.predict_codebook_ids: 899 | self.id_predictor = nn.Sequential( 900 | normalization(ch), 901 | conv_nd(dims, model_channels, n_embed, 1), 902 | # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits 903 | ) 904 | 905 | def forward(self, x, timesteps=None, context=None, y=None, **kwargs): 906 | """ 907 | Apply the model to an input batch. 908 | :param x: an [N x C x ...] Tensor of inputs. 909 | :param timesteps: a 1-D batch of timesteps. 910 | :param context: conditioning plugged in via crossattn 911 | :param y: an [N] Tensor of labels, if class-conditional. 912 | :return: an [N x C x ...] Tensor of outputs. 913 | """ 914 | assert (y is not None) == ( 915 | self.num_classes is not None 916 | ), "must specify y if and only if the model is class-conditional" 917 | hs = [] 918 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 919 | emb = self.time_embed(t_emb) 920 | 921 | if self.num_classes is not None: 922 | assert y.shape == (x.shape[0],) 923 | emb = emb + self.label_emb(y) 924 | 925 | h = x.type(self.dtype) 926 | for module in self.input_blocks: 927 | h = module(h, emb, context) 928 | hs.append(h) 929 | h = self.middle_block(h, emb, context) 930 | for module in self.output_blocks: 931 | # print(h.shape, hs[-1].shape) 932 | if h.shape != hs[-1].shape: 933 | if h.shape[-1] > hs[-1].shape[-1]: 934 | h = h[:, :, :, : hs[-1].shape[-1]] 935 | if h.shape[-2] > hs[-1].shape[-2]: 936 | h = h[:, :, : hs[-1].shape[-2], :] 937 | h = torch.cat([h, hs.pop()], dim=1) 938 | h = module(h, emb, context) 939 | # print(h.shape) 940 | h = h.type(x.dtype) 941 | if self.predict_codebook_ids: 942 | return self.id_predictor(h) 943 | else: 944 | return self.out(h) 945 | 946 | 947 | class AudioLDM(nn.Module): 948 | def __init__(self, cfg): 949 | super().__init__() 950 | self.cfg = cfg 951 | self.unet = UNetModel( 952 | image_size=cfg.image_size, 953 | in_channels=cfg.in_channels, 954 | out_channels=cfg.out_channels, 955 | model_channels=cfg.model_channels, 956 | attention_resolutions=cfg.attention_resolutions, 957 | num_res_blocks=cfg.num_res_blocks, 958 | channel_mult=cfg.channel_mult, 959 | num_heads=cfg.num_heads, 960 | use_spatial_transformer=cfg.use_spatial_transformer, 961 | transformer_depth=cfg.transformer_depth, 962 | context_dim=cfg.context_dim, 963 | use_checkpoint=cfg.use_checkpoint, 964 | legacy=cfg.legacy, 965 | extra_sa_layer=cfg.extra_sa_layer, 966 | ) 967 | 968 | def forward(self, x, timesteps=None, context=None, y=None): 969 | x = self.unet(x=x, timesteps=timesteps, context=context, y=y) 970 | return x 971 | -------------------------------------------------------------------------------- /models/tta/lfm/audiolfm_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import os 8 | import lightning as L 9 | import numpy as np 10 | from torch.optim.lr_scheduler import ExponentialLR 11 | from utils.utils import num_parameters 12 | from lightning.fabric import Fabric 13 | 14 | from models.tta.autoencoder.autoencoder import AutoencoderKL 15 | from models.tta.lfm.audioldm import AudioLDM 16 | 17 | from transformers import T5EncoderModel, AutoTokenizer 18 | from models.tta.lfm.fm_scheduler import ( 19 | FlowMatchingTrainer, 20 | ) 21 | from utils.io import save_audio 22 | import datetime 23 | 24 | # from diffusers import DDPMScheduler, PNDMScheduler, DDIMScheduler 25 | from utils.tensor_utils import ( 26 | get_vocoder, 27 | vocoder_infer, 28 | ) 29 | 30 | 31 | class LAFMA_Inference(object): 32 | def __init__(self, config, loggers=None, precision="32-true") -> None: 33 | self.fabric = Fabric( 34 | accelerator=config.train.accelerator, 35 | strategy=config.train.strategy, 36 | devices=config.train.devices, 37 | loggers=loggers, 38 | precision=precision, 39 | ) 40 | self.cfg = config 41 | 42 | self.fabric.launch() 43 | self.vocoder = get_vocoder(None, "cpu") 44 | 45 | def _build_model(self): 46 | if self.cfg.infer: 47 | with self.fabric.init_module(empty_init=True): 48 | # audioldm 49 | model = AudioLDM(self.cfg.model.audioldm) 50 | # pretrained autoencoder 51 | autoencoder = AutoencoderKL(self.cfg.model.autoencoderkl) 52 | autoencoder_path = self.cfg.model.autoencoder_path 53 | checkpoint = self.fabric.load(autoencoder_path) 54 | autoencoder.load_state_dict(checkpoint["autoencoder"]) 55 | autoencoder.requires_grad_(requires_grad=False) 56 | autoencoder.eval() 57 | # pretrained text encoder 58 | tokenizer = AutoTokenizer.from_pretrained( 59 | "huggingface/flan-t5-large", 60 | model_max_length=512, 61 | ) 62 | text_encoder = T5EncoderModel.from_pretrained( 63 | "huggingface/flan-t5-large", 64 | ) 65 | text_encoder.requires_grad_(requires_grad=False) 66 | text_encoder.eval() 67 | 68 | if self.fabric.local_rank == 0: 69 | output_dir = self.cfg.test_out_dir 70 | model_file = os.path.join(output_dir, "model.log") 71 | if os.path.exists(model_file): 72 | os.remove(model_file) 73 | log = open(model_file, mode="a+", encoding="utf-8") 74 | print(model, file=log) 75 | log.close() 76 | self.fabric.barrier() 77 | return model, autoencoder, text_encoder, tokenizer 78 | 79 | @torch.no_grad() 80 | def mel_to_latent(self, melspec): 81 | posterior = self.autoencoder.encode(melspec) 82 | latent = posterior.sample() # (B, 4, 5, 78) 83 | return latent 84 | 85 | @torch.no_grad() 86 | def get_text_embedding(self, text_input_ids, text_attention_mask): 87 | text_embedding = self.text_encoder( 88 | input_ids=text_input_ids, attention_mask=text_attention_mask 89 | )[0] 90 | return text_embedding # (B, T, 768) 91 | 92 | def before_test(self): 93 | self.fabric.barrier() 94 | test = f"test_out-{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}" 95 | self.cfg.test_out_dir = os.path.join(self.cfg.train.out_dir, test) 96 | os.makedirs(self.cfg.test_out_dir, exist_ok=True) 97 | 98 | def _prepare_model(self): 99 | checkpoint_file = self.cfg.checkpoint_file 100 | self.fabric.print(f"start test from {checkpoint_file}.") 101 | if checkpoint_file.endswith(".ckpt"): 102 | checkpoint = self.fabric.load(checkpoint_file) 103 | self.model.load_state_dict(checkpoint["model"]) 104 | else: 105 | raise ValueError("supported checkpoint file format : .ckpt") 106 | 107 | output_dir = self.cfg.test_out_dir 108 | 109 | if self.fabric.local_rank == 0: 110 | if self.cfg.infer: 111 | log_file = os.path.join(output_dir, "run.log") 112 | else: 113 | log_file = os.path.join(output_dir, "logs/run.log") 114 | log = open(log_file, mode="a+", encoding="utf-8") 115 | print("----------------------", file=log) 116 | print(f"accelerator: {self.fabric.accelerator}", file=log) 117 | print(f"strategy: {self.fabric.strategy}", file=log) 118 | print("----------------------", file=log) 119 | print(f"use dataset: {self.cfg.dataset}", file=log) 120 | print(f"sampling rate: {self.cfg.preprocess.sample_rate}", file=log) 121 | print(f"project name: {self.cfg.train.project}", file=log) 122 | print("----------------------", file=log) 123 | log.close() 124 | self.fabric.print("----------------------") 125 | self.fabric.print(f"accelerator: {self.fabric.accelerator}") 126 | self.fabric.print(f"strategy: {self.fabric.strategy}") 127 | self.fabric.print("----------------------") 128 | self.fabric.print(f"use dataset: {self.cfg.dataset}") 129 | self.fabric.print(f"sampling rate: {self.cfg.preprocess.sample_rate}") 130 | self.fabric.print(f"project name: {self.cfg.train.project}") 131 | self.fabric.barrier() 132 | 133 | def test(self): 134 | self.before_test() 135 | self.model, self.autoencoder, self.text_encoder, self.tokenizer = ( 136 | self._build_model() 137 | ) 138 | self.model = self.fabric.setup(self.model) 139 | self.trainer = FlowMatchingTrainer(self.model, sample_N=self.cfg.num_steps) 140 | 141 | self._prepare_model() 142 | 143 | # test 144 | self.test_step() 145 | self.fabric.print("-" * 16) 146 | if self.fabric.device.type == "cuda": 147 | self.fabric.print( 148 | f"memory used: {torch.cuda.max_memory_allocated()/1e9:.02f} GB" 149 | ) 150 | 151 | def test_step(self): 152 | assert self.vocoder is not None, "Vocoder is not loaded." 153 | if self.cfg.infer_text is not None: 154 | out_dir = self.cfg.test_out_dir 155 | os.makedirs(out_dir, exist_ok=True) 156 | 157 | pred_audio = self.inference_for_single_utterance() 158 | save_path = os.path.join(out_dir, "test_pred.wav") 159 | save_audio(save_path, pred_audio, self.cfg.preprocess.sample_rate) 160 | 161 | @torch.inference_mode() 162 | def inference_for_single_utterance(self): 163 | text = self.cfg.infer_text 164 | 165 | text_input = self.tokenizer( 166 | [text], 167 | max_length=self.tokenizer.model_max_length, 168 | truncation=True, 169 | padding="do_not_pad", 170 | return_tensors="pt", 171 | ) 172 | 173 | text_input = self.fabric.to_device(text_input) 174 | text_embedding = self.text_encoder(text_input.input_ids)[0] 175 | 176 | max_length = text_input.input_ids.shape[-1] 177 | uncond_input = self.tokenizer( 178 | [""] * 1, 179 | padding="max_length", 180 | max_length=max_length, 181 | return_tensors="pt", 182 | ) 183 | uncond_input = self.fabric.to_device(uncond_input) 184 | uncond_embedding = self.text_encoder(uncond_input.input_ids)[0] 185 | text_embeddings = torch.cat([uncond_embedding, text_embedding]) 186 | 187 | guidance_scale = self.cfg.guidance_scale 188 | 189 | self.model.eval() 190 | 191 | # sample 192 | latents_t = torch.randn( 193 | ( 194 | 1, 195 | 8, 196 | 256, 197 | 16, 198 | ) 199 | ) 200 | latents_out, nfe = self.trainer.euler_sample( 201 | text_embeddings, latents_t.shape, guidance_scale 202 | ) 203 | 204 | print(latents_out.shape, nfe) 205 | 206 | with torch.no_grad(): 207 | mel_pred = self.autoencoder.decode(latents_out) 208 | print(mel_pred.shape) 209 | wav_pred = vocoder_infer(mel_pred.transpose(2, 3)[0].cpu(), self.vocoder) 210 | wav_pred = ( 211 | wav_pred / np.max(np.abs(wav_pred)) 212 | ) * 0.8 # Normalize the energy of the generation output 213 | return wav_pred 214 | -------------------------------------------------------------------------------- /models/tta/lfm/fm_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | def extract(v, i, shape): 10 | """ 11 | Get the i-th number in v, and the shape of v is mostly (T, ), the shape of i is mostly (batch_size, ). 12 | equal to [v[index] for index in i] 13 | """ 14 | out = torch.gather(v, index=i, dim=0) 15 | out = out.to(device=i.device, dtype=torch.float32) 16 | 17 | # reshape to (batch_size, 1, 1, 1, 1, ...) for broadcasting purposes. 18 | out = out.view([i.shape[0]] + [1] * (len(shape) - 1)) 19 | return out 20 | 21 | 22 | class FlowMatchingTrainer(nn.Module): 23 | def __init__( 24 | self, 25 | model: nn.Module, 26 | init_type="gaussian", 27 | noise_scale=1.0, 28 | reflow_t_schedule="uniform", 29 | use_ode_sampler="euler", 30 | sigma_var=0.0, 31 | ode_tol=1e-5, 32 | sample_N=25, 33 | ): 34 | super().__init__() 35 | self.model = model 36 | self.init_type = init_type 37 | self.noise_scale = noise_scale 38 | self.reflow_t_schedule = reflow_t_schedule 39 | self.use_ode_sampler = use_ode_sampler 40 | self.sigma_var = sigma_var 41 | self.ode_tol = ode_tol 42 | self.sample_N = sample_N 43 | self.T = 1 44 | self.eps = 1e-3 45 | self.sigma_t = lambda t: (1.0 - t) * sigma_var 46 | print("Init. Distribution Variance:", self.noise_scale) 47 | print("SDE Sampler Variance:", sigma_var) 48 | print("ODE Tolerence:", self.ode_tol) 49 | 50 | def forward(self, x_0, c): 51 | # get a random training step $t \sim Uniform({1, ..., T})$ 52 | # t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device) 53 | t = torch.rand(x_0.shape[0], device=x_0.device) * (self.T - self.eps) + self.eps 54 | t_expand = t.view(-1, 1, 1, 1).repeat( 55 | 1, x_0.shape[1], x_0.shape[2], x_0.shape[3] 56 | ) 57 | c = c.to(x_0.device) 58 | 59 | noise = torch.randn_like(x_0) 60 | target = x_0 - noise 61 | perturbed_data = t_expand * x_0 + (1 - t_expand) * noise 62 | 63 | model_out = self.model(perturbed_data, t * 999, c) 64 | 65 | # get the gradient 66 | loss = F.mse_loss(model_out, target, reduction="none").mean([1, 2, 3]).mean() 67 | return loss 68 | 69 | @torch.no_grad() 70 | def euler_sample(self, cond, shape, guidance_scale): 71 | device = self.model.device 72 | batch = torch.randn(shape, device=device) 73 | x = torch.randn_like(batch) 74 | # uniform 75 | dt = 1.0 / self.sample_N 76 | eps = 1e-3 77 | for i in range(self.sample_N): 78 | num_t = i / self.sample_N * (self.T - eps) + eps 79 | t = torch.ones(batch.shape[0], device=device) * num_t 80 | 81 | model_out = self.model(torch.cat([x] * 2), torch.cat([t * 999] * 2), cond) 82 | # perform guidance 83 | noise_pred_uncond, noise_pred_text = model_out.chunk(2) 84 | pred = noise_pred_uncond + guidance_scale * ( 85 | noise_pred_text - noise_pred_uncond 86 | ) 87 | 88 | # pred = self.apply_model(x, t*999, cond) 89 | 90 | sigma_t = self.sigma_t(num_t) 91 | pred_sigma = pred + (sigma_t**2) / ( 92 | 2 * (self.noise_scale**2) * ((1.0 - num_t) ** 2) 93 | ) * ( 94 | 0.5 * num_t * (1.0 - num_t) * pred 95 | - 0.5 * (2.0 - num_t) * x.detach().clone() 96 | ) 97 | 98 | x = ( 99 | x.detach().clone() 100 | + pred_sigma * dt 101 | + sigma_t * np.sqrt(dt) * torch.randn_like(pred_sigma).to(device) 102 | ) 103 | 104 | nfe = self.sample_N 105 | return x, nfe 106 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gwh22/LAFMA/0a11ac2d4d176018f0aa0531dd6aca84cc532138/modules/__init__.py -------------------------------------------------------------------------------- /modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gwh22/LAFMA/0a11ac2d4d176018f0aa0531dd6aca84cc532138/modules/distributions/__init__.py -------------------------------------------------------------------------------- /modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class AbstractDistribution: 11 | def sample(self): 12 | raise NotImplementedError() 13 | 14 | def mode(self): 15 | raise NotImplementedError() 16 | 17 | 18 | class DiracDistribution(AbstractDistribution): 19 | def __init__(self, value): 20 | self.value = value 21 | 22 | def sample(self): 23 | return self.value 24 | 25 | def mode(self): 26 | return self.value 27 | 28 | 29 | class DiagonalGaussianDistribution(object): 30 | def __init__(self, parameters, deterministic=False): 31 | self.parameters = parameters 32 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 33 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 34 | self.deterministic = deterministic 35 | self.std = torch.exp(0.5 * self.logvar) 36 | self.var = torch.exp(self.logvar) 37 | if self.deterministic: 38 | self.var = self.std = torch.zeros_like(self.mean).to( 39 | device=self.parameters.device 40 | ) 41 | 42 | def sample(self): 43 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 44 | device=self.parameters.device 45 | ) 46 | return x 47 | 48 | def kl(self, other=None): 49 | if self.deterministic: 50 | return torch.Tensor([0.0]) 51 | else: 52 | if other is None: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 55 | dim=[1, 2, 3], 56 | ) 57 | else: 58 | return 0.5 * torch.sum( 59 | torch.pow(self.mean - other.mean, 2) / other.var 60 | + self.var / other.var 61 | - 1.0 62 | - self.logvar 63 | + other.logvar, 64 | dim=[1, 2, 3], 65 | ) 66 | 67 | def nll(self, sample, dims=[1, 2, 3]): 68 | if self.deterministic: 69 | return torch.Tensor([0.0]) 70 | logtwopi = np.log(2.0 * np.pi) 71 | return 0.5 * torch.sum( 72 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 73 | dim=dims, 74 | ) 75 | 76 | def mode(self): 77 | return self.mean 78 | 79 | 80 | def normal_kl(mean1, logvar1, mean2, logvar2): 81 | """ 82 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 83 | Compute the KL divergence between two gaussians. 84 | Shapes are automatically broadcasted, so batches can be compared to 85 | scalars, among other use cases. 86 | """ 87 | tensor = None 88 | for obj in (mean1, logvar1, mean2, logvar2): 89 | if isinstance(obj, torch.Tensor): 90 | tensor = obj 91 | break 92 | assert tensor is not None, "at least one argument must be a Tensor" 93 | 94 | # Force variances to be Tensors. Broadcasting helps convert scalars to 95 | # Tensors, but it does not work for torch.exp(). 96 | logvar1, logvar2 = [ 97 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 98 | for x in (logvar1, logvar2) 99 | ] 100 | 101 | return 0.5 * ( 102 | -1.0 103 | + logvar2 104 | - logvar1 105 | + torch.exp(logvar1 - logvar2) 106 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 107 | ) 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | json5==0.9.25 3 | lightning==2.2.5 4 | matplotlib==3.9.0 5 | numpy==1.26.4 6 | omegaconf==2.3.0 7 | ruamel.base==1.0.0 8 | ruamel_yaml==0.18.6 9 | scikit_learn==1.5.0 10 | scipy==1.13.1 11 | six==1.16.0 12 | torch==2.0.1 13 | torchaudio==2.0.2 14 | tqdm==4.66.1 15 | transformers==4.41.2 16 | taming-transformers-rom1504==0.0.6 -------------------------------------------------------------------------------- /utils/HyperParams/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .hps import HyperParams 7 | -------------------------------------------------------------------------------- /utils/HyperParams/hps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | class HyperParams: 8 | """The class to store hyperparameters. The key is case-insensitive. 9 | 10 | Args: 11 | *args: a list of dict or HyperParams. 12 | **kwargs: a list of key-value pairs. 13 | """ 14 | 15 | def __init__(self, **kwargs): 16 | for k, v in kwargs.items(): 17 | if type(v) == dict: 18 | v = HyperParams(**v) 19 | self[k] = v 20 | 21 | def keys(self): 22 | return self.__dict__.keys() 23 | 24 | def items(self): 25 | return self.__dict__.items() 26 | 27 | def values(self): 28 | return self.__dict__.values() 29 | 30 | def __len__(self): 31 | return len(self.__dict__) 32 | 33 | def __getitem__(self, key): 34 | return getattr(self, key) 35 | 36 | def __setitem__(self, key, value): 37 | return setattr(self, key, value) 38 | 39 | def __contains__(self, key): 40 | return key in self.__dict__ 41 | 42 | def __repr__(self): 43 | return self.__dict__.__repr__() 44 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gwh22/LAFMA/0a11ac2d4d176018f0aa0531dd6aca84cc532138/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import json 7 | import os 8 | 9 | import numpy as np 10 | from scipy.interpolate import interp1d 11 | from tqdm import tqdm 12 | from sklearn.preprocessing import StandardScaler 13 | 14 | 15 | def intersperse(lst, item): 16 | """ 17 | Insert an item in between any two consecutive elements of the given list, including beginning and end of list 18 | 19 | Example: 20 | >>> intersperse(0, [1, 74, 5, 31]) 21 | [0, 1, 0, 74, 0, 5, 0, 31, 0] 22 | """ 23 | result = [item] * (len(lst) * 2 + 1) 24 | result[1::2] = lst 25 | return result 26 | 27 | 28 | def load_content_feature_path(meta_data, processed_dir, feat_dir): 29 | utt2feat_path = {} 30 | for utt_info in meta_data: 31 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 32 | feat_path = os.path.join( 33 | processed_dir, utt_info["Dataset"], feat_dir, f'{utt_info["Uid"]}.npy' 34 | ) 35 | utt2feat_path[utt] = feat_path 36 | 37 | return utt2feat_path 38 | 39 | 40 | def load_source_content_feature_path(meta_data, feat_dir): 41 | utt2feat_path = {} 42 | for utt in meta_data: 43 | feat_path = os.path.join(feat_dir, f"{utt}.npy") 44 | utt2feat_path[utt] = feat_path 45 | 46 | return utt2feat_path 47 | 48 | 49 | def get_spk_map(spk2id_path, utt2spk_path): 50 | utt2spk = {} 51 | with open(spk2id_path, "r") as spk2id_file: 52 | spk2id = json.load(spk2id_file) 53 | with open(utt2spk_path, encoding="utf-8") as f: 54 | for line in f.readlines(): 55 | utt, spk = line.strip().split("\t") 56 | utt2spk[utt] = spk 57 | return spk2id, utt2spk 58 | 59 | 60 | def get_target_f0_median(f0_dir): 61 | total_f0 = [] 62 | for utt in os.listdir(f0_dir): 63 | if not utt.endswith(".npy"): 64 | continue 65 | f0_feat_path = os.path.join(f0_dir, utt) 66 | f0 = np.load(f0_feat_path) 67 | total_f0 += f0.tolist() 68 | 69 | total_f0 = np.array(total_f0) 70 | voiced_position = np.where(total_f0 != 0) 71 | return np.median(total_f0[voiced_position]) 72 | 73 | 74 | def get_conversion_f0_factor(source_f0, target_median, source_median=None): 75 | """Align the median between source f0 and target f0 76 | 77 | Note: Here we use multiplication, whose factor is target_median/source_median 78 | 79 | Reference: Frequency and pitch interval 80 | http://blog.ccyg.studio/article/be12c2ee-d47c-4098-9782-ca76da3035e4/ 81 | """ 82 | if source_median is None: 83 | voiced_position = np.where(source_f0 != 0) 84 | source_median = np.median(source_f0[voiced_position]) 85 | factor = target_median / source_median 86 | return source_median, factor 87 | 88 | 89 | def transpose_key(frame_pitch, trans_key): 90 | # Transpose by user's argument 91 | print("Transpose key = {} ...\n".format(trans_key)) 92 | 93 | transed_pitch = frame_pitch * 2 ** (trans_key / 12) 94 | return transed_pitch 95 | 96 | 97 | def pitch_shift_to_target(frame_pitch, target_pitch_median, source_pitch_median=None): 98 | # Loading F0 Base (median) and shift 99 | source_pitch_median, factor = get_conversion_f0_factor( 100 | frame_pitch, target_pitch_median, source_pitch_median 101 | ) 102 | print( 103 | "Auto transposing: source f0 median = {:.1f}, target f0 median = {:.1f}, factor = {:.2f}".format( 104 | source_pitch_median, target_pitch_median, factor 105 | ) 106 | ) 107 | transed_pitch = frame_pitch * factor 108 | return transed_pitch 109 | 110 | 111 | def load_frame_pitch( 112 | meta_data, 113 | processed_dir, 114 | pitch_dir, 115 | use_log_scale=False, 116 | return_norm=False, 117 | interoperate=False, 118 | utt2spk=None, 119 | ): 120 | utt2pitch = {} 121 | utt2uv = {} 122 | if utt2spk is None: 123 | pitch_scaler = StandardScaler() 124 | for utt_info in meta_data: 125 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 126 | pitch_path = os.path.join( 127 | processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy' 128 | ) 129 | pitch = np.load(pitch_path) 130 | assert len(pitch) > 0 131 | uv = pitch != 0 132 | utt2uv[utt] = uv 133 | if use_log_scale: 134 | nonzero_idxes = np.where(pitch != 0)[0] 135 | pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes]) 136 | utt2pitch[utt] = pitch 137 | pitch_scaler.partial_fit(pitch.reshape(-1, 1)) 138 | 139 | mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] 140 | if return_norm: 141 | for utt_info in meta_data: 142 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 143 | pitch = utt2pitch[utt] 144 | normalized_pitch = (pitch - mean) / std 145 | utt2pitch[utt] = normalized_pitch 146 | pitch_statistic = {"mean": mean, "std": std} 147 | else: 148 | spk2utt = {} 149 | pitch_statistic = [] 150 | for utt_info in meta_data: 151 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 152 | if not utt2spk[utt] in spk2utt: 153 | spk2utt[utt2spk[utt]] = [] 154 | spk2utt[utt2spk[utt]].append(utt) 155 | 156 | for spk in spk2utt: 157 | pitch_scaler = StandardScaler() 158 | for utt in spk2utt[spk]: 159 | dataset = utt.split("_")[0] 160 | uid = "_".join(utt.split("_")[1:]) 161 | pitch_path = os.path.join( 162 | processed_dir, dataset, pitch_dir, f"{uid}.npy" 163 | ) 164 | pitch = np.load(pitch_path) 165 | assert len(pitch) > 0 166 | uv = pitch != 0 167 | utt2uv[utt] = uv 168 | if use_log_scale: 169 | nonzero_idxes = np.where(pitch != 0)[0] 170 | pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes]) 171 | utt2pitch[utt] = pitch 172 | pitch_scaler.partial_fit(pitch.reshape(-1, 1)) 173 | 174 | mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] 175 | if return_norm: 176 | for utt in spk2utt[spk]: 177 | pitch = utt2pitch[utt] 178 | normalized_pitch = (pitch - mean) / std 179 | utt2pitch[utt] = normalized_pitch 180 | pitch_statistic.append({"spk": spk, "mean": mean, "std": std}) 181 | 182 | return utt2pitch, utt2uv, pitch_statistic 183 | 184 | 185 | # discard 186 | def load_phone_pitch( 187 | meta_data, 188 | processed_dir, 189 | pitch_dir, 190 | utt2dur, 191 | use_log_scale=False, 192 | return_norm=False, 193 | interoperate=True, 194 | utt2spk=None, 195 | ): 196 | print("Load Phone Pitch") 197 | utt2pitch = {} 198 | utt2uv = {} 199 | if utt2spk is None: 200 | pitch_scaler = StandardScaler() 201 | for utt_info in tqdm(meta_data): 202 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 203 | pitch_path = os.path.join( 204 | processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy' 205 | ) 206 | frame_pitch = np.load(pitch_path) 207 | assert len(frame_pitch) > 0 208 | uv = frame_pitch != 0 209 | utt2uv[utt] = uv 210 | phone_pitch = phone_average_pitch(frame_pitch, utt2dur[utt], interoperate) 211 | if use_log_scale: 212 | nonzero_idxes = np.where(phone_pitch != 0)[0] 213 | phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes]) 214 | utt2pitch[utt] = phone_pitch 215 | pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1)) 216 | 217 | mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] 218 | max_value = np.finfo(np.float64).min 219 | min_value = np.finfo(np.float64).max 220 | if return_norm: 221 | for utt_info in meta_data: 222 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 223 | pitch = utt2pitch[utt] 224 | normalized_pitch = (pitch - mean) / std 225 | max_value = max(max_value, max(normalized_pitch)) 226 | min_value = min(min_value, min(normalized_pitch)) 227 | utt2pitch[utt] = normalized_pitch 228 | phone_normalized_pitch_path = os.path.join( 229 | processed_dir, 230 | utt_info["Dataset"], 231 | "phone_level_" + pitch_dir, 232 | f'{utt_info["Uid"]}.npy', 233 | ) 234 | pitch_statistic = { 235 | "mean": mean, 236 | "std": std, 237 | "min_value": min_value, 238 | "max_value": max_value, 239 | } 240 | else: 241 | spk2utt = {} 242 | pitch_statistic = [] 243 | for utt_info in tqdm(meta_data): 244 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 245 | if not utt2spk[utt] in spk2utt: 246 | spk2utt[utt2spk[utt]] = [] 247 | spk2utt[utt2spk[utt]].append(utt) 248 | 249 | for spk in spk2utt: 250 | pitch_scaler = StandardScaler() 251 | for utt in spk2utt[spk]: 252 | dataset = utt.split("_")[0] 253 | uid = "_".join(utt.split("_")[1:]) 254 | pitch_path = os.path.join( 255 | processed_dir, dataset, pitch_dir, f"{uid}.npy" 256 | ) 257 | frame_pitch = np.load(pitch_path) 258 | assert len(frame_pitch) > 0 259 | uv = frame_pitch != 0 260 | utt2uv[utt] = uv 261 | phone_pitch = phone_average_pitch( 262 | frame_pitch, utt2dur[utt], interoperate 263 | ) 264 | if use_log_scale: 265 | nonzero_idxes = np.where(phone_pitch != 0)[0] 266 | phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes]) 267 | utt2pitch[utt] = phone_pitch 268 | pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1)) 269 | 270 | mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] 271 | max_value = np.finfo(np.float64).min 272 | min_value = np.finfo(np.float64).max 273 | 274 | if return_norm: 275 | for utt in spk2utt[spk]: 276 | pitch = utt2pitch[utt] 277 | normalized_pitch = (pitch - mean) / std 278 | max_value = max(max_value, max(normalized_pitch)) 279 | min_value = min(min_value, min(normalized_pitch)) 280 | utt2pitch[utt] = normalized_pitch 281 | pitch_statistic.append( 282 | { 283 | "spk": spk, 284 | "mean": mean, 285 | "std": std, 286 | "min_value": min_value, 287 | "max_value": max_value, 288 | } 289 | ) 290 | 291 | return utt2pitch, utt2uv, pitch_statistic 292 | 293 | 294 | def phone_average_pitch(pitch, dur, interoperate=False): 295 | pos = 0 296 | 297 | if interoperate: 298 | nonzero_ids = np.where(pitch != 0)[0] 299 | interp_fn = interp1d( 300 | nonzero_ids, 301 | pitch[nonzero_ids], 302 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), 303 | bounds_error=False, 304 | ) 305 | pitch = interp_fn(np.arange(0, len(pitch))) 306 | phone_pitch = np.zeros(len(dur)) 307 | 308 | for i, d in enumerate(dur): 309 | d = int(d) 310 | if d > 0 and pos < len(pitch): 311 | phone_pitch[i] = np.mean(pitch[pos : pos + d]) 312 | else: 313 | phone_pitch[i] = 0 314 | pos += d 315 | return phone_pitch 316 | 317 | 318 | def load_energy( 319 | meta_data, 320 | processed_dir, 321 | energy_dir, 322 | use_log_scale=False, 323 | return_norm=False, 324 | utt2spk=None, 325 | ): 326 | utt2energy = {} 327 | if utt2spk is None: 328 | for utt_info in meta_data: 329 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 330 | energy_path = os.path.join( 331 | processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy' 332 | ) 333 | if not os.path.exists(energy_path): 334 | continue 335 | energy = np.load(energy_path) 336 | assert len(energy) > 0 337 | 338 | if use_log_scale: 339 | nonzero_idxes = np.where(energy != 0)[0] 340 | energy[nonzero_idxes] = np.log(energy[nonzero_idxes]) 341 | utt2energy[utt] = energy 342 | 343 | if return_norm: 344 | with open( 345 | os.path.join( 346 | processed_dir, utt_info["Dataset"], energy_dir, "statistics.json" 347 | ) 348 | ) as f: 349 | stats = json.load(f) 350 | mean, std = ( 351 | stats[utt_info["Dataset"] + "_" + utt_info["Singer"]][ 352 | "voiced_positions" 353 | ]["mean"], 354 | stats["LJSpeech_LJSpeech"]["voiced_positions"]["std"], 355 | ) 356 | for utt in utt2energy.keys(): 357 | energy = utt2energy[utt] 358 | normalized_energy = (energy - mean) / std 359 | utt2energy[utt] = normalized_energy 360 | 361 | energy_statistic = {"mean": mean, "std": std} 362 | else: 363 | spk2utt = {} 364 | energy_statistic = [] 365 | for utt_info in meta_data: 366 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 367 | if not utt2spk[utt] in spk2utt: 368 | spk2utt[utt2spk[utt]] = [] 369 | spk2utt[utt2spk[utt]].append(utt) 370 | 371 | for spk in spk2utt: 372 | energy_scaler = StandardScaler() 373 | for utt in spk2utt[spk]: 374 | dataset = utt.split("_")[0] 375 | uid = "_".join(utt.split("_")[1:]) 376 | energy_path = os.path.join( 377 | processed_dir, dataset, energy_dir, f"{uid}.npy" 378 | ) 379 | if not os.path.exists(energy_path): 380 | continue 381 | frame_energy = np.load(energy_path) 382 | assert len(frame_energy) > 0 383 | 384 | if use_log_scale: 385 | nonzero_idxes = np.where(frame_energy != 0)[0] 386 | frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) 387 | utt2energy[utt] = frame_energy 388 | energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) 389 | 390 | mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] 391 | if return_norm: 392 | for utt in spk2utt[spk]: 393 | energy = utt2energy[utt] 394 | normalized_energy = (energy - mean) / std 395 | utt2energy[utt] = normalized_energy 396 | energy_statistic.append({"spk": spk, "mean": mean, "std": std}) 397 | 398 | return utt2energy, energy_statistic 399 | 400 | 401 | def load_frame_energy( 402 | meta_data, 403 | processed_dir, 404 | energy_dir, 405 | use_log_scale=False, 406 | return_norm=False, 407 | interoperate=False, 408 | utt2spk=None, 409 | ): 410 | utt2energy = {} 411 | if utt2spk is None: 412 | energy_scaler = StandardScaler() 413 | for utt_info in meta_data: 414 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 415 | energy_path = os.path.join( 416 | processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy' 417 | ) 418 | frame_energy = np.load(energy_path) 419 | assert len(frame_energy) > 0 420 | 421 | if use_log_scale: 422 | nonzero_idxes = np.where(frame_energy != 0)[0] 423 | frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) 424 | utt2energy[utt] = frame_energy 425 | energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) 426 | 427 | mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] 428 | if return_norm: 429 | for utt_info in meta_data: 430 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 431 | energy = utt2energy[utt] 432 | normalized_energy = (energy - mean) / std 433 | utt2energy[utt] = normalized_energy 434 | energy_statistic = {"mean": mean, "std": std} 435 | 436 | else: 437 | spk2utt = {} 438 | energy_statistic = [] 439 | for utt_info in meta_data: 440 | utt = utt_info["Dataset"] + "_" + utt_info["Uid"] 441 | if not utt2spk[utt] in spk2utt: 442 | spk2utt[utt2spk[utt]] = [] 443 | spk2utt[utt2spk[utt]].append(utt) 444 | 445 | for spk in spk2utt: 446 | energy_scaler = StandardScaler() 447 | for utt in spk2utt[spk]: 448 | dataset = utt.split("_")[0] 449 | uid = "_".join(utt.split("_")[1:]) 450 | energy_path = os.path.join( 451 | processed_dir, dataset, energy_dir, f"{uid}.npy" 452 | ) 453 | frame_energy = np.load(energy_path) 454 | assert len(frame_energy) > 0 455 | 456 | if use_log_scale: 457 | nonzero_idxes = np.where(frame_energy != 0)[0] 458 | frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) 459 | utt2energy[utt] = frame_energy 460 | energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) 461 | 462 | mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] 463 | if return_norm: 464 | for utt in spk2utt[spk]: 465 | energy = utt2energy[utt] 466 | normalized_energy = (energy - mean) / std 467 | utt2energy[utt] = normalized_energy 468 | energy_statistic.append({"spk": spk, "mean": mean, "std": std}) 469 | 470 | return utt2energy, energy_statistic 471 | 472 | 473 | def align_length(feature, target_len, pad_value=0.0): 474 | feature_len = feature.shape[-1] 475 | dim = len(feature.shape) 476 | # align 1-D data 477 | if dim == 2: 478 | if target_len > feature_len: 479 | feature = np.pad( 480 | feature, 481 | ((0, 0), (0, target_len - feature_len)), 482 | constant_values=pad_value, 483 | ) 484 | else: 485 | feature = feature[:, :target_len] 486 | # align 2-D data 487 | elif dim == 1: 488 | if target_len > feature_len: 489 | feature = np.pad( 490 | feature, (0, target_len - feature_len), constant_values=pad_value 491 | ) 492 | else: 493 | feature = feature[:target_len] 494 | else: 495 | raise NotImplementedError 496 | return feature 497 | 498 | 499 | def align_whisper_feauture_length( 500 | feature, target_len, fast_mapping=True, source_hop=320, target_hop=256 501 | ): 502 | factor = np.gcd(source_hop, target_hop) 503 | source_hop //= factor 504 | target_hop //= factor 505 | # print( 506 | # "Mapping source's {} frames => target's {} frames".format( 507 | # target_hop, source_hop 508 | # ) 509 | # ) 510 | 511 | max_source_len = 1500 512 | target_len = min(target_len, max_source_len * source_hop // target_hop) 513 | 514 | width = feature.shape[-1] 515 | 516 | if fast_mapping: 517 | source_len = target_len * target_hop // source_hop + 1 518 | feature = feature[:source_len] 519 | 520 | else: 521 | source_len = max_source_len 522 | 523 | # const ~= target_len * target_hop 524 | const = source_len * source_hop // target_hop * target_hop 525 | 526 | # (source_len * source_hop, dim) 527 | up_sampling_feats = np.repeat(feature, source_hop, axis=0) 528 | # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) 529 | down_sampling_feats = np.average( 530 | up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 531 | ) 532 | assert len(down_sampling_feats) >= target_len 533 | 534 | # (target_len, dim) 535 | feat = down_sampling_feats[:target_len] 536 | 537 | return feat 538 | 539 | 540 | def align_content_feature_length(feature, target_len, source_hop=320, target_hop=256): 541 | factor = np.gcd(source_hop, target_hop) 542 | source_hop //= factor 543 | target_hop //= factor 544 | # print( 545 | # "Mapping source's {} frames => target's {} frames".format( 546 | # target_hop, source_hop 547 | # ) 548 | # ) 549 | 550 | # (source_len, 256) 551 | source_len, width = feature.shape 552 | 553 | # const ~= target_len * target_hop 554 | const = source_len * source_hop // target_hop * target_hop 555 | 556 | # (source_len * source_hop, dim) 557 | up_sampling_feats = np.repeat(feature, source_hop, axis=0) 558 | # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) 559 | down_sampling_feats = np.average( 560 | up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 561 | ) 562 | 563 | err = abs(target_len - len(down_sampling_feats)) 564 | if err > 4: ## why 4 not 3? 565 | print("target_len:", target_len) 566 | print("raw feature:", feature.shape) 567 | print("up_sampling:", up_sampling_feats.shape) 568 | print("down_sampling_feats:", down_sampling_feats.shape) 569 | exit() 570 | if len(down_sampling_feats) < target_len: 571 | # (1, dim) -> (err, dim) 572 | end = down_sampling_feats[-1][None, :].repeat(err, axis=0) 573 | down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) 574 | 575 | # (target_len, dim) 576 | feat = down_sampling_feats[:target_len] 577 | 578 | return feat 579 | 580 | 581 | def remove_outlier(values): 582 | values = np.array(values) 583 | p25 = np.percentile(values, 25) 584 | p75 = np.percentile(values, 75) 585 | lower = p25 - 1.5 * (p75 - p25) 586 | upper = p75 + 1.5 * (p75 - p25) 587 | normal_indices = np.logical_and(values > lower, values < upper) 588 | return values[normal_indices] 589 | -------------------------------------------------------------------------------- /utils/hparam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long 7 | """Hyperparameter values.""" 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import json 13 | import numbers 14 | import re 15 | import six 16 | 17 | # Define the regular expression for parsing a single clause of the input 18 | # (delimited by commas). A legal clause looks like: 19 | # []? = 20 | # where is either a single token or [] enclosed list of tokens. 21 | # For example: "var[1] = a" or "x = [1,2,3]" 22 | PARAM_RE = re.compile( 23 | r""" 24 | (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" 25 | (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None 26 | \s*=\s* 27 | ((?P[^,\[]*) # single value: "a" or None 28 | | 29 | \[(?P[^\]]*)\]) # list of values: None or "1,2,3" 30 | ($|,\s*)""", 31 | re.VERBOSE, 32 | ) 33 | 34 | 35 | def _parse_fail(name, var_type, value, values): 36 | """Helper function for raising a value error for bad assignment.""" 37 | raise ValueError( 38 | "Could not parse hparam '%s' of type '%s' with value '%s' in %s" 39 | % (name, var_type.__name__, value, values) 40 | ) 41 | 42 | 43 | def _reuse_fail(name, values): 44 | """Helper function for raising a value error for reuse of name.""" 45 | raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values)) 46 | 47 | 48 | def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary): 49 | """Update results_dictionary with a scalar value. 50 | 51 | Used to update the results_dictionary to be returned by parse_values when 52 | encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) 53 | 54 | Mutates results_dictionary. 55 | 56 | Args: 57 | name: Name of variable in assignment ("s" or "arr"). 58 | parse_fn: Function for parsing the actual value. 59 | var_type: Type of named variable. 60 | m_dict: Dictionary constructed from regex parsing. 61 | m_dict['val']: RHS value (scalar) 62 | m_dict['index']: List index value (or None) 63 | values: Full expression being parsed 64 | results_dictionary: The dictionary being updated for return by the parsing 65 | function. 66 | 67 | Raises: 68 | ValueError: If the name has already been used. 69 | """ 70 | try: 71 | parsed_value = parse_fn(m_dict["val"]) 72 | except ValueError: 73 | _parse_fail(name, var_type, m_dict["val"], values) 74 | 75 | # If no index is provided 76 | if not m_dict["index"]: 77 | if name in results_dictionary: 78 | _reuse_fail(name, values) 79 | results_dictionary[name] = parsed_value 80 | else: 81 | if name in results_dictionary: 82 | # The name has already been used as a scalar, then it 83 | # will be in this dictionary and map to a non-dictionary. 84 | if not isinstance(results_dictionary.get(name), dict): 85 | _reuse_fail(name, values) 86 | else: 87 | results_dictionary[name] = {} 88 | 89 | index = int(m_dict["index"]) 90 | # Make sure the index position hasn't already been assigned a value. 91 | if index in results_dictionary[name]: 92 | _reuse_fail("{}[{}]".format(name, index), values) 93 | results_dictionary[name][index] = parsed_value 94 | 95 | 96 | def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary): 97 | """Update results_dictionary from a list of values. 98 | 99 | Used to update results_dictionary to be returned by parse_values when 100 | encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) 101 | 102 | Mutates results_dictionary. 103 | 104 | Args: 105 | name: Name of variable in assignment ("arr"). 106 | parse_fn: Function for parsing individual values. 107 | var_type: Type of named variable. 108 | m_dict: Dictionary constructed from regex parsing. 109 | m_dict['val']: RHS value (scalar) 110 | values: Full expression being parsed 111 | results_dictionary: The dictionary being updated for return by the parsing 112 | function. 113 | 114 | Raises: 115 | ValueError: If the name has an index or the values cannot be parsed. 116 | """ 117 | if m_dict["index"] is not None: 118 | raise ValueError("Assignment of a list to a list index.") 119 | elements = filter(None, re.split("[ ,]", m_dict["vals"])) 120 | # Make sure the name hasn't already been assigned a value 121 | if name in results_dictionary: 122 | raise _reuse_fail(name, values) 123 | try: 124 | results_dictionary[name] = [parse_fn(e) for e in elements] 125 | except ValueError: 126 | _parse_fail(name, var_type, m_dict["vals"], values) 127 | 128 | 129 | def _cast_to_type_if_compatible(name, param_type, value): 130 | """Cast hparam to the provided type, if compatible. 131 | 132 | Args: 133 | name: Name of the hparam to be cast. 134 | param_type: The type of the hparam. 135 | value: The value to be cast, if compatible. 136 | 137 | Returns: 138 | The result of casting `value` to `param_type`. 139 | 140 | Raises: 141 | ValueError: If the type of `value` is not compatible with param_type. 142 | * If `param_type` is a string type, but `value` is not. 143 | * If `param_type` is a boolean, but `value` is not, or vice versa. 144 | * If `param_type` is an integer type, but `value` is not. 145 | * If `param_type` is a float type, but `value` is not a numeric type. 146 | """ 147 | fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % ( 148 | name, 149 | param_type, 150 | value, 151 | ) 152 | 153 | # Some callers use None, for which we can't do any casting/checking. :( 154 | if issubclass(param_type, type(None)): 155 | return value 156 | 157 | # Avoid converting a non-string type to a string. 158 | if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance( 159 | value, (six.string_types, six.binary_type) 160 | ): 161 | raise ValueError(fail_msg) 162 | 163 | # Avoid converting a number or string type to a boolean or vice versa. 164 | if issubclass(param_type, bool) != isinstance(value, bool): 165 | raise ValueError(fail_msg) 166 | 167 | # Avoid converting float to an integer (the reverse is fine). 168 | if issubclass(param_type, numbers.Integral) and not isinstance( 169 | value, numbers.Integral 170 | ): 171 | raise ValueError(fail_msg) 172 | 173 | # Avoid converting a non-numeric type to a numeric type. 174 | if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number): 175 | raise ValueError(fail_msg) 176 | 177 | return param_type(value) 178 | 179 | 180 | def parse_values(values, type_map, ignore_unknown=False): 181 | """Parses hyperparameter values from a string into a python map. 182 | 183 | `values` is a string containing comma-separated `name=value` pairs. 184 | For each pair, the value of the hyperparameter named `name` is set to 185 | `value`. 186 | 187 | If a hyperparameter name appears multiple times in `values`, a ValueError 188 | is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). 189 | 190 | If a hyperparameter name in both an index assignment and scalar assignment, 191 | a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). 192 | 193 | The hyperparameter name may contain '.' symbols, which will result in an 194 | attribute name that is only accessible through the getattr and setattr 195 | functions. (And must be first explicit added through add_hparam.) 196 | 197 | WARNING: Use of '.' in your variable names is allowed, but is not well 198 | supported and not recommended. 199 | 200 | The `value` in `name=value` must follows the syntax according to the 201 | type of the parameter: 202 | 203 | * Scalar integer: A Python-parsable integer point value. E.g.: 1, 204 | 100, -12. 205 | * Scalar float: A Python-parsable floating point value. E.g.: 1.0, 206 | -.54e89. 207 | * Boolean: Either true or false. 208 | * Scalar string: A non-empty sequence of characters, excluding comma, 209 | spaces, and square brackets. E.g.: foo, bar_1. 210 | * List: A comma separated list of scalar values of the parameter type 211 | enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. 212 | 213 | When index assignment is used, the corresponding type_map key should be the 214 | list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not 215 | "arr[1]"). 216 | 217 | Args: 218 | values: String. Comma separated list of `name=value` pairs where 219 | 'value' must follow the syntax described above. 220 | type_map: A dictionary mapping hyperparameter names to types. Note every 221 | parameter name in values must be a key in type_map. The values must 222 | conform to the types indicated, where a value V is said to conform to a 223 | type T if either V has type T, or V is a list of elements of type T. 224 | Hence, for a multidimensional parameter 'x' taking float values, 225 | 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. 226 | ignore_unknown: Bool. Whether values that are missing a type in type_map 227 | should be ignored. If set to True, a ValueError will not be raised for 228 | unknown hyperparameter type. 229 | 230 | Returns: 231 | A python map mapping each name to either: 232 | * A scalar value. 233 | * A list of scalar values. 234 | * A dictionary mapping index numbers to scalar values. 235 | (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") 236 | 237 | Raises: 238 | ValueError: If there is a problem with input. 239 | * If `values` cannot be parsed. 240 | * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). 241 | * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', 242 | 'a[1]=1,a[1]=2', or 'a=1,a=[1]') 243 | """ 244 | results_dictionary = {} 245 | pos = 0 246 | while pos < len(values): 247 | m = PARAM_RE.match(values, pos) 248 | if not m: 249 | raise ValueError("Malformed hyperparameter value: %s" % values[pos:]) 250 | # Check that there is a comma between parameters and move past it. 251 | pos = m.end() 252 | # Parse the values. 253 | m_dict = m.groupdict() 254 | name = m_dict["name"] 255 | if name not in type_map: 256 | if ignore_unknown: 257 | continue 258 | raise ValueError("Unknown hyperparameter type for %s" % name) 259 | type_ = type_map[name] 260 | 261 | # Set up correct parsing function (depending on whether type_ is a bool) 262 | if type_ == bool: 263 | 264 | def parse_bool(value): 265 | if value in ["true", "True"]: 266 | return True 267 | elif value in ["false", "False"]: 268 | return False 269 | else: 270 | try: 271 | return bool(int(value)) 272 | except ValueError: 273 | _parse_fail(name, type_, value, values) 274 | 275 | parse = parse_bool 276 | else: 277 | parse = type_ 278 | 279 | # If a singe value is provided 280 | if m_dict["val"] is not None: 281 | _process_scalar_value( 282 | name, parse, type_, m_dict, values, results_dictionary 283 | ) 284 | 285 | # If the assigned value is a list: 286 | elif m_dict["vals"] is not None: 287 | _process_list_value(name, parse, type_, m_dict, values, results_dictionary) 288 | 289 | else: # Not assigned a list or value 290 | _parse_fail(name, type_, "", values) 291 | 292 | return results_dictionary 293 | 294 | 295 | class HParams(object): 296 | """Class to hold a set of hyperparameters as name-value pairs. 297 | 298 | A `HParams` object holds hyperparameters used to build and train a model, 299 | such as the number of hidden units in a neural net layer or the learning rate 300 | to use when training. 301 | 302 | You first create a `HParams` object by specifying the names and values of the 303 | hyperparameters. 304 | 305 | To make them easily accessible the parameter names are added as direct 306 | attributes of the class. A typical usage is as follows: 307 | 308 | ```python 309 | # Create a HParams object specifying names and values of the model 310 | # hyperparameters: 311 | hparams = HParams(learning_rate=0.1, num_hidden_units=100) 312 | 313 | # The hyperparameter are available as attributes of the HParams object: 314 | hparams.learning_rate ==> 0.1 315 | hparams.num_hidden_units ==> 100 316 | ``` 317 | 318 | Hyperparameters have type, which is inferred from the type of their value 319 | passed at construction type. The currently supported types are: integer, 320 | float, boolean, string, and list of integer, float, boolean, or string. 321 | 322 | You can override hyperparameter values by calling the 323 | [`parse()`](#HParams.parse) method, passing a string of comma separated 324 | `name=value` pairs. This is intended to make it possible to override 325 | any hyperparameter values from a single command-line flag to which 326 | the user passes 'hyper-param=value' pairs. It avoids having to define 327 | one flag for each hyperparameter. 328 | 329 | The syntax expected for each value depends on the type of the parameter. 330 | See `parse()` for a description of the syntax. 331 | 332 | Example: 333 | 334 | ```python 335 | # Define a command line flag to pass name=value pairs. 336 | # For example using argparse: 337 | import argparse 338 | parser = argparse.ArgumentParser(description='Train my model.') 339 | parser.add_argument('--hparams', type=str, 340 | help='Comma separated list of "name=value" pairs.') 341 | args = parser.parse_args() 342 | ... 343 | def my_program(): 344 | # Create a HParams object specifying the names and values of the 345 | # model hyperparameters: 346 | hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, 347 | activations=['relu', 'tanh']) 348 | 349 | # Override hyperparameters values by parsing the command line 350 | hparams.parse(args.hparams) 351 | 352 | # If the user passed `--hparams=learning_rate=0.3` on the command line 353 | # then 'hparams' has the following attributes: 354 | hparams.learning_rate ==> 0.3 355 | hparams.num_hidden_units ==> 100 356 | hparams.activations ==> ['relu', 'tanh'] 357 | 358 | # If the hyperparameters are in json format use parse_json: 359 | hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') 360 | ``` 361 | """ 362 | 363 | _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. 364 | 365 | def __init__(self, model_structure=None, **kwargs): 366 | """Create an instance of `HParams` from keyword arguments. 367 | 368 | The keyword arguments specify name-values pairs for the hyperparameters. 369 | The parameter types are inferred from the type of the values passed. 370 | 371 | The parameter names are added as attributes of `HParams` object, so they 372 | can be accessed directly with the dot notation `hparams._name_`. 373 | 374 | Example: 375 | 376 | ```python 377 | # Define 3 hyperparameters: 'learning_rate' is a float parameter, 378 | # 'num_hidden_units' an integer parameter, and 'activation' a string 379 | # parameter. 380 | hparams = tf.HParams( 381 | learning_rate=0.1, num_hidden_units=100, activation='relu') 382 | 383 | hparams.activation ==> 'relu' 384 | ``` 385 | 386 | Note that a few names are reserved and cannot be used as hyperparameter 387 | names. If you use one of the reserved name the constructor raises a 388 | `ValueError`. 389 | 390 | Args: 391 | model_structure: An instance of ModelStructure, defining the feature 392 | crosses to be used in the Trial. 393 | **kwargs: Key-value pairs where the key is the hyperparameter name and 394 | the value is the value for the parameter. 395 | 396 | Raises: 397 | ValueError: If both `hparam_def` and initialization values are provided, 398 | or if one of the arguments is invalid. 399 | 400 | """ 401 | # Register the hyperparameters and their type in _hparam_types. 402 | # This simplifies the implementation of parse(). 403 | # _hparam_types maps the parameter name to a tuple (type, bool). 404 | # The type value is the type of the parameter for scalar hyperparameters, 405 | # or the type of the list elements for multidimensional hyperparameters. 406 | # The bool value is True if the value is a list, False otherwise. 407 | self._hparam_types = {} 408 | self._model_structure = model_structure 409 | for name, value in six.iteritems(kwargs): 410 | self.add_hparam(name, value) 411 | 412 | def add_hparam(self, name, value): 413 | """Adds {name, value} pair to hyperparameters. 414 | 415 | Args: 416 | name: Name of the hyperparameter. 417 | value: Value of the hyperparameter. Can be one of the following types: 418 | int, float, string, int list, float list, or string list. 419 | 420 | Raises: 421 | ValueError: if one of the arguments is invalid. 422 | """ 423 | # Keys in kwargs are unique, but 'name' could the name of a pre-existing 424 | # attribute of this object. In that case we refuse to use it as a 425 | # hyperparameter name. 426 | if getattr(self, name, None) is not None: 427 | raise ValueError("Hyperparameter name is reserved: %s" % name) 428 | if isinstance(value, (list, tuple)): 429 | if not value: 430 | raise ValueError( 431 | "Multi-valued hyperparameters cannot be empty: %s" % name 432 | ) 433 | self._hparam_types[name] = (type(value[0]), True) 434 | else: 435 | self._hparam_types[name] = (type(value), False) 436 | setattr(self, name, value) 437 | 438 | def set_hparam(self, name, value): 439 | """Set the value of an existing hyperparameter. 440 | 441 | This function verifies that the type of the value matches the type of the 442 | existing hyperparameter. 443 | 444 | Args: 445 | name: Name of the hyperparameter. 446 | value: New value of the hyperparameter. 447 | 448 | Raises: 449 | KeyError: If the hyperparameter doesn't exist. 450 | ValueError: If there is a type mismatch. 451 | """ 452 | param_type, is_list = self._hparam_types[name] 453 | if isinstance(value, list): 454 | if not is_list: 455 | raise ValueError( 456 | "Must not pass a list for single-valued parameter: %s" % name 457 | ) 458 | setattr( 459 | self, 460 | name, 461 | [_cast_to_type_if_compatible(name, param_type, v) for v in value], 462 | ) 463 | else: 464 | if is_list: 465 | raise ValueError( 466 | "Must pass a list for multi-valued parameter: %s." % name 467 | ) 468 | setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) 469 | 470 | def del_hparam(self, name): 471 | """Removes the hyperparameter with key 'name'. 472 | 473 | Does nothing if it isn't present. 474 | 475 | Args: 476 | name: Name of the hyperparameter. 477 | """ 478 | if hasattr(self, name): 479 | delattr(self, name) 480 | del self._hparam_types[name] 481 | 482 | def parse(self, values): 483 | """Override existing hyperparameter values, parsing new values from a string. 484 | 485 | See parse_values for more detail on the allowed format for values. 486 | 487 | Args: 488 | values: String. Comma separated list of `name=value` pairs where 'value' 489 | must follow the syntax described above. 490 | 491 | Returns: 492 | The `HParams` instance. 493 | 494 | Raises: 495 | ValueError: If `values` cannot be parsed or a hyperparameter in `values` 496 | doesn't exist. 497 | """ 498 | type_map = {} 499 | for name, t in self._hparam_types.items(): 500 | param_type, _ = t 501 | type_map[name] = param_type 502 | 503 | values_map = parse_values(values, type_map) 504 | return self.override_from_dict(values_map) 505 | 506 | def override_from_dict(self, values_dict): 507 | """Override existing hyperparameter values, parsing new values from a dictionary. 508 | 509 | Args: 510 | values_dict: Dictionary of name:value pairs. 511 | 512 | Returns: 513 | The `HParams` instance. 514 | 515 | Raises: 516 | KeyError: If a hyperparameter in `values_dict` doesn't exist. 517 | ValueError: If `values_dict` cannot be parsed. 518 | """ 519 | for name, value in values_dict.items(): 520 | self.set_hparam(name, value) 521 | return self 522 | 523 | def set_model_structure(self, model_structure): 524 | self._model_structure = model_structure 525 | 526 | def get_model_structure(self): 527 | return self._model_structure 528 | 529 | def to_json(self, indent=None, separators=None, sort_keys=False): 530 | """Serializes the hyperparameters into JSON. 531 | 532 | Args: 533 | indent: If a non-negative integer, JSON array elements and object members 534 | will be pretty-printed with that indent level. An indent level of 0, or 535 | negative, will only insert newlines. `None` (the default) selects the 536 | most compact representation. 537 | separators: Optional `(item_separator, key_separator)` tuple. Default is 538 | `(', ', ': ')`. 539 | sort_keys: If `True`, the output dictionaries will be sorted by key. 540 | 541 | Returns: 542 | A JSON string. 543 | """ 544 | 545 | def remove_callables(x): 546 | """Omit callable elements from input with arbitrary nesting.""" 547 | if isinstance(x, dict): 548 | return { 549 | k: remove_callables(v) 550 | for k, v in six.iteritems(x) 551 | if not callable(v) 552 | } 553 | elif isinstance(x, list): 554 | return [remove_callables(i) for i in x if not callable(i)] 555 | return x 556 | 557 | return json.dumps( 558 | remove_callables(self.values()), 559 | indent=indent, 560 | separators=separators, 561 | sort_keys=sort_keys, 562 | ) 563 | 564 | def parse_json(self, values_json): 565 | """Override existing hyperparameter values, parsing new values from a json object. 566 | 567 | Args: 568 | values_json: String containing a json object of name:value pairs. 569 | 570 | Returns: 571 | The `HParams` instance. 572 | 573 | Raises: 574 | KeyError: If a hyperparameter in `values_json` doesn't exist. 575 | ValueError: If `values_json` cannot be parsed. 576 | """ 577 | values_map = json.loads(values_json) 578 | return self.override_from_dict(values_map) 579 | 580 | def values(self): 581 | """Return the hyperparameter values as a Python dictionary. 582 | 583 | Returns: 584 | A dictionary with hyperparameter names as keys. The values are the 585 | hyperparameter values. 586 | """ 587 | return {n: getattr(self, n) for n in self._hparam_types.keys()} 588 | 589 | def get(self, key, default=None): 590 | """Returns the value of `key` if it exists, else `default`.""" 591 | if key in self._hparam_types: 592 | # Ensure that default is compatible with the parameter type. 593 | if default is not None: 594 | param_type, is_param_list = self._hparam_types[key] 595 | type_str = "list<%s>" % param_type if is_param_list else str(param_type) 596 | fail_msg = ( 597 | "Hparam '%s' of type '%s' is incompatible with " 598 | "default=%s" % (key, type_str, default) 599 | ) 600 | 601 | is_default_list = isinstance(default, list) 602 | if is_param_list != is_default_list: 603 | raise ValueError(fail_msg) 604 | 605 | try: 606 | if is_default_list: 607 | for value in default: 608 | _cast_to_type_if_compatible(key, param_type, value) 609 | else: 610 | _cast_to_type_if_compatible(key, param_type, default) 611 | except ValueError as e: 612 | raise ValueError("%s. %s" % (fail_msg, e)) 613 | 614 | return getattr(self, key) 615 | 616 | return default 617 | 618 | def __contains__(self, key): 619 | return key in self._hparam_types 620 | 621 | def __str__(self): 622 | return str(sorted(self.values().items())) 623 | 624 | def __repr__(self): 625 | return "%s(%s)" % (type(self).__name__, self.__str__()) 626 | 627 | @staticmethod 628 | def _get_kind_name(param_type, is_list): 629 | """Returns the field name given parameter type and is_list. 630 | 631 | Args: 632 | param_type: Data type of the hparam. 633 | is_list: Whether this is a list. 634 | 635 | Returns: 636 | A string representation of the field name. 637 | 638 | Raises: 639 | ValueError: If parameter type is not recognized. 640 | """ 641 | if issubclass(param_type, bool): 642 | # This check must happen before issubclass(param_type, six.integer_types), 643 | # since Python considers bool to be a subclass of int. 644 | typename = "bool" 645 | elif issubclass(param_type, six.integer_types): 646 | # Setting 'int' and 'long' types to be 'int64' to ensure the type is 647 | # compatible with both Python2 and Python3. 648 | typename = "int64" 649 | elif issubclass(param_type, (six.string_types, six.binary_type)): 650 | # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is 651 | # compatible with both Python2 and Python3. 652 | typename = "bytes" 653 | elif issubclass(param_type, float): 654 | typename = "float" 655 | else: 656 | raise ValueError("Unsupported parameter type: %s" % str(param_type)) 657 | 658 | suffix = "list" if is_list else "value" 659 | return "_".join([typename, suffix]) 660 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import torchaudio 10 | 11 | 12 | def save_feature(process_dir, feature_dir, item, feature, overrides=True): 13 | """Save features to path 14 | 15 | Args: 16 | process_dir (str): directory to store features 17 | feature_dir (_type_): directory to store one type of features (mel, energy, ...) 18 | item (str): uid 19 | feature (tensor): feature tensor 20 | overrides (bool, optional): whether to override existing files. Defaults to True. 21 | """ 22 | process_dir = os.path.join(process_dir, feature_dir) 23 | os.makedirs(process_dir, exist_ok=True) 24 | out_path = os.path.join(process_dir, item + ".npy") 25 | 26 | if os.path.exists(out_path): 27 | if overrides: 28 | np.save(out_path, feature) 29 | else: 30 | np.save(out_path, feature) 31 | 32 | 33 | def save_txt(process_dir, feature_dir, item, feature, overrides=True): 34 | process_dir = os.path.join(process_dir, feature_dir) 35 | os.makedirs(process_dir, exist_ok=True) 36 | out_path = os.path.join(process_dir, item + ".txt") 37 | 38 | if os.path.exists(out_path): 39 | if overrides: 40 | f = open(out_path, "w") 41 | f.writelines(feature) 42 | f.close() 43 | else: 44 | f = open(out_path, "w") 45 | f.writelines(feature) 46 | f.close() 47 | 48 | 49 | def save_audio(path, waveform, fs, add_silence=False, turn_up=False, volume_peak=0.9): 50 | """Save audio to path with processing (turn up volume, add silence) 51 | Args: 52 | path (str): path to save audio 53 | waveform (numpy array): waveform to save 54 | fs (int): sampling rate 55 | add_silence (bool, optional): whether to add silence to beginning and end. Defaults to False. 56 | turn_up (bool, optional): whether to turn up volume. Defaults to False. 57 | volume_peak (float, optional): volume peak. Defaults to 0.9. 58 | """ 59 | if turn_up: 60 | # continue to turn up to volume_peak 61 | ratio = volume_peak / max(waveform.max(), abs(waveform.min())) 62 | waveform = waveform * ratio 63 | 64 | if add_silence: 65 | silence_len = fs // 20 66 | silence = np.zeros((silence_len,), dtype=waveform.dtype) 67 | result = np.concatenate([silence, waveform, silence]) 68 | waveform = result 69 | 70 | waveform = torch.as_tensor(waveform, dtype=torch.float32, device="cpu") 71 | if len(waveform.size()) == 1: 72 | waveform = waveform[None, :] 73 | elif waveform.size(0) != 1: 74 | # Stereo to mono 75 | waveform = torch.mean(waveform, dim=0, keepdim=True) 76 | torchaudio.save(path, waveform, fs, encoding="PCM_S", bits_per_sample=16) 77 | 78 | 79 | def save_torch_audio(process_dir, feature_dir, item, wav_torch, fs, overrides=True): 80 | """Save torch audio to path without processing 81 | Args: 82 | process_dir (str): directory to store features 83 | feature_dir (_type_): directory to store one type of features (mel, energy, ...) 84 | item (str): uid 85 | wav_torch (tensor): feature tensor 86 | fs (int): sampling rate 87 | overrides (bool, optional): whether to override existing files. Defaults to True. 88 | """ 89 | if wav_torch.shape != 2: 90 | wav_torch = wav_torch.unsqueeze(0) 91 | 92 | process_dir = os.path.join(process_dir, feature_dir) 93 | os.makedirs(process_dir, exist_ok=True) 94 | out_path = os.path.join(process_dir, item + ".wav") 95 | 96 | torchaudio.save(out_path, wav_torch, fs) 97 | 98 | 99 | async def async_load_audio(path, sample_rate: int = 24000): 100 | r""" 101 | Args: 102 | path: The source loading path. 103 | sample_rate: The target sample rate, will automatically resample if necessary. 104 | 105 | Returns: 106 | waveform: The waveform object. Should be [1 x sequence_len]. 107 | """ 108 | 109 | async def use_torchaudio_load(path): 110 | return torchaudio.load(path) 111 | 112 | waveform, sr = await use_torchaudio_load(path) 113 | waveform = torch.mean(waveform, dim=0, keepdim=True) 114 | 115 | if sr != sample_rate: 116 | waveform = torchaudio.functional.resample(waveform, sr, sample_rate) 117 | 118 | if torch.any(torch.isnan(waveform) or torch.isinf(waveform)): 119 | raise ValueError("NaN or Inf found in waveform.") 120 | return waveform 121 | 122 | 123 | async def async_save_audio( 124 | path, 125 | waveform, 126 | sample_rate: int = 24000, 127 | add_silence: bool = False, 128 | volume_peak: float = 0.9, 129 | ): 130 | r""" 131 | Args: 132 | path: The target saving path. 133 | waveform: The waveform object. Should be [n_channel x sequence_len]. 134 | sample_rate: Sample rate. 135 | add_silence: If ``true``, concat 0.05s silence to beginning and end. 136 | volume_peak: Turn up volume for larger number, vice versa. 137 | """ 138 | 139 | async def use_torchaudio_save(path, waveform, sample_rate): 140 | torchaudio.save( 141 | path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16 142 | ) 143 | 144 | waveform = torch.as_tensor(waveform, device="cpu", dtype=torch.float32) 145 | shape = waveform.size()[:-1] 146 | 147 | ratio = abs(volume_peak) / max(waveform.max(), abs(waveform.min())) 148 | waveform = waveform * ratio 149 | 150 | if add_silence: 151 | silence_len = sample_rate // 20 152 | silence = torch.zeros((*shape, silence_len), dtype=waveform.type()) 153 | waveform = torch.concatenate((silence, waveform, silence), dim=-1) 154 | 155 | if waveform.dim() == 1: 156 | waveform = waveform[None] 157 | 158 | await use_torchaudio_save(path, waveform, sample_rate) 159 | 160 | 161 | def load_mel_extrema(cfg, dataset_name, split): 162 | dataset_dir = os.path.join( 163 | cfg.OUTPUT_PATH, 164 | "preprocess/{}_version".format(cfg.data.process_version), 165 | dataset_name, 166 | ) 167 | 168 | min_file = os.path.join( 169 | dataset_dir, 170 | "mel_min_max", 171 | split.split("_")[-1], 172 | "mel_min.npy", 173 | ) 174 | max_file = os.path.join( 175 | dataset_dir, 176 | "mel_min_max", 177 | split.split("_")[-1], 178 | "mel_max.npy", 179 | ) 180 | mel_min = np.load(min_file) 181 | mel_max = np.load(max_file) 182 | return mel_min, mel_max 183 | -------------------------------------------------------------------------------- /utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | import matplotlib 5 | 6 | matplotlib.use("Agg") 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | import os 11 | import json 12 | import models.tta.hifigan as hifigan 13 | 14 | 15 | def reduce_tensors(metrics): 16 | new_metrics = {} 17 | for k, v in metrics.items(): 18 | if isinstance(v, torch.Tensor): 19 | dist.all_reduce(v) 20 | v = v / dist.get_world_size() 21 | if type(v) is dict: 22 | v = reduce_tensors(v) 23 | new_metrics[k] = v 24 | return new_metrics 25 | 26 | 27 | def tensors_to_scalars(tensors): 28 | if isinstance(tensors, torch.Tensor): 29 | tensors = tensors.item() 30 | return tensors 31 | elif isinstance(tensors, dict): 32 | new_tensors = {} 33 | for k, v in tensors.items(): 34 | v = tensors_to_scalars(v) 35 | new_tensors[k] = v 36 | return new_tensors 37 | elif isinstance(tensors, list): 38 | return [tensors_to_scalars(v) for v in tensors] 39 | else: 40 | return tensors 41 | 42 | 43 | def tensors_to_np(tensors): 44 | if isinstance(tensors, dict): 45 | new_np = {} 46 | for k, v in tensors.items(): 47 | if isinstance(v, torch.Tensor): 48 | v = v.cpu().numpy() 49 | if type(v) is dict: 50 | v = tensors_to_np(v) 51 | new_np[k] = v 52 | elif isinstance(tensors, list): 53 | new_np = [] 54 | for v in tensors: 55 | if isinstance(v, torch.Tensor): 56 | v = v.cpu().numpy() 57 | if type(v) is dict: 58 | v = tensors_to_np(v) 59 | new_np.append(v) 60 | elif isinstance(tensors, torch.Tensor): 61 | v = tensors 62 | if isinstance(v, torch.Tensor): 63 | v = v.cpu().numpy() 64 | if type(v) is dict: 65 | v = tensors_to_np(v) 66 | new_np = v 67 | else: 68 | raise Exception(f"tensors_to_np does not support type {type(tensors)}.") 69 | return new_np 70 | 71 | 72 | def move_to_cpu(tensors): 73 | ret = {} 74 | for k, v in tensors.items(): 75 | if isinstance(v, torch.Tensor): 76 | v = v.cpu() 77 | if type(v) is dict: 78 | v = move_to_cpu(v) 79 | ret[k] = v 80 | return ret 81 | 82 | 83 | def move_to_cuda(batch, gpu_id=0): 84 | # base case: object can be directly moved using `cuda` or `to` 85 | if callable(getattr(batch, "cuda", None)): 86 | return batch.cuda(gpu_id, non_blocking=True) 87 | elif callable(getattr(batch, "to", None)): 88 | return batch.to(torch.device("cuda", gpu_id), non_blocking=True) 89 | elif isinstance(batch, list): 90 | for i, x in enumerate(batch): 91 | batch[i] = move_to_cuda(x, gpu_id) 92 | return batch 93 | elif isinstance(batch, tuple): 94 | batch = list(batch) 95 | for i, x in enumerate(batch): 96 | batch[i] = move_to_cuda(x, gpu_id) 97 | return tuple(batch) 98 | elif isinstance(batch, dict): 99 | for k, v in batch.items(): 100 | batch[k] = move_to_cuda(v, gpu_id) 101 | return batch 102 | return batch 103 | 104 | 105 | def log_metrics(logger, metrics, step=None): 106 | for k, v in metrics.items(): 107 | if isinstance(v, torch.Tensor): 108 | v = v.item() 109 | logger.add_scalar(k, v, step) 110 | 111 | 112 | def spec_to_figure(spec, vmin=None, vmax=None, title="", f0s=None, dur_info=None): 113 | if isinstance(spec, torch.Tensor): 114 | spec = spec.cpu().numpy() 115 | H = spec.shape[1] // 2 116 | fig = plt.figure(figsize=(12, 6), dpi=100) 117 | plt.title(title) 118 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 119 | if dur_info is not None: 120 | assert isinstance(dur_info, dict) 121 | txt = dur_info["txt"] 122 | dur_gt = dur_info["dur_gt"] 123 | if isinstance(dur_gt, torch.Tensor): 124 | dur_gt = dur_gt.cpu().numpy() 125 | dur_gt = np.cumsum(dur_gt).astype(int) 126 | for i in range(len(dur_gt)): 127 | shift = (i % 8) + 1 128 | plt.text(dur_gt[i], shift * 4, txt[i]) 129 | plt.vlines(dur_gt[i], 0, H // 2, colors="b") # blue is gt 130 | plt.xlim(0, dur_gt[-1]) 131 | if "dur_pred" in dur_info: 132 | dur_pred = dur_info["dur_pred"] 133 | if isinstance(dur_pred, torch.Tensor): 134 | dur_pred = dur_pred.cpu().numpy() 135 | dur_pred = np.cumsum(dur_pred).astype(int) 136 | for i in range(len(dur_pred)): 137 | shift = (i % 8) + 1 138 | plt.text(dur_pred[i], H + shift * 4, txt[i]) 139 | plt.vlines(dur_pred[i], H, H * 1.5, colors="r") # red is pred 140 | plt.xlim(0, max(dur_gt[-1], dur_pred[-1])) 141 | # if f0s is not None: 142 | # ax = plt.gca() 143 | # ax2 = ax.twinx() 144 | # if not isinstance(f0s, dict): 145 | # f0s = {"f0": f0s} 146 | # for i, (k, f0) in enumerate(f0s.items()): 147 | # if isinstance(f0, torch.Tensor): 148 | # f0 = f0.cpu().numpy() 149 | # ax2.plot(f0, label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.5) 150 | # ax2.set_ylim(0, 1000) 151 | # ax2.legend() 152 | return fig 153 | 154 | 155 | def get_vocoder(config, device): 156 | ROOT = "/work/gwh/Amphion/ckpts/tta/hifigan" 157 | 158 | model_path = os.path.join(ROOT, "hifigan_16k_64bins") 159 | with open(model_path + ".json", "r") as f: 160 | config = json.load(f) 161 | config = hifigan.AttrDict(config) 162 | vocoder = hifigan.Generator(config) 163 | 164 | ckpt = torch.load(model_path + ".ckpt") 165 | ckpt = torch_version_orig_mod_remove(ckpt) 166 | vocoder.load_state_dict(ckpt["generator"]) 167 | vocoder.eval() 168 | vocoder.remove_weight_norm() 169 | vocoder.to(device) 170 | return vocoder 171 | 172 | 173 | def torch_version_orig_mod_remove(state_dict): 174 | new_state_dict = {} 175 | new_state_dict["generator"] = {} 176 | for key in state_dict["generator"].keys(): 177 | if "_orig_mod." in key: 178 | new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ 179 | "generator" 180 | ][key] 181 | else: 182 | new_state_dict["generator"][key] = state_dict["generator"][key] 183 | return new_state_dict 184 | 185 | 186 | def vocoder_infer(mels, vocoder, lengths=None): 187 | with torch.no_grad(): 188 | wavs = vocoder(mels).squeeze(1) 189 | 190 | wavs = (wavs.cpu().numpy() * 32768).astype("int16") 191 | 192 | if lengths is not None: 193 | wavs = wavs[:, :lengths] 194 | 195 | return wavs 196 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Amphion. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import collections 8 | import glob 9 | import os 10 | import random 11 | import time 12 | import argparse 13 | from collections import OrderedDict 14 | 15 | import json5 16 | import numpy as np 17 | import glob 18 | from torch.nn import functional as F 19 | 20 | 21 | try: 22 | from ruamel.yaml import YAML as yaml 23 | except: 24 | from ruamel_yaml import YAML as yaml 25 | 26 | import torch 27 | 28 | from utils.hparam import HParams 29 | import logging 30 | from logging import handlers 31 | 32 | 33 | def str2bool(v): 34 | """Used in argparse.ArgumentParser.add_argument to indicate 35 | that a type is a bool type and user can enter 36 | 37 | - yes, true, t, y, 1, to represent True 38 | - no, false, f, n, 0, to represent False 39 | 40 | See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa 41 | """ 42 | if isinstance(v, bool): 43 | return v 44 | if v.lower() in ("yes", "true", "t", "y", "1"): 45 | return True 46 | elif v.lower() in ("no", "false", "f", "n", "0"): 47 | return False 48 | else: 49 | raise argparse.ArgumentTypeError("Boolean value expected.") 50 | 51 | 52 | def find_checkpoint_of_mapper(mapper_ckpt_dir): 53 | mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt")) 54 | 55 | # Select the max steps 56 | mapper_ckpts.sort() 57 | mapper_weights_file = mapper_ckpts[-1] 58 | return mapper_weights_file 59 | 60 | 61 | def pad_f0_to_tensors(f0s, batched=None): 62 | # Initialize 63 | tensors = [] 64 | 65 | if batched == None: 66 | # Get the max frame for padding 67 | size = -1 68 | for f0 in f0s: 69 | size = max(size, f0.shape[-1]) 70 | 71 | tensor = torch.zeros(len(f0s), size) 72 | 73 | for i, f0 in enumerate(f0s): 74 | tensor[i, : f0.shape[-1]] = f0[:] 75 | 76 | tensors.append(tensor) 77 | else: 78 | start = 0 79 | while start + batched - 1 < len(f0s): 80 | end = start + batched - 1 81 | 82 | # Get the max frame for padding 83 | size = -1 84 | for i in range(start, end + 1): 85 | size = max(size, f0s[i].shape[-1]) 86 | 87 | tensor = torch.zeros(batched, size) 88 | 89 | for i in range(start, end + 1): 90 | tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:] 91 | 92 | tensors.append(tensor) 93 | 94 | start = start + batched 95 | 96 | if start != len(f0s): 97 | end = len(f0s) 98 | 99 | # Get the max frame for padding 100 | size = -1 101 | for i in range(start, end): 102 | size = max(size, f0s[i].shape[-1]) 103 | 104 | tensor = torch.zeros(len(f0s) - start, size) 105 | 106 | for i in range(start, end): 107 | tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:] 108 | 109 | tensors.append(tensor) 110 | 111 | return tensors 112 | 113 | 114 | def pad_mels_to_tensors(mels, batched=None): 115 | """ 116 | Args: 117 | mels: A list of mel-specs 118 | Returns: 119 | tensors: A list of tensors containing the batched mel-specs 120 | mel_frames: A list of tensors containing the frames of the original mel-specs 121 | """ 122 | # Initialize 123 | tensors = [] 124 | mel_frames = [] 125 | 126 | # Split mel-specs into batches to avoid cuda memory exceed 127 | if batched == None: 128 | # Get the max frame for padding 129 | size = -1 130 | for mel in mels: 131 | size = max(size, mel.shape[-1]) 132 | 133 | tensor = torch.zeros(len(mels), mels[0].shape[0], size) 134 | mel_frame = torch.zeros(len(mels), dtype=torch.int32) 135 | 136 | for i, mel in enumerate(mels): 137 | tensor[i, :, : mel.shape[-1]] = mel[:] 138 | mel_frame[i] = mel.shape[-1] 139 | 140 | tensors.append(tensor) 141 | mel_frames.append(mel_frame) 142 | else: 143 | start = 0 144 | while start + batched - 1 < len(mels): 145 | end = start + batched - 1 146 | 147 | # Get the max frame for padding 148 | size = -1 149 | for i in range(start, end + 1): 150 | size = max(size, mels[i].shape[-1]) 151 | 152 | tensor = torch.zeros(batched, mels[0].shape[0], size) 153 | mel_frame = torch.zeros(batched, dtype=torch.int32) 154 | 155 | for i in range(start, end + 1): 156 | tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:] 157 | mel_frame[i - start] = mels[i].shape[-1] 158 | 159 | tensors.append(tensor) 160 | mel_frames.append(mel_frame) 161 | 162 | start = start + batched 163 | 164 | if start != len(mels): 165 | end = len(mels) 166 | 167 | # Get the max frame for padding 168 | size = -1 169 | for i in range(start, end): 170 | size = max(size, mels[i].shape[-1]) 171 | 172 | tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size) 173 | mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32) 174 | 175 | for i in range(start, end): 176 | tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:] 177 | mel_frame[i - start] = mels[i].shape[-1] 178 | 179 | tensors.append(tensor) 180 | mel_frames.append(mel_frame) 181 | 182 | return tensors, mel_frames 183 | 184 | 185 | def load_model_config(args): 186 | """Load model configurations (in args.json under checkpoint directory) 187 | 188 | Args: 189 | args (ArgumentParser): arguments to run bins/preprocess.py 190 | 191 | Returns: 192 | dict: dictionary that stores model configurations 193 | """ 194 | if args.checkpoint_dir is None: 195 | assert args.checkpoint_file is not None 196 | checkpoint_dir = os.path.split(args.checkpoint_file)[0] 197 | else: 198 | checkpoint_dir = args.checkpoint_dir 199 | config_path = os.path.join(checkpoint_dir, "args.json") 200 | print("config_path: ", config_path) 201 | 202 | config = load_config(config_path) 203 | return config 204 | 205 | 206 | def remove_and_create(dir): 207 | if os.path.exists(dir): 208 | os.system("rm -r {}".format(dir)) 209 | os.makedirs(dir, exist_ok=True) 210 | 211 | 212 | def has_existed(path, warning=False): 213 | if not warning: 214 | return os.path.exists(path) 215 | 216 | if os.path.exists(path): 217 | answer = input( 218 | "The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format( 219 | path 220 | ) 221 | ) 222 | if not answer == "n": 223 | return True 224 | 225 | return False 226 | 227 | 228 | def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5): 229 | if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")): 230 | with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f: 231 | ckpts = [x.strip() for x in f.readlines()] 232 | else: 233 | ckpts = [] 234 | ckpts.append(saved_model_name) 235 | for item in ckpts[:-max_to_keep]: 236 | if os.path.exists(os.path.join(checkpoint_dir, item)): 237 | os.remove(os.path.join(checkpoint_dir, item)) 238 | with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f: 239 | for item in ckpts[-max_to_keep:]: 240 | f.write("{}\n".format(item)) 241 | 242 | 243 | def set_all_random_seed(seed: int): 244 | random.seed(seed) 245 | np.random.seed(seed) 246 | torch.random.manual_seed(seed) 247 | 248 | 249 | def save_checkpoint( 250 | args, 251 | generator, 252 | g_optimizer, 253 | step, 254 | discriminator=None, 255 | d_optimizer=None, 256 | max_to_keep=5, 257 | ): 258 | saved_model_name = "model.ckpt-{}.pt".format(step) 259 | checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name) 260 | 261 | if discriminator and d_optimizer: 262 | torch.save( 263 | { 264 | "generator": generator.state_dict(), 265 | "discriminator": discriminator.state_dict(), 266 | "g_optimizer": g_optimizer.state_dict(), 267 | "d_optimizer": d_optimizer.state_dict(), 268 | "global_step": step, 269 | }, 270 | checkpoint_path, 271 | ) 272 | else: 273 | torch.save( 274 | { 275 | "generator": generator.state_dict(), 276 | "g_optimizer": g_optimizer.state_dict(), 277 | "global_step": step, 278 | }, 279 | checkpoint_path, 280 | ) 281 | 282 | print("Saved checkpoint: {}".format(checkpoint_path)) 283 | 284 | if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")): 285 | with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f: 286 | ckpts = [x.strip() for x in f.readlines()] 287 | else: 288 | ckpts = [] 289 | ckpts.append(saved_model_name) 290 | for item in ckpts[:-max_to_keep]: 291 | if os.path.exists(os.path.join(args.checkpoint_dir, item)): 292 | os.remove(os.path.join(args.checkpoint_dir, item)) 293 | with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f: 294 | for item in ckpts[-max_to_keep:]: 295 | f.write("{}\n".format(item)) 296 | 297 | 298 | def attempt_to_restore( 299 | generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None 300 | ): 301 | checkpoint_list = os.path.join(checkpoint_dir, "checkpoint") 302 | if os.path.exists(checkpoint_list): 303 | checkpoint_filename = open(checkpoint_list).readlines()[-1].strip() 304 | checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename)) 305 | print("Restore from {}".format(checkpoint_path)) 306 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 307 | if generator: 308 | if not list(generator.state_dict().keys())[0].startswith("module."): 309 | raw_dict = checkpoint["generator"] 310 | clean_dict = OrderedDict() 311 | for k, v in raw_dict.items(): 312 | if k.startswith("module."): 313 | clean_dict[k[7:]] = v 314 | else: 315 | clean_dict[k] = v 316 | generator.load_state_dict(clean_dict) 317 | else: 318 | generator.load_state_dict(checkpoint["generator"]) 319 | if g_optimizer: 320 | g_optimizer.load_state_dict(checkpoint["g_optimizer"]) 321 | global_step = 100000 322 | if discriminator and "discriminator" in checkpoint.keys(): 323 | discriminator.load_state_dict(checkpoint["discriminator"]) 324 | global_step = checkpoint["global_step"] 325 | print("restore discriminator") 326 | if d_optimizer and "d_optimizer" in checkpoint.keys(): 327 | d_optimizer.load_state_dict(checkpoint["d_optimizer"]) 328 | print("restore d_optimizer...") 329 | else: 330 | global_step = 0 331 | return global_step 332 | 333 | 334 | class ExponentialMovingAverage(object): 335 | def __init__(self, decay): 336 | self.decay = decay 337 | self.shadow = {} 338 | 339 | def register(self, name, val): 340 | self.shadow[name] = val.clone() 341 | 342 | def update(self, name, x): 343 | assert name in self.shadow 344 | update_delta = self.shadow[name] - x 345 | self.shadow[name] -= (1.0 - self.decay) * update_delta 346 | 347 | 348 | def apply_moving_average(model, ema): 349 | for name, param in model.named_parameters(): 350 | if name in ema.shadow: 351 | ema.update(name, param.data) 352 | 353 | 354 | def register_model_to_ema(model, ema): 355 | for name, param in model.named_parameters(): 356 | if param.requires_grad: 357 | ema.register(name, param.data) 358 | 359 | 360 | class YParams(HParams): 361 | def __init__(self, yaml_file): 362 | if not os.path.exists(yaml_file): 363 | raise IOError("yaml file: {} is not existed".format(yaml_file)) 364 | super().__init__() 365 | self.d = collections.OrderedDict() 366 | with open(yaml_file) as fp: 367 | for _, v in yaml().load(fp).items(): 368 | for k1, v1 in v.items(): 369 | try: 370 | if self.get(k1): 371 | self.set_hparam(k1, v1) 372 | else: 373 | self.add_hparam(k1, v1) 374 | self.d[k1] = v1 375 | except Exception: 376 | import traceback 377 | 378 | print(traceback.format_exc()) 379 | 380 | # @property 381 | def get_elements(self): 382 | return self.d.items() 383 | 384 | 385 | def override_config(base_config, new_config): 386 | """Update new configurations in the original dict with the new dict 387 | 388 | Args: 389 | base_config (dict): original dict to be overridden 390 | new_config (dict): dict with new configurations 391 | 392 | Returns: 393 | dict: updated configuration dict 394 | """ 395 | for k, v in new_config.items(): 396 | if type(v) == dict: 397 | if k not in base_config.keys(): 398 | base_config[k] = {} 399 | base_config[k] = override_config(base_config[k], v) 400 | else: 401 | base_config[k] = v 402 | return base_config 403 | 404 | 405 | def get_lowercase_keys_config(cfg): 406 | """Change all keys in cfg to lower case 407 | 408 | Args: 409 | cfg (dict): dictionary that stores configurations 410 | 411 | Returns: 412 | dict: dictionary that stores configurations 413 | """ 414 | updated_cfg = dict() 415 | for k, v in cfg.items(): 416 | if type(v) == dict: 417 | v = get_lowercase_keys_config(v) 418 | updated_cfg[k.lower()] = v 419 | return updated_cfg 420 | 421 | 422 | def _load_config(config_fn, lowercase=False): 423 | """Load configurations into a dictionary 424 | 425 | Args: 426 | config_fn (str): path to configuration file 427 | lowercase (bool, optional): whether changing keys to lower case. Defaults to False. 428 | 429 | Returns: 430 | dict: dictionary that stores configurations 431 | """ 432 | with open(config_fn, "r") as f: 433 | data = f.read() 434 | config_ = json5.loads(data) 435 | if "base_config" in config_: 436 | # load configurations from new path 437 | p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"]) 438 | p_config_ = _load_config(p_config_path) 439 | config_ = override_config(p_config_, config_) 440 | if lowercase: 441 | # change keys in config_ to lower case 442 | config_ = get_lowercase_keys_config(config_) 443 | return config_ 444 | 445 | 446 | def load_config(config_fn, lowercase=False): 447 | """Load configurations into a dictionary 448 | 449 | Args: 450 | config_fn (str): path to configuration file 451 | lowercase (bool, optional): _description_. Defaults to False. 452 | 453 | Returns: 454 | JsonHParams: an object that stores configurations 455 | """ 456 | config_ = _load_config(config_fn, lowercase=lowercase) 457 | # create an JsonHParams object with configuration dict 458 | cfg = JsonHParams(**config_) 459 | return cfg 460 | 461 | 462 | def save_config(save_path, cfg): 463 | """Save configurations into a json file 464 | 465 | Args: 466 | save_path (str): path to save configurations 467 | cfg (dict): dictionary that stores configurations 468 | """ 469 | with open(save_path, "w") as f: 470 | json5.dump( 471 | cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True 472 | ) 473 | 474 | 475 | class JsonHParams: 476 | def __init__(self, **kwargs): 477 | for k, v in kwargs.items(): 478 | if type(v) == dict: 479 | v = JsonHParams(**v) 480 | self[k] = v 481 | 482 | def keys(self): 483 | return self.__dict__.keys() 484 | 485 | def items(self): 486 | return self.__dict__.items() 487 | 488 | def values(self): 489 | return self.__dict__.values() 490 | 491 | def __len__(self): 492 | return len(self.__dict__) 493 | 494 | def __getitem__(self, key): 495 | return getattr(self, key) 496 | 497 | def __setitem__(self, key, value): 498 | return setattr(self, key, value) 499 | 500 | def __contains__(self, key): 501 | return key in self.__dict__ 502 | 503 | def __repr__(self): 504 | return self.__dict__.__repr__() 505 | 506 | 507 | class ValueWindow: 508 | def __init__(self, window_size=100): 509 | self._window_size = window_size 510 | self._values = [] 511 | 512 | def append(self, x): 513 | self._values = self._values[-(self._window_size - 1) :] + [x] 514 | 515 | @property 516 | def sum(self): 517 | return sum(self._values) 518 | 519 | @property 520 | def count(self): 521 | return len(self._values) 522 | 523 | @property 524 | def average(self): 525 | return self.sum / max(1, self.count) 526 | 527 | def reset(self): 528 | self._values = [] 529 | 530 | 531 | class Logger(object): 532 | def __init__( 533 | self, 534 | filename, 535 | level="info", 536 | when="D", 537 | backCount=10, 538 | fmt="%(asctime)s : %(message)s", 539 | ): 540 | self.level_relations = { 541 | "debug": logging.DEBUG, 542 | "info": logging.INFO, 543 | "warning": logging.WARNING, 544 | "error": logging.ERROR, 545 | "crit": logging.CRITICAL, 546 | } 547 | if level == "debug": 548 | fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s" 549 | self.logger = logging.getLogger(filename) 550 | format_str = logging.Formatter(fmt) 551 | self.logger.setLevel(self.level_relations.get(level)) 552 | sh = logging.StreamHandler() 553 | sh.setFormatter(format_str) 554 | th = handlers.TimedRotatingFileHandler( 555 | filename=filename, when=when, backupCount=backCount, encoding="utf-8" 556 | ) 557 | th.setFormatter(format_str) 558 | self.logger.addHandler(sh) 559 | self.logger.addHandler(th) 560 | self.logger.info( 561 | "==========================New Starting Here==============================" 562 | ) 563 | 564 | 565 | def init_weights(m, mean=0.0, std=0.01): 566 | classname = m.__class__.__name__ 567 | if classname.find("Conv") != -1: 568 | m.weight.data.normal_(mean, std) 569 | 570 | 571 | def get_padding(kernel_size, dilation=1): 572 | return int((kernel_size * dilation - dilation) / 2) 573 | 574 | 575 | def slice_segments(x, ids_str, segment_size=4): 576 | ret = torch.zeros_like(x[:, :, :segment_size]) 577 | for i in range(x.size(0)): 578 | idx_str = ids_str[i] 579 | idx_end = idx_str + segment_size 580 | ret[i] = x[i, :, idx_str:idx_end] 581 | return ret 582 | 583 | 584 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 585 | b, d, t = x.size() 586 | if x_lengths is None: 587 | x_lengths = t 588 | ids_str_max = x_lengths - segment_size + 1 589 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 590 | ret = slice_segments(x, ids_str, segment_size) 591 | return ret, ids_str 592 | 593 | 594 | def subsequent_mask(length): 595 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 596 | return mask 597 | 598 | 599 | @torch.jit.script 600 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 601 | n_channels_int = n_channels[0] 602 | in_act = input_a + input_b 603 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 604 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 605 | acts = t_act * s_act 606 | return acts 607 | 608 | 609 | def convert_pad_shape(pad_shape): 610 | l = pad_shape[::-1] 611 | pad_shape = [item for sublist in l for item in sublist] 612 | return pad_shape 613 | 614 | 615 | def sequence_mask(length, max_length=None): 616 | if max_length is None: 617 | max_length = length.max() 618 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 619 | return x.unsqueeze(0) < length.unsqueeze(1) 620 | 621 | 622 | def generate_path(duration, mask): 623 | """ 624 | duration: [b, 1, t_x] 625 | mask: [b, 1, t_y, t_x] 626 | """ 627 | device = duration.device 628 | 629 | b, _, t_y, t_x = mask.shape 630 | cum_duration = torch.cumsum(duration, -1) 631 | 632 | cum_duration_flat = cum_duration.view(b * t_x) 633 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 634 | path = path.view(b, t_x, t_y) 635 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 636 | path = path.unsqueeze(1).transpose(2, 3) * mask 637 | return path 638 | 639 | 640 | def clip_grad_value_(parameters, clip_value, norm_type=2): 641 | if isinstance(parameters, torch.Tensor): 642 | parameters = [parameters] 643 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 644 | norm_type = float(norm_type) 645 | if clip_value is not None: 646 | clip_value = float(clip_value) 647 | 648 | total_norm = 0 649 | for p in parameters: 650 | param_norm = p.grad.data.norm(norm_type) 651 | total_norm += param_norm.item() ** norm_type 652 | if clip_value is not None: 653 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 654 | total_norm = total_norm ** (1.0 / norm_type) 655 | return total_norm 656 | 657 | 658 | def get_current_time(): 659 | pass 660 | 661 | 662 | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: 663 | """ 664 | Args: 665 | lengths: 666 | A 1-D tensor containing sentence lengths. 667 | max_len: 668 | The length of masks. 669 | Returns: 670 | Return a 2-D bool tensor, where masked positions 671 | are filled with `True` and non-masked positions are 672 | filled with `False`. 673 | 674 | >>> lengths = torch.tensor([1, 3, 2, 5]) 675 | >>> make_pad_mask(lengths) 676 | tensor([[False, True, True, True, True], 677 | [False, False, False, True, True], 678 | [False, False, True, True, True], 679 | [False, False, False, False, False]]) 680 | """ 681 | assert lengths.ndim == 1, lengths.ndim 682 | max_len = max(max_len, lengths.max()) 683 | n = lengths.size(0) 684 | seq_range = torch.arange(0, max_len, device=lengths.device) 685 | expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) 686 | 687 | return expaned_lengths >= lengths.unsqueeze(-1) 688 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from typing import Optional 4 | 5 | import torch.nn as nn 6 | 7 | from lightning.fabric.loggers import CSVLogger, TensorBoardLogger 8 | from lightning.pytorch.loggers import WandbLogger 9 | 10 | ZERO = 1e-12 11 | 12 | 13 | def lr_lambda(epoch, warmup, gamma): 14 | if epoch <= warmup: 15 | return epoch / warmup + 1e-3 16 | else: 17 | return gamma ** (epoch - warmup) 18 | 19 | 20 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: 21 | total = 0 22 | for p in module.parameters(): 23 | if requires_grad is None or p.requires_grad == requires_grad: 24 | total += p.numel() 25 | return total 26 | 27 | 28 | def choose_logger( 29 | logger_name: str, 30 | log_dir, 31 | project: Optional[str] = None, 32 | comment: Optional[str] = None, 33 | *args, 34 | **kwargs, 35 | ): 36 | if logger_name == "csv": 37 | return CSVLogger(root_dir=log_dir, name="csv", *args, **kwargs) 38 | elif logger_name == "tensorboard": 39 | return TensorBoardLogger(root_dir=log_dir, name="tensorboard", *args, **kwargs) 40 | elif logger_name == "wandb": 41 | return WandbLogger( 42 | project=project, save_dir=log_dir, notes=comment, *args, **kwargs 43 | ) 44 | else: 45 | raise ValueError(f"`logger={logger_name}` is not a valid option.") 46 | 47 | 48 | def get_checkpoint_files(checkpoint_dir): 49 | checkpoint_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))) 50 | return checkpoint_files 51 | --------------------------------------------------------------------------------