├── diffusion ├── model │ ├── __init__.py │ ├── nets │ │ ├── __init__.py │ │ ├── DiT2.py │ │ ├── DiT.py │ │ └── DiT_blocks.py │ └── utils.py ├── __init__.py ├── diffusion_utils.py └── respace.py ├── assets ├── dog.jpg ├── park.jpg ├── beach.jpg ├── church.png ├── tiger.png ├── pipeline.png ├── female_voice.wav ├── male_voice.wav └── bird_chirping.wav ├── i2a_translator └── __init__.py ├── voicedit ├── audio │ ├── __init__.py │ ├── audio_processing.py │ ├── tools.py │ └── stft.py ├── monotonic_align │ ├── build │ │ ├── temp.linux-x86_64-cpython-310 │ │ │ └── core.o │ │ └── lib.linux-x86_64-cpython-310 │ │ │ └── monotonic_align │ │ │ └── core.cpython-310-x86_64-linux-gnu.so │ ├── monotonic_align │ │ └── core.cpython-310-x86_64-linux-gnu.so │ ├── setup.py │ ├── __init__.py │ └── core.pyx ├── __init__.py ├── utils.py ├── data_ft.py ├── data.py ├── modules.py ├── text_encoder.py └── pipeline.py ├── requirements.txt ├── text ├── symbols.py ├── LICENSE ├── cleaners.py ├── cmudict.py ├── numbers.py └── __init__.py ├── README.md └── generate.py /diffusion/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .nets import * 2 | -------------------------------------------------------------------------------- /diffusion/model/nets/__init__.py: -------------------------------------------------------------------------------- 1 | from .DiT import DiT, DiT_XL_2 -------------------------------------------------------------------------------- /assets/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/dog.jpg -------------------------------------------------------------------------------- /assets/park.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/park.jpg -------------------------------------------------------------------------------- /assets/beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/beach.jpg -------------------------------------------------------------------------------- /assets/church.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/church.png -------------------------------------------------------------------------------- /assets/tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/tiger.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /assets/female_voice.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/female_voice.wav -------------------------------------------------------------------------------- /assets/male_voice.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/male_voice.wav -------------------------------------------------------------------------------- /assets/bird_chirping.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/bird_chirping.wav -------------------------------------------------------------------------------- /i2a_translator/__init__.py: -------------------------------------------------------------------------------- 1 | from i2a_translator.i2a_translator import DiffusionPriorNetwork, DiffusionPrior -------------------------------------------------------------------------------- /voicedit/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import get_mel_from_wav, wav_to_fbank, read_wav_file, raw_waveform_to_fbank 2 | from .stft import TacotronSTFT 3 | -------------------------------------------------------------------------------- /voicedit/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/voicedit/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o -------------------------------------------------------------------------------- /voicedit/monotonic_align/monotonic_align/core.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/voicedit/monotonic_align/monotonic_align/core.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /voicedit/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import AudioDataset, CollateFn 2 | from .data_ft import AudioDataset_ft, CollateFn_ft 3 | from .modules import DiTWrapper 4 | from .text_encoder import TextEncoder 5 | from .pipeline import VoiceDiTPipeline, EvalPipeline, VoiceUNetPipeline -------------------------------------------------------------------------------- /voicedit/monotonic_align/build/lib.linux-x86_64-cpython-310/monotonic_align/core.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/voicedit/monotonic_align/build/lib.linux-x86_64-cpython-310/monotonic_align/core.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /voicedit/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | 5 | setup( 6 | name = 'monotonic_align', 7 | ext_modules = cythonize("core.pyx"), 8 | include_dirs=[numpy.get_include()] 9 | ) 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.1 2 | torchaudio==2.1.1 3 | torchvision==0.16.1 4 | diffusers 5 | timm==0.6.12 6 | accelerate 7 | tensorboard 8 | transformers 9 | sentencepiece~=0.1.99 10 | ftfy 11 | beautifulsoup4 12 | opencv-python 13 | einops 14 | xformers==0.0.23 15 | librosa 16 | pandas 17 | scipy 18 | tqdm 19 | thop -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | from text import cmudict 4 | 5 | _pad = '_' 6 | _punctuation = '!\'(),.:;? ' 7 | _special = '-' 8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 9 | 10 | # Prepend "@" to ARPAbet symbols to ensure uniqueness: 11 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 15 | -------------------------------------------------------------------------------- /voicedit/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .monotonic_align.core import maximum_path_c 4 | 5 | 6 | def maximum_path(value, mask): 7 | """ Cython optimised version. 8 | value: [b, t_x, t_y] 9 | mask: [b, t_x, t_y] 10 | """ 11 | value = value * mask 12 | device = value.device 13 | dtype = value.dtype 14 | value = value.data.cpu().numpy().astype(np.float32) 15 | path = np.zeros_like(value).astype(np.int32) 16 | mask = mask.data.cpu().numpy() 17 | 18 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 19 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 20 | maximum_path_c(path, value, t_x_max, t_y_max) 21 | return torch.from_numpy(path).to(device=device, dtype=dtype) 22 | -------------------------------------------------------------------------------- /voicedit/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | cimport cython 4 | from cython.parallel import prange 5 | 6 | 7 | @cython.boundscheck(False) 8 | @cython.wraparound(False) 9 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 10 | cdef int x 11 | cdef int y 12 | cdef float v_prev 13 | cdef float v_cur 14 | cdef float tmp 15 | cdef int index = t_x - 1 16 | 17 | for y in range(t_y): 18 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 19 | if x == y: 20 | v_cur = max_neg_val 21 | else: 22 | v_cur = value[x, y-1] 23 | if x == 0: 24 | if y == 0: 25 | v_prev = 0. 26 | else: 27 | v_prev = max_neg_val 28 | else: 29 | v_prev = value[x-1, y-1] 30 | value[x, y] = max(v_cur, v_prev) + value[x, y] 31 | 32 | for y in range(t_y - 1, -1, -1): 33 | path[index, y] = 1 34 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 35 | index = index - 1 36 | 37 | 38 | @cython.boundscheck(False) 39 | @cython.wraparound(False) 40 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 41 | cdef int b = values.shape[0] 42 | 43 | cdef int i 44 | for i in prange(b, nogil=True): 45 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 46 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | CMUdict 2 | ------- 3 | 4 | CMUdict (the Carnegie Mellon Pronouncing Dictionary) is a free 5 | pronouncing dictionary of English, suitable for uses in speech 6 | technology and is maintained by the Speech Group in the School of 7 | Computer Science at Carnegie Mellon University. 8 | 9 | The Carnegie Mellon Speech Group does not guarantee the accuracy of 10 | this dictionary, nor its suitability for any specific purpose. In 11 | fact, we expect a number of errors, omissions and inconsistencies to 12 | remain in the dictionary. We intend to continually update the 13 | dictionary by correction existing entries and by adding new ones. From 14 | time to time a new major version will be released. 15 | 16 | We welcome input from users: Please send email to Alex Rudnicky 17 | (air+cmudict@cs.cmu.edu). 18 | 19 | The Carnegie Mellon Pronouncing Dictionary, in its current and 20 | previous versions is Copyright (C) 1993-2014 by Carnegie Mellon 21 | University. Use of this dictionary for any research or commercial 22 | purpose is completely unrestricted. If you make use of or 23 | redistribute this material we request that you acknowledge its 24 | origin in your descriptions. 25 | 26 | If you add words to or correct words in your version of this 27 | dictionary, we would appreciate it if you could send these additions 28 | and corrections to us (air+cmudict@cs.cmu.edu) for consideration in a 29 | subsequent version. All submissions will be reviewed and approved by 30 | the current maintainer, Alex Rudnicky at Carnegie Mellon. 31 | -------------------------------------------------------------------------------- /voicedit/utils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import torch 4 | 5 | 6 | def sequence_mask(length, max_length=None): 7 | if max_length is None: 8 | max_length = length.max() 9 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) 10 | return x.unsqueeze(0) < length.unsqueeze(1) 11 | 12 | 13 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 14 | while True: 15 | if length % (2**num_downsamplings_in_unet) == 0: 16 | return length 17 | length += 1 18 | 19 | 20 | def convert_pad_shape(pad_shape): 21 | l = pad_shape[::-1] 22 | pad_shape = [item for sublist in l for item in sublist] 23 | return pad_shape 24 | 25 | 26 | def generate_path(duration, mask): 27 | device = duration.device 28 | 29 | b, t_x, t_y = mask.shape 30 | cum_duration = torch.cumsum(duration, 1) 31 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 32 | 33 | cum_duration_flat = cum_duration.view(b * t_x) 34 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 35 | path = path.view(b, t_x, t_y) 36 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], 37 | [1, 0], [0, 0]]))[:, :-1] 38 | path = path * mask 39 | return path 40 | 41 | 42 | def duration_loss(logw, logw_, lengths): 43 | loss = torch.sum((logw - logw_)**2) / torch.sum(lengths) 44 | return loss 45 | 46 | 47 | def intersperse(lst, item): 48 | # Adds blank symbol 49 | result = [item] * (len(lst) * 2 + 1) 50 | result[1::2] = lst 51 | return result -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | from unidecode import unidecode 5 | from .numbers import normalize_numbers 6 | 7 | 8 | _whitespace_re = re.compile(r'\s+') 9 | 10 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 11 | ('mrs', 'misess'), 12 | ('mr', 'mister'), 13 | ('dr', 'doctor'), 14 | ('st', 'saint'), 15 | ('co', 'company'), 16 | ('jr', 'junior'), 17 | ('maj', 'major'), 18 | ('gen', 'general'), 19 | ('drs', 'doctors'), 20 | ('rev', 'reverend'), 21 | ('lt', 'lieutenant'), 22 | ('hon', 'honorable'), 23 | ('sgt', 'sergeant'), 24 | ('capt', 'captain'), 25 | ('esq', 'esquire'), 26 | ('ltd', 'limited'), 27 | ('col', 'colonel'), 28 | ('ft', 'fort'), 29 | ]] 30 | 31 | 32 | def expand_abbreviations(text): 33 | for regex, replacement in _abbreviations: 34 | text = re.sub(regex, replacement, text) 35 | return text 36 | 37 | 38 | def expand_numbers(text): 39 | return normalize_numbers(text) 40 | 41 | 42 | def lowercase(text): 43 | return text.lower() 44 | 45 | 46 | def collapse_whitespace(text): 47 | return re.sub(_whitespace_re, ' ', text) 48 | 49 | 50 | def convert_to_ascii(text): 51 | return unidecode(text) 52 | 53 | 54 | def basic_cleaners(text): 55 | text = lowercase(text) 56 | text = collapse_whitespace(text) 57 | return text 58 | 59 | 60 | def transliteration_cleaners(text): 61 | text = convert_to_ascii(text) 62 | text = lowercase(text) 63 | text = collapse_whitespace(text) 64 | return text 65 | 66 | 67 | def english_cleaners(text): 68 | text = convert_to_ascii(text) 69 | text = lowercase(text) 70 | text = expand_numbers(text) 71 | text = expand_abbreviations(text) 72 | text = collapse_whitespace(text) 73 | return text 74 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | pred_sigma=True, 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000, 20 | snr=False, 21 | return_startx=False, 22 | ): 23 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 24 | if use_kl: 25 | loss_type = gd.LossType.RESCALED_KL 26 | elif rescale_learned_sigmas: 27 | loss_type = gd.LossType.RESCALED_MSE 28 | else: 29 | loss_type = gd.LossType.MSE 30 | if timestep_respacing is None or timestep_respacing == "": 31 | timestep_respacing = [diffusion_steps] 32 | return SpacedDiffusion( 33 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 34 | betas=betas, 35 | model_mean_type=( 36 | gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON 37 | ), 38 | model_var_type=( 39 | (gd.ModelVarType.LEARNED_RANGE if learn_sigma else ( 40 | gd.ModelVarType.FIXED_LARGE 41 | if not sigma_small 42 | else gd.ModelVarType.FIXED_SMALL 43 | ) 44 | ) 45 | if pred_sigma 46 | else None 47 | ), 48 | loss_type=loss_type, 49 | snr=snr, 50 | return_startx=return_startx, 51 | # rescale_timesteps=rescale_timesteps, 52 | ) -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | def __init__(self, file_or_path, keep_ambiguous=True): 21 | if isinstance(file_or_path, str): 22 | with open(file_or_path, encoding='latin-1') as f: 23 | entries = _parse_cmudict(f) 24 | else: 25 | entries = _parse_cmudict(file_or_path) 26 | if not keep_ambiguous: 27 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 28 | self._entries = entries 29 | 30 | def __len__(self): 31 | return len(self._entries) 32 | 33 | def lookup(self, word): 34 | return self._entries.get(word.upper()) 35 | 36 | 37 | _alt_re = re.compile(r'\([0-9]+\)') 38 | 39 | 40 | def _parse_cmudict(file): 41 | cmudict = {} 42 | for line in file: 43 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 44 | parts = line.split(' ') 45 | word = re.sub(_alt_re, '', parts[0]) 46 | pronunciation = _get_pronunciation(parts[1]) 47 | if pronunciation: 48 | if word in cmudict: 49 | cmudict[word].append(pronunciation) 50 | else: 51 | cmudict[word] = [pronunciation] 52 | return cmudict 53 | 54 | 55 | def _get_pronunciation(s): 56 | parts = s.strip().split(' ') 57 | for part in parts: 58 | if part not in _valid_symbol_set: 59 | return None 60 | return ' '.join(parts) 61 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', 60 | group=2).replace(', ', ' ') 61 | else: 62 | return _inflect.number_to_words(num, andword='') 63 | 64 | 65 | def normalize_numbers(text): 66 | text = re.sub(_comma_number_re, _remove_commas, text) 67 | text = re.sub(_pounds_re, r'\1 pounds', text) 68 | text = re.sub(_dollars_re, _expand_dollars, text) 69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 70 | text = re.sub(_ordinal_re, _expand_ordinal, text) 71 | text = re.sub(_number_re, _expand_number, text) 72 | return text 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VoiceDiT 2 | 3 | This is a repository for the paper, [VoiceDiT: Dual-Condition Diffusion Transformer for Environment-Aware Speech Synthesis](https://arxiv.org/pdf/2412.19259), ICASSP 2025. 4 | 5 |

6 | 7 |

8 | 9 | VoiceDiT is a multi-modal generative model for producing environment-aware speech and audio from text and visual prompts. 10 | 11 | 12 | 13 | 14 | 15 | ## 🔧 Installation 16 | 17 | ### Install from source 18 | ```shell 19 | git clone https://github.com/kaistmm/VoiceDiT.git 20 | cd VoiceDiT 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## 📖 Usage 25 | 26 | - Generate audio with description prompt and content prompt: 27 | ```shell 28 | python generate.py --desc_prompt "She is talking in a park." --cont_prompt "Good morning! How are you feeling today?" 29 | ``` 30 | 31 | - Generate audio with audio prompt and content prompt: 32 | ```shell 33 | python generate.py --modality "audio" --desc_prompt "assets/bird_chirping.wav" --cont_prompt "Good morning! How are you feeling today?" 34 | ``` 35 | 36 | - Generate audio with image prompt and content prompt: 37 | ```shell 38 | python generate.py --modality "image" --desc_prompt "assets/park.jpg" --cont_prompt "Good morning! How are you feeling today?" 39 | ``` 40 | 41 | - Text-to-Speech Example: 42 | ```shell 43 | python generate.py --desc_prompt "clean speech" --cont_prompt "Good morning! How are you feeling today?" --desc_guidance_scale 1 --cont_guidance_scale 9 44 | ``` 45 | 46 | - Text-to-Audio Example: 47 | ```shell 48 | python generate.py --desc_prompt "trumpet" --cont_prompt "_" --desc_guidance_scale 9 --cont_guidance_scale 1 49 | ``` 50 | 51 | - Image-to-Audio Example: 52 | ```shell 53 | python generate.py --desc_prompt "assets/tiger.png" --cont_prompt "_" --v2a_guidance_scale 2 --desc_guidance_scale 9 --cont_guidance_scale 1 54 | ``` 55 | 56 | Generated audios will be saved at the default output folder `./outputs`. 57 | 58 | 59 | ## ⚙️ Full List of Options 60 | View the full list of options with the following command: 61 | ```console 62 | python generate.py -h 63 | ``` 64 | 65 | 66 | ## 🙏 Acknowledgements 67 | This work would not have been possible without the following repositories: 68 | 69 | [VoiceLDM](https://github.com/glory20h/VoiceLDM) 70 | 71 | [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha) 72 | 73 | [HuggingFace Diffusers](https://github.com/huggingface/diffusers) 74 | 75 | [HuggingFace Transformers](https://github.com/huggingface/transformers) 76 | 77 | [AudioLDM](https://github.com/haoheliu/AudioLDM) 78 | 79 | [naturalspeech](https://github.com/heatz123/naturalspeech) 80 | 81 | [audioldm_eval](https://github.com/haoheliu/audioldm_eval) 82 | 83 | [AudioLDM2](https://github.com/haoheliu/AudioLDM2) -------------------------------------------------------------------------------- /voicedit/audio/audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa.util as librosa_util 4 | from scipy.signal import get_window 5 | 6 | 7 | def window_sumsquare( 8 | window, 9 | n_frames, 10 | hop_length, 11 | win_length, 12 | n_fft, 13 | dtype=np.float32, 14 | norm=None, 15 | ): 16 | """ 17 | # from librosa 0.6 18 | Compute the sum-square envelope of a window function at a given hop length. 19 | 20 | This is used to estimate modulation effects induced by windowing 21 | observations in short-time fourier transforms. 22 | 23 | Parameters 24 | ---------- 25 | window : string, tuple, number, callable, or list-like 26 | Window specification, as in `get_window` 27 | 28 | n_frames : int > 0 29 | The number of analysis frames 30 | 31 | hop_length : int > 0 32 | The number of samples to advance between frames 33 | 34 | win_length : [optional] 35 | The length of the window function. By default, this matches `n_fft`. 36 | 37 | n_fft : int > 0 38 | The length of each analysis frame. 39 | 40 | dtype : np.dtype 41 | The data type of the output 42 | 43 | Returns 44 | ------- 45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 46 | The sum-squared envelope of the window function 47 | """ 48 | if win_length is None: 49 | win_length = n_fft 50 | 51 | n = n_fft + hop_length * (n_frames - 1) 52 | x = np.zeros(n, dtype=dtype) 53 | 54 | # Compute the squared window at the desired length 55 | win_sq = get_window(window, win_length, fftbins=True) 56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 57 | win_sq = librosa_util.pad_center(win_sq, n_fft) 58 | 59 | # Fill the envelope 60 | for i in range(n_frames): 61 | sample = i * hop_length 62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] 63 | return x 64 | 65 | 66 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 67 | """ 68 | PARAMS 69 | ------ 70 | magnitudes: spectrogram magnitudes 71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 72 | """ 73 | 74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 75 | angles = angles.astype(np.float32) 76 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 78 | 79 | for i in range(n_iters): 80 | _, angles = stft_fn.transform(signal) 81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 82 | return signal 83 | 84 | 85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): 86 | """ 87 | PARAMS 88 | ------ 89 | C: compression factor 90 | """ 91 | return normalize_fun(torch.clamp(x, min=clip_val) * C) 92 | 93 | 94 | def dynamic_range_decompression(x, C=1): 95 | """ 96 | PARAMS 97 | ------ 98 | C: compression factor used to compress 99 | """ 100 | return torch.exp(x) / C 101 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | from text import cleaners 5 | from text.symbols import symbols 6 | 7 | 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 12 | 13 | 14 | def get_arpabet(word, dictionary): 15 | word_arpabet = dictionary.lookup(word) 16 | if word_arpabet is not None: 17 | return "{" + word_arpabet[0] + "}" 18 | else: 19 | return word 20 | 21 | 22 | def text_to_sequence(text, cleaner_names=["english_cleaners"], dictionary=None): 23 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 24 | 25 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 26 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 27 | 28 | Args: 29 | text: string to convert to a sequence 30 | cleaner_names: names of the cleaner functions to run the text through 31 | dictionary: arpabet class with arpabet dictionary 32 | 33 | Returns: 34 | List of integers corresponding to the symbols in the text 35 | ''' 36 | sequence = [] 37 | space = _symbols_to_sequence(' ') 38 | # Check for curly braces and treat their contents as ARPAbet: 39 | while len(text): 40 | m = _curly_re.match(text) 41 | if not m: 42 | clean_text = _clean_text(text, cleaner_names) 43 | if dictionary is not None: 44 | clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")] 45 | for i in range(len(clean_text)): 46 | t = clean_text[i] 47 | if t.startswith("{"): 48 | sequence += _arpabet_to_sequence(t[1:-1]) 49 | else: 50 | sequence += _symbols_to_sequence(t) 51 | sequence += space 52 | else: 53 | sequence += _symbols_to_sequence(clean_text) 54 | break 55 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 56 | sequence += _arpabet_to_sequence(m.group(2)) 57 | text = m.group(3) 58 | 59 | # remove trailing space 60 | if dictionary is not None: 61 | sequence = sequence[:-1] if sequence[-1] == space[0] else sequence 62 | return sequence 63 | 64 | 65 | def sequence_to_text(sequence): 66 | '''Converts a sequence of IDs back to a string''' 67 | result = '' 68 | for symbol_id in sequence: 69 | if symbol_id in _id_to_symbol: 70 | s = _id_to_symbol[symbol_id] 71 | # Enclose ARPAbet back in curly braces: 72 | if len(s) > 1 and s[0] == '@': 73 | s = '{%s}' % s[1:] 74 | result += s 75 | return result.replace('}{', ' ') 76 | 77 | 78 | def _clean_text(text, cleaner_names): 79 | for name in cleaner_names: 80 | cleaner = getattr(cleaners, name) 81 | if not cleaner: 82 | raise Exception('Unknown cleaner: %s' % name) 83 | text = cleaner(text) 84 | return text 85 | 86 | 87 | def _symbols_to_sequence(symbols): 88 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 89 | 90 | 91 | def _arpabet_to_sequence(text): 92 | return _symbols_to_sequence(['@' + s for s in text.split()]) 93 | 94 | 95 | def _should_keep_symbol(s): 96 | return s in _symbol_to_id and s != '_' and s != '~' 97 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = next( 17 | ( 18 | obj 19 | for obj in (mean1, logvar1, mean2, logvar2) 20 | if isinstance(obj, th.Tensor) 21 | ), 22 | None, 23 | ) 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a continuous Gaussian distribution. 53 | :param x: the targets 54 | :param means: the Gaussian mean Tensor. 55 | :param log_scales: the Gaussian log stddev Tensor. 56 | :return: a tensor like x of log probabilities (in nats). 57 | """ 58 | centered_x = x - means 59 | inv_stdv = th.exp(-log_scales) 60 | normalized_x = centered_x * inv_stdv 61 | return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob( 62 | normalized_x 63 | ) 64 | 65 | 66 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 67 | """ 68 | Compute the log-likelihood of a Gaussian distribution discretizing to a 69 | given image. 70 | :param x: the target images. It is assumed that this was uint8 values, 71 | rescaled to the range [-1, 1]. 72 | :param means: the Gaussian mean Tensor. 73 | :param log_scales: the Gaussian log stddev Tensor. 74 | :return: a tensor like x of log probabilities (in nats). 75 | """ 76 | assert x.shape == means.shape == log_scales.shape 77 | centered_x = x - means 78 | inv_stdv = th.exp(-log_scales) 79 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 80 | cdf_plus = approx_standard_normal_cdf(plus_in) 81 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 82 | cdf_min = approx_standard_normal_cdf(min_in) 83 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 84 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 85 | cdf_delta = cdf_plus - cdf_min 86 | log_probs = th.where( 87 | x < -0.999, 88 | log_cdf_plus, 89 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 90 | ) 91 | assert log_probs.shape == x.shape 92 | return log_probs 93 | -------------------------------------------------------------------------------- /voicedit/audio/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torchaudio 4 | 5 | 6 | def get_mel_from_wav(audio, _stft): 7 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 8 | audio = torch.autograd.Variable(audio, requires_grad=False) 9 | melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) 10 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) 11 | log_magnitudes_stft = ( 12 | torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) 13 | ) 14 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32) 15 | return melspec, log_magnitudes_stft, energy 16 | 17 | 18 | def _pad_spec(fbank, target_length=1024): 19 | n_frames = fbank.shape[0] 20 | p = target_length - n_frames 21 | # cut and pad 22 | if p > 0: 23 | m = torch.nn.ZeroPad2d((0, 0, 0, p)) 24 | fbank = m(fbank) 25 | elif p < 0: 26 | fbank = fbank[0:target_length, :] 27 | 28 | if fbank.size(-1) % 2 != 0: 29 | fbank = fbank[..., :-1] 30 | 31 | return fbank 32 | 33 | 34 | def pad_wav(waveform, segment_length): 35 | waveform_length = waveform.shape[-1] 36 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length 37 | if segment_length is None or waveform_length == segment_length: 38 | return waveform 39 | elif waveform_length > segment_length: 40 | return waveform[:, :segment_length] 41 | elif waveform_length < segment_length: 42 | temp_wav = np.zeros((1, segment_length)) 43 | temp_wav[:, :waveform_length] = waveform 44 | return temp_wav 45 | 46 | def normalize_wav(waveform): 47 | waveform = waveform - np.mean(waveform) 48 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) 49 | return waveform * 0.5 50 | 51 | 52 | def read_wav_file(filename, segment_length): 53 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower 54 | waveform, sr = torchaudio.load(filename) # Faster!!! 55 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 56 | waveform = waveform.numpy()[0, ...] 57 | waveform = normalize_wav(waveform) 58 | waveform = waveform[None, ...] 59 | waveform = pad_wav(waveform, segment_length) 60 | 61 | waveform = waveform / np.max(np.abs(waveform)) 62 | waveform = 0.5 * waveform 63 | 64 | return waveform 65 | 66 | 67 | def wav_to_fbank(filename, target_length=1024, fn_STFT=None): 68 | assert fn_STFT is not None 69 | 70 | # mixup 71 | waveform = read_wav_file(filename, target_length * 160) # hop size is 160 72 | 73 | waveform = waveform[0, ...] 74 | waveform = torch.FloatTensor(waveform) 75 | 76 | fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) 77 | 78 | fbank = torch.FloatTensor(fbank.T) 79 | log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) 80 | 81 | fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( 82 | log_magnitudes_stft, target_length 83 | ) 84 | 85 | return fbank, log_magnitudes_stft, waveform 86 | 87 | 88 | def raw_waveform_to_fbank(waveform, target_length=1024, fn_STFT=None): 89 | assert fn_STFT is not None 90 | 91 | fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) 92 | 93 | fbank = torch.FloatTensor(fbank.T) 94 | log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) 95 | 96 | fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( 97 | log_magnitudes_stft, target_length 98 | ) 99 | 100 | return fbank, log_magnitudes_stft, waveform 101 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import warnings 4 | 5 | import torch 6 | import torchaudio 7 | 8 | from voicedit import VoiceDiTPipeline 9 | 10 | warnings.filterwarnings("ignore") # ignore warning 11 | 12 | 13 | def parse_args(parser): 14 | parser.add_argument("--modality", type=str, default="text", help="modality for generation. 'text' or 'audio' or 'image'") 15 | parser.add_argument('--desc_prompt', '-d', type=str, help='description prompt') 16 | parser.add_argument('--cont_prompt', '-c', type=str, help='content prompt') 17 | parser.add_argument('--speaker_audio', '-spk', type=str, help='speaker audio') 18 | 19 | parser.add_argument("--ckpt_path", type=str, required=True, help="checkpoint file path for VoiceDiT") 20 | parser.add_argument("--v2a_ckpt_path", type=str, help='checkpoint file path for V2A-Mapper') 21 | parser.add_argument("--output_dir", type=str, default="./outputs", help="directory to save generated audio") 22 | parser.add_argument("--file_name", type=str, help="filename for the generated audio") 23 | 24 | parser.add_argument('--num_inference_steps', type=int, default=100, help='number of inference steps for DDIM sampling') 25 | parser.add_argument('--audio_length_in_s', type=float, default=10, help='duration of the audio for generation') 26 | parser.add_argument('--v2a_guidance_scale', type=float, default=2.0, help='guidance weight for v2a-mapper classifier-free guidance') 27 | parser.add_argument('--guidance_scale', type=float, help='guidance weight for single classifier-free guidance') 28 | parser.add_argument('--desc_guidance_scale', type=float, default=5, required=False, help='desc guidance weight for dual classifier-free guidance') 29 | parser.add_argument('--cont_guidance_scale', type=float, default=5, required=False, help='cont guidance weight for dual classifier-free guidance') 30 | parser.add_argument('--female_voice_list', type=str, default='libri_test_clean_female.txt') 31 | parser.add_argument('--male_voice_list', type=str, default='libri_test_clean_male.txt') 32 | 33 | parser.add_argument("--device", type=str, default="auto", help="device to use for audio generation") 34 | parser.add_argument('--seed', type=int, help='random seed for generation') 35 | 36 | return parser.parse_args() 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | args = parse_args(parser) 42 | 43 | if args.device == "auto": 44 | if torch.cuda.is_available(): 45 | args.device = torch.device("cuda:0") 46 | else: 47 | args.device = torch.device("cpu") 48 | elif args.device is not None: 49 | args.device = torch.device(args.device) 50 | else: 51 | args.device = torch.device("cpu") 52 | 53 | pipe = VoiceDiTPipeline( 54 | args.ckpt_path, 55 | t2a_ckpt_path = args.t2a_ckpt_path, 56 | v2a_ckpt_path = args.v2a_ckpt_path, 57 | device = args.device, 58 | male_voice = args.male_voice_list, 59 | female_voice = args.female_voice_list, 60 | ) 61 | 62 | audio, _ = pipe( 63 | modality = args.modality, 64 | env_prompt = args.desc_prompt, 65 | cont_prompt = args.cont_prompt, 66 | batch_size = 1, 67 | num_inference_steps = args.num_inference_steps, 68 | audio_length_in_s = 10, 69 | do_classifier_free_guidance = True, 70 | desc_guidance_scale = args.desc_guidance_scale, 71 | cont_guidance_scale = args.cont_guidance_scale, 72 | v2a_guidance_scale = args.v2a_guidance_scale, 73 | device=args.device, 74 | seed=args.seed, 75 | progress=True, 76 | speaker_audio=args.speaker_audio, 77 | ) 78 | 79 | file_name = args.file_name 80 | if file_name is None: 81 | if args.modality == "text": 82 | file_name = args.desc_prompt[:10] 83 | else: 84 | file_name = os.path.basename(args.desc_prompt[:-4]) 85 | file_name = file_name + "-" + args.cont_prompt[:10] + ".wav" 86 | 87 | os.makedirs(args.output_dir, exist_ok=True) 88 | save_path = os.path.join(args.output_dir, file_name) 89 | 90 | torchaudio.save(save_path, src=audio, sample_rate=16000) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1) 52 | cur_idx = 0.0 53 | taken_steps = [] 54 | for _ in range(section_count): 55 | taken_steps.append(start_idx + round(cur_idx)) 56 | cur_idx += frac_stride 57 | all_steps += taken_steps 58 | start_idx += size 59 | return set(all_steps) 60 | 61 | 62 | class SpacedDiffusion(GaussianDiffusion): 63 | """ 64 | A diffusion process which can skip steps in a base diffusion process. 65 | :param use_timesteps: a collection (sequence or set) of timesteps from the 66 | original diffusion process to retain. 67 | :param kwargs: the kwargs to create the base diffusion process. 68 | """ 69 | 70 | def __init__(self, use_timesteps, **kwargs): 71 | self.use_timesteps = set(use_timesteps) 72 | self.timestep_map = [] 73 | self.original_num_steps = len(kwargs["betas"]) 74 | 75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 76 | last_alpha_cumprod = 1.0 77 | new_betas = [] 78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 79 | if i in self.use_timesteps: 80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 81 | last_alpha_cumprod = alpha_cumprod 82 | self.timestep_map.append(i) 83 | kwargs["betas"] = np.array(new_betas) 84 | super().__init__(**kwargs) 85 | 86 | def p_mean_variance( 87 | self, model, *args, **kwargs 88 | ): # pylint: disable=signature-differs 89 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 90 | 91 | def training_losses( 92 | self, model, *args, **kwargs 93 | ): # pylint: disable=signature-differs 94 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 95 | 96 | def training_losses_diffusers( 97 | self, model, *args, **kwargs 98 | ): # pylint: disable=signature-differs 99 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs) 100 | 101 | def condition_mean(self, cond_fn, *args, **kwargs): 102 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def condition_score(self, cond_fn, *args, **kwargs): 105 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 106 | 107 | def _wrap_model(self, model): 108 | if isinstance(model, _WrappedModel): 109 | return model 110 | return _WrappedModel( 111 | model, self.timestep_map, self.original_num_steps 112 | ) 113 | 114 | def _scale_timesteps(self, t): 115 | # Scaling is done by the wrapped model. 116 | return t 117 | 118 | 119 | class _WrappedModel: 120 | def __init__(self, model, timestep_map, original_num_steps): 121 | self.model = model 122 | self.timestep_map = timestep_map 123 | # self.rescale_timesteps = rescale_timesteps 124 | self.original_num_steps = original_num_steps 125 | 126 | def __call__(self, x, timestep, **kwargs): 127 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype) 128 | new_ts = map_tensor[timestep] 129 | # if self.rescale_timesteps: 130 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 131 | return self.model(x, timestep=new_ts, **kwargs) 132 | -------------------------------------------------------------------------------- /voicedit/audio/stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy.signal import get_window 5 | from librosa.util import pad_center, tiny 6 | from librosa.filters import mel as librosa_mel_fn 7 | 8 | from .audio_processing import ( 9 | dynamic_range_compression, 10 | dynamic_range_decompression, 11 | window_sumsquare, 12 | ) 13 | 14 | 15 | class STFT(torch.nn.Module): 16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 17 | 18 | def __init__(self, filter_length, hop_length, win_length, window="hann"): 19 | super(STFT, self).__init__() 20 | self.filter_length = filter_length 21 | self.hop_length = hop_length 22 | self.win_length = win_length 23 | self.window = window 24 | self.forward_transform = None 25 | scale = self.filter_length / self.hop_length 26 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 27 | 28 | cutoff = int((self.filter_length / 2 + 1)) 29 | fourier_basis = np.vstack( 30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] 31 | ) 32 | 33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 34 | inverse_basis = torch.FloatTensor( 35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :] 36 | ) 37 | 38 | if window is not None: 39 | assert filter_length >= win_length 40 | # get window and zero center pad it to filter_length 41 | fft_window = get_window(window, win_length, fftbins=True) 42 | fft_window = pad_center(fft_window, size=filter_length) 43 | fft_window = torch.from_numpy(fft_window).float() 44 | 45 | # window the bases 46 | forward_basis *= fft_window 47 | inverse_basis *= fft_window 48 | 49 | self.register_buffer("forward_basis", forward_basis.float()) 50 | self.register_buffer("inverse_basis", inverse_basis.float()) 51 | 52 | def transform(self, input_data): 53 | num_batches = input_data.size(0) 54 | num_samples = input_data.size(1) 55 | 56 | self.num_samples = num_samples 57 | 58 | # similar to librosa, reflect-pad the input 59 | input_data = input_data.view(num_batches, 1, num_samples) 60 | input_data = F.pad( 61 | input_data.unsqueeze(1), 62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 63 | mode="reflect", 64 | ) 65 | input_data = input_data.squeeze(1) 66 | 67 | forward_transform = F.conv1d( 68 | input_data, 69 | torch.autograd.Variable(self.forward_basis, requires_grad=False), 70 | stride=self.hop_length, 71 | padding=0, 72 | ).cpu() 73 | 74 | cutoff = int((self.filter_length / 2) + 1) 75 | real_part = forward_transform[:, :cutoff, :] 76 | imag_part = forward_transform[:, cutoff:, :] 77 | 78 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) 80 | 81 | return magnitude, phase 82 | 83 | def inverse(self, magnitude, phase): 84 | recombine_magnitude_phase = torch.cat( 85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 86 | ) 87 | 88 | inverse_transform = F.conv_transpose1d( 89 | recombine_magnitude_phase, 90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False), 91 | stride=self.hop_length, 92 | padding=0, 93 | ) 94 | 95 | if self.window is not None: 96 | window_sum = window_sumsquare( 97 | self.window, 98 | magnitude.size(-1), 99 | hop_length=self.hop_length, 100 | win_length=self.win_length, 101 | n_fft=self.filter_length, 102 | dtype=np.float32, 103 | ) 104 | # remove modulation effects 105 | approx_nonzero_indices = torch.from_numpy( 106 | np.where(window_sum > tiny(window_sum))[0] 107 | ) 108 | window_sum = torch.autograd.Variable( 109 | torch.from_numpy(window_sum), requires_grad=False 110 | ) 111 | window_sum = window_sum 112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ 113 | approx_nonzero_indices 114 | ] 115 | 116 | # scale by hop ratio 117 | inverse_transform *= float(self.filter_length) / self.hop_length 118 | 119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] 120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] 121 | 122 | return inverse_transform 123 | 124 | def forward(self, input_data): 125 | self.magnitude, self.phase = self.transform(input_data) 126 | reconstruction = self.inverse(self.magnitude, self.phase) 127 | return reconstruction 128 | 129 | 130 | class TacotronSTFT(torch.nn.Module): 131 | def __init__( 132 | self, 133 | filter_length, 134 | hop_length, 135 | win_length, 136 | n_mel_channels, 137 | sampling_rate, 138 | mel_fmin, 139 | mel_fmax, 140 | ): 141 | super(TacotronSTFT, self).__init__() 142 | self.n_mel_channels = n_mel_channels 143 | self.sampling_rate = sampling_rate 144 | self.stft_fn = STFT(filter_length, hop_length, win_length) 145 | mel_basis = librosa_mel_fn( 146 | sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax 147 | ) 148 | mel_basis = torch.from_numpy(mel_basis).float() 149 | self.register_buffer("mel_basis", mel_basis) 150 | 151 | def spectral_normalize(self, magnitudes, normalize_fun): 152 | output = dynamic_range_compression(magnitudes, normalize_fun) 153 | return output 154 | 155 | def spectral_de_normalize(self, magnitudes): 156 | output = dynamic_range_decompression(magnitudes) 157 | return output 158 | 159 | def mel_spectrogram(self, y, normalize_fun=torch.log): 160 | """Computes mel-spectrograms from a batch of waves 161 | PARAMS 162 | ------ 163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 164 | 165 | RETURNS 166 | ------- 167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 168 | """ 169 | assert torch.min(y.data) >= -1, torch.min(y.data) 170 | assert torch.max(y.data) <= 1, torch.max(y.data) 171 | 172 | magnitudes, phases = self.stft_fn.transform(y) 173 | magnitudes = magnitudes.data 174 | mel_output = torch.matmul(self.mel_basis, magnitudes) 175 | mel_output = self.spectral_normalize(mel_output, normalize_fun) 176 | energy = torch.norm(magnitudes, dim=1) 177 | 178 | log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) 179 | 180 | return mel_output, log_magnitudes, energy 181 | -------------------------------------------------------------------------------- /voicedit/data_ft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | import torchaudio 9 | from torch.utils.data import Dataset 10 | 11 | from .audio import get_mel_from_wav, raw_waveform_to_fbank, TacotronSTFT 12 | from text import text_to_sequence, cmudict 13 | from text.symbols import symbols 14 | from .utils import intersperse 15 | 16 | 17 | class AudioDataset_ft(Dataset): 18 | def __init__(self, args, df, df_noise, clap_processor, rir_path=None, random_pad=False, cmudict_path='voicedit/cmu_dictionary', add_blank=True): 19 | self.df = df 20 | self.df_noise = df_noise 21 | 22 | self.paths = args.paths 23 | self.noise_paths = args.noise_paths 24 | 25 | self.uncond_text_prob = args.uncond_text_prob 26 | self.add_noise_prob = args.add_noise_prob 27 | self.reverb_prob = args.reverb_prob 28 | self.random_pad = random_pad 29 | self.only_noise_prob = args.only_noise_prob 30 | 31 | self.duration = 10 32 | self.target_length = int(self.duration * 102.4) 33 | self.stft = TacotronSTFT( 34 | filter_length=1024, 35 | hop_length=160, 36 | win_length=1024, 37 | n_mel_channels=64, 38 | sampling_rate=16000, 39 | mel_fmin=0, 40 | mel_fmax=8000, 41 | ) 42 | 43 | self.clap_processor = clap_processor 44 | 45 | self.cmudict = cmudict.CMUDict(cmudict_path) 46 | self.add_blank = add_blank 47 | 48 | if rir_path is not None: 49 | self.rir_files = glob.glob(os.path.join(rir_path, '*/*/*.wav')) 50 | 51 | def get_mel(self, audio, _stft): 52 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 53 | audio = torch.autograd.Variable(audio, requires_grad=False) 54 | melspec, _, _ = _stft.mel_spectrogram(audio) 55 | return torch.squeeze(melspec, 0).float() 56 | 57 | def get_text(self, text, add_blank=True): 58 | text_norm = text_to_sequence(text, dictionary=self.cmudict) 59 | if add_blank: 60 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) 61 | text_norm = torch.IntTensor(text_norm) 62 | return text_norm 63 | 64 | def reverberate(self, audio): 65 | 66 | rir_file = random.choice(self.rir_files) 67 | rir, fs = torchaudio.load(rir_file) 68 | rir = rir.to(dtype=audio.dtype) 69 | rir = rir / torch.linalg.vector_norm(rir, ord=2, dim=-1) 70 | 71 | return torchaudio.functional.fftconvolve(audio, rir)[:,:self.target_length * 160] 72 | 73 | def pad_wav(self, wav, target_len, random_cut=False, random_pad=False): 74 | n_channels, wav_len = wav.shape 75 | if n_channels == 2: 76 | wav = wav.mean(-2, keepdim=True) 77 | 78 | if wav_len > target_len: 79 | if random_cut: 80 | i = random.randint(0, wav_len - target_len) 81 | return wav[:, i:i+target_len], 0 82 | return wav[:, :target_len], 0 83 | elif wav_len < target_len: 84 | if random_pad: 85 | i = random.randint(0, (target_len - wav_len) // 640) 86 | pre_pad = i * 640 87 | return F.pad(wav, (pre_pad, target_len-wav_len-pre_pad)), i 88 | wav = F.pad(wav, (0, target_len-wav_len)) 89 | return wav, 0 90 | 91 | def normalize_wav(self, waveform): 92 | waveform = waveform - torch.mean(waveform) 93 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) 94 | return waveform 95 | 96 | def __getitem__(self, index): 97 | if random.random() > self.only_noise_prob: 98 | row = self.df.iloc[index] 99 | text = row.text 100 | file_path = os.path.join(self.paths[row.data], row.file_path) 101 | 102 | waveform, sr = torchaudio.load(file_path) 103 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 104 | waveform, speech_start = self.pad_wav(waveform, self.target_length * 160, random_pad=self.random_pad) 105 | 106 | if row.data in ['as_speech_en']: 107 | spk_emb=torch.load(file_path.replace('/dataset/', '/spk_emb/').replace('.wav', '.pt'), map_location=torch.device('cpu')) 108 | 109 | if row.data in ['cv', 'voxceleb', 'libri'] and len(self.df_noise) > 0: 110 | if random.random() < self.add_noise_prob: 111 | noise_row = self.df_noise.iloc[random.randint(0, len(self.df_noise)-1)] 112 | noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path)) 113 | noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000) 114 | noise, _ = self.pad_wav(noise, self.target_length * 160, random_cut=True) 115 | if torch.linalg.vector_norm(noise, ord=2, dim=-1).item() != 0.0: 116 | snr = torch.Tensor(1).uniform_(2, 10) 117 | waveform = torchaudio.functional.add_noise(waveform, noise, snr) 118 | 119 | if random.random() < self.reverb_prob and hasattr(self, 'rir_files'): 120 | waveform = self.reverberate(waveform) 121 | 122 | spk_emb=torch.load(file_path.replace('/LibriTTS_R_16k/', '/LibriTTS_R_spk/').replace('.wav', '.pt'), map_location=torch.device('cpu')) 123 | 124 | else: 125 | text = '_' 126 | noise_row = self.df_noise.iloc[random.randint(0, len(self.df_noise)-1)] 127 | noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path)) 128 | noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000) 129 | waveform, _ = self.pad_wav(noise, self.target_length * 160, random_cut=True) 130 | spk_emb = torch.zeros(1, 192) 131 | speech_start = 0 132 | 133 | if type(spk_emb) != torch.Tensor: 134 | spk_emb = spk_emb[0] 135 | 136 | fbank, _, waveform = raw_waveform_to_fbank( 137 | waveform[0], 138 | target_length=self.target_length, 139 | fn_STFT=self.stft 140 | ) 141 | 142 | tokenized_text = self.get_text(text, add_blank=self.add_blank) 143 | 144 | # resample to 48k for clap 145 | wav_48k = torchaudio.functional.resample(waveform, orig_freq=16000, new_freq=48000) 146 | clap_inputs = self.clap_processor(audios=wav_48k, return_tensors="pt", sampling_rate=48000) 147 | 148 | return text, tokenized_text, fbank, spk_emb, clap_inputs, self.normalize_wav(waveform), speech_start 149 | 150 | def __len__(self): 151 | return len(self.df) 152 | 153 | 154 | class CollateFn_ft(object): 155 | 156 | def __call__(self, examples): 157 | 158 | B = len(examples) 159 | 160 | fbank = torch.stack([example[2] for example in examples]) 161 | spk_embs = torch.cat([example[3] for example in examples]) 162 | clap_input_features = torch.cat([example[4].input_features for example in examples]) 163 | clap_is_longer = torch.cat([example[4].is_longer for example in examples]) 164 | audios = [example[5] for example in examples] 165 | # speech_starts = torch.LongTensor([example[6] for example in examples]) 166 | 167 | x_max_length = max([example[1].shape[-1] for example in examples]) 168 | 169 | x = torch.zeros((B, x_max_length), dtype=torch.long) 170 | x_lengths = [] 171 | 172 | for i, example in enumerate(examples): 173 | x_ = example[1] 174 | x_lengths.append(x_.shape[-1]) 175 | x[i, :x_.shape[-1]] = x_ 176 | 177 | x_lengths = torch.LongTensor(x_lengths) 178 | 179 | return { 180 | "audio": audios, 181 | "fbank": fbank, 182 | "spk": spk_embs, 183 | "text": x, 184 | "text_lengths": x_lengths, 185 | "clap_input_features": clap_input_features, 186 | "clap_is_longer": clap_is_longer, 187 | "texts": [example[0] for example in examples], 188 | # "speech_starts": speech_starts 189 | } -------------------------------------------------------------------------------- /voicedit/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import torchaudio 7 | from torch.utils.data import Dataset 8 | 9 | from .audio import get_mel_from_wav, raw_waveform_to_fbank, TacotronSTFT 10 | from text import text_to_sequence, cmudict 11 | from text.symbols import symbols 12 | from .utils import intersperse 13 | 14 | 15 | class AudioDataset(Dataset): 16 | def __init__(self, args, df, df_noise, clap_processor, rir_path=None, cmudict_path='voicedit/cmu_dictionary', add_blank=True, train=True): 17 | self.df = df 18 | self.df_noise = df_noise 19 | 20 | self.paths = args.paths 21 | self.noise_paths = args.noise_paths 22 | 23 | self.uncond_text_prob = args.uncond_text_prob 24 | self.add_noise_prob = args.add_noise_prob 25 | self.reverb_prob = args.reverb_prob 26 | 27 | self.duration = 10 28 | self.target_length = int(self.duration * 102.4) 29 | self.stft = TacotronSTFT( 30 | filter_length=1024, 31 | hop_length=160, 32 | win_length=1024, 33 | n_mel_channels=64, 34 | sampling_rate=16000, 35 | mel_fmin=0, 36 | mel_fmax=8000, 37 | ) 38 | 39 | self.clap_processor = clap_processor 40 | self.train = train 41 | if not train: 42 | self.noise_dict = dict() 43 | 44 | self.cmudict = cmudict.CMUDict(cmudict_path) 45 | self.add_blank = add_blank 46 | 47 | def get_mel(self, audio, _stft): 48 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) 49 | audio = torch.autograd.Variable(audio, requires_grad=False) 50 | melspec, _, _ = _stft.mel_spectrogram(audio) 51 | return torch.squeeze(melspec, 0).float() 52 | 53 | def get_text(self, text, add_blank=True): 54 | text_norm = text_to_sequence(text, dictionary=self.cmudict) 55 | if add_blank: 56 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) 57 | text_norm = torch.IntTensor(text_norm) 58 | return text_norm 59 | 60 | def reverberate(self, audio): 61 | 62 | rir_file = random.choice(self.rir_files) 63 | rir, fs = torchaudio.load(rir_file) 64 | rir = rir.to(dtype=audio.dtype) 65 | rir = rir / torch.linalg.vector_norm(rir, ord=2, dim=-1) 66 | 67 | return torchaudio.functional.fftconvolve(audio, rir)[:,:self.target_length * 160] 68 | 69 | def pad_wav(self, wav, target_len, random_cut=False): 70 | n_channels, wav_len = wav.shape 71 | if n_channels == 2: 72 | wav = wav.mean(-2, keepdim=True) 73 | 74 | if wav_len > target_len: 75 | if random_cut: 76 | i = random.randint(0, wav_len - target_len) 77 | return wav[:, i:i+target_len] 78 | return wav[:, :target_len] 79 | elif wav_len < target_len: 80 | wav = F.pad(wav, (0, target_len-wav_len)) 81 | return wav 82 | 83 | def normalize_wav(self, waveform): 84 | waveform = waveform - torch.mean(waveform) 85 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) 86 | return waveform 87 | 88 | def __getitem__(self, index): 89 | row = self.df.iloc[index] 90 | file_path = os.path.join(self.paths[row.data], row.file_path) 91 | 92 | waveform, sr = torchaudio.load(file_path) 93 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) 94 | wav_len = waveform.shape[1] 95 | if wav_len > self.target_length * 160: 96 | wav_len = self.target_length * 160 97 | 98 | fbank_clean = self.get_mel(waveform[0, :wav_len], self.stft) 99 | 100 | waveform = self.pad_wav(waveform, self.target_length * 160) 101 | 102 | if row.data in ['as_speech_en']: 103 | spk_emb=torch.load(file_path.replace('/dataset/', '/spk_emb/').replace('.wav', '.pt'), map_location=torch.device('cpu')) 104 | 105 | if row.data in ['cv', 'voxceleb', 'libri'] and len(self.df_noise) > 0: 106 | if random.random() < self.reverb_prob and hasattr(self, 'rir_files'): 107 | waveform = self.reverberate(waveform) 108 | 109 | if random.random() < self.add_noise_prob: 110 | noise_row = self.df_noise.iloc[random.randint(0, len(self.df_noise)-1)] 111 | noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path)) 112 | noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000) 113 | noise = self.pad_wav(noise, self.target_length * 160, random_cut=True) 114 | if torch.linalg.vector_norm(noise, ord=2, dim=-1).item() != 0.0: 115 | snr = torch.Tensor(1).uniform_(2, 10) 116 | waveform = torchaudio.functional.add_noise(waveform, noise, snr) 117 | # else: 118 | # if index not in self.noise_dict: 119 | # noise_idx = random.randint(0, len(self.df_noise)-1) 120 | # noise_row = self.df_noise.iloc[noise_idx] 121 | # noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path)) 122 | # noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000) 123 | # noise = self.pad_wav(noise, self.target_length * 160, random_cut=True) 124 | # if torch.linalg.vector_norm(noise, ord=2, dim=-1).item() == 0.0: 125 | # noise = torch.zeros_like(noise) 126 | # snr = torch.Tensor(1).uniform_(4, 20) 127 | # self.noise_dict[index] = (noise, snr) 128 | # else: 129 | # noise, snr = self.noise_dict[index] 130 | # waveform = torchaudio.functional.add_noise(waveform, noise, snr) 131 | 132 | spk_emb=torch.load(file_path.replace('/LibriTTS_R_16k/', '/LibriTTS_R_spk/').replace('.wav', '.pt'), map_location=torch.device('cpu')) 133 | 134 | if type(spk_emb) != torch.Tensor: 135 | spk_emb = spk_emb[0] 136 | 137 | fbank, _, waveform = raw_waveform_to_fbank( 138 | waveform[0], 139 | target_length=self.target_length, 140 | fn_STFT=self.stft 141 | ) 142 | 143 | text = row.text 144 | tokenized_text = self.get_text(text, add_blank=self.add_blank) 145 | 146 | # resample to 48k for clap 147 | wav_48k = torchaudio.functional.resample(waveform, orig_freq=16000, new_freq=48000) 148 | clap_inputs = self.clap_processor(audios=wav_48k, return_tensors="pt", sampling_rate=48000) 149 | 150 | return text, tokenized_text, fbank, spk_emb, clap_inputs, self.normalize_wav(waveform), fbank_clean 151 | 152 | def __len__(self): 153 | return len(self.df) 154 | 155 | 156 | class CollateFn(object): 157 | 158 | def __call__(self, examples): 159 | B = len(examples) 160 | 161 | fbank = torch.stack([example[2] for example in examples]) 162 | spk_embs = torch.cat([example[3] for example in examples]) 163 | clap_input_features = torch.cat([example[4].input_features for example in examples]) 164 | clap_is_longer = torch.cat([example[4].is_longer for example in examples]) 165 | audios = [example[5] for example in examples] 166 | 167 | y_max_length = max([example[6].shape[-1] for example in examples]) 168 | x_max_length = max([example[1].shape[-1] for example in examples]) 169 | n_feats = examples[0][6].shape[-2] 170 | 171 | y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float) 172 | x = torch.zeros((B, x_max_length), dtype=torch.long) 173 | y_lengths, x_lengths = [], [] 174 | 175 | for i, example in enumerate(examples): 176 | x_, y_ = example[1], example[6] 177 | y_lengths.append(y_.shape[-1]) 178 | x_lengths.append(x_.shape[-1]) 179 | y[i, :, :y_.shape[-1]] = y_ 180 | x[i, :x_.shape[-1]] = x_ 181 | 182 | y_lengths = torch.LongTensor(y_lengths) 183 | x_lengths = torch.LongTensor(x_lengths) 184 | 185 | return { 186 | "audio": audios, 187 | "fbank": fbank, 188 | "spk": spk_embs, 189 | "text": x, 190 | "text_lengths": x_lengths, 191 | "y": y, 192 | "y_lengths": y_lengths, 193 | "clap_input_features": clap_input_features, 194 | "clap_is_longer": clap_is_longer, 195 | "texts": [example[0] for example in examples], 196 | } -------------------------------------------------------------------------------- /voicedit/modules.py: -------------------------------------------------------------------------------- 1 | # Codes adapted from https://github.com/heatz123/naturalspeech/blob/main/models/models.py 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from voicedit.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility 9 | from voicedit import monotonic_align 10 | 11 | from voicedit.text_encoder import TextEncoder 12 | from text.symbols import symbols 13 | 14 | class DiTWrapper(nn.Module): 15 | def __init__(self, dit, cond_speaker=False, concat_y=False, uncond_prob=0.0, dit_arch=True): 16 | super().__init__() 17 | self.dit = dit 18 | self.cond_speaker = cond_speaker 19 | self.uncond_prob = uncond_prob 20 | self.n_feats = 64 21 | self.concat_y = concat_y 22 | self.dit_arch = dit_arch 23 | 24 | self.text_encoder = TextEncoder(len(symbols)+1, n_feats = 64, n_channels = 192, 25 | filter_channels = 768, filter_channels_dp = 256, n_heads = 2, 26 | n_layers = 4, kernel_size = 3, p_dropout = 0.1, window_size = 4, spk_emb_dim = self.n_feats, cond_spk=cond_speaker) 27 | 28 | if cond_speaker: 29 | self.spk_embedding = torch.nn.Sequential(torch.nn.Linear(192, 192 * 4), nn.ReLU(), 30 | torch.nn.Linear(192 * 4, self.n_feats)) 31 | if concat_y: 32 | self.proj = nn.Sequential( 33 | nn.ConstantPad2d((0, 1, 0, 1), value=0), 34 | nn.Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2)), 35 | nn.ReLU(), 36 | nn.ConstantPad2d((0, 1, 0, 1), value=0), 37 | nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2)), 38 | ) 39 | 40 | if not dit_arch: 41 | self.class_embedding = nn.Linear(512, 1024) 42 | 43 | self.register_buffer("y_embedding", nn.Parameter(torch.randn(8, 256, 16)) if concat_y 44 | else nn.Parameter(torch.randn(1, self.n_feats) / self.n_feats ** 0.5)) 45 | 46 | 47 | def forward(self, sample, timestep, text, text_lengths=None, cond=None, y=None, y_lengths=None, y_latent=None, speech_starts=None, spk=None, train_text_encoder=True, **kwargs): 48 | 49 | if train_text_encoder: 50 | if spk is not None: 51 | spk = self.spk_embedding(spk) 52 | 53 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 54 | mu_x, logw, x_mask = self.text_encoder(text, text_lengths, spk) 55 | y_max_length = y.shape[-1] 56 | 57 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask) 58 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 59 | 60 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram 61 | with torch.no_grad(): 62 | const = -0.5 * math.log(2 * math.pi) * self.n_feats 63 | factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device) 64 | y_square = torch.matmul(factor.transpose(1, 2), y ** 2) 65 | y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y) 66 | mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1) 67 | log_prior = y_square - y_mu_double + mu_square + const 68 | 69 | attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1)) 70 | attn = attn.detach() 71 | 72 | # Compute loss between predicted log-scaled durations and those obtained from MAS 73 | logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask 74 | dur_loss = duration_loss(logw, logw_, text_lengths) 75 | 76 | # Align encoded text with mel-spectrogram and get mu_y segment 77 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 78 | 79 | # Compute loss between aligned encoder outputs and mel-spectrogram 80 | prior_loss = torch.sum(0.5 * ((y - mu_y.transpose(1,2)) ** 2 + math.log(2 * math.pi)) * y_mask) 81 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats) 82 | 83 | # spk_emb = spk_emb.unsqueeze(1).repeat(1, mu_y.shape[1], 1) 84 | # mu_y_vae = torch.stack([mu_y, spk_emb], dim=1) 85 | 86 | if self.concat_y: 87 | condition_y = self.proj(mu_y.unsqueeze(1)) 88 | 89 | # y_latent_lengths = torch.ceil(y_lengths / 4).long() 90 | # y_latent_mask = sequence_mask(y_latent_lengths, y_latent.shape[-2]).unsqueeze(1).to(mu_y_latent.dtype) 91 | # latent_prior_loss = torch.sum(0.5 * ((y_latent - mu_y_latent) ** 2 + math.log(2 * math.pi)) * y_latent_mask.unsqueeze(-1)) 92 | # prior_loss += latent_prior_loss / (torch.sum(y_latent_mask) * y_latent.shape[1] * y_latent.shape[-1]) 93 | 94 | condition_y = torch.nn.functional.pad(condition_y, (0, 0, 0, sample.shape[-2] - condition_y.shape[-2]), value=0) 95 | else: 96 | condition_y = mu_y.unsqueeze(1) 97 | 98 | else: 99 | 100 | with torch.no_grad(): 101 | mu_y, y_mask = self.process_content(text, text_lengths, spk) 102 | 103 | if self.concat_y: 104 | condition_y = self.proj(mu_y.unsqueeze(1)) 105 | 106 | # y_latent_lengths = torch.ceil(y_lengths / 4).long() 107 | # y_latent_mask = sequence_mask(y_latent_lengths, y_latent.shape[-2]).unsqueeze(1).to(mu_y_latent.dtype) 108 | # latent_prior_loss = torch.sum(0.5 * ((y_latent - mu_y_latent) ** 2 + math.log(2 * math.pi)) * y_latent_mask.unsqueeze(-1)) 109 | # prior_loss += latent_prior_loss / (torch.sum(y_latent_mask) * y_latent.shape[1] * y_latent.shape[-1]) 110 | 111 | # condition_temp = torch.zeros(sample.size()).to(condition_y.device, dtype=condition_y.dtype) 112 | 113 | if speech_starts is not None and torch.any(speech_starts): 114 | condition_temp = [] 115 | for start in speech_starts: 116 | condition_temp.append(torch.nn.functional.pad(condition_y[0], (0, 0, start, sample.shape[-2] - condition_y[0].shape[-2] - start), value=0)) 117 | condition_y = torch.stack(condition_temp, dim=0) 118 | # condition_temp_y = torch.zeros_like(sample) 119 | # indices = speech_starts.unsqueeze(1).unsqueeze(2).unsqueeze(3) + torch.arange(condition_y.shape[-2]).unsqueeze(0).unsqueeze(1).unsqueeze(3).to(speech_starts.device) 120 | # condition_temp_y.scatter_(2, indices.expand(condition_y.size()), condition_y) 121 | # condition_y = condition_temp_y 122 | else: 123 | condition_y = torch.nn.functional.pad(condition_y, (0, 0, 0, sample.shape[-2] - condition_y.shape[-2]), value=0) 124 | else: 125 | condition_y = mu_y.unsqueeze(1) 126 | 127 | if self.concat_y and self.uncond_prob > 0.0: 128 | drop_ids = torch.rand(condition_y.shape[0]).cuda() < self.uncond_prob 129 | condition_y = torch.where(drop_ids[:, None, None, None], self.y_embedding, condition_y) 130 | 131 | if self.dit_arch: 132 | sample = self.dit( 133 | sample, 134 | timestep, 135 | y=condition_y, 136 | mask=y_mask.long() if not self.concat_y else None, 137 | class_labels=cond, 138 | ) 139 | else: 140 | sample = torch.cat([sample, condition_y], dim=1) 141 | cond = self.class_embedding(cond) 142 | sample = self.dit( 143 | sample, 144 | timestep, 145 | class_labels=cond, 146 | ).sample 147 | 148 | if sample.isnan().any(): 149 | print("NAN detected from sample") 150 | 151 | if train_text_encoder: 152 | return dict( 153 | x=sample, 154 | dur_loss=dur_loss, 155 | prior_loss=prior_loss, 156 | ) 157 | else: 158 | return sample 159 | 160 | def process_content(self, x, x_lengths, spk=None, length_scale=1.0): 161 | 162 | if spk is not None: 163 | spk = self.spk_embedding(spk) 164 | 165 | if spk.shape[0] != x.shape[0]: 166 | spk = spk.repeat(x.shape[0], 1) 167 | 168 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 169 | mu_x, logw, x_mask = self.text_encoder(x, x_lengths, spk) 170 | 171 | w = torch.exp(logw) * x_mask 172 | w_ceil = torch.ceil(w) * length_scale 173 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 174 | y_max_length = int(y_lengths.max()) 175 | y_max_length_ = fix_len_compatibility(y_max_length) 176 | 177 | # Using obtained durations `w` construct alignment map `attn` 178 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) 179 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 180 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 181 | 182 | # Align encoded text and get mu_y 183 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 184 | 185 | # spk_emb = spk_emb.unsqueeze(1).repeat(1, mu_y.shape[1], 1) 186 | # mu_y_vae = torch.stack([mu_y, spk_emb], dim=1) 187 | 188 | # if self.concat_y == 'vae': 189 | # assert vae is not None 190 | # mu_y_latent = vae.encode(mu_y.unsqueeze(1)).latent_dist.sample() 191 | # mu_y_latent = mu_y_latent * vae.config.scaling_factor 192 | # if self.concat_y: 193 | # condition_y = self.proj(mu_y.unsqueeze(1)) 194 | # else: 195 | # condition_y = mu_y.unsqueeze(1) 196 | 197 | return mu_y, y_mask 198 | -------------------------------------------------------------------------------- /diffusion/model/nets/DiT2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import os 15 | import numpy as np 16 | from timm.models.layers import DropPath 17 | from timm.models.vision_transformer import Mlp, Attention 18 | 19 | import sys 20 | sys.path.insert(1, os.getcwd()) 21 | 22 | from diffusion.model.utils import auto_grad_checkpoint, to_2tuple, PatchEmbed 23 | from diffusion.model.nets.DiT_blocks import t2i_modulate, CaptionEmbedder, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer, modulate 24 | # from diffusion.utils.logger import get_root_logger 25 | 26 | 27 | class DiTBlock(nn.Module): 28 | """ 29 | A DiT block with adaptive layer norm (adaLN-single) conditioning. 30 | """ 31 | 32 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, cross_class=False, **block_kwargs): 33 | super().__init__() 34 | self.hidden_size = hidden_size 35 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 36 | self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, 37 | input_size=input_size if window_size == 0 else (window_size, window_size), 38 | use_rel_pos=use_rel_pos, **block_kwargs) 39 | if cross_class: 40 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) 41 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 42 | # to be compatible with lower version pytorch 43 | approx_gelu = lambda: nn.GELU(approximate="tanh") 44 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) 45 | self.window_size = window_size 46 | self.adaLN_modulation = nn.Sequential( 47 | nn.SiLU(), 48 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 49 | ) 50 | 51 | def forward(self, x, t, c=None, mask=None, **kwargs): 52 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(6, dim=1) 53 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) 54 | if hasattr(self, 'cross_attn'): 55 | x = x + self.cross_attn(x, c, mask) 56 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 57 | 58 | return x 59 | 60 | 61 | ############################################################################# 62 | # Core DiT Model # 63 | ################################################################################# 64 | class DiT(nn.Module): 65 | """ 66 | Diffusion model with a Transformer backbone. 67 | """ 68 | 69 | def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, concat_y=False, cross_class=False, class_dropout_prob=0.0, pred_sigma=True, projection_class_embeddings_input_dim=512, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, caption_channels=4096, lewei_scale=1.0, config=None, **kwargs): 70 | if window_block_indexes is None: 71 | window_block_indexes = [] 72 | super().__init__() 73 | self.pred_sigma = pred_sigma 74 | self.in_channels = in_channels 75 | self.out_channels = in_channels * 2 if pred_sigma else in_channels 76 | self.patch_size = patch_size 77 | self.num_heads = num_heads 78 | self.lewei_scale = lewei_scale, 79 | self.cross_class = cross_class 80 | 81 | self.concat_y = concat_y 82 | if isinstance(input_size, int): 83 | input_size = to_2tuple(input_size) 84 | 85 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels * 2 if concat_y else in_channels, hidden_size, bias=True) 86 | self.t_embedder = TimestepEmbedder(hidden_size) 87 | self.c_embedder = nn.Linear(projection_class_embeddings_input_dim, hidden_size) 88 | num_patches = self.x_embedder.num_patches 89 | # Will use fixed sin-cos embedding: 90 | self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) 91 | 92 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule 93 | self.blocks = nn.ModuleList([ 94 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], 95 | input_size=(input_size[0] // patch_size, input_size[1] // patch_size), 96 | window_size=window_size if i in window_block_indexes else 0, 97 | use_rel_pos=use_rel_pos if i in window_block_indexes else False, 98 | cross_class=cross_class) 99 | for i in range(depth) 100 | ]) 101 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 102 | 103 | self.initialize_weights() 104 | 105 | def forward(self, x, timestep, class_labels, y, mask=None, **kwargs): 106 | """ 107 | Forward pass of DiT. 108 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 109 | t: (N,) tensor of diffusion timesteps 110 | y: (N, C, H, W) tensor of caption labels 111 | class_labels: (N, C') tensor of class labels 112 | """ 113 | 114 | x = x.to(self.dtype) 115 | timestep = timestep.to(self.dtype) 116 | y = y.to(self.dtype) 117 | pos_embed = self.pos_embed.to(self.dtype) 118 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size 119 | if self.concat_y: 120 | x = torch.cat([x, y], dim=1) 121 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 122 | t = self.t_embedder(timestep) # (N, D) 123 | class_embed = self.c_embedder(class_labels) 124 | t0 = t + class_embed 125 | if self.cross_class: 126 | class_lens = [1] * class_embed.shape[0] 127 | class_embed = class_embed.unsqueeze(1).view(1, -1, x.shape[-1]) 128 | else: 129 | class_lens = None 130 | for block in self.blocks: 131 | x = auto_grad_checkpoint(block, x, t0, class_embed, class_lens) 132 | x = self.final_layer(x, t0) # (N, T, patch_size ** 2 * out_channels) 133 | x = self.unpatchify(x) # (N, out_channels, H, W) 134 | return x 135 | 136 | def forward_with_dpmsolver(self, x, timestep, y, class_labels=None, mask=None, **kwargs): 137 | """ 138 | dpm solver donnot need variance prediction 139 | """ 140 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 141 | model_out = self.forward(x, timestep, y, mask, class_labels) 142 | return model_out.chunk(2, dim=1)[0] 143 | 144 | def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, class_labels=None, guidance="single", **kwargs): 145 | """ 146 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 147 | """ 148 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 149 | if guidance == "dual": 150 | num_chunk = 4 151 | elif guidance == "single": 152 | num_chunk = 2 153 | chunk = x[: len(x) // num_chunk] 154 | combined = torch.cat([chunk] * num_chunk, dim=0) 155 | model_out = self.forward(combined, timestep, class_labels, y, mask) 156 | model_out = model_out['x'] if isinstance(model_out, dict) else model_out 157 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 158 | eps = eps.chunk(num_chunk, dim=0) 159 | if guidance == "dual": 160 | pred_eps = eps[0] + cfg_scale[0] * (eps[1] - eps[3]) + cfg_scale[1] * (eps[2] - eps[3]) 161 | elif guidance == "single": 162 | pred_eps = eps[1] + cfg_scale * (eps[0] - eps[1]) 163 | eps = torch.cat([pred_eps] * num_chunk, dim=0) 164 | return torch.cat([eps, rest], dim=1) 165 | 166 | def unpatchify(self, x): 167 | """ 168 | x: (N, T, patch_size**2 * C) 169 | imgs: (N, H, W, C) 170 | """ 171 | c = self.out_channels 172 | p1, p2 = self.x_embedder.patch_size 173 | h, w = self.x_embedder.grid_size 174 | assert h * w == x.shape[1] 175 | 176 | x = x.reshape(shape=(x.shape[0], h, w, p1, p2, c)) 177 | x = torch.einsum('nhwpqc->nchpwq', x) 178 | return x.reshape(shape=(x.shape[0], c, h * p1, w * p2)) 179 | 180 | def initialize_weights(self): 181 | # Initialize transformer layers: 182 | def _basic_init(module): 183 | if isinstance(module, nn.Linear): 184 | torch.nn.init.xavier_uniform_(module.weight) 185 | if module.bias is not None: 186 | nn.init.constant_(module.bias, 0) 187 | 188 | self.apply(_basic_init) 189 | 190 | # Initialize (and freeze) pos_embed by sin-cos embedding: 191 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size) 192 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 193 | 194 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 195 | w = self.x_embedder.proj.weight.data 196 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 197 | nn.init.constant_(self.x_embedder.proj.bias, 0) 198 | 199 | # Initialize timestep embedding MLP: 200 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 201 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 202 | 203 | # Zero-out adaLN modulation layers in DiT blocks: 204 | for block in self.blocks: 205 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 206 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 207 | if hasattr(block, 'cross_attn'): 208 | nn.init.constant_(block.cross_attn.proj.weight, 0) 209 | nn.init.constant_(block.cross_attn.proj.bias, 0) 210 | 211 | # Zero-out output layers: 212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 214 | nn.init.constant_(self.final_layer.linear.weight, 0) 215 | nn.init.constant_(self.final_layer.linear.bias, 0) 216 | 217 | @property 218 | def dtype(self): 219 | return next(self.parameters()).dtype 220 | 221 | 222 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 223 | """ 224 | grid_size: int of the grid height and width 225 | return: 226 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 227 | """ 228 | if isinstance(grid_size, int): 229 | grid_size = to_2tuple(grid_size) 230 | grid_h = np.arange(grid_size[0], dtype=np.float32) 231 | grid_w = np.arange(grid_size[1], dtype=np.float32) 232 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 233 | grid = np.stack(grid, axis=0) 234 | 235 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) 236 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 237 | if cls_token and extra_tokens > 0: 238 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 239 | return pos_embed 240 | 241 | 242 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 243 | assert embed_dim % 2 == 0 244 | 245 | # use half of dimensions to encode grid_h 246 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 247 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 248 | 249 | return np.concatenate([emb_h, emb_w], axis=1) 250 | 251 | 252 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 253 | """ 254 | embed_dim: output dimension for each position 255 | pos: a list of positions to be encoded: size (M,) 256 | out: (M, D) 257 | """ 258 | assert embed_dim % 2 == 0 259 | omega = np.arange(embed_dim // 2, dtype=np.float64) 260 | omega /= embed_dim / 2. 261 | omega = 1. / 10000 ** omega # (D/2,) 262 | 263 | pos = pos.reshape(-1) # (M,) 264 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 265 | 266 | emb_sin = np.sin(out) # (M, D/2) 267 | emb_cos = np.cos(out) # (M, D/2) 268 | 269 | return np.concatenate([emb_sin, emb_cos], axis=1) 270 | 271 | 272 | ################################################################################# 273 | # DiT Configs # 274 | ################################################################################# 275 | def DiT_XL_2(**kwargs): 276 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 277 | 278 | def DiT_L_2(**kwargs): 279 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 280 | 281 | def DiT_B_2(**kwargs): 282 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 283 | 284 | def DiT_S_2(**kwargs): 285 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 286 | 287 | 288 | DiT_models = { 289 | 'DiT_XL/2': DiT_XL_2, 290 | 'DiT_L/2': DiT_L_2, 291 | 'DiT_B/2': DiT_B_2, 292 | 'DiT_S/2': DiT_S_2, 293 | } 294 | 295 | if __name__ == '__main__': 296 | 297 | from thop import profile, clever_format 298 | 299 | device = "cuda:0" 300 | model = DiT_models['DiT_L/2']( 301 | input_size = (256, 16), 302 | in_channels = 1, 303 | projection_class_embeddings_input_dim=512, 304 | concat_y = True, 305 | cross_class = True, 306 | ).to(device) 307 | model.eval() 308 | model.requires_grad_(False) 309 | 310 | x = torch.randn(1, 1, 256, 16).to(device) 311 | timestep = torch.Tensor([999]).to(device) 312 | y = torch.randn(1, 1, 256, 16).to(device) 313 | mask = torch.ones(1, 1, 1, 1).to(device) 314 | class_labels = torch.randn(1, 512).to(device) 315 | 316 | macs, params = profile(model, inputs=(x, timestep, class_labels, y)) 317 | flops = macs * 2 318 | macs, params, flops = clever_format([macs, params, flops], "%.4f") 319 | print("Params:", params) 320 | print("MACs:", macs) 321 | print("FLOPs:", flops) 322 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /voicedit/text_encoder.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .utils import sequence_mask, convert_pad_shape 9 | 10 | 11 | class LayerNorm(nn.Module): 12 | def __init__(self, channels, eps=1e-4): 13 | super(LayerNorm, self).__init__() 14 | self.channels = channels 15 | self.eps = eps 16 | 17 | self.gamma = torch.nn.Parameter(torch.ones(channels)) 18 | self.beta = torch.nn.Parameter(torch.zeros(channels)) 19 | 20 | def forward(self, x): 21 | n_dims = len(x.shape) 22 | mean = torch.mean(x, 1, keepdim=True) 23 | variance = torch.mean((x - mean)**2, 1, keepdim=True) 24 | 25 | x = (x - mean) * torch.rsqrt(variance + self.eps) 26 | 27 | shape = [1, -1] + [1] * (n_dims - 2) 28 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 29 | return x 30 | 31 | 32 | class ConvReluNorm(nn.Module): 33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, 34 | n_layers, p_dropout): 35 | super(ConvReluNorm, self).__init__() 36 | self.in_channels = in_channels 37 | self.hidden_channels = hidden_channels 38 | self.out_channels = out_channels 39 | self.kernel_size = kernel_size 40 | self.n_layers = n_layers 41 | self.p_dropout = p_dropout 42 | 43 | self.conv_layers = torch.nn.ModuleList() 44 | self.norm_layers = torch.nn.ModuleList() 45 | self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, 46 | kernel_size, padding=kernel_size//2)) 47 | self.norm_layers.append(LayerNorm(hidden_channels)) 48 | self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) 49 | for _ in range(n_layers - 1): 50 | self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels, 51 | kernel_size, padding=kernel_size//2)) 52 | self.norm_layers.append(LayerNorm(hidden_channels)) 53 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) 54 | self.proj.weight.data.zero_() 55 | self.proj.bias.data.zero_() 56 | 57 | def forward(self, x, x_mask): 58 | x_org = x 59 | for i in range(self.n_layers): 60 | x = self.conv_layers[i](x * x_mask) 61 | x = self.norm_layers[i](x) 62 | x = self.relu_drop(x) 63 | x = x_org + self.proj(x) 64 | return x * x_mask 65 | 66 | 67 | class DurationPredictor(nn.Module): 68 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): 69 | super(DurationPredictor, self).__init__() 70 | self.in_channels = in_channels 71 | self.filter_channels = filter_channels 72 | self.p_dropout = p_dropout 73 | 74 | self.drop = torch.nn.Dropout(p_dropout) 75 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, 76 | kernel_size, padding=kernel_size//2) 77 | self.norm_1 = LayerNorm(filter_channels) 78 | self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, 79 | kernel_size, padding=kernel_size//2) 80 | self.norm_2 = LayerNorm(filter_channels) 81 | self.proj = torch.nn.Conv1d(filter_channels, 1, 1) 82 | 83 | def forward(self, x, x_mask): 84 | x = self.conv_1(x * x_mask) 85 | x = torch.relu(x) 86 | x = self.norm_1(x) 87 | x = self.drop(x) 88 | x = self.conv_2(x * x_mask) 89 | x = torch.relu(x) 90 | x = self.norm_2(x) 91 | x = self.drop(x) 92 | x = self.proj(x * x_mask) 93 | return x * x_mask 94 | 95 | 96 | class MultiHeadAttention(nn.Module): 97 | def __init__(self, channels, out_channels, n_heads, window_size=None, 98 | heads_share=True, p_dropout=0.0, proximal_bias=False, 99 | proximal_init=False): 100 | super(MultiHeadAttention, self).__init__() 101 | assert channels % n_heads == 0 102 | 103 | self.channels = channels 104 | self.out_channels = out_channels 105 | self.n_heads = n_heads 106 | self.window_size = window_size 107 | self.heads_share = heads_share 108 | self.proximal_bias = proximal_bias 109 | self.p_dropout = p_dropout 110 | self.attn = None 111 | 112 | self.k_channels = channels // n_heads 113 | self.conv_q = torch.nn.Conv1d(channels, channels, 1) 114 | self.conv_k = torch.nn.Conv1d(channels, channels, 1) 115 | self.conv_v = torch.nn.Conv1d(channels, channels, 1) 116 | if window_size is not None: 117 | n_heads_rel = 1 if heads_share else n_heads 118 | rel_stddev = self.k_channels**-0.5 119 | self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, 120 | window_size * 2 + 1, self.k_channels) * rel_stddev) 121 | self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, 122 | window_size * 2 + 1, self.k_channels) * rel_stddev) 123 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) 124 | self.drop = torch.nn.Dropout(p_dropout) 125 | 126 | torch.nn.init.xavier_uniform_(self.conv_q.weight) 127 | torch.nn.init.xavier_uniform_(self.conv_k.weight) 128 | if proximal_init: 129 | self.conv_k.weight.data.copy_(self.conv_q.weight.data) 130 | self.conv_k.bias.data.copy_(self.conv_q.bias.data) 131 | torch.nn.init.xavier_uniform_(self.conv_v.weight) 132 | 133 | def forward(self, x, c, attn_mask=None): 134 | q = self.conv_q(x) 135 | k = self.conv_k(c) 136 | v = self.conv_v(c) 137 | 138 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 139 | 140 | x = self.conv_o(x) 141 | return x 142 | 143 | def attention(self, query, key, value, mask=None): 144 | b, d, t_s, t_t = (*key.size(), query.size(2)) 145 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 146 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 147 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 148 | 149 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) 150 | if self.window_size is not None: 151 | assert t_s == t_t, "Relative attention is only available for self-attention." 152 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 153 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) 154 | rel_logits = self._relative_position_to_absolute_position(rel_logits) 155 | scores_local = rel_logits / math.sqrt(self.k_channels) 156 | scores = scores + scores_local 157 | if self.proximal_bias: 158 | assert t_s == t_t, "Proximal bias is only available for self-attention." 159 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, 160 | dtype=scores.dtype) 161 | if mask is not None: 162 | scores = scores.masked_fill(mask == 0, -1e4) 163 | p_attn = torch.nn.functional.softmax(scores, dim=-1) 164 | p_attn = self.drop(p_attn) 165 | output = torch.matmul(p_attn, value) 166 | if self.window_size is not None: 167 | relative_weights = self._absolute_position_to_relative_position(p_attn) 168 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 169 | output = output + self._matmul_with_relative_values(relative_weights, 170 | value_relative_embeddings) 171 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) 172 | return output, p_attn 173 | 174 | def _matmul_with_relative_values(self, x, y): 175 | ret = torch.matmul(x, y.unsqueeze(0)) 176 | return ret 177 | 178 | def _matmul_with_relative_keys(self, x, y): 179 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 180 | return ret 181 | 182 | def _get_relative_embeddings(self, relative_embeddings, length): 183 | pad_length = max(length - (self.window_size + 1), 0) 184 | slice_start_position = max((self.window_size + 1) - length, 0) 185 | slice_end_position = slice_start_position + 2 * length - 1 186 | if pad_length > 0: 187 | padded_relative_embeddings = torch.nn.functional.pad( 188 | relative_embeddings, convert_pad_shape([[0, 0], 189 | [pad_length, pad_length], [0, 0]])) 190 | else: 191 | padded_relative_embeddings = relative_embeddings 192 | used_relative_embeddings = padded_relative_embeddings[:, 193 | slice_start_position:slice_end_position] 194 | return used_relative_embeddings 195 | 196 | def _relative_position_to_absolute_position(self, x): 197 | batch, heads, length, _ = x.size() 198 | x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 199 | x_flat = x.view([batch, heads, length * 2 * length]) 200 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) 201 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 202 | return x_final 203 | 204 | def _absolute_position_to_relative_position(self, x): 205 | batch, heads, length, _ = x.size() 206 | x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 207 | x_flat = x.view([batch, heads, length**2 + length*(length - 1)]) 208 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 209 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 210 | return x_final 211 | 212 | def _attention_bias_proximal(self, length): 213 | r = torch.arange(length, dtype=torch.float32) 214 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 215 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 216 | 217 | 218 | class FFN(nn.Module): 219 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, 220 | p_dropout=0.0): 221 | super(FFN, self).__init__() 222 | self.in_channels = in_channels 223 | self.out_channels = out_channels 224 | self.filter_channels = filter_channels 225 | self.kernel_size = kernel_size 226 | self.p_dropout = p_dropout 227 | 228 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, 229 | padding=kernel_size//2) 230 | self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, 231 | padding=kernel_size//2) 232 | self.drop = torch.nn.Dropout(p_dropout) 233 | 234 | def forward(self, x, x_mask): 235 | x = self.conv_1(x * x_mask) 236 | x = torch.relu(x) 237 | x = self.drop(x) 238 | x = self.conv_2(x * x_mask) 239 | return x * x_mask 240 | 241 | 242 | class Encoder(nn.Module): 243 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, 244 | kernel_size=1, p_dropout=0.0, window_size=None, **kwargs): 245 | super(Encoder, self).__init__() 246 | self.hidden_channels = hidden_channels 247 | self.filter_channels = filter_channels 248 | self.n_heads = n_heads 249 | self.n_layers = n_layers 250 | self.kernel_size = kernel_size 251 | self.p_dropout = p_dropout 252 | self.window_size = window_size 253 | 254 | self.drop = torch.nn.Dropout(p_dropout) 255 | self.attn_layers = torch.nn.ModuleList() 256 | self.norm_layers_1 = torch.nn.ModuleList() 257 | self.ffn_layers = torch.nn.ModuleList() 258 | self.norm_layers_2 = torch.nn.ModuleList() 259 | for _ in range(self.n_layers): 260 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, 261 | n_heads, window_size=window_size, p_dropout=p_dropout)) 262 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 263 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, 264 | filter_channels, kernel_size, p_dropout=p_dropout)) 265 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 266 | 267 | def forward(self, x, x_mask): 268 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 269 | for i in range(self.n_layers): 270 | x = x * x_mask 271 | y = self.attn_layers[i](x, x, attn_mask) 272 | y = self.drop(y) 273 | x = self.norm_layers_1[i](x + y) 274 | y = self.ffn_layers[i](x, x_mask) 275 | y = self.drop(y) 276 | x = self.norm_layers_2[i](x + y) 277 | x = x * x_mask 278 | return x 279 | 280 | 281 | class TextEncoder(nn.Module): 282 | def __init__(self, n_vocab, n_feats, n_channels, filter_channels, 283 | filter_channels_dp, n_heads, n_layers, kernel_size, 284 | p_dropout, window_size=None, spk_emb_dim=64, cond_spk=False): 285 | super(TextEncoder, self).__init__() 286 | self.n_vocab = n_vocab 287 | self.n_feats = n_feats 288 | self.n_channels = n_channels 289 | self.filter_channels = filter_channels 290 | self.filter_channels_dp = filter_channels_dp 291 | self.n_heads = n_heads 292 | self.n_layers = n_layers 293 | self.kernel_size = kernel_size 294 | self.p_dropout = p_dropout 295 | self.window_size = window_size 296 | self.spk_emb_dim = spk_emb_dim 297 | self.cond_spk = cond_spk 298 | 299 | self.emb = torch.nn.Embedding(n_vocab, n_channels) 300 | torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) 301 | 302 | self.prenet = ConvReluNorm(n_channels, n_channels, n_channels, 303 | kernel_size=5, n_layers=3, p_dropout=0.5) 304 | 305 | self.encoder = Encoder(n_channels + (spk_emb_dim if cond_spk else 0), filter_channels, n_heads, n_layers, 306 | kernel_size, p_dropout, window_size=window_size) 307 | 308 | self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if cond_spk else 0), n_feats, 1) 309 | self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if cond_spk else 0), filter_channels_dp, 310 | kernel_size, p_dropout) 311 | 312 | def forward(self, x, x_lengths, spk=None): 313 | x = self.emb(x) * math.sqrt(self.n_channels) 314 | x = torch.transpose(x, 1, -1) 315 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 316 | 317 | x = self.prenet(x, x_mask) 318 | if self.cond_spk: 319 | x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) 320 | x = self.encoder(x, x_mask) 321 | mu = self.proj_m(x) * x_mask 322 | 323 | x_dp = torch.detach(x) 324 | logw = self.proj_w(x_dp, x_mask) 325 | 326 | return mu, logw, x_mask 327 | -------------------------------------------------------------------------------- /diffusion/model/nets/DiT.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | import os 15 | import numpy as np 16 | from timm.models.layers import DropPath 17 | from timm.models.vision_transformer import Mlp 18 | 19 | import sys 20 | sys.path.insert(1, os.getcwd()) 21 | 22 | from diffusion.model.utils import auto_grad_checkpoint, to_2tuple, PatchEmbed 23 | from diffusion.model.nets.DiT_blocks import t2i_modulate, CaptionEmbedder, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer, modulate 24 | # from diffusion.utils.logger import get_root_logger 25 | 26 | 27 | class DiTBlock(nn.Module): 28 | """ 29 | A DiT block with adaptive layer norm (adaLN-single) conditioning. 30 | """ 31 | 32 | def __init__(self, hidden_size, num_heads, cross_attn_dim=None, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs): 33 | super().__init__() 34 | self.hidden_size = hidden_size 35 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 36 | self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True, 37 | input_size=input_size if window_size == 0 else (window_size, window_size), 38 | use_rel_pos=use_rel_pos, **block_kwargs) 39 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, cross_attn_dim, **block_kwargs) 40 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 41 | # to be compatible with lower version pytorch 42 | approx_gelu = lambda: nn.GELU(approximate="tanh") 43 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) 44 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 45 | self.window_size = window_size 46 | self.adaLN_modulation = nn.Sequential( 47 | nn.SiLU(), 48 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 49 | ) 50 | 51 | def forward(self, x, y, t, mask=None, **kwargs): 52 | B, N, C = x.shape 53 | 54 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(6, dim=1) 55 | x = x + self.drop_path(gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) 56 | x = x + self.cross_attn(x, y, mask) 57 | x = x + self.drop_path(gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))) 58 | 59 | return x 60 | 61 | 62 | ############################################################################# 63 | # Core DiT Model # 64 | ################################################################################# 65 | class DiT(nn.Module): 66 | """ 67 | Diffusion model with a Transformer backbone. 68 | """ 69 | 70 | def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.0, pred_sigma=True, projection_class_embeddings_input_dim=512, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, caption_channels=4096, lewei_scale=1.0, config=None, **kwargs): 71 | if window_block_indexes is None: 72 | window_block_indexes = [] 73 | super().__init__() 74 | self.pred_sigma = pred_sigma 75 | self.in_channels = in_channels 76 | self.out_channels = in_channels * 2 if pred_sigma else in_channels 77 | self.patch_size = patch_size 78 | self.num_heads = num_heads 79 | self.lewei_scale = lewei_scale, 80 | 81 | if isinstance(input_size, int): 82 | input_size = to_2tuple(input_size) 83 | 84 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) 85 | self.t_embedder = TimestepEmbedder(hidden_size) 86 | self.c_embedder = nn.Linear(projection_class_embeddings_input_dim, hidden_size) 87 | num_patches = self.x_embedder.num_patches 88 | # Will use fixed sin-cos embedding: 89 | self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) 90 | 91 | approx_gelu = lambda: nn.GELU(approximate="tanh") 92 | self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu) 93 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule 94 | self.blocks = nn.ModuleList([ 95 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], 96 | input_size=(input_size[0] // patch_size, input_size[1] // patch_size), 97 | window_size=window_size if i in window_block_indexes else 0, 98 | use_rel_pos=use_rel_pos if i in window_block_indexes else False) 99 | for i in range(depth) 100 | ]) 101 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 102 | 103 | self.initialize_weights() 104 | 105 | def forward(self, x, timestep, y, mask=None, class_labels=None, data_info=None, **kwargs): 106 | """ 107 | Forward pass of DiT. 108 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 109 | t: (N,) tensor of diffusion timesteps 110 | y: (N, 1, T, C) tensor of caption labels 111 | class_labels: (N, C') tensor of class labels 112 | """ 113 | x = x.to(self.dtype) 114 | timestep = timestep.to(self.dtype) 115 | y = y.to(self.dtype) 116 | pos_embed = self.pos_embed.to(self.dtype) 117 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size 118 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 119 | t = self.t_embedder(timestep.to(x.dtype)) # (N, D) 120 | class_embed = self.c_embedder(class_labels) 121 | t0 = t + class_embed 122 | y = self.y_embedder(y) # (N, 1, L, D) 123 | if mask is not None: 124 | if mask.shape[0] != y.shape[0]: 125 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1) 126 | mask = mask.squeeze(1).squeeze(1) 127 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) 128 | y_lens = mask.sum(dim=1).tolist() 129 | else: 130 | y_lens = [y.shape[2]] * y.shape[0] 131 | y = y.squeeze(1).view(1, -1, x.shape[-1]) 132 | for block in self.blocks: 133 | x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint 134 | x = self.final_layer(x, t0) # (N, T, patch_size ** 2 * out_channels) 135 | x = self.unpatchify(x) # (N, out_channels, H, W) 136 | return x 137 | 138 | def forward_with_dpmsolver(self, x, timestep, y, mask=None, class_labels=None, **kwargs): 139 | """ 140 | dpm solver donnot need variance prediction 141 | """ 142 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 143 | model_out = self.forward(x, timestep, y, mask, class_labels) 144 | return model_out.chunk(2, dim=1)[0] 145 | 146 | def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, class_labels=None, guidance="single", **kwargs): 147 | """ 148 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. 149 | """ 150 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 151 | if guidance == "dual": 152 | num_chunk = 4 153 | elif guidance == "single": 154 | num_chunk = 2 155 | chunk = x[: len(x) // num_chunk] 156 | combined = torch.cat([chunk] * num_chunk, dim=0) 157 | model_out = self.forward(combined, timestep, y, mask, class_labels, kwargs) 158 | model_out = model_out['x'] if isinstance(model_out, dict) else model_out 159 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 160 | eps = eps.chunk(num_chunk, dim=0) 161 | if guidance == "dual": 162 | pred_eps = eps[0] + cfg_scale[0] * (eps[1] - eps[3]) + cfg_scale[1] * (eps[2] - eps[3]) 163 | elif guidance == "single": 164 | pred_eps = eps[1] + cfg_scale * (eps[0] - eps[1]) 165 | eps = torch.cat([pred_eps] * num_chunk, dim=0) 166 | return torch.cat([eps, rest], dim=1) 167 | 168 | def unpatchify(self, x): 169 | """ 170 | x: (N, T, patch_size**2 * C) 171 | imgs: (N, H, W, C) 172 | """ 173 | c = self.out_channels 174 | p1, p2 = self.x_embedder.patch_size 175 | h, w = self.x_embedder.grid_size 176 | assert h * w == x.shape[1] 177 | 178 | x = x.reshape(shape=(x.shape[0], h, w, p1, p2, c)) 179 | x = torch.einsum('nhwpqc->nchpwq', x) 180 | return x.reshape(shape=(x.shape[0], c, h * p1, w * p2)) 181 | 182 | def initialize_weights(self): 183 | # Initialize transformer layers: 184 | def _basic_init(module): 185 | if isinstance(module, nn.Linear): 186 | torch.nn.init.xavier_uniform_(module.weight) 187 | if module.bias is not None: 188 | nn.init.constant_(module.bias, 0) 189 | 190 | self.apply(_basic_init) 191 | 192 | # Initialize (and freeze) pos_embed by sin-cos embedding: 193 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale) 194 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 195 | 196 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 197 | w = self.x_embedder.proj.weight.data 198 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 199 | nn.init.constant_(self.x_embedder.proj.bias, 0) 200 | 201 | # Initialize timestep embedding MLP: 202 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 203 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 204 | 205 | # Initialize caption embedding MLP: 206 | nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) 207 | nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) 208 | 209 | # Zero-out adaLN modulation layers in DiT blocks: 210 | for block in self.blocks: 211 | nn.init.constant_(block.cross_attn.proj.weight, 0) 212 | nn.init.constant_(block.cross_attn.proj.bias, 0) 213 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 214 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 215 | 216 | # Zero-out output layers: 217 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 218 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 219 | nn.init.constant_(self.final_layer.linear.weight, 0) 220 | nn.init.constant_(self.final_layer.linear.bias, 0) 221 | 222 | @property 223 | def dtype(self): 224 | return next(self.parameters()).dtype 225 | 226 | 227 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0): 228 | """ 229 | grid_size: int of the grid height and width 230 | return: 231 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 232 | """ 233 | if isinstance(grid_size, int): 234 | grid_size = to_2tuple(grid_size) 235 | grid_h = np.arange(grid_size[0], dtype=np.float32) / lewei_scale 236 | grid_w = np.arange(grid_size[1], dtype=np.float32) / lewei_scale 237 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 238 | grid = np.stack(grid, axis=0) 239 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) 240 | 241 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 242 | if cls_token and extra_tokens > 0: 243 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 244 | return pos_embed 245 | 246 | 247 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 248 | assert embed_dim % 2 == 0 249 | 250 | # use half of dimensions to encode grid_h 251 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 252 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 253 | 254 | return np.concatenate([emb_h, emb_w], axis=1) 255 | 256 | 257 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 258 | """ 259 | embed_dim: output dimension for each position 260 | pos: a list of positions to be encoded: size (M,) 261 | out: (M, D) 262 | """ 263 | assert embed_dim % 2 == 0 264 | omega = np.arange(embed_dim // 2, dtype=np.float64) 265 | omega /= embed_dim / 2. 266 | omega = 1. / 10000 ** omega # (D/2,) 267 | 268 | pos = pos.reshape(-1) # (M,) 269 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 270 | 271 | emb_sin = np.sin(out) # (M, D/2) 272 | emb_cos = np.cos(out) # (M, D/2) 273 | 274 | return np.concatenate([emb_sin, emb_cos], axis=1) 275 | 276 | 277 | ################################################################################# 278 | # DiT Configs # 279 | ################################################################################# 280 | def DiT_XL_2(**kwargs): 281 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 282 | 283 | def DiT_L_2(**kwargs): 284 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 285 | 286 | def DiT_B_2(**kwargs): 287 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 288 | 289 | def DiT_S_2(**kwargs): 290 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 291 | 292 | 293 | DiT_cross_models = { 294 | 'DiT_XL/2': DiT_XL_2, 295 | 'DiT_L/2': DiT_L_2, 296 | 'DiT_B/2': DiT_B_2, 297 | 'DiT_S/2': DiT_S_2, 298 | } 299 | 300 | if __name__ == '__main__': 301 | 302 | from thop import profile, clever_format 303 | 304 | device = "cuda:0" 305 | 306 | model = DiT_cross_models['DiT_L/2']( 307 | input_size = (256, 16), 308 | in_channels = 8, 309 | projection_class_embeddings_input_dim=512, 310 | caption_channels = 64, 311 | concat_y = False, 312 | cross_class = False, 313 | ).to(device) 314 | 315 | model.eval() 316 | model.requires_grad_(False) 317 | x = torch.randn(1, 8, 256, 16).to(device) 318 | timestep = torch.Tensor([999]).to(device) 319 | y = torch.randn(1, 1, 1024, 64).to(device) 320 | mask = torch.ones(1, 1, 1, 1024).to(device) 321 | class_labels = torch.randn(1, 512).to(device) 322 | 323 | macs, params = profile(model, inputs=(x, timestep, y, mask, class_labels)) 324 | flops = macs * 2 325 | macs, params, flops = clever_format([macs, params, flops], "%.4f") 326 | 327 | print("Params:", params) 328 | print("MACs:", macs) 329 | print("FLOPs:", flops) -------------------------------------------------------------------------------- /diffusion/model/nets/DiT_blocks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | import math 12 | import torch 13 | import torch.nn as nn 14 | from timm.models.vision_transformer import Mlp, Attention as Attention_ 15 | from einops import rearrange, repeat 16 | import xformers.ops 17 | 18 | from diffusion.model.utils import add_decomposed_rel_pos 19 | 20 | 21 | def modulate(x, shift, scale): 22 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 23 | 24 | 25 | def t2i_modulate(x, shift, scale): 26 | return x * (1 + scale) + shift 27 | 28 | 29 | class MultiHeadCrossAttention(nn.Module): 30 | def __init__(self, d_model, num_heads, cross_attn_dim=None, attn_drop=0., proj_drop=0., **block_kwargs): 31 | super(MultiHeadCrossAttention, self).__init__() 32 | assert d_model % num_heads == 0, "d_model must be divisible by num_heads" 33 | 34 | self.d_model = d_model 35 | self.num_heads = num_heads 36 | self.head_dim = d_model // num_heads 37 | if cross_attn_dim is None: 38 | cross_attn_dim = d_model 39 | self.cross_attn_dim = cross_attn_dim 40 | 41 | self.q_linear = nn.Linear(d_model, d_model) 42 | self.kv_linear = nn.Linear(cross_attn_dim, d_model*2) 43 | self.attn_drop = nn.Dropout(attn_drop) 44 | self.proj = nn.Linear(d_model, d_model) 45 | self.proj_drop = nn.Dropout(proj_drop) 46 | 47 | def forward(self, x, cond, mask=None): 48 | # query: img tokens; key/value: condition; mask: if padding tokens 49 | B, N, C = x.shape 50 | 51 | q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) 52 | kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) 53 | k, v = kv.unbind(2) 54 | attn_bias = None 55 | if mask is not None: 56 | attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask) 57 | x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) 58 | x = x.view(B, -1, C) 59 | x = self.proj(x) 60 | x = self.proj_drop(x) 61 | 62 | # q = self.q_linear(x).reshape(B, -1, self.num_heads, self.head_dim) 63 | # kv = self.kv_linear(cond).reshape(B, -1, 2, self.num_heads, self.head_dim) 64 | # k, v = kv.unbind(2) 65 | # attn_bias = None 66 | # if mask is not None: 67 | # attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) 68 | # attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf')) 69 | # x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) 70 | # x = x.contiguous().reshape(B, -1, C) 71 | # x = self.proj(x) 72 | # x = self.proj_drop(x) 73 | 74 | return x 75 | 76 | 77 | class WindowAttention(Attention_): 78 | """Multi-head Attention block with relative position embeddings.""" 79 | 80 | def __init__( 81 | self, 82 | dim, 83 | num_heads=8, 84 | qkv_bias=True, 85 | use_rel_pos=False, 86 | rel_pos_zero_init=True, 87 | input_size=None, 88 | **block_kwargs, 89 | ): 90 | """ 91 | Args: 92 | dim (int): Number of input channels. 93 | num_heads (int): Number of attention heads. 94 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 95 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 96 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 97 | input_size (int or None): Input resolution for calculating the relative positional 98 | parameter size. 99 | """ 100 | super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs) 101 | 102 | self.use_rel_pos = use_rel_pos 103 | if self.use_rel_pos: 104 | # initialize relative positional embeddings 105 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim)) 106 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim)) 107 | 108 | if not rel_pos_zero_init: 109 | nn.init.trunc_normal_(self.rel_pos_h, std=0.02) 110 | nn.init.trunc_normal_(self.rel_pos_w, std=0.02) 111 | 112 | def forward(self, x, mask=None): 113 | B, N, C = x.shape 114 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 115 | q, k, v = qkv.unbind(2) 116 | if use_fp32_attention := getattr(self, 'fp32_attention', False): 117 | q, k, v = q.float(), k.float(), v.float() 118 | 119 | attn_bias = None 120 | if mask is not None: 121 | attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) 122 | attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf')) 123 | x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias) 124 | 125 | x = x.view(B, N, C) 126 | x = self.proj(x) 127 | x = self.proj_drop(x) 128 | return x 129 | 130 | 131 | ################################################################################# 132 | # AMP attention with fp32 softmax to fix loss NaN problem during training # 133 | ################################################################################# 134 | class Attention(Attention_): 135 | def forward(self, x): 136 | B, N, C = x.shape 137 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 138 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 139 | use_fp32_attention = getattr(self, 'fp32_attention', False) 140 | if use_fp32_attention: 141 | q, k = q.float(), k.float() 142 | with torch.cuda.amp.autocast(enabled=not use_fp32_attention): 143 | attn = (q @ k.transpose(-2, -1)) * self.scale 144 | attn = attn.softmax(dim=-1) 145 | 146 | attn = self.attn_drop(attn) 147 | 148 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 149 | x = self.proj(x) 150 | x = self.proj_drop(x) 151 | return x 152 | 153 | 154 | class FinalLayer(nn.Module): 155 | """ 156 | The final layer of DiT. 157 | """ 158 | 159 | def __init__(self, hidden_size, patch_size, out_channels): 160 | super().__init__() 161 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 162 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 163 | self.adaLN_modulation = nn.Sequential( 164 | nn.SiLU(), 165 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 166 | ) 167 | 168 | def forward(self, x, c): 169 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 170 | x = modulate(self.norm_final(x), shift, scale) 171 | x = self.linear(x) 172 | return x 173 | 174 | 175 | class T2IFinalLayer(nn.Module): 176 | """ 177 | The final layer of DiT. 178 | """ 179 | 180 | def __init__(self, hidden_size, patch_size, out_channels): 181 | super().__init__() 182 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 183 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) 184 | self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) 185 | self.out_channels = out_channels 186 | 187 | def forward(self, x, t): 188 | shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) 189 | x = t2i_modulate(self.norm_final(x), shift, scale) 190 | x = self.linear(x) 191 | return x 192 | 193 | 194 | class MaskFinalLayer(nn.Module): 195 | """ 196 | The final layer of DiT. 197 | """ 198 | 199 | def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): 200 | super().__init__() 201 | self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) 202 | self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) 203 | self.adaLN_modulation = nn.Sequential( 204 | nn.SiLU(), 205 | nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) 206 | ) 207 | def forward(self, x, t): 208 | shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) 209 | x = modulate(self.norm_final(x), shift, scale) 210 | x = self.linear(x) 211 | return x 212 | 213 | 214 | class DecoderLayer(nn.Module): 215 | """ 216 | The final layer of DiT. 217 | """ 218 | 219 | def __init__(self, hidden_size, decoder_hidden_size): 220 | super().__init__() 221 | self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 222 | self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True) 223 | self.adaLN_modulation = nn.Sequential( 224 | nn.SiLU(), 225 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 226 | ) 227 | def forward(self, x, t): 228 | shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) 229 | x = modulate(self.norm_decoder(x), shift, scale) 230 | x = self.linear(x) 231 | return x 232 | 233 | 234 | ################################################################################# 235 | # Embedding Layers for Timesteps and Class Labels # 236 | ################################################################################# 237 | class TimestepEmbedder(nn.Module): 238 | """ 239 | Embeds scalar timesteps into vector representations. 240 | """ 241 | 242 | def __init__(self, hidden_size, frequency_embedding_size=256): 243 | super().__init__() 244 | self.mlp = nn.Sequential( 245 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 246 | nn.SiLU(), 247 | nn.Linear(hidden_size, hidden_size, bias=True), 248 | ) 249 | self.frequency_embedding_size = frequency_embedding_size 250 | 251 | @staticmethod 252 | def timestep_embedding(t, dim, max_period=10000): 253 | """ 254 | Create sinusoidal timestep embeddings. 255 | :param t: a 1-D Tensor of N indices, one per batch element. 256 | These may be fractional. 257 | :param dim: the dimension of the output. 258 | :param max_period: controls the minimum frequency of the embeddings. 259 | :return: an (N, D) Tensor of positional embeddings. 260 | """ 261 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 262 | half = dim // 2 263 | freqs = torch.exp( 264 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) 265 | args = t[:, None].float() * freqs[None] 266 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 267 | if dim % 2: 268 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 269 | return embedding 270 | 271 | def forward(self, t): 272 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype) 273 | return self.mlp(t_freq) 274 | 275 | @property 276 | def dtype(self): 277 | # 返回模型参数的数据类型 278 | return next(self.parameters()).dtype 279 | 280 | 281 | class SizeEmbedder(TimestepEmbedder): 282 | """ 283 | Embeds scalar timesteps into vector representations. 284 | """ 285 | 286 | def __init__(self, hidden_size, frequency_embedding_size=256): 287 | super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size) 288 | self.mlp = nn.Sequential( 289 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 290 | nn.SiLU(), 291 | nn.Linear(hidden_size, hidden_size, bias=True), 292 | ) 293 | self.frequency_embedding_size = frequency_embedding_size 294 | self.outdim = hidden_size 295 | 296 | def forward(self, s, bs): 297 | if s.ndim == 1: 298 | s = s[:, None] 299 | assert s.ndim == 2 300 | if s.shape[0] != bs: 301 | s = s.repeat(bs//s.shape[0], 1) 302 | assert s.shape[0] == bs 303 | b, dims = s.shape[0], s.shape[1] 304 | s = rearrange(s, "b d -> (b d)") 305 | s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype) 306 | s_emb = self.mlp(s_freq) 307 | s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) 308 | return s_emb 309 | 310 | @property 311 | def dtype(self): 312 | # 返回模型参数的数据类型 313 | return next(self.parameters()).dtype 314 | 315 | 316 | class LabelEmbedder(nn.Module): 317 | """ 318 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 319 | """ 320 | 321 | def __init__(self, num_classes, hidden_size, dropout_prob): 322 | super().__init__() 323 | use_cfg_embedding = dropout_prob > 0 324 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 325 | self.num_classes = num_classes 326 | self.dropout_prob = dropout_prob 327 | 328 | def token_drop(self, labels, force_drop_ids=None): 329 | """ 330 | Drops labels to enable classifier-free guidance. 331 | """ 332 | if force_drop_ids is None: 333 | drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob 334 | else: 335 | drop_ids = force_drop_ids == 1 336 | labels = torch.where(drop_ids, self.num_classes, labels) 337 | return labels 338 | 339 | def forward(self, labels, train, force_drop_ids=None): 340 | use_dropout = self.dropout_prob > 0 341 | if (train and use_dropout) or (force_drop_ids is not None): 342 | labels = self.token_drop(labels, force_drop_ids) 343 | return self.embedding_table(labels) 344 | 345 | 346 | class CaptionEmbedder2(nn.Module): 347 | """ 348 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 349 | """ 350 | 351 | def __init__(self, in_channels, hidden_size, act_layer=nn.GELU(approximate='tanh')): 352 | super().__init__() 353 | self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0) 354 | 355 | 356 | def forward(self, caption): 357 | caption = self.y_proj(caption) 358 | return caption 359 | 360 | 361 | class CaptionEmbedder(nn.Module): 362 | """ 363 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 364 | """ 365 | 366 | def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh')): 367 | super().__init__() 368 | self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0) 369 | self.register_buffer("y_embedding", nn.Parameter(torch.randn(1, in_channels) / in_channels ** 0.5)) 370 | self.uncond_prob = uncond_prob 371 | 372 | def token_drop(self, caption, force_drop_ids=None): 373 | """ 374 | Drops labels to enable classifier-free guidance. 375 | """ 376 | if force_drop_ids is None: 377 | drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob 378 | else: 379 | drop_ids = force_drop_ids == 1 380 | null_y = self.y_embedding.repeat(caption.shape[1],1) 381 | caption = torch.where(drop_ids[:, None, None, None], null_y, caption) 382 | return caption 383 | 384 | def forward(self, caption, force_drop_ids=None): 385 | use_dropout = self.uncond_prob > 0 386 | if use_dropout or (force_drop_ids is not None): 387 | caption = self.token_drop(caption, force_drop_ids) 388 | caption = self.y_proj(caption) 389 | return caption 390 | 391 | 392 | class CaptionEmbedderDoubleBr(nn.Module): 393 | """ 394 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 395 | """ 396 | 397 | def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120): 398 | super().__init__() 399 | self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0) 400 | self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5) 401 | self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5) 402 | self.uncond_prob = uncond_prob 403 | 404 | def token_drop(self, global_caption, caption, force_drop_ids=None): 405 | """ 406 | Drops labels to enable classifier-free guidance. 407 | """ 408 | if force_drop_ids is None: 409 | drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob 410 | else: 411 | drop_ids = force_drop_ids == 1 412 | global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption) 413 | caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) 414 | return global_caption, caption 415 | 416 | def forward(self, caption, train, force_drop_ids=None): 417 | assert caption.shape[2: ] == self.y_embedding.shape 418 | global_caption = caption.mean(dim=2).squeeze() 419 | use_dropout = self.uncond_prob > 0 420 | if (train and use_dropout) or (force_drop_ids is not None): 421 | global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids) 422 | y_embed = self.proj(global_caption) 423 | return y_embed, caption -------------------------------------------------------------------------------- /diffusion/model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch.nn as nn 4 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential 5 | import torch.nn.functional as F 6 | import torch 7 | import torch.distributed as dist 8 | import re 9 | import math 10 | from collections.abc import Iterable 11 | from itertools import repeat 12 | from torchvision import transforms as T 13 | import random 14 | from PIL import Image 15 | from torch import _assert 16 | 17 | 18 | def _ntuple(n): 19 | def parse(x): 20 | if isinstance(x, Iterable) and not isinstance(x, str): 21 | return x 22 | return tuple(repeat(x, n)) 23 | return parse 24 | 25 | 26 | to_1tuple = _ntuple(1) 27 | to_2tuple = _ntuple(2) 28 | 29 | def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1): 30 | assert isinstance(model, nn.Module) 31 | 32 | def set_attr(module): 33 | module.grad_checkpointing = True 34 | module.fp32_attention = use_fp32_attention 35 | module.grad_checkpointing_step = gc_step 36 | model.apply(set_attr) 37 | 38 | 39 | def auto_grad_checkpoint(module, *args, **kwargs): 40 | if getattr(module, 'grad_checkpointing', False): 41 | if not isinstance(module, Iterable): 42 | return checkpoint(module, *args, **kwargs) 43 | gc_step = module[0].grad_checkpointing_step 44 | return checkpoint_sequential(module, gc_step, *args, **kwargs) 45 | return module(*args, **kwargs) 46 | 47 | 48 | def checkpoint_sequential(functions, step, input, *args, **kwargs): 49 | 50 | # Hack for keyword-only parameter in a python 2.7-compliant way 51 | preserve = kwargs.pop('preserve_rng_state', True) 52 | if kwargs: 53 | raise ValueError("Unexpected keyword arguments: " + ",".join(kwargs)) 54 | 55 | def run_function(start, end, functions): 56 | def forward(input): 57 | for j in range(start, end + 1): 58 | input = functions[j](input, *args) 59 | return input 60 | return forward 61 | 62 | if isinstance(functions, torch.nn.Sequential): 63 | functions = list(functions.children()) 64 | 65 | # the last chunk has to be non-volatile 66 | end = -1 67 | segment = len(functions) // step 68 | for start in range(0, step * (segment - 1), step): 69 | end = start + step - 1 70 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve) 71 | return run_function(end + 1, len(functions) - 1, functions)(input) 72 | 73 | 74 | def window_partition(x, window_size): 75 | """ 76 | Partition into non-overlapping windows with padding if needed. 77 | Args: 78 | x (tensor): input tokens with [B, H, W, C]. 79 | window_size (int): window size. 80 | 81 | Returns: 82 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 83 | (Hp, Wp): padded height and width before partition 84 | """ 85 | B, H, W, C = x.shape 86 | 87 | pad_h = (window_size - H % window_size) % window_size 88 | pad_w = (window_size - W % window_size) % window_size 89 | if pad_h > 0 or pad_w > 0: 90 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 91 | Hp, Wp = H + pad_h, W + pad_w 92 | 93 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 94 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 95 | return windows, (Hp, Wp) 96 | 97 | 98 | def window_unpartition(windows, window_size, pad_hw, hw): 99 | """ 100 | Window unpartition into original sequences and removing padding. 101 | Args: 102 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 103 | window_size (int): window size. 104 | pad_hw (Tuple): padded height and width (Hp, Wp). 105 | hw (Tuple): original height and width (H, W) before padding. 106 | 107 | Returns: 108 | x: unpartitioned sequences with [B, H, W, C]. 109 | """ 110 | Hp, Wp = pad_hw 111 | H, W = hw 112 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 113 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 114 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 115 | 116 | if Hp > H or Wp > W: 117 | x = x[:, :H, :W, :].contiguous() 118 | return x 119 | 120 | 121 | def get_rel_pos(q_size, k_size, rel_pos): 122 | """ 123 | Get relative positional embeddings according to the relative positions of 124 | query and key sizes. 125 | Args: 126 | q_size (int): size of query q. 127 | k_size (int): size of key k. 128 | rel_pos (Tensor): relative position embeddings (L, C). 129 | 130 | Returns: 131 | Extracted positional embeddings according to relative positions. 132 | """ 133 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 134 | # Interpolate rel pos if needed. 135 | if rel_pos.shape[0] != max_rel_dist: 136 | # Interpolate rel pos. 137 | rel_pos_resized = F.interpolate( 138 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 139 | size=max_rel_dist, 140 | mode="linear", 141 | ) 142 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 143 | else: 144 | rel_pos_resized = rel_pos 145 | 146 | # Scale the coords with short length if shapes for q and k are different. 147 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 148 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 149 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 150 | 151 | return rel_pos_resized[relative_coords.long()] 152 | 153 | 154 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): 155 | """ 156 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 157 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 158 | Args: 159 | attn (Tensor): attention map. 160 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 161 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 162 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 163 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 164 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 165 | 166 | Returns: 167 | attn (Tensor): attention map with added relative positional embeddings. 168 | """ 169 | q_h, q_w = q_size 170 | k_h, k_w = k_size 171 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 172 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 173 | 174 | B, _, dim = q.shape 175 | r_q = q.reshape(B, q_h, q_w, dim) 176 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 177 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 178 | 179 | attn = ( 180 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 181 | ).view(B, q_h * q_w, k_h * k_w) 182 | 183 | return attn 184 | 185 | def mean_flat(tensor): 186 | return tensor.mean(dim=list(range(1, tensor.ndim))) 187 | 188 | 189 | ################################################################################# 190 | # Token Masking and Unmasking # 191 | ################################################################################# 192 | def get_mask(batch, length, mask_ratio, device, mask_type=None, data_info=None, extra_len=0): 193 | """ 194 | Get the binary mask for the input sequence. 195 | Args: 196 | - batch: batch size 197 | - length: sequence length 198 | - mask_ratio: ratio of tokens to mask 199 | - data_info: dictionary with info for reconstruction 200 | return: 201 | mask_dict with following keys: 202 | - mask: binary mask, 0 is keep, 1 is remove 203 | - ids_keep: indices of tokens to keep 204 | - ids_restore: indices to restore the original order 205 | """ 206 | assert mask_type in ['random', 'fft', 'laplacian', 'group'] 207 | mask = torch.ones([batch, length], device=device) 208 | len_keep = int(length * (1 - mask_ratio)) - extra_len 209 | 210 | if mask_type in ['random', 'group']: 211 | noise = torch.rand(batch, length, device=device) # noise in [0, 1] 212 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 213 | ids_restore = torch.argsort(ids_shuffle, dim=1) 214 | # keep the first subset 215 | ids_keep = ids_shuffle[:, :len_keep] 216 | ids_removed = ids_shuffle[:, len_keep:] 217 | 218 | elif mask_type in ['fft', 'laplacian']: 219 | if 'strength' in data_info: 220 | strength = data_info['strength'] 221 | 222 | else: 223 | N = data_info['N'][0] 224 | img = data_info['ori_img'] 225 | # 获取原图的尺寸信息 226 | _, C, H, W = img.shape 227 | if mask_type == 'fft': 228 | # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) 229 | reshaped_image = img.reshape((batch, -1, H // N, N, W // N, N)) 230 | fft_image = torch.fft.fftn(reshaped_image, dim=(3, 5)) 231 | # 取绝对值并求和获取频率强度 232 | strength = torch.sum(torch.abs(fft_image), dim=(1, 3, 5)).reshape((batch, -1,)) 233 | elif type == 'laplacian': 234 | laplacian_kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32).reshape(1, 1, 3, 3) 235 | laplacian_kernel = laplacian_kernel.repeat(C, 1, 1, 1) 236 | # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N) 237 | reshaped_image = img.reshape(-1, C, H // N, N, W // N, N).permute(0, 2, 4, 1, 3, 5).reshape(-1, C, N, N) 238 | laplacian_response = F.conv2d(reshaped_image, laplacian_kernel, padding=1, groups=C) 239 | strength = laplacian_response.sum(dim=[1, 2, 3]).reshape((batch, -1,)) 240 | 241 | # 对频率强度进行归一化,然后使用torch.multinomial进行采样 242 | probabilities = strength / (strength.max(dim=1)[0][:, None]+1e-5) 243 | ids_shuffle = torch.multinomial(probabilities.clip(1e-5, 1), length, replacement=False) 244 | ids_keep = ids_shuffle[:, :len_keep] 245 | ids_restore = torch.argsort(ids_shuffle, dim=1) 246 | ids_removed = ids_shuffle[:, len_keep:] 247 | 248 | mask[:, :len_keep] = 0 249 | mask = torch.gather(mask, dim=1, index=ids_restore) 250 | 251 | return {'mask': mask, 252 | 'ids_keep': ids_keep, 253 | 'ids_restore': ids_restore, 254 | 'ids_removed': ids_removed} 255 | 256 | 257 | def mask_out_token(x, ids_keep, ids_removed=None): 258 | """ 259 | Mask out the tokens specified by ids_keep. 260 | Args: 261 | - x: input sequence, [N, L, D] 262 | - ids_keep: indices of tokens to keep 263 | return: 264 | - x_masked: masked sequence 265 | """ 266 | N, L, D = x.shape # batch, length, dim 267 | x_remain = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 268 | if ids_removed is not None: 269 | x_masked = torch.gather(x, dim=1, index=ids_removed.unsqueeze(-1).repeat(1, 1, D)) 270 | return x_remain, x_masked 271 | else: 272 | return x_remain 273 | 274 | 275 | def mask_tokens(x, mask_ratio): 276 | """ 277 | Perform per-sample random masking by per-sample shuffling. 278 | Per-sample shuffling is done by argsort random noise. 279 | x: [N, L, D], sequence 280 | """ 281 | N, L, D = x.shape # batch, length, dim 282 | len_keep = int(L * (1 - mask_ratio)) 283 | 284 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 285 | 286 | # sort noise for each sample 287 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 288 | ids_restore = torch.argsort(ids_shuffle, dim=1) 289 | 290 | # keep the first subset 291 | ids_keep = ids_shuffle[:, :len_keep] 292 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 293 | 294 | # generate the binary mask: 0 is keep, 1 is remove 295 | mask = torch.ones([N, L], device=x.device) 296 | mask[:, :len_keep] = 0 297 | mask = torch.gather(mask, dim=1, index=ids_restore) 298 | 299 | return x_masked, mask, ids_restore 300 | 301 | 302 | def unmask_tokens(x, ids_restore, mask_token): 303 | # x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D] 304 | mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) 305 | x = torch.cat([x, mask_tokens], dim=1) 306 | x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 307 | return x 308 | 309 | 310 | # Parse 'None' to None and others to float value 311 | def parse_float_none(s): 312 | assert isinstance(s, str) 313 | return None if s == 'None' else float(s) 314 | 315 | 316 | #---------------------------------------------------------------------------- 317 | # Parse a comma separated list of numbers or ranges and return a list of ints. 318 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10] 319 | 320 | def parse_int_list(s): 321 | if isinstance(s, list): return s 322 | ranges = [] 323 | range_re = re.compile(r'^(\d+)-(\d+)$') 324 | for p in s.split(','): 325 | if m := range_re.match(p): 326 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 327 | else: 328 | ranges.append(int(p)) 329 | return ranges 330 | 331 | 332 | def init_processes(fn, args): 333 | """ Initialize the distributed environment. """ 334 | os.environ['MASTER_ADDR'] = args.master_address 335 | os.environ['MASTER_PORT'] = str(random.randint(2000, 6000)) 336 | print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}') 337 | print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}') 338 | torch.cuda.set_device(args.local_rank) 339 | dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size) 340 | fn(args) 341 | if args.global_size > 1: 342 | cleanup() 343 | 344 | 345 | def mprint(*args, **kwargs): 346 | """ 347 | Print only from rank 0. 348 | """ 349 | if dist.get_rank() == 0: 350 | print(*args, **kwargs) 351 | 352 | 353 | def cleanup(): 354 | """ 355 | End DDP training. 356 | """ 357 | dist.barrier() 358 | mprint("Done!") 359 | dist.barrier() 360 | dist.destroy_process_group() 361 | 362 | 363 | #---------------------------------------------------------------------------- 364 | # logging info. 365 | class Logger(object): 366 | """ 367 | Redirect stderr to stdout, optionally print stdout to a file, 368 | and optionally force flushing on both stdout and the file. 369 | """ 370 | 371 | def __init__(self, file_name=None, file_mode="w", should_flush=True): 372 | self.file = None 373 | 374 | if file_name is not None: 375 | self.file = open(file_name, file_mode) 376 | 377 | self.should_flush = should_flush 378 | self.stdout = sys.stdout 379 | self.stderr = sys.stderr 380 | 381 | sys.stdout = self 382 | sys.stderr = self 383 | 384 | def __enter__(self): 385 | return self 386 | 387 | def __exit__(self, exc_type, exc_value, traceback): 388 | self.close() 389 | 390 | def write(self, text): 391 | """Write text to stdout (and a file) and optionally flush.""" 392 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 393 | return 394 | 395 | if self.file is not None: 396 | self.file.write(text) 397 | 398 | self.stdout.write(text) 399 | 400 | if self.should_flush: 401 | self.flush() 402 | 403 | def flush(self): 404 | """Flush written text to both stdout and a file, if open.""" 405 | if self.file is not None: 406 | self.file.flush() 407 | 408 | self.stdout.flush() 409 | 410 | def close(self): 411 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 412 | self.flush() 413 | 414 | # if using multiple loggers, prevent closing in wrong order 415 | if sys.stdout is self: 416 | sys.stdout = self.stdout 417 | if sys.stderr is self: 418 | sys.stderr = self.stderr 419 | 420 | if self.file is not None: 421 | self.file.close() 422 | 423 | 424 | class StackedRandomGenerator: 425 | def __init__(self, device, seeds): 426 | super().__init__() 427 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] 428 | 429 | def randn(self, size, **kwargs): 430 | assert size[0] == len(self.generators) 431 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) 432 | 433 | def randn_like(self, input): 434 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) 435 | 436 | def randint(self, *args, size, **kwargs): 437 | assert size[0] == len(self.generators) 438 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) 439 | 440 | 441 | def prepare_prompt_ar(prompt, ratios, device='cpu', show=True): 442 | # get aspect_ratio or ar 443 | aspect_ratios = re.findall(r"--aspect_ratio\s+(\d+:\d+)", prompt) 444 | ars = re.findall(r"--ar\s+(\d+:\d+)", prompt) 445 | custom_hw = re.findall(r"--hw\s+(\d+:\d+)", prompt) 446 | if show: 447 | print("aspect_ratios:", aspect_ratios, "ars:", ars, "hws:", custom_hw) 448 | prompt_clean = prompt.split("--aspect_ratio")[0].split("--ar")[0].split("--hw")[0] 449 | if len(aspect_ratios) + len(ars) + len(custom_hw) == 0 and show: 450 | print( "Wrong prompt format. Set to default ar: 1. change your prompt into format '--ar h:w or --hw h:w' for correct generating") 451 | if len(aspect_ratios) != 0: 452 | ar = float(aspect_ratios[0].split(':')[0]) / float(aspect_ratios[0].split(':')[1]) 453 | elif len(ars) != 0: 454 | ar = float(ars[0].split(':')[0]) / float(ars[0].split(':')[1]) 455 | else: 456 | ar = 1. 457 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) 458 | if len(custom_hw) != 0: 459 | custom_hw = [float(custom_hw[0].split(':')[0]), float(custom_hw[0].split(':')[1])] 460 | else: 461 | custom_hw = ratios[closest_ratio] 462 | default_hw = ratios[closest_ratio] 463 | prompt_show = f'prompt: {prompt_clean.strip()}\nSize: --ar {closest_ratio}, --bin hw {ratios[closest_ratio]}, --custom hw {custom_hw}' 464 | return prompt_clean, prompt_show, torch.tensor(default_hw, device=device)[None], torch.tensor([float(closest_ratio)], device=device)[None], torch.tensor(custom_hw, device=device)[None] 465 | 466 | 467 | def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int): 468 | orig_hw = torch.tensor([samples.shape[2], samples.shape[3]], dtype=torch.int) 469 | custom_hw = torch.tensor([int(new_height), int(new_width)], dtype=torch.int) 470 | 471 | if (orig_hw != custom_hw).all(): 472 | ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1]) 473 | resized_width = int(orig_hw[1] * ratio) 474 | resized_height = int(orig_hw[0] * ratio) 475 | 476 | transform = T.Compose([ 477 | T.Resize((resized_height, resized_width)), 478 | T.CenterCrop(custom_hw.tolist()) 479 | ]) 480 | return transform(samples) 481 | else: 482 | return samples 483 | 484 | 485 | def resize_and_crop_img(img: Image, new_width, new_height): 486 | orig_width, orig_height = img.size 487 | 488 | ratio = max(new_width/orig_width, new_height/orig_height) 489 | resized_width = int(orig_width * ratio) 490 | resized_height = int(orig_height * ratio) 491 | 492 | img = img.resize((resized_width, resized_height), Image.LANCZOS) 493 | 494 | left = (resized_width - new_width)/2 495 | top = (resized_height - new_height)/2 496 | right = (resized_width + new_width)/2 497 | bottom = (resized_height + new_height)/2 498 | 499 | img = img.crop((left, top, right, bottom)) 500 | 501 | return img 502 | 503 | 504 | 505 | def mask_feature(emb, mask): 506 | if emb.shape[0] == 1: 507 | keep_index = mask.sum().item() 508 | return emb[:, :, :keep_index, :], keep_index 509 | else: 510 | masked_feature = emb * mask[:, None, :, None] 511 | return masked_feature, emb.shape[2] 512 | 513 | 514 | from enum import Enum 515 | class Format(str, Enum): 516 | NCHW = 'NCHW' 517 | NHWC = 'NHWC' 518 | NCL = 'NCL' 519 | NLC = 'NLC' 520 | 521 | class PatchEmbed(nn.Module): 522 | """ 2D Image to Patch Embedding 523 | """ 524 | output_fmt: Format 525 | dynamic_img_pad: torch.jit.Final[bool] 526 | 527 | def __init__( 528 | self, 529 | img_size=224, 530 | patch_size=16, 531 | in_chans=3, 532 | embed_dim=768, 533 | norm_layer=None, 534 | flatten=True, 535 | bias=True, 536 | ): 537 | super().__init__() 538 | if isinstance(img_size, int): 539 | img_size = to_2tuple(img_size) 540 | patch_size = to_2tuple(patch_size) 541 | self.img_size = img_size 542 | self.patch_size = patch_size 543 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 544 | self.num_patches = self.grid_size[0] * self.grid_size[1] 545 | self.flatten = flatten 546 | 547 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 548 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 549 | 550 | def forward(self, x): 551 | B, C, H, W = x.shape 552 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 553 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 554 | x = self.proj(x) 555 | if self.flatten: 556 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 557 | x = self.norm(x) 558 | return x -------------------------------------------------------------------------------- /voicedit/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import torchaudio 5 | import numpy as np 6 | from tqdm.auto import tqdm 7 | from PIL import Image 8 | from copy import deepcopy 9 | import random 10 | 11 | from transformers import ( 12 | ClapModel, 13 | ClapProcessor, 14 | CLIPModel, 15 | CLIPProcessor, 16 | SpeechT5HifiGan, 17 | ) 18 | from diffusers import ( 19 | AutoencoderKL, 20 | DDIMScheduler, 21 | UNet2DModel, 22 | ) 23 | try: 24 | from diffusers.utils import randn_tensor 25 | except: 26 | from diffusers.utils.torch_utils import randn_tensor 27 | from huggingface_hub import hf_hub_download 28 | 29 | from .utils import sequence_mask, generate_path, fix_len_compatibility 30 | from espnet2.bin.spk_inference import Speech2Embedding 31 | from i2a_translator import DiffusionPrior 32 | 33 | from diffusion.model.nets.DiT2 import DiT_models 34 | from diffusion.model.nets.DiT import DiT_cross_models 35 | from diffusion import create_diffusion 36 | 37 | from text import text_to_sequence, cmudict 38 | from text.symbols import symbols 39 | from .modules import DiTWrapper 40 | from .utils import intersperse 41 | 42 | # helper function 43 | def exists(val): 44 | return val is not None 45 | 46 | 47 | class VoiceDiTPipeline(): 48 | def __init__( 49 | self, 50 | ckpt_path, 51 | v2a_ckpt_path = None, 52 | t2a_ckpt_path = None, 53 | device = None, 54 | cmudict_path='voicedit/cmu_dictionary', 55 | male_voice = None, 56 | female_voice = None, 57 | ): 58 | 59 | self.vae = AutoencoderKL.from_pretrained("cvssp/audioldm-m-full", subfolder="vae").eval() 60 | self.vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm-m-full", subfolder="vocoder").eval() 61 | self.clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").eval() 62 | self.clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused") 63 | self.speech2spk_embed = Speech2Embedding.from_pretrained(model_tag="espnet/voxcelebs12_ecapa_wavlm_joint", device=str(device)) 64 | self.cmudict = cmudict.CMUDict(cmudict_path) 65 | 66 | if exists(v2a_ckpt_path): 67 | self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval() 68 | self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 69 | self.v2a_mapper = DiffusionPrior.from_pretrained(v2a_ckpt_path).eval() 70 | 71 | if exists(t2a_ckpt_path): 72 | self.t2a_mapper = DiffusionPrior.from_pretrained(t2a_ckpt_path).eval() 73 | 74 | model_config = ckpt_path.rsplit('/', 2)[0] + '/config.yaml' 75 | with open(model_config, 'r') as f: 76 | self.config = yaml.safe_load(f) 77 | 78 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 79 | if self.config['concat_y']: 80 | dit_model = DiT_models[self.config['model']] 81 | else: 82 | dit_model = DiT_cross_models[self.config['model']] 83 | dit = dit_model( 84 | input_size = (1024 // vae_scale_factor, 64 // vae_scale_factor), 85 | in_channels = 8, 86 | projection_class_embeddings_input_dim=512, 87 | caption_channels = 64, 88 | concat_y = self.config.get("concat_y", False), 89 | cross_class = self.config.get("cross_class", False), 90 | ) 91 | 92 | # TODO: Get checkpoints 93 | def load_ckpt(model, ckpt_path): 94 | print(f"Loading checkpoint from {ckpt_path}") 95 | ckpt = torch.load(ckpt_path, map_location="cpu") 96 | model.load_state_dict(ckpt) 97 | print(f"Loaded checkpoint from {ckpt_path}") 98 | return model 99 | 100 | self.model = load_ckpt(DiTWrapper(dit, cond_speaker=True, concat_y=self.config.get("concat_y", False)), ckpt_path) 101 | self.model.eval() 102 | 103 | self.device = device 104 | self.vae.to(device) 105 | self.vocoder.to(device) 106 | self.clap_model.to(device) 107 | 108 | if exists(v2a_ckpt_path): 109 | self.clip_model.to(device) 110 | self.v2a_mapper.to(device) 111 | if exists(t2a_ckpt_path): 112 | self.t2a_mapper.to(device) 113 | self.model.to(device) 114 | 115 | # if exists(male_voice): 116 | # wav, sr = torchaudio.load(male_voice) 117 | # if sr != 16000: 118 | # wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=16000) 119 | # self.male_voice = self.speech2spk_embed(wav.squeeze(0).to(device)) 120 | 121 | # if exists(female_voice): 122 | # wav, sr = torchaudio.load(female_voice) 123 | # if sr != 16000: 124 | # wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=16000) 125 | # self.female_voice = self.speech2spk_embed(wav.squeeze(0).to(device)) 126 | 127 | if exists(male_voice): 128 | with open(male_voice, 'r') as f: 129 | self.male_voice = f.readlines() 130 | 131 | if exists(female_voice): 132 | with open(female_voice, 'r') as f: 133 | self.female_voice = f.readlines() 134 | 135 | def get_text(self, text, add_blank=True): 136 | text_norm = text_to_sequence(text, dictionary=self.cmudict) 137 | if add_blank: 138 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) 139 | text_norm = torch.IntTensor(text_norm) 140 | return text_norm 141 | 142 | def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): 143 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 144 | shape = ( 145 | batch_size, 146 | num_channels_latents, 147 | height // vae_scale_factor, 148 | self.vocoder.config.model_in_dim // vae_scale_factor, 149 | ) 150 | if isinstance(generator, list) and len(generator) != batch_size: 151 | raise ValueError( 152 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 153 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 154 | ) 155 | 156 | if latents is None: 157 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 158 | else: 159 | latents = latents.to(device) 160 | 161 | # latents = latents * self.noise_scheduler.init_noise_sigma 162 | return latents 163 | 164 | def decode_latents(self, latents): 165 | latents = 1 / self.vae.config.scaling_factor * latents 166 | mel_spectrogram = self.vae.decode(latents).sample 167 | return mel_spectrogram 168 | 169 | def mel_spectrogram_to_waveform(self, mel_spectrogram): 170 | if mel_spectrogram.dim() == 4: 171 | mel_spectrogram = mel_spectrogram.squeeze(1) 172 | 173 | waveform = self.vocoder(mel_spectrogram) 174 | waveform = waveform.cpu().float() 175 | return waveform 176 | 177 | def normalize_wav(self, waveform): 178 | waveform = waveform - torch.mean(waveform) 179 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) 180 | return waveform 181 | 182 | @torch.no_grad() 183 | def __call__( 184 | self, 185 | modality, 186 | env_prompt, 187 | cont_prompt, 188 | style_prompt = None, 189 | speaker_audio = None, 190 | gender = None, 191 | batch_size = 1, 192 | num_inference_steps = 100, 193 | audio_length_in_s = 10, 194 | do_classifier_free_guidance = True, 195 | v2a_guidance_scale = None, 196 | guidance_scale = None, 197 | desc_guidance_scale = None, 198 | cont_guidance_scale = None, 199 | device=None, 200 | seed=None, 201 | progress=True, 202 | **kwargs, 203 | ): 204 | 205 | if guidance_scale is None and desc_guidance_scale is None and cont_guidance_scale is None: 206 | do_classifier_free_guidance = False 207 | 208 | guidance = None 209 | if guidance_scale is None: 210 | guidance = "dual" 211 | else: 212 | guidance = "single" 213 | 214 | # description condition 215 | if modality == 'text': 216 | if do_classifier_free_guidance: 217 | if guidance == "dual": 218 | env_prompt = [env_prompt] * 2 + [""] * 2 219 | if guidance == "single": 220 | env_prompt = [env_prompt] + [""] 221 | else: 222 | env_prompt = [env_prompt] 223 | 224 | clap_inputs = self.clap_processor( 225 | text=env_prompt, 226 | return_tensors="pt", 227 | padding=True 228 | ).to(self.device) 229 | 230 | c_desc = self.clap_model.get_text_features(**clap_inputs) 231 | 232 | elif modality == 'audio': 233 | audio_sample, sr = torchaudio.load(env_prompt) 234 | if sr != 48000: 235 | audio_sample = torchaudio.functional.resample(audio_sample, orig_freq=sr, new_freq=48000) 236 | audio_sample = audio_sample[0] 237 | 238 | clap_inputs = self.clap_processor(audios=audio_sample, sampling_rate=48000, return_tensors="pt", padding=True).to(self.device) 239 | c_desc = self.clap_model.get_audio_features(**clap_inputs) 240 | 241 | if do_classifier_free_guidance: 242 | clap_inputs = self.clap_processor(text=[""], return_tensors="pt", padding=True).to(self.device) 243 | # uncond_embeds = self.clap_model.text_model(**clap_inputs).pooler_output 244 | # uc_desc = self.clap_model.text_projection(uncond_embeds) 245 | uc_desc = self.clap_model.get_text_features(**clap_inputs) 246 | 247 | if guidance == "dual": 248 | c_desc = torch.cat((c_desc, c_desc, uc_desc, uc_desc)) 249 | if guidance == "single": 250 | c_desc = torch.cat((c_desc, uc_desc)) 251 | 252 | elif modality == 'image': 253 | image_sample = Image.open(env_prompt) 254 | 255 | clip_inputs = self.clip_processor(images=image_sample, return_tensors="pt").to(self.device) 256 | # clip_embeds = self.clip_model.vision_model(**clip_inputs).pooler_output 257 | # clip_embeds = self.clip_model.visual_projection(clip_embeds) 258 | # clip_embeds = torch.nn.functional.normalize(clip_embeds, dim=-1) 259 | clip_embeds = self.clip_model.get_image_features(**clip_inputs) 260 | # clip_embeds = env_prompt 261 | text_cond = dict(text_embed = clip_embeds) 262 | 263 | c_desc = self.v2a_mapper.p_sample_loop( 264 | clip_embeds.shape, 265 | text_cond = text_cond, 266 | cond_scale = v2a_guidance_scale, 267 | timesteps = 100 268 | ) 269 | 270 | c_desc = torch.nn.functional.normalize(c_desc, dim=-1) 271 | 272 | if do_classifier_free_guidance: 273 | clip_inputs = self.clap_processor(text=[""], return_tensors="pt", padding=True).to(self.device) 274 | uc_desc = self.clap_model.get_text_features(**clip_inputs) 275 | 276 | if guidance == "dual": 277 | c_desc = torch.cat((c_desc, c_desc, uc_desc, uc_desc)) 278 | if guidance == "single": 279 | c_desc = torch.cat((c_desc, uc_desc)) 280 | 281 | # speaker style conditon 282 | if style_prompt is not None: 283 | if do_classifier_free_guidance: 284 | if guidance == "dual": 285 | style_prompt = [style_prompt] * 2 + [""] * 2 286 | if guidance == "single": 287 | style_prompt = [style_prompt] + [""] 288 | else: 289 | style_prompt = [style_prompt] 290 | 291 | clap_text_tokens = self.clap_tokenizer(style_prompt, return_tensors="pt", padding=True).to(self.device) 292 | c_style = self.clap_model.get_text_features(**clap_text_tokens) 293 | 294 | elif gender is not None: 295 | 296 | if gender == 'man': 297 | spk_audio = random.choice(self.male_voice).strip() 298 | elif gender == 'woman': 299 | spk_audio = random.choice(self.female_voice).strip() 300 | else: 301 | spk_audio = random.choice(self.male_voice + self.female_voice).strip() 302 | spk_audio, sr = torchaudio.load(spk_audio) 303 | 304 | if sr != 16000: 305 | spk_audio = torchaudio.functional.resample(spk_audio, orig_freq=sr, new_freq=16000) 306 | if spk_audio.shape[0] > 1: 307 | spk_audio = spk_audio.mean(0, keepdim=True) 308 | c_style = self.speech2spk_embed(spk_audio.squeeze(0)) 309 | 310 | elif speaker_audio is not None: 311 | spk_audio, sr = torchaudio.load(speaker_audio) 312 | if sr != 16000: 313 | spk_audio = torchaudio.functional.resample(spk_audio, orig_freq=sr, new_freq=16000) 314 | if spk_audio.shape[0] > 1: 315 | spk_audio = spk_audio.mean(0, keepdim=True) 316 | c_style=self.speech2spk_embed(spk_audio.squeeze(0)) 317 | else: 318 | c_style = torch.zeros(192).to(self.device) 319 | 320 | cont_tokens = self.get_text(cont_prompt, add_blank=True) 321 | cont_tokens = cont_tokens.unsqueeze(0).to(self.device) 322 | cont_lengths = torch.LongTensor([cont_tokens.shape[-1]]).to(self.device) 323 | 324 | mu_y, y_mask = self.model.process_content(cont_tokens, cont_lengths, c_style) 325 | 326 | if self.model.concat_y: 327 | mu_y_vae = self.model.proj(mu_y.unsqueeze(1)) 328 | else: 329 | mu_y_vae = mu_y.unsqueeze(1) 330 | 331 | vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate 332 | height = int(audio_length_in_s * 1.024 / vocoder_upsample_factor) 333 | original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) 334 | 335 | # prepare latent variables 336 | num_channels_latents = self.model.dit.in_channels 337 | latents = self.prepare_latents( 338 | batch_size, 339 | num_channels_latents, 340 | height, 341 | c_desc.dtype, 342 | device=device, 343 | generator=torch.manual_seed(seed) if seed else None, 344 | latents=None, 345 | ) 346 | 347 | # denoising loop 348 | 349 | if guidance == "dual": 350 | z = torch.cat([latents] * 4) if do_classifier_free_guidance else latents 351 | cfg_scale = (desc_guidance_scale, cont_guidance_scale) 352 | elif guidance == "single": 353 | z = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 354 | cfg_scale = guidance_scale 355 | 356 | if self.model.concat_y: 357 | mu_y_vae = torch.nn.functional.pad(mu_y_vae, (0, 0, 0, z.shape[-2] - mu_y_vae.shape[-2]), value=0) 358 | if do_classifier_free_guidance: 359 | if self.config['concat_y']: 360 | if guidance == "dual": 361 | mu_y_vae = torch.cat([mu_y_vae, self.model.y_embedding.unsqueeze(0)] * 2, dim=0) 362 | elif guidance == "single": 363 | mu_y_vae = torch.cat([mu_y_vae, self.model.y_embedding.unsqueeze(0)], dim=0) 364 | else: 365 | null_y = self.model.dit.y_embedder.y_embedding.unsqueeze(1).repeat(1, mu_y_vae.shape[-2], 1) 366 | if guidance == "dual": 367 | mu_y_vae = torch.cat([mu_y_vae, null_y.unsqueeze(0)] * 2, dim=0) 368 | elif guidance == "single": 369 | mu_y_vae = torch.cat([mu_y_vae, null_y.unsqueeze(0)], dim=0) 370 | 371 | model_kwargs = dict( 372 | y=mu_y_vae, 373 | mask=y_mask.long().squeeze(1) if not self.model.concat_y else None, 374 | class_labels=c_desc, 375 | guidance=guidance, 376 | cfg_scale = cfg_scale, 377 | ) 378 | 379 | diffusion = create_diffusion(str(num_inference_steps)) 380 | 381 | # Sample images: 382 | samples = diffusion.p_sample_loop( 383 | self.model.dit.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress, 384 | device=self.device 385 | ) 386 | if guidance == "dual": 387 | samples = samples[: len(samples) // 4] 388 | elif guidance == "single": 389 | samples = samples[: len(samples) // 2] 390 | 391 | mel_spectrogram = self.decode_latents(samples) 392 | audio = self.mel_spectrogram_to_waveform(mel_spectrogram) 393 | audio = audio[:, :original_waveform_length] 394 | audio = self.normalize_wav(audio) 395 | 396 | return audio, mel_spectrogram 397 | 398 | 399 | class EvalPipeline(): 400 | def __init__( 401 | self, 402 | vae, 403 | clap_model: ClapModel, 404 | clap_processor: ClapProcessor, 405 | device, 406 | cmudict_path='voicedit/cmu_dictionary' 407 | ): 408 | self.vae = vae 409 | self.vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm-m-full", subfolder="vocoder").eval() 410 | self.clap_model = clap_model 411 | self.clap_processor = clap_processor 412 | self.speech2spk_embed = Speech2Embedding.from_pretrained(model_tag="espnet/voxcelebs12_ecapa_wavlm_joint") 413 | self.device = device 414 | self.vocoder.to(device) 415 | self.cmudict = cmudict.CMUDict(cmudict_path) 416 | 417 | def get_text(self, text, add_blank=True): 418 | text_norm = text_to_sequence(text, dictionary=self.cmudict) 419 | if add_blank: 420 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols) 421 | text_norm = torch.IntTensor(text_norm) 422 | return text_norm 423 | 424 | def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None): 425 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 426 | shape = ( 427 | batch_size, 428 | num_channels_latents, 429 | height // vae_scale_factor, 430 | self.vocoder.config.model_in_dim // vae_scale_factor, 431 | ) 432 | if isinstance(generator, list) and len(generator) != batch_size: 433 | raise ValueError( 434 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 435 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 436 | ) 437 | 438 | if latents is None: 439 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 440 | else: 441 | latents = latents.to(device) 442 | 443 | return latents 444 | 445 | def decode_latents(self, latents): 446 | latents = 1 / self.vae.config.scaling_factor * latents 447 | mel_spectrogram = self.vae.decode(latents).sample 448 | return mel_spectrogram 449 | 450 | def mel_spectrogram_to_waveform(self, mel_spectrogram): 451 | if mel_spectrogram.dim() == 4: 452 | mel_spectrogram = mel_spectrogram.squeeze(1) 453 | 454 | waveform = self.vocoder(mel_spectrogram) 455 | waveform = waveform.cpu().float() 456 | return waveform 457 | 458 | def normalize_wav(self, waveform): 459 | waveform = waveform - torch.mean(waveform) 460 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) 461 | return waveform 462 | 463 | @torch.autocast('cuda', dtype=torch.float16) 464 | @torch.no_grad() 465 | def __call__( 466 | self, 467 | model, 468 | clap_input_features, 469 | clap_is_longer, 470 | cont_prompt, 471 | spk_embeds, 472 | batch_size = 1, 473 | num_inference_steps = 100, 474 | audio_length_in_s = 10, 475 | do_classifier_free_guidance = True, 476 | guidance_scale = None, 477 | desc_guidance_scale = None, 478 | cont_guidance_scale = None, 479 | seed=None, 480 | concat_y = None, 481 | **kwargs, 482 | ): 483 | if guidance_scale is None and desc_guidance_scale is None and cont_guidance_scale is None: 484 | do_classifier_free_guidance = False 485 | 486 | guidance = None 487 | if guidance_scale is None: 488 | guidance = "dual" 489 | else: 490 | guidance = "single" 491 | 492 | # embeds = self.clap_model.audio_model(input_features=clap_input_features, is_longer=clap_is_longer).pooler_output 493 | # c_desc = self.clap_model.audio_projection(embeds) 494 | 495 | c_desc = self.clap_model.get_audio_features(input_features=clap_input_features, is_longer=clap_is_longer) 496 | 497 | if do_classifier_free_guidance: 498 | clap_inputs = self.clap_processor(text=[""], return_tensors="pt", padding=True).to(self.device) 499 | # uncond_embeds = self.clap_model.text_model(**clap_inputs).pooler_output 500 | # uc_desc = self.clap_model.text_projection(uncond_embeds) 501 | uc_desc = self.clap_model.get_text_features(**clap_inputs) 502 | 503 | if guidance == "dual": 504 | c_desc = torch.cat((c_desc, c_desc, uc_desc, uc_desc)) 505 | if guidance == "single": 506 | c_desc = torch.cat((c_desc, uc_desc)) 507 | 508 | # content condition 509 | # if do_classifier_free_guidance: 510 | # if guidance == "dual": 511 | # cont_prompt = ([cont_prompt] + ["_"]) * 2 512 | # if guidance == "single": 513 | # cont_prompt = [cont_prompt] + ["_"] 514 | 515 | # cont_tokens = self.text_processor( 516 | # text=[cont_prompt], 517 | # padding=True, 518 | # truncation=True, 519 | # max_length=1000, 520 | # return_tensors="pt" 521 | # ).to(self.device) 522 | # cont_embed_mask = cont_tokens.attention_mask 523 | 524 | cont_tokens = self.get_text(cont_prompt, add_blank=True) 525 | cont_tokens = cont_tokens.unsqueeze(0).to(self.device) 526 | cont_lengths = torch.LongTensor([cont_tokens.shape[-1]]).to(self.device) 527 | 528 | mu_y, y_mask = model.process_content(cont_tokens, cont_lengths, spk_embeds) 529 | mu_y_vae = mu_y.unsqueeze(1) 530 | if concat_y: 531 | mu_y_vae = model.proj(mu_y_vae) 532 | 533 | vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate 534 | height = int(audio_length_in_s * 1.024 / vocoder_upsample_factor) 535 | original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate) 536 | 537 | # prepare latent variables 538 | num_channels_latents = model.dit.in_channels 539 | latents = self.prepare_latents( 540 | batch_size, 541 | num_channels_latents, 542 | height, 543 | c_desc.dtype, 544 | device=self.device, 545 | generator=torch.manual_seed(seed) if seed else None, 546 | latents=None, 547 | ) 548 | 549 | # denoising loop 550 | if guidance == "dual": 551 | z = torch.cat([latents] * 4) if do_classifier_free_guidance else latents 552 | cfg_scale = (desc_guidance_scale, cont_guidance_scale) 553 | elif guidance == "single": 554 | z = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 555 | cfg_scale = guidance_scale 556 | 557 | if concat_y: 558 | mu_y_vae = torch.nn.functional.pad(mu_y_vae, (0, 0, 0, z.shape[-2] - mu_y_vae.shape[-2]), value=0) 559 | if do_classifier_free_guidance: 560 | if guidance == "dual": 561 | mu_y_vae = torch.cat([mu_y_vae, model.y_embedding.unsqueeze(0)] * 2, dim=0) 562 | elif guidance == "single": 563 | mu_y_vae = torch.cat([mu_y_vae, model.y_embedding.unsqueeze(0)], dim=0) 564 | 565 | model_kwargs = dict( 566 | y=mu_y_vae, 567 | class_labels=c_desc, 568 | guidance=guidance, 569 | cfg_scale=cfg_scale, 570 | ) 571 | 572 | else: 573 | null_y = model.dit.y_embedder.y_embedding.repeat(mu_y_vae.shape[-2], 1) 574 | null_y = null_y.unsqueeze(0).unsqueeze(0) 575 | if do_classifier_free_guidance: 576 | if guidance == "dual": 577 | mu_y_vae = torch.cat([mu_y_vae, null_y] * 2, dim=0) 578 | elif guidance == "single": 579 | mu_y_vae = torch.cat([mu_y_vae, null_y], dim=0) 580 | 581 | model_kwargs = dict( 582 | y=mu_y_vae, 583 | mask=y_mask.squeeze(1).bool(), 584 | class_labels=c_desc, 585 | guidance=guidance, 586 | cfg_scale=cfg_scale, 587 | ) 588 | 589 | diffusion = create_diffusion(str(num_inference_steps)) 590 | 591 | # Sample images 592 | samples = diffusion.p_sample_loop( 593 | model.dit.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, 594 | device=self.device 595 | ) 596 | if guidance == "dual": 597 | samples = samples[: len(samples) // 4] 598 | if guidance == "single": 599 | samples = samples[: len(samples) // 2] 600 | 601 | mel_spectrogram = self.decode_latents(samples) 602 | audio = self.mel_spectrogram_to_waveform(mel_spectrogram) 603 | audio = audio[:, :original_waveform_length] 604 | audio = self.normalize_wav(audio) 605 | 606 | return mu_y[0].transpose(0,1), mel_spectrogram.transpose(2,3), audio --------------------------------------------------------------------------------