├── .gitignore ├── LOGO.png ├── audio_diffusion_pytorch ├── __init__.py ├── utils.py ├── components.py ├── models.py └── diffusion.py ├── setup.py ├── LICENSE ├── .pre-commit-config.yaml ├── .github └── workflows │ └── python-publish.yml └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | -------------------------------------------------------------------------------- /LOGO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vaibhavs10/audio-diffusion-pytorch/main/LOGO.png -------------------------------------------------------------------------------- /audio_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .components import LTPlugin, MelSpectrogram, UNetV0, XUNet 2 | from .diffusion import ( 3 | Diffusion, 4 | Distribution, 5 | LinearSchedule, 6 | Sampler, 7 | Schedule, 8 | UniformDistribution, 9 | VDiffusion, 10 | VSampler, 11 | ) 12 | from .models import ( 13 | DiffusionAE, 14 | DiffusionAR, 15 | DiffusionModel, 16 | DiffusionUpsampler, 17 | DiffusionVocoder, 18 | ) 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="audio-diffusion-pytorch", 5 | packages=find_packages(exclude=[]), 6 | version="0.1.1", 7 | license="MIT", 8 | description="Audio Diffusion - PyTorch", 9 | long_description_content_type="text/markdown", 10 | author="Flavio Schneider", 11 | author_email="archinetai@protonmail.com", 12 | url="https://github.com/archinetai/audio-diffusion-pytorch", 13 | keywords=["artificial intelligence", "deep learning", "audio generation"], 14 | install_requires=[ 15 | "tqdm", 16 | "torch>=1.6", 17 | "torchaudio", 18 | "data-science-types>=0.2", 19 | "einops>=0.6", 20 | "a-unet", 21 | ], 22 | classifiers=[ 23 | "Development Status :: 4 - Beta", 24 | "Intended Audience :: Developers", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | "License :: OSI Approved :: MIT License", 27 | "Programming Language :: Python :: 3.6", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 archinet.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | 8 | # Formats code correctly 9 | - repo: https://github.com/psf/black 10 | rev: 21.12b0 11 | hooks: 12 | - id: black 13 | args: [ 14 | '--experimental-string-processing' 15 | ] 16 | 17 | # Sorts imports 18 | - repo: https://github.com/pycqa/isort 19 | rev: 5.10.1 20 | hooks: 21 | - id: isort 22 | name: isort (python) 23 | args: ["--profile", "black"] 24 | 25 | # Checks unused imports, like lengths, etc 26 | - repo: https://gitlab.com/pycqa/flake8 27 | rev: 4.0.0 28 | hooks: 29 | - id: flake8 30 | args: [ 31 | '--per-file-ignores=__init__.py:F401', 32 | '--max-line-length=88', 33 | '--ignore=E203,W503' 34 | ] 35 | 36 | # Checks types 37 | - repo: https://github.com/pre-commit/mirrors-mypy 38 | rev: 'v0.971' 39 | hooks: 40 | - id: mypy 41 | additional_dependencies: [data-science-types>=0.2, torch>=1.6] 42 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /audio_diffusion_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from inspect import isfunction 3 | from math import ceil, floor, log2, pi 4 | from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | from torch import Generator, Tensor 10 | from typing_extensions import TypeGuard 11 | 12 | T = TypeVar("T") 13 | 14 | 15 | def exists(val: Optional[T]) -> TypeGuard[T]: 16 | return val is not None 17 | 18 | 19 | def iff(condition: bool, value: T) -> Optional[T]: 20 | return value if condition else None 21 | 22 | 23 | def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]: 24 | return isinstance(obj, list) or isinstance(obj, tuple) 25 | 26 | 27 | def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: 28 | if exists(val): 29 | return val 30 | return d() if isfunction(d) else d 31 | 32 | 33 | def to_list(val: Union[T, Sequence[T]]) -> List[T]: 34 | if isinstance(val, tuple): 35 | return list(val) 36 | if isinstance(val, list): 37 | return val 38 | return [val] # type: ignore 39 | 40 | 41 | def prod(vals: Sequence[int]) -> int: 42 | return reduce(lambda x, y: x * y, vals) 43 | 44 | 45 | def closest_power_2(x: float) -> int: 46 | exponent = log2(x) 47 | distance_fn = lambda z: abs(x - 2 ** z) # noqa 48 | exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) 49 | return 2 ** int(exponent_closest) 50 | 51 | 52 | """ 53 | Kwargs Utils 54 | """ 55 | 56 | 57 | def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: 58 | return_dicts: Tuple[Dict, Dict] = ({}, {}) 59 | for key in d.keys(): 60 | no_prefix = int(not key.startswith(prefix)) 61 | return_dicts[no_prefix][key] = d[key] 62 | return return_dicts 63 | 64 | 65 | def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: 66 | kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) 67 | if keep_prefix: 68 | return kwargs_with_prefix, kwargs 69 | kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} 70 | return kwargs_no_prefix, kwargs 71 | 72 | 73 | def prefix_dict(prefix: str, d: Dict) -> Dict: 74 | return {prefix + str(k): v for k, v in d.items()} 75 | 76 | 77 | """ 78 | DSP Utils 79 | """ 80 | 81 | 82 | def resample( 83 | waveforms: Tensor, 84 | factor_in: int, 85 | factor_out: int, 86 | rolloff: float = 0.99, 87 | lowpass_filter_width: int = 6, 88 | ) -> Tensor: 89 | """Resamples a waveform using sinc interpolation, adapted from torchaudio""" 90 | b, _, length = waveforms.shape 91 | length_target = int(factor_out * length / factor_in) 92 | d = dict(device=waveforms.device, dtype=waveforms.dtype) 93 | 94 | base_factor = min(factor_in, factor_out) * rolloff 95 | width = ceil(lowpass_filter_width * factor_in / base_factor) 96 | idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa 97 | t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa 98 | t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * pi 99 | 100 | window = torch.cos(t / lowpass_filter_width / 2) ** 2 101 | scale = base_factor / factor_in 102 | kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t) 103 | kernels *= window * scale 104 | 105 | waveforms = rearrange(waveforms, "b c t -> (b c) t") 106 | waveforms = F.pad(waveforms, (width, width + factor_in)) 107 | resampled = F.conv1d(waveforms[:, None], kernels, stride=factor_in) 108 | resampled = rearrange(resampled, "(b c) k l -> b c (l k)", b=b) 109 | return resampled[..., :length_target] 110 | 111 | 112 | def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: 113 | return resample(waveforms, factor_in=factor, factor_out=1, **kwargs) 114 | 115 | 116 | def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: 117 | return resample(waveforms, factor_in=1, factor_out=factor, **kwargs) 118 | 119 | 120 | """ Torch Utils """ 121 | 122 | 123 | def randn_like(tensor: Tensor, *args, generator: Optional[Generator] = None, **kwargs): 124 | """randn_like that supports generator""" 125 | return torch.randn(tensor.shape, *args, generator=generator, **kwargs).to(tensor) 126 | -------------------------------------------------------------------------------- /audio_diffusion_pytorch/components.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Sequence 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from a_unet import ( 6 | ClassifierFreeGuidancePlugin, 7 | Conv, 8 | Module, 9 | TextConditioningPlugin, 10 | TimeConditioningPlugin, 11 | default, 12 | exists, 13 | ) 14 | from a_unet.apex import ( 15 | AttentionItem, 16 | CrossAttentionItem, 17 | InjectChannelsItem, 18 | ModulationItem, 19 | ResnetItem, 20 | SkipCat, 21 | SkipModulate, 22 | XBlock, 23 | XUNet, 24 | ) 25 | from einops import pack, unpack 26 | from torch import Tensor, nn 27 | from torchaudio import transforms 28 | 29 | """ 30 | UNets (built with a-unet: https://github.com/archinetai/a-unet) 31 | """ 32 | 33 | 34 | def UNetV0( 35 | dim: int, 36 | in_channels: int, 37 | channels: Sequence[int], 38 | factors: Sequence[int], 39 | items: Sequence[int], 40 | attentions: Optional[Sequence[int]] = None, 41 | cross_attentions: Optional[Sequence[int]] = None, 42 | context_channels: Optional[Sequence[int]] = None, 43 | attention_features: Optional[int] = None, 44 | attention_heads: Optional[int] = None, 45 | embedding_features: Optional[int] = None, 46 | resnet_groups: int = 8, 47 | use_modulation: bool = True, 48 | modulation_features: int = 1024, 49 | embedding_max_length: Optional[int] = None, 50 | use_time_conditioning: bool = True, 51 | use_embedding_cfg: bool = False, 52 | use_text_conditioning: bool = False, 53 | out_channels: Optional[int] = None, 54 | ): 55 | # Set defaults and check lengths 56 | num_layers = len(channels) 57 | attentions = default(attentions, [0] * num_layers) 58 | cross_attentions = default(cross_attentions, [0] * num_layers) 59 | context_channels = default(context_channels, [0] * num_layers) 60 | xs = (channels, factors, items, attentions, cross_attentions, context_channels) 61 | assert all(len(x) == num_layers for x in xs) # type: ignore 62 | 63 | # Define UNet type 64 | UNetV0 = XUNet 65 | 66 | if use_embedding_cfg: 67 | msg = "use_embedding_cfg requires embedding_max_length" 68 | assert exists(embedding_max_length), msg 69 | UNetV0 = ClassifierFreeGuidancePlugin(UNetV0, embedding_max_length) 70 | 71 | if use_text_conditioning: 72 | UNetV0 = TextConditioningPlugin(UNetV0) 73 | 74 | if use_time_conditioning: 75 | assert use_modulation, "use_time_conditioning requires use_modulation=True" 76 | UNetV0 = TimeConditioningPlugin(UNetV0) 77 | 78 | # Build 79 | return UNetV0( 80 | dim=dim, 81 | in_channels=in_channels, 82 | out_channels=out_channels, 83 | blocks=[ 84 | XBlock( 85 | channels=channels, 86 | factor=factor, 87 | context_channels=ctx_channels, 88 | items=( 89 | [ResnetItem] 90 | + [ModulationItem] * use_modulation 91 | + [InjectChannelsItem] * (ctx_channels > 0) 92 | + [AttentionItem] * att 93 | + [CrossAttentionItem] * cross 94 | ) 95 | * items, 96 | ) 97 | for channels, factor, items, att, cross, ctx_channels in zip(*xs) # type: ignore # noqa 98 | ], 99 | skip_t=SkipModulate if use_modulation else SkipCat, 100 | attention_features=attention_features, 101 | attention_heads=attention_heads, 102 | embedding_features=embedding_features, 103 | modulation_features=modulation_features, 104 | resnet_groups=resnet_groups, 105 | ) 106 | 107 | 108 | """ 109 | Plugins 110 | """ 111 | 112 | 113 | def LTPlugin( 114 | net_t: Callable, num_filters: int, window_length: int, stride: int 115 | ) -> Callable[..., nn.Module]: 116 | """Learned Transform Plugin""" 117 | 118 | def Net( 119 | dim: int, in_channels: int, out_channels: Optional[int] = None, **kwargs 120 | ) -> nn.Module: 121 | out_channels = default(out_channels, in_channels) 122 | in_channel_transform = in_channels * num_filters 123 | out_channel_transform = out_channels * num_filters # type: ignore 124 | 125 | padding = window_length // 2 - stride // 2 126 | encode = Conv( 127 | dim=dim, 128 | in_channels=in_channels, 129 | out_channels=in_channel_transform, 130 | kernel_size=window_length, 131 | stride=stride, 132 | padding=padding, 133 | padding_mode="reflect", 134 | bias=False, 135 | ) 136 | decode = nn.ConvTranspose1d( 137 | in_channels=out_channel_transform, 138 | out_channels=out_channels, # type: ignore 139 | kernel_size=window_length, 140 | stride=stride, 141 | padding=padding, 142 | bias=False, 143 | ) 144 | net = net_t( # type: ignore 145 | dim=dim, 146 | in_channels=in_channel_transform, 147 | out_channels=out_channel_transform, 148 | **kwargs 149 | ) 150 | 151 | def forward(x: Tensor, *args, **kwargs): 152 | x = encode(x) 153 | x = net(x, *args, **kwargs) 154 | x = decode(x) 155 | return x 156 | 157 | return Module([encode, decode, net], forward) 158 | 159 | return Net 160 | 161 | 162 | def AppendChannelsPlugin( 163 | net_t: Callable, 164 | channels: int, 165 | ): 166 | def Net( 167 | in_channels: int, out_channels: Optional[int] = None, **kwargs 168 | ) -> nn.Module: 169 | out_channels = default(out_channels, in_channels) 170 | net = net_t( # type: ignore 171 | in_channels=in_channels + channels, out_channels=out_channels, **kwargs 172 | ) 173 | 174 | def forward(x: Tensor, *args, append_channels: Tensor, **kwargs): 175 | x = torch.cat([x, append_channels], dim=1) 176 | return net(x, *args, **kwargs) 177 | 178 | return Module([net], forward) 179 | 180 | return Net 181 | 182 | 183 | """ 184 | Other 185 | """ 186 | 187 | 188 | class MelSpectrogram(nn.Module): 189 | def __init__( 190 | self, 191 | n_fft: int, 192 | hop_length: int, 193 | win_length: int, 194 | sample_rate: int, 195 | n_mel_channels: int, 196 | center: bool = False, 197 | normalize: bool = False, 198 | normalize_log: bool = False, 199 | ): 200 | super().__init__() 201 | self.padding = (n_fft - hop_length) // 2 202 | self.normalize = normalize 203 | self.normalize_log = normalize_log 204 | self.hop_length = hop_length 205 | 206 | self.to_spectrogram = transforms.Spectrogram( 207 | n_fft=n_fft, 208 | hop_length=hop_length, 209 | win_length=win_length, 210 | center=center, 211 | power=None, 212 | ) 213 | 214 | self.to_mel_scale = transforms.MelScale( 215 | n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate 216 | ) 217 | 218 | def forward(self, waveform: Tensor) -> Tensor: 219 | # Pack non-time dimension 220 | waveform, ps = pack([waveform], "* t") 221 | # Pad waveform 222 | waveform = F.pad(waveform, [self.padding] * 2, mode="reflect") 223 | # Compute STFT 224 | spectrogram = self.to_spectrogram(waveform) 225 | # Compute magnitude 226 | spectrogram = torch.abs(spectrogram) 227 | # Convert to mel scale 228 | mel_spectrogram = self.to_mel_scale(spectrogram) 229 | # Normalize 230 | if self.normalize: 231 | mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram) 232 | mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1 233 | if self.normalize_log: 234 | mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5)) 235 | # Unpack non-spectrogram dimension 236 | return unpack(mel_spectrogram, ps, "* f l")[0] 237 | -------------------------------------------------------------------------------- /audio_diffusion_pytorch/models.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from math import floor 3 | from typing import Any, Callable, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | from einops import pack, rearrange, unpack 7 | from torch import Generator, Tensor, nn 8 | 9 | from .components import AppendChannelsPlugin, MelSpectrogram 10 | from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler 11 | from .utils import closest_power_2, default, downsample, groupby, randn_like, upsample 12 | 13 | 14 | class DiffusionModel(nn.Module): 15 | def __init__( 16 | self, 17 | net_t: Callable, 18 | diffusion_t: Callable = VDiffusion, 19 | sampler_t: Callable = VSampler, 20 | dim: int = 1, 21 | **kwargs, 22 | ): 23 | super().__init__() 24 | diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) 25 | sampler_kwargs, kwargs = groupby("sampler_", kwargs) 26 | 27 | self.net = net_t(dim=dim, **kwargs) 28 | self.diffusion = diffusion_t(net=self.net, **diffusion_kwargs) 29 | self.sampler = sampler_t(net=self.net, **sampler_kwargs) 30 | 31 | def forward(self, *args, **kwargs) -> Tensor: 32 | return self.diffusion(*args, **kwargs) 33 | 34 | @torch.no_grad() 35 | def sample(self, *args, **kwargs) -> Tensor: 36 | return self.sampler(*args, **kwargs) 37 | 38 | 39 | class EncoderBase(nn.Module, ABC): 40 | """Abstract class for DiffusionAE encoder""" 41 | 42 | @abstractmethod 43 | def __init__(self): 44 | super().__init__() 45 | self.out_channels = None 46 | self.downsample_factor = None 47 | 48 | 49 | class DiffusionAE(DiffusionModel): 50 | """Diffusion Auto Encoder""" 51 | 52 | def __init__( 53 | self, 54 | in_channels: int, 55 | channels: Sequence[int], 56 | encoder: EncoderBase, 57 | inject_depth: int, 58 | **kwargs, 59 | ): 60 | context_channels = [0] * len(channels) 61 | context_channels[inject_depth] = encoder.out_channels 62 | super().__init__( 63 | in_channels=in_channels, 64 | channels=channels, 65 | context_channels=context_channels, 66 | **kwargs, 67 | ) 68 | self.in_channels = in_channels 69 | self.encoder = encoder 70 | self.inject_depth = inject_depth 71 | 72 | def forward( # type: ignore 73 | self, x: Tensor, with_info: bool = False, **kwargs 74 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 75 | latent, info = self.encode(x, with_info=True) 76 | channels = [None] * self.inject_depth + [latent] 77 | loss = super().forward(x, channels=channels, **kwargs) 78 | return (loss, info) if with_info else loss 79 | 80 | def encode(self, *args, **kwargs): 81 | return self.encoder(*args, **kwargs) 82 | 83 | @torch.no_grad() 84 | def decode( 85 | self, latent: Tensor, generator: Optional[Generator] = None, **kwargs 86 | ) -> Tensor: 87 | b = latent.shape[0] 88 | length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) 89 | # Compute noise by inferring shape from latent length 90 | noise = torch.randn( 91 | (b, self.in_channels, length), 92 | device=latent.device, 93 | dtype=latent.dtype, 94 | generator=generator, 95 | ) 96 | # Compute context from latent 97 | channels = [None] * self.inject_depth + [latent] # type: ignore 98 | # Decode by sampling while conditioning on latent channels 99 | return super().sample(noise, channels=channels, **kwargs) 100 | 101 | 102 | class DiffusionUpsampler(DiffusionModel): 103 | def __init__( 104 | self, 105 | in_channels: int, 106 | upsample_factor: int, 107 | net_t: Callable, 108 | **kwargs, 109 | ): 110 | self.upsample_factor = upsample_factor 111 | super().__init__( 112 | net_t=AppendChannelsPlugin(net_t, channels=in_channels), 113 | in_channels=in_channels, 114 | **kwargs, 115 | ) 116 | 117 | def reupsample(self, x: Tensor) -> Tensor: 118 | x = x.clone() 119 | x = downsample(x, factor=self.upsample_factor) 120 | x = upsample(x, factor=self.upsample_factor) 121 | return x 122 | 123 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore 124 | reupsampled = self.reupsample(x) 125 | return super().forward(x, *args, append_channels=reupsampled, **kwargs) 126 | 127 | @torch.no_grad() 128 | def sample( # type: ignore 129 | self, downsampled: Tensor, generator: Optional[Generator] = None, **kwargs 130 | ) -> Tensor: 131 | reupsampled = upsample(downsampled, factor=self.upsample_factor) 132 | noise = randn_like(reupsampled, generator=generator) 133 | return super().sample(noise, append_channels=reupsampled, **kwargs) 134 | 135 | 136 | class DiffusionVocoder(DiffusionModel): 137 | def __init__( 138 | self, 139 | net_t: Callable, 140 | mel_channels: int, 141 | mel_n_fft: int, 142 | mel_hop_length: Optional[int] = None, 143 | mel_win_length: Optional[int] = None, 144 | in_channels: int = 1, # Ignored: channels are automatically batched. 145 | **kwargs, 146 | ): 147 | mel_hop_length = default(mel_hop_length, floor(mel_n_fft) // 4) 148 | mel_win_length = default(mel_win_length, mel_n_fft) 149 | mel_kwargs, kwargs = groupby("mel_", kwargs) 150 | super().__init__( 151 | net_t=AppendChannelsPlugin(net_t, channels=1), 152 | in_channels=1, 153 | **kwargs, 154 | ) 155 | self.to_spectrogram = MelSpectrogram( 156 | n_fft=mel_n_fft, 157 | hop_length=mel_hop_length, 158 | win_length=mel_win_length, 159 | n_mel_channels=mel_channels, 160 | **mel_kwargs, 161 | ) 162 | self.to_flat = nn.ConvTranspose1d( 163 | in_channels=mel_channels, 164 | out_channels=1, 165 | kernel_size=mel_win_length, 166 | stride=mel_hop_length, 167 | padding=(mel_win_length - mel_hop_length) // 2, 168 | bias=False, 169 | ) 170 | 171 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: # type: ignore 172 | # Get spectrogram, pack channels and flatten 173 | spectrogram = rearrange(self.to_spectrogram(x), "b c f l -> (b c) f l") 174 | spectrogram_flat = self.to_flat(spectrogram) 175 | # Pack wave channels 176 | x = rearrange(x, "b c t -> (b c) 1 t") 177 | return super().forward(x, *args, append_channels=spectrogram_flat, **kwargs) 178 | 179 | @torch.no_grad() 180 | def sample( # type: ignore 181 | self, spectrogram: Tensor, generator: Optional[Generator] = None, **kwargs 182 | ) -> Tensor: # type: ignore 183 | # Pack channels and flatten spectrogram 184 | spectrogram, ps = pack([spectrogram], "* f l") 185 | spectrogram_flat = self.to_flat(spectrogram) 186 | # Get start noise and sample 187 | noise = randn_like(spectrogram_flat, generator=generator) 188 | waveform = super().sample(noise, append_channels=spectrogram_flat, **kwargs) 189 | # Unpack wave channels 190 | waveform = rearrange(waveform, "... 1 t -> ... t") 191 | waveform = unpack(waveform, ps, "* t")[0] 192 | return waveform 193 | 194 | 195 | class DiffusionAR(DiffusionModel): 196 | def __init__( 197 | self, 198 | in_channels: int, 199 | length: int, 200 | num_splits: int, 201 | diffusion_t: Callable = ARVDiffusion, 202 | sampler_t: Callable = ARVSampler, 203 | **kwargs, 204 | ): 205 | super().__init__( 206 | in_channels=in_channels + 1, 207 | out_channels=in_channels, 208 | diffusion_t=diffusion_t, 209 | diffusion_length=length, 210 | diffusion_num_splits=num_splits, 211 | sampler_t=sampler_t, 212 | sampler_in_channels=in_channels, 213 | sampler_length=length, 214 | sampler_num_splits=num_splits, 215 | use_time_conditioning=False, 216 | use_modulation=False, 217 | **kwargs, 218 | ) 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | A fully featured audio diffusion library, for PyTorch. Includes models for unconditional audio generation, text-conditional audio generation, diffusion autoencoding, upsampling, and vocoding. The provided models are waveform-based, however, the U-Net (built using [`a-unet`](https://github.com/archinetai/a-unet)), `DiffusionModel`, diffusion method, and diffusion samplers are both generic to any dimension and highly customizable to work on other formats. 4 | 5 | 6 | ## Install 7 | 8 | ```bash 9 | pip install audio-diffusion-pytorch 10 | ``` 11 | 12 | [![PyPI - Python Version](https://img.shields.io/pypi/v/audio-diffusion-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/audio-diffusion-pytorch/) 13 | [![Downloads](https://static.pepy.tech/personalized-badge/audio-diffusion-pytorch?period=total&units=international_system&left_color=black&right_color=black&left_text=Downloads)](https://pepy.tech/project/audio-diffusion-pytorch) 14 | 15 | 16 | ## Usage 17 | 18 | ### Unconditional Generator 19 | 20 | ```py 21 | from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler 22 | 23 | model = DiffusionModel( 24 | net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case) 25 | in_channels=2, # U-Net: number of input/output (audio) channels 26 | channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer 27 | factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer 28 | items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer 29 | attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer 30 | attention_heads=8, # U-Net: number of attention heads per attention item 31 | attention_features=64, # U-Net: number of attention features per attention item 32 | diffusion_t=VDiffusion, # The diffusion method used 33 | sampler_t=VSampler, # The diffusion sampler used 34 | ) 35 | 36 | # Train model with audio waveforms 37 | audio = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length] 38 | loss = model(audio) 39 | loss.backward() 40 | 41 | # Turn noise into new audio sample with diffusion 42 | noise = torch.randn(1, 2, 2**18) # [batch_size, in_channels, length] 43 | sample = model.sample(noise, num_steps=10) # Suggested num_steps 10-100 44 | ``` 45 | 46 | ### Text-Conditional Generator 47 | A text-to-audio diffusion model that conditions the generation with `t5-base` text embeddings, requires `pip install transformers`. 48 | ```py 49 | from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler 50 | 51 | model = DiffusionModel( 52 | # ... same as unconditional model 53 | use_text_conditioning=True, # U-Net: enables text conditioning (default T5-base) 54 | use_embedding_cfg=True, # U-Net: enables classifier free guidance 55 | embedding_max_length=64, # U-Net: text embedding maximum length (default for T5-base) 56 | embedding_features=768, # U-Net: text mbedding features (default for T5-base) 57 | cross_attentions=[0, 0, 0, 1, 1, 1, 1, 1, 1], # U-Net: cross-attention enabled/disabled at each layer 58 | ) 59 | 60 | # Train model with audio waveforms 61 | audio_wave = torch.randn(1, 2, 2**18) # [batch, in_channels, length] 62 | loss = model( 63 | audio_wave, 64 | text=['The audio description'], # Text conditioning, one element per batch 65 | embedding_mask_proba=0.1 # Probability of masking text with learned embedding (Classifier-Free Guidance Mask) 66 | ) 67 | loss.backward() 68 | 69 | # Turn noise into new audio sample with diffusion 70 | noise = torch.randn(1, 2, 2**18) 71 | sample = model.sample( 72 | noise, 73 | text=['The audio description'], 74 | embedding_scale=5.0, # Higher for more text importance, suggested range: 1-15 (Classifier-Free Guidance Scale) 75 | num_steps=2 # Higher for better quality, suggested num_steps: 10-100 76 | ) 77 | ``` 78 | 79 | ### Diffusion Upsampler 80 | Upsample audio from a lower sample rate to higher sample rate using diffusion, e.g. 3kHz to 48kHz. 81 | ```py 82 | from audio_diffusion_pytorch import DiffusionUpsampler, UNetV0, VDiffusion, VSampler 83 | 84 | upsampler = DiffusionUpsampler( 85 | net_t=UNetV0, # The model type used for diffusion 86 | upsample_factor=16, # The upsample factor (e.g. 16 can be used for 3kHz to 48kHz) 87 | in_channels=2, # U-Net: number of input/output (audio) channels 88 | channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer 89 | factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer 90 | items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer 91 | diffusion_t=VDiffusion, # The diffusion method used 92 | sampler_t=VSampler, # The diffusion sampler used 93 | ) 94 | 95 | # Train model with high sample rate audio waveforms 96 | audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] 97 | loss = upsampler(audio) 98 | loss.backward() 99 | 100 | # Turn low sample rate audio into high sample rate 101 | downsampled_audio = torch.randn(1, 2, 2**14) # [batch, in_channels, length] 102 | sample = upsampler.sample(downsampled_audio, num_steps=10) # Output has shape: [1, 2, 2**18] 103 | ``` 104 | 105 | ### Diffusion Vocoder 106 | Convert a mel-spectrogram to wavefrom using diffusion. 107 | ```py 108 | from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler 109 | 110 | vocoder = DiffusionVocoder( 111 | mel_n_fft=1024, # Mel-spectrogram n_fft 112 | mel_channels=80, # Mel-spectrogram channels 113 | mel_sample_rate=48000, # Mel-spectrogram sample rate 114 | mel_normalize_log=True, # Mel-spectrogram log normalization (alternative is mel_normalize=True for [-1,1] power normalization) 115 | net_t=UNetV0, # The model type used for diffusion vocoding 116 | channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer 117 | factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer 118 | items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer 119 | diffusion_t=VDiffusion, # The diffusion method used 120 | sampler_t=VSampler, # The diffusion sampler used 121 | ) 122 | 123 | # Train model on waveforms (automatically converted to mel internally) 124 | audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] 125 | loss = vocoder(audio) 126 | loss.backward() 127 | 128 | # Turn mel spectrogram into waveform 129 | mel_spectrogram = torch.randn(1, 2, 80, 1024) # [batch, in_channels, mel_channels, mel_length] 130 | sample = vocoder.sample(mel_spectrogram, num_steps=10) # Output has shape: [1, 2, 2**18] 131 | ``` 132 | 133 | ## Diffusion Autoencoder 134 | Autoencode audio into a compressed latent using diffusion. Any encoder can be provided as long as it subclasses the `EncoderBase` class or contains an `out_channels` and `downsample_factor` field. 135 | ```py 136 | from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler 137 | from audio_encoders_pytorch import MelE1d, TanhBottleneck 138 | 139 | autoencoder = DiffusionAE( 140 | encoder=MelE1d( # The encoder used, in this case a mel-spectrogram encoder 141 | in_channels=2, 142 | channels=512, 143 | multipliers=[1, 1], 144 | factors=[2], 145 | num_blocks=[12], 146 | out_channels=32, 147 | mel_channels=80, 148 | mel_sample_rate=48000, 149 | mel_normalize_log=True, 150 | bottleneck=TanhBottleneck(), 151 | ), 152 | inject_depth=6, 153 | net_t=UNetV0, # The model type used for diffusion upsampling 154 | in_channels=2, # U-Net: number of input/output (audio) channels 155 | channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer 156 | factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer 157 | items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer 158 | diffusion_t=VDiffusion, # The diffusion method used 159 | sampler_t=VSampler, # The diffusion sampler used 160 | ) 161 | 162 | # Train autoencoder with audio samples 163 | audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] 164 | loss = autoencoder(audio) 165 | loss.backward() 166 | 167 | # Encode/decode audio 168 | audio = torch.randn(1, 2, 2**18) # [batch, in_channels, length] 169 | latent = autoencoder.encode(audio) # Encode 170 | sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent 171 | ``` 172 | 173 | ## Appreciation 174 | 175 | * [StabilityAI](https://stability.ai/) for the compute, [Zach Evans](https://github.com/zqevans) and everyone else from [HarmonAI](https://www.harmonai.org/) for the interesting research discussions. 176 | * [ETH Zurich](https://inf.ethz.ch/) for the resources, [Zhijing Jin](https://zhijing-jin.com/), [Bernhard Schoelkopf](https://is.mpg.de/~bs), and [Mrinmaya Sachan](http://www.mrinmaya.io/) for supervising this Thesis. 177 | * [Phil Wang](https://github.com/lucidrains) for the beautiful open source contributions on [diffusion](https://github.com/lucidrains/denoising-diffusion-pytorch) and [Imagen](https://github.com/lucidrains/imagen-pytorch). 178 | * [Katherine Crowson](https://github.com/crowsonkb) for the experiments with [k-diffusion](https://github.com/crowsonkb/k-diffusion) and the insane collection of samplers. 179 | 180 | ## Citations 181 | 182 | DDPM Diffusion 183 | ```bibtex 184 | @misc{2006.11239, 185 | Author = {Jonathan Ho and Ajay Jain and Pieter Abbeel}, 186 | Title = {Denoising Diffusion Probabilistic Models}, 187 | Year = {2020}, 188 | Eprint = {arXiv:2006.11239}, 189 | } 190 | ``` 191 | 192 | DDIM (V-Sampler) 193 | ```bibtex 194 | @misc{2010.02502, 195 | Author = {Jiaming Song and Chenlin Meng and Stefano Ermon}, 196 | Title = {Denoising Diffusion Implicit Models}, 197 | Year = {2020}, 198 | Eprint = {arXiv:2010.02502}, 199 | } 200 | ``` 201 | 202 | V-Diffusion 203 | ```bibtex 204 | @misc{2202.00512, 205 | Author = {Tim Salimans and Jonathan Ho}, 206 | Title = {Progressive Distillation for Fast Sampling of Diffusion Models}, 207 | Year = {2022}, 208 | Eprint = {arXiv:2202.00512}, 209 | } 210 | ``` 211 | 212 | Imagen (T5 Text Conditioning) 213 | ```bibtex 214 | @misc{2205.11487, 215 | Author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi}, 216 | Title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding}, 217 | Year = {2022}, 218 | Eprint = {arXiv:2205.11487}, 219 | } 220 | ``` 221 | -------------------------------------------------------------------------------- /audio_diffusion_pytorch/diffusion.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | from typing import Any, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import rearrange, repeat 8 | from torch import Tensor 9 | from tqdm import tqdm 10 | 11 | """ Distributions """ 12 | 13 | 14 | class Distribution: 15 | """Interface used by different distributions""" 16 | 17 | def __call__(self, num_samples: int, device: torch.device): 18 | raise NotImplementedError() 19 | 20 | 21 | class UniformDistribution(Distribution): 22 | def __init__(self, vmin: float = 0.0, vmax: float = 1.0): 23 | super().__init__() 24 | self.vmin, self.vmax = vmin, vmax 25 | 26 | def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")): 27 | vmax, vmin = self.vmax, self.vmin 28 | return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin 29 | 30 | 31 | """ Diffusion Methods """ 32 | 33 | 34 | def pad_dims(x: Tensor, ndim: int) -> Tensor: 35 | # Pads additional ndims to the right of the tensor 36 | return x.view(*x.shape, *((1,) * ndim)) 37 | 38 | 39 | def clip(x: Tensor, dynamic_threshold: float = 0.0): 40 | if dynamic_threshold == 0.0: 41 | return x.clamp(-1.0, 1.0) 42 | else: 43 | # Dynamic thresholding 44 | # Find dynamic threshold quantile for each batch 45 | x_flat = rearrange(x, "b ... -> b (...)") 46 | scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) 47 | # Clamp to a min of 1.0 48 | scale.clamp_(min=1.0) 49 | # Clamp all values and scale 50 | scale = pad_dims(scale, ndim=x.ndim - scale.ndim) 51 | x = x.clamp(-scale, scale) / scale 52 | return x 53 | 54 | 55 | def extend_dim(x: Tensor, dim: int): 56 | # e.g. if dim = 4: shape [b] => [b, 1, 1, 1], 57 | return x.view(*x.shape + (1,) * (dim - x.ndim)) 58 | 59 | 60 | class Diffusion(nn.Module): 61 | """Interface used by different diffusion methods""" 62 | 63 | pass 64 | 65 | 66 | class VDiffusion(Diffusion): 67 | def __init__( 68 | self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution() 69 | ): 70 | super().__init__() 71 | self.net = net 72 | self.sigma_distribution = sigma_distribution 73 | 74 | def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: 75 | angle = sigmas * pi / 2 76 | alpha, beta = torch.cos(angle), torch.sin(angle) 77 | return alpha, beta 78 | 79 | def forward(self, x: Tensor, **kwargs) -> Tensor: # type: ignore 80 | batch_size, device = x.shape[0], x.device 81 | # Sample amount of noise to add for each batch element 82 | sigmas = self.sigma_distribution(num_samples=batch_size, device=device) 83 | sigmas_batch = extend_dim(sigmas, dim=x.ndim) 84 | # Get noise 85 | noise = torch.randn_like(x) 86 | # Combine input and noise weighted by half-circle 87 | alphas, betas = self.get_alpha_beta(sigmas_batch) 88 | x_noisy = alphas * x + betas * noise 89 | v_target = alphas * noise - betas * x 90 | # Predict velocity and return loss 91 | v_pred = self.net(x_noisy, sigmas, **kwargs) 92 | return F.mse_loss(v_pred, v_target) 93 | 94 | 95 | class ARVDiffusion(Diffusion): 96 | def __init__(self, net: nn.Module, length: int, num_splits: int): 97 | super().__init__() 98 | assert length % num_splits == 0, "length must be divisible by num_splits" 99 | self.net = net 100 | self.length = length 101 | self.num_splits = num_splits 102 | self.split_length = length // num_splits 103 | 104 | def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: 105 | angle = sigmas * pi / 2 106 | alpha, beta = torch.cos(angle), torch.sin(angle) 107 | return alpha, beta 108 | 109 | def forward(self, x: Tensor, **kwargs) -> Tensor: 110 | """Returns diffusion loss of v-objective with different noises per split""" 111 | b, _, t, device, dtype = *x.shape, x.device, x.dtype 112 | assert t == self.length, "input length must match length" 113 | # Sample amount of noise to add for each split 114 | sigmas = torch.rand((b, 1, self.num_splits), device=device, dtype=dtype) 115 | sigmas = repeat(sigmas, "b 1 n -> b 1 (n l)", l=self.split_length) 116 | # Get noise 117 | noise = torch.randn_like(x) 118 | # Combine input and noise weighted by half-circle 119 | alphas, betas = self.get_alpha_beta(sigmas) 120 | x_noisy = alphas * x + betas * noise 121 | v_target = alphas * noise - betas * x 122 | # Sigmas will be provided as additional channel 123 | channels = torch.cat([x_noisy, sigmas], dim=1) 124 | # Predict velocity and return loss 125 | v_pred = self.net(channels, **kwargs) 126 | return F.mse_loss(v_pred, v_target) 127 | 128 | 129 | """ Schedules """ 130 | 131 | 132 | class Schedule(nn.Module): 133 | """Interface used by different sampling schedules""" 134 | 135 | def forward(self, num_steps: int, device: torch.device) -> Tensor: 136 | raise NotImplementedError() 137 | 138 | 139 | class LinearSchedule(Schedule): 140 | def __init__(self, start: float = 1.0, end: float = 0.0): 141 | super().__init__() 142 | self.start, self.end = start, end 143 | 144 | def forward(self, num_steps: int, device: Any) -> Tensor: 145 | return torch.linspace(self.start, self.end, num_steps, device=device) 146 | 147 | 148 | """ Samplers """ 149 | 150 | 151 | class Sampler(nn.Module): 152 | pass 153 | 154 | 155 | class VSampler(Sampler): 156 | 157 | diffusion_types = [VDiffusion] 158 | 159 | def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()): 160 | super().__init__() 161 | self.net = net 162 | self.schedule = schedule 163 | 164 | def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: 165 | angle = sigmas * pi / 2 166 | alpha, beta = torch.cos(angle), torch.sin(angle) 167 | return alpha, beta 168 | 169 | def forward( # type: ignore 170 | self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs 171 | ) -> Tensor: 172 | b = x_noisy.shape[0] 173 | sigmas = self.schedule(num_steps + 1, device=x_noisy.device) 174 | sigmas = repeat(sigmas, "i -> i b", b=b) 175 | sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1) 176 | alphas, betas = self.get_alpha_beta(sigmas_batch) 177 | progress_bar = tqdm(range(num_steps), disable=not show_progress) 178 | 179 | for i in progress_bar: 180 | v_pred = self.net(x_noisy, sigmas[i], **kwargs) 181 | x_pred = alphas[i] * x_noisy - betas[i] * v_pred 182 | noise_pred = betas[i] * x_noisy + alphas[i] * v_pred 183 | x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred 184 | progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})") 185 | 186 | return x_noisy 187 | 188 | 189 | class ARVSampler(Sampler): 190 | def __init__(self, net: nn.Module, in_channels: int, length: int, num_splits: int): 191 | super().__init__() 192 | assert length % num_splits == 0, "length must be divisible by num_splits" 193 | self.length = length 194 | self.in_channels = in_channels 195 | self.num_splits = num_splits 196 | self.split_length = length // num_splits 197 | self.net = net 198 | 199 | @property 200 | def device(self): 201 | return next(self.net.parameters()).device 202 | 203 | def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]: 204 | angle = sigmas * pi / 2 205 | alpha = torch.cos(angle) 206 | beta = torch.sin(angle) 207 | return alpha, beta 208 | 209 | def get_sigmas_ladder(self, num_items: int, num_steps_per_split: int) -> Tensor: 210 | b, n, l, i = num_items, self.num_splits, self.split_length, num_steps_per_split 211 | n_half = n // 2 # Only half ladder, rest is zero, to leave some context 212 | sigmas = torch.linspace(1, 0, i * n_half, device=self.device) 213 | sigmas = repeat(sigmas, "(n i) -> i b 1 (n l)", b=b, l=l, n=n_half) 214 | sigmas = torch.flip(sigmas, dims=[-1]) # Lowest noise level first 215 | sigmas = F.pad(sigmas, pad=[0, 0, 0, 0, 0, 0, 0, 1]) # Add index i+1 216 | sigmas[-1, :, :, l:] = sigmas[0, :, :, :-l] # Loop back at index i+1 217 | return torch.cat([torch.zeros_like(sigmas), sigmas], dim=-1) 218 | 219 | def sample_loop( 220 | self, current: Tensor, sigmas: Tensor, show_progress: bool = False, **kwargs 221 | ) -> Tensor: 222 | num_steps = sigmas.shape[0] - 1 223 | alphas, betas = self.get_alpha_beta(sigmas) 224 | progress_bar = tqdm(range(num_steps), disable=not show_progress) 225 | 226 | for i in progress_bar: 227 | channels = torch.cat([current, sigmas[i]], dim=1) 228 | v_pred = self.net(channels, **kwargs) 229 | x_pred = alphas[i] * current - betas[i] * v_pred 230 | noise_pred = betas[i] * current + alphas[i] * v_pred 231 | current = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred 232 | progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0,0,0]:.2f})") 233 | 234 | return current 235 | 236 | def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor: 237 | b, c, t = num_items, self.in_channels, self.length 238 | # Same sigma schedule over all chunks 239 | sigmas = torch.linspace(1, 0, num_steps + 1, device=self.device) 240 | sigmas = repeat(sigmas, "i -> i b 1 t", b=b, t=t) 241 | noise = torch.randn((b, c, t), device=self.device) * sigmas[0] 242 | # Sample start 243 | return self.sample_loop(current=noise, sigmas=sigmas, **kwargs) 244 | 245 | def forward( 246 | self, 247 | num_items: int, 248 | num_chunks: int, 249 | num_steps: int, 250 | start: Optional[Tensor] = None, 251 | show_progress: bool = False, 252 | **kwargs, 253 | ) -> Tensor: 254 | assert_message = f"required at least {self.num_splits} chunks" 255 | assert num_chunks >= self.num_splits, assert_message 256 | 257 | # Sample initial chunks 258 | start = self.sample_start(num_items=num_items, num_steps=num_steps, **kwargs) 259 | # Return start if only num_splits chunks 260 | if num_chunks == self.num_splits: 261 | return start 262 | 263 | # Get sigmas for autoregressive ladder 264 | b, n = num_items, self.num_splits 265 | assert num_steps >= n, "num_steps must be greater than num_splits" 266 | sigmas = self.get_sigmas_ladder( 267 | num_items=b, 268 | num_steps_per_split=num_steps // self.num_splits, 269 | ) 270 | alphas, betas = self.get_alpha_beta(sigmas) 271 | 272 | # Noise start to match ladder and set starting chunks 273 | start_noise = alphas[0] * start + betas[0] * torch.randn_like(start) 274 | chunks = list(start_noise.chunk(chunks=n, dim=-1)) 275 | 276 | # Loop over ladder shifts 277 | num_shifts = num_chunks # - self.num_splits 278 | progress_bar = tqdm(range(num_shifts), disable=not show_progress) 279 | 280 | for j in progress_bar: 281 | # Decrease ladder noise of last n chunks 282 | updated = self.sample_loop( 283 | current=torch.cat(chunks[-n:], dim=-1), sigmas=sigmas, **kwargs 284 | ) 285 | # Update chunks 286 | chunks[-n:] = list(updated.chunk(chunks=n, dim=-1)) 287 | # Add fresh noise chunk 288 | shape = (b, self.in_channels, self.split_length) 289 | chunks += [torch.randn(shape, device=self.device)] 290 | 291 | return torch.cat(chunks[:num_chunks], dim=-1) 292 | --------------------------------------------------------------------------------