├── .gitignore ├── README.md ├── data └── dataset │ └── metadata │ └── dataset_root.json ├── drawspeech ├── __init__.py ├── conditional_models.py ├── config │ ├── drawspeech_ljspeech_22k.yaml │ └── vae_ljspeech_22k.yaml ├── dataset_plugin.py ├── infer.py ├── losses │ ├── __init__.py │ └── contperceptual.py ├── modules │ ├── __init__.py │ ├── contour_predictor │ │ └── model.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── distributions.py │ │ ├── ema.py │ │ ├── model.py │ │ ├── nn.py │ │ ├── openaimodel.py │ │ └── x_transformer.py │ ├── fastspeech2 │ │ ├── modules.py │ │ └── tools.py │ ├── hifigan │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── models.py │ │ └── models_hifires.py │ ├── latent_diffusion │ │ ├── __init__.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ │ ├── __init__.py │ │ │ ├── dpm_solver.py │ │ │ └── sampler.py │ │ └── plms.py │ ├── latent_encoder │ │ ├── __init__.py │ │ └── autoencoder.py │ └── text_encoder │ │ ├── __init__.py │ │ ├── attentions.py │ │ ├── commons.py │ │ └── encoder.py ├── train │ ├── autoencoder.py │ └── latent_diffusion.py └── utilities │ ├── __init__.py │ ├── audio │ ├── __init__.py │ ├── audio_processing.py │ ├── stft.py │ └── tools.py │ ├── data │ ├── __init__.py │ └── dataset.py │ ├── diffusion_util.py │ ├── model_util.py │ ├── preprocessor │ ├── preprocess_frame_level.yaml │ ├── preprocess_one_sample.py │ ├── preprocess_phoneme_level.yaml │ └── preprocessor.py │ ├── sampler_util.py │ ├── text │ ├── __init__.py │ ├── cleaners.py │ ├── cmudict.py │ ├── lexicon │ │ ├── librispeech-lexicon.txt │ │ └── pinyin-lexicon-r.txt │ ├── numbers.py │ ├── pinyin.py │ └── symbols.py │ └── tools.py └── preprocessing.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | taming 3 | log 4 | esc50.zip 5 | ESC-50-master 6 | ckpt 7 | lightning_logs 8 | mlx_submit_* 9 | job_queue.sh 10 | *.wav 11 | # *.txt 12 | *.cleaned 13 | *.tar 14 | *.npy 15 | *.TextGrid 16 | TextGrid 17 | LJSpeech-1.1 18 | condor* 19 | wandb 20 | temp.py 21 | temp 22 | tests 23 | test.yaml 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # DrawSpeech: Expressive Speech Synthesis Using Prosodic Sketches as Control Conditions 3 | This paper was submitted to ICASSP 2025. 4 | 5 | ## Status 6 | This project is currently under active development. We are continuously updating and improving it, with more usage details and features to be released in the future. 7 | 8 | # Getting started 9 | 10 | ## Download dataset and checkpoints 11 | 1. Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) dataset and place the dataset into `data/dataset` with structure looks like below: 12 | ```plaintext 13 | data/dataset/LJSpeech-1.1 14 | ┣ metadata.csv 15 | ┣ wavs 16 | ┃ ┣ LJ001-0001.wav 17 | ┃ ┣ LJ001-0002.wav 18 | ┃ ┣ ... 19 | ┣ README 20 | ``` 21 | 2. Download the alignments of the LJSpeech dataset [LJSpeech.zip](https://drive.google.com/drive/folders/1DBRkALpPd6FL9gjHMmMEdHODmkgNIIK4). You have to unzip the files in `data/dataset/LJSpeech-1.1` 22 | 3. Download checkpoints (Coming Soon) 23 | 4. Uncompress the checkpoint tar file and place the content into **data/checkpoints/** 24 | 25 | ## Preprocessing 26 | ```python 27 | python preprocessing.py 28 | ``` 29 | 30 | ## Training 31 | 32 | Train the VAE (Optional) 33 | ```python 34 | CUDA_VISIBLE_DEVICES=0 python drawspeech/train/autoencoder.py -c drawspeech/config/vae_ljspeech_22k.yaml 35 | ``` 36 | 37 | If you don't want to train the VAE, you can just use the VAE checkpoint that we provide. 38 | - set the variable `reload_from_ckpt` in `drawspeech_ljspeech_22k.yaml` to `data/checkpoints/vae.ckpt` 39 | 40 | Train the DrawSpeech 41 | ```python 42 | CUDA_VISIBLE_DEVICES=0 python drawspeech/train/latent_diffusion.py -c drawspeech/config/drawspeech_ljspeech_22k.yaml 43 | ``` 44 | 45 | 46 | ## Inference 47 | 48 | If you have trained the model using `drawspeech_ljspeech_22k.yaml`, use the following syntax: 49 | ```shell 50 | CUDA_VISIBLE_DEVICES=0 python drawspeech/infer.py --config_yaml drawspeech/config/drawspeech_ljspeech_22k.yaml --list_inference tests/inference.json 51 | ``` 52 | 53 | If not, please specify the DrawSpeech checkpoint: 54 | ```shell 55 | CUDA_VISIBLE_DEVICES=0 python drawspeech/infer.py --config_yaml drawspeech/config/drawspeech_ljspeech_22k.yaml --list_inference tests/inference.json --reload_from_ckpt data/checkpoints/drawspeech.ckpt 56 | ``` 57 | 58 | ## Acknowledgement 59 | This repository borrows codes from the following repos. Many thanks to the authors for their great work. 60 | AudioLDM: https://github.com/haoheliu/AudioLDM-training-finetuning?tab=readme-ov-file#prepare-python-running-environment 61 | FastSpeech 2: https://github.com/ming024/FastSpeech2 62 | HiFi-GAN: https://github.com/jik876/hifi-gan 63 | 64 | -------------------------------------------------------------------------------- /data/dataset/metadata/dataset_root.json: -------------------------------------------------------------------------------- 1 | { 2 | "ljspeech": "./data/dataset/LJSpeech-1.1", 3 | 4 | "metadata":{ 5 | "path": { 6 | "ljspeech":{ 7 | "train": "./data/dataset/metadata/ljspeech/train.json", 8 | "test": "./data/dataset/metadata/ljspeech/test.json", 9 | "val": "./data/dataset/metadata/ljspeech/val.json" 10 | } 11 | } 12 | } 13 | } -------------------------------------------------------------------------------- /drawspeech/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/DrawSpeech_PyTorch/9db245d68e87826a175494f037eb4b9482d3a836/drawspeech/__init__.py -------------------------------------------------------------------------------- /drawspeech/config/drawspeech_ljspeech_22k.yaml: -------------------------------------------------------------------------------- 1 | metadata_root: "./data/dataset/metadata/dataset_root.json" 2 | log_directory: "./log/latent_diffusion" 3 | project: "drawspeech" 4 | precision: "high" 5 | 6 | variables: 7 | sampling_rate: &sampling_rate 22050 8 | mel_bins: &mel_bins 80 9 | latent_embed_dim: &latent_embed_dim 8 # the output channel of diffusion model 10 | latent_t_size: &latent_t_size 216 # TODO might need to change 11 | latent_f_size: &latent_f_size 20 12 | in_channels: &unet_in_channels 11 # the input channel of diffusion model 因为有concat这种condition方式,所以in_channel可以大于latent_embed_dim 13 | optimize_ddpm_parameter: &optimize_ddpm_parameter true 14 | optimize_gpt: &optimize_gpt true 15 | warmup_steps: &warmup_steps 2000 16 | 17 | data: 18 | train: ["ljspeech"] 19 | val: ["ljspeech"] 20 | test: ["ljspeech"] 21 | class_label_indices: null 22 | dataloader_add_ons: ["get_preprocessed_meta"] 23 | 24 | step: 25 | # val_check_interval: 200 26 | validation_every_n_epochs: 5 27 | save_checkpoint_every_n_steps: 1000 28 | num_sanity_val_steps: 1 29 | limit_val_batches: 2 30 | max_steps: 80000 31 | save_top_k: 3 32 | save_last: ture 33 | 34 | preprocessing: 35 | audio: 36 | sampling_rate: *sampling_rate 37 | max_wav_value: 32768.0 38 | duration: &duration 10.04 39 | stft: 40 | filter_length: 1024 41 | hop_length: 256 42 | win_length: 1024 43 | mel: 44 | n_mel_channels: *mel_bins 45 | mel_fmin: 0 46 | mel_fmax: 8000 47 | phoneme_pad_length: &phoneme_pad_length 135 48 | preprocessed_data: 49 | energy: data/dataset/metadata/ljspeech/phoneme_level/energy 50 | pitch: data/dataset/metadata/ljspeech/phoneme_level/pitch 51 | duration: data/dataset/metadata/ljspeech/phoneme_level/duration 52 | stats_json: &stats_json data/dataset/metadata/ljspeech/phoneme_level/stats.json 53 | feature: &feature_level phoneme_level 54 | 55 | model: 56 | target: drawspeech.modules.latent_diffusion.ddpm.LatentDiffusion 57 | params: 58 | # Autoencoder 59 | first_stage_config: 60 | base_learning_rate: 8.0e-06 61 | target: drawspeech.modules.latent_encoder.autoencoder.AutoencoderKL 62 | params: 63 | reload_from_ckpt: "log/latent_diffusion/vae_ljspeech_22k/checkpoints/checkpoint-79999.ckpt" 64 | sampling_rate: *sampling_rate 65 | batchsize: 4 66 | monitor: val/rec_loss # actually we use global_step 67 | image_key: fbank 68 | subband: 1 69 | embed_dim: *latent_embed_dim 70 | time_shuffle: 1 71 | lossconfig: 72 | target: drawspeech.losses.LPIPSWithDiscriminator 73 | params: 74 | disc_start: 50001 75 | kl_weight: 1000.0 76 | disc_weight: 0.5 77 | disc_in_channels: 1 78 | ddconfig: 79 | double_z: true 80 | mel_bins: *mel_bins # The frequency bins of mel spectrogram 81 | z_channels: 8 82 | resolution: *latent_t_size 83 | downsample_time: false 84 | in_channels: 1 85 | out_ch: 1 86 | ch: 128 87 | ch_mult: 88 | - 1 89 | - 2 90 | - 4 91 | num_res_blocks: 2 92 | attn_resolutions: [] 93 | dropout: 0.0 94 | 95 | # Other parameters 96 | base_learning_rate: 1.0e-4 97 | warmup_steps: *warmup_steps 98 | optimize_ddpm_parameter: *optimize_ddpm_parameter 99 | sampling_rate: *sampling_rate 100 | batchsize: 32 101 | linear_start: 0.0015 102 | linear_end: 0.0195 103 | num_timesteps_cond: 1 104 | log_every_t: 200 105 | timesteps: 1000 106 | unconditional_prob_cfg: 0.1 107 | parameterization: eps # [eps, x0, v] 108 | first_stage_key: fbank 109 | latent_t_size: *latent_t_size # TODO might need to change 110 | latent_f_size: *latent_f_size 111 | channels: *latent_embed_dim # TODO might need to change 112 | monitor: train/loss 113 | scale_by_std: true 114 | use_tts_data_for_training: true # useless, might need to delete this argument 115 | unet_config: 116 | target: drawspeech.modules.diffusionmodules.openaimodel.UNetModel 117 | params: 118 | image_size: 64 119 | in_channels: *unet_in_channels # The input channel of the UNet model 120 | out_channels: *latent_embed_dim # TODO might need to change 121 | model_channels: 128 # TODO might need to change 122 | attention_resolutions: 123 | # - 8 124 | - 4 125 | - 2 126 | num_res_blocks: 2 127 | channel_mult: 128 | - 1 129 | - 2 130 | - 3 131 | # - 5 132 | num_head_channels: 32 133 | use_spatial_transformer: true 134 | transformer_depth: 1 135 | use_scale_shift_norm: ture 136 | 137 | cond_stage_config: 138 | concat_text_encoder_with_varianceadaptor: 139 | cond_stage_key: [phoneme_idx, pitch, pitch_length, energy, energy_length, pitch_sketch, energy_sketch, phoneme_duration, mel_mask] 140 | conditioning_key: concat 141 | target: drawspeech.conditional_models.TextEncoderwithVarianceAdaptor 142 | params: 143 | vocabs_size: 360 144 | pad_token_id: 0 145 | pad_length: *phoneme_pad_length 146 | output_size: &phoneme_emb_dim 256 147 | latent_t_size: *latent_t_size 148 | latent_f_size: *latent_f_size 149 | adaptor_params: 150 | phoneme_embedding_dim: *phoneme_emb_dim 151 | pitch_embedding_dim: *phoneme_emb_dim 152 | energy_embedding_dim: *phoneme_emb_dim 153 | pitch_quantization: linear 154 | energy_quantization: linear 155 | n_bins: 256 156 | stats_json: *stats_json 157 | pitch_feature_level: *feature_level 158 | energy_feature_level: *feature_level 159 | predict_detailed_curve: ture 160 | prob_drop_pitch: 0.2 161 | prob_drop_energy: 0.2 162 | predictor_params: 163 | pitch_embedding_dim: *phoneme_emb_dim 164 | energy_embedding_dim: *phoneme_emb_dim 165 | ffn_dim: 512 166 | n_bins: 256 167 | n_heads: 2 168 | n_layers: 2 169 | concat_pitch_energy_sketch: 170 | cond_stage_key: [pitch, pitch_sketch, pitch_length, energy, energy_sketch, energy_length, noncond_mel_mask, noncond_predicted_pitch, noncond_predicted_energy] 171 | conditioning_key: concat 172 | target: drawspeech.conditional_models.SketchEncoder 173 | params: 174 | latent_t_size: *latent_t_size 175 | latent_f_size: *latent_f_size 176 | stats_json: *stats_json 177 | pitch_feature_level: *feature_level 178 | energy_feature_level: *feature_level 179 | 180 | evaluation_params: 181 | unconditional_guidance_scale: 3.5 182 | ddim_sampling_steps: 200 183 | n_candidates_per_samples: 1 184 | -------------------------------------------------------------------------------- /drawspeech/config/vae_ljspeech_22k.yaml: -------------------------------------------------------------------------------- 1 | metadata_root: "./data/dataset/metadata/dataset_root.json" 2 | log_directory: "./log/latent_diffusion" 3 | project: "drawspeech" 4 | precision: "high" 5 | 6 | variables: 7 | sampling_rate: &sampling_rate 22050 8 | mel_bins: &mel_bins 80 9 | latent_embed_dim: &latent_embed_dim 8 10 | latent_t_size: &latent_t_size 216 # TODO might need to change 11 | latent_f_size: &latent_f_size 20 12 | in_channels: &unet_in_channels 8 13 | optimize_ddpm_parameter: &optimize_ddpm_parameter true 14 | optimize_gpt: &optimize_gpt true 15 | warmup_steps: &warmup_steps 2000 # only works in LDM training 16 | 17 | data: 18 | train: ["ljspeech"] 19 | val: "ljspeech" 20 | test: "ljspeech" 21 | class_label_indices: null 22 | dataloader_add_ons: [] 23 | 24 | step: 25 | # val_check_interval: 5000 26 | validation_every_n_epochs: 2 27 | save_checkpoint_every_n_steps: 1000 28 | limit_val_batches: 2 29 | max_steps: 80000 30 | save_top_k: 1 31 | save_last: ture 32 | 33 | preprocessing: 34 | audio: 35 | sampling_rate: *sampling_rate 36 | max_wav_value: 32768.0 37 | duration: 10.04 38 | stft: 39 | filter_length: 1024 40 | hop_length: 256 41 | win_length: 1024 42 | mel: 43 | n_mel_channels: *mel_bins 44 | mel_fmin: 0 45 | mel_fmax: 8000 46 | preprocessed_data: 47 | energy: data/dataset/metadata/ljspeech/phoneme_level/energy 48 | pitch: data/dataset/metadata/ljspeech/phoneme_level/pitch 49 | duration: data/dataset/metadata/ljspeech/phoneme_level/duration 50 | 51 | model: 52 | base_learning_rate: 8.0e-06 53 | target: drawspeech.modules.latent_encoder.autoencoder.AutoencoderKL 54 | params: 55 | # reload_from_ckpt: "data/checkpoints/vae_mel_16k_64bins.ckpt" 56 | sampling_rate: *sampling_rate 57 | batchsize: 8 58 | monitor: val/rec_loss # actually we use global_step 59 | image_key: fbank 60 | subband: 1 61 | embed_dim: *latent_embed_dim 62 | time_shuffle: 1 63 | lossconfig: 64 | target: drawspeech.losses.LPIPSWithDiscriminator 65 | params: 66 | disc_start: 50001 67 | kl_weight: 1000.0 68 | disc_weight: 0.5 69 | disc_in_channels: 1 70 | ddconfig: 71 | double_z: true 72 | mel_bins: *mel_bins # The frequency bins of mel spectrogram 73 | z_channels: 8 74 | resolution: *latent_t_size 75 | downsample_time: false 76 | in_channels: 1 77 | out_ch: 1 78 | ch: 128 79 | ch_mult: 80 | - 1 81 | - 2 82 | - 4 83 | num_res_blocks: 2 84 | attn_resolutions: [] 85 | dropout: 0.0 86 | -------------------------------------------------------------------------------- /drawspeech/dataset_plugin.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import json 6 | 7 | from drawspeech.utilities.text import text_to_sequence 8 | from drawspeech.utilities.preprocessor.preprocess_one_sample import preprocess_english 9 | from drawspeech.utilities.tools import sketch_extractor, min_max_normalize 10 | 11 | 12 | PITCH_MIN, ENERGY_MIN = None, None 13 | def get_preprocessed_meta(config, dl_output, metadata): 14 | basename = os.path.split(metadata["wav"])[-1].replace(".wav", "") 15 | speaker = "LJSpeech" 16 | 17 | global PITCH_MIN, ENERGY_MIN 18 | if PITCH_MIN is None or ENERGY_MIN is None: 19 | # print("Loading pitch and energy stats from %s" % config["preprocessing"]["preprocessed_data"]["stats_json"]) 20 | with open(config["preprocessing"]["preprocessed_data"]["stats_json"], "r") as f: 21 | stats = json.load(f) 22 | PITCH_MIN, pitch_max, pitch_mean, pitch_std = stats["pitch"] 23 | ENERGY_MIN, energy_max, energy_mean, energy_std = stats["energy"] 24 | 25 | pad_token_id = 0 26 | r = 2 ** (len(config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]["ch_mult"]) - 1) # 4 27 | mel_pad_length = config["variables"]["latent_t_size"] * r 28 | phoneme_pad_length = config["preprocessing"]["phoneme_pad_length"] 29 | 30 | feature_level = config["preprocessing"]["preprocessed_data"]["feature"] 31 | if feature_level == "phoneme_level": 32 | pitch_pad_length = phoneme_pad_length 33 | energy_pad_length = phoneme_pad_length 34 | elif feature_level == "frame_level": 35 | pitch_pad_length = mel_pad_length 36 | energy_pad_length = mel_pad_length 37 | else: 38 | raise ValueError("Unknown feature level %s" % feature_level) 39 | duration_pad_length = phoneme_pad_length 40 | 41 | # load phoneme 42 | if "phonemes" in metadata.keys(): 43 | phoneme_idx = torch.LongTensor(text_to_sequence(metadata["phonemes"], ["english_cleaners"])) 44 | else: 45 | assert "transcription" in metadata.keys(), "You must provide the phoneme or transcription in the metadata" 46 | phoneme_idx, phoneme = preprocess_english(metadata["transcription"]) 47 | phoneme_idx = torch.LongTensor(phoneme_idx) 48 | phoneme_idx = F.pad(phoneme_idx, (0, phoneme_pad_length - phoneme_idx.size(0)), value=pad_token_id) if phoneme_idx.size(0) < phoneme_pad_length else phoneme_idx[:phoneme_pad_length] 49 | 50 | # load pitch and pitch sketch 51 | pitch_path = metadata["pitch"] if "pitch" in metadata.keys() else os.path.join(config["preprocessing"]["preprocessed_data"]["pitch"], "{}-pitch-{}.npy".format(speaker, basename)) 52 | if os.path.exists(pitch_path): 53 | original_pitch = np.load(pitch_path) 54 | pitch = torch.from_numpy(original_pitch).float() 55 | pitch_length = torch.LongTensor([min(pitch.size(0), pitch_pad_length)]) 56 | pitch = F.pad(pitch, (0, pitch_pad_length - pitch.size(0)), value=PITCH_MIN) if pitch.size(0) < pitch_pad_length else pitch[:pitch_pad_length] 57 | else: 58 | original_pitch = None 59 | pitch = "" 60 | pitch_length = "" 61 | 62 | if "pitch_sketch" in metadata.keys() and metadata["pitch_sketch"]: # under inference mode 63 | assert original_pitch is None, "You cannot provide both pitch and pitch_sketch in the metadata" 64 | pitch_sketch_path = metadata["pitch_sketch"] 65 | pitch_sketch = np.load(pitch_sketch_path) 66 | pitch_sketch = torch.from_numpy(pitch_sketch).float()[None, None, :] 67 | pitch_sketch = F.interpolate(pitch_sketch, size=pitch_pad_length, mode="linear", align_corners=True).squeeze(0).squeeze(0) 68 | pitch_length = torch.LongTensor([min(pitch_sketch.size(0), pitch_pad_length)]) 69 | elif original_pitch is not None: 70 | pitch_sketch = sketch_extractor(original_pitch) 71 | pitch_sketch = torch.from_numpy(pitch_sketch).float() 72 | pitch_sketch = F.pad(pitch_sketch, (0, pitch_pad_length - pitch_sketch.size(0)), value=PITCH_MIN) if pitch_sketch.size(0) < pitch_pad_length else pitch_sketch[:pitch_pad_length] 73 | pitch_sketch = min_max_normalize(pitch_sketch) 74 | else: 75 | pitch_sketch = "" 76 | 77 | # load energy and energy sketch 78 | energy_path = metadata["energy"] if "energy" in metadata.keys() else os.path.join(config["preprocessing"]["preprocessed_data"]["energy"], "{}-energy-{}.npy".format(speaker, basename)) 79 | if os.path.exists(energy_path): 80 | original_energy = np.load(energy_path) 81 | energy = torch.from_numpy(original_energy).float() 82 | energy_length = torch.LongTensor([min(energy.size(0), energy_pad_length)]) 83 | energy = F.pad(energy, (0, energy_pad_length - energy.size(0)), value=ENERGY_MIN) if energy.size(0) < energy_pad_length else energy[:energy_pad_length] 84 | else: 85 | original_energy = None 86 | energy = "" 87 | energy_length = "" 88 | 89 | if "energy_sketch" in metadata.keys() and metadata["energy_sketch"]: # under inference mode 90 | assert original_energy is None, "You cannot provide both energy and energy_sketch in the metadata" 91 | energy_sketch_path = metadata["energy_sketch"] 92 | energy_sketch = np.load(energy_sketch_path) 93 | energy_sketch = torch.from_numpy(energy_sketch).float()[None, None, :] 94 | energy_sketch = F.interpolate(energy_sketch, size=energy_pad_length, mode="linear", align_corners=True).squeeze(0).squeeze(0) 95 | energy_length = torch.LongTensor([min(energy_sketch.size(0), energy_pad_length)]) 96 | elif original_energy is not None: 97 | energy_sketch = sketch_extractor(original_energy) 98 | energy_sketch = torch.from_numpy(energy_sketch).float() 99 | energy_sketch = F.pad(energy_sketch, (0, energy_pad_length - energy_sketch.size(0)), value=ENERGY_MIN) if energy_sketch.size(0) < energy_pad_length else energy_sketch[:energy_pad_length] 100 | energy_sketch = min_max_normalize(energy_sketch) 101 | else: 102 | energy_sketch = "" 103 | 104 | # load phoneme duration 105 | duration_path = os.path.join(config["preprocessing"]["preprocessed_data"]["duration"], "{}-duration-{}.npy".format(speaker, basename)) 106 | if os.path.exists(duration_path): 107 | duration = np.load(duration_path) 108 | duration = torch.from_numpy(duration).float() 109 | duration = F.pad(duration, (0, duration_pad_length - duration.size(0))) if duration.size(0) < duration_pad_length else duration[:duration_pad_length] 110 | else: 111 | duration = "" 112 | 113 | return { 114 | "phoneme_idx": phoneme_idx, 115 | "pitch": pitch, 116 | "pitch_sketch": pitch_sketch, 117 | "pitch_length": pitch_length, 118 | "energy": energy, 119 | "energy_sketch": energy_sketch, 120 | "energy_length": energy_length, 121 | "phoneme_duration": duration, 122 | } 123 | 124 | -------------------------------------------------------------------------------- /drawspeech/infer.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | import argparse 5 | import yaml 6 | import torch 7 | import json 8 | 9 | from torch.utils.data import DataLoader 10 | from pytorch_lightning import seed_everything 11 | from drawspeech.utilities.tools import get_restore_step 12 | from drawspeech.utilities.model_util import instantiate_from_config 13 | from drawspeech.utilities.tools import build_dataset_json_from_list 14 | from drawspeech.conditional_models import * 15 | from drawspeech.utilities.data.dataset import AudioDataset 16 | 17 | def set_cond_infer_mode(latent_diffusion): 18 | for key in latent_diffusion.cond_stage_model_metadata.keys(): 19 | model_idx = latent_diffusion.cond_stage_model_metadata[key]["model_idx"] 20 | if isinstance(latent_diffusion.cond_stage_models[model_idx], TextEncoderwithVarianceAdaptor): 21 | print("Set infer mode for TextEncoderwithVarianceAdaptor") 22 | latent_diffusion.cond_stage_models[model_idx].infer = True 23 | if isinstance(latent_diffusion.cond_stage_models[model_idx], SketchEncoder): 24 | print("Set infer mode for SketchEncoder") 25 | latent_diffusion.cond_stage_models[model_idx].infer = True 26 | 27 | return latent_diffusion 28 | 29 | def infer(dataset_json, configs, config_yaml_path, exp_group_name, exp_name): 30 | if "seed" in configs.keys(): 31 | seed_everything(configs["seed"]) 32 | else: 33 | print("SEED EVERYTHING TO 0") 34 | seed_everything(0) 35 | 36 | if "precision" in configs.keys(): 37 | torch.set_float32_matmul_precision(configs["precision"]) 38 | 39 | log_path = configs["log_directory"] 40 | 41 | if "dataloader_add_ons" in configs["data"].keys(): 42 | dataloader_add_ons = configs["data"]["dataloader_add_ons"] 43 | else: 44 | dataloader_add_ons = [] 45 | 46 | val_dataset = AudioDataset( 47 | configs, split="test", add_ons=dataloader_add_ons, dataset_json=dataset_json 48 | ) 49 | 50 | val_loader = DataLoader( 51 | val_dataset, 52 | batch_size=1, 53 | ) 54 | 55 | try: 56 | config_reload_from_ckpt = configs["reload_from_ckpt"] 57 | except: 58 | config_reload_from_ckpt = None 59 | 60 | checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") 61 | 62 | wandb_path = os.path.join(log_path, exp_group_name, exp_name) 63 | 64 | os.makedirs(checkpoint_path, exist_ok=True) 65 | shutil.copy(config_yaml_path, wandb_path) 66 | 67 | if config_reload_from_ckpt is not None: 68 | resume_from_checkpoint = config_reload_from_ckpt 69 | print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) 70 | elif len(os.listdir(checkpoint_path)) > 0: 71 | print("Load checkpoint from path: %s" % checkpoint_path) 72 | restore_step, n_step = get_restore_step(checkpoint_path) 73 | resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) 74 | print("Resume from checkpoint", resume_from_checkpoint) 75 | else: 76 | print("Train from scratch") 77 | resume_from_checkpoint = None 78 | 79 | latent_diffusion = instantiate_from_config(configs["model"]) 80 | latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name) 81 | 82 | guidance_scale = configs["model"]["params"]["evaluation_params"][ 83 | "unconditional_guidance_scale" 84 | ] 85 | ddim_sampling_steps = configs["model"]["params"]["evaluation_params"][ 86 | "ddim_sampling_steps" 87 | ] 88 | n_candidates_per_samples = configs["model"]["params"]["evaluation_params"][ 89 | "n_candidates_per_samples" 90 | ] 91 | 92 | checkpoint = torch.load(resume_from_checkpoint) 93 | latent_diffusion.load_state_dict(checkpoint["state_dict"]) 94 | 95 | latent_diffusion = set_cond_infer_mode(latent_diffusion) 96 | 97 | latent_diffusion.eval() 98 | latent_diffusion = latent_diffusion.cuda() 99 | 100 | latent_diffusion.generate_sample( 101 | val_loader, 102 | unconditional_guidance_scale=guidance_scale, 103 | ddim_steps=ddim_sampling_steps, 104 | n_gen=n_candidates_per_samples, 105 | ) 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser() 110 | 111 | parser.add_argument( 112 | "-c", 113 | "--config_yaml", 114 | type=str, 115 | required=False, 116 | help="path to config .yaml file", 117 | ) 118 | 119 | parser.add_argument( 120 | "-l", 121 | "--list_inference", 122 | type=str, 123 | required=False, 124 | help="The filelist that contain captions (and optionally filenames)", 125 | ) 126 | parser.add_argument( 127 | "-reload_from_ckpt", 128 | "--reload_from_ckpt", 129 | type=str, 130 | required=False, 131 | help="the checkpoint path for the model", 132 | ) 133 | 134 | args = parser.parse_args() 135 | 136 | assert torch.cuda.is_available(), "CUDA is not available" 137 | 138 | config_yaml = args.config_yaml 139 | if args.list_inference.endswith(".json"): 140 | dataset_json = json.load(open(args.list_inference, "r")) 141 | else: 142 | dataset_json = build_dataset_json_from_list(args.list_inference) 143 | exp_name = os.path.basename(config_yaml.split(".")[0]) 144 | exp_group_name = os.path.basename(os.path.dirname(config_yaml)) 145 | 146 | config_yaml_path = os.path.join(config_yaml) 147 | config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) 148 | 149 | if args.reload_from_ckpt != None: 150 | config_yaml["reload_from_ckpt"] = args.reload_from_ckpt 151 | 152 | if "pitch" in dataset_json.keys() and dataset_json["pitch"] != "": 153 | config_yaml["preprocessing"]["preprocessed_data"]["pitch"] = dataset_json["pitch"] 154 | if "energy" in dataset_json.keys() and dataset_json["energy"] != "": 155 | config_yaml["preprocessing"]["preprocessed_data"]["energy"] = dataset_json["energy"] 156 | if "duration" in dataset_json.keys() and dataset_json["duration"] != "": 157 | config_yaml["preprocessing"]["preprocessed_data"]["duration"] = dataset_json["duration"] 158 | 159 | infer(dataset_json, config_yaml, config_yaml_path, exp_group_name, exp_name) 160 | -------------------------------------------------------------------------------- /drawspeech/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .contperceptual import LPIPSWithDiscriminator 2 | -------------------------------------------------------------------------------- /drawspeech/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__( 9 | self, 10 | disc_start, 11 | logvar_init=0.0, 12 | kl_weight=1.0, 13 | pixelloss_weight=1.0, 14 | disc_num_layers=3, 15 | disc_in_channels=3, 16 | disc_factor=1.0, 17 | disc_weight=1.0, 18 | perceptual_weight=1.0, 19 | use_actnorm=False, 20 | disc_conditional=False, 21 | disc_loss="hinge", 22 | ): 23 | super().__init__() 24 | assert disc_loss in ["hinge", "vanilla"] 25 | self.kl_weight = kl_weight 26 | self.pixel_weight = pixelloss_weight 27 | self.perceptual_loss = LPIPS().eval() 28 | self.perceptual_weight = perceptual_weight 29 | # output log variance 30 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 31 | 32 | self.discriminator = NLayerDiscriminator( 33 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm 34 | ).apply(weights_init) 35 | self.discriminator_iter_start = disc_start 36 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 37 | self.disc_factor = disc_factor 38 | self.discriminator_weight = disc_weight 39 | self.disc_conditional = disc_conditional 40 | 41 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 42 | if last_layer is not None: 43 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 44 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 45 | else: 46 | nll_grads = torch.autograd.grad( 47 | nll_loss, self.last_layer[0], retain_graph=True 48 | )[0] 49 | g_grads = torch.autograd.grad( 50 | g_loss, self.last_layer[0], retain_graph=True 51 | )[0] 52 | 53 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 54 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 55 | d_weight = d_weight * self.discriminator_weight 56 | return d_weight 57 | 58 | def forward( 59 | self, 60 | inputs, 61 | reconstructions, 62 | posteriors, 63 | optimizer_idx, 64 | global_step, 65 | waveform=None, 66 | rec_waveform=None, 67 | last_layer=None, 68 | cond=None, 69 | split="train", 70 | weights=None, 71 | ): 72 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 73 | 74 | # Always true 75 | if self.perceptual_weight > 0: 76 | p_loss = self.perceptual_loss( 77 | inputs.contiguous(), reconstructions.contiguous() 78 | ) 79 | rec_loss = rec_loss + self.perceptual_weight * p_loss 80 | 81 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 82 | weighted_nll_loss = nll_loss 83 | if weights is not None: 84 | weighted_nll_loss = weights * nll_loss 85 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 86 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 87 | kl_loss = posteriors.kl() 88 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 89 | 90 | # now the GAN part 91 | if optimizer_idx == 0: 92 | # generator update 93 | if cond is None: 94 | assert not self.disc_conditional 95 | logits_fake = self.discriminator(reconstructions.contiguous()) 96 | else: 97 | assert self.disc_conditional 98 | logits_fake = self.discriminator( 99 | torch.cat((reconstructions.contiguous(), cond), dim=1) 100 | ) 101 | g_loss = -torch.mean(logits_fake) 102 | 103 | if self.disc_factor > 0.0: 104 | try: 105 | d_weight = self.calculate_adaptive_weight( 106 | nll_loss, g_loss, last_layer=last_layer 107 | ) 108 | except RuntimeError: 109 | assert not self.training 110 | d_weight = torch.tensor(0.0) 111 | else: 112 | d_weight = torch.tensor(0.0) 113 | 114 | disc_factor = adopt_weight( 115 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 116 | ) 117 | loss = ( 118 | weighted_nll_loss 119 | + self.kl_weight * kl_loss 120 | + d_weight * disc_factor * g_loss 121 | ) 122 | 123 | log = { 124 | "{}/total_loss".format(split): loss.clone().detach().mean(), 125 | "{}/logvar".format(split): self.logvar.detach(), 126 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 127 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 128 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 129 | "{}/d_weight".format(split): d_weight.detach(), 130 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 131 | "{}/g_loss".format(split): g_loss.detach().mean(), 132 | } 133 | return loss, log 134 | 135 | if optimizer_idx == 1: 136 | # second pass for discriminator update 137 | if cond is None: 138 | logits_real = self.discriminator(inputs.contiguous().detach()) 139 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 140 | else: 141 | logits_real = self.discriminator( 142 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 143 | ) 144 | logits_fake = self.discriminator( 145 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 146 | ) 147 | 148 | disc_factor = adopt_weight( 149 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 150 | ) 151 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 152 | 153 | log = { 154 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 155 | "{}/logits_real".format(split): logits_real.detach().mean(), 156 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 157 | } 158 | return d_loss, log 159 | -------------------------------------------------------------------------------- /drawspeech/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/DrawSpeech_PyTorch/9db245d68e87826a175494f037eb4b9482d3a836/drawspeech/modules/__init__.py -------------------------------------------------------------------------------- /drawspeech/modules/contour_predictor/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import drawspeech.modules.text_encoder.attentions as attentions 5 | 6 | class Sketch2ContourPredictor(nn.Module): 7 | def __init__(self, pitch_embedding_dim, energy_embedding_dim, ffn_dim, n_bins, n_heads, n_layers): 8 | super(Sketch2ContourPredictor, self).__init__() 9 | 10 | pitch_min = energy_min = 0 11 | pitch_max = energy_max = 1 12 | self.pitch_bins = nn.Parameter(torch.linspace(pitch_min, pitch_max, n_bins - 1),requires_grad=False) 13 | self.energy_bins = nn.Parameter(torch.linspace(energy_min, energy_max, n_bins - 1),requires_grad=False) 14 | 15 | self.pitch_embedding = nn.Embedding(n_bins, pitch_embedding_dim) 16 | self.energy_embedding = nn.Embedding(n_bins, energy_embedding_dim) 17 | 18 | assert pitch_embedding_dim == energy_embedding_dim 19 | embedding_dim = pitch_embedding_dim 20 | self.encoder = attentions.Encoder(embedding_dim, ffn_dim, n_heads, n_layers, kernel_size=3, p_dropout=0.1) 21 | self.linear_layer = nn.Linear(embedding_dim, 2) 22 | 23 | def get_pitch_embedding(self, x, mask): 24 | if mask is not None: 25 | x = x.masked_fill(mask == 0, 0.0) 26 | embedding = self.pitch_embedding(torch.bucketize(x, self.pitch_bins)) 27 | return embedding 28 | 29 | def get_energy_embedding(self, x, mask): 30 | if mask is not None: 31 | x = x.masked_fill(mask == 0, 0.0) 32 | embedding = self.energy_embedding(torch.bucketize(x, self.energy_bins)) 33 | return embedding 34 | 35 | def forward(self, x, pitch_sketch, energy_sketch, mask=None): 36 | ''' 37 | x: expanded text embedding, [b, t, h] 38 | mask: [b, t], 1 for real data, 0 for padding 39 | ''' 40 | 41 | if pitch_sketch is None and energy_sketch is None: 42 | return None, None 43 | 44 | if pitch_sketch is None: 45 | pitch_sketch_embedding = 0 46 | else: 47 | pitch_sketch_embedding = self.get_pitch_embedding(pitch_sketch, mask) 48 | 49 | if energy_sketch is None: 50 | energy_sketch_embedding = 0 51 | else: 52 | energy_sketch_embedding = self.get_energy_embedding(energy_sketch, mask) 53 | 54 | x = x + pitch_sketch_embedding + energy_sketch_embedding 55 | 56 | x = x.transpose(1, 2) # [b, h, t] 57 | mask = mask.unsqueeze(1).to(x.dtype) # [b, 1, t], 1 for real data, 0 for padding 58 | x = self.encoder(x * mask, mask) 59 | 60 | out = self.linear_layer(x.transpose(1, 2)) * mask.transpose(1, 2) 61 | pitch, energy = out.chunk(2, dim=-1) 62 | 63 | pitch = pitch.squeeze(-1) 64 | energy = energy.squeeze(-1) 65 | 66 | return pitch, energy 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /drawspeech/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/DrawSpeech_PyTorch/9db245d68e87826a175494f037eb4b9482d3a836/drawspeech/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /drawspeech/modules/diffusionmodules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from drawspeech.utilities.diffusion_util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return {el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = ( 53 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 54 | if not glu 55 | else GEGLU(dim, inner_dim) 56 | ) 57 | 58 | self.net = nn.Sequential( 59 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 60 | ) 61 | 62 | def forward(self, x): 63 | return self.net(x) 64 | 65 | 66 | def zero_module(module): 67 | """ 68 | Zero out the parameters of a module and return it. 69 | """ 70 | for p in module.parameters(): 71 | p.detach().zero_() 72 | return module 73 | 74 | 75 | def Normalize(in_channels): 76 | return torch.nn.GroupNorm( 77 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 78 | ) 79 | 80 | 81 | class LinearAttention(nn.Module): 82 | def __init__(self, dim, heads=4, dim_head=32): 83 | super().__init__() 84 | self.heads = heads 85 | hidden_dim = dim_head * heads 86 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 87 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 88 | 89 | def forward(self, x): 90 | b, c, h, w = x.shape 91 | qkv = self.to_qkv(x) 92 | q, k, v = rearrange( 93 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 94 | ) 95 | k = k.softmax(dim=-1) 96 | context = torch.einsum("bhdn,bhen->bhde", k, v) 97 | out = torch.einsum("bhde,bhdn->bhen", context, q) 98 | out = rearrange( 99 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 100 | ) 101 | return self.to_out(out) 102 | 103 | 104 | class SpatialSelfAttention(nn.Module): 105 | def __init__(self, in_channels): 106 | super().__init__() 107 | self.in_channels = in_channels 108 | 109 | self.norm = Normalize(in_channels) 110 | self.q = torch.nn.Conv2d( 111 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 112 | ) 113 | self.k = torch.nn.Conv2d( 114 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 115 | ) 116 | self.v = torch.nn.Conv2d( 117 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 118 | ) 119 | self.proj_out = torch.nn.Conv2d( 120 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 121 | ) 122 | 123 | def forward(self, x): 124 | h_ = x 125 | h_ = self.norm(h_) 126 | q = self.q(h_) 127 | k = self.k(h_) 128 | v = self.v(h_) 129 | 130 | # compute attention 131 | b, c, h, w = q.shape 132 | q = rearrange(q, "b c h w -> b (h w) c") 133 | k = rearrange(k, "b c h w -> b c (h w)") 134 | w_ = torch.einsum("bij,bjk->bik", q, k) 135 | 136 | w_ = w_ * (int(c) ** (-0.5)) 137 | w_ = torch.nn.functional.softmax(w_, dim=2) 138 | 139 | # attend to values 140 | v = rearrange(v, "b c h w -> b c (h w)") 141 | w_ = rearrange(w_, "b i j -> b j i") 142 | h_ = torch.einsum("bij,bjk->bik", v, w_) 143 | h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) 144 | h_ = self.proj_out(h_) 145 | 146 | return x + h_ 147 | 148 | 149 | # class CrossAttention(nn.Module): 150 | # """ 151 | # ### Cross Attention Layer 152 | # This falls-back to self-attention when conditional embeddings are not specified. 153 | # """ 154 | 155 | # use_flash_attention: bool = True 156 | 157 | # # use_flash_attention: bool = False 158 | # def __init__( 159 | # self, 160 | # query_dim, 161 | # context_dim=None, 162 | # heads=8, 163 | # dim_head=64, 164 | # dropout=0.0, 165 | # is_inplace: bool = True, 166 | # ): 167 | # # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): 168 | # """ 169 | # :param d_model: is the input embedding size 170 | # :param n_heads: is the number of attention heads 171 | # :param d_head: is the size of a attention head 172 | # :param d_cond: is the size of the conditional embeddings 173 | # :param is_inplace: specifies whether to perform the attention softmax computation inplace to 174 | # save memory 175 | # """ 176 | # super().__init__() 177 | 178 | # self.is_inplace = is_inplace 179 | # self.n_heads = heads 180 | # self.d_head = dim_head 181 | 182 | # # Attention scaling factor 183 | # self.scale = dim_head**-0.5 184 | 185 | # # The normal self-attention layer 186 | # if context_dim is None: 187 | # context_dim = query_dim 188 | 189 | # # Query, key and value mappings 190 | # d_attn = dim_head * heads 191 | # self.to_q = nn.Linear(query_dim, d_attn, bias=False) 192 | # self.to_k = nn.Linear(context_dim, d_attn, bias=False) 193 | # self.to_v = nn.Linear(context_dim, d_attn, bias=False) 194 | 195 | # # Final linear layer 196 | # self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) 197 | 198 | # # Setup [flash attention](https://github.com/HazyResearch/flash-attention). 199 | # # Flash attention is only used if it's installed 200 | # # and `CrossAttention.use_flash_attention` is set to `True`. 201 | # try: 202 | # # You can install flash attention by cloning their Github repo, 203 | # # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) 204 | # # and then running `python setup.py install` 205 | # from flash_attn.flash_attention import FlashAttention 206 | 207 | # self.flash = FlashAttention() 208 | # # Set the scale for scaled dot-product attention. 209 | # self.flash.softmax_scale = self.scale 210 | # # Set to `None` if it's not installed 211 | # except ImportError: 212 | # self.flash = None 213 | 214 | # def forward(self, x, context=None, mask=None): 215 | # """ 216 | # :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` 217 | # :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` 218 | # """ 219 | 220 | # # If `cond` is `None` we perform self attention 221 | # has_cond = context is not None 222 | # if not has_cond: 223 | # context = x 224 | 225 | # # Get query, key and value vectors 226 | # q = self.to_q(x) 227 | # k = self.to_k(context) 228 | # v = self.to_v(context) 229 | 230 | # # Use flash attention if it's available and the head size is less than or equal to `128` 231 | # if ( 232 | # CrossAttention.use_flash_attention 233 | # and self.flash is not None 234 | # and not has_cond 235 | # and self.d_head <= 128 236 | # ): 237 | # return self.flash_attention(q, k, v) 238 | # # Otherwise, fallback to normal attention 239 | # else: 240 | # return self.normal_attention(q, k, v) 241 | 242 | # def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): 243 | # """ 244 | # #### Flash Attention 245 | # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` 246 | # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` 247 | # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` 248 | # """ 249 | 250 | # # Get batch size and number of elements along sequence axis (`width * height`) 251 | # batch_size, seq_len, _ = q.shape 252 | 253 | # # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of 254 | # # shape `[batch_size, seq_len, 3, n_heads * d_head]` 255 | # qkv = torch.stack((q, k, v), dim=2) 256 | # # Split the heads 257 | # qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) 258 | 259 | # # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to 260 | # # fit this size. 261 | # if self.d_head <= 32: 262 | # pad = 32 - self.d_head 263 | # elif self.d_head <= 64: 264 | # pad = 64 - self.d_head 265 | # elif self.d_head <= 128: 266 | # pad = 128 - self.d_head 267 | # else: 268 | # raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") 269 | 270 | # # Pad the heads 271 | # if pad: 272 | # qkv = torch.cat( 273 | # (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 274 | # ) 275 | 276 | # # Compute attention 277 | # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ 278 | # # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` 279 | # # TODO here I add the dtype changing 280 | # out, _ = self.flash(qkv.type(torch.float16)) 281 | # # Truncate the extra head size 282 | # out = out[:, :, :, : self.d_head].float() 283 | # # Reshape to `[batch_size, seq_len, n_heads * d_head]` 284 | # out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) 285 | 286 | # # Map to `[batch_size, height * width, d_model]` with a linear layer 287 | # return self.to_out(out) 288 | 289 | # def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): 290 | # """ 291 | # #### Normal Attention 292 | 293 | # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` 294 | # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` 295 | # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` 296 | # """ 297 | 298 | # # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` 299 | # q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] 300 | # k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] 301 | # v = v.view(*v.shape[:2], self.n_heads, -1) 302 | 303 | # # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ 304 | # attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale 305 | 306 | # # Compute softmax 307 | # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ 308 | # if self.is_inplace: 309 | # half = attn.shape[0] // 2 310 | # attn[half:] = attn[half:].softmax(dim=-1) 311 | # attn[:half] = attn[:half].softmax(dim=-1) 312 | # else: 313 | # attn = attn.softmax(dim=-1) 314 | 315 | # # Compute attention output 316 | # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ 317 | # # attn: [bs, 20, 64, 1] 318 | # # v: [bs, 1, 20, 32] 319 | # out = torch.einsum("bhij,bjhd->bihd", attn, v) 320 | # # Reshape to `[batch_size, height * width, n_heads * d_head]` 321 | # out = out.reshape(*out.shape[:2], -1) 322 | # # Map to `[batch_size, height * width, d_model]` with a linear layer 323 | # return self.to_out(out) 324 | 325 | 326 | class CrossAttention(nn.Module): 327 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 328 | super().__init__() 329 | inner_dim = dim_head * heads 330 | context_dim = default(context_dim, query_dim) 331 | 332 | self.scale = dim_head**-0.5 333 | self.heads = heads 334 | 335 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 336 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 337 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 338 | 339 | self.to_out = nn.Sequential( 340 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 341 | ) 342 | 343 | def forward(self, x, context=None, mask=None): 344 | h = self.heads 345 | 346 | q = self.to_q(x) 347 | context = default(context, x) 348 | 349 | k = self.to_k(context) 350 | v = self.to_v(context) 351 | 352 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 353 | 354 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 355 | 356 | if exists(mask): 357 | mask = rearrange(mask, "b ... -> b (...)") 358 | max_neg_value = -torch.finfo(sim.dtype).max 359 | mask = repeat(mask, "b j -> (b h) () j", h=h) 360 | sim.masked_fill_(~(mask == 1), max_neg_value) 361 | 362 | # attention, what we cannot get enough of 363 | attn = sim.softmax(dim=-1) 364 | 365 | out = einsum("b i j, b j d -> b i d", attn, v) 366 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 367 | return self.to_out(out) 368 | 369 | 370 | class BasicTransformerBlock(nn.Module): 371 | def __init__( 372 | self, 373 | dim, 374 | n_heads, 375 | d_head, 376 | dropout=0.0, 377 | context_dim=None, 378 | gated_ff=True, 379 | checkpoint=True, 380 | ): 381 | super().__init__() 382 | self.attn1 = CrossAttention( 383 | query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout 384 | ) # is a self-attention 385 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 386 | self.attn2 = CrossAttention( 387 | query_dim=dim, 388 | context_dim=context_dim, 389 | heads=n_heads, 390 | dim_head=d_head, 391 | dropout=dropout, 392 | ) # is self-attn if context is none 393 | self.norm1 = nn.LayerNorm(dim) 394 | self.norm2 = nn.LayerNorm(dim) 395 | self.norm3 = nn.LayerNorm(dim) 396 | self.checkpoint = checkpoint 397 | 398 | def forward(self, x, context=None, mask=None): 399 | if context is None: 400 | return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) 401 | else: 402 | return checkpoint( 403 | self._forward, (x, context, mask), self.parameters(), self.checkpoint 404 | ) 405 | 406 | def _forward(self, x, context=None, mask=None): 407 | x = self.attn1(self.norm1(x)) + x 408 | x = self.attn2(self.norm2(x), context=context, mask=mask) + x 409 | x = self.ff(self.norm3(x)) + x 410 | return x 411 | 412 | 413 | class SpatialTransformer(nn.Module): 414 | """ 415 | Transformer block for image-like data. 416 | First, project the input (aka embedding) 417 | and reshape to b, t, d. 418 | Then apply standard transformer action. 419 | Finally, reshape to image 420 | """ 421 | 422 | def __init__( 423 | self, 424 | in_channels, 425 | n_heads, 426 | d_head, 427 | depth=1, 428 | dropout=0.0, 429 | context_dim=None, 430 | ): 431 | super().__init__() 432 | 433 | context_dim = context_dim 434 | 435 | self.in_channels = in_channels 436 | inner_dim = n_heads * d_head 437 | self.norm = Normalize(in_channels) 438 | 439 | self.proj_in = nn.Conv2d( 440 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 441 | ) 442 | 443 | self.transformer_blocks = nn.ModuleList( 444 | [ 445 | BasicTransformerBlock( 446 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 447 | ) 448 | for d in range(depth) 449 | ] 450 | ) 451 | 452 | self.proj_out = zero_module( 453 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 454 | ) 455 | 456 | def forward(self, x, context=None, mask=None): 457 | # note: if no context is given, cross-attention defaults to self-attention 458 | b, c, h, w = x.shape 459 | x_in = x 460 | x = self.norm(x) 461 | x = self.proj_in(x) 462 | x = rearrange(x, "b c h w -> b (h w) c") 463 | for block in self.transformer_blocks: 464 | x = block(x, context=context, mask=mask) 465 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 466 | x = self.proj_out(x) 467 | return x + x_in 468 | -------------------------------------------------------------------------------- /drawspeech/modules/diffusionmodules/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.mean( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.mean( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /drawspeech/modules/diffusionmodules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_( 47 | one_minus_decay * (shadow_params[sname] - m_param[key]) 48 | ) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /drawspeech/modules/diffusionmodules/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class GroupNorm32(nn.GroupNorm): 13 | def __init__(self, num_groups, num_channels, swish, eps=1e-5): 14 | super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) 15 | self.swish = swish 16 | 17 | def forward(self, x): 18 | y = super().forward(x.float()).to(x.dtype) 19 | if self.swish == 1.0: 20 | y = F.silu(y) 21 | elif self.swish: 22 | y = y * F.sigmoid(y * float(self.swish)) 23 | return y 24 | 25 | 26 | def conv_nd(dims, *args, **kwargs): 27 | """ 28 | Create a 1D, 2D, or 3D convolution module. 29 | """ 30 | if dims == 1: 31 | return nn.Conv1d(*args, **kwargs) 32 | elif dims == 2: 33 | return nn.Conv2d(*args, **kwargs) 34 | elif dims == 3: 35 | return nn.Conv3d(*args, **kwargs) 36 | raise ValueError(f"unsupported dimensions: {dims}") 37 | 38 | 39 | def linear(*args, **kwargs): 40 | """ 41 | Create a linear module. 42 | """ 43 | return nn.Linear(*args, **kwargs) 44 | 45 | 46 | def avg_pool_nd(dims, *args, **kwargs): 47 | """ 48 | Create a 1D, 2D, or 3D average pooling module. 49 | """ 50 | if dims == 1: 51 | return nn.AvgPool1d(*args, **kwargs) 52 | elif dims == 2: 53 | return nn.AvgPool2d(*args, **kwargs) 54 | elif dims == 3: 55 | return nn.AvgPool3d(*args, **kwargs) 56 | raise ValueError(f"unsupported dimensions: {dims}") 57 | 58 | 59 | def update_ema(target_params, source_params, rate=0.99): 60 | """ 61 | Update target parameters to be closer to those of source parameters using 62 | an exponential moving average. 63 | 64 | :param target_params: the target parameter sequence. 65 | :param source_params: the source parameter sequence. 66 | :param rate: the EMA rate (closer to 1 means slower). 67 | """ 68 | for targ, src in zip(target_params, source_params): 69 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 70 | 71 | 72 | def zero_module(module): 73 | """ 74 | Zero out the parameters of a module and return it. 75 | """ 76 | for p in module.parameters(): 77 | p.detach().zero_() 78 | return module 79 | 80 | 81 | def scale_module(module, scale): 82 | """ 83 | Scale the parameters of a module and return it. 84 | """ 85 | for p in module.parameters(): 86 | p.detach().mul_(scale) 87 | return module 88 | 89 | 90 | def mean_flat(tensor): 91 | """ 92 | Take the mean over all non-batch dimensions. 93 | """ 94 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 95 | 96 | 97 | def normalization(channels, swish=0.0): 98 | """ 99 | Make a standard normalization layer, with an optional swish activation. 100 | 101 | :param channels: number of input channels. 102 | :return: an nn.Module for normalization. 103 | """ 104 | return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) 105 | 106 | 107 | # def timestep_embedding(timesteps, dim, max_period=10000): 108 | # """ 109 | # Create sinusoidal timestep embeddings. 110 | 111 | # :param timesteps: a 1-D Tensor of N indices, one per batch element. 112 | # These may be fractional. 113 | # :param dim: the dimension of the output. 114 | # :param max_period: controls the minimum frequency of the embeddings. 115 | # :return: an [N x dim] Tensor of positional embeddings. 116 | # """ 117 | # half = dim // 2 118 | # freqs = th.exp( 119 | # -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 120 | # ).to(device=timesteps.device) 121 | # args = timesteps[:, None].float() * freqs[None] 122 | # embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 123 | # if dim % 2: 124 | # embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 125 | # return embedding 126 | 127 | 128 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 129 | """ 130 | Create sinusoidal timestep embeddings. 131 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 132 | These may be fractional. 133 | :param dim: the dimension of the output. 134 | :param max_period: controls the minimum frequency of the embeddings. 135 | :return: an [N x dim] Tensor of positional embeddings. 136 | """ 137 | if not repeat_only: 138 | half = dim // 2 139 | freqs = th.exp( 140 | -math.log(max_period) 141 | * th.arange(start=0, end=half, dtype=th.float32) 142 | / half 143 | ).to(device=timesteps.device) 144 | args = timesteps[:, None].float() * freqs[None] 145 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 146 | if dim % 2: 147 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 148 | else: 149 | embedding = repeat(timesteps, "b -> b d", d=dim) 150 | return embedding 151 | 152 | 153 | def checkpoint(func, inputs, params, flag): 154 | """ 155 | Evaluate a function without caching intermediate activations, allowing for 156 | reduced memory at the expense of extra compute in the backward pass. 157 | 158 | :param func: the function to evaluate. 159 | :param inputs: the argument sequence to pass to `func`. 160 | :param params: a sequence of parameters `func` depends on but does not 161 | explicitly take as arguments. 162 | :param flag: if False, disable gradient checkpointing. 163 | """ 164 | # flag = False 165 | if flag: 166 | args = tuple(inputs) + tuple(params) 167 | return CheckpointFunction.apply(func, len(inputs), *args) 168 | else: 169 | return func(*inputs) 170 | 171 | 172 | class CheckpointFunction(th.autograd.Function): 173 | @staticmethod 174 | def forward(ctx, run_function, length, *args): 175 | ctx.run_function = run_function 176 | ctx.input_tensors = list(args[:length]) 177 | ctx.input_params = list(args[length:]) 178 | with th.no_grad(): 179 | output_tensors = ctx.run_function(*ctx.input_tensors) 180 | return output_tensors 181 | 182 | @staticmethod 183 | def backward(ctx, *output_grads): 184 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 185 | with th.enable_grad(): 186 | # Fixes a bug where the first op in run_function modifies the 187 | # Tensor storage in place, which is not allowed for detach()'d 188 | # Tensors. 189 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 190 | output_tensors = ctx.run_function(*shallow_copies) 191 | input_grads = th.autograd.grad( 192 | output_tensors, 193 | ctx.input_tensors + ctx.input_params, 194 | output_grads, 195 | allow_unused=True, 196 | ) 197 | del ctx.input_tensors 198 | del ctx.input_params 199 | del output_tensors 200 | return (None, None) + input_grads 201 | -------------------------------------------------------------------------------- /drawspeech/modules/fastspeech2/modules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import copy 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | from .tools import get_mask_from_lengths, pad 13 | from drawspeech.utilities.tools import modify_curve_length 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | class VarianceAdaptor(nn.Module): 19 | """Variance Adaptor""" 20 | 21 | def __init__(self, 22 | phoneme_embedding_dim=192, 23 | pitch_embedding_dim=128, 24 | energy_embedding_dim=128, 25 | pitch_quantization="linear", 26 | energy_quantization="linear", 27 | n_bins=256, 28 | stats_json="path/to/stats.json", 29 | pitch_feature_level="frame_level", 30 | energy_feature_level="frame_level" 31 | ): 32 | super(VarianceAdaptor, self).__init__() 33 | self.duration_predictor = VariancePredictor(input_size=phoneme_embedding_dim) 34 | self.length_regulator = LengthRegulator() 35 | self.pitch_predictor = VariancePredictor(input_size=phoneme_embedding_dim) 36 | self.energy_predictor = VariancePredictor(input_size=phoneme_embedding_dim) 37 | 38 | self.pitch_feature_level = pitch_feature_level 39 | self.energy_feature_level = energy_feature_level 40 | assert self.pitch_feature_level in ["phoneme_level", "frame_level"] 41 | assert self.energy_feature_level in ["phoneme_level", "frame_level"] 42 | 43 | assert pitch_quantization in ["linear", "log"] 44 | assert energy_quantization in ["linear", "log"] 45 | 46 | with open(stats_json, "r") as f: 47 | stats = json.load(f) 48 | pitch_min, pitch_max = stats["pitch"][:2] 49 | energy_min, energy_max = stats["energy"][:2] 50 | 51 | self.pitch_min = pitch_min 52 | self.energy_min = energy_min 53 | 54 | if pitch_quantization == "log": 55 | self.pitch_bins = nn.Parameter( 56 | torch.exp( 57 | torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1) 58 | ), 59 | requires_grad=False, 60 | ) 61 | else: 62 | self.pitch_bins = nn.Parameter( 63 | torch.linspace(pitch_min, pitch_max, n_bins - 1), 64 | requires_grad=False, 65 | ) 66 | if energy_quantization == "log": 67 | self.energy_bins = nn.Parameter( 68 | torch.exp( 69 | torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1) 70 | ), 71 | requires_grad=False, 72 | ) 73 | else: 74 | self.energy_bins = nn.Parameter( 75 | torch.linspace(energy_min, energy_max, n_bins - 1), 76 | requires_grad=False, 77 | ) 78 | 79 | self.pitch_embedding = nn.Embedding( 80 | n_bins, pitch_embedding_dim 81 | ) 82 | self.energy_embedding = nn.Embedding( 83 | n_bins, energy_embedding_dim 84 | ) 85 | 86 | def get_pitch_embedding(self, x, target, mask, control): 87 | prediction = self.pitch_predictor(x, mask) 88 | if target is not None: 89 | embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins)) 90 | else: 91 | prediction = prediction * control 92 | embedding = self.pitch_embedding( 93 | torch.bucketize(prediction, self.pitch_bins) 94 | ) 95 | return prediction, embedding 96 | 97 | def get_energy_embedding(self, x, target, mask, control): 98 | prediction = self.energy_predictor(x, mask) 99 | if target is not None: 100 | embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins)) 101 | else: 102 | prediction = prediction * control 103 | embedding = self.energy_embedding( 104 | torch.bucketize(prediction, self.energy_bins) 105 | ) 106 | return prediction, embedding 107 | 108 | def get_expanded_text_embedding(self, x, src_mask, max_len, duration_target=None): 109 | 110 | if self.pitch_feature_level == "phoneme_level" and self.energy_feature_level == "phoneme_level": 111 | return x, src_mask 112 | 113 | if duration_target is not None: 114 | x, mel_len = self.length_regulator(x, duration_target, max_len) 115 | else: 116 | log_duration_prediction = self.duration_predictor(x, src_mask) 117 | duration_rounded = torch.clamp((torch.round(torch.exp(log_duration_prediction) - 1)), min=0) 118 | x, mel_len = self.length_regulator(x, duration_rounded, max_len) 119 | 120 | mel_mask = get_mask_from_lengths(mel_len, max_len) 121 | 122 | return x, mel_mask 123 | 124 | def predict_duration(self, x, src_mask=None): 125 | log_duration_prediction = self.duration_predictor(x, src_mask) 126 | duration_rounded = torch.clamp((torch.round(torch.exp(log_duration_prediction) - 1)), min=0) 127 | return duration_rounded 128 | 129 | def forward( 130 | self, 131 | x, 132 | src_mask, 133 | mel_mask=None, 134 | max_len=None, 135 | pitch_target=None, 136 | energy_target=None, 137 | duration_target=None, 138 | p_control=1.0, 139 | e_control=1.0, 140 | d_control=1.0, 141 | ): 142 | 143 | log_duration_prediction = self.duration_predictor(x, src_mask) 144 | if self.pitch_feature_level == "phoneme_level": 145 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 146 | x, pitch_target, src_mask, p_control 147 | ) 148 | x = x + pitch_embedding 149 | if self.energy_feature_level == "phoneme_level": 150 | energy_prediction, energy_embedding = self.get_energy_embedding( 151 | x, energy_target, src_mask, e_control 152 | ) 153 | x = x + energy_embedding 154 | 155 | if duration_target is not None: 156 | x, mel_len = self.length_regulator(x, duration_target, max_len) 157 | duration_rounded = duration_target 158 | else: 159 | duration_rounded = torch.clamp( 160 | (torch.round(torch.exp(log_duration_prediction) - 1) * d_control), 161 | min=0, 162 | ) 163 | x, mel_len = self.length_regulator(x, duration_rounded, max_len) 164 | mel_mask = get_mask_from_lengths(mel_len, max_len) 165 | if x.size(1) == 0: 166 | # print("Warning: predicted 0 duration") 167 | x = torch.zeros((x.size(0), 3, x.size(2))).to(device) 168 | mel_mask = torch.ones((x.size(0), 3)).to(device).bool() 169 | 170 | if self.pitch_feature_level == "frame_level": 171 | pitch_prediction, pitch_embedding = self.get_pitch_embedding( 172 | x, pitch_target, mel_mask, p_control 173 | ) 174 | x = x + pitch_embedding 175 | 176 | if self.energy_feature_level == "frame_level": 177 | energy_prediction, energy_embedding = self.get_energy_embedding( 178 | x, energy_target, mel_mask, e_control 179 | ) 180 | x = x + energy_embedding 181 | 182 | return ( 183 | x, 184 | pitch_prediction, 185 | energy_prediction, 186 | log_duration_prediction, 187 | duration_rounded, 188 | mel_len, 189 | mel_mask, 190 | ) 191 | 192 | 193 | class LengthRegulator(nn.Module): 194 | """Length Regulator""" 195 | 196 | def __init__(self): 197 | super(LengthRegulator, self).__init__() 198 | 199 | def LR(self, x, duration, max_len): 200 | output = list() 201 | mel_len = list() 202 | for batch, expand_target in zip(x, duration): 203 | expanded = self.expand(batch, expand_target) 204 | output.append(expanded) 205 | mel_len.append(expanded.shape[0]) 206 | 207 | if max_len is not None: 208 | output = pad(output, max_len) 209 | else: 210 | output = pad(output) 211 | 212 | return output, torch.LongTensor(mel_len).to(device) 213 | 214 | def expand(self, batch, predicted): 215 | out = list() 216 | 217 | for i, vec in enumerate(batch): 218 | expand_size = predicted[i].item() 219 | out.append(vec.expand(max(int(expand_size), 0), -1)) 220 | out = torch.cat(out, 0) 221 | 222 | return out 223 | 224 | def forward(self, x, duration, max_len): 225 | output, mel_len = self.LR(x, duration, max_len) 226 | return output, mel_len 227 | 228 | 229 | class VariancePredictor(nn.Module): 230 | """Duration, Pitch and Energy Predictor""" 231 | 232 | def __init__(self, input_size): 233 | super(VariancePredictor, self).__init__() 234 | 235 | self.input_size = input_size 236 | self.filter_size = 256 237 | self.kernel = 3 238 | self.conv_output_size = 256 239 | self.dropout = 0.5 240 | 241 | self.conv_layer = nn.Sequential( 242 | OrderedDict( 243 | [ 244 | ( 245 | "conv1d_1", 246 | Conv( 247 | self.input_size, 248 | self.filter_size, 249 | kernel_size=self.kernel, 250 | padding=(self.kernel - 1) // 2, 251 | ), 252 | ), 253 | ("relu_1", nn.ReLU()), 254 | ("layer_norm_1", nn.LayerNorm(self.filter_size)), 255 | ("dropout_1", nn.Dropout(self.dropout)), 256 | ( 257 | "conv1d_2", 258 | Conv( 259 | self.filter_size, 260 | self.filter_size, 261 | kernel_size=self.kernel, 262 | padding=1, 263 | ), 264 | ), 265 | ("relu_2", nn.ReLU()), 266 | ("layer_norm_2", nn.LayerNorm(self.filter_size)), 267 | ("dropout_2", nn.Dropout(self.dropout)), 268 | ] 269 | ) 270 | ) 271 | 272 | self.linear_layer = nn.Linear(self.conv_output_size, 1) 273 | 274 | def forward(self, encoder_output, mask): 275 | out = self.conv_layer(encoder_output) 276 | out = self.linear_layer(out) 277 | out = out.squeeze(-1) 278 | 279 | if mask is not None: 280 | out = out.masked_fill(mask, 0.0) 281 | 282 | return out 283 | 284 | 285 | class Conv(nn.Module): 286 | """ 287 | Convolution Module 288 | """ 289 | 290 | def __init__( 291 | self, 292 | in_channels, 293 | out_channels, 294 | kernel_size=1, 295 | stride=1, 296 | padding=0, 297 | dilation=1, 298 | bias=True, 299 | w_init="linear", 300 | ): 301 | """ 302 | :param in_channels: dimension of input 303 | :param out_channels: dimension of output 304 | :param kernel_size: size of kernel 305 | :param stride: size of stride 306 | :param padding: size of padding 307 | :param dilation: dilation rate 308 | :param bias: boolean. if True, bias is included. 309 | :param w_init: str. weight inits with xavier initialization. 310 | """ 311 | super(Conv, self).__init__() 312 | 313 | self.conv = nn.Conv1d( 314 | in_channels, 315 | out_channels, 316 | kernel_size=kernel_size, 317 | stride=stride, 318 | padding=padding, 319 | dilation=dilation, 320 | bias=bias, 321 | ) 322 | 323 | def forward(self, x): 324 | x = x.contiguous().transpose(1, 2) 325 | x = self.conv(x) 326 | x = x.contiguous().transpose(1, 2) 327 | 328 | return x 329 | -------------------------------------------------------------------------------- /drawspeech/modules/fastspeech2/tools.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | def get_mask_from_lengths(lengths, max_len=None): 8 | batch_size = lengths.shape[0] 9 | if max_len is None: 10 | max_len = torch.max(lengths).item() 11 | 12 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 13 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 14 | 15 | return mask 16 | 17 | def pad(input_ele, mel_max_length=None): 18 | if mel_max_length: 19 | max_len = mel_max_length 20 | else: 21 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 22 | 23 | out_list = list() 24 | for i, batch in enumerate(input_ele): 25 | if len(batch.shape) == 1: 26 | one_batch_padded = F.pad( 27 | batch, (0, max_len - batch.size(0)), "constant", 0.0 28 | ) 29 | elif len(batch.shape) == 2: 30 | one_batch_padded = F.pad( 31 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 32 | ) 33 | out_list.append(one_batch_padded) 34 | out_padded = torch.stack(out_list) 35 | return out_padded 36 | 37 | -------------------------------------------------------------------------------- /drawspeech/modules/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /drawspeech/modules/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_hifires import Generator_HiFiRes 2 | from .models import Generator as Generator 3 | 4 | 5 | class AttrDict(dict): 6 | def __init__(self, *args, **kwargs): 7 | super(AttrDict, self).__init__(*args, **kwargs) 8 | self.__dict__ = self 9 | -------------------------------------------------------------------------------- /drawspeech/modules/hifigan/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class Generator(torch.nn.Module): 113 | def __init__(self, h): 114 | super(Generator, self).__init__() 115 | self.h = h 116 | self.num_kernels = len(h.resblock_kernel_sizes) 117 | self.num_upsamples = len(h.upsample_rates) 118 | self.conv_pre = weight_norm( 119 | Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) 120 | ) 121 | resblock = ResBlock 122 | 123 | self.ups = nn.ModuleList() 124 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 125 | self.ups.append( 126 | weight_norm( 127 | ConvTranspose1d( 128 | h.upsample_initial_channel // (2**i), 129 | h.upsample_initial_channel // (2 ** (i + 1)), 130 | k, 131 | u, 132 | padding=(k - u) // 2, 133 | ) 134 | ) 135 | ) 136 | 137 | self.resblocks = nn.ModuleList() 138 | for i in range(len(self.ups)): 139 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 140 | for j, (k, d) in enumerate( 141 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 142 | ): 143 | self.resblocks.append(resblock(h, ch, k, d)) 144 | 145 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 146 | self.ups.apply(init_weights) 147 | self.conv_post.apply(init_weights) 148 | 149 | def forward(self, x): 150 | x = self.conv_pre(x) 151 | for i in range(self.num_upsamples): 152 | x = F.leaky_relu(x, LRELU_SLOPE) 153 | x = self.ups[i](x) 154 | xs = None 155 | for j in range(self.num_kernels): 156 | if xs is None: 157 | xs = self.resblocks[i * self.num_kernels + j](x) 158 | else: 159 | xs += self.resblocks[i * self.num_kernels + j](x) 160 | x = xs / self.num_kernels 161 | x = F.leaky_relu(x) 162 | x = self.conv_post(x) 163 | x = torch.tanh(x) 164 | 165 | return x 166 | 167 | def remove_weight_norm(self): 168 | print("Removing weight norm...") 169 | for l in self.ups: 170 | remove_weight_norm(l) 171 | for l in self.resblocks: 172 | l.remove_weight_norm() 173 | remove_weight_norm(self.conv_pre) 174 | remove_weight_norm(self.conv_post) 175 | -------------------------------------------------------------------------------- /drawspeech/modules/hifigan/models_hifires.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn import Conv1d, ConvTranspose1d 5 | from torch.nn.utils import weight_norm, remove_weight_norm 6 | 7 | LRELU_SLOPE = 0.1 8 | 9 | 10 | def init_weights(m, mean=0.0, std=0.01): 11 | classname = m.__class__.__name__ 12 | if classname.find("Conv") != -1: 13 | m.weight.data.normal_(mean, std) 14 | 15 | 16 | def get_padding(kernel_size, dilation=1): 17 | return int((kernel_size * dilation - dilation) / 2) 18 | 19 | 20 | class ResBlock1(torch.nn.Module): 21 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 22 | super(ResBlock1, self).__init__() 23 | self.h = h 24 | self.convs1 = nn.ModuleList( 25 | [ 26 | weight_norm( 27 | Conv1d( 28 | channels, 29 | channels, 30 | kernel_size, 31 | 1, 32 | dilation=dilation[0], 33 | padding=get_padding(kernel_size, dilation[0]), 34 | ) 35 | ), 36 | weight_norm( 37 | Conv1d( 38 | channels, 39 | channels, 40 | kernel_size, 41 | 1, 42 | dilation=dilation[1], 43 | padding=get_padding(kernel_size, dilation[1]), 44 | ) 45 | ), 46 | weight_norm( 47 | Conv1d( 48 | channels, 49 | channels, 50 | kernel_size, 51 | 1, 52 | dilation=dilation[2], 53 | padding=get_padding(kernel_size, dilation[2]), 54 | ) 55 | ), 56 | ] 57 | ) 58 | self.convs1.apply(init_weights) 59 | 60 | self.convs2 = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | Conv1d( 64 | channels, 65 | channels, 66 | kernel_size, 67 | 1, 68 | dilation=1, 69 | padding=get_padding(kernel_size, 1), 70 | ) 71 | ), 72 | weight_norm( 73 | Conv1d( 74 | channels, 75 | channels, 76 | kernel_size, 77 | 1, 78 | dilation=1, 79 | padding=get_padding(kernel_size, 1), 80 | ) 81 | ), 82 | weight_norm( 83 | Conv1d( 84 | channels, 85 | channels, 86 | kernel_size, 87 | 1, 88 | dilation=1, 89 | padding=get_padding(kernel_size, 1), 90 | ) 91 | ), 92 | ] 93 | ) 94 | self.convs2.apply(init_weights) 95 | 96 | def forward(self, x): 97 | for c1, c2 in zip(self.convs1, self.convs2): 98 | xt = F.leaky_relu(x, LRELU_SLOPE) 99 | xt = c1(xt) 100 | xt = F.leaky_relu(xt, LRELU_SLOPE) 101 | xt = c2(xt) 102 | x = xt + x 103 | return x 104 | 105 | def remove_weight_norm(self): 106 | for l in self.convs1: 107 | remove_weight_norm(l) 108 | for l in self.convs2: 109 | remove_weight_norm(l) 110 | 111 | 112 | class ResBlock2(torch.nn.Module): 113 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): 114 | super(ResBlock2, self).__init__() 115 | self.h = h 116 | self.convs = nn.ModuleList( 117 | [ 118 | weight_norm( 119 | Conv1d( 120 | channels, 121 | channels, 122 | kernel_size, 123 | 1, 124 | dilation=dilation[0], 125 | padding=get_padding(kernel_size, dilation[0]), 126 | ) 127 | ), 128 | weight_norm( 129 | Conv1d( 130 | channels, 131 | channels, 132 | kernel_size, 133 | 1, 134 | dilation=dilation[1], 135 | padding=get_padding(kernel_size, dilation[1]), 136 | ) 137 | ), 138 | ] 139 | ) 140 | self.convs.apply(init_weights) 141 | 142 | def forward(self, x): 143 | for c in self.convs: 144 | xt = F.leaky_relu(x, LRELU_SLOPE) 145 | xt = c(xt) 146 | x = xt + x 147 | return x 148 | 149 | def remove_weight_norm(self): 150 | for l in self.convs: 151 | remove_weight_norm(l) 152 | 153 | 154 | class Generator_HiFiRes(torch.nn.Module): 155 | def __init__(self, h): 156 | super(Generator_HiFiRes, self).__init__() 157 | self.h = h 158 | self.num_kernels = len(h.resblock_kernel_sizes) 159 | self.num_upsamples = len(h.upsample_rates) 160 | self.conv_pre = weight_norm( 161 | Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3) 162 | ) 163 | resblock = ResBlock1 if h.resblock == "1" else ResBlock2 164 | 165 | self.ups = nn.ModuleList() 166 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 167 | self.ups.append( 168 | weight_norm( 169 | ConvTranspose1d( 170 | h.upsample_initial_channel // (2**i), 171 | h.upsample_initial_channel // (2 ** (i + 1)), 172 | u * 2, 173 | u, 174 | padding=u // 2 + u % 2, 175 | output_padding=u % 2, 176 | ) 177 | ) 178 | ) 179 | 180 | self.resblocks = nn.ModuleList() 181 | for i in range(len(self.ups)): 182 | ch = h.upsample_initial_channel // (2 ** (i + 1)) 183 | for j, (k, d) in enumerate( 184 | zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) 185 | ): 186 | self.resblocks.append(resblock(h, ch, k, d)) 187 | 188 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 189 | self.ups.apply(init_weights) 190 | self.conv_post.apply(init_weights) 191 | 192 | def forward(self, x): 193 | x = self.conv_pre(x) 194 | for i in range(self.num_upsamples): 195 | x = F.leaky_relu(x, LRELU_SLOPE) 196 | x = self.ups[i](x) 197 | xs = None 198 | for j in range(self.num_kernels): 199 | if xs is None: 200 | xs = self.resblocks[i * self.num_kernels + j](x) 201 | else: 202 | xs += self.resblocks[i * self.num_kernels + j](x) 203 | x = xs / self.num_kernels 204 | x = F.leaky_relu(x) 205 | x = self.conv_post(x) 206 | x = torch.tanh(x) 207 | 208 | return x 209 | 210 | def remove_weight_norm(self): 211 | print("Removing weight norm...") 212 | for l in self.ups: 213 | remove_weight_norm(l) 214 | for l in self.resblocks: 215 | l.remove_weight_norm() 216 | remove_weight_norm(self.conv_pre) 217 | remove_weight_norm(self.conv_post) 218 | -------------------------------------------------------------------------------- /drawspeech/modules/latent_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/DrawSpeech_PyTorch/9db245d68e87826a175494f037eb4b9482d3a836/drawspeech/modules/latent_diffusion/__init__.py -------------------------------------------------------------------------------- /drawspeech/modules/latent_diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler 2 | -------------------------------------------------------------------------------- /drawspeech/modules/latent_diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample( 23 | self, 24 | S, 25 | batch_size, 26 | shape, 27 | conditioning=None, 28 | callback=None, 29 | normals_sequence=None, 30 | img_callback=None, 31 | quantize_x0=False, 32 | eta=0.0, 33 | mask=None, 34 | x0=None, 35 | temperature=1.0, 36 | noise_dropout=0.0, 37 | score_corrector=None, 38 | corrector_kwargs=None, 39 | verbose=True, 40 | x_T=None, 41 | log_every_t=100, 42 | unconditional_guidance_scale=1.0, 43 | unconditional_conditioning=None, 44 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 45 | **kwargs, 46 | ): 47 | if conditioning is not None: 48 | if isinstance(conditioning, dict): 49 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 50 | if cbs != batch_size: 51 | print( 52 | f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" 53 | ) 54 | else: 55 | if conditioning.shape[0] != batch_size: 56 | print( 57 | f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" 58 | ) 59 | 60 | # sampling 61 | C, H, W = shape 62 | size = (batch_size, C, H, W) 63 | 64 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 65 | 66 | device = self.model.betas.device 67 | if x_T is None: 68 | img = torch.randn(size, device=device) 69 | else: 70 | img = x_T 71 | 72 | ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) 73 | 74 | model_fn = model_wrapper( 75 | lambda x, t, c: self.model.apply_model(x, t, c), 76 | ns, 77 | model_type="noise", 78 | guidance_type="classifier-free", 79 | condition=conditioning, 80 | unconditional_condition=unconditional_conditioning, 81 | guidance_scale=unconditional_guidance_scale, 82 | ) 83 | 84 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 85 | x = dpm_solver.sample( 86 | img, 87 | steps=S, 88 | skip_type="time_uniform", 89 | method="multistep", 90 | order=2, 91 | lower_order_final=True, 92 | ) 93 | 94 | return x.to(device), None 95 | -------------------------------------------------------------------------------- /drawspeech/modules/latent_diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from drawspeech.utilities.diffusion_util import ( 9 | make_ddim_sampling_parameters, 10 | make_ddim_timesteps, 11 | noise_like, 12 | ) 13 | 14 | 15 | class PLMSSampler(object): 16 | def __init__(self, model, schedule="linear", **kwargs): 17 | super().__init__() 18 | self.model = model 19 | self.ddpm_num_timesteps = model.num_timesteps 20 | self.schedule = schedule 21 | 22 | def register_buffer(self, name, attr): 23 | if type(attr) == torch.Tensor: 24 | if attr.device != torch.device("cuda"): 25 | attr = attr.to(torch.device("cuda")) 26 | setattr(self, name, attr) 27 | 28 | def make_schedule( 29 | self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True 30 | ): 31 | if ddim_eta != 0: 32 | ddim_eta = 0 33 | # raise ValueError('ddim_eta must be 0 for PLMS') 34 | 35 | self.ddim_timesteps = make_ddim_timesteps( 36 | ddim_discr_method=ddim_discretize, 37 | num_ddim_timesteps=ddim_num_steps, 38 | num_ddpm_timesteps=self.ddpm_num_timesteps, 39 | verbose=verbose, 40 | ) 41 | alphas_cumprod = self.model.alphas_cumprod 42 | assert ( 43 | alphas_cumprod.shape[0] == self.ddpm_num_timesteps 44 | ), "alphas have to be defined for each timestep" 45 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 46 | 47 | self.register_buffer("betas", to_torch(self.model.betas)) 48 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 49 | self.register_buffer( 50 | "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) 51 | ) 52 | 53 | # calculations for diffusion q(x_t | x_{t-1}) and others 54 | self.register_buffer( 55 | "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) 56 | ) 57 | self.register_buffer( 58 | "sqrt_one_minus_alphas_cumprod", 59 | to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), 60 | ) 61 | self.register_buffer( 62 | "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) 63 | ) 64 | self.register_buffer( 65 | "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) 66 | ) 67 | self.register_buffer( 68 | "sqrt_recipm1_alphas_cumprod", 69 | to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), 70 | ) 71 | 72 | # ddim sampling parameters 73 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( 74 | alphacums=alphas_cumprod.cpu(), 75 | ddim_timesteps=self.ddim_timesteps, 76 | eta=ddim_eta, 77 | verbose=verbose, 78 | ) 79 | self.register_buffer("ddim_sigmas", ddim_sigmas) 80 | self.register_buffer("ddim_alphas", ddim_alphas) 81 | self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) 82 | self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) 83 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 84 | (1 - self.alphas_cumprod_prev) 85 | / (1 - self.alphas_cumprod) 86 | * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) 87 | ) 88 | self.register_buffer( 89 | "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps 90 | ) 91 | 92 | @torch.no_grad() 93 | def sample( 94 | self, 95 | S, 96 | batch_size, 97 | shape, 98 | conditioning=None, 99 | callback=None, 100 | normals_sequence=None, 101 | img_callback=None, 102 | quantize_x0=False, 103 | eta=0.0, 104 | mask=None, 105 | x0=None, 106 | temperature=1.0, 107 | noise_dropout=0.0, 108 | score_corrector=None, 109 | corrector_kwargs=None, 110 | verbose=True, 111 | x_T=None, 112 | log_every_t=100, 113 | unconditional_guidance_scale=1.0, 114 | unconditional_conditioning=None, 115 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 116 | **kwargs, 117 | ): 118 | if conditioning is not None: 119 | if isinstance(conditioning, dict): 120 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 121 | if cbs != batch_size: 122 | print( 123 | f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" 124 | ) 125 | else: 126 | if conditioning.shape[0] != batch_size: 127 | print( 128 | f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" 129 | ) 130 | 131 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 132 | # sampling 133 | C, H, W = shape 134 | size = (batch_size, C, H, W) 135 | print(f"Data shape for PLMS sampling is {size}") 136 | 137 | samples, intermediates = self.plms_sampling( 138 | conditioning, 139 | size, 140 | callback=callback, 141 | img_callback=img_callback, 142 | quantize_denoised=quantize_x0, 143 | mask=mask, 144 | x0=x0, 145 | ddim_use_original_steps=False, 146 | noise_dropout=noise_dropout, 147 | temperature=temperature, 148 | score_corrector=score_corrector, 149 | corrector_kwargs=corrector_kwargs, 150 | x_T=x_T, 151 | log_every_t=log_every_t, 152 | unconditional_guidance_scale=unconditional_guidance_scale, 153 | unconditional_conditioning=unconditional_conditioning, 154 | ) 155 | return samples, intermediates 156 | 157 | @torch.no_grad() 158 | def plms_sampling( 159 | self, 160 | cond, 161 | shape, 162 | x_T=None, 163 | ddim_use_original_steps=False, 164 | callback=None, 165 | timesteps=None, 166 | quantize_denoised=False, 167 | mask=None, 168 | x0=None, 169 | img_callback=None, 170 | log_every_t=100, 171 | temperature=1.0, 172 | noise_dropout=0.0, 173 | score_corrector=None, 174 | corrector_kwargs=None, 175 | unconditional_guidance_scale=1.0, 176 | unconditional_conditioning=None, 177 | ): 178 | device = self.model.betas.device 179 | b = shape[0] 180 | if x_T is None: 181 | img = torch.randn(shape, device=device) 182 | else: 183 | img = x_T 184 | 185 | if timesteps is None: 186 | timesteps = ( 187 | self.ddpm_num_timesteps 188 | if ddim_use_original_steps 189 | else self.ddim_timesteps 190 | ) 191 | elif timesteps is not None and not ddim_use_original_steps: 192 | subset_end = ( 193 | int( 194 | min(timesteps / self.ddim_timesteps.shape[0], 1) 195 | * self.ddim_timesteps.shape[0] 196 | ) 197 | - 1 198 | ) 199 | timesteps = self.ddim_timesteps[:subset_end] 200 | 201 | intermediates = {"x_inter": [img], "pred_x0": [img]} 202 | time_range = ( 203 | list(reversed(range(0, timesteps))) 204 | if ddim_use_original_steps 205 | else np.flip(timesteps) 206 | ) 207 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 208 | print(f"Running PLMS Sampling with {total_steps} timesteps") 209 | 210 | iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps) 211 | old_eps = [] 212 | 213 | for i, step in enumerate(iterator): 214 | index = total_steps - i - 1 215 | ts = torch.full((b,), step, device=device, dtype=torch.long) 216 | ts_next = torch.full( 217 | (b,), 218 | time_range[min(i + 1, len(time_range) - 1)], 219 | device=device, 220 | dtype=torch.long, 221 | ) 222 | 223 | if mask is not None: 224 | assert x0 is not None 225 | img_orig = self.model.q_sample( 226 | x0, ts 227 | ) # TODO: deterministic forward pass? 228 | img = img_orig * mask + (1.0 - mask) * img 229 | 230 | outs = self.p_sample_plms( 231 | img, 232 | cond, 233 | ts, 234 | index=index, 235 | use_original_steps=ddim_use_original_steps, 236 | quantize_denoised=quantize_denoised, 237 | temperature=temperature, 238 | noise_dropout=noise_dropout, 239 | score_corrector=score_corrector, 240 | corrector_kwargs=corrector_kwargs, 241 | unconditional_guidance_scale=unconditional_guidance_scale, 242 | unconditional_conditioning=unconditional_conditioning, 243 | old_eps=old_eps, 244 | t_next=ts_next, 245 | ) 246 | img, pred_x0, e_t = outs 247 | old_eps.append(e_t) 248 | if len(old_eps) >= 4: 249 | old_eps.pop(0) 250 | if callback: 251 | callback(i) 252 | if img_callback: 253 | img_callback(pred_x0, i) 254 | 255 | if index % log_every_t == 0 or index == total_steps - 1: 256 | intermediates["x_inter"].append(img) 257 | intermediates["pred_x0"].append(pred_x0) 258 | 259 | return img, intermediates 260 | 261 | @torch.no_grad() 262 | def p_sample_plms( 263 | self, 264 | x, 265 | c, 266 | t, 267 | index, 268 | repeat_noise=False, 269 | use_original_steps=False, 270 | quantize_denoised=False, 271 | temperature=1.0, 272 | noise_dropout=0.0, 273 | score_corrector=None, 274 | corrector_kwargs=None, 275 | unconditional_guidance_scale=1.0, 276 | unconditional_conditioning=None, 277 | old_eps=None, 278 | t_next=None, 279 | ): 280 | b, *_, device = *x.shape, x.device 281 | 282 | def get_model_output(x, t): 283 | if ( 284 | unconditional_conditioning is None 285 | or unconditional_guidance_scale == 1.0 286 | ): 287 | e_t = self.model.apply_model(x, t, c) 288 | else: 289 | x_in = torch.cat([x] * 2) 290 | t_in = torch.cat([t] * 2) 291 | c_in = torch.cat([unconditional_conditioning, c]) 292 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 293 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 294 | 295 | if score_corrector is not None: 296 | assert self.model.parameterization == "eps" 297 | e_t = score_corrector.modify_score( 298 | self.model, e_t, x, t, c, **corrector_kwargs 299 | ) 300 | 301 | return e_t 302 | 303 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 304 | alphas_prev = ( 305 | self.model.alphas_cumprod_prev 306 | if use_original_steps 307 | else self.ddim_alphas_prev 308 | ) 309 | sqrt_one_minus_alphas = ( 310 | self.model.sqrt_one_minus_alphas_cumprod 311 | if use_original_steps 312 | else self.ddim_sqrt_one_minus_alphas 313 | ) 314 | sigmas = ( 315 | self.model.ddim_sigmas_for_original_num_steps 316 | if use_original_steps 317 | else self.ddim_sigmas 318 | ) 319 | 320 | def get_x_prev_and_pred_x0(e_t, index): 321 | # select parameters corresponding to the currently considered timestep 322 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 323 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 324 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 325 | sqrt_one_minus_at = torch.full( 326 | (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device 327 | ) 328 | 329 | # current prediction for x_0 330 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 331 | if quantize_denoised: 332 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 333 | # direction pointing to x_t 334 | dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t 335 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 336 | if noise_dropout > 0.0: 337 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 338 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 339 | return x_prev, pred_x0 340 | 341 | e_t = get_model_output(x, t) 342 | if len(old_eps) == 0: 343 | # Pseudo Improved Euler (2nd order) 344 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 345 | e_t_next = get_model_output(x_prev, t_next) 346 | e_t_prime = (e_t + e_t_next) / 2 347 | elif len(old_eps) == 1: 348 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 349 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 350 | elif len(old_eps) == 2: 351 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 352 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 353 | elif len(old_eps) >= 3: 354 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 355 | e_t_prime = ( 356 | 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] 357 | ) / 24 358 | 359 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 360 | 361 | return x_prev, pred_x0, e_t 362 | -------------------------------------------------------------------------------- /drawspeech/modules/latent_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/DrawSpeech_PyTorch/9db245d68e87826a175494f037eb4b9482d3a836/drawspeech/modules/latent_encoder/__init__.py -------------------------------------------------------------------------------- /drawspeech/modules/text_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HappyColor/DrawSpeech_PyTorch/9db245d68e87826a175494f037eb4b9482d3a836/drawspeech/modules/text_encoder/__init__.py -------------------------------------------------------------------------------- /drawspeech/modules/text_encoder/attentions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import drawspeech.modules.text_encoder.commons as commons 9 | 10 | LRELU_SLOPE = 0.1 11 | 12 | 13 | class LayerNorm(nn.Module): 14 | def __init__(self, channels, eps=1e-5): 15 | super().__init__() 16 | self.channels = channels 17 | self.eps = eps 18 | 19 | self.gamma = nn.Parameter(torch.ones(channels)) 20 | self.beta = nn.Parameter(torch.zeros(channels)) 21 | 22 | def forward(self, x): 23 | x = x.transpose(1, -1) 24 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 25 | return x.transpose(1, -1) 26 | 27 | 28 | class Encoder(nn.Module): 29 | def __init__( 30 | self, 31 | hidden_channels, 32 | filter_channels, 33 | n_heads, 34 | n_layers, 35 | kernel_size=1, 36 | p_dropout=0.0, 37 | window_size=4, 38 | **kwargs 39 | ): 40 | super().__init__() 41 | self.hidden_channels = hidden_channels 42 | self.filter_channels = filter_channels 43 | self.n_heads = n_heads 44 | self.n_layers = n_layers 45 | self.kernel_size = kernel_size 46 | self.p_dropout = p_dropout 47 | self.window_size = window_size 48 | 49 | self.drop = nn.Dropout(p_dropout) 50 | self.attn_layers = nn.ModuleList() 51 | self.norm_layers_1 = nn.ModuleList() 52 | self.ffn_layers = nn.ModuleList() 53 | self.norm_layers_2 = nn.ModuleList() 54 | for i in range(self.n_layers): 55 | self.attn_layers.append( 56 | MultiHeadAttention( 57 | hidden_channels, 58 | hidden_channels, 59 | n_heads, 60 | p_dropout=p_dropout, 61 | window_size=window_size, 62 | ) 63 | ) 64 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 65 | self.ffn_layers.append( 66 | FFN( 67 | hidden_channels, 68 | hidden_channels, 69 | filter_channels, 70 | kernel_size, 71 | p_dropout=p_dropout, 72 | ) 73 | ) 74 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 75 | 76 | def forward(self, x, x_mask): 77 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 78 | x = x * x_mask 79 | for i in range(self.n_layers): 80 | y = self.attn_layers[i](x, x, attn_mask) 81 | y = self.drop(y) 82 | x = self.norm_layers_1[i](x + y) 83 | 84 | y = self.ffn_layers[i](x, x_mask) 85 | y = self.drop(y) 86 | x = self.norm_layers_2[i](x + y) 87 | x = x * x_mask 88 | return x 89 | 90 | 91 | class Decoder(nn.Module): 92 | def __init__( 93 | self, 94 | hidden_channels, 95 | filter_channels, 96 | n_heads, 97 | n_layers, 98 | kernel_size=1, 99 | p_dropout=0.0, 100 | proximal_bias=False, 101 | proximal_init=True, 102 | **kwargs 103 | ): 104 | super().__init__() 105 | self.hidden_channels = hidden_channels 106 | self.filter_channels = filter_channels 107 | self.n_heads = n_heads 108 | self.n_layers = n_layers 109 | self.kernel_size = kernel_size 110 | self.p_dropout = p_dropout 111 | self.proximal_bias = proximal_bias 112 | self.proximal_init = proximal_init 113 | 114 | self.drop = nn.Dropout(p_dropout) 115 | self.self_attn_layers = nn.ModuleList() 116 | self.norm_layers_0 = nn.ModuleList() 117 | self.encdec_attn_layers = nn.ModuleList() 118 | self.norm_layers_1 = nn.ModuleList() 119 | self.ffn_layers = nn.ModuleList() 120 | self.norm_layers_2 = nn.ModuleList() 121 | for i in range(self.n_layers): 122 | self.self_attn_layers.append( 123 | MultiHeadAttention( 124 | hidden_channels, 125 | hidden_channels, 126 | n_heads, 127 | p_dropout=p_dropout, 128 | proximal_bias=proximal_bias, 129 | proximal_init=proximal_init, 130 | ) 131 | ) 132 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 133 | self.encdec_attn_layers.append( 134 | MultiHeadAttention( 135 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout 136 | ) 137 | ) 138 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 139 | self.ffn_layers.append( 140 | FFN( 141 | hidden_channels, 142 | hidden_channels, 143 | filter_channels, 144 | kernel_size, 145 | p_dropout=p_dropout, 146 | causal=True, 147 | ) 148 | ) 149 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 150 | 151 | def forward(self, x, x_mask, h, h_mask): 152 | """ 153 | x: decoder input 154 | h: encoder output 155 | """ 156 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( 157 | device=x.device, dtype=x.dtype 158 | ) 159 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 160 | x = x * x_mask 161 | for i in range(self.n_layers): 162 | y = self.self_attn_layers[i](x, x, self_attn_mask) 163 | y = self.drop(y) 164 | x = self.norm_layers_0[i](x + y) 165 | 166 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 167 | y = self.drop(y) 168 | x = self.norm_layers_1[i](x + y) 169 | 170 | y = self.ffn_layers[i](x, x_mask) 171 | y = self.drop(y) 172 | x = self.norm_layers_2[i](x + y) 173 | x = x * x_mask 174 | return x 175 | 176 | 177 | class MultiHeadAttention(nn.Module): 178 | def __init__( 179 | self, 180 | channels, 181 | out_channels, 182 | n_heads, 183 | p_dropout=0.0, 184 | window_size=None, 185 | heads_share=True, 186 | block_length=None, 187 | proximal_bias=False, 188 | proximal_init=False, 189 | ): 190 | super().__init__() 191 | assert channels % n_heads == 0 192 | 193 | self.channels = channels 194 | self.out_channels = out_channels 195 | self.n_heads = n_heads 196 | self.p_dropout = p_dropout 197 | self.window_size = window_size 198 | self.heads_share = heads_share 199 | self.block_length = block_length 200 | self.proximal_bias = proximal_bias 201 | self.proximal_init = proximal_init 202 | self.attn = None 203 | 204 | self.k_channels = channels // n_heads 205 | self.conv_q = nn.Conv1d(channels, channels, 1) 206 | self.conv_k = nn.Conv1d(channels, channels, 1) 207 | self.conv_v = nn.Conv1d(channels, channels, 1) 208 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 209 | self.drop = nn.Dropout(p_dropout) 210 | 211 | if window_size is not None: 212 | n_heads_rel = 1 if heads_share else n_heads 213 | rel_stddev = self.k_channels**-0.5 214 | self.emb_rel_k = nn.Parameter( 215 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 216 | * rel_stddev 217 | ) 218 | self.emb_rel_v = nn.Parameter( 219 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) 220 | * rel_stddev 221 | ) 222 | 223 | nn.init.xavier_uniform_(self.conv_q.weight) 224 | nn.init.xavier_uniform_(self.conv_k.weight) 225 | nn.init.xavier_uniform_(self.conv_v.weight) 226 | if proximal_init: 227 | with torch.no_grad(): 228 | self.conv_k.weight.copy_(self.conv_q.weight) 229 | self.conv_k.bias.copy_(self.conv_q.bias) 230 | 231 | def forward(self, x, c, attn_mask=None): 232 | q = self.conv_q(x) 233 | k = self.conv_k(c) 234 | v = self.conv_v(c) 235 | 236 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 237 | 238 | x = self.conv_o(x) 239 | return x 240 | 241 | def attention(self, query, key, value, mask=None): 242 | # reshape [b, d, t] -> [b, n_h, t, d_k] 243 | b, d, t_s, t_t = (*key.size(), query.size(2)) 244 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 245 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 246 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 247 | 248 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 249 | if self.window_size is not None: 250 | assert ( 251 | t_s == t_t 252 | ), "Relative attention is only available for self-attention." 253 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 254 | rel_logits = self._matmul_with_relative_keys( 255 | query / math.sqrt(self.k_channels), key_relative_embeddings 256 | ) 257 | scores_local = self._relative_position_to_absolute_position(rel_logits) 258 | scores = scores + scores_local 259 | if self.proximal_bias: 260 | assert t_s == t_t, "Proximal bias is only available for self-attention." 261 | scores = scores + self._attention_bias_proximal(t_s).to( 262 | device=scores.device, dtype=scores.dtype 263 | ) 264 | if mask is not None: 265 | scores = scores.masked_fill(mask == 0, -1e4) 266 | if self.block_length is not None: 267 | assert ( 268 | t_s == t_t 269 | ), "Local attention is only available for self-attention." 270 | block_mask = ( 271 | torch.ones_like(scores) 272 | .triu(-self.block_length) 273 | .tril(self.block_length) 274 | ) 275 | scores = scores.masked_fill(block_mask == 0, -1e4) 276 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 277 | p_attn = self.drop(p_attn) 278 | output = torch.matmul(p_attn, value) 279 | if self.window_size is not None: 280 | relative_weights = self._absolute_position_to_relative_position(p_attn) 281 | value_relative_embeddings = self._get_relative_embeddings( 282 | self.emb_rel_v, t_s 283 | ) 284 | output = output + self._matmul_with_relative_values( 285 | relative_weights, value_relative_embeddings 286 | ) 287 | output = ( 288 | output.transpose(2, 3).contiguous().view(b, d, t_t) 289 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t] 290 | return output, p_attn 291 | 292 | def _matmul_with_relative_values(self, x, y): 293 | """ 294 | x: [b, h, l, m] 295 | y: [h or 1, m, d] 296 | ret: [b, h, l, d] 297 | """ 298 | ret = torch.matmul(x, y.unsqueeze(0)) 299 | return ret 300 | 301 | def _matmul_with_relative_keys(self, x, y): 302 | """ 303 | x: [b, h, l, d] 304 | y: [h or 1, m, d] 305 | ret: [b, h, l, m] 306 | """ 307 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 308 | return ret 309 | 310 | def _get_relative_embeddings(self, relative_embeddings, length): 311 | max_relative_position = 2 * self.window_size + 1 312 | # Pad first before slice to avoid using cond ops. 313 | pad_length = max(length - (self.window_size + 1), 0) 314 | slice_start_position = max((self.window_size + 1) - length, 0) 315 | slice_end_position = slice_start_position + 2 * length - 1 316 | if pad_length > 0: 317 | padded_relative_embeddings = F.pad( 318 | relative_embeddings, 319 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), 320 | ) 321 | else: 322 | padded_relative_embeddings = relative_embeddings 323 | used_relative_embeddings = padded_relative_embeddings[ 324 | :, slice_start_position:slice_end_position 325 | ] 326 | return used_relative_embeddings 327 | 328 | def _relative_position_to_absolute_position(self, x): 329 | """ 330 | x: [b, h, l, 2*l-1] 331 | ret: [b, h, l, l] 332 | """ 333 | batch, heads, length, _ = x.size() 334 | # Concat columns of pad to shift from relative to absolute indexing. 335 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])) 336 | 337 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 338 | x_flat = x.view([batch, heads, length * 2 * length]) 339 | x_flat = F.pad( 340 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) 341 | ) 342 | 343 | # Reshape and slice out the padded elements. 344 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ 345 | :, :, :length, length - 1 : 346 | ] 347 | return x_final 348 | 349 | def _absolute_position_to_relative_position(self, x): 350 | """ 351 | x: [b, h, l, l] 352 | ret: [b, h, l, 2*l-1] 353 | """ 354 | batch, heads, length, _ = x.size() 355 | # padd along column 356 | x = F.pad( 357 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) 358 | ) 359 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) 360 | # add 0's in the beginning that will skew the elements after reshape 361 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 362 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] 363 | return x_final 364 | 365 | def _attention_bias_proximal(self, length): 366 | """Bias for self-attention to encourage attention to close positions. 367 | Args: 368 | length: an integer scalar. 369 | Returns: 370 | a Tensor with shape [1, 1, length, length] 371 | """ 372 | r = torch.arange(length, dtype=torch.float32) 373 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 374 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 375 | 376 | 377 | class FFN(nn.Module): 378 | def __init__( 379 | self, 380 | in_channels, 381 | out_channels, 382 | filter_channels, 383 | kernel_size, 384 | p_dropout=0.0, 385 | activation=None, 386 | causal=False, 387 | ): 388 | super().__init__() 389 | self.in_channels = in_channels 390 | self.out_channels = out_channels 391 | self.filter_channels = filter_channels 392 | self.kernel_size = kernel_size 393 | self.p_dropout = p_dropout 394 | self.activation = activation 395 | self.causal = causal 396 | 397 | if causal: 398 | self.padding = self._causal_padding 399 | else: 400 | self.padding = self._same_padding 401 | 402 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 403 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 404 | self.drop = nn.Dropout(p_dropout) 405 | 406 | def forward(self, x, x_mask): 407 | x = self.conv_1(self.padding(x * x_mask)) 408 | if self.activation == "gelu": 409 | x = x * torch.sigmoid(1.702 * x) 410 | else: 411 | x = torch.relu(x) 412 | x = self.drop(x) 413 | x = self.conv_2(self.padding(x * x_mask)) 414 | return x * x_mask 415 | 416 | def _causal_padding(self, x): 417 | if self.kernel_size == 1: 418 | return x 419 | pad_l = self.kernel_size - 1 420 | pad_r = 0 421 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 422 | x = F.pad(x, commons.convert_pad_shape(padding)) 423 | return x 424 | 425 | def _same_padding(self, x): 426 | if self.kernel_size == 1: 427 | return x 428 | pad_l = (self.kernel_size - 1) // 2 429 | pad_r = self.kernel_size // 2 430 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 431 | x = F.pad(x, commons.convert_pad_shape(padding)) 432 | return x 433 | -------------------------------------------------------------------------------- /drawspeech/modules/text_encoder/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size * dilation - dilation) / 2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += ( 34 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) 35 | ) 36 | return kl 37 | 38 | 39 | def rand_gumbel(shape): 40 | """Sample from the Gumbel distribution, protect from overflows.""" 41 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 42 | return -torch.log(-torch.log(uniform_samples)) 43 | 44 | 45 | def rand_gumbel_like(x): 46 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 47 | return g 48 | 49 | 50 | def slice_segments(x, ids_str, segment_size=4): 51 | ret = torch.zeros_like(x[:, :, :segment_size]) 52 | for i in range(x.size(0)): 53 | idx_str = ids_str[i] 54 | idx_end = idx_str + segment_size 55 | ret[i] = x[i, :, idx_str:idx_end] 56 | return ret 57 | 58 | 59 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 60 | b, d, t = x.size() 61 | if x_lengths is None: 62 | x_lengths = t 63 | ids_str_max = x_lengths - segment_size + 1 64 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 65 | ret = slice_segments(x, ids_str, segment_size) 66 | return ret, ids_str 67 | 68 | 69 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): 70 | position = torch.arange(length, dtype=torch.float) 71 | num_timescales = channels // 2 72 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( 73 | num_timescales - 1 74 | ) 75 | inv_timescales = min_timescale * torch.exp( 76 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment 77 | ) 78 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 79 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 80 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 81 | signal = signal.view(1, channels, length) 82 | return signal 83 | 84 | 85 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 86 | b, channels, length = x.size() 87 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 88 | return x + signal.to(dtype=x.dtype, device=x.device) 89 | 90 | 91 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 92 | b, channels, length = x.size() 93 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 94 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 95 | 96 | 97 | def subsequent_mask(length): 98 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 99 | return mask 100 | 101 | 102 | @torch.jit.script 103 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 104 | n_channels_int = n_channels[0] 105 | in_act = input_a + input_b 106 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 107 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 108 | acts = t_act * s_act 109 | return acts 110 | 111 | 112 | def convert_pad_shape(pad_shape): 113 | l = pad_shape[::-1] 114 | pad_shape = [item for sublist in l for item in sublist] 115 | return pad_shape 116 | 117 | 118 | def shift_1d(x): 119 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 120 | return x 121 | 122 | 123 | def sequence_mask(length, max_length=None): 124 | if max_length is None: 125 | max_length = length.max() 126 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 127 | return x.unsqueeze(0) < length.unsqueeze(1) 128 | 129 | 130 | def generate_path(duration, mask): 131 | """ 132 | duration: [b, 1, t_x] 133 | mask: [b, 1, t_y, t_x] 134 | """ 135 | device = duration.device 136 | 137 | b, _, t_y, t_x = mask.shape 138 | cum_duration = torch.cumsum(duration, -1) 139 | 140 | cum_duration_flat = cum_duration.view(b * t_x) 141 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 142 | path = path.view(b, t_x, t_y) 143 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 144 | path = path.unsqueeze(1).transpose(2, 3) * mask 145 | return path 146 | 147 | 148 | def clip_grad_value_(parameters, clip_value, norm_type=2): 149 | if isinstance(parameters, torch.Tensor): 150 | parameters = [parameters] 151 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 152 | norm_type = float(norm_type) 153 | if clip_value is not None: 154 | clip_value = float(clip_value) 155 | 156 | total_norm = 0 157 | for p in parameters: 158 | param_norm = p.grad.data.norm(norm_type) 159 | total_norm += param_norm.item() ** norm_type 160 | if clip_value is not None: 161 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 162 | total_norm = total_norm ** (1.0 / norm_type) 163 | return total_norm 164 | -------------------------------------------------------------------------------- /drawspeech/modules/text_encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | import drawspeech.modules.text_encoder.commons as commons 8 | import drawspeech.modules.text_encoder.attentions as attentions 9 | 10 | # refer to vits 11 | class TextEncoder(nn.Module): 12 | def __init__( 13 | self, 14 | n_vocab, 15 | out_channels=192, 16 | hidden_channels=192, 17 | filter_channels=768, 18 | n_heads=2, 19 | n_layers=6, 20 | kernel_size=3, 21 | p_dropout=0.1, 22 | ): 23 | super().__init__() 24 | self.n_vocab = n_vocab 25 | self.out_channels = out_channels 26 | self.hidden_channels = hidden_channels 27 | self.filter_channels = filter_channels 28 | self.n_heads = n_heads 29 | self.n_layers = n_layers 30 | self.kernel_size = kernel_size 31 | self.p_dropout = p_dropout 32 | 33 | self.emb = nn.Embedding(n_vocab, hidden_channels) 34 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 35 | 36 | self.encoder = attentions.Encoder( 37 | hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout 38 | ) 39 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 40 | 41 | def forward(self, x, x_lengths): 42 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 43 | x = torch.transpose(x, 1, -1) # [b, h, t] 44 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( 45 | x.dtype 46 | ) 47 | 48 | x = self.encoder(x * x_mask, x_mask) 49 | stats = self.proj(x) * x_mask 50 | 51 | m, logs = torch.split(stats, self.out_channels, dim=1) 52 | return x, m, logs, x_mask 53 | -------------------------------------------------------------------------------- /drawspeech/train/autoencoder.py: -------------------------------------------------------------------------------- 1 | # mainly from https://github.com/haoheliu/AudioLDM-training-finetuning 2 | 3 | 4 | import sys 5 | 6 | sys.path.append("src") 7 | 8 | import os 9 | import wandb 10 | 11 | import argparse 12 | import yaml 13 | import torch 14 | from pytorch_lightning.strategies.ddp import DDPStrategy 15 | from drawspeech.utilities.data.dataset import AudioDataset 16 | from torch.utils.data import DataLoader 17 | from pytorch_lightning.loggers import WandbLogger 18 | from pytorch_lightning import Trainer 19 | from drawspeech.modules.latent_encoder.autoencoder import AutoencoderKL 20 | from pytorch_lightning.callbacks import ModelCheckpoint 21 | from drawspeech.utilities.tools import get_restore_step 22 | 23 | 24 | def main(configs, exp_group_name, exp_name): 25 | if "precision" in configs.keys(): 26 | torch.set_float32_matmul_precision(configs["precision"]) 27 | batch_size = configs["model"]["params"]["batchsize"] 28 | log_path = configs["log_directory"] 29 | 30 | if "dataloader_add_ons" in configs["data"].keys(): 31 | dataloader_add_ons = configs["data"]["dataloader_add_ons"] 32 | else: 33 | dataloader_add_ons = [] 34 | 35 | dataset = AudioDataset(configs, split="train", add_ons=dataloader_add_ons) 36 | 37 | loader = DataLoader( 38 | dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True 39 | ) 40 | 41 | print( 42 | "The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" 43 | % (len(dataset), len(loader), batch_size) 44 | ) 45 | 46 | val_dataset = AudioDataset(configs, split="val", add_ons=dataloader_add_ons) 47 | 48 | val_loader = DataLoader( 49 | val_dataset, 50 | batch_size=batch_size, 51 | num_workers=8, 52 | shuffle=True, 53 | ) 54 | 55 | model = AutoencoderKL( 56 | ddconfig=configs["model"]["params"]["ddconfig"], 57 | lossconfig=configs["model"]["params"]["lossconfig"], 58 | embed_dim=configs["model"]["params"]["embed_dim"], 59 | image_key=configs["model"]["params"]["image_key"], 60 | base_learning_rate=configs["model"]["base_learning_rate"], 61 | subband=configs["model"]["params"]["subband"], 62 | sampling_rate=configs["preprocessing"]["audio"]["sampling_rate"], 63 | ) 64 | 65 | try: 66 | config_reload_from_ckpt = configs["reload_from_ckpt"] 67 | except: 68 | config_reload_from_ckpt = None 69 | 70 | checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") 71 | 72 | checkpoint_callback = ModelCheckpoint( 73 | dirpath=checkpoint_path, 74 | monitor="global_step", 75 | mode="max", 76 | filename="checkpoint-{global_step:.0f}", 77 | every_n_train_steps=configs["step"]["save_checkpoint_every_n_steps"], 78 | save_top_k=configs["step"]["save_top_k"], 79 | auto_insert_metric_name=False, 80 | save_last=configs["step"]["save_last"], 81 | ) 82 | 83 | wandb_path = os.path.join(log_path, exp_group_name, exp_name) 84 | 85 | model.set_log_dir(log_path, exp_group_name, exp_name) 86 | 87 | os.makedirs(checkpoint_path, exist_ok=True) 88 | 89 | if len(os.listdir(checkpoint_path)) > 0: 90 | print("Load checkpoint from path: %s" % checkpoint_path) 91 | restore_step, n_step = get_restore_step(checkpoint_path) 92 | resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) 93 | print("Resume from checkpoint", resume_from_checkpoint) 94 | elif config_reload_from_ckpt is not None: 95 | resume_from_checkpoint = config_reload_from_ckpt 96 | print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) 97 | else: 98 | print("Train from scratch") 99 | resume_from_checkpoint = None 100 | 101 | devices = torch.cuda.device_count() 102 | 103 | wandb_logger = WandbLogger( 104 | save_dir=wandb_path, 105 | project=configs["project"], 106 | config=configs, 107 | name="%s/%s" % (exp_group_name, exp_name), 108 | ) 109 | 110 | trainer = Trainer( 111 | accelerator="gpu", 112 | devices=devices, 113 | logger=wandb_logger, 114 | max_steps=configs["step"]["max_steps"], 115 | limit_val_batches=configs["step"]["limit_val_batches"], 116 | callbacks=[checkpoint_callback], 117 | strategy=DDPStrategy(find_unused_parameters=True), 118 | # val_check_interval=configs["step"]["val_check_interval"], 119 | check_val_every_n_epoch=configs["step"]["validation_every_n_epochs"], 120 | ) 121 | 122 | # TRAINING 123 | trainer.fit(model, loader, val_loader, ckpt_path=resume_from_checkpoint) 124 | 125 | # EVALUTION 126 | # trainer.test(model, test_loader, ckpt_path=resume_from_checkpoint) 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument( 132 | "-c", 133 | "--autoencoder_config", 134 | type=str, 135 | required=True, 136 | help="path to autoencoder config .yam", 137 | ) 138 | 139 | args = parser.parse_args() 140 | 141 | config_yaml = args.autoencoder_config 142 | exp_name = os.path.basename(config_yaml.split(".")[0]) 143 | exp_group_name = os.path.basename(os.path.dirname(config_yaml)) 144 | 145 | config_yaml = os.path.join(config_yaml) 146 | 147 | config_yaml = yaml.load(open(config_yaml, "r"), Loader=yaml.FullLoader) 148 | 149 | main(config_yaml, exp_group_name, exp_name) 150 | -------------------------------------------------------------------------------- /drawspeech/train/latent_diffusion.py: -------------------------------------------------------------------------------- 1 | # mainly from https://github.com/haoheliu/AudioLDM-training-finetuning 2 | 3 | 4 | import sys 5 | 6 | sys.path.append("src") 7 | import shutil 8 | import os 9 | 10 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 11 | 12 | import argparse 13 | import yaml 14 | import torch 15 | 16 | from tqdm import tqdm 17 | from pytorch_lightning.strategies.ddp import DDPStrategy 18 | from drawspeech.utilities.data.dataset import AudioDataset 19 | 20 | from torch.utils.data import DataLoader 21 | from pytorch_lightning import Trainer, seed_everything 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from pytorch_lightning.loggers import WandbLogger 24 | from drawspeech.utilities.tools import ( 25 | get_restore_step, 26 | copy_test_subset_data, 27 | ) 28 | from drawspeech.utilities.model_util import instantiate_from_config 29 | import logging 30 | 31 | logging.basicConfig(level=logging.WARNING) 32 | 33 | 34 | def print_on_rank0(msg): 35 | if torch.distributed.get_rank() == 0: 36 | print(msg) 37 | 38 | 39 | def main(configs, config_yaml_path, exp_group_name, exp_name, perform_validation): 40 | if "seed" in configs.keys(): 41 | seed_everything(configs["seed"]) 42 | else: 43 | print("SEED EVERYTHING TO 0") 44 | seed_everything(0) 45 | 46 | if "precision" in configs.keys(): 47 | torch.set_float32_matmul_precision( 48 | configs["precision"] 49 | ) # highest, high, medium 50 | 51 | log_path = configs["log_directory"] 52 | batch_size = configs["model"]["params"]["batchsize"] 53 | 54 | if "dataloader_add_ons" in configs["data"].keys(): 55 | dataloader_add_ons = configs["data"]["dataloader_add_ons"] 56 | else: 57 | dataloader_add_ons = [] 58 | 59 | dataset = AudioDataset(configs, split="train", add_ons=dataloader_add_ons) 60 | 61 | loader = DataLoader( 62 | dataset, 63 | batch_size=batch_size, 64 | num_workers=32, 65 | pin_memory=True, 66 | shuffle=True, 67 | ) 68 | 69 | print( 70 | "The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" 71 | % (len(dataset), len(loader), batch_size) 72 | ) 73 | 74 | val_dataset = AudioDataset(configs, split="test", add_ons=dataloader_add_ons) 75 | 76 | val_loader = DataLoader( 77 | val_dataset, 78 | batch_size=8, 79 | shuffle=True, 80 | ) 81 | 82 | # Copy test data 83 | test_data_subset_folder = os.path.join( 84 | os.path.dirname(configs["log_directory"]), 85 | "testset_data", 86 | "_".join(val_dataset.dataset_name) if isinstance(val_dataset.dataset_name, list) else val_dataset.dataset_name 87 | ) 88 | os.makedirs(test_data_subset_folder, exist_ok=True) 89 | copy_test_subset_data(val_dataset.data, test_data_subset_folder) 90 | 91 | try: 92 | config_reload_from_ckpt = configs["reload_from_ckpt"] 93 | except: 94 | config_reload_from_ckpt = None 95 | 96 | try: 97 | limit_val_batches = configs["step"]["limit_val_batches"] 98 | except: 99 | limit_val_batches = None 100 | 101 | validation_every_n_epochs = configs["step"]["validation_every_n_epochs"] 102 | # val_check_interval = configs["step"]["val_check_interval"] 103 | num_sanity_val_steps = configs["step"]["num_sanity_val_steps"] 104 | save_checkpoint_every_n_steps = configs["step"]["save_checkpoint_every_n_steps"] 105 | max_steps = configs["step"]["max_steps"] 106 | save_top_k = configs["step"]["save_top_k"] 107 | save_last = configs["step"]["save_last"] 108 | monitor = configs["model"]["params"]["monitor"] 109 | 110 | checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") 111 | 112 | wandb_path = os.path.join(log_path, exp_group_name, exp_name) 113 | 114 | checkpoint_callback = ModelCheckpoint( 115 | dirpath=checkpoint_path, 116 | monitor=monitor, # "global_step" 117 | mode="min", 118 | filename="checkpoint-fad-{val/frechet_inception_distance:.2f}-global_step={global_step:.0f}", 119 | every_n_train_steps=save_checkpoint_every_n_steps, 120 | save_top_k=save_top_k, 121 | auto_insert_metric_name=False, 122 | save_last=save_last, 123 | ) 124 | 125 | os.makedirs(checkpoint_path, exist_ok=True) 126 | shutil.copy(config_yaml_path, wandb_path) 127 | 128 | is_external_checkpoints = False 129 | if len(os.listdir(checkpoint_path)) > 0: 130 | print("Load checkpoint from path: %s" % checkpoint_path) 131 | restore_step, n_step = get_restore_step(checkpoint_path) 132 | resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) 133 | print("Resume from checkpoint", resume_from_checkpoint) 134 | elif config_reload_from_ckpt is not None: 135 | resume_from_checkpoint = config_reload_from_ckpt 136 | is_external_checkpoints = True 137 | print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) 138 | else: 139 | print("Train from scratch") 140 | resume_from_checkpoint = None 141 | 142 | devices = torch.cuda.device_count() 143 | 144 | latent_diffusion = instantiate_from_config(configs["model"]) 145 | latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name) 146 | 147 | wandb_logger = WandbLogger( 148 | save_dir=wandb_path, 149 | project=configs["project"], 150 | config=configs, 151 | name="%s/%s" % (exp_group_name, exp_name), 152 | ) 153 | 154 | latent_diffusion.test_data_subset_path = test_data_subset_folder 155 | 156 | print("==> Save checkpoint every %s steps" % save_checkpoint_every_n_steps) 157 | print("==> Perform validation every %s epochs" % validation_every_n_epochs) 158 | # print("==> Perform validation every %s steps" % val_check_interval) 159 | 160 | trainer = Trainer( 161 | accelerator="gpu", 162 | devices=devices, 163 | logger=wandb_logger, 164 | max_steps=max_steps, 165 | num_sanity_val_steps=num_sanity_val_steps, 166 | limit_val_batches=limit_val_batches, 167 | check_val_every_n_epoch=validation_every_n_epochs, 168 | # val_check_interval=val_check_interval, 169 | strategy=DDPStrategy(find_unused_parameters=True), 170 | callbacks=[checkpoint_callback], 171 | ) 172 | 173 | if is_external_checkpoints: 174 | if resume_from_checkpoint is not None: 175 | ckpt = torch.load(resume_from_checkpoint)["state_dict"] 176 | 177 | key_not_in_model_state_dict = [] 178 | size_mismatch_keys = [] 179 | state_dict = latent_diffusion.state_dict() 180 | print("Filtering key for reloading:", resume_from_checkpoint) 181 | print( 182 | "State dict key size:", 183 | len(list(state_dict.keys())), 184 | len(list(ckpt.keys())), 185 | ) 186 | for key in tqdm(list(ckpt.keys())): 187 | if key not in state_dict.keys(): 188 | key_not_in_model_state_dict.append(key) 189 | del ckpt[key] 190 | continue 191 | if state_dict[key].size() != ckpt[key].size(): 192 | del ckpt[key] 193 | size_mismatch_keys.append(key) 194 | 195 | # if(len(key_not_in_model_state_dict) != 0 or len(size_mismatch_keys) != 0): 196 | # print("⛳", end=" ") 197 | 198 | # print("==> Warning: The following key in the checkpoint is not presented in the model:", key_not_in_model_state_dict) 199 | # print("==> Warning: These keys have different size between checkpoint and current model: ", size_mismatch_keys) 200 | 201 | latent_diffusion.load_state_dict(ckpt, strict=False) 202 | 203 | # if(perform_validation): 204 | # trainer.validate(latent_diffusion, val_loader) 205 | 206 | trainer.fit(latent_diffusion, loader, val_loader) 207 | else: 208 | trainer.fit( 209 | latent_diffusion, loader, val_loader, ckpt_path=resume_from_checkpoint 210 | ) 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument( 216 | "-c", 217 | "--config_yaml", 218 | type=str, 219 | required=False, 220 | help="path to config .yaml file", 221 | ) 222 | 223 | parser.add_argument( 224 | "--reload_from_ckpt", 225 | type=str, 226 | required=False, 227 | default=None, 228 | help="path to pretrained checkpoint", 229 | ) 230 | 231 | parser.add_argument("--val", action="store_true") 232 | 233 | args = parser.parse_args() 234 | 235 | perform_validation = args.val 236 | 237 | assert torch.cuda.is_available(), "CUDA is not available" 238 | 239 | config_yaml = args.config_yaml 240 | 241 | exp_name = os.path.basename(config_yaml.split(".")[0]) 242 | exp_group_name = os.path.basename(os.path.dirname(config_yaml)) 243 | 244 | config_yaml_path = os.path.join(config_yaml) 245 | config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) 246 | 247 | if args.reload_from_ckpt is not None: 248 | config_yaml["reload_from_ckpt"] = args.reload_from_ckpt 249 | 250 | if perform_validation: 251 | config_yaml["step"]["limit_val_batches"] = None 252 | 253 | main(config_yaml, config_yaml_path, exp_group_name, exp_name, perform_validation) 254 | -------------------------------------------------------------------------------- /drawspeech/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .data import * 3 | from .model_util import * 4 | -------------------------------------------------------------------------------- /drawspeech/utilities/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_processing import * 2 | from .stft import * 3 | from .tools import * 4 | -------------------------------------------------------------------------------- /drawspeech/utilities/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return normalize_fun(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /drawspeech/utilities/audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from drawspeech.utilities.audio.audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data, 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes, normalize_fun): 152 | output = dynamic_range_compression(magnitudes, normalize_fun) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y, normalize_fun=torch.log): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1, torch.min(y.data) 170 | assert torch.max(y.data) <= 1, torch.max(y.data) 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output, normalize_fun) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | return mel_output, magnitudes, phases, energy 179 | -------------------------------------------------------------------------------- /drawspeech/utilities/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.io.wavfile import write 4 | import torchaudio 5 | 6 | from drawspeech.utilities.audio.audio_processing import griffin_lim 7 | 8 | 9 | def get_mel_from_wav(audio, _stft): 10 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 11 | audio = torch.autograd.Variable(audio, requires_grad=False) 12 | melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio) 13 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 14 | magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32) 15 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 16 | return melspec, magnitudes, energy 17 | 18 | 19 | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60): 20 | mel = torch.stack([mel]) 21 | mel_decompress = _stft.spectral_de_normalize(mel) 22 | mel_decompress = mel_decompress.transpose(1, 2).data.cpu() 23 | spec_from_mel_scaling = 1000 24 | spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis) 25 | spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) 26 | spec_from_mel = spec_from_mel * spec_from_mel_scaling 27 | 28 | audio = griffin_lim( 29 | torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters 30 | ) 31 | 32 | audio = audio.squeeze() 33 | audio = audio.cpu().numpy() 34 | audio_path = out_filename 35 | write(audio_path, _stft.sampling_rate, audio) 36 | -------------------------------------------------------------------------------- /drawspeech/utilities/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | -------------------------------------------------------------------------------- /drawspeech/utilities/diffusion_util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from drawspeech.utilities.model_util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule( 22 | schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 23 | ): 24 | if schedule == "linear": 25 | betas = ( 26 | torch.linspace( 27 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 28 | ) 29 | ** 2 30 | ) 31 | 32 | elif schedule == "cosine": 33 | timesteps = ( 34 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 35 | ) 36 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 37 | alphas = torch.cos(alphas).pow(2) 38 | alphas = alphas / alphas[0] 39 | betas = 1 - alphas[1:] / alphas[:-1] 40 | betas = np.clip(betas, a_min=0, a_max=0.999) 41 | 42 | elif schedule == "sqrt_linear": 43 | betas = torch.linspace( 44 | linear_start, linear_end, n_timestep, dtype=torch.float64 45 | ) 46 | elif schedule == "sqrt": 47 | betas = ( 48 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 49 | ** 0.5 50 | ) 51 | else: 52 | raise ValueError(f"schedule '{schedule}' unknown.") 53 | return betas.numpy() 54 | 55 | 56 | def make_ddim_timesteps( 57 | ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True 58 | ): 59 | if ddim_discr_method == "uniform": 60 | c = num_ddpm_timesteps // num_ddim_timesteps 61 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 62 | elif ddim_discr_method == "quad": 63 | ddim_timesteps = ( 64 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 65 | ).astype(int) 66 | else: 67 | raise NotImplementedError( 68 | f'There is no ddim discretization method called "{ddim_discr_method}"' 69 | ) 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f"Selected timesteps for ddim sampler: {steps_out}") 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | alphas = alphacums[ddim_timesteps] 82 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 83 | 84 | # according the the formula provided in https://arxiv.org/abs/2010.02502 85 | sigmas = eta * np.sqrt( 86 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 87 | ) 88 | if verbose: 89 | print( 90 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 91 | ) 92 | print( 93 | f"For the chosen value of eta, which is {eta}, " 94 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 95 | ) 96 | return sigmas, alphas, alphas_prev 97 | 98 | 99 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 100 | """ 101 | Create a beta schedule that discretizes the given alpha_t_bar function, 102 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 103 | :param num_diffusion_timesteps: the number of betas to produce. 104 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 105 | produces the cumulative product of (1-beta) up to that 106 | part of the diffusion process. 107 | :param max_beta: the maximum beta to use; use values lower than 1 to 108 | prevent singularities. 109 | """ 110 | betas = [] 111 | for i in range(num_diffusion_timesteps): 112 | t1 = i / num_diffusion_timesteps 113 | t2 = (i + 1) / num_diffusion_timesteps 114 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 115 | return np.array(betas) 116 | 117 | 118 | def extract_into_tensor(a, t, x_shape): 119 | b, *_ = t.shape 120 | out = a.gather(-1, t).contiguous() 121 | return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | :param func: the function to evaluate. 129 | :param inputs: the argument sequence to pass to `func`. 130 | :param params: a sequence of parameters `func` depends on but does not 131 | explicitly take as arguments. 132 | :param flag: if False, disable gradient checkpointing. 133 | """ 134 | if flag: 135 | args = tuple(inputs) + tuple(params) 136 | return CheckpointFunction.apply(func, len(inputs), *args) 137 | else: 138 | return func(*inputs) 139 | 140 | 141 | class CheckpointFunction(torch.autograd.Function): 142 | @staticmethod 143 | def forward(ctx, run_function, length, *args): 144 | ctx.run_function = run_function 145 | ctx.input_tensors = list(args[:length]) 146 | ctx.input_params = list(args[length:]) 147 | 148 | with torch.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with torch.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = torch.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | 172 | 173 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 174 | """ 175 | Create sinusoidal timestep embeddings. 176 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 177 | These may be fractional. 178 | :param dim: the dimension of the output. 179 | :param max_period: controls the minimum frequency of the embeddings. 180 | :return: an [N x dim] Tensor of positional embeddings. 181 | """ 182 | if not repeat_only: 183 | half = dim // 2 184 | freqs = torch.exp( 185 | -math.log(max_period) 186 | * torch.arange(start=0, end=half, dtype=torch.float32) 187 | / half 188 | ).to(device=timesteps.device) 189 | args = timesteps[:, None].float() * freqs[None] 190 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 191 | if dim % 2: 192 | embedding = torch.cat( 193 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 194 | ) 195 | else: 196 | embedding = repeat(timesteps, "b -> b d", d=dim) 197 | return embedding 198 | 199 | 200 | def zero_module(module): 201 | """ 202 | Zero out the parameters of a module and return it. 203 | """ 204 | for p in module.parameters(): 205 | p.detach().zero_() 206 | return module 207 | 208 | 209 | def scale_module(module, scale): 210 | """ 211 | Scale the parameters of a module and return it. 212 | """ 213 | for p in module.parameters(): 214 | p.detach().mul_(scale) 215 | return module 216 | 217 | 218 | def mean_flat(tensor): 219 | """ 220 | Take the mean over all non-batch dimensions. 221 | """ 222 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 223 | 224 | 225 | def normalization(channels): 226 | """ 227 | Make a standard normalization layer. 228 | :param channels: number of input channels. 229 | :return: an nn.Module for normalization. 230 | """ 231 | return GroupNorm32(32, channels) 232 | 233 | 234 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 235 | class SiLU(nn.Module): 236 | def forward(self, x): 237 | return x * torch.sigmoid(x) 238 | 239 | 240 | class GroupNorm32(nn.GroupNorm): 241 | def forward(self, x): 242 | return super().forward(x.float()).type(x.dtype) 243 | 244 | 245 | def conv_nd(dims, *args, **kwargs): 246 | """ 247 | Create a 1D, 2D, or 3D convolution module. 248 | """ 249 | if dims == 1: 250 | return nn.Conv1d(*args, **kwargs) 251 | elif dims == 2: 252 | return nn.Conv2d(*args, **kwargs) 253 | elif dims == 3: 254 | return nn.Conv3d(*args, **kwargs) 255 | raise ValueError(f"unsupported dimensions: {dims}") 256 | 257 | 258 | def linear(*args, **kwargs): 259 | """ 260 | Create a linear module. 261 | """ 262 | return nn.Linear(*args, **kwargs) 263 | 264 | 265 | def avg_pool_nd(dims, *args, **kwargs): 266 | """ 267 | Create a 1D, 2D, or 3D average pooling module. 268 | """ 269 | if dims == 1: 270 | return nn.AvgPool1d(*args, **kwargs) 271 | elif dims == 2: 272 | return nn.AvgPool2d(*args, **kwargs) 273 | elif dims == 3: 274 | return nn.AvgPool3d(*args, **kwargs) 275 | raise ValueError(f"unsupported dimensions: {dims}") 276 | 277 | 278 | class HybridConditioner(nn.Module): 279 | def __init__(self, c_concat_config, c_crossattn_config): 280 | super().__init__() 281 | self.concat_conditioner = instantiate_from_config(c_concat_config) 282 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 283 | 284 | def forward(self, c_concat, c_crossattn): 285 | c_concat = self.concat_conditioner(c_concat) 286 | c_crossattn = self.crossattn_conditioner(c_crossattn) 287 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 288 | 289 | 290 | def noise_like(shape, device, repeat=False): 291 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 292 | shape[0], *((1,) * (len(shape) - 1)) 293 | ) 294 | noise = lambda: torch.randn(shape, device=device) 295 | return repeat_noise() if repeat else noise() 296 | -------------------------------------------------------------------------------- /drawspeech/utilities/model_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | import numpy as np 6 | 7 | import drawspeech.modules.hifigan as hifigan 8 | 9 | import importlib 10 | 11 | import torch 12 | import numpy as np 13 | from collections import abc 14 | 15 | import multiprocessing as mp 16 | from threading import Thread 17 | from queue import Queue 18 | 19 | from inspect import isfunction 20 | from PIL import Image, ImageDraw, ImageFont 21 | 22 | 23 | def log_txt_as_img(wh, xc, size=10): 24 | # wh a tuple of (width, height) 25 | # xc a list of captions to plot 26 | b = len(xc) 27 | txts = list() 28 | for bi in range(b): 29 | txt = Image.new("RGB", wh, color="white") 30 | draw = ImageDraw.Draw(txt) 31 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 32 | nc = int(40 * (wh[0] / 256)) 33 | lines = "\n".join( 34 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 35 | ) 36 | 37 | try: 38 | draw.text((0, 0), lines, fill="black", font=font) 39 | except UnicodeEncodeError: 40 | print("Cant encode string for logging. Skipping.") 41 | 42 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 43 | txts.append(txt) 44 | txts = np.stack(txts) 45 | txts = torch.tensor(txts) 46 | return txts 47 | 48 | 49 | def ismap(x): 50 | if not isinstance(x, torch.Tensor): 51 | return False 52 | return (len(x.shape) == 4) and (x.shape[1] > 3) 53 | 54 | 55 | def isimage(x): 56 | if not isinstance(x, torch.Tensor): 57 | return False 58 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 59 | 60 | 61 | def int16_to_float32(x): 62 | return (x / 32767.0).astype(np.float32) 63 | 64 | 65 | def float32_to_int16(x): 66 | x = np.clip(x, a_min=-1.0, a_max=1.0) 67 | return (x * 32767.0).astype(np.int16) 68 | 69 | 70 | def exists(x): 71 | return x is not None 72 | 73 | 74 | def default(val, d): 75 | if exists(val): 76 | return val 77 | return d() if isfunction(d) else d 78 | 79 | 80 | def mean_flat(tensor): 81 | """ 82 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 83 | Take the mean over all non-batch dimensions. 84 | """ 85 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 86 | 87 | 88 | def count_params(model, verbose=False): 89 | total_params = sum(p.numel() for p in model.parameters()) 90 | if verbose: 91 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 92 | return total_params 93 | 94 | 95 | def instantiate_from_config(config): 96 | if not "target" in config: 97 | if config == "__is_first_stage__": 98 | return None 99 | elif config == "__is_unconditional__": 100 | return None 101 | raise KeyError("Expected key `target` to instantiate.") 102 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 103 | 104 | 105 | def get_obj_from_str(string, reload=False): 106 | module, cls = string.rsplit(".", 1) 107 | if reload: 108 | module_imp = importlib.import_module(module) 109 | importlib.reload(module_imp) 110 | return getattr(importlib.import_module(module, package=None), cls) 111 | 112 | 113 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 114 | # create dummy dataset instance 115 | 116 | # run prefetching 117 | if idx_to_fn: 118 | res = func(data, worker_id=idx) 119 | else: 120 | res = func(data) 121 | Q.put([idx, res]) 122 | Q.put("Done") 123 | 124 | 125 | def parallel_data_prefetch( 126 | func: callable, 127 | data, 128 | n_proc, 129 | target_data_type="ndarray", 130 | cpu_intensive=True, 131 | use_worker_id=False, 132 | ): 133 | # if target_data_type not in ["ndarray", "list"]: 134 | # raise ValueError( 135 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 136 | # ) 137 | if isinstance(data, np.ndarray) and target_data_type == "list": 138 | raise ValueError("list expected but function got ndarray.") 139 | elif isinstance(data, abc.Iterable): 140 | if isinstance(data, dict): 141 | print( 142 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 143 | ) 144 | data = list(data.values()) 145 | if target_data_type == "ndarray": 146 | data = np.asarray(data) 147 | else: 148 | data = list(data) 149 | else: 150 | raise TypeError( 151 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 152 | ) 153 | 154 | if cpu_intensive: 155 | Q = mp.Queue(1000) 156 | proc = mp.Process 157 | else: 158 | Q = Queue(1000) 159 | proc = Thread 160 | # spawn processes 161 | if target_data_type == "ndarray": 162 | arguments = [ 163 | [func, Q, part, i, use_worker_id] 164 | for i, part in enumerate(np.array_split(data, n_proc)) 165 | ] 166 | else: 167 | step = ( 168 | int(len(data) / n_proc + 1) 169 | if len(data) % n_proc != 0 170 | else int(len(data) / n_proc) 171 | ) 172 | arguments = [ 173 | [func, Q, part, i, use_worker_id] 174 | for i, part in enumerate( 175 | [data[i : i + step] for i in range(0, len(data), step)] 176 | ) 177 | ] 178 | processes = [] 179 | for i in range(n_proc): 180 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 181 | processes += [p] 182 | 183 | # start processes 184 | print(f"Start prefetching...") 185 | import time 186 | 187 | start = time.time() 188 | gather_res = [[] for _ in range(n_proc)] 189 | try: 190 | for p in processes: 191 | p.start() 192 | 193 | k = 0 194 | while k < n_proc: 195 | # get result 196 | res = Q.get() 197 | if res == "Done": 198 | k += 1 199 | else: 200 | gather_res[res[0]] = res[1] 201 | 202 | except Exception as e: 203 | print("Exception: ", e) 204 | for p in processes: 205 | p.terminate() 206 | 207 | raise e 208 | finally: 209 | for p in processes: 210 | p.join() 211 | print(f"Prefetching complete. [{time.time() - start} sec.]") 212 | 213 | if target_data_type == "ndarray": 214 | if not isinstance(gather_res[0], np.ndarray): 215 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 216 | 217 | # order outputs 218 | return np.concatenate(gather_res, axis=0) 219 | elif target_data_type == "list": 220 | out = [] 221 | for r in gather_res: 222 | out.extend(r) 223 | return out 224 | else: 225 | return gather_res 226 | 227 | 228 | def get_available_checkpoint_keys(model, ckpt): 229 | print("==> Attemp to reload from %s" % ckpt) 230 | state_dict = torch.load(ckpt)["state_dict"] 231 | current_state_dict = model.state_dict() 232 | new_state_dict = {} 233 | for k in state_dict.keys(): 234 | if ( 235 | k in current_state_dict.keys() 236 | and current_state_dict[k].size() == state_dict[k].size() 237 | ): 238 | new_state_dict[k] = state_dict[k] 239 | else: 240 | print("==> WARNING: Skipping %s" % k) 241 | print( 242 | "%s out of %s keys are matched" 243 | % (len(new_state_dict.keys()), len(state_dict.keys())) 244 | ) 245 | return new_state_dict 246 | 247 | 248 | def get_param_num(model): 249 | num_param = sum(param.numel() for param in model.parameters()) 250 | return num_param 251 | 252 | 253 | def torch_version_orig_mod_remove(state_dict): 254 | new_state_dict = {} 255 | new_state_dict["generator"] = {} 256 | for key in state_dict["generator"].keys(): 257 | if "_orig_mod." in key: 258 | new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[ 259 | "generator" 260 | ][key] 261 | else: 262 | new_state_dict["generator"][key] = state_dict["generator"][key] 263 | return new_state_dict 264 | 265 | 266 | def get_vocoder(config, device, mel_bins, ckpt_path=None): 267 | ROOT = "data/checkpoints" 268 | 269 | if mel_bins == 64: 270 | model_path = os.path.join(ROOT, "hifigan_16k_64bins") 271 | with open(model_path + ".json", "r") as f: 272 | config = json.load(f) 273 | config = hifigan.AttrDict(config) 274 | vocoder = hifigan.Generator(config) 275 | ckpt = torch.load(model_path + ".ckpt") 276 | elif mel_bins == 256: 277 | model_path = os.path.join(ROOT, "hifigan_48k_256bins") 278 | with open(model_path + ".json", "r") as f: 279 | config = json.load(f) 280 | config = hifigan.AttrDict(config) 281 | vocoder = hifigan.Generator_HiFiRes(config) 282 | ckpt = torch.load(model_path + ".ckpt") 283 | elif mel_bins == 80: 284 | # load HiFi-GAN pretrained on LJSpeech 285 | print("==> Loading HiFi-GAN LJ_V1 (https://github.com/jik876/hifi-gan) pretrained on LJSpeech") 286 | model_path = os.path.join(ROOT, "LJ_V1", "generator_v1") 287 | 288 | config_file = os.path.join(os.path.split(model_path)[0], 'config.json') 289 | with open(config_file) as f: 290 | data = f.read() 291 | json_config = json.loads(data) 292 | config = hifigan.AttrDict(json_config) 293 | vocoder = hifigan.Generator(config) 294 | ckpt = torch.load(model_path) 295 | 296 | if ckpt_path is not None: 297 | ckpt = torch.load(ckpt_path) 298 | 299 | ckpt = torch_version_orig_mod_remove(ckpt) 300 | vocoder.load_state_dict(ckpt["generator"]) 301 | vocoder.eval() 302 | vocoder.remove_weight_norm() 303 | vocoder.to(device) 304 | return vocoder 305 | 306 | 307 | def vocoder_infer(mels, vocoder, lengths=None): 308 | with torch.no_grad(): 309 | wavs = vocoder(mels).squeeze(1) 310 | 311 | wavs = (wavs.cpu().numpy() * 32768).astype("int16") 312 | 313 | if lengths is not None: 314 | wavs = wavs[:, :lengths] 315 | 316 | return wavs 317 | -------------------------------------------------------------------------------- /drawspeech/utilities/preprocessor/preprocess_frame_level.yaml: -------------------------------------------------------------------------------- 1 | dataset: "LJSpeech" 2 | 3 | path: 4 | corpus_path: "data/dataset/LJSpeech-1.1" 5 | TextGrid_path: "data/dataset/LJSpeech-1.1/TextGrid/" 6 | raw_path: "data/dataset/LJSpeech-1.1/wavs" 7 | preprocessed_path: "data/dataset/metadata/ljspeech/frame_level" 8 | 9 | preprocessing: 10 | audio: 11 | sampling_rate: 22050 12 | max_wav_value: 32768.0 13 | stft: 14 | filter_length: 1024 15 | hop_length: 256 16 | win_length: 1024 17 | mel: 18 | n_mel_channels: 80 19 | mel_fmin: 0 20 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder 21 | pitch: 22 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 23 | normalization: True 24 | energy: 25 | feature: "frame_level" # support 'phoneme_level' or 'frame_level' 26 | normalization: True 27 | -------------------------------------------------------------------------------- /drawspeech/utilities/preprocessor/preprocess_one_sample.py: -------------------------------------------------------------------------------- 1 | 2 | from string import punctuation 3 | from g2p_en import G2p 4 | 5 | import numpy as np 6 | import re 7 | 8 | from drawspeech.utilities.text import text_to_sequence 9 | 10 | def read_lexicon(lex_path): 11 | lexicon = {} 12 | with open(lex_path) as f: 13 | for line in f: 14 | temp = re.split(r"\s+", line.strip("\n")) 15 | word = temp[0] 16 | phones = temp[1:] 17 | if word.lower() not in lexicon: 18 | lexicon[word.lower()] = phones 19 | return lexicon 20 | 21 | def preprocess_english(text, return_word_phone_alignment=False, always_use_g2p=False, g2p=None, verbose=True): 22 | text = text.rstrip(punctuation) 23 | 24 | if always_use_g2p: 25 | lexicon = None 26 | assert g2p is not None # initialize once can save time 27 | else: 28 | lexicon_path = "drawspeech/utilities/text/lexicon/librispeech-lexicon.txt" 29 | lexicon = read_lexicon(lexicon_path) 30 | g2p = G2p() 31 | 32 | phones = [] 33 | word_phone_alignment = [] 34 | words = re.split(r"([,;.\-\?\!\s+])", text) 35 | 36 | if always_use_g2p: 37 | phones = list(filter(lambda p: p != " ", g2p(text))) 38 | else: 39 | for w in words: 40 | if w.lower() in lexicon: 41 | # phones += lexicon[w.lower()] 42 | p = lexicon[w.lower()] 43 | phones += p 44 | else: 45 | # phones += list(filter(lambda p: p != " ", g2p(w))) 46 | p = list(filter(lambda p: p != " ", g2p(w))) 47 | phones += p 48 | if len(p) > 0: 49 | word_phone_alignment.append((w, p, len(p))) 50 | 51 | phones = "{" + "}{".join(phones) + "}" 52 | phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) 53 | phones = phones.replace("}{", " ") 54 | 55 | if verbose: 56 | print("Raw Text Sequence: {}".format(text)) 57 | print("Phoneme Sequence: {}".format(phones)) 58 | sequence = text_to_sequence(phones, ["english_cleaners"]) 59 | 60 | if return_word_phone_alignment: 61 | return sequence, phones, word_phone_alignment 62 | else: 63 | return sequence, phones -------------------------------------------------------------------------------- /drawspeech/utilities/preprocessor/preprocess_phoneme_level.yaml: -------------------------------------------------------------------------------- 1 | dataset: "LJSpeech" 2 | 3 | path: 4 | corpus_path: "data/dataset/LJSpeech-1.1" 5 | TextGrid_path: "data/dataset/LJSpeech-1.1/TextGrid" 6 | raw_path: "data/dataset/LJSpeech-1.1/wavs" 7 | preprocessed_path: "data/dataset/metadata/ljspeech/phoneme_level" 8 | 9 | preprocessing: 10 | audio: 11 | sampling_rate: 22050 12 | max_wav_value: 32768.0 13 | stft: 14 | filter_length: 1024 15 | hop_length: 256 16 | win_length: 1024 17 | mel: 18 | n_mel_channels: 80 19 | mel_fmin: 0 20 | mel_fmax: 8000 # please set to 8000 for HiFi-GAN vocoder, set to null for MelGAN vocoder 21 | pitch: 22 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 23 | normalization: True 24 | energy: 25 | feature: "phoneme_level" # support 'phoneme_level' or 'frame_level' 26 | normalization: True 27 | -------------------------------------------------------------------------------- /drawspeech/utilities/preprocessor/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import argparse 5 | import yaml 6 | 7 | import tgt 8 | import librosa 9 | import numpy as np 10 | import pyworld as pw 11 | from scipy.interpolate import interp1d 12 | from sklearn.preprocessing import StandardScaler 13 | from tqdm import tqdm 14 | 15 | import sys 16 | 17 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../'))) 18 | import drawspeech.utilities.audio as Audio 19 | 20 | class Preprocessor: 21 | def __init__(self, config): 22 | self.config = config 23 | self.in_dir = config["path"]["raw_path"] 24 | self.out_dir = config["path"]["preprocessed_path"] 25 | # self.val_size = config["preprocessing"]["val_size"] 26 | self.sampling_rate = config["preprocessing"]["audio"]["sampling_rate"] 27 | self.hop_length = config["preprocessing"]["stft"]["hop_length"] 28 | 29 | assert config["preprocessing"]["pitch"]["feature"] in [ 30 | "phoneme_level", 31 | "frame_level", 32 | ] 33 | assert config["preprocessing"]["energy"]["feature"] in [ 34 | "phoneme_level", 35 | "frame_level", 36 | ] 37 | self.pitch_phoneme_averaging = ( 38 | config["preprocessing"]["pitch"]["feature"] == "phoneme_level" 39 | ) 40 | self.energy_phoneme_averaging = ( 41 | config["preprocessing"]["energy"]["feature"] == "phoneme_level" 42 | ) 43 | 44 | self.pitch_normalization = config["preprocessing"]["pitch"]["normalization"] 45 | self.energy_normalization = config["preprocessing"]["energy"]["normalization"] 46 | 47 | self.STFT = Audio.stft.TacotronSTFT( 48 | config["preprocessing"]["stft"]["filter_length"], 49 | config["preprocessing"]["stft"]["hop_length"], 50 | config["preprocessing"]["stft"]["win_length"], 51 | config["preprocessing"]["mel"]["n_mel_channels"], 52 | config["preprocessing"]["audio"]["sampling_rate"], 53 | config["preprocessing"]["mel"]["mel_fmin"], 54 | config["preprocessing"]["mel"]["mel_fmax"], 55 | ) 56 | 57 | def build_from_path(self): 58 | # os.makedirs((os.path.join(self.out_dir, "mel")), exist_ok=True) 59 | os.makedirs((os.path.join(self.out_dir, "pitch")), exist_ok=True) 60 | os.makedirs((os.path.join(self.out_dir, "energy")), exist_ok=True) 61 | os.makedirs((os.path.join(self.out_dir, "duration")), exist_ok=True) 62 | 63 | print("Processing Data ...") 64 | out = list() 65 | n_frames = 0 66 | pitch_scaler = StandardScaler() 67 | energy_scaler = StandardScaler() 68 | 69 | # Compute pitch, energy, duration, and mel-spectrogram 70 | speaker = "LJSpeech" 71 | for wav_name in tqdm(os.listdir(self.in_dir)): 72 | if ".wav" not in wav_name: 73 | continue 74 | 75 | basename = wav_name.split(".")[0] 76 | tg_path = os.path.join( 77 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 78 | ) 79 | if os.path.exists(tg_path): 80 | ret = self.process_utterance(speaker, basename) 81 | if ret is None: 82 | continue 83 | else: 84 | info, pitch, energy, n = ret 85 | out.append(info) 86 | 87 | if len(pitch) > 0: 88 | pitch_scaler.partial_fit(pitch.reshape((-1, 1))) 89 | if len(energy) > 0: 90 | energy_scaler.partial_fit(energy.reshape((-1, 1))) 91 | 92 | n_frames += n 93 | 94 | print("Computing statistic quantities ...") 95 | # Perform normalization if necessary 96 | if self.pitch_normalization: 97 | pitch_mean = pitch_scaler.mean_[0] 98 | pitch_std = pitch_scaler.scale_[0] 99 | else: 100 | # A numerical trick to avoid normalization... 101 | pitch_mean = 0 102 | pitch_std = 1 103 | if self.energy_normalization: 104 | energy_mean = energy_scaler.mean_[0] 105 | energy_std = energy_scaler.scale_[0] 106 | else: 107 | energy_mean = 0 108 | energy_std = 1 109 | 110 | pitch_min, pitch_max = self.normalize( 111 | os.path.join(self.out_dir, "pitch"), pitch_mean, pitch_std 112 | ) 113 | energy_min, energy_max = self.normalize( 114 | os.path.join(self.out_dir, "energy"), energy_mean, energy_std 115 | ) 116 | 117 | # Save files 118 | # with open(os.path.join(self.out_dir, "speakers.json"), "w") as f: 119 | # f.write(json.dumps(speakers)) 120 | 121 | with open(os.path.join(self.out_dir, "stats.json"), "w") as f: 122 | stats = { 123 | "pitch": [ 124 | float(pitch_min), 125 | float(pitch_max), 126 | float(pitch_mean), 127 | float(pitch_std), 128 | ], 129 | "energy": [ 130 | float(energy_min), 131 | float(energy_max), 132 | float(energy_mean), 133 | float(energy_std), 134 | ], 135 | } 136 | f.write(json.dumps(stats)) 137 | 138 | print( 139 | "Total time: {} hours".format( 140 | n_frames * self.hop_length / self.sampling_rate / 3600 141 | ) 142 | ) 143 | 144 | # random.shuffle(out) 145 | # out = [r for r in out if r is not None] 146 | 147 | # Write metadata 148 | with open(os.path.join(self.out_dir, "metadata.txt"), "w", encoding="utf-8") as f: 149 | for m in out: 150 | f.write(m + "\n") 151 | # with open(os.path.join(self.out_dir, "train.txt"), "w", encoding="utf-8") as f: 152 | # for m in out[self.val_size :]: 153 | # f.write(m + "\n") 154 | # with open(os.path.join(self.out_dir, "val.txt"), "w", encoding="utf-8") as f: 155 | # for m in out[: self.val_size]: 156 | # f.write(m + "\n") 157 | 158 | return out 159 | 160 | def process_utterance(self, speaker, basename): 161 | wav_path = os.path.join(self.in_dir, "{}.wav".format(basename)) 162 | # text_path = os.path.join(self.in_dir, speaker, "{}.lab".format(basename)) 163 | tg_path = os.path.join( 164 | self.out_dir, "TextGrid", speaker, "{}.TextGrid".format(basename) 165 | ) 166 | 167 | # Get alignments 168 | textgrid = tgt.io.read_textgrid(tg_path) 169 | phone, duration, start, end = self.get_alignment( 170 | textgrid.get_tier_by_name("phones") 171 | ) 172 | text = "{" + " ".join(phone) + "}" 173 | if start >= end: 174 | return None 175 | 176 | assert start == 0, "start time is not 0, which may cause misalignment" 177 | 178 | # Read and trim wav files 179 | wav, _ = librosa.load(wav_path) 180 | wav = wav[ 181 | int(self.sampling_rate * start) : int(self.sampling_rate * end) 182 | ].astype(np.float32) 183 | 184 | # Read raw text 185 | # with open(text_path, "r") as f: 186 | # raw_text = f.readline().strip("\n") 187 | raw_text = "" 188 | 189 | # Compute fundamental frequency 190 | pitch, t = pw.dio( 191 | wav.astype(np.float64), 192 | self.sampling_rate, 193 | frame_period=self.hop_length / self.sampling_rate * 1000, 194 | ) 195 | pitch = pw.stonemask(wav.astype(np.float64), pitch, t, self.sampling_rate) 196 | 197 | pitch = pitch[: sum(duration)] 198 | if np.sum(pitch != 0) <= 1: 199 | return None 200 | 201 | # Compute mel-scale spectrogram and energy 202 | # mel_spectrogram, energy = Audio.tools.get_mel_from_wav(wav, self.STFT) 203 | mel_spectrogram, magnitudes, energy = Audio.tools.get_mel_from_wav(wav, self.STFT) 204 | mel_spectrogram = mel_spectrogram[:, : sum(duration)] 205 | energy = energy[: sum(duration)] 206 | 207 | if self.pitch_phoneme_averaging: 208 | # perform linear interpolation 209 | nonzero_ids = np.where(pitch != 0)[0] 210 | interp_fn = interp1d( 211 | nonzero_ids, 212 | pitch[nonzero_ids], 213 | fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), 214 | bounds_error=False, 215 | ) 216 | pitch = interp_fn(np.arange(0, len(pitch))) 217 | 218 | # Phoneme-level average 219 | pos = 0 220 | for i, d in enumerate(duration): 221 | if d > 0: 222 | pitch[i] = np.mean(pitch[pos : pos + d]) 223 | else: 224 | pitch[i] = 0 225 | pos += d 226 | pitch = pitch[: len(duration)] 227 | 228 | if self.energy_phoneme_averaging: 229 | # Phoneme-level average 230 | pos = 0 231 | for i, d in enumerate(duration): 232 | if d > 0: 233 | energy[i] = np.mean(energy[pos : pos + d]) 234 | else: 235 | energy[i] = 0 236 | pos += d 237 | energy = energy[: len(duration)] 238 | 239 | # Save files 240 | dur_filename = "{}-duration-{}.npy".format(speaker, basename) 241 | np.save(os.path.join(self.out_dir, "duration", dur_filename), duration) 242 | 243 | pitch_filename = "{}-pitch-{}.npy".format(speaker, basename) 244 | np.save(os.path.join(self.out_dir, "pitch", pitch_filename), pitch) 245 | 246 | energy_filename = "{}-energy-{}.npy".format(speaker, basename) 247 | np.save(os.path.join(self.out_dir, "energy", energy_filename), energy) 248 | 249 | # mel_filename = "{}-mel-{}.npy".format(speaker, basename) 250 | # np.save( 251 | # os.path.join(self.out_dir, "mel", mel_filename), 252 | # mel_spectrogram.T, 253 | # ) 254 | 255 | return ( 256 | "|".join([basename, speaker, text, raw_text]), 257 | self.remove_outlier(pitch), 258 | self.remove_outlier(energy), 259 | mel_spectrogram.shape[1], 260 | ) 261 | 262 | def get_alignment(self, tier): 263 | sil_phones = ["sil", "sp", "spn"] 264 | 265 | phones = [] 266 | durations = [] 267 | start_time = 0 268 | end_time = 0 269 | end_idx = 0 270 | for t in tier._objects: 271 | s, e, p = t.start_time, t.end_time, t.text 272 | 273 | # Trim leading silences 274 | if phones == []: 275 | # if p in sil_phones: 276 | # continue 277 | # else: 278 | # start_time = s 279 | start_time = s 280 | 281 | # if p not in sil_phones: 282 | # # For ordinary phones 283 | # phones.append(p) 284 | # end_time = e 285 | # end_idx = len(phones) 286 | # else: 287 | # # For silent phones 288 | # phones.append(p) 289 | 290 | # Do not trim silences 291 | phones.append(p) 292 | end_time = e 293 | 294 | durations.append( 295 | int( 296 | np.round(e * self.sampling_rate / self.hop_length) 297 | - np.round(s * self.sampling_rate / self.hop_length) 298 | ) 299 | ) 300 | 301 | # Trim tailing silences 302 | # phones = phones[:end_idx] 303 | # durations = durations[:end_idx] 304 | 305 | return phones, durations, start_time, end_time 306 | 307 | def remove_outlier(self, values): 308 | values = np.array(values) 309 | p25 = np.percentile(values, 25) 310 | p75 = np.percentile(values, 75) 311 | lower = p25 - 1.5 * (p75 - p25) 312 | upper = p75 + 1.5 * (p75 - p25) 313 | normal_indices = np.logical_and(values > lower, values < upper) 314 | 315 | return values[normal_indices] 316 | 317 | def normalize(self, in_dir, mean, std): 318 | max_value = np.finfo(np.float64).min 319 | min_value = np.finfo(np.float64).max 320 | for filename in os.listdir(in_dir): 321 | filename = os.path.join(in_dir, filename) 322 | values = (np.load(filename) - mean) / std 323 | np.save(filename, values) 324 | 325 | max_value = max(max_value, max(values)) 326 | min_value = min(min_value, min(values)) 327 | 328 | return min_value, max_value 329 | 330 | if __name__ == "__main__": 331 | print("Extract pitch, energy, and duration.") 332 | parser = argparse.ArgumentParser() 333 | parser.add_argument("config", type=str, help="path to preprocess.yaml") 334 | args = parser.parse_args() 335 | 336 | config = yaml.load(open(args.config, "r"), Loader=yaml.FullLoader) 337 | 338 | if not os.path.exists(config["path"]["preprocessed_path"]): 339 | os.makedirs(config["path"]["preprocessed_path"]) 340 | 341 | if not os.path.exists(os.path.join(config["path"]["preprocessed_path"], "TextGrid")): 342 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')) 343 | cmd = "ln -s {} {}".format(os.path.join(root, config["path"]["TextGrid_path"]), os.path.join(config["path"]["preprocessed_path"], "TextGrid")) 344 | os.system(cmd) 345 | 346 | preprocessor = Preprocessor(config) 347 | preprocessor.build_from_path() 348 | -------------------------------------------------------------------------------- /drawspeech/utilities/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from . import cleaners 4 | from .symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | """ 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | 34 | if not m: 35 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 36 | break 37 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 38 | sequence += _arpabet_to_sequence(m.group(2)) 39 | text = m.group(3) 40 | 41 | return sequence 42 | 43 | 44 | def sequence_to_text(sequence): 45 | """Converts a sequence of IDs back to a string""" 46 | result = "" 47 | for symbol_id in sequence: 48 | if symbol_id in _id_to_symbol: 49 | s = _id_to_symbol[symbol_id] 50 | # Enclose ARPAbet back in curly braces: 51 | if len(s) > 1 and s[0] == "@": 52 | s = "{%s}" % s[1:] 53 | result += s 54 | return result.replace("}{", " ") 55 | 56 | 57 | def _clean_text(text, cleaner_names): 58 | for name in cleaner_names: 59 | cleaner = getattr(cleaners, name) 60 | if not cleaner: 61 | raise Exception("Unknown cleaner: %s" % name) 62 | text = cleaner(text) 63 | return text 64 | 65 | 66 | def _symbols_to_sequence(symbols): 67 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 68 | 69 | 70 | def _arpabet_to_sequence(text): 71 | return _symbols_to_sequence(["@" + s for s in text.split()]) 72 | 73 | 74 | def _should_keep_symbol(s): 75 | return s in _symbol_to_id and s != "_" and s != "~" 76 | -------------------------------------------------------------------------------- /drawspeech/utilities/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | import re 18 | from unidecode import unidecode 19 | from .numbers import normalize_numbers 20 | _whitespace_re = re.compile(r'\s+') 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 69 | text = lowercase(text) 70 | text = collapse_whitespace(text) 71 | return text 72 | 73 | 74 | def transliteration_cleaners(text): 75 | '''Pipeline for non-English text that transliterates to ASCII.''' 76 | text = convert_to_ascii(text) 77 | text = lowercase(text) 78 | text = collapse_whitespace(text) 79 | return text 80 | 81 | 82 | def english_cleaners(text): 83 | '''Pipeline for English text, including number and abbreviation expansion.''' 84 | text = convert_to_ascii(text) 85 | text = lowercase(text) 86 | text = expand_numbers(text) 87 | text = expand_abbreviations(text) 88 | text = collapse_whitespace(text) 89 | return text 90 | -------------------------------------------------------------------------------- /drawspeech/utilities/text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | "AA", 8 | "AA0", 9 | "AA1", 10 | "AA2", 11 | "AE", 12 | "AE0", 13 | "AE1", 14 | "AE2", 15 | "AH", 16 | "AH0", 17 | "AH1", 18 | "AH2", 19 | "AO", 20 | "AO0", 21 | "AO1", 22 | "AO2", 23 | "AW", 24 | "AW0", 25 | "AW1", 26 | "AW2", 27 | "AY", 28 | "AY0", 29 | "AY1", 30 | "AY2", 31 | "B", 32 | "CH", 33 | "D", 34 | "DH", 35 | "EH", 36 | "EH0", 37 | "EH1", 38 | "EH2", 39 | "ER", 40 | "ER0", 41 | "ER1", 42 | "ER2", 43 | "EY", 44 | "EY0", 45 | "EY1", 46 | "EY2", 47 | "F", 48 | "G", 49 | "HH", 50 | "IH", 51 | "IH0", 52 | "IH1", 53 | "IH2", 54 | "IY", 55 | "IY0", 56 | "IY1", 57 | "IY2", 58 | "JH", 59 | "K", 60 | "L", 61 | "M", 62 | "N", 63 | "NG", 64 | "OW", 65 | "OW0", 66 | "OW1", 67 | "OW2", 68 | "OY", 69 | "OY0", 70 | "OY1", 71 | "OY2", 72 | "P", 73 | "R", 74 | "S", 75 | "SH", 76 | "T", 77 | "TH", 78 | "UH", 79 | "UH0", 80 | "UH1", 81 | "UH2", 82 | "UW", 83 | "UW0", 84 | "UW1", 85 | "UW2", 86 | "V", 87 | "W", 88 | "Y", 89 | "Z", 90 | "ZH", 91 | ] 92 | 93 | _valid_symbol_set = set(valid_symbols) 94 | 95 | 96 | class CMUDict: 97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 98 | 99 | def __init__(self, file_or_path, keep_ambiguous=True): 100 | if isinstance(file_or_path, str): 101 | with open(file_or_path, encoding="latin-1") as f: 102 | entries = _parse_cmudict(f) 103 | else: 104 | entries = _parse_cmudict(file_or_path) 105 | if not keep_ambiguous: 106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 107 | self._entries = entries 108 | 109 | def __len__(self): 110 | return len(self._entries) 111 | 112 | def lookup(self, word): 113 | """Returns list of ARPAbet pronunciations of the given word.""" 114 | return self._entries.get(word.upper()) 115 | 116 | 117 | _alt_re = re.compile(r"\([0-9]+\)") 118 | 119 | 120 | def _parse_cmudict(file): 121 | cmudict = {} 122 | for line in file: 123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 124 | parts = line.split(" ") 125 | word = re.sub(_alt_re, "", parts[0]) 126 | pronunciation = _get_pronunciation(parts[1]) 127 | if pronunciation: 128 | if word in cmudict: 129 | cmudict[word].append(pronunciation) 130 | else: 131 | cmudict[word] = [pronunciation] 132 | return cmudict 133 | 134 | 135 | def _get_pronunciation(s): 136 | parts = s.strip().split(" ") 137 | for part in parts: 138 | if part not in _valid_symbol_set: 139 | return None 140 | return " ".join(parts) 141 | -------------------------------------------------------------------------------- /drawspeech/utilities/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return "%s %s" % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return "%s %s" % (cents, cent_unit) 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words( 60 | num, andword="", zero="oh", group=2 61 | ).replace(", ", " ") 62 | else: 63 | return _inflect.number_to_words(num, andword="") 64 | 65 | 66 | def normalize_numbers(text): 67 | text = re.sub(_comma_number_re, _remove_commas, text) 68 | text = re.sub(_pounds_re, r"\1 pounds", text) 69 | text = re.sub(_dollars_re, _expand_dollars, text) 70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 71 | text = re.sub(_ordinal_re, _expand_ordinal, text) 72 | text = re.sub(_number_re, _expand_number, text) 73 | return text 74 | -------------------------------------------------------------------------------- /drawspeech/utilities/text/pinyin.py: -------------------------------------------------------------------------------- 1 | initials = [ 2 | "b", 3 | "c", 4 | "ch", 5 | "d", 6 | "f", 7 | "g", 8 | "h", 9 | "j", 10 | "k", 11 | "l", 12 | "m", 13 | "n", 14 | "p", 15 | "q", 16 | "r", 17 | "s", 18 | "sh", 19 | "t", 20 | "w", 21 | "x", 22 | "y", 23 | "z", 24 | "zh", 25 | ] 26 | finals = [ 27 | "a1", 28 | "a2", 29 | "a3", 30 | "a4", 31 | "a5", 32 | "ai1", 33 | "ai2", 34 | "ai3", 35 | "ai4", 36 | "ai5", 37 | "an1", 38 | "an2", 39 | "an3", 40 | "an4", 41 | "an5", 42 | "ang1", 43 | "ang2", 44 | "ang3", 45 | "ang4", 46 | "ang5", 47 | "ao1", 48 | "ao2", 49 | "ao3", 50 | "ao4", 51 | "ao5", 52 | "e1", 53 | "e2", 54 | "e3", 55 | "e4", 56 | "e5", 57 | "ei1", 58 | "ei2", 59 | "ei3", 60 | "ei4", 61 | "ei5", 62 | "en1", 63 | "en2", 64 | "en3", 65 | "en4", 66 | "en5", 67 | "eng1", 68 | "eng2", 69 | "eng3", 70 | "eng4", 71 | "eng5", 72 | "er1", 73 | "er2", 74 | "er3", 75 | "er4", 76 | "er5", 77 | "i1", 78 | "i2", 79 | "i3", 80 | "i4", 81 | "i5", 82 | "ia1", 83 | "ia2", 84 | "ia3", 85 | "ia4", 86 | "ia5", 87 | "ian1", 88 | "ian2", 89 | "ian3", 90 | "ian4", 91 | "ian5", 92 | "iang1", 93 | "iang2", 94 | "iang3", 95 | "iang4", 96 | "iang5", 97 | "iao1", 98 | "iao2", 99 | "iao3", 100 | "iao4", 101 | "iao5", 102 | "ie1", 103 | "ie2", 104 | "ie3", 105 | "ie4", 106 | "ie5", 107 | "ii1", 108 | "ii2", 109 | "ii3", 110 | "ii4", 111 | "ii5", 112 | "iii1", 113 | "iii2", 114 | "iii3", 115 | "iii4", 116 | "iii5", 117 | "in1", 118 | "in2", 119 | "in3", 120 | "in4", 121 | "in5", 122 | "ing1", 123 | "ing2", 124 | "ing3", 125 | "ing4", 126 | "ing5", 127 | "iong1", 128 | "iong2", 129 | "iong3", 130 | "iong4", 131 | "iong5", 132 | "iou1", 133 | "iou2", 134 | "iou3", 135 | "iou4", 136 | "iou5", 137 | "o1", 138 | "o2", 139 | "o3", 140 | "o4", 141 | "o5", 142 | "ong1", 143 | "ong2", 144 | "ong3", 145 | "ong4", 146 | "ong5", 147 | "ou1", 148 | "ou2", 149 | "ou3", 150 | "ou4", 151 | "ou5", 152 | "u1", 153 | "u2", 154 | "u3", 155 | "u4", 156 | "u5", 157 | "ua1", 158 | "ua2", 159 | "ua3", 160 | "ua4", 161 | "ua5", 162 | "uai1", 163 | "uai2", 164 | "uai3", 165 | "uai4", 166 | "uai5", 167 | "uan1", 168 | "uan2", 169 | "uan3", 170 | "uan4", 171 | "uan5", 172 | "uang1", 173 | "uang2", 174 | "uang3", 175 | "uang4", 176 | "uang5", 177 | "uei1", 178 | "uei2", 179 | "uei3", 180 | "uei4", 181 | "uei5", 182 | "uen1", 183 | "uen2", 184 | "uen3", 185 | "uen4", 186 | "uen5", 187 | "uo1", 188 | "uo2", 189 | "uo3", 190 | "uo4", 191 | "uo5", 192 | "v1", 193 | "v2", 194 | "v3", 195 | "v4", 196 | "v5", 197 | "van1", 198 | "van2", 199 | "van3", 200 | "van4", 201 | "van5", 202 | "ve1", 203 | "ve2", 204 | "ve3", 205 | "ve4", 206 | "ve5", 207 | "vn1", 208 | "vn2", 209 | "vn3", 210 | "vn4", 211 | "vn5", 212 | ] 213 | valid_symbols = initials + finals + ["rr"] -------------------------------------------------------------------------------- /drawspeech/utilities/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | """ 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """ 7 | 8 | from . import cmudict, pinyin 9 | 10 | _pad = "_" 11 | _punctuation = "!'(),.:;? " 12 | _special = "-" 13 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 14 | _silences = ["@sp", "@spn", "@sil"] 15 | 16 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 17 | _arpabet = ["@" + s for s in cmudict.valid_symbols] 18 | _pinyin = ["@" + s for s in pinyin.valid_symbols] 19 | 20 | # Export all symbols: 21 | symbols = ( 22 | [_pad] 23 | + list(_special) 24 | + list(_punctuation) 25 | + list(_letters) 26 | + _arpabet 27 | + _pinyin 28 | + _silences 29 | ) 30 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import random 5 | import librosa 6 | from g2p_en import G2p 7 | from tqdm import tqdm 8 | import numpy as np 9 | import torch 10 | 11 | from drawspeech.utilities.preprocessor.preprocess_one_sample import preprocess_english 12 | from drawspeech.utilities.tools import sketch_extractor 13 | 14 | def format_ljspeech(dataset_root="/mnt/users/hccl.local/wdchen/dataset/LJSpeech-1.1"): 15 | ''' Create json file for ljspeech dataset. 16 | Provide the path where you save the LJSpeech dataset and let the program do the rest. 17 | ''' 18 | 19 | metadata = os.path.join(dataset_root, "metadata.csv") 20 | wav_path = os.path.join(dataset_root, "wavs") 21 | 22 | # if not os.path.exists("data/dataset/ljspeech"): 23 | # os.makedirs("data/dataset/ljspeech") 24 | # if not os.path.exists("data/dataset/ljspeech/wavs"): 25 | # cmd = f"ln -s {wav_path} data/dataset/ljspeech/wavs" 26 | # os.system(cmd) 27 | 28 | data = [] 29 | print("Start formating ljspeech dataset.") 30 | with open(os.path.join(metadata), encoding="utf-8") as f: 31 | for line in f: 32 | id, text, norm_text = line.strip().split("|") 33 | norm_text = re.sub(r'\(|\)|\[|\]|<.*?>', '', norm_text) # remove (), [], 34 | 35 | file_path = f"wavs/{id}.wav" 36 | duration = librosa.get_duration(filename=os.path.join(dataset_root, file_path)) 37 | 38 | data.append( 39 | { 40 | "wav": file_path, 41 | "transcription": norm_text, 42 | "duration": duration 43 | } 44 | ) 45 | 46 | # print("Perform text to phoneme conversion.") 47 | # g2p = G2p() 48 | # for d in tqdm(data, total=len(data)): 49 | # phoneme_idx, phoneme = preprocess_english(d["transcription"], always_use_g2p=True, g2p=g2p, verbose=False) 50 | # d["phonemes"] = phoneme 51 | 52 | num_data = len(data) 53 | print(f"Total {num_data} data in ljspeech dataset.") 54 | ids = list(range(num_data)) 55 | random.shuffle(ids) 56 | 57 | train_ids = ids[:12500] 58 | val_ids = ids[12500:12800] 59 | test_ids = ids[12800:] 60 | 61 | if not os.path.exists("data/dataset/metadata/ljspeech"): 62 | os.makedirs("data/dataset/metadata/ljspeech") 63 | 64 | json.dump({"data": [data[i] for i in train_ids]}, open(os.path.join("data/dataset/metadata/ljspeech", "train.json"), "w"), indent=1, ensure_ascii=False) 65 | json.dump({"data": [data[i] for i in val_ids]}, open(os.path.join("data/dataset/metadata/ljspeech", "val.json"), "w"), indent=1, ensure_ascii=False) 66 | json.dump({"data": [data[i] for i in test_ids]}, open(os.path.join("data/dataset/metadata/ljspeech", "test.json"), "w"), indent=1, ensure_ascii=False) 67 | print("Finish formating ljspeech dataset.") 68 | 69 | 70 | def add_phoneme_for_ljspeech(ljspeech_json_path = "data/dataset/metadata/ljspeech", fs_files = ["data/dataset/metadata/ljspeech/phoneme_level/metadata.txt"]): 71 | print("Add phoneme for ljspeech dataset.") 72 | names, phones = [], [] 73 | for fs_file in fs_files: 74 | for line in open(fs_file, "r").readlines(): 75 | fname, speaker, p, raw_text = line.strip().split("|") 76 | names.append(fname) 77 | phones.append(p) 78 | 79 | json_files = os.listdir(ljspeech_json_path) 80 | json_files = [f for f in json_files if f.endswith(".json")] 81 | for json_file in json_files: 82 | data = json.load(open(os.path.join(ljspeech_json_path, json_file), "r"))["data"] 83 | for d in tqdm(data): 84 | name = d["wav"].split("/")[-1].replace(".wav", "") 85 | for n, p in zip(names, phones): 86 | if n == name: 87 | d["phonemes"] = p 88 | break 89 | 90 | json.dump({"data": data}, open(os.path.join(ljspeech_json_path, json_file), "w"), indent=1, ensure_ascii=False) 91 | 92 | def find_min_max_values_in_sketch(metadata_root="data/dataset/metadata/ljspeech/phoneme_level"): 93 | print("Find min and max values in pitch and energy sketch.") 94 | pitch_dir = os.path.join(metadata_root, "pitch") 95 | energy_dir = os.path.join(metadata_root, "energy") 96 | stats_json = os.path.join(metadata_root, "stats.json") 97 | 98 | pitch_files = os.listdir(pitch_dir) 99 | energy_files = os.listdir(energy_dir) 100 | 101 | l_min = [] 102 | l_max = [] 103 | pitch_sketch_global_min = 1000 104 | pitch_sketch_global_max = -1000 105 | for f in tqdm(pitch_files): 106 | pitch = np.load(os.path.join(pitch_dir, f)) 107 | pitch_sketch = sketch_extractor(pitch) 108 | p_min = pitch_sketch.min() 109 | p_max = pitch_sketch.max() 110 | if p_min < pitch_sketch_global_min: 111 | pitch_sketch_global_min = p_min 112 | if p_max > pitch_sketch_global_max: 113 | pitch_sketch_global_max = p_max 114 | l_min.append(p_min) 115 | l_max.append(p_max) 116 | print("pitch_sketch global min and max: ") 117 | print(pitch_sketch_global_min, pitch_sketch_global_max) 118 | 119 | 120 | l_min = [] 121 | l_max = [] 122 | energy_sketch_global_min = 1000 123 | energy_sketch_global_max = -1000 124 | for f in tqdm(energy_files): 125 | energy = np.load(os.path.join(energy_dir, f)) 126 | energy_sketch = sketch_extractor(energy) 127 | e_min = energy_sketch.min() 128 | e_max = energy_sketch.max() 129 | if e_min < energy_sketch_global_min: 130 | energy_sketch_global_min = e_min 131 | if e_max > energy_sketch_global_max: 132 | energy_sketch_global_max = e_max 133 | l_min.append(e_min) 134 | l_max.append(e_max) 135 | print("energy_sketch global min and max: ") 136 | print(energy_sketch_global_min, energy_sketch_global_max) 137 | 138 | with open(stats_json, "r") as f: 139 | stats = json.load(f) 140 | stats["pitch_sketch"] = [float(pitch_sketch_global_min), float(pitch_sketch_global_max)] 141 | stats["energy_sketch"] = [float(energy_sketch_global_min), float(energy_sketch_global_max)] 142 | 143 | with open(stats_json, "w") as f: 144 | f.write(json.dumps(stats)) 145 | 146 | if __name__ == "__main__": 147 | format_ljspeech("data/dataset/LJSpeech-1.1") 148 | 149 | cmd = "python drawspeech/utilities/preprocessor/preprocessor.py drawspeech/utilities/preprocessor/preprocess_phoneme_level.yaml" 150 | os.system(cmd) 151 | 152 | add_phoneme_for_ljspeech() 153 | 154 | find_min_max_values_in_sketch() 155 | 156 | --------------------------------------------------------------------------------