├── diffusion
├── model
│ ├── __init__.py
│ ├── nets
│ │ ├── __init__.py
│ │ ├── DiT2.py
│ │ ├── DiT.py
│ │ └── DiT_blocks.py
│ └── utils.py
├── __init__.py
├── diffusion_utils.py
└── respace.py
├── assets
├── dog.jpg
├── park.jpg
├── beach.jpg
├── church.png
├── tiger.png
├── pipeline.png
├── female_voice.wav
├── male_voice.wav
└── bird_chirping.wav
├── i2a_translator
└── __init__.py
├── voicedit
├── audio
│ ├── __init__.py
│ ├── audio_processing.py
│ ├── tools.py
│ └── stft.py
├── monotonic_align
│ ├── build
│ │ ├── temp.linux-x86_64-cpython-310
│ │ │ └── core.o
│ │ └── lib.linux-x86_64-cpython-310
│ │ │ └── monotonic_align
│ │ │ └── core.cpython-310-x86_64-linux-gnu.so
│ ├── monotonic_align
│ │ └── core.cpython-310-x86_64-linux-gnu.so
│ ├── setup.py
│ ├── __init__.py
│ └── core.pyx
├── __init__.py
├── utils.py
├── data_ft.py
├── data.py
├── modules.py
├── text_encoder.py
└── pipeline.py
├── requirements.txt
├── text
├── symbols.py
├── LICENSE
├── cleaners.py
├── cmudict.py
├── numbers.py
└── __init__.py
├── README.md
└── generate.py
/diffusion/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .nets import *
2 |
--------------------------------------------------------------------------------
/diffusion/model/nets/__init__.py:
--------------------------------------------------------------------------------
1 | from .DiT import DiT, DiT_XL_2
--------------------------------------------------------------------------------
/assets/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/dog.jpg
--------------------------------------------------------------------------------
/assets/park.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/park.jpg
--------------------------------------------------------------------------------
/assets/beach.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/beach.jpg
--------------------------------------------------------------------------------
/assets/church.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/church.png
--------------------------------------------------------------------------------
/assets/tiger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/tiger.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/female_voice.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/female_voice.wav
--------------------------------------------------------------------------------
/assets/male_voice.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/male_voice.wav
--------------------------------------------------------------------------------
/assets/bird_chirping.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/assets/bird_chirping.wav
--------------------------------------------------------------------------------
/i2a_translator/__init__.py:
--------------------------------------------------------------------------------
1 | from i2a_translator.i2a_translator import DiffusionPriorNetwork, DiffusionPrior
--------------------------------------------------------------------------------
/voicedit/audio/__init__.py:
--------------------------------------------------------------------------------
1 | from .tools import get_mel_from_wav, wav_to_fbank, read_wav_file, raw_waveform_to_fbank
2 | from .stft import TacotronSTFT
3 |
--------------------------------------------------------------------------------
/voicedit/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/voicedit/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o
--------------------------------------------------------------------------------
/voicedit/monotonic_align/monotonic_align/core.cpython-310-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/voicedit/monotonic_align/monotonic_align/core.cpython-310-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/voicedit/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import AudioDataset, CollateFn
2 | from .data_ft import AudioDataset_ft, CollateFn_ft
3 | from .modules import DiTWrapper
4 | from .text_encoder import TextEncoder
5 | from .pipeline import VoiceDiTPipeline, EvalPipeline, VoiceUNetPipeline
--------------------------------------------------------------------------------
/voicedit/monotonic_align/build/lib.linux-x86_64-cpython-310/monotonic_align/core.cpython-310-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kaistmm/VoiceDiT/HEAD/voicedit/monotonic_align/build/lib.linux-x86_64-cpython-310/monotonic_align/core.cpython-310-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/voicedit/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | from Cython.Build import cythonize
3 | import numpy
4 |
5 | setup(
6 | name = 'monotonic_align',
7 | ext_modules = cythonize("core.pyx"),
8 | include_dirs=[numpy.get_include()]
9 | )
10 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.1.1
2 | torchaudio==2.1.1
3 | torchvision==0.16.1
4 | diffusers
5 | timm==0.6.12
6 | accelerate
7 | tensorboard
8 | transformers
9 | sentencepiece~=0.1.99
10 | ftfy
11 | beautifulsoup4
12 | opencv-python
13 | einops
14 | xformers==0.0.23
15 | librosa
16 | pandas
17 | scipy
18 | tqdm
19 | thop
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | from text import cmudict
4 |
5 | _pad = '_'
6 | _punctuation = '!\'(),.:;? '
7 | _special = '-'
8 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
9 |
10 | # Prepend "@" to ARPAbet symbols to ensure uniqueness:
11 | _arpabet = ['@' + s for s in cmudict.valid_symbols]
12 |
13 | # Export all symbols:
14 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
15 |
--------------------------------------------------------------------------------
/voicedit/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from .monotonic_align.core import maximum_path_c
4 |
5 |
6 | def maximum_path(value, mask):
7 | """ Cython optimised version.
8 | value: [b, t_x, t_y]
9 | mask: [b, t_x, t_y]
10 | """
11 | value = value * mask
12 | device = value.device
13 | dtype = value.dtype
14 | value = value.data.cpu().numpy().astype(np.float32)
15 | path = np.zeros_like(value).astype(np.int32)
16 | mask = mask.data.cpu().numpy()
17 |
18 | t_x_max = mask.sum(1)[:, 0].astype(np.int32)
19 | t_y_max = mask.sum(2)[:, 0].astype(np.int32)
20 | maximum_path_c(path, value, t_x_max, t_y_max)
21 | return torch.from_numpy(path).to(device=device, dtype=dtype)
22 |
--------------------------------------------------------------------------------
/voicedit/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | cimport numpy as np
3 | cimport cython
4 | from cython.parallel import prange
5 |
6 |
7 | @cython.boundscheck(False)
8 | @cython.wraparound(False)
9 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
10 | cdef int x
11 | cdef int y
12 | cdef float v_prev
13 | cdef float v_cur
14 | cdef float tmp
15 | cdef int index = t_x - 1
16 |
17 | for y in range(t_y):
18 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
19 | if x == y:
20 | v_cur = max_neg_val
21 | else:
22 | v_cur = value[x, y-1]
23 | if x == 0:
24 | if y == 0:
25 | v_prev = 0.
26 | else:
27 | v_prev = max_neg_val
28 | else:
29 | v_prev = value[x-1, y-1]
30 | value[x, y] = max(v_cur, v_prev) + value[x, y]
31 |
32 | for y in range(t_y - 1, -1, -1):
33 | path[index, y] = 1
34 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
35 | index = index - 1
36 |
37 |
38 | @cython.boundscheck(False)
39 | @cython.wraparound(False)
40 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
41 | cdef int b = values.shape[0]
42 |
43 | cdef int i
44 | for i in prange(b, nogil=True):
45 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
46 |
--------------------------------------------------------------------------------
/text/LICENSE:
--------------------------------------------------------------------------------
1 | CMUdict
2 | -------
3 |
4 | CMUdict (the Carnegie Mellon Pronouncing Dictionary) is a free
5 | pronouncing dictionary of English, suitable for uses in speech
6 | technology and is maintained by the Speech Group in the School of
7 | Computer Science at Carnegie Mellon University.
8 |
9 | The Carnegie Mellon Speech Group does not guarantee the accuracy of
10 | this dictionary, nor its suitability for any specific purpose. In
11 | fact, we expect a number of errors, omissions and inconsistencies to
12 | remain in the dictionary. We intend to continually update the
13 | dictionary by correction existing entries and by adding new ones. From
14 | time to time a new major version will be released.
15 |
16 | We welcome input from users: Please send email to Alex Rudnicky
17 | (air+cmudict@cs.cmu.edu).
18 |
19 | The Carnegie Mellon Pronouncing Dictionary, in its current and
20 | previous versions is Copyright (C) 1993-2014 by Carnegie Mellon
21 | University. Use of this dictionary for any research or commercial
22 | purpose is completely unrestricted. If you make use of or
23 | redistribute this material we request that you acknowledge its
24 | origin in your descriptions.
25 |
26 | If you add words to or correct words in your version of this
27 | dictionary, we would appreciate it if you could send these additions
28 | and corrections to us (air+cmudict@cs.cmu.edu) for consideration in a
29 | subsequent version. All submissions will be reviewed and approved by
30 | the current maintainer, Alex Rudnicky at Carnegie Mellon.
31 |
--------------------------------------------------------------------------------
/voicedit/utils.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jaywalnut310/glow-tts """
2 |
3 | import torch
4 |
5 |
6 | def sequence_mask(length, max_length=None):
7 | if max_length is None:
8 | max_length = length.max()
9 | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
10 | return x.unsqueeze(0) < length.unsqueeze(1)
11 |
12 |
13 | def fix_len_compatibility(length, num_downsamplings_in_unet=2):
14 | while True:
15 | if length % (2**num_downsamplings_in_unet) == 0:
16 | return length
17 | length += 1
18 |
19 |
20 | def convert_pad_shape(pad_shape):
21 | l = pad_shape[::-1]
22 | pad_shape = [item for sublist in l for item in sublist]
23 | return pad_shape
24 |
25 |
26 | def generate_path(duration, mask):
27 | device = duration.device
28 |
29 | b, t_x, t_y = mask.shape
30 | cum_duration = torch.cumsum(duration, 1)
31 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
32 |
33 | cum_duration_flat = cum_duration.view(b * t_x)
34 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
35 | path = path.view(b, t_x, t_y)
36 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
37 | [1, 0], [0, 0]]))[:, :-1]
38 | path = path * mask
39 | return path
40 |
41 |
42 | def duration_loss(logw, logw_, lengths):
43 | loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
44 | return loss
45 |
46 |
47 | def intersperse(lst, item):
48 | # Adds blank symbol
49 | result = [item] * (len(lst) * 2 + 1)
50 | result[1::2] = lst
51 | return result
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 | from unidecode import unidecode
5 | from .numbers import normalize_numbers
6 |
7 |
8 | _whitespace_re = re.compile(r'\s+')
9 |
10 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
11 | ('mrs', 'misess'),
12 | ('mr', 'mister'),
13 | ('dr', 'doctor'),
14 | ('st', 'saint'),
15 | ('co', 'company'),
16 | ('jr', 'junior'),
17 | ('maj', 'major'),
18 | ('gen', 'general'),
19 | ('drs', 'doctors'),
20 | ('rev', 'reverend'),
21 | ('lt', 'lieutenant'),
22 | ('hon', 'honorable'),
23 | ('sgt', 'sergeant'),
24 | ('capt', 'captain'),
25 | ('esq', 'esquire'),
26 | ('ltd', 'limited'),
27 | ('col', 'colonel'),
28 | ('ft', 'fort'),
29 | ]]
30 |
31 |
32 | def expand_abbreviations(text):
33 | for regex, replacement in _abbreviations:
34 | text = re.sub(regex, replacement, text)
35 | return text
36 |
37 |
38 | def expand_numbers(text):
39 | return normalize_numbers(text)
40 |
41 |
42 | def lowercase(text):
43 | return text.lower()
44 |
45 |
46 | def collapse_whitespace(text):
47 | return re.sub(_whitespace_re, ' ', text)
48 |
49 |
50 | def convert_to_ascii(text):
51 | return unidecode(text)
52 |
53 |
54 | def basic_cleaners(text):
55 | text = lowercase(text)
56 | text = collapse_whitespace(text)
57 | return text
58 |
59 |
60 | def transliteration_cleaners(text):
61 | text = convert_to_ascii(text)
62 | text = lowercase(text)
63 | text = collapse_whitespace(text)
64 | return text
65 |
66 |
67 | def english_cleaners(text):
68 | text = convert_to_ascii(text)
69 | text = lowercase(text)
70 | text = expand_numbers(text)
71 | text = expand_abbreviations(text)
72 | text = collapse_whitespace(text)
73 | return text
74 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | pred_sigma=True,
18 | rescale_learned_sigmas=False,
19 | diffusion_steps=1000,
20 | snr=False,
21 | return_startx=False,
22 | ):
23 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
24 | if use_kl:
25 | loss_type = gd.LossType.RESCALED_KL
26 | elif rescale_learned_sigmas:
27 | loss_type = gd.LossType.RESCALED_MSE
28 | else:
29 | loss_type = gd.LossType.MSE
30 | if timestep_respacing is None or timestep_respacing == "":
31 | timestep_respacing = [diffusion_steps]
32 | return SpacedDiffusion(
33 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
34 | betas=betas,
35 | model_mean_type=(
36 | gd.ModelMeanType.START_X if predict_xstart else gd.ModelMeanType.EPSILON
37 | ),
38 | model_var_type=(
39 | (gd.ModelVarType.LEARNED_RANGE if learn_sigma else (
40 | gd.ModelVarType.FIXED_LARGE
41 | if not sigma_small
42 | else gd.ModelVarType.FIXED_SMALL
43 | )
44 | )
45 | if pred_sigma
46 | else None
47 | ),
48 | loss_type=loss_type,
49 | snr=snr,
50 | return_startx=return_startx,
51 | # rescale_timesteps=rescale_timesteps,
52 | )
--------------------------------------------------------------------------------
/text/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2',
8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2',
9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY',
10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1',
11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0',
12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW',
13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH'
14 | ]
15 |
16 | _valid_symbol_set = set(valid_symbols)
17 |
18 |
19 | class CMUDict:
20 | def __init__(self, file_or_path, keep_ambiguous=True):
21 | if isinstance(file_or_path, str):
22 | with open(file_or_path, encoding='latin-1') as f:
23 | entries = _parse_cmudict(f)
24 | else:
25 | entries = _parse_cmudict(file_or_path)
26 | if not keep_ambiguous:
27 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
28 | self._entries = entries
29 |
30 | def __len__(self):
31 | return len(self._entries)
32 |
33 | def lookup(self, word):
34 | return self._entries.get(word.upper())
35 |
36 |
37 | _alt_re = re.compile(r'\([0-9]+\)')
38 |
39 |
40 | def _parse_cmudict(file):
41 | cmudict = {}
42 | for line in file:
43 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
44 | parts = line.split(' ')
45 | word = re.sub(_alt_re, '', parts[0])
46 | pronunciation = _get_pronunciation(parts[1])
47 | if pronunciation:
48 | if word in cmudict:
49 | cmudict[word].append(pronunciation)
50 | else:
51 | cmudict[word] = [pronunciation]
52 | return cmudict
53 |
54 |
55 | def _get_pronunciation(s):
56 | parts = s.strip().split(' ')
57 | for part in parts:
58 | if part not in _valid_symbol_set:
59 | return None
60 | return ' '.join(parts)
61 |
--------------------------------------------------------------------------------
/text/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
13 | _number_re = re.compile(r'[0-9]+')
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(',', '')
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace('.', ' point ')
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split('.')
27 | if len(parts) > 2:
28 | return match + ' dollars'
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
33 | cent_unit = 'cent' if cents == 1 else 'cents'
34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
37 | return '%s %s' % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = 'cent' if cents == 1 else 'cents'
40 | return '%s %s' % (cents, cent_unit)
41 | else:
42 | return 'zero dollars'
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return 'two thousand'
54 | elif num > 2000 and num < 2010:
55 | return 'two thousand ' + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + ' hundred'
58 | else:
59 | return _inflect.number_to_words(num, andword='', zero='oh',
60 | group=2).replace(', ', ' ')
61 | else:
62 | return _inflect.number_to_words(num, andword='')
63 |
64 |
65 | def normalize_numbers(text):
66 | text = re.sub(_comma_number_re, _remove_commas, text)
67 | text = re.sub(_pounds_re, r'\1 pounds', text)
68 | text = re.sub(_dollars_re, _expand_dollars, text)
69 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
70 | text = re.sub(_ordinal_re, _expand_ordinal, text)
71 | text = re.sub(_number_re, _expand_number, text)
72 | return text
73 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VoiceDiT
2 |
3 | This is a repository for the paper, [VoiceDiT: Dual-Condition Diffusion Transformer for Environment-Aware Speech Synthesis](https://arxiv.org/pdf/2412.19259), ICASSP 2025.
4 |
5 |
6 |
7 |
8 |
9 | VoiceDiT is a multi-modal generative model for producing environment-aware speech and audio from text and visual prompts.
10 |
11 |
12 |
13 |
14 |
15 | ## 🔧 Installation
16 |
17 | ### Install from source
18 | ```shell
19 | git clone https://github.com/kaistmm/VoiceDiT.git
20 | cd VoiceDiT
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ## 📖 Usage
25 |
26 | - Generate audio with description prompt and content prompt:
27 | ```shell
28 | python generate.py --desc_prompt "She is talking in a park." --cont_prompt "Good morning! How are you feeling today?"
29 | ```
30 |
31 | - Generate audio with audio prompt and content prompt:
32 | ```shell
33 | python generate.py --modality "audio" --desc_prompt "assets/bird_chirping.wav" --cont_prompt "Good morning! How are you feeling today?"
34 | ```
35 |
36 | - Generate audio with image prompt and content prompt:
37 | ```shell
38 | python generate.py --modality "image" --desc_prompt "assets/park.jpg" --cont_prompt "Good morning! How are you feeling today?"
39 | ```
40 |
41 | - Text-to-Speech Example:
42 | ```shell
43 | python generate.py --desc_prompt "clean speech" --cont_prompt "Good morning! How are you feeling today?" --desc_guidance_scale 1 --cont_guidance_scale 9
44 | ```
45 |
46 | - Text-to-Audio Example:
47 | ```shell
48 | python generate.py --desc_prompt "trumpet" --cont_prompt "_" --desc_guidance_scale 9 --cont_guidance_scale 1
49 | ```
50 |
51 | - Image-to-Audio Example:
52 | ```shell
53 | python generate.py --desc_prompt "assets/tiger.png" --cont_prompt "_" --v2a_guidance_scale 2 --desc_guidance_scale 9 --cont_guidance_scale 1
54 | ```
55 |
56 | Generated audios will be saved at the default output folder `./outputs`.
57 |
58 |
59 | ## ⚙️ Full List of Options
60 | View the full list of options with the following command:
61 | ```console
62 | python generate.py -h
63 | ```
64 |
65 |
66 | ## 🙏 Acknowledgements
67 | This work would not have been possible without the following repositories:
68 |
69 | [VoiceLDM](https://github.com/glory20h/VoiceLDM)
70 |
71 | [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha)
72 |
73 | [HuggingFace Diffusers](https://github.com/huggingface/diffusers)
74 |
75 | [HuggingFace Transformers](https://github.com/huggingface/transformers)
76 |
77 | [AudioLDM](https://github.com/haoheliu/AudioLDM)
78 |
79 | [naturalspeech](https://github.com/heatz123/naturalspeech)
80 |
81 | [audioldm_eval](https://github.com/haoheliu/audioldm_eval)
82 |
83 | [AudioLDM2](https://github.com/haoheliu/AudioLDM2)
--------------------------------------------------------------------------------
/voicedit/audio/audio_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import librosa.util as librosa_util
4 | from scipy.signal import get_window
5 |
6 |
7 | def window_sumsquare(
8 | window,
9 | n_frames,
10 | hop_length,
11 | win_length,
12 | n_fft,
13 | dtype=np.float32,
14 | norm=None,
15 | ):
16 | """
17 | # from librosa 0.6
18 | Compute the sum-square envelope of a window function at a given hop length.
19 |
20 | This is used to estimate modulation effects induced by windowing
21 | observations in short-time fourier transforms.
22 |
23 | Parameters
24 | ----------
25 | window : string, tuple, number, callable, or list-like
26 | Window specification, as in `get_window`
27 |
28 | n_frames : int > 0
29 | The number of analysis frames
30 |
31 | hop_length : int > 0
32 | The number of samples to advance between frames
33 |
34 | win_length : [optional]
35 | The length of the window function. By default, this matches `n_fft`.
36 |
37 | n_fft : int > 0
38 | The length of each analysis frame.
39 |
40 | dtype : np.dtype
41 | The data type of the output
42 |
43 | Returns
44 | -------
45 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46 | The sum-squared envelope of the window function
47 | """
48 | if win_length is None:
49 | win_length = n_fft
50 |
51 | n = n_fft + hop_length * (n_frames - 1)
52 | x = np.zeros(n, dtype=dtype)
53 |
54 | # Compute the squared window at the desired length
55 | win_sq = get_window(window, win_length, fftbins=True)
56 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57 | win_sq = librosa_util.pad_center(win_sq, n_fft)
58 |
59 | # Fill the envelope
60 | for i in range(n_frames):
61 | sample = i * hop_length
62 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63 | return x
64 |
65 |
66 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
67 | """
68 | PARAMS
69 | ------
70 | magnitudes: spectrogram magnitudes
71 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72 | """
73 |
74 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75 | angles = angles.astype(np.float32)
76 | angles = torch.autograd.Variable(torch.from_numpy(angles))
77 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78 |
79 | for i in range(n_iters):
80 | _, angles = stft_fn.transform(signal)
81 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82 | return signal
83 |
84 |
85 | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86 | """
87 | PARAMS
88 | ------
89 | C: compression factor
90 | """
91 | return normalize_fun(torch.clamp(x, min=clip_val) * C)
92 |
93 |
94 | def dynamic_range_decompression(x, C=1):
95 | """
96 | PARAMS
97 | ------
98 | C: compression factor used to compress
99 | """
100 | return torch.exp(x) / C
101 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 | from text import cleaners
5 | from text.symbols import symbols
6 |
7 |
8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
10 |
11 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
12 |
13 |
14 | def get_arpabet(word, dictionary):
15 | word_arpabet = dictionary.lookup(word)
16 | if word_arpabet is not None:
17 | return "{" + word_arpabet[0] + "}"
18 | else:
19 | return word
20 |
21 |
22 | def text_to_sequence(text, cleaner_names=["english_cleaners"], dictionary=None):
23 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
24 |
25 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
26 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
27 |
28 | Args:
29 | text: string to convert to a sequence
30 | cleaner_names: names of the cleaner functions to run the text through
31 | dictionary: arpabet class with arpabet dictionary
32 |
33 | Returns:
34 | List of integers corresponding to the symbols in the text
35 | '''
36 | sequence = []
37 | space = _symbols_to_sequence(' ')
38 | # Check for curly braces and treat their contents as ARPAbet:
39 | while len(text):
40 | m = _curly_re.match(text)
41 | if not m:
42 | clean_text = _clean_text(text, cleaner_names)
43 | if dictionary is not None:
44 | clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
45 | for i in range(len(clean_text)):
46 | t = clean_text[i]
47 | if t.startswith("{"):
48 | sequence += _arpabet_to_sequence(t[1:-1])
49 | else:
50 | sequence += _symbols_to_sequence(t)
51 | sequence += space
52 | else:
53 | sequence += _symbols_to_sequence(clean_text)
54 | break
55 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
56 | sequence += _arpabet_to_sequence(m.group(2))
57 | text = m.group(3)
58 |
59 | # remove trailing space
60 | if dictionary is not None:
61 | sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
62 | return sequence
63 |
64 |
65 | def sequence_to_text(sequence):
66 | '''Converts a sequence of IDs back to a string'''
67 | result = ''
68 | for symbol_id in sequence:
69 | if symbol_id in _id_to_symbol:
70 | s = _id_to_symbol[symbol_id]
71 | # Enclose ARPAbet back in curly braces:
72 | if len(s) > 1 and s[0] == '@':
73 | s = '{%s}' % s[1:]
74 | result += s
75 | return result.replace('}{', ' ')
76 |
77 |
78 | def _clean_text(text, cleaner_names):
79 | for name in cleaner_names:
80 | cleaner = getattr(cleaners, name)
81 | if not cleaner:
82 | raise Exception('Unknown cleaner: %s' % name)
83 | text = cleaner(text)
84 | return text
85 |
86 |
87 | def _symbols_to_sequence(symbols):
88 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
89 |
90 |
91 | def _arpabet_to_sequence(text):
92 | return _symbols_to_sequence(['@' + s for s in text.split()])
93 |
94 |
95 | def _should_keep_symbol(s):
96 | return s in _symbol_to_id and s != '_' and s != '~'
97 |
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = next(
17 | (
18 | obj
19 | for obj in (mean1, logvar1, mean2, logvar2)
20 | if isinstance(obj, th.Tensor)
21 | ),
22 | None,
23 | )
24 | assert tensor is not None, "at least one argument must be a Tensor"
25 |
26 | # Force variances to be Tensors. Broadcasting helps convert scalars to
27 | # Tensors, but it does not work for th.exp().
28 | logvar1, logvar2 = [
29 | x if isinstance(x, th.Tensor) else th.tensor(x, device=tensor.device)
30 | for x in (logvar1, logvar2)
31 | ]
32 |
33 | return 0.5 * (
34 | -1.0
35 | + logvar2
36 | - logvar1
37 | + th.exp(logvar1 - logvar2)
38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39 | )
40 |
41 |
42 | def approx_standard_normal_cdf(x):
43 | """
44 | A fast approximation of the cumulative distribution function of the
45 | standard normal.
46 | """
47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48 |
49 |
50 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
51 | """
52 | Compute the log-likelihood of a continuous Gaussian distribution.
53 | :param x: the targets
54 | :param means: the Gaussian mean Tensor.
55 | :param log_scales: the Gaussian log stddev Tensor.
56 | :return: a tensor like x of log probabilities (in nats).
57 | """
58 | centered_x = x - means
59 | inv_stdv = th.exp(-log_scales)
60 | normalized_x = centered_x * inv_stdv
61 | return th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(
62 | normalized_x
63 | )
64 |
65 |
66 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
67 | """
68 | Compute the log-likelihood of a Gaussian distribution discretizing to a
69 | given image.
70 | :param x: the target images. It is assumed that this was uint8 values,
71 | rescaled to the range [-1, 1].
72 | :param means: the Gaussian mean Tensor.
73 | :param log_scales: the Gaussian log stddev Tensor.
74 | :return: a tensor like x of log probabilities (in nats).
75 | """
76 | assert x.shape == means.shape == log_scales.shape
77 | centered_x = x - means
78 | inv_stdv = th.exp(-log_scales)
79 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
80 | cdf_plus = approx_standard_normal_cdf(plus_in)
81 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
82 | cdf_min = approx_standard_normal_cdf(min_in)
83 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
84 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
85 | cdf_delta = cdf_plus - cdf_min
86 | log_probs = th.where(
87 | x < -0.999,
88 | log_cdf_plus,
89 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
90 | )
91 | assert log_probs.shape == x.shape
92 | return log_probs
93 |
--------------------------------------------------------------------------------
/voicedit/audio/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchaudio
4 |
5 |
6 | def get_mel_from_wav(audio, _stft):
7 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
8 | audio = torch.autograd.Variable(audio, requires_grad=False)
9 | melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
10 | melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
11 | log_magnitudes_stft = (
12 | torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
13 | )
14 | energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
15 | return melspec, log_magnitudes_stft, energy
16 |
17 |
18 | def _pad_spec(fbank, target_length=1024):
19 | n_frames = fbank.shape[0]
20 | p = target_length - n_frames
21 | # cut and pad
22 | if p > 0:
23 | m = torch.nn.ZeroPad2d((0, 0, 0, p))
24 | fbank = m(fbank)
25 | elif p < 0:
26 | fbank = fbank[0:target_length, :]
27 |
28 | if fbank.size(-1) % 2 != 0:
29 | fbank = fbank[..., :-1]
30 |
31 | return fbank
32 |
33 |
34 | def pad_wav(waveform, segment_length):
35 | waveform_length = waveform.shape[-1]
36 | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
37 | if segment_length is None or waveform_length == segment_length:
38 | return waveform
39 | elif waveform_length > segment_length:
40 | return waveform[:, :segment_length]
41 | elif waveform_length < segment_length:
42 | temp_wav = np.zeros((1, segment_length))
43 | temp_wav[:, :waveform_length] = waveform
44 | return temp_wav
45 |
46 | def normalize_wav(waveform):
47 | waveform = waveform - np.mean(waveform)
48 | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
49 | return waveform * 0.5
50 |
51 |
52 | def read_wav_file(filename, segment_length):
53 | # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
54 | waveform, sr = torchaudio.load(filename) # Faster!!!
55 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
56 | waveform = waveform.numpy()[0, ...]
57 | waveform = normalize_wav(waveform)
58 | waveform = waveform[None, ...]
59 | waveform = pad_wav(waveform, segment_length)
60 |
61 | waveform = waveform / np.max(np.abs(waveform))
62 | waveform = 0.5 * waveform
63 |
64 | return waveform
65 |
66 |
67 | def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
68 | assert fn_STFT is not None
69 |
70 | # mixup
71 | waveform = read_wav_file(filename, target_length * 160) # hop size is 160
72 |
73 | waveform = waveform[0, ...]
74 | waveform = torch.FloatTensor(waveform)
75 |
76 | fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
77 |
78 | fbank = torch.FloatTensor(fbank.T)
79 | log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
80 |
81 | fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
82 | log_magnitudes_stft, target_length
83 | )
84 |
85 | return fbank, log_magnitudes_stft, waveform
86 |
87 |
88 | def raw_waveform_to_fbank(waveform, target_length=1024, fn_STFT=None):
89 | assert fn_STFT is not None
90 |
91 | fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
92 |
93 | fbank = torch.FloatTensor(fbank.T)
94 | log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
95 |
96 | fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
97 | log_magnitudes_stft, target_length
98 | )
99 |
100 | return fbank, log_magnitudes_stft, waveform
101 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import warnings
4 |
5 | import torch
6 | import torchaudio
7 |
8 | from voicedit import VoiceDiTPipeline
9 |
10 | warnings.filterwarnings("ignore") # ignore warning
11 |
12 |
13 | def parse_args(parser):
14 | parser.add_argument("--modality", type=str, default="text", help="modality for generation. 'text' or 'audio' or 'image'")
15 | parser.add_argument('--desc_prompt', '-d', type=str, help='description prompt')
16 | parser.add_argument('--cont_prompt', '-c', type=str, help='content prompt')
17 | parser.add_argument('--speaker_audio', '-spk', type=str, help='speaker audio')
18 |
19 | parser.add_argument("--ckpt_path", type=str, required=True, help="checkpoint file path for VoiceDiT")
20 | parser.add_argument("--v2a_ckpt_path", type=str, help='checkpoint file path for V2A-Mapper')
21 | parser.add_argument("--output_dir", type=str, default="./outputs", help="directory to save generated audio")
22 | parser.add_argument("--file_name", type=str, help="filename for the generated audio")
23 |
24 | parser.add_argument('--num_inference_steps', type=int, default=100, help='number of inference steps for DDIM sampling')
25 | parser.add_argument('--audio_length_in_s', type=float, default=10, help='duration of the audio for generation')
26 | parser.add_argument('--v2a_guidance_scale', type=float, default=2.0, help='guidance weight for v2a-mapper classifier-free guidance')
27 | parser.add_argument('--guidance_scale', type=float, help='guidance weight for single classifier-free guidance')
28 | parser.add_argument('--desc_guidance_scale', type=float, default=5, required=False, help='desc guidance weight for dual classifier-free guidance')
29 | parser.add_argument('--cont_guidance_scale', type=float, default=5, required=False, help='cont guidance weight for dual classifier-free guidance')
30 | parser.add_argument('--female_voice_list', type=str, default='libri_test_clean_female.txt')
31 | parser.add_argument('--male_voice_list', type=str, default='libri_test_clean_male.txt')
32 |
33 | parser.add_argument("--device", type=str, default="auto", help="device to use for audio generation")
34 | parser.add_argument('--seed', type=int, help='random seed for generation')
35 |
36 | return parser.parse_args()
37 |
38 |
39 | def main():
40 | parser = argparse.ArgumentParser()
41 | args = parse_args(parser)
42 |
43 | if args.device == "auto":
44 | if torch.cuda.is_available():
45 | args.device = torch.device("cuda:0")
46 | else:
47 | args.device = torch.device("cpu")
48 | elif args.device is not None:
49 | args.device = torch.device(args.device)
50 | else:
51 | args.device = torch.device("cpu")
52 |
53 | pipe = VoiceDiTPipeline(
54 | args.ckpt_path,
55 | t2a_ckpt_path = args.t2a_ckpt_path,
56 | v2a_ckpt_path = args.v2a_ckpt_path,
57 | device = args.device,
58 | male_voice = args.male_voice_list,
59 | female_voice = args.female_voice_list,
60 | )
61 |
62 | audio, _ = pipe(
63 | modality = args.modality,
64 | env_prompt = args.desc_prompt,
65 | cont_prompt = args.cont_prompt,
66 | batch_size = 1,
67 | num_inference_steps = args.num_inference_steps,
68 | audio_length_in_s = 10,
69 | do_classifier_free_guidance = True,
70 | desc_guidance_scale = args.desc_guidance_scale,
71 | cont_guidance_scale = args.cont_guidance_scale,
72 | v2a_guidance_scale = args.v2a_guidance_scale,
73 | device=args.device,
74 | seed=args.seed,
75 | progress=True,
76 | speaker_audio=args.speaker_audio,
77 | )
78 |
79 | file_name = args.file_name
80 | if file_name is None:
81 | if args.modality == "text":
82 | file_name = args.desc_prompt[:10]
83 | else:
84 | file_name = os.path.basename(args.desc_prompt[:-4])
85 | file_name = file_name + "-" + args.cont_prompt[:10] + ".wav"
86 |
87 | os.makedirs(args.output_dir, exist_ok=True)
88 | save_path = os.path.join(args.output_dir, file_name)
89 |
90 | torchaudio.save(save_path, src=audio, sample_rate=16000)
91 |
92 |
93 | if __name__ == "__main__":
94 | main()
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | frac_stride = 1 if section_count <= 1 else (size - 1) / (section_count - 1)
52 | cur_idx = 0.0
53 | taken_steps = []
54 | for _ in range(section_count):
55 | taken_steps.append(start_idx + round(cur_idx))
56 | cur_idx += frac_stride
57 | all_steps += taken_steps
58 | start_idx += size
59 | return set(all_steps)
60 |
61 |
62 | class SpacedDiffusion(GaussianDiffusion):
63 | """
64 | A diffusion process which can skip steps in a base diffusion process.
65 | :param use_timesteps: a collection (sequence or set) of timesteps from the
66 | original diffusion process to retain.
67 | :param kwargs: the kwargs to create the base diffusion process.
68 | """
69 |
70 | def __init__(self, use_timesteps, **kwargs):
71 | self.use_timesteps = set(use_timesteps)
72 | self.timestep_map = []
73 | self.original_num_steps = len(kwargs["betas"])
74 |
75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
76 | last_alpha_cumprod = 1.0
77 | new_betas = []
78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
79 | if i in self.use_timesteps:
80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
81 | last_alpha_cumprod = alpha_cumprod
82 | self.timestep_map.append(i)
83 | kwargs["betas"] = np.array(new_betas)
84 | super().__init__(**kwargs)
85 |
86 | def p_mean_variance(
87 | self, model, *args, **kwargs
88 | ): # pylint: disable=signature-differs
89 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
90 |
91 | def training_losses(
92 | self, model, *args, **kwargs
93 | ): # pylint: disable=signature-differs
94 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
95 |
96 | def training_losses_diffusers(
97 | self, model, *args, **kwargs
98 | ): # pylint: disable=signature-differs
99 | return super().training_losses_diffusers(self._wrap_model(model), *args, **kwargs)
100 |
101 | def condition_mean(self, cond_fn, *args, **kwargs):
102 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def condition_score(self, cond_fn, *args, **kwargs):
105 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
106 |
107 | def _wrap_model(self, model):
108 | if isinstance(model, _WrappedModel):
109 | return model
110 | return _WrappedModel(
111 | model, self.timestep_map, self.original_num_steps
112 | )
113 |
114 | def _scale_timesteps(self, t):
115 | # Scaling is done by the wrapped model.
116 | return t
117 |
118 |
119 | class _WrappedModel:
120 | def __init__(self, model, timestep_map, original_num_steps):
121 | self.model = model
122 | self.timestep_map = timestep_map
123 | # self.rescale_timesteps = rescale_timesteps
124 | self.original_num_steps = original_num_steps
125 |
126 | def __call__(self, x, timestep, **kwargs):
127 | map_tensor = th.tensor(self.timestep_map, device=timestep.device, dtype=timestep.dtype)
128 | new_ts = map_tensor[timestep]
129 | # if self.rescale_timesteps:
130 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
131 | return self.model(x, timestep=new_ts, **kwargs)
132 |
--------------------------------------------------------------------------------
/voicedit/audio/stft.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy.signal import get_window
5 | from librosa.util import pad_center, tiny
6 | from librosa.filters import mel as librosa_mel_fn
7 |
8 | from .audio_processing import (
9 | dynamic_range_compression,
10 | dynamic_range_decompression,
11 | window_sumsquare,
12 | )
13 |
14 |
15 | class STFT(torch.nn.Module):
16 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17 |
18 | def __init__(self, filter_length, hop_length, win_length, window="hann"):
19 | super(STFT, self).__init__()
20 | self.filter_length = filter_length
21 | self.hop_length = hop_length
22 | self.win_length = win_length
23 | self.window = window
24 | self.forward_transform = None
25 | scale = self.filter_length / self.hop_length
26 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
27 |
28 | cutoff = int((self.filter_length / 2 + 1))
29 | fourier_basis = np.vstack(
30 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31 | )
32 |
33 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34 | inverse_basis = torch.FloatTensor(
35 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36 | )
37 |
38 | if window is not None:
39 | assert filter_length >= win_length
40 | # get window and zero center pad it to filter_length
41 | fft_window = get_window(window, win_length, fftbins=True)
42 | fft_window = pad_center(fft_window, size=filter_length)
43 | fft_window = torch.from_numpy(fft_window).float()
44 |
45 | # window the bases
46 | forward_basis *= fft_window
47 | inverse_basis *= fft_window
48 |
49 | self.register_buffer("forward_basis", forward_basis.float())
50 | self.register_buffer("inverse_basis", inverse_basis.float())
51 |
52 | def transform(self, input_data):
53 | num_batches = input_data.size(0)
54 | num_samples = input_data.size(1)
55 |
56 | self.num_samples = num_samples
57 |
58 | # similar to librosa, reflect-pad the input
59 | input_data = input_data.view(num_batches, 1, num_samples)
60 | input_data = F.pad(
61 | input_data.unsqueeze(1),
62 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63 | mode="reflect",
64 | )
65 | input_data = input_data.squeeze(1)
66 |
67 | forward_transform = F.conv1d(
68 | input_data,
69 | torch.autograd.Variable(self.forward_basis, requires_grad=False),
70 | stride=self.hop_length,
71 | padding=0,
72 | ).cpu()
73 |
74 | cutoff = int((self.filter_length / 2) + 1)
75 | real_part = forward_transform[:, :cutoff, :]
76 | imag_part = forward_transform[:, cutoff:, :]
77 |
78 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
79 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80 |
81 | return magnitude, phase
82 |
83 | def inverse(self, magnitude, phase):
84 | recombine_magnitude_phase = torch.cat(
85 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86 | )
87 |
88 | inverse_transform = F.conv_transpose1d(
89 | recombine_magnitude_phase,
90 | torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91 | stride=self.hop_length,
92 | padding=0,
93 | )
94 |
95 | if self.window is not None:
96 | window_sum = window_sumsquare(
97 | self.window,
98 | magnitude.size(-1),
99 | hop_length=self.hop_length,
100 | win_length=self.win_length,
101 | n_fft=self.filter_length,
102 | dtype=np.float32,
103 | )
104 | # remove modulation effects
105 | approx_nonzero_indices = torch.from_numpy(
106 | np.where(window_sum > tiny(window_sum))[0]
107 | )
108 | window_sum = torch.autograd.Variable(
109 | torch.from_numpy(window_sum), requires_grad=False
110 | )
111 | window_sum = window_sum
112 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113 | approx_nonzero_indices
114 | ]
115 |
116 | # scale by hop ratio
117 | inverse_transform *= float(self.filter_length) / self.hop_length
118 |
119 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121 |
122 | return inverse_transform
123 |
124 | def forward(self, input_data):
125 | self.magnitude, self.phase = self.transform(input_data)
126 | reconstruction = self.inverse(self.magnitude, self.phase)
127 | return reconstruction
128 |
129 |
130 | class TacotronSTFT(torch.nn.Module):
131 | def __init__(
132 | self,
133 | filter_length,
134 | hop_length,
135 | win_length,
136 | n_mel_channels,
137 | sampling_rate,
138 | mel_fmin,
139 | mel_fmax,
140 | ):
141 | super(TacotronSTFT, self).__init__()
142 | self.n_mel_channels = n_mel_channels
143 | self.sampling_rate = sampling_rate
144 | self.stft_fn = STFT(filter_length, hop_length, win_length)
145 | mel_basis = librosa_mel_fn(
146 | sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
147 | )
148 | mel_basis = torch.from_numpy(mel_basis).float()
149 | self.register_buffer("mel_basis", mel_basis)
150 |
151 | def spectral_normalize(self, magnitudes, normalize_fun):
152 | output = dynamic_range_compression(magnitudes, normalize_fun)
153 | return output
154 |
155 | def spectral_de_normalize(self, magnitudes):
156 | output = dynamic_range_decompression(magnitudes)
157 | return output
158 |
159 | def mel_spectrogram(self, y, normalize_fun=torch.log):
160 | """Computes mel-spectrograms from a batch of waves
161 | PARAMS
162 | ------
163 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164 |
165 | RETURNS
166 | -------
167 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168 | """
169 | assert torch.min(y.data) >= -1, torch.min(y.data)
170 | assert torch.max(y.data) <= 1, torch.max(y.data)
171 |
172 | magnitudes, phases = self.stft_fn.transform(y)
173 | magnitudes = magnitudes.data
174 | mel_output = torch.matmul(self.mel_basis, magnitudes)
175 | mel_output = self.spectral_normalize(mel_output, normalize_fun)
176 | energy = torch.norm(magnitudes, dim=1)
177 |
178 | log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
179 |
180 | return mel_output, log_magnitudes, energy
181 |
--------------------------------------------------------------------------------
/voicedit/data_ft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import random
4 |
5 | import math
6 | import torch
7 | import torch.nn.functional as F
8 | import torchaudio
9 | from torch.utils.data import Dataset
10 |
11 | from .audio import get_mel_from_wav, raw_waveform_to_fbank, TacotronSTFT
12 | from text import text_to_sequence, cmudict
13 | from text.symbols import symbols
14 | from .utils import intersperse
15 |
16 |
17 | class AudioDataset_ft(Dataset):
18 | def __init__(self, args, df, df_noise, clap_processor, rir_path=None, random_pad=False, cmudict_path='voicedit/cmu_dictionary', add_blank=True):
19 | self.df = df
20 | self.df_noise = df_noise
21 |
22 | self.paths = args.paths
23 | self.noise_paths = args.noise_paths
24 |
25 | self.uncond_text_prob = args.uncond_text_prob
26 | self.add_noise_prob = args.add_noise_prob
27 | self.reverb_prob = args.reverb_prob
28 | self.random_pad = random_pad
29 | self.only_noise_prob = args.only_noise_prob
30 |
31 | self.duration = 10
32 | self.target_length = int(self.duration * 102.4)
33 | self.stft = TacotronSTFT(
34 | filter_length=1024,
35 | hop_length=160,
36 | win_length=1024,
37 | n_mel_channels=64,
38 | sampling_rate=16000,
39 | mel_fmin=0,
40 | mel_fmax=8000,
41 | )
42 |
43 | self.clap_processor = clap_processor
44 |
45 | self.cmudict = cmudict.CMUDict(cmudict_path)
46 | self.add_blank = add_blank
47 |
48 | if rir_path is not None:
49 | self.rir_files = glob.glob(os.path.join(rir_path, '*/*/*.wav'))
50 |
51 | def get_mel(self, audio, _stft):
52 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
53 | audio = torch.autograd.Variable(audio, requires_grad=False)
54 | melspec, _, _ = _stft.mel_spectrogram(audio)
55 | return torch.squeeze(melspec, 0).float()
56 |
57 | def get_text(self, text, add_blank=True):
58 | text_norm = text_to_sequence(text, dictionary=self.cmudict)
59 | if add_blank:
60 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
61 | text_norm = torch.IntTensor(text_norm)
62 | return text_norm
63 |
64 | def reverberate(self, audio):
65 |
66 | rir_file = random.choice(self.rir_files)
67 | rir, fs = torchaudio.load(rir_file)
68 | rir = rir.to(dtype=audio.dtype)
69 | rir = rir / torch.linalg.vector_norm(rir, ord=2, dim=-1)
70 |
71 | return torchaudio.functional.fftconvolve(audio, rir)[:,:self.target_length * 160]
72 |
73 | def pad_wav(self, wav, target_len, random_cut=False, random_pad=False):
74 | n_channels, wav_len = wav.shape
75 | if n_channels == 2:
76 | wav = wav.mean(-2, keepdim=True)
77 |
78 | if wav_len > target_len:
79 | if random_cut:
80 | i = random.randint(0, wav_len - target_len)
81 | return wav[:, i:i+target_len], 0
82 | return wav[:, :target_len], 0
83 | elif wav_len < target_len:
84 | if random_pad:
85 | i = random.randint(0, (target_len - wav_len) // 640)
86 | pre_pad = i * 640
87 | return F.pad(wav, (pre_pad, target_len-wav_len-pre_pad)), i
88 | wav = F.pad(wav, (0, target_len-wav_len))
89 | return wav, 0
90 |
91 | def normalize_wav(self, waveform):
92 | waveform = waveform - torch.mean(waveform)
93 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
94 | return waveform
95 |
96 | def __getitem__(self, index):
97 | if random.random() > self.only_noise_prob:
98 | row = self.df.iloc[index]
99 | text = row.text
100 | file_path = os.path.join(self.paths[row.data], row.file_path)
101 |
102 | waveform, sr = torchaudio.load(file_path)
103 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
104 | waveform, speech_start = self.pad_wav(waveform, self.target_length * 160, random_pad=self.random_pad)
105 |
106 | if row.data in ['as_speech_en']:
107 | spk_emb=torch.load(file_path.replace('/dataset/', '/spk_emb/').replace('.wav', '.pt'), map_location=torch.device('cpu'))
108 |
109 | if row.data in ['cv', 'voxceleb', 'libri'] and len(self.df_noise) > 0:
110 | if random.random() < self.add_noise_prob:
111 | noise_row = self.df_noise.iloc[random.randint(0, len(self.df_noise)-1)]
112 | noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path))
113 | noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000)
114 | noise, _ = self.pad_wav(noise, self.target_length * 160, random_cut=True)
115 | if torch.linalg.vector_norm(noise, ord=2, dim=-1).item() != 0.0:
116 | snr = torch.Tensor(1).uniform_(2, 10)
117 | waveform = torchaudio.functional.add_noise(waveform, noise, snr)
118 |
119 | if random.random() < self.reverb_prob and hasattr(self, 'rir_files'):
120 | waveform = self.reverberate(waveform)
121 |
122 | spk_emb=torch.load(file_path.replace('/LibriTTS_R_16k/', '/LibriTTS_R_spk/').replace('.wav', '.pt'), map_location=torch.device('cpu'))
123 |
124 | else:
125 | text = '_'
126 | noise_row = self.df_noise.iloc[random.randint(0, len(self.df_noise)-1)]
127 | noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path))
128 | noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000)
129 | waveform, _ = self.pad_wav(noise, self.target_length * 160, random_cut=True)
130 | spk_emb = torch.zeros(1, 192)
131 | speech_start = 0
132 |
133 | if type(spk_emb) != torch.Tensor:
134 | spk_emb = spk_emb[0]
135 |
136 | fbank, _, waveform = raw_waveform_to_fbank(
137 | waveform[0],
138 | target_length=self.target_length,
139 | fn_STFT=self.stft
140 | )
141 |
142 | tokenized_text = self.get_text(text, add_blank=self.add_blank)
143 |
144 | # resample to 48k for clap
145 | wav_48k = torchaudio.functional.resample(waveform, orig_freq=16000, new_freq=48000)
146 | clap_inputs = self.clap_processor(audios=wav_48k, return_tensors="pt", sampling_rate=48000)
147 |
148 | return text, tokenized_text, fbank, spk_emb, clap_inputs, self.normalize_wav(waveform), speech_start
149 |
150 | def __len__(self):
151 | return len(self.df)
152 |
153 |
154 | class CollateFn_ft(object):
155 |
156 | def __call__(self, examples):
157 |
158 | B = len(examples)
159 |
160 | fbank = torch.stack([example[2] for example in examples])
161 | spk_embs = torch.cat([example[3] for example in examples])
162 | clap_input_features = torch.cat([example[4].input_features for example in examples])
163 | clap_is_longer = torch.cat([example[4].is_longer for example in examples])
164 | audios = [example[5] for example in examples]
165 | # speech_starts = torch.LongTensor([example[6] for example in examples])
166 |
167 | x_max_length = max([example[1].shape[-1] for example in examples])
168 |
169 | x = torch.zeros((B, x_max_length), dtype=torch.long)
170 | x_lengths = []
171 |
172 | for i, example in enumerate(examples):
173 | x_ = example[1]
174 | x_lengths.append(x_.shape[-1])
175 | x[i, :x_.shape[-1]] = x_
176 |
177 | x_lengths = torch.LongTensor(x_lengths)
178 |
179 | return {
180 | "audio": audios,
181 | "fbank": fbank,
182 | "spk": spk_embs,
183 | "text": x,
184 | "text_lengths": x_lengths,
185 | "clap_input_features": clap_input_features,
186 | "clap_is_longer": clap_is_longer,
187 | "texts": [example[0] for example in examples],
188 | # "speech_starts": speech_starts
189 | }
--------------------------------------------------------------------------------
/voicedit/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import torchaudio
7 | from torch.utils.data import Dataset
8 |
9 | from .audio import get_mel_from_wav, raw_waveform_to_fbank, TacotronSTFT
10 | from text import text_to_sequence, cmudict
11 | from text.symbols import symbols
12 | from .utils import intersperse
13 |
14 |
15 | class AudioDataset(Dataset):
16 | def __init__(self, args, df, df_noise, clap_processor, rir_path=None, cmudict_path='voicedit/cmu_dictionary', add_blank=True, train=True):
17 | self.df = df
18 | self.df_noise = df_noise
19 |
20 | self.paths = args.paths
21 | self.noise_paths = args.noise_paths
22 |
23 | self.uncond_text_prob = args.uncond_text_prob
24 | self.add_noise_prob = args.add_noise_prob
25 | self.reverb_prob = args.reverb_prob
26 |
27 | self.duration = 10
28 | self.target_length = int(self.duration * 102.4)
29 | self.stft = TacotronSTFT(
30 | filter_length=1024,
31 | hop_length=160,
32 | win_length=1024,
33 | n_mel_channels=64,
34 | sampling_rate=16000,
35 | mel_fmin=0,
36 | mel_fmax=8000,
37 | )
38 |
39 | self.clap_processor = clap_processor
40 | self.train = train
41 | if not train:
42 | self.noise_dict = dict()
43 |
44 | self.cmudict = cmudict.CMUDict(cmudict_path)
45 | self.add_blank = add_blank
46 |
47 | def get_mel(self, audio, _stft):
48 | audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
49 | audio = torch.autograd.Variable(audio, requires_grad=False)
50 | melspec, _, _ = _stft.mel_spectrogram(audio)
51 | return torch.squeeze(melspec, 0).float()
52 |
53 | def get_text(self, text, add_blank=True):
54 | text_norm = text_to_sequence(text, dictionary=self.cmudict)
55 | if add_blank:
56 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
57 | text_norm = torch.IntTensor(text_norm)
58 | return text_norm
59 |
60 | def reverberate(self, audio):
61 |
62 | rir_file = random.choice(self.rir_files)
63 | rir, fs = torchaudio.load(rir_file)
64 | rir = rir.to(dtype=audio.dtype)
65 | rir = rir / torch.linalg.vector_norm(rir, ord=2, dim=-1)
66 |
67 | return torchaudio.functional.fftconvolve(audio, rir)[:,:self.target_length * 160]
68 |
69 | def pad_wav(self, wav, target_len, random_cut=False):
70 | n_channels, wav_len = wav.shape
71 | if n_channels == 2:
72 | wav = wav.mean(-2, keepdim=True)
73 |
74 | if wav_len > target_len:
75 | if random_cut:
76 | i = random.randint(0, wav_len - target_len)
77 | return wav[:, i:i+target_len]
78 | return wav[:, :target_len]
79 | elif wav_len < target_len:
80 | wav = F.pad(wav, (0, target_len-wav_len))
81 | return wav
82 |
83 | def normalize_wav(self, waveform):
84 | waveform = waveform - torch.mean(waveform)
85 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
86 | return waveform
87 |
88 | def __getitem__(self, index):
89 | row = self.df.iloc[index]
90 | file_path = os.path.join(self.paths[row.data], row.file_path)
91 |
92 | waveform, sr = torchaudio.load(file_path)
93 | waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
94 | wav_len = waveform.shape[1]
95 | if wav_len > self.target_length * 160:
96 | wav_len = self.target_length * 160
97 |
98 | fbank_clean = self.get_mel(waveform[0, :wav_len], self.stft)
99 |
100 | waveform = self.pad_wav(waveform, self.target_length * 160)
101 |
102 | if row.data in ['as_speech_en']:
103 | spk_emb=torch.load(file_path.replace('/dataset/', '/spk_emb/').replace('.wav', '.pt'), map_location=torch.device('cpu'))
104 |
105 | if row.data in ['cv', 'voxceleb', 'libri'] and len(self.df_noise) > 0:
106 | if random.random() < self.reverb_prob and hasattr(self, 'rir_files'):
107 | waveform = self.reverberate(waveform)
108 |
109 | if random.random() < self.add_noise_prob:
110 | noise_row = self.df_noise.iloc[random.randint(0, len(self.df_noise)-1)]
111 | noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path))
112 | noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000)
113 | noise = self.pad_wav(noise, self.target_length * 160, random_cut=True)
114 | if torch.linalg.vector_norm(noise, ord=2, dim=-1).item() != 0.0:
115 | snr = torch.Tensor(1).uniform_(2, 10)
116 | waveform = torchaudio.functional.add_noise(waveform, noise, snr)
117 | # else:
118 | # if index not in self.noise_dict:
119 | # noise_idx = random.randint(0, len(self.df_noise)-1)
120 | # noise_row = self.df_noise.iloc[noise_idx]
121 | # noise, sr = torchaudio.load(os.path.join(self.noise_paths[noise_row.data], noise_row.file_path))
122 | # noise = torchaudio.functional.resample(noise, orig_freq=sr, new_freq=16000)
123 | # noise = self.pad_wav(noise, self.target_length * 160, random_cut=True)
124 | # if torch.linalg.vector_norm(noise, ord=2, dim=-1).item() == 0.0:
125 | # noise = torch.zeros_like(noise)
126 | # snr = torch.Tensor(1).uniform_(4, 20)
127 | # self.noise_dict[index] = (noise, snr)
128 | # else:
129 | # noise, snr = self.noise_dict[index]
130 | # waveform = torchaudio.functional.add_noise(waveform, noise, snr)
131 |
132 | spk_emb=torch.load(file_path.replace('/LibriTTS_R_16k/', '/LibriTTS_R_spk/').replace('.wav', '.pt'), map_location=torch.device('cpu'))
133 |
134 | if type(spk_emb) != torch.Tensor:
135 | spk_emb = spk_emb[0]
136 |
137 | fbank, _, waveform = raw_waveform_to_fbank(
138 | waveform[0],
139 | target_length=self.target_length,
140 | fn_STFT=self.stft
141 | )
142 |
143 | text = row.text
144 | tokenized_text = self.get_text(text, add_blank=self.add_blank)
145 |
146 | # resample to 48k for clap
147 | wav_48k = torchaudio.functional.resample(waveform, orig_freq=16000, new_freq=48000)
148 | clap_inputs = self.clap_processor(audios=wav_48k, return_tensors="pt", sampling_rate=48000)
149 |
150 | return text, tokenized_text, fbank, spk_emb, clap_inputs, self.normalize_wav(waveform), fbank_clean
151 |
152 | def __len__(self):
153 | return len(self.df)
154 |
155 |
156 | class CollateFn(object):
157 |
158 | def __call__(self, examples):
159 | B = len(examples)
160 |
161 | fbank = torch.stack([example[2] for example in examples])
162 | spk_embs = torch.cat([example[3] for example in examples])
163 | clap_input_features = torch.cat([example[4].input_features for example in examples])
164 | clap_is_longer = torch.cat([example[4].is_longer for example in examples])
165 | audios = [example[5] for example in examples]
166 |
167 | y_max_length = max([example[6].shape[-1] for example in examples])
168 | x_max_length = max([example[1].shape[-1] for example in examples])
169 | n_feats = examples[0][6].shape[-2]
170 |
171 | y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float)
172 | x = torch.zeros((B, x_max_length), dtype=torch.long)
173 | y_lengths, x_lengths = [], []
174 |
175 | for i, example in enumerate(examples):
176 | x_, y_ = example[1], example[6]
177 | y_lengths.append(y_.shape[-1])
178 | x_lengths.append(x_.shape[-1])
179 | y[i, :, :y_.shape[-1]] = y_
180 | x[i, :x_.shape[-1]] = x_
181 |
182 | y_lengths = torch.LongTensor(y_lengths)
183 | x_lengths = torch.LongTensor(x_lengths)
184 |
185 | return {
186 | "audio": audios,
187 | "fbank": fbank,
188 | "spk": spk_embs,
189 | "text": x,
190 | "text_lengths": x_lengths,
191 | "y": y,
192 | "y_lengths": y_lengths,
193 | "clap_input_features": clap_input_features,
194 | "clap_is_longer": clap_is_longer,
195 | "texts": [example[0] for example in examples],
196 | }
--------------------------------------------------------------------------------
/voicedit/modules.py:
--------------------------------------------------------------------------------
1 | # Codes adapted from https://github.com/heatz123/naturalspeech/blob/main/models/models.py
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from voicedit.utils import sequence_mask, generate_path, duration_loss, fix_len_compatibility
9 | from voicedit import monotonic_align
10 |
11 | from voicedit.text_encoder import TextEncoder
12 | from text.symbols import symbols
13 |
14 | class DiTWrapper(nn.Module):
15 | def __init__(self, dit, cond_speaker=False, concat_y=False, uncond_prob=0.0, dit_arch=True):
16 | super().__init__()
17 | self.dit = dit
18 | self.cond_speaker = cond_speaker
19 | self.uncond_prob = uncond_prob
20 | self.n_feats = 64
21 | self.concat_y = concat_y
22 | self.dit_arch = dit_arch
23 |
24 | self.text_encoder = TextEncoder(len(symbols)+1, n_feats = 64, n_channels = 192,
25 | filter_channels = 768, filter_channels_dp = 256, n_heads = 2,
26 | n_layers = 4, kernel_size = 3, p_dropout = 0.1, window_size = 4, spk_emb_dim = self.n_feats, cond_spk=cond_speaker)
27 |
28 | if cond_speaker:
29 | self.spk_embedding = torch.nn.Sequential(torch.nn.Linear(192, 192 * 4), nn.ReLU(),
30 | torch.nn.Linear(192 * 4, self.n_feats))
31 | if concat_y:
32 | self.proj = nn.Sequential(
33 | nn.ConstantPad2d((0, 1, 0, 1), value=0),
34 | nn.Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2)),
35 | nn.ReLU(),
36 | nn.ConstantPad2d((0, 1, 0, 1), value=0),
37 | nn.Conv2d(8, 8, kernel_size=(3, 3), stride=(2, 2)),
38 | )
39 |
40 | if not dit_arch:
41 | self.class_embedding = nn.Linear(512, 1024)
42 |
43 | self.register_buffer("y_embedding", nn.Parameter(torch.randn(8, 256, 16)) if concat_y
44 | else nn.Parameter(torch.randn(1, self.n_feats) / self.n_feats ** 0.5))
45 |
46 |
47 | def forward(self, sample, timestep, text, text_lengths=None, cond=None, y=None, y_lengths=None, y_latent=None, speech_starts=None, spk=None, train_text_encoder=True, **kwargs):
48 |
49 | if train_text_encoder:
50 | if spk is not None:
51 | spk = self.spk_embedding(spk)
52 |
53 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
54 | mu_x, logw, x_mask = self.text_encoder(text, text_lengths, spk)
55 | y_max_length = y.shape[-1]
56 |
57 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask)
58 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
59 |
60 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
61 | with torch.no_grad():
62 | const = -0.5 * math.log(2 * math.pi) * self.n_feats
63 | factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
64 | y_square = torch.matmul(factor.transpose(1, 2), y ** 2)
65 | y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
66 | mu_square = torch.sum(factor * (mu_x ** 2), 1).unsqueeze(-1)
67 | log_prior = y_square - y_mu_double + mu_square + const
68 |
69 | attn = monotonic_align.maximum_path(log_prior, attn_mask.squeeze(1))
70 | attn = attn.detach()
71 |
72 | # Compute loss between predicted log-scaled durations and those obtained from MAS
73 | logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
74 | dur_loss = duration_loss(logw, logw_, text_lengths)
75 |
76 | # Align encoded text with mel-spectrogram and get mu_y segment
77 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
78 |
79 | # Compute loss between aligned encoder outputs and mel-spectrogram
80 | prior_loss = torch.sum(0.5 * ((y - mu_y.transpose(1,2)) ** 2 + math.log(2 * math.pi)) * y_mask)
81 | prior_loss = prior_loss / (torch.sum(y_mask) * self.n_feats)
82 |
83 | # spk_emb = spk_emb.unsqueeze(1).repeat(1, mu_y.shape[1], 1)
84 | # mu_y_vae = torch.stack([mu_y, spk_emb], dim=1)
85 |
86 | if self.concat_y:
87 | condition_y = self.proj(mu_y.unsqueeze(1))
88 |
89 | # y_latent_lengths = torch.ceil(y_lengths / 4).long()
90 | # y_latent_mask = sequence_mask(y_latent_lengths, y_latent.shape[-2]).unsqueeze(1).to(mu_y_latent.dtype)
91 | # latent_prior_loss = torch.sum(0.5 * ((y_latent - mu_y_latent) ** 2 + math.log(2 * math.pi)) * y_latent_mask.unsqueeze(-1))
92 | # prior_loss += latent_prior_loss / (torch.sum(y_latent_mask) * y_latent.shape[1] * y_latent.shape[-1])
93 |
94 | condition_y = torch.nn.functional.pad(condition_y, (0, 0, 0, sample.shape[-2] - condition_y.shape[-2]), value=0)
95 | else:
96 | condition_y = mu_y.unsqueeze(1)
97 |
98 | else:
99 |
100 | with torch.no_grad():
101 | mu_y, y_mask = self.process_content(text, text_lengths, spk)
102 |
103 | if self.concat_y:
104 | condition_y = self.proj(mu_y.unsqueeze(1))
105 |
106 | # y_latent_lengths = torch.ceil(y_lengths / 4).long()
107 | # y_latent_mask = sequence_mask(y_latent_lengths, y_latent.shape[-2]).unsqueeze(1).to(mu_y_latent.dtype)
108 | # latent_prior_loss = torch.sum(0.5 * ((y_latent - mu_y_latent) ** 2 + math.log(2 * math.pi)) * y_latent_mask.unsqueeze(-1))
109 | # prior_loss += latent_prior_loss / (torch.sum(y_latent_mask) * y_latent.shape[1] * y_latent.shape[-1])
110 |
111 | # condition_temp = torch.zeros(sample.size()).to(condition_y.device, dtype=condition_y.dtype)
112 |
113 | if speech_starts is not None and torch.any(speech_starts):
114 | condition_temp = []
115 | for start in speech_starts:
116 | condition_temp.append(torch.nn.functional.pad(condition_y[0], (0, 0, start, sample.shape[-2] - condition_y[0].shape[-2] - start), value=0))
117 | condition_y = torch.stack(condition_temp, dim=0)
118 | # condition_temp_y = torch.zeros_like(sample)
119 | # indices = speech_starts.unsqueeze(1).unsqueeze(2).unsqueeze(3) + torch.arange(condition_y.shape[-2]).unsqueeze(0).unsqueeze(1).unsqueeze(3).to(speech_starts.device)
120 | # condition_temp_y.scatter_(2, indices.expand(condition_y.size()), condition_y)
121 | # condition_y = condition_temp_y
122 | else:
123 | condition_y = torch.nn.functional.pad(condition_y, (0, 0, 0, sample.shape[-2] - condition_y.shape[-2]), value=0)
124 | else:
125 | condition_y = mu_y.unsqueeze(1)
126 |
127 | if self.concat_y and self.uncond_prob > 0.0:
128 | drop_ids = torch.rand(condition_y.shape[0]).cuda() < self.uncond_prob
129 | condition_y = torch.where(drop_ids[:, None, None, None], self.y_embedding, condition_y)
130 |
131 | if self.dit_arch:
132 | sample = self.dit(
133 | sample,
134 | timestep,
135 | y=condition_y,
136 | mask=y_mask.long() if not self.concat_y else None,
137 | class_labels=cond,
138 | )
139 | else:
140 | sample = torch.cat([sample, condition_y], dim=1)
141 | cond = self.class_embedding(cond)
142 | sample = self.dit(
143 | sample,
144 | timestep,
145 | class_labels=cond,
146 | ).sample
147 |
148 | if sample.isnan().any():
149 | print("NAN detected from sample")
150 |
151 | if train_text_encoder:
152 | return dict(
153 | x=sample,
154 | dur_loss=dur_loss,
155 | prior_loss=prior_loss,
156 | )
157 | else:
158 | return sample
159 |
160 | def process_content(self, x, x_lengths, spk=None, length_scale=1.0):
161 |
162 | if spk is not None:
163 | spk = self.spk_embedding(spk)
164 |
165 | if spk.shape[0] != x.shape[0]:
166 | spk = spk.repeat(x.shape[0], 1)
167 |
168 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
169 | mu_x, logw, x_mask = self.text_encoder(x, x_lengths, spk)
170 |
171 | w = torch.exp(logw) * x_mask
172 | w_ceil = torch.ceil(w) * length_scale
173 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
174 | y_max_length = int(y_lengths.max())
175 | y_max_length_ = fix_len_compatibility(y_max_length)
176 |
177 | # Using obtained durations `w` construct alignment map `attn`
178 | y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
179 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
180 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
181 |
182 | # Align encoded text and get mu_y
183 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
184 |
185 | # spk_emb = spk_emb.unsqueeze(1).repeat(1, mu_y.shape[1], 1)
186 | # mu_y_vae = torch.stack([mu_y, spk_emb], dim=1)
187 |
188 | # if self.concat_y == 'vae':
189 | # assert vae is not None
190 | # mu_y_latent = vae.encode(mu_y.unsqueeze(1)).latent_dist.sample()
191 | # mu_y_latent = mu_y_latent * vae.config.scaling_factor
192 | # if self.concat_y:
193 | # condition_y = self.proj(mu_y.unsqueeze(1))
194 | # else:
195 | # condition_y = mu_y.unsqueeze(1)
196 |
197 | return mu_y, y_mask
198 |
--------------------------------------------------------------------------------
/diffusion/model/nets/DiT2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 | import math
12 | import torch
13 | import torch.nn as nn
14 | import os
15 | import numpy as np
16 | from timm.models.layers import DropPath
17 | from timm.models.vision_transformer import Mlp, Attention
18 |
19 | import sys
20 | sys.path.insert(1, os.getcwd())
21 |
22 | from diffusion.model.utils import auto_grad_checkpoint, to_2tuple, PatchEmbed
23 | from diffusion.model.nets.DiT_blocks import t2i_modulate, CaptionEmbedder, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer, modulate
24 | # from diffusion.utils.logger import get_root_logger
25 |
26 |
27 | class DiTBlock(nn.Module):
28 | """
29 | A DiT block with adaptive layer norm (adaLN-single) conditioning.
30 | """
31 |
32 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, cross_class=False, **block_kwargs):
33 | super().__init__()
34 | self.hidden_size = hidden_size
35 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
36 | self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
37 | input_size=input_size if window_size == 0 else (window_size, window_size),
38 | use_rel_pos=use_rel_pos, **block_kwargs)
39 | if cross_class:
40 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
41 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
42 | # to be compatible with lower version pytorch
43 | approx_gelu = lambda: nn.GELU(approximate="tanh")
44 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
45 | self.window_size = window_size
46 | self.adaLN_modulation = nn.Sequential(
47 | nn.SiLU(),
48 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
49 | )
50 |
51 | def forward(self, x, t, c=None, mask=None, **kwargs):
52 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(6, dim=1)
53 | x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
54 | if hasattr(self, 'cross_attn'):
55 | x = x + self.cross_attn(x, c, mask)
56 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
57 |
58 | return x
59 |
60 |
61 | #############################################################################
62 | # Core DiT Model #
63 | #################################################################################
64 | class DiT(nn.Module):
65 | """
66 | Diffusion model with a Transformer backbone.
67 | """
68 |
69 | def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, concat_y=False, cross_class=False, class_dropout_prob=0.0, pred_sigma=True, projection_class_embeddings_input_dim=512, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, caption_channels=4096, lewei_scale=1.0, config=None, **kwargs):
70 | if window_block_indexes is None:
71 | window_block_indexes = []
72 | super().__init__()
73 | self.pred_sigma = pred_sigma
74 | self.in_channels = in_channels
75 | self.out_channels = in_channels * 2 if pred_sigma else in_channels
76 | self.patch_size = patch_size
77 | self.num_heads = num_heads
78 | self.lewei_scale = lewei_scale,
79 | self.cross_class = cross_class
80 |
81 | self.concat_y = concat_y
82 | if isinstance(input_size, int):
83 | input_size = to_2tuple(input_size)
84 |
85 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels * 2 if concat_y else in_channels, hidden_size, bias=True)
86 | self.t_embedder = TimestepEmbedder(hidden_size)
87 | self.c_embedder = nn.Linear(projection_class_embeddings_input_dim, hidden_size)
88 | num_patches = self.x_embedder.num_patches
89 | # Will use fixed sin-cos embedding:
90 | self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
91 |
92 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
93 | self.blocks = nn.ModuleList([
94 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
95 | input_size=(input_size[0] // patch_size, input_size[1] // patch_size),
96 | window_size=window_size if i in window_block_indexes else 0,
97 | use_rel_pos=use_rel_pos if i in window_block_indexes else False,
98 | cross_class=cross_class)
99 | for i in range(depth)
100 | ])
101 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
102 |
103 | self.initialize_weights()
104 |
105 | def forward(self, x, timestep, class_labels, y, mask=None, **kwargs):
106 | """
107 | Forward pass of DiT.
108 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
109 | t: (N,) tensor of diffusion timesteps
110 | y: (N, C, H, W) tensor of caption labels
111 | class_labels: (N, C') tensor of class labels
112 | """
113 |
114 | x = x.to(self.dtype)
115 | timestep = timestep.to(self.dtype)
116 | y = y.to(self.dtype)
117 | pos_embed = self.pos_embed.to(self.dtype)
118 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
119 | if self.concat_y:
120 | x = torch.cat([x, y], dim=1)
121 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
122 | t = self.t_embedder(timestep) # (N, D)
123 | class_embed = self.c_embedder(class_labels)
124 | t0 = t + class_embed
125 | if self.cross_class:
126 | class_lens = [1] * class_embed.shape[0]
127 | class_embed = class_embed.unsqueeze(1).view(1, -1, x.shape[-1])
128 | else:
129 | class_lens = None
130 | for block in self.blocks:
131 | x = auto_grad_checkpoint(block, x, t0, class_embed, class_lens)
132 | x = self.final_layer(x, t0) # (N, T, patch_size ** 2 * out_channels)
133 | x = self.unpatchify(x) # (N, out_channels, H, W)
134 | return x
135 |
136 | def forward_with_dpmsolver(self, x, timestep, y, class_labels=None, mask=None, **kwargs):
137 | """
138 | dpm solver donnot need variance prediction
139 | """
140 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
141 | model_out = self.forward(x, timestep, y, mask, class_labels)
142 | return model_out.chunk(2, dim=1)[0]
143 |
144 | def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, class_labels=None, guidance="single", **kwargs):
145 | """
146 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
147 | """
148 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
149 | if guidance == "dual":
150 | num_chunk = 4
151 | elif guidance == "single":
152 | num_chunk = 2
153 | chunk = x[: len(x) // num_chunk]
154 | combined = torch.cat([chunk] * num_chunk, dim=0)
155 | model_out = self.forward(combined, timestep, class_labels, y, mask)
156 | model_out = model_out['x'] if isinstance(model_out, dict) else model_out
157 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
158 | eps = eps.chunk(num_chunk, dim=0)
159 | if guidance == "dual":
160 | pred_eps = eps[0] + cfg_scale[0] * (eps[1] - eps[3]) + cfg_scale[1] * (eps[2] - eps[3])
161 | elif guidance == "single":
162 | pred_eps = eps[1] + cfg_scale * (eps[0] - eps[1])
163 | eps = torch.cat([pred_eps] * num_chunk, dim=0)
164 | return torch.cat([eps, rest], dim=1)
165 |
166 | def unpatchify(self, x):
167 | """
168 | x: (N, T, patch_size**2 * C)
169 | imgs: (N, H, W, C)
170 | """
171 | c = self.out_channels
172 | p1, p2 = self.x_embedder.patch_size
173 | h, w = self.x_embedder.grid_size
174 | assert h * w == x.shape[1]
175 |
176 | x = x.reshape(shape=(x.shape[0], h, w, p1, p2, c))
177 | x = torch.einsum('nhwpqc->nchpwq', x)
178 | return x.reshape(shape=(x.shape[0], c, h * p1, w * p2))
179 |
180 | def initialize_weights(self):
181 | # Initialize transformer layers:
182 | def _basic_init(module):
183 | if isinstance(module, nn.Linear):
184 | torch.nn.init.xavier_uniform_(module.weight)
185 | if module.bias is not None:
186 | nn.init.constant_(module.bias, 0)
187 |
188 | self.apply(_basic_init)
189 |
190 | # Initialize (and freeze) pos_embed by sin-cos embedding:
191 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size)
192 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
193 |
194 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
195 | w = self.x_embedder.proj.weight.data
196 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
197 | nn.init.constant_(self.x_embedder.proj.bias, 0)
198 |
199 | # Initialize timestep embedding MLP:
200 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
201 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
202 |
203 | # Zero-out adaLN modulation layers in DiT blocks:
204 | for block in self.blocks:
205 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
206 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
207 | if hasattr(block, 'cross_attn'):
208 | nn.init.constant_(block.cross_attn.proj.weight, 0)
209 | nn.init.constant_(block.cross_attn.proj.bias, 0)
210 |
211 | # Zero-out output layers:
212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214 | nn.init.constant_(self.final_layer.linear.weight, 0)
215 | nn.init.constant_(self.final_layer.linear.bias, 0)
216 |
217 | @property
218 | def dtype(self):
219 | return next(self.parameters()).dtype
220 |
221 |
222 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
223 | """
224 | grid_size: int of the grid height and width
225 | return:
226 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
227 | """
228 | if isinstance(grid_size, int):
229 | grid_size = to_2tuple(grid_size)
230 | grid_h = np.arange(grid_size[0], dtype=np.float32)
231 | grid_w = np.arange(grid_size[1], dtype=np.float32)
232 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
233 | grid = np.stack(grid, axis=0)
234 |
235 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
236 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
237 | if cls_token and extra_tokens > 0:
238 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
239 | return pos_embed
240 |
241 |
242 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
243 | assert embed_dim % 2 == 0
244 |
245 | # use half of dimensions to encode grid_h
246 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
247 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
248 |
249 | return np.concatenate([emb_h, emb_w], axis=1)
250 |
251 |
252 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
253 | """
254 | embed_dim: output dimension for each position
255 | pos: a list of positions to be encoded: size (M,)
256 | out: (M, D)
257 | """
258 | assert embed_dim % 2 == 0
259 | omega = np.arange(embed_dim // 2, dtype=np.float64)
260 | omega /= embed_dim / 2.
261 | omega = 1. / 10000 ** omega # (D/2,)
262 |
263 | pos = pos.reshape(-1) # (M,)
264 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
265 |
266 | emb_sin = np.sin(out) # (M, D/2)
267 | emb_cos = np.cos(out) # (M, D/2)
268 |
269 | return np.concatenate([emb_sin, emb_cos], axis=1)
270 |
271 |
272 | #################################################################################
273 | # DiT Configs #
274 | #################################################################################
275 | def DiT_XL_2(**kwargs):
276 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
277 |
278 | def DiT_L_2(**kwargs):
279 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
280 |
281 | def DiT_B_2(**kwargs):
282 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
283 |
284 | def DiT_S_2(**kwargs):
285 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
286 |
287 |
288 | DiT_models = {
289 | 'DiT_XL/2': DiT_XL_2,
290 | 'DiT_L/2': DiT_L_2,
291 | 'DiT_B/2': DiT_B_2,
292 | 'DiT_S/2': DiT_S_2,
293 | }
294 |
295 | if __name__ == '__main__':
296 |
297 | from thop import profile, clever_format
298 |
299 | device = "cuda:0"
300 | model = DiT_models['DiT_L/2'](
301 | input_size = (256, 16),
302 | in_channels = 1,
303 | projection_class_embeddings_input_dim=512,
304 | concat_y = True,
305 | cross_class = True,
306 | ).to(device)
307 | model.eval()
308 | model.requires_grad_(False)
309 |
310 | x = torch.randn(1, 1, 256, 16).to(device)
311 | timestep = torch.Tensor([999]).to(device)
312 | y = torch.randn(1, 1, 256, 16).to(device)
313 | mask = torch.ones(1, 1, 1, 1).to(device)
314 | class_labels = torch.randn(1, 512).to(device)
315 |
316 | macs, params = profile(model, inputs=(x, timestep, class_labels, y))
317 | flops = macs * 2
318 | macs, params, flops = clever_format([macs, params, flops], "%.4f")
319 | print("Params:", params)
320 | print("MACs:", macs)
321 | print("FLOPs:", flops)
322 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/voicedit/text_encoder.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jaywalnut310/glow-tts """
2 |
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .utils import sequence_mask, convert_pad_shape
9 |
10 |
11 | class LayerNorm(nn.Module):
12 | def __init__(self, channels, eps=1e-4):
13 | super(LayerNorm, self).__init__()
14 | self.channels = channels
15 | self.eps = eps
16 |
17 | self.gamma = torch.nn.Parameter(torch.ones(channels))
18 | self.beta = torch.nn.Parameter(torch.zeros(channels))
19 |
20 | def forward(self, x):
21 | n_dims = len(x.shape)
22 | mean = torch.mean(x, 1, keepdim=True)
23 | variance = torch.mean((x - mean)**2, 1, keepdim=True)
24 |
25 | x = (x - mean) * torch.rsqrt(variance + self.eps)
26 |
27 | shape = [1, -1] + [1] * (n_dims - 2)
28 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
29 | return x
30 |
31 |
32 | class ConvReluNorm(nn.Module):
33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
34 | n_layers, p_dropout):
35 | super(ConvReluNorm, self).__init__()
36 | self.in_channels = in_channels
37 | self.hidden_channels = hidden_channels
38 | self.out_channels = out_channels
39 | self.kernel_size = kernel_size
40 | self.n_layers = n_layers
41 | self.p_dropout = p_dropout
42 |
43 | self.conv_layers = torch.nn.ModuleList()
44 | self.norm_layers = torch.nn.ModuleList()
45 | self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
46 | kernel_size, padding=kernel_size//2))
47 | self.norm_layers.append(LayerNorm(hidden_channels))
48 | self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
49 | for _ in range(n_layers - 1):
50 | self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
51 | kernel_size, padding=kernel_size//2))
52 | self.norm_layers.append(LayerNorm(hidden_channels))
53 | self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
54 | self.proj.weight.data.zero_()
55 | self.proj.bias.data.zero_()
56 |
57 | def forward(self, x, x_mask):
58 | x_org = x
59 | for i in range(self.n_layers):
60 | x = self.conv_layers[i](x * x_mask)
61 | x = self.norm_layers[i](x)
62 | x = self.relu_drop(x)
63 | x = x_org + self.proj(x)
64 | return x * x_mask
65 |
66 |
67 | class DurationPredictor(nn.Module):
68 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
69 | super(DurationPredictor, self).__init__()
70 | self.in_channels = in_channels
71 | self.filter_channels = filter_channels
72 | self.p_dropout = p_dropout
73 |
74 | self.drop = torch.nn.Dropout(p_dropout)
75 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
76 | kernel_size, padding=kernel_size//2)
77 | self.norm_1 = LayerNorm(filter_channels)
78 | self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
79 | kernel_size, padding=kernel_size//2)
80 | self.norm_2 = LayerNorm(filter_channels)
81 | self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
82 |
83 | def forward(self, x, x_mask):
84 | x = self.conv_1(x * x_mask)
85 | x = torch.relu(x)
86 | x = self.norm_1(x)
87 | x = self.drop(x)
88 | x = self.conv_2(x * x_mask)
89 | x = torch.relu(x)
90 | x = self.norm_2(x)
91 | x = self.drop(x)
92 | x = self.proj(x * x_mask)
93 | return x * x_mask
94 |
95 |
96 | class MultiHeadAttention(nn.Module):
97 | def __init__(self, channels, out_channels, n_heads, window_size=None,
98 | heads_share=True, p_dropout=0.0, proximal_bias=False,
99 | proximal_init=False):
100 | super(MultiHeadAttention, self).__init__()
101 | assert channels % n_heads == 0
102 |
103 | self.channels = channels
104 | self.out_channels = out_channels
105 | self.n_heads = n_heads
106 | self.window_size = window_size
107 | self.heads_share = heads_share
108 | self.proximal_bias = proximal_bias
109 | self.p_dropout = p_dropout
110 | self.attn = None
111 |
112 | self.k_channels = channels // n_heads
113 | self.conv_q = torch.nn.Conv1d(channels, channels, 1)
114 | self.conv_k = torch.nn.Conv1d(channels, channels, 1)
115 | self.conv_v = torch.nn.Conv1d(channels, channels, 1)
116 | if window_size is not None:
117 | n_heads_rel = 1 if heads_share else n_heads
118 | rel_stddev = self.k_channels**-0.5
119 | self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
120 | window_size * 2 + 1, self.k_channels) * rel_stddev)
121 | self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
122 | window_size * 2 + 1, self.k_channels) * rel_stddev)
123 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
124 | self.drop = torch.nn.Dropout(p_dropout)
125 |
126 | torch.nn.init.xavier_uniform_(self.conv_q.weight)
127 | torch.nn.init.xavier_uniform_(self.conv_k.weight)
128 | if proximal_init:
129 | self.conv_k.weight.data.copy_(self.conv_q.weight.data)
130 | self.conv_k.bias.data.copy_(self.conv_q.bias.data)
131 | torch.nn.init.xavier_uniform_(self.conv_v.weight)
132 |
133 | def forward(self, x, c, attn_mask=None):
134 | q = self.conv_q(x)
135 | k = self.conv_k(c)
136 | v = self.conv_v(c)
137 |
138 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
139 |
140 | x = self.conv_o(x)
141 | return x
142 |
143 | def attention(self, query, key, value, mask=None):
144 | b, d, t_s, t_t = (*key.size(), query.size(2))
145 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
146 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
147 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
148 |
149 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
150 | if self.window_size is not None:
151 | assert t_s == t_t, "Relative attention is only available for self-attention."
152 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
153 | rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
154 | rel_logits = self._relative_position_to_absolute_position(rel_logits)
155 | scores_local = rel_logits / math.sqrt(self.k_channels)
156 | scores = scores + scores_local
157 | if self.proximal_bias:
158 | assert t_s == t_t, "Proximal bias is only available for self-attention."
159 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
160 | dtype=scores.dtype)
161 | if mask is not None:
162 | scores = scores.masked_fill(mask == 0, -1e4)
163 | p_attn = torch.nn.functional.softmax(scores, dim=-1)
164 | p_attn = self.drop(p_attn)
165 | output = torch.matmul(p_attn, value)
166 | if self.window_size is not None:
167 | relative_weights = self._absolute_position_to_relative_position(p_attn)
168 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
169 | output = output + self._matmul_with_relative_values(relative_weights,
170 | value_relative_embeddings)
171 | output = output.transpose(2, 3).contiguous().view(b, d, t_t)
172 | return output, p_attn
173 |
174 | def _matmul_with_relative_values(self, x, y):
175 | ret = torch.matmul(x, y.unsqueeze(0))
176 | return ret
177 |
178 | def _matmul_with_relative_keys(self, x, y):
179 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
180 | return ret
181 |
182 | def _get_relative_embeddings(self, relative_embeddings, length):
183 | pad_length = max(length - (self.window_size + 1), 0)
184 | slice_start_position = max((self.window_size + 1) - length, 0)
185 | slice_end_position = slice_start_position + 2 * length - 1
186 | if pad_length > 0:
187 | padded_relative_embeddings = torch.nn.functional.pad(
188 | relative_embeddings, convert_pad_shape([[0, 0],
189 | [pad_length, pad_length], [0, 0]]))
190 | else:
191 | padded_relative_embeddings = relative_embeddings
192 | used_relative_embeddings = padded_relative_embeddings[:,
193 | slice_start_position:slice_end_position]
194 | return used_relative_embeddings
195 |
196 | def _relative_position_to_absolute_position(self, x):
197 | batch, heads, length, _ = x.size()
198 | x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
199 | x_flat = x.view([batch, heads, length * 2 * length])
200 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
201 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
202 | return x_final
203 |
204 | def _absolute_position_to_relative_position(self, x):
205 | batch, heads, length, _ = x.size()
206 | x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
207 | x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
208 | x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
209 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
210 | return x_final
211 |
212 | def _attention_bias_proximal(self, length):
213 | r = torch.arange(length, dtype=torch.float32)
214 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
215 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
216 |
217 |
218 | class FFN(nn.Module):
219 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
220 | p_dropout=0.0):
221 | super(FFN, self).__init__()
222 | self.in_channels = in_channels
223 | self.out_channels = out_channels
224 | self.filter_channels = filter_channels
225 | self.kernel_size = kernel_size
226 | self.p_dropout = p_dropout
227 |
228 | self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
229 | padding=kernel_size//2)
230 | self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
231 | padding=kernel_size//2)
232 | self.drop = torch.nn.Dropout(p_dropout)
233 |
234 | def forward(self, x, x_mask):
235 | x = self.conv_1(x * x_mask)
236 | x = torch.relu(x)
237 | x = self.drop(x)
238 | x = self.conv_2(x * x_mask)
239 | return x * x_mask
240 |
241 |
242 | class Encoder(nn.Module):
243 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
244 | kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
245 | super(Encoder, self).__init__()
246 | self.hidden_channels = hidden_channels
247 | self.filter_channels = filter_channels
248 | self.n_heads = n_heads
249 | self.n_layers = n_layers
250 | self.kernel_size = kernel_size
251 | self.p_dropout = p_dropout
252 | self.window_size = window_size
253 |
254 | self.drop = torch.nn.Dropout(p_dropout)
255 | self.attn_layers = torch.nn.ModuleList()
256 | self.norm_layers_1 = torch.nn.ModuleList()
257 | self.ffn_layers = torch.nn.ModuleList()
258 | self.norm_layers_2 = torch.nn.ModuleList()
259 | for _ in range(self.n_layers):
260 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
261 | n_heads, window_size=window_size, p_dropout=p_dropout))
262 | self.norm_layers_1.append(LayerNorm(hidden_channels))
263 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
264 | filter_channels, kernel_size, p_dropout=p_dropout))
265 | self.norm_layers_2.append(LayerNorm(hidden_channels))
266 |
267 | def forward(self, x, x_mask):
268 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
269 | for i in range(self.n_layers):
270 | x = x * x_mask
271 | y = self.attn_layers[i](x, x, attn_mask)
272 | y = self.drop(y)
273 | x = self.norm_layers_1[i](x + y)
274 | y = self.ffn_layers[i](x, x_mask)
275 | y = self.drop(y)
276 | x = self.norm_layers_2[i](x + y)
277 | x = x * x_mask
278 | return x
279 |
280 |
281 | class TextEncoder(nn.Module):
282 | def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
283 | filter_channels_dp, n_heads, n_layers, kernel_size,
284 | p_dropout, window_size=None, spk_emb_dim=64, cond_spk=False):
285 | super(TextEncoder, self).__init__()
286 | self.n_vocab = n_vocab
287 | self.n_feats = n_feats
288 | self.n_channels = n_channels
289 | self.filter_channels = filter_channels
290 | self.filter_channels_dp = filter_channels_dp
291 | self.n_heads = n_heads
292 | self.n_layers = n_layers
293 | self.kernel_size = kernel_size
294 | self.p_dropout = p_dropout
295 | self.window_size = window_size
296 | self.spk_emb_dim = spk_emb_dim
297 | self.cond_spk = cond_spk
298 |
299 | self.emb = torch.nn.Embedding(n_vocab, n_channels)
300 | torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
301 |
302 | self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
303 | kernel_size=5, n_layers=3, p_dropout=0.5)
304 |
305 | self.encoder = Encoder(n_channels + (spk_emb_dim if cond_spk else 0), filter_channels, n_heads, n_layers,
306 | kernel_size, p_dropout, window_size=window_size)
307 |
308 | self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if cond_spk else 0), n_feats, 1)
309 | self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if cond_spk else 0), filter_channels_dp,
310 | kernel_size, p_dropout)
311 |
312 | def forward(self, x, x_lengths, spk=None):
313 | x = self.emb(x) * math.sqrt(self.n_channels)
314 | x = torch.transpose(x, 1, -1)
315 | x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
316 |
317 | x = self.prenet(x, x_mask)
318 | if self.cond_spk:
319 | x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
320 | x = self.encoder(x, x_mask)
321 | mu = self.proj_m(x) * x_mask
322 |
323 | x_dp = torch.detach(x)
324 | logw = self.proj_w(x_dp, x_mask)
325 |
326 | return mu, logw, x_mask
327 |
--------------------------------------------------------------------------------
/diffusion/model/nets/DiT.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 | import math
12 | import torch
13 | import torch.nn as nn
14 | import os
15 | import numpy as np
16 | from timm.models.layers import DropPath
17 | from timm.models.vision_transformer import Mlp
18 |
19 | import sys
20 | sys.path.insert(1, os.getcwd())
21 |
22 | from diffusion.model.utils import auto_grad_checkpoint, to_2tuple, PatchEmbed
23 | from diffusion.model.nets.DiT_blocks import t2i_modulate, CaptionEmbedder, WindowAttention, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer, modulate
24 | # from diffusion.utils.logger import get_root_logger
25 |
26 |
27 | class DiTBlock(nn.Module):
28 | """
29 | A DiT block with adaptive layer norm (adaLN-single) conditioning.
30 | """
31 |
32 | def __init__(self, hidden_size, num_heads, cross_attn_dim=None, mlp_ratio=4.0, drop_path=0., window_size=0, input_size=None, use_rel_pos=False, **block_kwargs):
33 | super().__init__()
34 | self.hidden_size = hidden_size
35 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
36 | self.attn = WindowAttention(hidden_size, num_heads=num_heads, qkv_bias=True,
37 | input_size=input_size if window_size == 0 else (window_size, window_size),
38 | use_rel_pos=use_rel_pos, **block_kwargs)
39 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, cross_attn_dim, **block_kwargs)
40 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
41 | # to be compatible with lower version pytorch
42 | approx_gelu = lambda: nn.GELU(approximate="tanh")
43 | self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
44 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
45 | self.window_size = window_size
46 | self.adaLN_modulation = nn.Sequential(
47 | nn.SiLU(),
48 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
49 | )
50 |
51 | def forward(self, x, y, t, mask=None, **kwargs):
52 | B, N, C = x.shape
53 |
54 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(6, dim=1)
55 | x = x + self.drop_path(gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
56 | x = x + self.cross_attn(x, y, mask)
57 | x = x + self.drop_path(gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)))
58 |
59 | return x
60 |
61 |
62 | #############################################################################
63 | # Core DiT Model #
64 | #################################################################################
65 | class DiT(nn.Module):
66 | """
67 | Diffusion model with a Transformer backbone.
68 | """
69 |
70 | def __init__(self, input_size=32, patch_size=2, in_channels=4, hidden_size=1152, depth=28, num_heads=16, mlp_ratio=4.0, class_dropout_prob=0.0, pred_sigma=True, projection_class_embeddings_input_dim=512, drop_path: float = 0., window_size=0, window_block_indexes=None, use_rel_pos=False, caption_channels=4096, lewei_scale=1.0, config=None, **kwargs):
71 | if window_block_indexes is None:
72 | window_block_indexes = []
73 | super().__init__()
74 | self.pred_sigma = pred_sigma
75 | self.in_channels = in_channels
76 | self.out_channels = in_channels * 2 if pred_sigma else in_channels
77 | self.patch_size = patch_size
78 | self.num_heads = num_heads
79 | self.lewei_scale = lewei_scale,
80 |
81 | if isinstance(input_size, int):
82 | input_size = to_2tuple(input_size)
83 |
84 | self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
85 | self.t_embedder = TimestepEmbedder(hidden_size)
86 | self.c_embedder = nn.Linear(projection_class_embeddings_input_dim, hidden_size)
87 | num_patches = self.x_embedder.num_patches
88 | # Will use fixed sin-cos embedding:
89 | self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
90 |
91 | approx_gelu = lambda: nn.GELU(approximate="tanh")
92 | self.y_embedder = CaptionEmbedder(in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, act_layer=approx_gelu)
93 | drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
94 | self.blocks = nn.ModuleList([
95 | DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
96 | input_size=(input_size[0] // patch_size, input_size[1] // patch_size),
97 | window_size=window_size if i in window_block_indexes else 0,
98 | use_rel_pos=use_rel_pos if i in window_block_indexes else False)
99 | for i in range(depth)
100 | ])
101 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
102 |
103 | self.initialize_weights()
104 |
105 | def forward(self, x, timestep, y, mask=None, class_labels=None, data_info=None, **kwargs):
106 | """
107 | Forward pass of DiT.
108 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
109 | t: (N,) tensor of diffusion timesteps
110 | y: (N, 1, T, C) tensor of caption labels
111 | class_labels: (N, C') tensor of class labels
112 | """
113 | x = x.to(self.dtype)
114 | timestep = timestep.to(self.dtype)
115 | y = y.to(self.dtype)
116 | pos_embed = self.pos_embed.to(self.dtype)
117 | self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
118 | x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
119 | t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
120 | class_embed = self.c_embedder(class_labels)
121 | t0 = t + class_embed
122 | y = self.y_embedder(y) # (N, 1, L, D)
123 | if mask is not None:
124 | if mask.shape[0] != y.shape[0]:
125 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
126 | mask = mask.squeeze(1).squeeze(1)
127 | y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
128 | y_lens = mask.sum(dim=1).tolist()
129 | else:
130 | y_lens = [y.shape[2]] * y.shape[0]
131 | y = y.squeeze(1).view(1, -1, x.shape[-1])
132 | for block in self.blocks:
133 | x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
134 | x = self.final_layer(x, t0) # (N, T, patch_size ** 2 * out_channels)
135 | x = self.unpatchify(x) # (N, out_channels, H, W)
136 | return x
137 |
138 | def forward_with_dpmsolver(self, x, timestep, y, mask=None, class_labels=None, **kwargs):
139 | """
140 | dpm solver donnot need variance prediction
141 | """
142 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
143 | model_out = self.forward(x, timestep, y, mask, class_labels)
144 | return model_out.chunk(2, dim=1)[0]
145 |
146 | def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, class_labels=None, guidance="single", **kwargs):
147 | """
148 | Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
149 | """
150 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
151 | if guidance == "dual":
152 | num_chunk = 4
153 | elif guidance == "single":
154 | num_chunk = 2
155 | chunk = x[: len(x) // num_chunk]
156 | combined = torch.cat([chunk] * num_chunk, dim=0)
157 | model_out = self.forward(combined, timestep, y, mask, class_labels, kwargs)
158 | model_out = model_out['x'] if isinstance(model_out, dict) else model_out
159 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
160 | eps = eps.chunk(num_chunk, dim=0)
161 | if guidance == "dual":
162 | pred_eps = eps[0] + cfg_scale[0] * (eps[1] - eps[3]) + cfg_scale[1] * (eps[2] - eps[3])
163 | elif guidance == "single":
164 | pred_eps = eps[1] + cfg_scale * (eps[0] - eps[1])
165 | eps = torch.cat([pred_eps] * num_chunk, dim=0)
166 | return torch.cat([eps, rest], dim=1)
167 |
168 | def unpatchify(self, x):
169 | """
170 | x: (N, T, patch_size**2 * C)
171 | imgs: (N, H, W, C)
172 | """
173 | c = self.out_channels
174 | p1, p2 = self.x_embedder.patch_size
175 | h, w = self.x_embedder.grid_size
176 | assert h * w == x.shape[1]
177 |
178 | x = x.reshape(shape=(x.shape[0], h, w, p1, p2, c))
179 | x = torch.einsum('nhwpqc->nchpwq', x)
180 | return x.reshape(shape=(x.shape[0], c, h * p1, w * p2))
181 |
182 | def initialize_weights(self):
183 | # Initialize transformer layers:
184 | def _basic_init(module):
185 | if isinstance(module, nn.Linear):
186 | torch.nn.init.xavier_uniform_(module.weight)
187 | if module.bias is not None:
188 | nn.init.constant_(module.bias, 0)
189 |
190 | self.apply(_basic_init)
191 |
192 | # Initialize (and freeze) pos_embed by sin-cos embedding:
193 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.x_embedder.grid_size, lewei_scale=self.lewei_scale)
194 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
195 |
196 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
197 | w = self.x_embedder.proj.weight.data
198 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
199 | nn.init.constant_(self.x_embedder.proj.bias, 0)
200 |
201 | # Initialize timestep embedding MLP:
202 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
203 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
204 |
205 | # Initialize caption embedding MLP:
206 | nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
207 | nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
208 |
209 | # Zero-out adaLN modulation layers in DiT blocks:
210 | for block in self.blocks:
211 | nn.init.constant_(block.cross_attn.proj.weight, 0)
212 | nn.init.constant_(block.cross_attn.proj.bias, 0)
213 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
214 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
215 |
216 | # Zero-out output layers:
217 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
218 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
219 | nn.init.constant_(self.final_layer.linear.weight, 0)
220 | nn.init.constant_(self.final_layer.linear.bias, 0)
221 |
222 | @property
223 | def dtype(self):
224 | return next(self.parameters()).dtype
225 |
226 |
227 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, lewei_scale=1.0):
228 | """
229 | grid_size: int of the grid height and width
230 | return:
231 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
232 | """
233 | if isinstance(grid_size, int):
234 | grid_size = to_2tuple(grid_size)
235 | grid_h = np.arange(grid_size[0], dtype=np.float32) / lewei_scale
236 | grid_w = np.arange(grid_size[1], dtype=np.float32) / lewei_scale
237 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
238 | grid = np.stack(grid, axis=0)
239 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
240 |
241 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
242 | if cls_token and extra_tokens > 0:
243 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
244 | return pos_embed
245 |
246 |
247 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
248 | assert embed_dim % 2 == 0
249 |
250 | # use half of dimensions to encode grid_h
251 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
252 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
253 |
254 | return np.concatenate([emb_h, emb_w], axis=1)
255 |
256 |
257 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
258 | """
259 | embed_dim: output dimension for each position
260 | pos: a list of positions to be encoded: size (M,)
261 | out: (M, D)
262 | """
263 | assert embed_dim % 2 == 0
264 | omega = np.arange(embed_dim // 2, dtype=np.float64)
265 | omega /= embed_dim / 2.
266 | omega = 1. / 10000 ** omega # (D/2,)
267 |
268 | pos = pos.reshape(-1) # (M,)
269 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
270 |
271 | emb_sin = np.sin(out) # (M, D/2)
272 | emb_cos = np.cos(out) # (M, D/2)
273 |
274 | return np.concatenate([emb_sin, emb_cos], axis=1)
275 |
276 |
277 | #################################################################################
278 | # DiT Configs #
279 | #################################################################################
280 | def DiT_XL_2(**kwargs):
281 | return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
282 |
283 | def DiT_L_2(**kwargs):
284 | return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
285 |
286 | def DiT_B_2(**kwargs):
287 | return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
288 |
289 | def DiT_S_2(**kwargs):
290 | return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
291 |
292 |
293 | DiT_cross_models = {
294 | 'DiT_XL/2': DiT_XL_2,
295 | 'DiT_L/2': DiT_L_2,
296 | 'DiT_B/2': DiT_B_2,
297 | 'DiT_S/2': DiT_S_2,
298 | }
299 |
300 | if __name__ == '__main__':
301 |
302 | from thop import profile, clever_format
303 |
304 | device = "cuda:0"
305 |
306 | model = DiT_cross_models['DiT_L/2'](
307 | input_size = (256, 16),
308 | in_channels = 8,
309 | projection_class_embeddings_input_dim=512,
310 | caption_channels = 64,
311 | concat_y = False,
312 | cross_class = False,
313 | ).to(device)
314 |
315 | model.eval()
316 | model.requires_grad_(False)
317 | x = torch.randn(1, 8, 256, 16).to(device)
318 | timestep = torch.Tensor([999]).to(device)
319 | y = torch.randn(1, 1, 1024, 64).to(device)
320 | mask = torch.ones(1, 1, 1, 1024).to(device)
321 | class_labels = torch.randn(1, 512).to(device)
322 |
323 | macs, params = profile(model, inputs=(x, timestep, y, mask, class_labels))
324 | flops = macs * 2
325 | macs, params, flops = clever_format([macs, params, flops], "%.4f")
326 |
327 | print("Params:", params)
328 | print("MACs:", macs)
329 | print("FLOPs:", flops)
--------------------------------------------------------------------------------
/diffusion/model/nets/DiT_blocks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 | import math
12 | import torch
13 | import torch.nn as nn
14 | from timm.models.vision_transformer import Mlp, Attention as Attention_
15 | from einops import rearrange, repeat
16 | import xformers.ops
17 |
18 | from diffusion.model.utils import add_decomposed_rel_pos
19 |
20 |
21 | def modulate(x, shift, scale):
22 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
23 |
24 |
25 | def t2i_modulate(x, shift, scale):
26 | return x * (1 + scale) + shift
27 |
28 |
29 | class MultiHeadCrossAttention(nn.Module):
30 | def __init__(self, d_model, num_heads, cross_attn_dim=None, attn_drop=0., proj_drop=0., **block_kwargs):
31 | super(MultiHeadCrossAttention, self).__init__()
32 | assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
33 |
34 | self.d_model = d_model
35 | self.num_heads = num_heads
36 | self.head_dim = d_model // num_heads
37 | if cross_attn_dim is None:
38 | cross_attn_dim = d_model
39 | self.cross_attn_dim = cross_attn_dim
40 |
41 | self.q_linear = nn.Linear(d_model, d_model)
42 | self.kv_linear = nn.Linear(cross_attn_dim, d_model*2)
43 | self.attn_drop = nn.Dropout(attn_drop)
44 | self.proj = nn.Linear(d_model, d_model)
45 | self.proj_drop = nn.Dropout(proj_drop)
46 |
47 | def forward(self, x, cond, mask=None):
48 | # query: img tokens; key/value: condition; mask: if padding tokens
49 | B, N, C = x.shape
50 |
51 | q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
52 | kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
53 | k, v = kv.unbind(2)
54 | attn_bias = None
55 | if mask is not None:
56 | attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
57 | x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
58 | x = x.view(B, -1, C)
59 | x = self.proj(x)
60 | x = self.proj_drop(x)
61 |
62 | # q = self.q_linear(x).reshape(B, -1, self.num_heads, self.head_dim)
63 | # kv = self.kv_linear(cond).reshape(B, -1, 2, self.num_heads, self.head_dim)
64 | # k, v = kv.unbind(2)
65 | # attn_bias = None
66 | # if mask is not None:
67 | # attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
68 | # attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
69 | # x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
70 | # x = x.contiguous().reshape(B, -1, C)
71 | # x = self.proj(x)
72 | # x = self.proj_drop(x)
73 |
74 | return x
75 |
76 |
77 | class WindowAttention(Attention_):
78 | """Multi-head Attention block with relative position embeddings."""
79 |
80 | def __init__(
81 | self,
82 | dim,
83 | num_heads=8,
84 | qkv_bias=True,
85 | use_rel_pos=False,
86 | rel_pos_zero_init=True,
87 | input_size=None,
88 | **block_kwargs,
89 | ):
90 | """
91 | Args:
92 | dim (int): Number of input channels.
93 | num_heads (int): Number of attention heads.
94 | qkv_bias (bool: If True, add a learnable bias to query, key, value.
95 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
96 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
97 | input_size (int or None): Input resolution for calculating the relative positional
98 | parameter size.
99 | """
100 | super().__init__(dim, num_heads=num_heads, qkv_bias=qkv_bias, **block_kwargs)
101 |
102 | self.use_rel_pos = use_rel_pos
103 | if self.use_rel_pos:
104 | # initialize relative positional embeddings
105 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
106 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
107 |
108 | if not rel_pos_zero_init:
109 | nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
110 | nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
111 |
112 | def forward(self, x, mask=None):
113 | B, N, C = x.shape
114 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
115 | q, k, v = qkv.unbind(2)
116 | if use_fp32_attention := getattr(self, 'fp32_attention', False):
117 | q, k, v = q.float(), k.float(), v.float()
118 |
119 | attn_bias = None
120 | if mask is not None:
121 | attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
122 | attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
123 | x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
124 |
125 | x = x.view(B, N, C)
126 | x = self.proj(x)
127 | x = self.proj_drop(x)
128 | return x
129 |
130 |
131 | #################################################################################
132 | # AMP attention with fp32 softmax to fix loss NaN problem during training #
133 | #################################################################################
134 | class Attention(Attention_):
135 | def forward(self, x):
136 | B, N, C = x.shape
137 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
138 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
139 | use_fp32_attention = getattr(self, 'fp32_attention', False)
140 | if use_fp32_attention:
141 | q, k = q.float(), k.float()
142 | with torch.cuda.amp.autocast(enabled=not use_fp32_attention):
143 | attn = (q @ k.transpose(-2, -1)) * self.scale
144 | attn = attn.softmax(dim=-1)
145 |
146 | attn = self.attn_drop(attn)
147 |
148 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
149 | x = self.proj(x)
150 | x = self.proj_drop(x)
151 | return x
152 |
153 |
154 | class FinalLayer(nn.Module):
155 | """
156 | The final layer of DiT.
157 | """
158 |
159 | def __init__(self, hidden_size, patch_size, out_channels):
160 | super().__init__()
161 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
162 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
163 | self.adaLN_modulation = nn.Sequential(
164 | nn.SiLU(),
165 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
166 | )
167 |
168 | def forward(self, x, c):
169 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
170 | x = modulate(self.norm_final(x), shift, scale)
171 | x = self.linear(x)
172 | return x
173 |
174 |
175 | class T2IFinalLayer(nn.Module):
176 | """
177 | The final layer of DiT.
178 | """
179 |
180 | def __init__(self, hidden_size, patch_size, out_channels):
181 | super().__init__()
182 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
183 | self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
184 | self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
185 | self.out_channels = out_channels
186 |
187 | def forward(self, x, t):
188 | shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
189 | x = t2i_modulate(self.norm_final(x), shift, scale)
190 | x = self.linear(x)
191 | return x
192 |
193 |
194 | class MaskFinalLayer(nn.Module):
195 | """
196 | The final layer of DiT.
197 | """
198 |
199 | def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
200 | super().__init__()
201 | self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
202 | self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
203 | self.adaLN_modulation = nn.Sequential(
204 | nn.SiLU(),
205 | nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
206 | )
207 | def forward(self, x, t):
208 | shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
209 | x = modulate(self.norm_final(x), shift, scale)
210 | x = self.linear(x)
211 | return x
212 |
213 |
214 | class DecoderLayer(nn.Module):
215 | """
216 | The final layer of DiT.
217 | """
218 |
219 | def __init__(self, hidden_size, decoder_hidden_size):
220 | super().__init__()
221 | self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
222 | self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
223 | self.adaLN_modulation = nn.Sequential(
224 | nn.SiLU(),
225 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
226 | )
227 | def forward(self, x, t):
228 | shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
229 | x = modulate(self.norm_decoder(x), shift, scale)
230 | x = self.linear(x)
231 | return x
232 |
233 |
234 | #################################################################################
235 | # Embedding Layers for Timesteps and Class Labels #
236 | #################################################################################
237 | class TimestepEmbedder(nn.Module):
238 | """
239 | Embeds scalar timesteps into vector representations.
240 | """
241 |
242 | def __init__(self, hidden_size, frequency_embedding_size=256):
243 | super().__init__()
244 | self.mlp = nn.Sequential(
245 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
246 | nn.SiLU(),
247 | nn.Linear(hidden_size, hidden_size, bias=True),
248 | )
249 | self.frequency_embedding_size = frequency_embedding_size
250 |
251 | @staticmethod
252 | def timestep_embedding(t, dim, max_period=10000):
253 | """
254 | Create sinusoidal timestep embeddings.
255 | :param t: a 1-D Tensor of N indices, one per batch element.
256 | These may be fractional.
257 | :param dim: the dimension of the output.
258 | :param max_period: controls the minimum frequency of the embeddings.
259 | :return: an (N, D) Tensor of positional embeddings.
260 | """
261 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
262 | half = dim // 2
263 | freqs = torch.exp(
264 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
265 | args = t[:, None].float() * freqs[None]
266 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
267 | if dim % 2:
268 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
269 | return embedding
270 |
271 | def forward(self, t):
272 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(self.dtype)
273 | return self.mlp(t_freq)
274 |
275 | @property
276 | def dtype(self):
277 | # 返回模型参数的数据类型
278 | return next(self.parameters()).dtype
279 |
280 |
281 | class SizeEmbedder(TimestepEmbedder):
282 | """
283 | Embeds scalar timesteps into vector representations.
284 | """
285 |
286 | def __init__(self, hidden_size, frequency_embedding_size=256):
287 | super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
288 | self.mlp = nn.Sequential(
289 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
290 | nn.SiLU(),
291 | nn.Linear(hidden_size, hidden_size, bias=True),
292 | )
293 | self.frequency_embedding_size = frequency_embedding_size
294 | self.outdim = hidden_size
295 |
296 | def forward(self, s, bs):
297 | if s.ndim == 1:
298 | s = s[:, None]
299 | assert s.ndim == 2
300 | if s.shape[0] != bs:
301 | s = s.repeat(bs//s.shape[0], 1)
302 | assert s.shape[0] == bs
303 | b, dims = s.shape[0], s.shape[1]
304 | s = rearrange(s, "b d -> (b d)")
305 | s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
306 | s_emb = self.mlp(s_freq)
307 | s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
308 | return s_emb
309 |
310 | @property
311 | def dtype(self):
312 | # 返回模型参数的数据类型
313 | return next(self.parameters()).dtype
314 |
315 |
316 | class LabelEmbedder(nn.Module):
317 | """
318 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
319 | """
320 |
321 | def __init__(self, num_classes, hidden_size, dropout_prob):
322 | super().__init__()
323 | use_cfg_embedding = dropout_prob > 0
324 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
325 | self.num_classes = num_classes
326 | self.dropout_prob = dropout_prob
327 |
328 | def token_drop(self, labels, force_drop_ids=None):
329 | """
330 | Drops labels to enable classifier-free guidance.
331 | """
332 | if force_drop_ids is None:
333 | drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
334 | else:
335 | drop_ids = force_drop_ids == 1
336 | labels = torch.where(drop_ids, self.num_classes, labels)
337 | return labels
338 |
339 | def forward(self, labels, train, force_drop_ids=None):
340 | use_dropout = self.dropout_prob > 0
341 | if (train and use_dropout) or (force_drop_ids is not None):
342 | labels = self.token_drop(labels, force_drop_ids)
343 | return self.embedding_table(labels)
344 |
345 |
346 | class CaptionEmbedder2(nn.Module):
347 | """
348 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
349 | """
350 |
351 | def __init__(self, in_channels, hidden_size, act_layer=nn.GELU(approximate='tanh')):
352 | super().__init__()
353 | self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
354 |
355 |
356 | def forward(self, caption):
357 | caption = self.y_proj(caption)
358 | return caption
359 |
360 |
361 | class CaptionEmbedder(nn.Module):
362 | """
363 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
364 | """
365 |
366 | def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh')):
367 | super().__init__()
368 | self.y_proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
369 | self.register_buffer("y_embedding", nn.Parameter(torch.randn(1, in_channels) / in_channels ** 0.5))
370 | self.uncond_prob = uncond_prob
371 |
372 | def token_drop(self, caption, force_drop_ids=None):
373 | """
374 | Drops labels to enable classifier-free guidance.
375 | """
376 | if force_drop_ids is None:
377 | drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
378 | else:
379 | drop_ids = force_drop_ids == 1
380 | null_y = self.y_embedding.repeat(caption.shape[1],1)
381 | caption = torch.where(drop_ids[:, None, None, None], null_y, caption)
382 | return caption
383 |
384 | def forward(self, caption, force_drop_ids=None):
385 | use_dropout = self.uncond_prob > 0
386 | if use_dropout or (force_drop_ids is not None):
387 | caption = self.token_drop(caption, force_drop_ids)
388 | caption = self.y_proj(caption)
389 | return caption
390 |
391 |
392 | class CaptionEmbedderDoubleBr(nn.Module):
393 | """
394 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
395 | """
396 |
397 | def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120):
398 | super().__init__()
399 | self.proj = Mlp(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0)
400 | self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
401 | self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
402 | self.uncond_prob = uncond_prob
403 |
404 | def token_drop(self, global_caption, caption, force_drop_ids=None):
405 | """
406 | Drops labels to enable classifier-free guidance.
407 | """
408 | if force_drop_ids is None:
409 | drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
410 | else:
411 | drop_ids = force_drop_ids == 1
412 | global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
413 | caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
414 | return global_caption, caption
415 |
416 | def forward(self, caption, train, force_drop_ids=None):
417 | assert caption.shape[2: ] == self.y_embedding.shape
418 | global_caption = caption.mean(dim=2).squeeze()
419 | use_dropout = self.uncond_prob > 0
420 | if (train and use_dropout) or (force_drop_ids is not None):
421 | global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
422 | y_embed = self.proj(global_caption)
423 | return y_embed, caption
--------------------------------------------------------------------------------
/diffusion/model/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch.nn as nn
4 | from torch.utils.checkpoint import checkpoint, checkpoint_sequential
5 | import torch.nn.functional as F
6 | import torch
7 | import torch.distributed as dist
8 | import re
9 | import math
10 | from collections.abc import Iterable
11 | from itertools import repeat
12 | from torchvision import transforms as T
13 | import random
14 | from PIL import Image
15 | from torch import _assert
16 |
17 |
18 | def _ntuple(n):
19 | def parse(x):
20 | if isinstance(x, Iterable) and not isinstance(x, str):
21 | return x
22 | return tuple(repeat(x, n))
23 | return parse
24 |
25 |
26 | to_1tuple = _ntuple(1)
27 | to_2tuple = _ntuple(2)
28 |
29 | def set_grad_checkpoint(model, use_fp32_attention=False, gc_step=1):
30 | assert isinstance(model, nn.Module)
31 |
32 | def set_attr(module):
33 | module.grad_checkpointing = True
34 | module.fp32_attention = use_fp32_attention
35 | module.grad_checkpointing_step = gc_step
36 | model.apply(set_attr)
37 |
38 |
39 | def auto_grad_checkpoint(module, *args, **kwargs):
40 | if getattr(module, 'grad_checkpointing', False):
41 | if not isinstance(module, Iterable):
42 | return checkpoint(module, *args, **kwargs)
43 | gc_step = module[0].grad_checkpointing_step
44 | return checkpoint_sequential(module, gc_step, *args, **kwargs)
45 | return module(*args, **kwargs)
46 |
47 |
48 | def checkpoint_sequential(functions, step, input, *args, **kwargs):
49 |
50 | # Hack for keyword-only parameter in a python 2.7-compliant way
51 | preserve = kwargs.pop('preserve_rng_state', True)
52 | if kwargs:
53 | raise ValueError("Unexpected keyword arguments: " + ",".join(kwargs))
54 |
55 | def run_function(start, end, functions):
56 | def forward(input):
57 | for j in range(start, end + 1):
58 | input = functions[j](input, *args)
59 | return input
60 | return forward
61 |
62 | if isinstance(functions, torch.nn.Sequential):
63 | functions = list(functions.children())
64 |
65 | # the last chunk has to be non-volatile
66 | end = -1
67 | segment = len(functions) // step
68 | for start in range(0, step * (segment - 1), step):
69 | end = start + step - 1
70 | input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
71 | return run_function(end + 1, len(functions) - 1, functions)(input)
72 |
73 |
74 | def window_partition(x, window_size):
75 | """
76 | Partition into non-overlapping windows with padding if needed.
77 | Args:
78 | x (tensor): input tokens with [B, H, W, C].
79 | window_size (int): window size.
80 |
81 | Returns:
82 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
83 | (Hp, Wp): padded height and width before partition
84 | """
85 | B, H, W, C = x.shape
86 |
87 | pad_h = (window_size - H % window_size) % window_size
88 | pad_w = (window_size - W % window_size) % window_size
89 | if pad_h > 0 or pad_w > 0:
90 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
91 | Hp, Wp = H + pad_h, W + pad_w
92 |
93 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
94 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
95 | return windows, (Hp, Wp)
96 |
97 |
98 | def window_unpartition(windows, window_size, pad_hw, hw):
99 | """
100 | Window unpartition into original sequences and removing padding.
101 | Args:
102 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
103 | window_size (int): window size.
104 | pad_hw (Tuple): padded height and width (Hp, Wp).
105 | hw (Tuple): original height and width (H, W) before padding.
106 |
107 | Returns:
108 | x: unpartitioned sequences with [B, H, W, C].
109 | """
110 | Hp, Wp = pad_hw
111 | H, W = hw
112 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
113 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
114 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
115 |
116 | if Hp > H or Wp > W:
117 | x = x[:, :H, :W, :].contiguous()
118 | return x
119 |
120 |
121 | def get_rel_pos(q_size, k_size, rel_pos):
122 | """
123 | Get relative positional embeddings according to the relative positions of
124 | query and key sizes.
125 | Args:
126 | q_size (int): size of query q.
127 | k_size (int): size of key k.
128 | rel_pos (Tensor): relative position embeddings (L, C).
129 |
130 | Returns:
131 | Extracted positional embeddings according to relative positions.
132 | """
133 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
134 | # Interpolate rel pos if needed.
135 | if rel_pos.shape[0] != max_rel_dist:
136 | # Interpolate rel pos.
137 | rel_pos_resized = F.interpolate(
138 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
139 | size=max_rel_dist,
140 | mode="linear",
141 | )
142 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
143 | else:
144 | rel_pos_resized = rel_pos
145 |
146 | # Scale the coords with short length if shapes for q and k are different.
147 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
148 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
149 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
150 |
151 | return rel_pos_resized[relative_coords.long()]
152 |
153 |
154 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
155 | """
156 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
157 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
158 | Args:
159 | attn (Tensor): attention map.
160 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
161 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
162 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
163 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
164 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
165 |
166 | Returns:
167 | attn (Tensor): attention map with added relative positional embeddings.
168 | """
169 | q_h, q_w = q_size
170 | k_h, k_w = k_size
171 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
172 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
173 |
174 | B, _, dim = q.shape
175 | r_q = q.reshape(B, q_h, q_w, dim)
176 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
177 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
178 |
179 | attn = (
180 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
181 | ).view(B, q_h * q_w, k_h * k_w)
182 |
183 | return attn
184 |
185 | def mean_flat(tensor):
186 | return tensor.mean(dim=list(range(1, tensor.ndim)))
187 |
188 |
189 | #################################################################################
190 | # Token Masking and Unmasking #
191 | #################################################################################
192 | def get_mask(batch, length, mask_ratio, device, mask_type=None, data_info=None, extra_len=0):
193 | """
194 | Get the binary mask for the input sequence.
195 | Args:
196 | - batch: batch size
197 | - length: sequence length
198 | - mask_ratio: ratio of tokens to mask
199 | - data_info: dictionary with info for reconstruction
200 | return:
201 | mask_dict with following keys:
202 | - mask: binary mask, 0 is keep, 1 is remove
203 | - ids_keep: indices of tokens to keep
204 | - ids_restore: indices to restore the original order
205 | """
206 | assert mask_type in ['random', 'fft', 'laplacian', 'group']
207 | mask = torch.ones([batch, length], device=device)
208 | len_keep = int(length * (1 - mask_ratio)) - extra_len
209 |
210 | if mask_type in ['random', 'group']:
211 | noise = torch.rand(batch, length, device=device) # noise in [0, 1]
212 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
213 | ids_restore = torch.argsort(ids_shuffle, dim=1)
214 | # keep the first subset
215 | ids_keep = ids_shuffle[:, :len_keep]
216 | ids_removed = ids_shuffle[:, len_keep:]
217 |
218 | elif mask_type in ['fft', 'laplacian']:
219 | if 'strength' in data_info:
220 | strength = data_info['strength']
221 |
222 | else:
223 | N = data_info['N'][0]
224 | img = data_info['ori_img']
225 | # 获取原图的尺寸信息
226 | _, C, H, W = img.shape
227 | if mask_type == 'fft':
228 | # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N)
229 | reshaped_image = img.reshape((batch, -1, H // N, N, W // N, N))
230 | fft_image = torch.fft.fftn(reshaped_image, dim=(3, 5))
231 | # 取绝对值并求和获取频率强度
232 | strength = torch.sum(torch.abs(fft_image), dim=(1, 3, 5)).reshape((batch, -1,))
233 | elif type == 'laplacian':
234 | laplacian_kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32).reshape(1, 1, 3, 3)
235 | laplacian_kernel = laplacian_kernel.repeat(C, 1, 1, 1)
236 | # 对图片进行reshape,将其变为patch (3, H/N, N, W/N, N)
237 | reshaped_image = img.reshape(-1, C, H // N, N, W // N, N).permute(0, 2, 4, 1, 3, 5).reshape(-1, C, N, N)
238 | laplacian_response = F.conv2d(reshaped_image, laplacian_kernel, padding=1, groups=C)
239 | strength = laplacian_response.sum(dim=[1, 2, 3]).reshape((batch, -1,))
240 |
241 | # 对频率强度进行归一化,然后使用torch.multinomial进行采样
242 | probabilities = strength / (strength.max(dim=1)[0][:, None]+1e-5)
243 | ids_shuffle = torch.multinomial(probabilities.clip(1e-5, 1), length, replacement=False)
244 | ids_keep = ids_shuffle[:, :len_keep]
245 | ids_restore = torch.argsort(ids_shuffle, dim=1)
246 | ids_removed = ids_shuffle[:, len_keep:]
247 |
248 | mask[:, :len_keep] = 0
249 | mask = torch.gather(mask, dim=1, index=ids_restore)
250 |
251 | return {'mask': mask,
252 | 'ids_keep': ids_keep,
253 | 'ids_restore': ids_restore,
254 | 'ids_removed': ids_removed}
255 |
256 |
257 | def mask_out_token(x, ids_keep, ids_removed=None):
258 | """
259 | Mask out the tokens specified by ids_keep.
260 | Args:
261 | - x: input sequence, [N, L, D]
262 | - ids_keep: indices of tokens to keep
263 | return:
264 | - x_masked: masked sequence
265 | """
266 | N, L, D = x.shape # batch, length, dim
267 | x_remain = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
268 | if ids_removed is not None:
269 | x_masked = torch.gather(x, dim=1, index=ids_removed.unsqueeze(-1).repeat(1, 1, D))
270 | return x_remain, x_masked
271 | else:
272 | return x_remain
273 |
274 |
275 | def mask_tokens(x, mask_ratio):
276 | """
277 | Perform per-sample random masking by per-sample shuffling.
278 | Per-sample shuffling is done by argsort random noise.
279 | x: [N, L, D], sequence
280 | """
281 | N, L, D = x.shape # batch, length, dim
282 | len_keep = int(L * (1 - mask_ratio))
283 |
284 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
285 |
286 | # sort noise for each sample
287 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
288 | ids_restore = torch.argsort(ids_shuffle, dim=1)
289 |
290 | # keep the first subset
291 | ids_keep = ids_shuffle[:, :len_keep]
292 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
293 |
294 | # generate the binary mask: 0 is keep, 1 is remove
295 | mask = torch.ones([N, L], device=x.device)
296 | mask[:, :len_keep] = 0
297 | mask = torch.gather(mask, dim=1, index=ids_restore)
298 |
299 | return x_masked, mask, ids_restore
300 |
301 |
302 | def unmask_tokens(x, ids_restore, mask_token):
303 | # x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D]
304 | mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
305 | x = torch.cat([x, mask_tokens], dim=1)
306 | x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
307 | return x
308 |
309 |
310 | # Parse 'None' to None and others to float value
311 | def parse_float_none(s):
312 | assert isinstance(s, str)
313 | return None if s == 'None' else float(s)
314 |
315 |
316 | #----------------------------------------------------------------------------
317 | # Parse a comma separated list of numbers or ranges and return a list of ints.
318 | # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
319 |
320 | def parse_int_list(s):
321 | if isinstance(s, list): return s
322 | ranges = []
323 | range_re = re.compile(r'^(\d+)-(\d+)$')
324 | for p in s.split(','):
325 | if m := range_re.match(p):
326 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
327 | else:
328 | ranges.append(int(p))
329 | return ranges
330 |
331 |
332 | def init_processes(fn, args):
333 | """ Initialize the distributed environment. """
334 | os.environ['MASTER_ADDR'] = args.master_address
335 | os.environ['MASTER_PORT'] = str(random.randint(2000, 6000))
336 | print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}')
337 | print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}')
338 | torch.cuda.set_device(args.local_rank)
339 | dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size)
340 | fn(args)
341 | if args.global_size > 1:
342 | cleanup()
343 |
344 |
345 | def mprint(*args, **kwargs):
346 | """
347 | Print only from rank 0.
348 | """
349 | if dist.get_rank() == 0:
350 | print(*args, **kwargs)
351 |
352 |
353 | def cleanup():
354 | """
355 | End DDP training.
356 | """
357 | dist.barrier()
358 | mprint("Done!")
359 | dist.barrier()
360 | dist.destroy_process_group()
361 |
362 |
363 | #----------------------------------------------------------------------------
364 | # logging info.
365 | class Logger(object):
366 | """
367 | Redirect stderr to stdout, optionally print stdout to a file,
368 | and optionally force flushing on both stdout and the file.
369 | """
370 |
371 | def __init__(self, file_name=None, file_mode="w", should_flush=True):
372 | self.file = None
373 |
374 | if file_name is not None:
375 | self.file = open(file_name, file_mode)
376 |
377 | self.should_flush = should_flush
378 | self.stdout = sys.stdout
379 | self.stderr = sys.stderr
380 |
381 | sys.stdout = self
382 | sys.stderr = self
383 |
384 | def __enter__(self):
385 | return self
386 |
387 | def __exit__(self, exc_type, exc_value, traceback):
388 | self.close()
389 |
390 | def write(self, text):
391 | """Write text to stdout (and a file) and optionally flush."""
392 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
393 | return
394 |
395 | if self.file is not None:
396 | self.file.write(text)
397 |
398 | self.stdout.write(text)
399 |
400 | if self.should_flush:
401 | self.flush()
402 |
403 | def flush(self):
404 | """Flush written text to both stdout and a file, if open."""
405 | if self.file is not None:
406 | self.file.flush()
407 |
408 | self.stdout.flush()
409 |
410 | def close(self):
411 | """Flush, close possible files, and remove stdout/stderr mirroring."""
412 | self.flush()
413 |
414 | # if using multiple loggers, prevent closing in wrong order
415 | if sys.stdout is self:
416 | sys.stdout = self.stdout
417 | if sys.stderr is self:
418 | sys.stderr = self.stderr
419 |
420 | if self.file is not None:
421 | self.file.close()
422 |
423 |
424 | class StackedRandomGenerator:
425 | def __init__(self, device, seeds):
426 | super().__init__()
427 | self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
428 |
429 | def randn(self, size, **kwargs):
430 | assert size[0] == len(self.generators)
431 | return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
432 |
433 | def randn_like(self, input):
434 | return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
435 |
436 | def randint(self, *args, size, **kwargs):
437 | assert size[0] == len(self.generators)
438 | return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
439 |
440 |
441 | def prepare_prompt_ar(prompt, ratios, device='cpu', show=True):
442 | # get aspect_ratio or ar
443 | aspect_ratios = re.findall(r"--aspect_ratio\s+(\d+:\d+)", prompt)
444 | ars = re.findall(r"--ar\s+(\d+:\d+)", prompt)
445 | custom_hw = re.findall(r"--hw\s+(\d+:\d+)", prompt)
446 | if show:
447 | print("aspect_ratios:", aspect_ratios, "ars:", ars, "hws:", custom_hw)
448 | prompt_clean = prompt.split("--aspect_ratio")[0].split("--ar")[0].split("--hw")[0]
449 | if len(aspect_ratios) + len(ars) + len(custom_hw) == 0 and show:
450 | print( "Wrong prompt format. Set to default ar: 1. change your prompt into format '--ar h:w or --hw h:w' for correct generating")
451 | if len(aspect_ratios) != 0:
452 | ar = float(aspect_ratios[0].split(':')[0]) / float(aspect_ratios[0].split(':')[1])
453 | elif len(ars) != 0:
454 | ar = float(ars[0].split(':')[0]) / float(ars[0].split(':')[1])
455 | else:
456 | ar = 1.
457 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
458 | if len(custom_hw) != 0:
459 | custom_hw = [float(custom_hw[0].split(':')[0]), float(custom_hw[0].split(':')[1])]
460 | else:
461 | custom_hw = ratios[closest_ratio]
462 | default_hw = ratios[closest_ratio]
463 | prompt_show = f'prompt: {prompt_clean.strip()}\nSize: --ar {closest_ratio}, --bin hw {ratios[closest_ratio]}, --custom hw {custom_hw}'
464 | return prompt_clean, prompt_show, torch.tensor(default_hw, device=device)[None], torch.tensor([float(closest_ratio)], device=device)[None], torch.tensor(custom_hw, device=device)[None]
465 |
466 |
467 | def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int):
468 | orig_hw = torch.tensor([samples.shape[2], samples.shape[3]], dtype=torch.int)
469 | custom_hw = torch.tensor([int(new_height), int(new_width)], dtype=torch.int)
470 |
471 | if (orig_hw != custom_hw).all():
472 | ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1])
473 | resized_width = int(orig_hw[1] * ratio)
474 | resized_height = int(orig_hw[0] * ratio)
475 |
476 | transform = T.Compose([
477 | T.Resize((resized_height, resized_width)),
478 | T.CenterCrop(custom_hw.tolist())
479 | ])
480 | return transform(samples)
481 | else:
482 | return samples
483 |
484 |
485 | def resize_and_crop_img(img: Image, new_width, new_height):
486 | orig_width, orig_height = img.size
487 |
488 | ratio = max(new_width/orig_width, new_height/orig_height)
489 | resized_width = int(orig_width * ratio)
490 | resized_height = int(orig_height * ratio)
491 |
492 | img = img.resize((resized_width, resized_height), Image.LANCZOS)
493 |
494 | left = (resized_width - new_width)/2
495 | top = (resized_height - new_height)/2
496 | right = (resized_width + new_width)/2
497 | bottom = (resized_height + new_height)/2
498 |
499 | img = img.crop((left, top, right, bottom))
500 |
501 | return img
502 |
503 |
504 |
505 | def mask_feature(emb, mask):
506 | if emb.shape[0] == 1:
507 | keep_index = mask.sum().item()
508 | return emb[:, :, :keep_index, :], keep_index
509 | else:
510 | masked_feature = emb * mask[:, None, :, None]
511 | return masked_feature, emb.shape[2]
512 |
513 |
514 | from enum import Enum
515 | class Format(str, Enum):
516 | NCHW = 'NCHW'
517 | NHWC = 'NHWC'
518 | NCL = 'NCL'
519 | NLC = 'NLC'
520 |
521 | class PatchEmbed(nn.Module):
522 | """ 2D Image to Patch Embedding
523 | """
524 | output_fmt: Format
525 | dynamic_img_pad: torch.jit.Final[bool]
526 |
527 | def __init__(
528 | self,
529 | img_size=224,
530 | patch_size=16,
531 | in_chans=3,
532 | embed_dim=768,
533 | norm_layer=None,
534 | flatten=True,
535 | bias=True,
536 | ):
537 | super().__init__()
538 | if isinstance(img_size, int):
539 | img_size = to_2tuple(img_size)
540 | patch_size = to_2tuple(patch_size)
541 | self.img_size = img_size
542 | self.patch_size = patch_size
543 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
544 | self.num_patches = self.grid_size[0] * self.grid_size[1]
545 | self.flatten = flatten
546 |
547 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
548 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
549 |
550 | def forward(self, x):
551 | B, C, H, W = x.shape
552 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
553 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
554 | x = self.proj(x)
555 | if self.flatten:
556 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
557 | x = self.norm(x)
558 | return x
--------------------------------------------------------------------------------
/voicedit/pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import torch
4 | import torchaudio
5 | import numpy as np
6 | from tqdm.auto import tqdm
7 | from PIL import Image
8 | from copy import deepcopy
9 | import random
10 |
11 | from transformers import (
12 | ClapModel,
13 | ClapProcessor,
14 | CLIPModel,
15 | CLIPProcessor,
16 | SpeechT5HifiGan,
17 | )
18 | from diffusers import (
19 | AutoencoderKL,
20 | DDIMScheduler,
21 | UNet2DModel,
22 | )
23 | try:
24 | from diffusers.utils import randn_tensor
25 | except:
26 | from diffusers.utils.torch_utils import randn_tensor
27 | from huggingface_hub import hf_hub_download
28 |
29 | from .utils import sequence_mask, generate_path, fix_len_compatibility
30 | from espnet2.bin.spk_inference import Speech2Embedding
31 | from i2a_translator import DiffusionPrior
32 |
33 | from diffusion.model.nets.DiT2 import DiT_models
34 | from diffusion.model.nets.DiT import DiT_cross_models
35 | from diffusion import create_diffusion
36 |
37 | from text import text_to_sequence, cmudict
38 | from text.symbols import symbols
39 | from .modules import DiTWrapper
40 | from .utils import intersperse
41 |
42 | # helper function
43 | def exists(val):
44 | return val is not None
45 |
46 |
47 | class VoiceDiTPipeline():
48 | def __init__(
49 | self,
50 | ckpt_path,
51 | v2a_ckpt_path = None,
52 | t2a_ckpt_path = None,
53 | device = None,
54 | cmudict_path='voicedit/cmu_dictionary',
55 | male_voice = None,
56 | female_voice = None,
57 | ):
58 |
59 | self.vae = AutoencoderKL.from_pretrained("cvssp/audioldm-m-full", subfolder="vae").eval()
60 | self.vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm-m-full", subfolder="vocoder").eval()
61 | self.clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").eval()
62 | self.clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
63 | self.speech2spk_embed = Speech2Embedding.from_pretrained(model_tag="espnet/voxcelebs12_ecapa_wavlm_joint", device=str(device))
64 | self.cmudict = cmudict.CMUDict(cmudict_path)
65 |
66 | if exists(v2a_ckpt_path):
67 | self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").eval()
68 | self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
69 | self.v2a_mapper = DiffusionPrior.from_pretrained(v2a_ckpt_path).eval()
70 |
71 | if exists(t2a_ckpt_path):
72 | self.t2a_mapper = DiffusionPrior.from_pretrained(t2a_ckpt_path).eval()
73 |
74 | model_config = ckpt_path.rsplit('/', 2)[0] + '/config.yaml'
75 | with open(model_config, 'r') as f:
76 | self.config = yaml.safe_load(f)
77 |
78 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
79 | if self.config['concat_y']:
80 | dit_model = DiT_models[self.config['model']]
81 | else:
82 | dit_model = DiT_cross_models[self.config['model']]
83 | dit = dit_model(
84 | input_size = (1024 // vae_scale_factor, 64 // vae_scale_factor),
85 | in_channels = 8,
86 | projection_class_embeddings_input_dim=512,
87 | caption_channels = 64,
88 | concat_y = self.config.get("concat_y", False),
89 | cross_class = self.config.get("cross_class", False),
90 | )
91 |
92 | # TODO: Get checkpoints
93 | def load_ckpt(model, ckpt_path):
94 | print(f"Loading checkpoint from {ckpt_path}")
95 | ckpt = torch.load(ckpt_path, map_location="cpu")
96 | model.load_state_dict(ckpt)
97 | print(f"Loaded checkpoint from {ckpt_path}")
98 | return model
99 |
100 | self.model = load_ckpt(DiTWrapper(dit, cond_speaker=True, concat_y=self.config.get("concat_y", False)), ckpt_path)
101 | self.model.eval()
102 |
103 | self.device = device
104 | self.vae.to(device)
105 | self.vocoder.to(device)
106 | self.clap_model.to(device)
107 |
108 | if exists(v2a_ckpt_path):
109 | self.clip_model.to(device)
110 | self.v2a_mapper.to(device)
111 | if exists(t2a_ckpt_path):
112 | self.t2a_mapper.to(device)
113 | self.model.to(device)
114 |
115 | # if exists(male_voice):
116 | # wav, sr = torchaudio.load(male_voice)
117 | # if sr != 16000:
118 | # wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=16000)
119 | # self.male_voice = self.speech2spk_embed(wav.squeeze(0).to(device))
120 |
121 | # if exists(female_voice):
122 | # wav, sr = torchaudio.load(female_voice)
123 | # if sr != 16000:
124 | # wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=16000)
125 | # self.female_voice = self.speech2spk_embed(wav.squeeze(0).to(device))
126 |
127 | if exists(male_voice):
128 | with open(male_voice, 'r') as f:
129 | self.male_voice = f.readlines()
130 |
131 | if exists(female_voice):
132 | with open(female_voice, 'r') as f:
133 | self.female_voice = f.readlines()
134 |
135 | def get_text(self, text, add_blank=True):
136 | text_norm = text_to_sequence(text, dictionary=self.cmudict)
137 | if add_blank:
138 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
139 | text_norm = torch.IntTensor(text_norm)
140 | return text_norm
141 |
142 | def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
143 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
144 | shape = (
145 | batch_size,
146 | num_channels_latents,
147 | height // vae_scale_factor,
148 | self.vocoder.config.model_in_dim // vae_scale_factor,
149 | )
150 | if isinstance(generator, list) and len(generator) != batch_size:
151 | raise ValueError(
152 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
153 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
154 | )
155 |
156 | if latents is None:
157 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
158 | else:
159 | latents = latents.to(device)
160 |
161 | # latents = latents * self.noise_scheduler.init_noise_sigma
162 | return latents
163 |
164 | def decode_latents(self, latents):
165 | latents = 1 / self.vae.config.scaling_factor * latents
166 | mel_spectrogram = self.vae.decode(latents).sample
167 | return mel_spectrogram
168 |
169 | def mel_spectrogram_to_waveform(self, mel_spectrogram):
170 | if mel_spectrogram.dim() == 4:
171 | mel_spectrogram = mel_spectrogram.squeeze(1)
172 |
173 | waveform = self.vocoder(mel_spectrogram)
174 | waveform = waveform.cpu().float()
175 | return waveform
176 |
177 | def normalize_wav(self, waveform):
178 | waveform = waveform - torch.mean(waveform)
179 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
180 | return waveform
181 |
182 | @torch.no_grad()
183 | def __call__(
184 | self,
185 | modality,
186 | env_prompt,
187 | cont_prompt,
188 | style_prompt = None,
189 | speaker_audio = None,
190 | gender = None,
191 | batch_size = 1,
192 | num_inference_steps = 100,
193 | audio_length_in_s = 10,
194 | do_classifier_free_guidance = True,
195 | v2a_guidance_scale = None,
196 | guidance_scale = None,
197 | desc_guidance_scale = None,
198 | cont_guidance_scale = None,
199 | device=None,
200 | seed=None,
201 | progress=True,
202 | **kwargs,
203 | ):
204 |
205 | if guidance_scale is None and desc_guidance_scale is None and cont_guidance_scale is None:
206 | do_classifier_free_guidance = False
207 |
208 | guidance = None
209 | if guidance_scale is None:
210 | guidance = "dual"
211 | else:
212 | guidance = "single"
213 |
214 | # description condition
215 | if modality == 'text':
216 | if do_classifier_free_guidance:
217 | if guidance == "dual":
218 | env_prompt = [env_prompt] * 2 + [""] * 2
219 | if guidance == "single":
220 | env_prompt = [env_prompt] + [""]
221 | else:
222 | env_prompt = [env_prompt]
223 |
224 | clap_inputs = self.clap_processor(
225 | text=env_prompt,
226 | return_tensors="pt",
227 | padding=True
228 | ).to(self.device)
229 |
230 | c_desc = self.clap_model.get_text_features(**clap_inputs)
231 |
232 | elif modality == 'audio':
233 | audio_sample, sr = torchaudio.load(env_prompt)
234 | if sr != 48000:
235 | audio_sample = torchaudio.functional.resample(audio_sample, orig_freq=sr, new_freq=48000)
236 | audio_sample = audio_sample[0]
237 |
238 | clap_inputs = self.clap_processor(audios=audio_sample, sampling_rate=48000, return_tensors="pt", padding=True).to(self.device)
239 | c_desc = self.clap_model.get_audio_features(**clap_inputs)
240 |
241 | if do_classifier_free_guidance:
242 | clap_inputs = self.clap_processor(text=[""], return_tensors="pt", padding=True).to(self.device)
243 | # uncond_embeds = self.clap_model.text_model(**clap_inputs).pooler_output
244 | # uc_desc = self.clap_model.text_projection(uncond_embeds)
245 | uc_desc = self.clap_model.get_text_features(**clap_inputs)
246 |
247 | if guidance == "dual":
248 | c_desc = torch.cat((c_desc, c_desc, uc_desc, uc_desc))
249 | if guidance == "single":
250 | c_desc = torch.cat((c_desc, uc_desc))
251 |
252 | elif modality == 'image':
253 | image_sample = Image.open(env_prompt)
254 |
255 | clip_inputs = self.clip_processor(images=image_sample, return_tensors="pt").to(self.device)
256 | # clip_embeds = self.clip_model.vision_model(**clip_inputs).pooler_output
257 | # clip_embeds = self.clip_model.visual_projection(clip_embeds)
258 | # clip_embeds = torch.nn.functional.normalize(clip_embeds, dim=-1)
259 | clip_embeds = self.clip_model.get_image_features(**clip_inputs)
260 | # clip_embeds = env_prompt
261 | text_cond = dict(text_embed = clip_embeds)
262 |
263 | c_desc = self.v2a_mapper.p_sample_loop(
264 | clip_embeds.shape,
265 | text_cond = text_cond,
266 | cond_scale = v2a_guidance_scale,
267 | timesteps = 100
268 | )
269 |
270 | c_desc = torch.nn.functional.normalize(c_desc, dim=-1)
271 |
272 | if do_classifier_free_guidance:
273 | clip_inputs = self.clap_processor(text=[""], return_tensors="pt", padding=True).to(self.device)
274 | uc_desc = self.clap_model.get_text_features(**clip_inputs)
275 |
276 | if guidance == "dual":
277 | c_desc = torch.cat((c_desc, c_desc, uc_desc, uc_desc))
278 | if guidance == "single":
279 | c_desc = torch.cat((c_desc, uc_desc))
280 |
281 | # speaker style conditon
282 | if style_prompt is not None:
283 | if do_classifier_free_guidance:
284 | if guidance == "dual":
285 | style_prompt = [style_prompt] * 2 + [""] * 2
286 | if guidance == "single":
287 | style_prompt = [style_prompt] + [""]
288 | else:
289 | style_prompt = [style_prompt]
290 |
291 | clap_text_tokens = self.clap_tokenizer(style_prompt, return_tensors="pt", padding=True).to(self.device)
292 | c_style = self.clap_model.get_text_features(**clap_text_tokens)
293 |
294 | elif gender is not None:
295 |
296 | if gender == 'man':
297 | spk_audio = random.choice(self.male_voice).strip()
298 | elif gender == 'woman':
299 | spk_audio = random.choice(self.female_voice).strip()
300 | else:
301 | spk_audio = random.choice(self.male_voice + self.female_voice).strip()
302 | spk_audio, sr = torchaudio.load(spk_audio)
303 |
304 | if sr != 16000:
305 | spk_audio = torchaudio.functional.resample(spk_audio, orig_freq=sr, new_freq=16000)
306 | if spk_audio.shape[0] > 1:
307 | spk_audio = spk_audio.mean(0, keepdim=True)
308 | c_style = self.speech2spk_embed(spk_audio.squeeze(0))
309 |
310 | elif speaker_audio is not None:
311 | spk_audio, sr = torchaudio.load(speaker_audio)
312 | if sr != 16000:
313 | spk_audio = torchaudio.functional.resample(spk_audio, orig_freq=sr, new_freq=16000)
314 | if spk_audio.shape[0] > 1:
315 | spk_audio = spk_audio.mean(0, keepdim=True)
316 | c_style=self.speech2spk_embed(spk_audio.squeeze(0))
317 | else:
318 | c_style = torch.zeros(192).to(self.device)
319 |
320 | cont_tokens = self.get_text(cont_prompt, add_blank=True)
321 | cont_tokens = cont_tokens.unsqueeze(0).to(self.device)
322 | cont_lengths = torch.LongTensor([cont_tokens.shape[-1]]).to(self.device)
323 |
324 | mu_y, y_mask = self.model.process_content(cont_tokens, cont_lengths, c_style)
325 |
326 | if self.model.concat_y:
327 | mu_y_vae = self.model.proj(mu_y.unsqueeze(1))
328 | else:
329 | mu_y_vae = mu_y.unsqueeze(1)
330 |
331 | vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
332 | height = int(audio_length_in_s * 1.024 / vocoder_upsample_factor)
333 | original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
334 |
335 | # prepare latent variables
336 | num_channels_latents = self.model.dit.in_channels
337 | latents = self.prepare_latents(
338 | batch_size,
339 | num_channels_latents,
340 | height,
341 | c_desc.dtype,
342 | device=device,
343 | generator=torch.manual_seed(seed) if seed else None,
344 | latents=None,
345 | )
346 |
347 | # denoising loop
348 |
349 | if guidance == "dual":
350 | z = torch.cat([latents] * 4) if do_classifier_free_guidance else latents
351 | cfg_scale = (desc_guidance_scale, cont_guidance_scale)
352 | elif guidance == "single":
353 | z = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
354 | cfg_scale = guidance_scale
355 |
356 | if self.model.concat_y:
357 | mu_y_vae = torch.nn.functional.pad(mu_y_vae, (0, 0, 0, z.shape[-2] - mu_y_vae.shape[-2]), value=0)
358 | if do_classifier_free_guidance:
359 | if self.config['concat_y']:
360 | if guidance == "dual":
361 | mu_y_vae = torch.cat([mu_y_vae, self.model.y_embedding.unsqueeze(0)] * 2, dim=0)
362 | elif guidance == "single":
363 | mu_y_vae = torch.cat([mu_y_vae, self.model.y_embedding.unsqueeze(0)], dim=0)
364 | else:
365 | null_y = self.model.dit.y_embedder.y_embedding.unsqueeze(1).repeat(1, mu_y_vae.shape[-2], 1)
366 | if guidance == "dual":
367 | mu_y_vae = torch.cat([mu_y_vae, null_y.unsqueeze(0)] * 2, dim=0)
368 | elif guidance == "single":
369 | mu_y_vae = torch.cat([mu_y_vae, null_y.unsqueeze(0)], dim=0)
370 |
371 | model_kwargs = dict(
372 | y=mu_y_vae,
373 | mask=y_mask.long().squeeze(1) if not self.model.concat_y else None,
374 | class_labels=c_desc,
375 | guidance=guidance,
376 | cfg_scale = cfg_scale,
377 | )
378 |
379 | diffusion = create_diffusion(str(num_inference_steps))
380 |
381 | # Sample images:
382 | samples = diffusion.p_sample_loop(
383 | self.model.dit.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=progress,
384 | device=self.device
385 | )
386 | if guidance == "dual":
387 | samples = samples[: len(samples) // 4]
388 | elif guidance == "single":
389 | samples = samples[: len(samples) // 2]
390 |
391 | mel_spectrogram = self.decode_latents(samples)
392 | audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
393 | audio = audio[:, :original_waveform_length]
394 | audio = self.normalize_wav(audio)
395 |
396 | return audio, mel_spectrogram
397 |
398 |
399 | class EvalPipeline():
400 | def __init__(
401 | self,
402 | vae,
403 | clap_model: ClapModel,
404 | clap_processor: ClapProcessor,
405 | device,
406 | cmudict_path='voicedit/cmu_dictionary'
407 | ):
408 | self.vae = vae
409 | self.vocoder = SpeechT5HifiGan.from_pretrained("cvssp/audioldm-m-full", subfolder="vocoder").eval()
410 | self.clap_model = clap_model
411 | self.clap_processor = clap_processor
412 | self.speech2spk_embed = Speech2Embedding.from_pretrained(model_tag="espnet/voxcelebs12_ecapa_wavlm_joint")
413 | self.device = device
414 | self.vocoder.to(device)
415 | self.cmudict = cmudict.CMUDict(cmudict_path)
416 |
417 | def get_text(self, text, add_blank=True):
418 | text_norm = text_to_sequence(text, dictionary=self.cmudict)
419 | if add_blank:
420 | text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
421 | text_norm = torch.IntTensor(text_norm)
422 | return text_norm
423 |
424 | def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
425 | vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
426 | shape = (
427 | batch_size,
428 | num_channels_latents,
429 | height // vae_scale_factor,
430 | self.vocoder.config.model_in_dim // vae_scale_factor,
431 | )
432 | if isinstance(generator, list) and len(generator) != batch_size:
433 | raise ValueError(
434 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
435 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
436 | )
437 |
438 | if latents is None:
439 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
440 | else:
441 | latents = latents.to(device)
442 |
443 | return latents
444 |
445 | def decode_latents(self, latents):
446 | latents = 1 / self.vae.config.scaling_factor * latents
447 | mel_spectrogram = self.vae.decode(latents).sample
448 | return mel_spectrogram
449 |
450 | def mel_spectrogram_to_waveform(self, mel_spectrogram):
451 | if mel_spectrogram.dim() == 4:
452 | mel_spectrogram = mel_spectrogram.squeeze(1)
453 |
454 | waveform = self.vocoder(mel_spectrogram)
455 | waveform = waveform.cpu().float()
456 | return waveform
457 |
458 | def normalize_wav(self, waveform):
459 | waveform = waveform - torch.mean(waveform)
460 | waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
461 | return waveform
462 |
463 | @torch.autocast('cuda', dtype=torch.float16)
464 | @torch.no_grad()
465 | def __call__(
466 | self,
467 | model,
468 | clap_input_features,
469 | clap_is_longer,
470 | cont_prompt,
471 | spk_embeds,
472 | batch_size = 1,
473 | num_inference_steps = 100,
474 | audio_length_in_s = 10,
475 | do_classifier_free_guidance = True,
476 | guidance_scale = None,
477 | desc_guidance_scale = None,
478 | cont_guidance_scale = None,
479 | seed=None,
480 | concat_y = None,
481 | **kwargs,
482 | ):
483 | if guidance_scale is None and desc_guidance_scale is None and cont_guidance_scale is None:
484 | do_classifier_free_guidance = False
485 |
486 | guidance = None
487 | if guidance_scale is None:
488 | guidance = "dual"
489 | else:
490 | guidance = "single"
491 |
492 | # embeds = self.clap_model.audio_model(input_features=clap_input_features, is_longer=clap_is_longer).pooler_output
493 | # c_desc = self.clap_model.audio_projection(embeds)
494 |
495 | c_desc = self.clap_model.get_audio_features(input_features=clap_input_features, is_longer=clap_is_longer)
496 |
497 | if do_classifier_free_guidance:
498 | clap_inputs = self.clap_processor(text=[""], return_tensors="pt", padding=True).to(self.device)
499 | # uncond_embeds = self.clap_model.text_model(**clap_inputs).pooler_output
500 | # uc_desc = self.clap_model.text_projection(uncond_embeds)
501 | uc_desc = self.clap_model.get_text_features(**clap_inputs)
502 |
503 | if guidance == "dual":
504 | c_desc = torch.cat((c_desc, c_desc, uc_desc, uc_desc))
505 | if guidance == "single":
506 | c_desc = torch.cat((c_desc, uc_desc))
507 |
508 | # content condition
509 | # if do_classifier_free_guidance:
510 | # if guidance == "dual":
511 | # cont_prompt = ([cont_prompt] + ["_"]) * 2
512 | # if guidance == "single":
513 | # cont_prompt = [cont_prompt] + ["_"]
514 |
515 | # cont_tokens = self.text_processor(
516 | # text=[cont_prompt],
517 | # padding=True,
518 | # truncation=True,
519 | # max_length=1000,
520 | # return_tensors="pt"
521 | # ).to(self.device)
522 | # cont_embed_mask = cont_tokens.attention_mask
523 |
524 | cont_tokens = self.get_text(cont_prompt, add_blank=True)
525 | cont_tokens = cont_tokens.unsqueeze(0).to(self.device)
526 | cont_lengths = torch.LongTensor([cont_tokens.shape[-1]]).to(self.device)
527 |
528 | mu_y, y_mask = model.process_content(cont_tokens, cont_lengths, spk_embeds)
529 | mu_y_vae = mu_y.unsqueeze(1)
530 | if concat_y:
531 | mu_y_vae = model.proj(mu_y_vae)
532 |
533 | vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
534 | height = int(audio_length_in_s * 1.024 / vocoder_upsample_factor)
535 | original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
536 |
537 | # prepare latent variables
538 | num_channels_latents = model.dit.in_channels
539 | latents = self.prepare_latents(
540 | batch_size,
541 | num_channels_latents,
542 | height,
543 | c_desc.dtype,
544 | device=self.device,
545 | generator=torch.manual_seed(seed) if seed else None,
546 | latents=None,
547 | )
548 |
549 | # denoising loop
550 | if guidance == "dual":
551 | z = torch.cat([latents] * 4) if do_classifier_free_guidance else latents
552 | cfg_scale = (desc_guidance_scale, cont_guidance_scale)
553 | elif guidance == "single":
554 | z = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
555 | cfg_scale = guidance_scale
556 |
557 | if concat_y:
558 | mu_y_vae = torch.nn.functional.pad(mu_y_vae, (0, 0, 0, z.shape[-2] - mu_y_vae.shape[-2]), value=0)
559 | if do_classifier_free_guidance:
560 | if guidance == "dual":
561 | mu_y_vae = torch.cat([mu_y_vae, model.y_embedding.unsqueeze(0)] * 2, dim=0)
562 | elif guidance == "single":
563 | mu_y_vae = torch.cat([mu_y_vae, model.y_embedding.unsqueeze(0)], dim=0)
564 |
565 | model_kwargs = dict(
566 | y=mu_y_vae,
567 | class_labels=c_desc,
568 | guidance=guidance,
569 | cfg_scale=cfg_scale,
570 | )
571 |
572 | else:
573 | null_y = model.dit.y_embedder.y_embedding.repeat(mu_y_vae.shape[-2], 1)
574 | null_y = null_y.unsqueeze(0).unsqueeze(0)
575 | if do_classifier_free_guidance:
576 | if guidance == "dual":
577 | mu_y_vae = torch.cat([mu_y_vae, null_y] * 2, dim=0)
578 | elif guidance == "single":
579 | mu_y_vae = torch.cat([mu_y_vae, null_y], dim=0)
580 |
581 | model_kwargs = dict(
582 | y=mu_y_vae,
583 | mask=y_mask.squeeze(1).bool(),
584 | class_labels=c_desc,
585 | guidance=guidance,
586 | cfg_scale=cfg_scale,
587 | )
588 |
589 | diffusion = create_diffusion(str(num_inference_steps))
590 |
591 | # Sample images
592 | samples = diffusion.p_sample_loop(
593 | model.dit.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
594 | device=self.device
595 | )
596 | if guidance == "dual":
597 | samples = samples[: len(samples) // 4]
598 | if guidance == "single":
599 | samples = samples[: len(samples) // 2]
600 |
601 | mel_spectrogram = self.decode_latents(samples)
602 | audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
603 | audio = audio[:, :original_waveform_length]
604 | audio = self.normalize_wav(audio)
605 |
606 | return mu_y[0].transpose(0,1), mel_spectrogram.transpose(2,3), audio
--------------------------------------------------------------------------------