├── .gitignore ├── LICENSE ├── README.md ├── args.json ├── cli.py ├── dance_diffusion ├── __init__.py ├── api.py ├── base │ ├── inference.py │ ├── model.py │ └── type.py └── dd │ ├── blocks.py │ ├── ddattnunet.py │ ├── inference.py │ ├── model.py │ └── utils.py ├── diffusion_library ├── __init__.py ├── sampler.py └── scheduler.py ├── environment-mac.yml ├── environment.yml ├── scripts └── trim_model.py ├── setup.py └── util ├── __init__.py ├── platform.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | audio/ 3 | models/ 4 | 5 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 6 | 7 | inputs 8 | audio_in 9 | outputs_from_discord_bot 10 | 11 | # dependencies 12 | node_modules 13 | .pnp 14 | .pnp.js 15 | 16 | # testing 17 | coverage 18 | 19 | # database 20 | prisma/db.sqlite 21 | prisma/db.sqlite-journal 22 | 23 | # next.js 24 | .next/ 25 | out/ 26 | 27 | # expo 28 | .expo/ 29 | dist/ 30 | 31 | # production 32 | build 33 | 34 | # misc 35 | .DS_Store 36 | *.pem 37 | 38 | # debug 39 | npm-debug.log* 40 | yarn-debug.log* 41 | yarn-error.log* 42 | .pnpm-debug.log* 43 | 44 | # local env files 45 | .env 46 | .env*.local 47 | 48 | # vercel 49 | .vercel 50 | 51 | # typescript 52 | *.tsbuildinfo 53 | 54 | # turbo 55 | .turbo 56 | 57 | # conda 58 | condaenv.*.requirements.txt 59 | models 60 | audio_out 61 | sample_diffusion.egg-info 62 | dance_diffusion.egg-info 63 | __pycache__ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022-2023 sample-diffusion contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sample-diffusion 2 | 3 | A Python library and CLI for generating audio samples using Harmonai [Dance Diffusion](https://github.com/Harmonai-org/sample-generator) models. 4 | 5 | 🚧 This project is early in development. Expect breaking changes! 🚧 6 | 7 | ## Features 8 | 9 | - A CLI for generating audio samples from the command line using Dance Diffusion models. (`cli.py`) 10 | - A script for reducing the file size of Dance Diffusion models by removing data that is only needed for training and not inference. (`scripts/trim_model.py`) 11 | 12 | ## Installation 13 | 14 | ### Requirements 15 | 16 | - [git](https://git-scm.com/downloads) (to clone the repo) 17 | - [conda](https://docs.conda.io/en/latest/) (to set up the python environment) 18 | 19 | `conda` can be installed through [Anaconda](https://www.anaconda.com) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html). To run on an Apple Silicon device, you will need to use a conda installation that includes Apple Silicon support, such as [Miniforge](https://github.com/conda-forge/miniforge). 20 | 21 | ### Cloning the repo 22 | 23 | Clone the repo and `cd` into it: 24 | 25 | ```sh 26 | git clone https://github.com/sudosilico/sample-diffusion 27 | cd sample-diffusion 28 | ``` 29 | 30 | ### Setting up the conda environment 31 | 32 | Create the `conda` environment: 33 | 34 | ```sh 35 | # If you're not running on an Apple Silicon machine: 36 | conda env create -f environment.yml 37 | 38 | # For Apple Silicon machines: 39 | conda env create -f environment-mac.yml 40 | ``` 41 | 42 | This may take a few minutes as it will install all the necessary Python dependencies so that they will be available to the CLI script. 43 | 44 | > Note: You must activate the `dd` conda environment after creating it. You can do this by running `conda activate dd` in the terminal. You will need to do this every time you open a new terminal window. Learn more about [conda environments.](https://docs.conda.io/projects/conda/en/latest/user-guide/concepts/environments.html) 45 | 46 | ```sh 47 | conda activate dd 48 | ``` 49 | 50 | ## Using the `cli.py` CLI 51 | 52 | ### Generating samples 53 | 54 | Make a `models` folder and place your model in `models/DD/model.ckpt`, then run the generator: 55 | 56 | ```sh 57 | python cli.py 58 | ``` 59 | 60 | Alternatively, you can pass a custom model path as an argument instead of using the `models/DD/model.ckpt` default path: 61 | 62 | ```sh 63 | python cli.py --model models/DD/some-other-model.ckpt 64 | ``` 65 | 66 | Your audio samples will then be in one of the following folders: 67 | 68 | - `audio/Output/DD/{mode}/{seed}_{steps}` 69 | 70 | ### `cli.py` Command Line Arguments 71 | 72 | | argument | type | default | desc | 73 | |---------------------------|------------------|------------------------|----------------------------------------------------------------------------------------| 74 | | `--argsfile` | str | None | Path to JSON file containing cli args. If used, other passed cli args are ignored. | 75 | | `--use_autocast` | bool | True | Use autocast. | 76 | | `--crop_offset` | int | 0 | The starting sample offset to crop input audio to. Use -1 for random cropping. | 77 | | `--device_accelerator` | str | None | Device of execution. | 78 | | `--device_offload` | str | `cpu` | Device to store models when not in use. | 79 | | `--model` | str | `models/dd/model.ckpt` | Path to the model checkpoint file to be used (default: models/dd/model.ckpt). | 80 | | `--sample_rate` | int | 48000 | The samplerate the model was trained on. | 81 | | `--chunk_size` | int | 65536 | The native chunk size of the model. | 82 | | `--mode` | RequestType | `Generation` | The mode of operation (Generation, Variation, Interpolation, Inpainting or Extension). | 83 | | `--seed` | int | -1 (Random) | The seed used for reproducable outputs. Leave empty for random seed. | 84 | | `--batch_size` | int | 1 | The maximal number of samples to be produced per batch. | 85 | | `--audio_source` | str | None | Path to the audio source. | 86 | | `--audio_target` | str | None | Path to the audio target (used for interpolations). | 87 | | `--mask` | str | None | Path to the mask tensor (used for inpainting). | 88 | | `--noise_level` | float | 0.7 | The noise level used for variations & interpolations. | 89 | | `--interpolations_linear` | int | 1 | The number of interpolations, even spacing. | 90 | | `--interpolations` | float or float[] | None | The interpolation positions. | 91 | | `--keep_start` | bool | True | Keep beginning of audio provided(only applies to mode Extension). | 92 | | `--tame` | bool | True | Decrease output by 3db, then clip. | 93 | | `--steps` | int | 50 | The number of steps for the sampler. | 94 | | `--sampler` | SamplerType | `IPLMS` | The sampler used for the diffusion model. | 95 | | `--sampler_args` | Json String | `{}` | Additional arguments of the DD sampler. | 96 | | `--schedule` | SchedulerType | `CrashSchedule` | The schedule used for the diffusion model. | 97 | | `--schedule_args` | Json String | `{}` | Additional arguments of the DD schedule. | 98 | | `--inpainting_args` | Json String | `{}` | Additional arguments for inpainting (currently unsupported) | 99 | 100 | ### Using args.json 101 | Instead of specifying all the necessary arguments each time we encourage you to try using the args.json file provided with this library: 102 | ```sh 103 | python cli.py --argsfile 'args.json' 104 | ``` 105 | To change any settings you can edit the args.json file. 106 | 107 | ## Using the model trimming script 108 | 109 | `scripts/trim_model.py` can be used to reduce the file size of Dance Diffusion models by removing data that is only needed for training and not inference. For our first models, this reduced the model size by about 75% (from 3.46 GB to 0.87 GB). 110 | 111 | To use it, simply pass the path to the model you want to trim as an argument: 112 | 113 | ```sh 114 | python scripts/trim_model.py models/model.ckpt 115 | ``` 116 | 117 | This will create a new model file at `models/model_trim.ckpt`. 118 | -------------------------------------------------------------------------------- /args.json: -------------------------------------------------------------------------------- 1 | { 2 | "optimize_memory_use": true, 3 | 4 | "device_accelerator": null, 5 | "device_offload": "cpu", 6 | 7 | "use_autocast": true, 8 | "use_autocrop": true, 9 | "crop_offset": 0, 10 | "tame": true, 11 | 12 | "model_type": "DD", 13 | "model": "models/DD/model.ckpt", 14 | 15 | "sample_rate": 48000, 16 | "chunk_size": 65536, 17 | 18 | "mode": "Generation", 19 | 20 | "seed": -1, 21 | "batch_size": 4, 22 | 23 | "audio_source": null, 24 | "audio_target": null, 25 | 26 | "noise_level": 1.0, 27 | 28 | "interpolations_linear": 4, 29 | 30 | "keep_start": true, 31 | 32 | "steps": 25, 33 | "sampler": "K_DPMPP2M", 34 | "sampler_args": { 35 | }, 36 | "schedule": "K_POLYEXPONENTIAL", 37 | "schedule_args": { 38 | "sigma_min": 0.15, 39 | "sigma_max": 50.0, 40 | "rho": 1.0 41 | }, 42 | 43 | "inpainting_args": { 44 | "method": "posterior_guidance", 45 | "posterior_guidance_scale": 12 46 | } 47 | } -------------------------------------------------------------------------------- /cli.py: -------------------------------------------------------------------------------- 1 | import json, torch, argparse, os, logging 2 | 3 | from util.util import load_audio, save_audio, crop_audio 4 | from util.platform import get_torch_device_type 5 | from dance_diffusion.api import RequestHandler, Request, Response, RequestType, ModelType 6 | from diffusion_library.sampler import SamplerType 7 | from diffusion_library.scheduler import SchedulerType 8 | from transformers import logging as transformers_logging 9 | 10 | def main(): 11 | args = parse_cli_args() 12 | 13 | device_type_accelerator = args.get('device_accelerator') if(args.get('device_accelerator') != None) else get_torch_device_type() 14 | device_accelerator = torch.device(device_type_accelerator) 15 | device_offload = torch.device(args.get('device_offload')) 16 | 17 | crop = lambda audio: crop_audio(audio, args.get('chunk_size'), args.get('crop_offset')) if args.get('crop_offset') is not None else audio 18 | load_input = lambda source: crop(load_audio(device_accelerator, source, args.get('sample_rate'))) if source is not None else None 19 | 20 | request_handler = RequestHandler(device_accelerator, device_offload, optimize_memory_use=args.get('optimize_memory_use'), use_autocast=args.get('use_autocast')) 21 | 22 | seed = args.get('seed') if(args.get('seed') != -1) else torch.randint(0, 4294967294, [1], device=device_type_accelerator).item() 23 | print(f"Using accelerator: {device_type_accelerator}, Seed: {seed}.") 24 | 25 | request = Request( 26 | request_type=args.get('mode'), 27 | model_path=args.get('model'), 28 | model_type=args.get('model_type'), 29 | model_chunk_size=args.get('chunk_size'), 30 | model_sample_rate=args.get('sample_rate'), 31 | 32 | seed=seed, 33 | batch_size=args.get('batch_size'), 34 | 35 | audio_source=load_input(args.get("audio_source")), 36 | audio_target=load_input(args.get("audio_target")), 37 | 38 | mask=torch.load(args.get('mask')) if(args.get('mask') != None) else None, 39 | 40 | noise_level=args.get('noise_level'), 41 | interpolation_positions=args.get('interpolations') if(args.get('interpolations_linear') == None) else torch.linspace(0, 1, args.get('interpolations_linear'), device=device_accelerator), 42 | keep_start=args.get('keep_start'), 43 | 44 | steps=args.get('steps'), 45 | 46 | sampler_type=args.get('sampler'), 47 | sampler_args=args.get('sampler_args'), 48 | 49 | scheduler_type=args.get('schedule'), 50 | scheduler_args=args.get('schedule_args'), 51 | 52 | inpainting_args=args.get('inpainting_args') 53 | ) 54 | 55 | response = request_handler.process_request(request) 56 | save_audio((0.5 * response.result).clamp(-1,1) if(args.get('tame') == True) else response.result, f"audio/Output/{args.get('model_type')}/{args.get('mode')}/", args.get('sample_rate'), f"{seed}") 57 | 58 | def str2bool(value): 59 | if value.lower() in ('yes', 'true', 't', 'y', '1'): 60 | return True 61 | elif value.lower() in ('no', 'false', 'f', 'n', '0'): 62 | return False 63 | else: 64 | raise argparse.ArgumentTypeError('Boolean value expected.') 65 | 66 | def parse_cli_args(): 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument( 69 | "--argsfile", 70 | type=str, 71 | default=None, 72 | help="When used, uses args from a provided .json file instead of using the passed cli args." 73 | ) 74 | parser.add_argument( 75 | "--crop_offset", 76 | type=int, 77 | default=0, 78 | help="The starting sample offset to crop input audio to. Use -1 for random cropping." 79 | ) 80 | parser.add_argument( 81 | "--optimize_memory_use", 82 | type=str2bool, 83 | default=True, 84 | help="Try to minimize memory use during execution, might decrease performance." 85 | ) 86 | parser.add_argument( 87 | "--use_autocast", 88 | type=str2bool, 89 | default=True, 90 | help="Use autocast." 91 | ) 92 | parser.add_argument( 93 | "--use_autocrop", 94 | type=str2bool, 95 | default=True, 96 | help="Use autocrop(automatically crops audio provided to chunk_size)." 97 | ) 98 | parser.add_argument( 99 | "--device_accelerator", 100 | type=str, 101 | default=None, 102 | help="Device of execution." 103 | ) 104 | parser.add_argument( 105 | "--device_offload", 106 | type=str, 107 | default="cpu", 108 | help="Device to store models when not in use." 109 | ) 110 | parser.add_argument( 111 | "--model", 112 | type=str, 113 | default="models/DD/model.ckpt", 114 | help="Path to the model checkpoint file to be used (default: models/DD/model.ckpt)." 115 | ) 116 | parser.add_argument( 117 | "--model_type", 118 | type=ModelType, 119 | choices=ModelType, 120 | default=ModelType.DD, 121 | help="The model type." 122 | ) 123 | parser.add_argument( 124 | "--sample_rate", 125 | type=int, 126 | default=48000, 127 | help="The samplerate the model was trained on." 128 | ) 129 | parser.add_argument( 130 | "--chunk_size", 131 | type=int, 132 | default=65536, 133 | help="The native chunk size of the model." 134 | ) 135 | parser.add_argument( 136 | "--mode", 137 | type=RequestType, 138 | choices=RequestType, 139 | default=RequestType.Generation, 140 | help="The mode of operation (Generation, Variation, Interpolation, Inpainting, Extension or Upscaling)." 141 | ) 142 | parser.add_argument( 143 | "--seed", 144 | type=int, 145 | default=-1, 146 | help="The seed used for reproducable outputs. Leave empty for random seed." 147 | ) 148 | parser.add_argument( 149 | "--batch_size", 150 | type=int, 151 | default=1, 152 | help="The maximal number of samples to be produced per batch." 153 | ) 154 | parser.add_argument( 155 | "--audio_source", 156 | type=str, 157 | default=None, 158 | help="Path to the audio source." 159 | ) 160 | parser.add_argument( 161 | "--audio_target", 162 | type=str, 163 | default=None, 164 | help="Path to the audio target (used for interpolations)." 165 | ) 166 | parser.add_argument( 167 | "--mask", 168 | type=str, 169 | default=None, 170 | help="Path to the mask tensor (used for inpainting)." 171 | ) 172 | parser.add_argument( 173 | "--noise_level", 174 | type=float, 175 | default=0.7, 176 | help="The noise level used for variations & interpolations." 177 | ) 178 | parser.add_argument( 179 | "--interpolations_linear", 180 | type=int, 181 | default=1, 182 | help="The number of interpolations, even spacing." 183 | ) 184 | parser.add_argument( 185 | "--interpolations", 186 | nargs='+', 187 | type=float, 188 | default=None, 189 | help="The interpolation positions." 190 | ) 191 | parser.add_argument( 192 | "--keep_start", 193 | type=str2bool, 194 | default=True, 195 | help="Keep beginning of audio provided(only applies to mode Extension)." 196 | ) 197 | parser.add_argument( 198 | "--tame", 199 | type=str2bool, 200 | default=True, 201 | help="Decrease output by 3db, then clip." 202 | ) 203 | parser.add_argument( 204 | "--steps", 205 | type=int, 206 | default=50, 207 | help="The number of steps for the sampler." 208 | ) 209 | parser.add_argument( 210 | "--sampler", 211 | type=SamplerType, 212 | choices=SamplerType, 213 | default=SamplerType.V_IPLMS, 214 | help="The sampler used for the diffusion model." 215 | ) 216 | parser.add_argument( 217 | "--sampler_args", 218 | type=json.loads, 219 | default={}, 220 | help="Additional arguments of the DD sampler." 221 | ) 222 | parser.add_argument( 223 | "--schedule", 224 | type=SchedulerType, 225 | choices=SchedulerType, 226 | default=SchedulerType.V_CRASH, 227 | help="The schedule used for the diffusion model." 228 | ) 229 | parser.add_argument( 230 | "--schedule_args", 231 | type=json.loads, 232 | default={}, 233 | help="Additional arguments of the DD schedule." 234 | ) 235 | parser.add_argument( 236 | "--inpaint_args", 237 | type=json.loads, 238 | default={}, 239 | help="Arguments for inpainting." 240 | ) 241 | 242 | args = parser.parse_args() 243 | 244 | if args.argsfile is not None: 245 | if os.path.exists(args.argsfile): 246 | with open(args.argsfile, "r") as f: 247 | print(f"Using cli args from file: {args.argsfile}") 248 | args = json.load(f) 249 | 250 | # parse enum objects from strings & apply defaults 251 | args['sampler'] = SamplerType(args.get('sampler', SamplerType.V_IPLMS)) 252 | args['schedule'] = SchedulerType(args.get('schedule', SchedulerType.V_CRASH)) 253 | 254 | return args 255 | else: 256 | print(f"Could not locate argsfile: {args.argsfile}") 257 | 258 | return vars(args) 259 | 260 | if __name__ == '__main__': 261 | transformers_logging.set_verbosity_error() 262 | main() -------------------------------------------------------------------------------- /dance_diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sudosilico/sample-diffusion/0729aaa3e1f77401008bcb89713beaeacde76781/dance_diffusion/__init__.py -------------------------------------------------------------------------------- /dance_diffusion/api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import enum 3 | 4 | from dataclasses import dataclass 5 | from typing import Callable 6 | 7 | from .base.type import ModelType 8 | from .base.model import ModelWrapperBase 9 | from .base.inference import InferenceBase 10 | 11 | from .dd.model import DDModelWrapper 12 | from .dd.inference import DDInference 13 | 14 | class RequestType(str, enum.Enum): 15 | Generation = 'Generation' 16 | Variation = 'Variation' 17 | Interpolation = 'Interpolation' 18 | Inpainting = 'Inpainting' 19 | Extension = 'Extension' 20 | 21 | class Request: 22 | def __init__( 23 | self, 24 | request_type: RequestType, 25 | model_path: str, 26 | model_type: ModelType, 27 | model_chunk_size: int, 28 | model_sample_rate: int, 29 | **kwargs 30 | ): 31 | self.request_type = request_type 32 | self.model_path = model_path 33 | self.model_type = model_type 34 | self.model_chunk_size = model_chunk_size 35 | self.model_sample_rate = model_sample_rate 36 | self.kwargs = kwargs 37 | 38 | 39 | class Response: 40 | def __init__( 41 | self, 42 | result: torch.Tensor 43 | ): 44 | self.result = result 45 | 46 | 47 | class RequestHandler: 48 | def __init__( 49 | self, 50 | device_accelerator: torch.device, 51 | device_offload: torch.device = None, 52 | optimize_memory_use: bool = False, 53 | use_autocast: bool = True 54 | ): 55 | self.device_accelerator = device_accelerator 56 | self.device_offload = device_offload 57 | 58 | self.model_wrapper: ModelWrapperBase = None 59 | self.inference: InferenceBase = None 60 | 61 | self.optimize_memory_use = optimize_memory_use 62 | self.use_autocast = use_autocast 63 | 64 | def process_request( 65 | self, 66 | request: Request, 67 | callback: Callable = None 68 | ) -> Response: 69 | # load the model from the request if it's not already loaded 70 | if (self.model_wrapper == None): 71 | self.load_model( 72 | request.model_type, 73 | request.model_path, 74 | request.model_chunk_size, 75 | request.model_sample_rate 76 | ) 77 | elif (request.model_path != self.model_wrapper.path): 78 | del self.model_wrapper, self.inference 79 | self.load_model( 80 | request.model_type, 81 | request.model_path, 82 | request.model_chunk_size, 83 | request.model_sample_rate 84 | ) 85 | 86 | handlers_by_request_type = { 87 | RequestType.Generation: self.handle_generation, 88 | RequestType.Variation: self.handle_variation, 89 | RequestType.Interpolation: self.handle_interpolation, 90 | RequestType.Inpainting: self.handle_inpainting, 91 | RequestType.Extension: self.handle_extension, 92 | } 93 | 94 | Handler = handlers_by_request_type.get(request.request_type) 95 | 96 | if Handler: 97 | tensor_result = Handler(request, callback) 98 | else: 99 | raise ValueError('Unexpected RequestType in process_request') 100 | 101 | return Response(tensor_result) 102 | 103 | def load_model(self, model_type, model_path, chunk_size, sample_rate): 104 | wrappers_by_model_type = { 105 | ModelType.DD: [DDModelWrapper, DDInference] 106 | } 107 | 108 | [Wrapper, Inference] = wrappers_by_model_type.get(model_type, [None, None]) 109 | 110 | if Wrapper: 111 | self.model_wrapper = Wrapper() 112 | self.model_wrapper.load( 113 | model_path, 114 | self.device_accelerator, 115 | self.optimize_memory_use, 116 | chunk_size, 117 | sample_rate 118 | ) 119 | self.inference = Inference( 120 | self.device_accelerator, 121 | self.device_offload, 122 | self.optimize_memory_use, 123 | self.use_autocast, 124 | self.model_wrapper 125 | ) 126 | else: 127 | raise ValueError("Unexpected ModelType in load_model") 128 | 129 | def handle_generation(self, request: Request, callback: Callable) -> Response: 130 | kwargs = request.kwargs.copy() 131 | 132 | if request.model_type in [ModelType.DD]: 133 | return self.inference.generate( 134 | callback=callback, 135 | scheduler=kwargs['scheduler_type'], 136 | sampler=kwargs['sampler_type'], 137 | **kwargs 138 | ) 139 | else: 140 | raise ValueError("Unexpected ModelType in handle_generation") 141 | 142 | def handle_variation(self, request: Request, callback: Callable) -> torch.Tensor: 143 | kwargs = request.kwargs.copy() 144 | kwargs.update( 145 | expansion_map = [kwargs['batch_size']], 146 | audio_source = kwargs['audio_source'][None,:,:] 147 | ) 148 | 149 | if request.model_type in [ModelType.DD]: 150 | return self.inference.generate_variation( 151 | callback=callback, 152 | scheduler=kwargs['scheduler_type'], 153 | sampler=kwargs['sampler_type'], 154 | **kwargs 155 | ) 156 | else: 157 | raise ValueError("Unexpected ModelType in handle_variation") 158 | 159 | def handle_interpolation(self, request: Request, callback: Callable) -> torch.Tensor: 160 | kwargs = request.kwargs.copy() 161 | kwargs.update( 162 | batch_size = len(kwargs['interpolation_positions']), 163 | audio_source = kwargs['audio_source'][None,:,:], 164 | audio_target = kwargs['audio_target'][None,:,:] 165 | ) 166 | 167 | if request.model_type in [ModelType.DD]: 168 | return self.inference.generate_interpolation( 169 | callback=callback, 170 | scheduler=kwargs['scheduler_type'], 171 | sampler=kwargs['sampler_type'], 172 | **kwargs 173 | ) 174 | else: 175 | raise ValueError("Unexpected ModelType in handle_interpolation") 176 | 177 | def handle_inpainting(self, request: Request, callback: Callable) -> torch.Tensor: 178 | kwargs = request.kwargs.copy() 179 | kwargs.update( 180 | expansion_map = [kwargs['batch_size']], 181 | audio_source = kwargs['audio_source'][None,:,:] 182 | ) 183 | 184 | if request.model_type == [ModelType.DD]: 185 | return self.inference.generate_inpainting( 186 | callback=callback, 187 | scheduler=kwargs['scheduler_type'], 188 | sampler=kwargs['sampler_type'], 189 | **kwargs 190 | ) 191 | else: 192 | raise ValueError("Unexpected ModelType in handle_inpainting") 193 | 194 | def handle_extension(self, request: Request, callback: Callable) -> torch.Tensor: 195 | kwargs = request.kwargs.copy() 196 | kwargs.update( 197 | expansion_map = [kwargs['batch_size']], 198 | audio_source = kwargs['audio_source'][None,:,:] 199 | ) 200 | 201 | if request.model_type in [ModelType.DD]: 202 | return self.inference.generate_extension( 203 | callback=callback, 204 | scheduler=kwargs['scheduler_type'], 205 | sampler=kwargs['sampler_type'], 206 | **kwargs 207 | ) 208 | else: 209 | raise ValueError("Unexpected ModelType in handle_extension") -------------------------------------------------------------------------------- /dance_diffusion/base/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import enum 3 | import contextlib 4 | from contextlib import nullcontext 5 | from typing import Tuple 6 | 7 | from .model import ModelWrapperBase 8 | 9 | class InferenceBase(): 10 | def __init__( 11 | self, 12 | device_accelerator: torch.device, 13 | device_offload: torch.device, 14 | optimize_memory_use: bool, 15 | use_autocast: bool, 16 | model: ModelWrapperBase 17 | ): 18 | self.device_accelerator = device_accelerator 19 | self.device_offload = device_offload if(optimize_memory_use==True) else None 20 | self.optimize_memory_use = optimize_memory_use 21 | self.use_autocast = use_autocast 22 | self.model = model 23 | self.generator = torch.Generator(device_accelerator)# if (device_accelerator.type != 'mps') else torch.device('cpu')) 24 | self.rng_state = None 25 | 26 | def set_device_accelerator( 27 | self, 28 | device: torch.device = None 29 | ): 30 | self.device_accelerator = device 31 | 32 | def get_device_accelerator( 33 | self 34 | ) -> torch.device: 35 | return self.device_accelerator 36 | 37 | def set_model( 38 | self, 39 | model: ModelWrapperBase = None 40 | ): 41 | self.model = model 42 | 43 | def get_model( 44 | self 45 | ) -> ModelWrapperBase: 46 | return self.model 47 | 48 | def expand( 49 | self, 50 | tensor: torch.Tensor, 51 | expansion_map: list[int] 52 | ) -> torch.Tensor: 53 | out = torch.empty([0], device=self.device_accelerator) 54 | 55 | for i in range(tensor.shape[0]): 56 | out = torch.cat([out, tensor[i,:,:].expand(expansion_map[i], -1, -1)], 0) 57 | 58 | return out 59 | 60 | 61 | # def cc_randn(self, shape:tuple, seed:int, device:torch.device, dtype = None, rng_state_in:torch.Tensor = None): 62 | 63 | # initial_rng_state = self.generator.get_state() 64 | # rng_state_out = torch.empty([shape[0], shape[1]], dtype=torch.ByteTensor,device=self.generator.device) 65 | 66 | # rn = torch.empty(shape,device=device, dtype=dtype, device=device) 67 | 68 | # for sample in range(shape[0]): 69 | # for channel in range(shape[1]): 70 | # self.generator.manual_seed(seed + sample * shape[1] + channel) if(rng_state_in == None) else self.generator.set_state(rng_state_in[sample, channel]) 71 | # rn[sample, channel] = torch.randn([shape[2]], generator=self.generator, dtype=dtype, device=device) 72 | # rng_state_out[sample, channel] = self.generator.get_state() 73 | 74 | # self.rng_state = rng_state_out 75 | # self.generator.set_state(initial_rng_state) 76 | # return rn 77 | 78 | # def cc_randn_like(self, input:torch.Tensor, seed:int, rng_state_in:torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: 79 | 80 | # initial_rng_state = self.generator.get_state() 81 | # rng_state_out = torch.empty([input.shape[0], input.shape[1]], dtype=torch.ByteTensor,device=self.generator.device) 82 | 83 | # rn = torch.empty_like(input) 84 | 85 | # for sample in range(input.shape[0]): 86 | # for channel in range(input.shape[1]): 87 | # self.generator.manual_seed(seed + sample * input.shape[1] + channel) if(rng_state_in == None) else self.generator.set_state(rng_state_in[sample, channel]) 88 | # rn[sample, channel] = torch.randn([input.shape[2]], generator=self.generator, dtype=input.dtype, device=input.device) 89 | # rng_state_out[sample, channel] = self.generator.get_state() 90 | 91 | # self.rng_state = rng_state_out 92 | # self.generator.set_state(initial_rng_state) 93 | # return rn 94 | 95 | 96 | def autocast_context(self): 97 | if self.device_accelerator.type == 'cuda': 98 | return torch.cuda.amp.autocast() 99 | elif self.device_accelerator.type == 'cpu': 100 | return torch.cpu.amp.autocast() 101 | elif self.device_accelerator.type == 'mps': 102 | return nullcontext() 103 | else: 104 | return torch.autocast(self.device_accelerator.type, dtype=torch.float32) 105 | 106 | @contextlib.contextmanager 107 | def offload_context(self, model): 108 | """ 109 | Used by inference implementations, this context manager moves the 110 | passed model to the inference's `device_accelerator` device on enter, 111 | and then returns it to the `device_offload` device on exit. 112 | 113 | It also wraps the `inference.autocast_context()` context. 114 | """ 115 | 116 | autocast = self.autocast_context() if self.use_autocast else nullcontext() 117 | 118 | with autocast: 119 | if self.optimize_memory_use: 120 | model.to(self.device_accelerator) 121 | 122 | yield None 123 | 124 | if self.optimize_memory_use: 125 | model.to(self.device_offload) -------------------------------------------------------------------------------- /dance_diffusion/base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ModelWrapperBase(): 4 | 5 | def __init__(self): 6 | #self.uuid: str = None 7 | #self.name: str = None 8 | self.path: str = None 9 | 10 | self.device_accelerator: torch.device = None 11 | 12 | self.chunk_size: int = None 13 | self.sample_rate: int = None 14 | 15 | 16 | def load( 17 | self, 18 | path: str, 19 | device_accelerator: torch.device, 20 | optimize_memory_use:bool=False, 21 | chunk_size: int=131072, 22 | sample_rate: int=48000 23 | ): 24 | raise NotImplementedError -------------------------------------------------------------------------------- /dance_diffusion/base/type.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | class ModelType(str, enum.Enum): 4 | DD = 'DD' -------------------------------------------------------------------------------- /dance_diffusion/dd/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from . import utils 7 | 8 | class ResidualBlock(nn.Module): 9 | def __init__(self, main, skip=None): 10 | super().__init__() 11 | self.main = nn.Sequential(*main) 12 | self.skip = skip if skip else nn.Identity() 13 | 14 | def forward(self, input): 15 | return self.main(input) + self.skip(input) 16 | 17 | 18 | # Noise level (and other) conditioning 19 | 20 | class ConditionedModule(nn.Module): 21 | pass 22 | 23 | 24 | class UnconditionedModule(ConditionedModule): 25 | def __init__(self, module): 26 | self.module = module 27 | 28 | def forward(self, input, cond): 29 | return self.module(input) 30 | 31 | 32 | class ConditionedSequential(nn.Sequential, ConditionedModule): 33 | def forward(self, input, cond): 34 | for module in self: 35 | if isinstance(module, ConditionedModule): 36 | input = module(input, cond) 37 | else: 38 | input = module(input) 39 | return input 40 | 41 | 42 | class ConditionedResidualBlock(ConditionedModule): 43 | def __init__(self, *main, skip=None): 44 | super().__init__() 45 | self.main = ConditionedSequential(*main) 46 | self.skip = skip if skip else nn.Identity() 47 | 48 | def forward(self, input, cond): 49 | skip = self.skip(input, cond) if isinstance(self.skip, ConditionedModule) else self.skip(input) 50 | return self.main(input, cond) + skip 51 | 52 | 53 | class AdaGN(ConditionedModule): 54 | def __init__(self, feats_in, c_out, num_groups, eps=1e-5, cond_key='cond'): 55 | super().__init__() 56 | self.num_groups = num_groups 57 | self.eps = eps 58 | self.cond_key = cond_key 59 | self.mapper = nn.Linear(feats_in, c_out * 2) 60 | 61 | def forward(self, input, cond): 62 | weight, bias = self.mapper(cond[self.cond_key]).chunk(2, dim=-1) 63 | input = F.group_norm(input, self.num_groups, eps=self.eps) 64 | return torch.addcmul(utils.append_dims(bias, input.ndim), input, utils.append_dims(weight, input.ndim) + 1) 65 | 66 | class ResConvBlock(ResidualBlock): 67 | def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5): 68 | skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) 69 | super().__init__([ 70 | nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2), 71 | nn.GroupNorm(1, c_mid), 72 | nn.GELU(), 73 | nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2), 74 | nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), 75 | nn.GELU() if not is_last else nn.Identity(), 76 | ], skip) 77 | 78 | class OutConvBlock(nn.Sequential): 79 | def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5): 80 | super().__init__( 81 | nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2), 82 | nn.GroupNorm(1, c_mid), 83 | nn.GELU(), 84 | nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2), 85 | nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), 86 | nn.GELU() if not is_last else nn.Identity(), 87 | ) 88 | 89 | class ResModConvBlock(ConditionedResidualBlock): 90 | def __init__(self, feats_in, c_in, c_mid, c_out, group_size=32, dropout_rate=0.): 91 | skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) 92 | super().__init__( 93 | AdaGN(feats_in, c_in, max(1, c_in // group_size)), 94 | nn.GELU(), 95 | nn.Conv1d(c_in, c_mid, 3, padding=1), 96 | nn.Dropout(dropout_rate, inplace=True), 97 | AdaGN(feats_in, c_mid, max(1, c_mid // group_size)), 98 | nn.GELU(), 99 | nn.Conv1d(c_mid, c_out, 3, padding=1), 100 | nn.Dropout(dropout_rate, inplace=True), 101 | skip=skip) 102 | 103 | class SelfAttention1d(nn.Module): 104 | def __init__(self, c_in, n_head=1, dropout_rate=0.): 105 | super().__init__() 106 | assert c_in % n_head == 0 107 | self.norm = nn.GroupNorm(1, c_in) 108 | self.n_head = n_head 109 | self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) 110 | self.out_proj = nn.Conv1d(c_in, c_in, 1) 111 | self.dropout = nn.Dropout(dropout_rate, inplace=True) 112 | 113 | def forward(self, input): 114 | n, c, s = input.shape 115 | qkv = self.qkv_proj(self.norm(input)) 116 | qkv = qkv.view( 117 | [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) 118 | q, k, v = qkv.chunk(3, dim=1) 119 | scale = k.shape[3]**-0.25 120 | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) 121 | y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) 122 | return input + self.dropout(self.out_proj(y)) 123 | 124 | class SelfAttentionMod1d(ConditionedModule): 125 | def __init__(self, c_in, n_head, norm, dropout_rate=0.): 126 | super().__init__() 127 | assert c_in % n_head == 0 128 | self.norm_in = norm(c_in) 129 | self.n_head = n_head 130 | self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) 131 | self.out_proj = nn.Conv1d(c_in, c_in, 1) 132 | self.dropout = nn.Dropout(dropout_rate) 133 | 134 | def forward(self, input, cond): 135 | n, c, s = input.shape 136 | qkv = self.qkv_proj(self.norm_in(input, cond)) 137 | qkv = qkv.view( 138 | [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) 139 | q, k, v = qkv.chunk(3, dim=1) 140 | scale = k.shape[3]**-0.25 141 | att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) 142 | att = self.dropout(att) 143 | y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) 144 | return input + self.out_proj(y) 145 | 146 | class SkipBlock(nn.Module): 147 | def __init__(self, *main): 148 | super().__init__() 149 | self.main = nn.Sequential(*main) 150 | 151 | def forward(self, input): 152 | return torch.cat([self.main(input), input], dim=1) 153 | 154 | class FourierFeatures(nn.Module): 155 | def __init__(self, in_features, out_features, std=1.): 156 | super().__init__() 157 | assert out_features % 2 == 0 158 | self.weight = nn.Parameter(torch.randn( 159 | [out_features // 2, in_features]) * std) 160 | 161 | def forward(self, input): 162 | f = 2 * math.pi * input @ self.weight.T 163 | return torch.cat([f.cos(), f.sin()], dim=-1) 164 | 165 | 166 | def expand_to_planes(input, shape): 167 | return input[..., None].repeat([1, 1, shape[2]]) 168 | 169 | _kernels = { 170 | 'linear': 171 | [1 / 8, 3 / 8, 3 / 8, 1 / 8], 172 | 'cubic': 173 | [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 174 | 0.43359375, 0.11328125, -0.03515625, -0.01171875], 175 | 'lanczos3': 176 | [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, 177 | -0.066637322306633, 0.13550527393817902, 0.44638532400131226, 178 | 0.44638532400131226, 0.13550527393817902, -0.066637322306633, 179 | -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] 180 | } 181 | 182 | 183 | class Downsample1d(nn.Module): 184 | def __init__(self, kernel='linear', pad_mode='reflect'): 185 | super().__init__() 186 | self.pad_mode = pad_mode 187 | kernel_1d = torch.tensor(_kernels[kernel]) 188 | self.pad = kernel_1d.shape[0] // 2 - 1 189 | self.register_buffer('kernel', kernel_1d) 190 | 191 | def forward(self, x): 192 | x = F.pad(x, (self.pad,) * 2, self.pad_mode) 193 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) 194 | indices = torch.arange(x.shape[1], device=x.device) 195 | weight[indices, indices] = self.kernel.to(weight) 196 | return F.conv1d(x, weight, stride=2) 197 | 198 | 199 | class Upsample1d(nn.Module): 200 | def __init__(self, kernel='linear', pad_mode='reflect'): 201 | super().__init__() 202 | self.pad_mode = pad_mode 203 | kernel_1d = torch.tensor(_kernels[kernel]) * 2 204 | self.pad = kernel_1d.shape[0] // 2 - 1 205 | self.register_buffer('kernel', kernel_1d) 206 | 207 | def forward(self, x): 208 | x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) 209 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) 210 | indices = torch.arange(x.shape[1], device=x.device) 211 | weight[indices, indices] = self.kernel.to(weight) 212 | return F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) 213 | 214 | def Downsample1d_2( 215 | in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 216 | ) -> nn.Module: 217 | assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 218 | 219 | return nn.Conv1d( 220 | in_channels=in_channels, 221 | out_channels=out_channels, 222 | kernel_size=factor * kernel_multiplier + 1, 223 | stride=factor, 224 | padding=factor * (kernel_multiplier // 2), 225 | ) 226 | 227 | 228 | def Upsample1d_2( 229 | in_channels: int, out_channels: int, factor: int, use_nearest: bool = False 230 | ) -> nn.Module: 231 | 232 | if factor == 1: 233 | return nn.Conv1d( 234 | in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 235 | ) 236 | 237 | if use_nearest: 238 | return nn.Sequential( 239 | nn.Upsample(scale_factor=factor, mode="nearest"), 240 | nn.Conv1d( 241 | in_channels=in_channels, 242 | out_channels=out_channels, 243 | kernel_size=3, 244 | padding=1, 245 | ), 246 | ) 247 | else: 248 | return nn.ConvTranspose1d( 249 | in_channels=in_channels, 250 | out_channels=out_channels, 251 | kernel_size=factor * 2, 252 | stride=factor, 253 | padding=factor // 2 + factor % 2, 254 | output_padding=factor % 2, 255 | ) 256 | 257 | 258 | # U-Nets 259 | 260 | class DBlock(ConditionedSequential): 261 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., downsample_ratio=1, self_attn=False): 262 | modules = [Downsample1d_2(c_in, c_in, downsample_ratio)] if downsample_ratio > 1 else [] 263 | for i in range(n_layers): 264 | my_c_in = c_in if i == 0 else c_mid 265 | my_c_out = c_mid if i < n_layers - 1 else c_out 266 | modules.append(ResModConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 267 | if self_attn: 268 | norm = lambda c_in: AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 269 | modules.append(SelfAttentionMod1d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 270 | super().__init__(*modules) 271 | 272 | 273 | class UBlock(ConditionedSequential): 274 | def __init__(self, n_layers, feats_in, c_in, c_mid, c_out, group_size=32, head_size=64, dropout_rate=0., upsample_ratio=1, self_attn=False): 275 | modules = [] 276 | for i in range(n_layers): 277 | my_c_in = c_in if i == 0 else c_mid 278 | my_c_out = c_mid if i < n_layers - 1 else c_out 279 | modules.append(ResModConvBlock(feats_in, my_c_in, c_mid, my_c_out, group_size, dropout_rate)) 280 | if self_attn: 281 | norm = lambda c_in: AdaGN(feats_in, c_in, max(1, my_c_out // group_size)) 282 | modules.append(SelfAttentionMod1d(my_c_out, max(1, my_c_out // head_size), norm, dropout_rate)) 283 | if upsample_ratio > 1: 284 | modules.append(Upsample1d_2(c_out, c_out, upsample_ratio)) 285 | super().__init__(*modules) 286 | 287 | def forward(self, input, cond, skip=None): 288 | if skip is not None: 289 | input = torch.cat([input, skip], dim=1) 290 | return super().forward(input, cond) 291 | 292 | class MappingNet(nn.Sequential): 293 | def __init__(self, feats_in, feats_out, n_layers=2): 294 | layers = [] 295 | for i in range(n_layers): 296 | layers.append(nn.Linear(feats_in if i == 0 else feats_out, feats_out)) 297 | layers.append(nn.GELU()) 298 | super().__init__(*layers) 299 | for layer in self: 300 | if isinstance(layer, nn.Linear): 301 | nn.init.orthogonal_(layer.weight) 302 | 303 | class UNet(ConditionedModule): 304 | def __init__(self, d_blocks, u_blocks): 305 | super().__init__() 306 | self.d_blocks = nn.ModuleList(d_blocks) 307 | self.u_blocks = nn.ModuleList(u_blocks) 308 | 309 | def forward(self, x, cond): 310 | skips = [] 311 | for block in self.d_blocks: 312 | x = block(x, cond) 313 | skips.append(x) 314 | for i, (block, skip) in enumerate(zip(self.u_blocks, reversed(skips))): 315 | x = block(x, cond, skip if i > 0 else None) 316 | return x 317 | -------------------------------------------------------------------------------- /dance_diffusion/dd/ddattnunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .blocks import SkipBlock, FourierFeatures, SelfAttention1d, ResConvBlock, Downsample1d, Upsample1d 6 | 7 | 8 | class DiffusionAttnUnet1D(nn.Module): 9 | def __init__( 10 | self, 11 | global_args, 12 | io_channels = 2, 13 | depth=14, 14 | n_attn_layers = 6, 15 | c_mults = [128, 128, 256, 256] + [512] * 10 16 | ): 17 | super().__init__() 18 | 19 | self.timestep_embed = FourierFeatures(1, 16) 20 | 21 | attn_layer = depth - n_attn_layers - 1 22 | 23 | block = nn.Identity() 24 | 25 | conv_block = ResConvBlock 26 | 27 | for i in range(depth, 0, -1): 28 | c = c_mults[i - 1] 29 | if i > 1: 30 | c_prev = c_mults[i - 2] 31 | add_attn = i >= attn_layer and n_attn_layers > 0 32 | block = SkipBlock( 33 | Downsample1d("cubic"), 34 | conv_block(c_prev, c, c), 35 | SelfAttention1d( 36 | c, c // 32) if add_attn else nn.Identity(), 37 | conv_block(c, c, c), 38 | SelfAttention1d( 39 | c, c // 32) if add_attn else nn.Identity(), 40 | conv_block(c, c, c), 41 | SelfAttention1d( 42 | c, c // 32) if add_attn else nn.Identity(), 43 | block, 44 | conv_block(c * 2 if i != depth else c, c, c), 45 | SelfAttention1d( 46 | c, c // 32) if add_attn else nn.Identity(), 47 | conv_block(c, c, c), 48 | SelfAttention1d( 49 | c, c // 32) if add_attn else nn.Identity(), 50 | conv_block(c, c, c_prev), 51 | SelfAttention1d(c_prev, c_prev // 52 | 32) if add_attn else nn.Identity(), 53 | Upsample1d(kernel="cubic") 54 | # nn.Upsample(scale_factor=2, mode='linear', 55 | # align_corners=False), 56 | ) 57 | else: 58 | block = nn.Sequential( 59 | conv_block(io_channels + 16 + global_args.get('latent_dim'), c, c), 60 | conv_block(c, c, c), 61 | conv_block(c, c, c), 62 | block, 63 | conv_block(c * 2, c, c), 64 | conv_block(c, c, c), 65 | conv_block(c, c, io_channels, is_last=True), 66 | ) 67 | self.net = block 68 | 69 | with torch.no_grad(): 70 | for param in self.net.parameters(): 71 | param *= 0.5 72 | 73 | def forward(self, input, t, cond=None): 74 | timestep_embed = self.timestep_embed(t[:, None])[..., None].repeat([1, 1, input.shape[2]]) 75 | 76 | inputs = [input, timestep_embed] 77 | 78 | if cond is not None: 79 | cond = F.interpolate(cond, (input.shape[2], ), mode='linear', align_corners=False) 80 | inputs.append(cond) 81 | 82 | return self.net(torch.cat(inputs, dim=1)) 83 | -------------------------------------------------------------------------------- /dance_diffusion/dd/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tqdm.auto import trange 4 | 5 | from diffusion.utils import t_to_alpha_sigma 6 | from k_diffusion.external import VDenoiser 7 | 8 | from typing import Tuple, Callable 9 | from diffusion_library.scheduler import SchedulerType 10 | from diffusion_library.sampler import SamplerType 11 | from dance_diffusion.base.model import ModelWrapperBase 12 | from dance_diffusion.base.inference import InferenceBase 13 | 14 | from util.util import tensor_slerp_2D, PosteriorSampling 15 | 16 | class DDInference(InferenceBase): 17 | 18 | def __init__( 19 | self, 20 | device_accelerator: torch.device = None, 21 | device_offload: torch.device = None, 22 | optimize_memory_use: bool = False, 23 | use_autocast: bool = True, 24 | model: ModelWrapperBase = None 25 | ): 26 | super().__init__(device_accelerator, device_offload, optimize_memory_use, use_autocast, model) 27 | 28 | def generate( 29 | self, 30 | callback: Callable = None, 31 | batch_size: int = None, 32 | seed: int = None, 33 | steps: int = None, 34 | scheduler: SchedulerType = None, 35 | scheduler_args: dict = None, 36 | sampler: SamplerType = None, 37 | sampler_args: dict = None, 38 | **kwargs 39 | ): 40 | self.generator.manual_seed(seed) 41 | 42 | step_list = scheduler.get_step_list(steps, self.device_accelerator.type, **scheduler_args)#step_list = step_list[:-1] if sampler in [SamplerType.V_PRK, SamplerType.V_PLMS, SamplerType.V_PIE, SamplerType.V_PLMS2, SamplerType.V_IPLMS] else step_list 43 | 44 | if SamplerType.is_v_sampler(sampler): 45 | x_T = torch.randn([batch_size, 2, self.model.chunk_size], generator=self.generator, device=self.device_accelerator) 46 | model = self.model.model 47 | else: 48 | x_T = step_list[0] * torch.randn([batch_size, 2, self.model.chunk_size], generator=self.generator, device=self.device_accelerator) 49 | model = VDenoiser(self.model.model) 50 | 51 | with self.offload_context(self.model.model): 52 | return sampler.sample( 53 | model, 54 | x_T, 55 | step_list, 56 | callback, 57 | **sampler_args 58 | ).float() 59 | 60 | 61 | def generate_variation( 62 | self, 63 | callback: Callable = None, 64 | batch_size: int = None, 65 | seed: int = None, 66 | audio_source: torch.Tensor = None, 67 | expansion_map: list[int] = None, 68 | noise_level: float = None, 69 | steps: int = None, 70 | scheduler: SchedulerType = None, 71 | scheduler_args = None, 72 | sampler: SamplerType = None, 73 | sampler_args = None, 74 | **kwargs 75 | ) -> torch.Tensor: 76 | self.generator.manual_seed(seed) 77 | 78 | audio_source = self.expand(audio_source, expansion_map) 79 | 80 | if SamplerType.is_v_sampler(sampler): 81 | step_list = scheduler.get_step_list(steps, self.device_accelerator.type, **scheduler_args) 82 | step_list = step_list[step_list < noise_level] 83 | alpha_T, sigma_T = t_to_alpha_sigma(step_list[0]) 84 | x_T = alpha_T * audio_source + sigma_T * torch.randn(audio_source.shape, device=audio_source.device, generator=self.generator) 85 | model = self.model.model 86 | else: 87 | scheduler_args.update(sigma_max = scheduler_args.get('sigma_max', 1.0) * noise_level) 88 | step_list = scheduler.get_step_list(steps, self.device_accelerator.type, **scheduler_args) 89 | x_T = audio_source + step_list[0] * torch.randn(audio_source.shape, device=audio_source.device, generator=self.generator) 90 | model = VDenoiser(self.model.model) 91 | 92 | with self.offload_context(self.model.model): 93 | return sampler.sample( 94 | model, 95 | x_T, 96 | step_list, 97 | callback, 98 | **sampler_args 99 | ).float() 100 | 101 | 102 | def generate_interpolation( 103 | self, 104 | callback: Callable = None, 105 | batch_size: int = None, 106 | # seed: int = None, 107 | interpolation_positions: list[float] = None, 108 | audio_source: torch.Tensor = None, 109 | audio_target: torch.Tensor = None, 110 | expansion_map: list[int] = None, 111 | noise_level: float = None, 112 | steps: int = None, 113 | scheduler: SchedulerType = None, 114 | scheduler_args = None, 115 | sampler: SamplerType = None, 116 | sampler_args = None, 117 | **kwargs 118 | ) -> torch.Tensor: 119 | 120 | if SamplerType.is_v_sampler(sampler): 121 | step_list = scheduler.get_step_list(steps, self.device_accelerator.type, **scheduler_args) 122 | step_list = step_list[step_list < noise_level] 123 | step_list[-1] += 1e-7 #HACK avoid division by 0 in reverse sampling 124 | model = self.model.model 125 | else: 126 | scheduler_args.update(sigma_max = scheduler_args.get('sigma_max', 1.0) * noise_level) 127 | step_list = scheduler.get_step_list(steps, self.device_accelerator.type, **scheduler_args) 128 | step_list = step_list[:-1] #HACK avoid division by 0 in reverse sampling 129 | model = VDenoiser(self.model.model) 130 | 131 | if self.optimize_memory_use and batch_size < 2: 132 | x_0_source = audio_source 133 | x_0_target = audio_target 134 | 135 | with self.offload_context(self.model.model): 136 | x_T_source = sampler.sample( 137 | model, 138 | x_0_source, 139 | step_list.flip(0), 140 | callback, 141 | **sampler_args 142 | ) 143 | 144 | with self.offload_context(self.model.model): 145 | x_T_target = sampler.sample( 146 | model, 147 | x_0_target, 148 | step_list.flip(0), 149 | callback, 150 | **sampler_args 151 | ) 152 | 153 | x_T = torch.cat([x_T_source, x_T_target], dim=0) 154 | else: 155 | x_0 = torch.cat([audio_source, audio_target], dim=0) 156 | 157 | with self.offload_context(self.model.model): 158 | x_T = sampler.sample( 159 | model, 160 | x_0, 161 | step_list.flip(0), 162 | callback, 163 | **sampler_args 164 | ) 165 | 166 | if SamplerType.is_v_sampler(sampler): #HACK reset schedule after hack 167 | step_list[-1] = 0.0 168 | else: 169 | step_list = torch.cat([step_list, step_list.new_zeros([1])]) 170 | 171 | x_Int = torch.empty([batch_size, 2, self.model.chunk_size], device=self.device_accelerator) 172 | 173 | for pos in range(len(interpolation_positions)): 174 | x_Int[pos] = tensor_slerp_2D(x_T[0], x_T[1], interpolation_positions[pos]) 175 | 176 | with self.offload_context(self.model.model): 177 | return sampler.sample( 178 | model, 179 | x_Int, 180 | step_list, 181 | callback, 182 | **sampler_args 183 | ).float() 184 | 185 | 186 | def generate_inpainting( 187 | self, 188 | callback: Callable = None, 189 | batch_size: int = None, 190 | seed: int = None, 191 | audio_source: torch.Tensor = None, 192 | expansion_map: list[int] = None, 193 | mask: torch.Tensor = None, 194 | steps: int = None, 195 | scheduler: SchedulerType = None, 196 | scheduler_args = None, 197 | sampler: SamplerType = None, 198 | sampler_args = None, 199 | inpainting_args = None, 200 | **kwargs 201 | ) -> torch.Tensor: 202 | 203 | self.generator.manual_seed(seed) 204 | 205 | method = inpainting_args.get('method') 206 | 207 | if(method == 'repaint'): 208 | raise Exception("Repaint currently not supported due to changed requirements") 209 | 210 | elif(method == 'posterior_guidance'): 211 | step_list = scheduler.get_step_list(steps, self.device_accelerator.type, **scheduler_args) 212 | 213 | if SamplerType.is_v_sampler(sampler): 214 | raise Exception('V-Sampler currently not supported for posterior guidance. Please choose a K-Sampler.') 215 | else: 216 | x_T = audio_source + step_list[0] * torch.randn([batch_size, 2, self.model.chunk_size], generator=self.generator, device=self.device_accelerator) 217 | model = PosteriorSampling( 218 | VDenoiser(self.model.model), 219 | x_T, 220 | audio_source, 221 | mask, 222 | inpainting_args.get('posterior_guidance_scale') 223 | ) 224 | 225 | with self.offload_context(self.model.model): 226 | return sampler.sample( 227 | model, 228 | x_T, 229 | step_list, 230 | callback, 231 | **sampler_args 232 | ).float() 233 | 234 | 235 | def generate_extension( 236 | self, 237 | callback: Callable = None, 238 | batch_size: int = None, 239 | seed: int = None, 240 | audio_source: torch.Tensor = None, 241 | expansion_map: list[int] = None, 242 | steps: int = None, 243 | scheduler: SchedulerType = None, 244 | scheduler_args = None, 245 | sampler: SamplerType = None, 246 | sampler_args = None, 247 | inpainting_args = None, 248 | keep_start: bool = None, 249 | **kwargs 250 | ) -> torch.Tensor: 251 | 252 | half_chunk_size = self.model.chunk_size // 2 253 | chunk = torch.cat([audio_source[:, :, -half_chunk_size:], torch.zeros([batch_size, 2, half_chunk_size], device=self.device_accelerator)], dim=2) 254 | #chunk = audio_source 255 | 256 | mask = torch.cat( 257 | [torch.ones([batch_size, 2, half_chunk_size], dtype=torch.bool, device=self.device_accelerator), 258 | torch.zeros([batch_size, 2, half_chunk_size], dtype=torch.bool, device=self.device_accelerator)], 259 | dim=2 260 | ) 261 | 262 | output = self.generate_inpainting( 263 | callback, 264 | batch_size, 265 | seed, 266 | chunk, 267 | expansion_map, 268 | mask, 269 | steps, 270 | scheduler, 271 | scheduler_args, 272 | sampler, 273 | sampler_args, 274 | inpainting_args 275 | ) 276 | 277 | if (keep_start): 278 | return torch.cat( 279 | [audio_source, 280 | output[:, :, -half_chunk_size:]], 281 | dim=2 282 | ) 283 | else: 284 | return output[:, :, -half_chunk_size:] -------------------------------------------------------------------------------- /dance_diffusion/dd/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Callable 4 | 5 | from .ddattnunet import DiffusionAttnUnet1D 6 | from dance_diffusion.base.model import ModelWrapperBase 7 | from dance_diffusion.base.type import ModelType 8 | 9 | 10 | class DanceDiffusionInference(nn.Module): 11 | def __init__(self, n_attn_layers:int = 4, **kwargs): 12 | super().__init__() 13 | 14 | self.diffusion_ema = DiffusionAttnUnet1D(kwargs, n_attn_layers=n_attn_layers) 15 | 16 | class DDModelWrapper(ModelWrapperBase): 17 | def __init__(self): 18 | 19 | super().__init__() 20 | 21 | self.module:DanceDiffusionInference = None 22 | self.model:Callable = None 23 | 24 | def load( 25 | self, 26 | path:str, 27 | device_accelerator:torch.device, 28 | optimize_memory_use:bool=False, 29 | chunk_size:int=None, 30 | sample_rate:int=None 31 | ): 32 | 33 | default_model_config = dict( 34 | version = [0, 0, 1], 35 | model_info = dict( 36 | name = 'Dance Diffusion Model', 37 | description = 'v1.0', 38 | type = ModelType.DD, 39 | native_chunk_size = 65536, 40 | sample_rate = 48000, 41 | ), 42 | diffusion_config = dict( 43 | n_attn_layers = 4 44 | ) 45 | ) 46 | 47 | file = torch.load(path, map_location='cpu') 48 | 49 | model_config = file.get('model_config') 50 | if not model_config: 51 | print(f"Model file {path} is invalid. Please run the conversion script.") 52 | print(f" - Default model config will be used, which may be inaccurate.") 53 | model_config = default_model_config 54 | 55 | model_info = model_config.get('model_info') 56 | diffusion_config = model_config.get('diffusion_config') 57 | 58 | self.path = path 59 | self.chunk_size = model_info.get('native_chunk_size')if not chunk_size else chunk_size 60 | self.sample_rate = model_info.get('sample_rate')if not sample_rate else sample_rate 61 | 62 | self.module = DanceDiffusionInference( 63 | n_attn_layers=diffusion_config.get('n_attn_layers'), 64 | sample_size=chunk_size, 65 | sample_rate=sample_rate, 66 | latent_dim=0, 67 | ) 68 | 69 | self.module.load_state_dict( 70 | file["state_dict"], 71 | strict=False 72 | ) 73 | self.module.eval().requires_grad_(False) 74 | 75 | self.model = self.module.diffusion_ema if (optimize_memory_use) else self.module.diffusion_ema.to(device_accelerator) -------------------------------------------------------------------------------- /dance_diffusion/dd/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import warnings 4 | 5 | from torch import optim 6 | from contextlib import contextmanager 7 | 8 | def append_dims(x, target_dims): 9 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 10 | dims_to_append = target_dims - x.ndim 11 | if dims_to_append < 0: 12 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 13 | return x[(...,) + (None,) * dims_to_append] 14 | 15 | 16 | def n_params(module): 17 | """Returns the number of trainable parameters in a module.""" 18 | return sum(p.numel() for p in module.parameters()) 19 | 20 | 21 | @contextmanager 22 | def train_mode(model, mode=True): 23 | """A context manager that places a model into training mode and restores 24 | the previous mode on exit.""" 25 | modes = [module.training for module in model.modules()] 26 | try: 27 | yield model.train(mode) 28 | finally: 29 | for i, module in enumerate(model.modules()): 30 | module.training = modes[i] 31 | 32 | 33 | def eval_mode(model): 34 | """A context manager that places a model into evaluation mode and restores 35 | the previous mode on exit.""" 36 | return train_mode(model, False) 37 | 38 | 39 | @torch.no_grad() 40 | def ema_update(model, averaged_model, decay): 41 | """Incorporates updated model parameters into an exponential moving averaged 42 | version of a model. It should be called after each optimizer step.""" 43 | model_params = dict(model.named_parameters()) 44 | averaged_params = dict(averaged_model.named_parameters()) 45 | assert model_params.keys() == averaged_params.keys() 46 | 47 | for name, param in model_params.items(): 48 | averaged_params[name].mul_(decay).add_(param, alpha=1 - decay) 49 | 50 | model_buffers = dict(model.named_buffers()) 51 | averaged_buffers = dict(averaged_model.named_buffers()) 52 | assert model_buffers.keys() == averaged_buffers.keys() 53 | 54 | for name, buf in model_buffers.items(): 55 | averaged_buffers[name].copy_(buf) 56 | 57 | 58 | class EMAWarmup: 59 | """Implements an EMA warmup using an inverse decay schedule. 60 | If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are 61 | good values for models you plan to train for a million or more steps (reaches decay 62 | factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models 63 | you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 64 | 215.4k steps). 65 | Args: 66 | inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. 67 | power (float): Exponential factor of EMA warmup. Default: 1. 68 | min_value (float): The minimum EMA decay rate. Default: 0. 69 | max_value (float): The maximum EMA decay rate. Default: 1. 70 | start_at (int): The epoch to start averaging at. Default: 0. 71 | last_epoch (int): The index of last epoch. Default: 0. 72 | """ 73 | 74 | def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0, 75 | last_epoch=0): 76 | self.inv_gamma = inv_gamma 77 | self.power = power 78 | self.min_value = min_value 79 | self.max_value = max_value 80 | self.start_at = start_at 81 | self.last_epoch = last_epoch 82 | 83 | def state_dict(self): 84 | """Returns the state of the class as a :class:`dict`.""" 85 | return dict(self.__dict__.items()) 86 | 87 | def load_state_dict(self, state_dict): 88 | """Loads the class's state. 89 | Args: 90 | state_dict (dict): scaler state. Should be an object returned 91 | from a call to :meth:`state_dict`. 92 | """ 93 | self.__dict__.update(state_dict) 94 | 95 | def get_value(self): 96 | """Gets the current EMA decay rate.""" 97 | epoch = max(0, self.last_epoch - self.start_at) 98 | value = 1 - (1 + epoch / self.inv_gamma) ** -self.power 99 | return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value)) 100 | 101 | def step(self): 102 | """Updates the step count.""" 103 | self.last_epoch += 1 104 | 105 | 106 | class InverseLR(optim.lr_scheduler._LRScheduler): 107 | """Implements an inverse decay learning rate schedule with an optional exponential 108 | warmup. When last_epoch=-1, sets initial lr as lr. 109 | inv_gamma is the number of steps/epochs required for the learning rate to decay to 110 | (1 / 2)**power of its original value. 111 | Args: 112 | optimizer (Optimizer): Wrapped optimizer. 113 | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. 114 | power (float): Exponential factor of learning rate decay. Default: 1. 115 | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) 116 | Default: 0. 117 | final_lr (float): The final learning rate. Default: 0. 118 | last_epoch (int): The index of last epoch. Default: -1. 119 | verbose (bool): If ``True``, prints a message to stdout for 120 | each update. Default: ``False``. 121 | """ 122 | 123 | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., 124 | last_epoch=-1, verbose=False): 125 | self.inv_gamma = inv_gamma 126 | self.power = power 127 | if not 0. <= warmup < 1: 128 | raise ValueError('Invalid value for warmup') 129 | self.warmup = warmup 130 | self.final_lr = final_lr 131 | super().__init__(optimizer, last_epoch, verbose) 132 | 133 | def get_lr(self): 134 | if not self._get_lr_called_within_step: 135 | warnings.warn("To get the last learning rate computed by the scheduler, " 136 | "please use `get_last_lr()`.") 137 | 138 | return self._get_closed_form_lr() 139 | 140 | def _get_closed_form_lr(self): 141 | warmup = 1 - self.warmup ** (self.last_epoch + 1) 142 | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power 143 | return [warmup * max(self.final_lr, base_lr * lr_mult) 144 | for base_lr in self.base_lrs] 145 | 146 | 147 | def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): 148 | """Draws samples from an lognormal distribution.""" 149 | return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp() 150 | 151 | 152 | def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32): 153 | """Draws samples from an optionally truncated log-logistic distribution.""" 154 | min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64) 155 | max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64) 156 | min_cdf = min_value.log().sub(loc).div(scale).sigmoid() 157 | max_cdf = max_value.log().sub(loc).div(scale).sigmoid() 158 | u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf 159 | return u.logit().mul(scale).add(loc).exp().to(dtype) 160 | 161 | 162 | def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32): 163 | """Draws samples from an log-uniform distribution.""" 164 | min_value = math.log(min_value) 165 | max_value = math.log(max_value) 166 | return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() 167 | -------------------------------------------------------------------------------- /diffusion_library/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sudosilico/sample-diffusion/0729aaa3e1f77401008bcb89713beaeacde76781/diffusion_library/__init__.py -------------------------------------------------------------------------------- /diffusion_library/sampler.py: -------------------------------------------------------------------------------- 1 | import enum, torch 2 | from diffusion import sampling as vsampling 3 | from k_diffusion import sampling as ksampling 4 | 5 | 6 | class SamplerType(str, enum.Enum): 7 | V_DDPM = 'V_DDPM' 8 | V_DDIM = 'V_DDIM' 9 | V_PRK = 'V_PRK' 10 | V_PIE = 'V_PIE' 11 | V_PLMS = 'V_PLMS' 12 | V_PLMS2 = 'V_PLMS2' 13 | V_IPLMS = 'V_IPLMS' 14 | 15 | K_EULER = 'K_EULER' 16 | K_EULERA = 'K_EULERA' 17 | K_HEUN = 'K_HEUN' 18 | K_DPM2 = 'K_DPM2' 19 | K_DPM2A = 'K_DPM2A' 20 | K_LMS = 'K_LMS' 21 | K_DPMF = 'K_DPMF' 22 | K_DPMA = 'K_DPMA' 23 | K_DPMPP2SA = 'K_DPMPP2SA' 24 | K_DPMPP2M = 'K_DPMPP2M' 25 | K_DPMPPSDE = 'K_DPMPPSDE' 26 | 27 | @classmethod 28 | def is_v_sampler(cls, value): 29 | return value[0] == 'V' 30 | 31 | def sample(self, model_fn, x_t, steps, callback, **sampler_args) -> torch.Tensor: 32 | if self == SamplerType.V_DDPM: 33 | if sampler_args.get('is_reverse'): 34 | return vsampling.reverse_sample( 35 | model_fn, 36 | x_t, 37 | steps, 38 | 0.0, 39 | sampler_args.get('extra_args', {}), 40 | callback 41 | ) 42 | else: 43 | return vsampling.sample( 44 | model_fn, 45 | x_t, 46 | steps, 47 | 0.0, 48 | sampler_args.get('extra_args', {}), 49 | callback 50 | ) 51 | elif self == SamplerType.V_DDIM: 52 | if sampler_args.get('is_reverse'): # HACK: Technically incorrect since DDIM implies eta > 0.0 53 | return vsampling.reverse_sample( 54 | model_fn, 55 | x_t, 56 | steps, 57 | 0.0, 58 | sampler_args.get('extra_args', {}), 59 | callback 60 | ) 61 | else: 62 | return vsampling.sample( 63 | model_fn, 64 | x_t, 65 | steps, 66 | sampler_args.get('eta', 0.1), 67 | sampler_args.get('extra_args', {}), 68 | callback 69 | ) 70 | elif self == SamplerType.V_PRK: 71 | return vsampling.prk_sample( 72 | model_fn, 73 | x_t, 74 | steps, 75 | sampler_args.get('extra_args', {}), 76 | True, 77 | callback 78 | ) 79 | elif self == SamplerType.V_PIE: 80 | return vsampling.pie_sample( 81 | model_fn, 82 | x_t, 83 | steps, 84 | sampler_args.get('extra_args', {}), 85 | True, 86 | callback 87 | ) 88 | elif self == SamplerType.V_PLMS: 89 | return vsampling.plms_sample( 90 | model_fn, 91 | x_t, 92 | steps, 93 | sampler_args.get('extra_args', {}), 94 | True, 95 | callback 96 | ) 97 | elif self == SamplerType.V_PLMS2: 98 | return vsampling.plms2_sample( 99 | model_fn, 100 | x_t, 101 | steps, 102 | sampler_args.get('extra_args', {}), 103 | True, 104 | callback 105 | ) 106 | elif self == SamplerType.V_IPLMS: 107 | return vsampling.iplms_sample( 108 | model_fn, 109 | x_t, 110 | steps, 111 | sampler_args.get('extra_args', {}), 112 | True, 113 | callback 114 | ) 115 | elif self == SamplerType.K_EULER: 116 | return ksampling.sample_euler( 117 | model_fn, 118 | x_t, 119 | steps, 120 | sampler_args.get('extra_args', {}), 121 | callback, 122 | sampler_args.get('disable', False), 123 | sampler_args.get('s_churn', 0.0), 124 | sampler_args.get('s_tmin', 0.0), 125 | sampler_args.get('s_tmax',float('inf')), 126 | sampler_args.get('s_noise', 1.0) 127 | ) 128 | elif self == SamplerType.K_EULERA: 129 | return ksampling.sample_euler_ancestral( 130 | model_fn, 131 | x_t, 132 | steps, 133 | sampler_args.get('extra_args', {}), 134 | callback, 135 | sampler_args.get('disable', False), 136 | sampler_args.get('eta', 0.1), 137 | sampler_args.get('s_noise', 1.0), 138 | sampler_args.get('noise_sampler', None) 139 | ) 140 | elif self == SamplerType.K_HEUN: 141 | return ksampling.sample_heun( 142 | model_fn, 143 | x_t, 144 | steps, 145 | sampler_args.get('extra_args', {}), 146 | callback, 147 | sampler_args.get('disable', False), 148 | sampler_args.get('s_churn', 0.0), 149 | sampler_args.get('s_tmin', 0.0), 150 | sampler_args.get('s_tmax',float('inf')), 151 | sampler_args.get('s_noise', 1.0) 152 | ) 153 | elif self == SamplerType.K_DPM2: 154 | return ksampling.sample_dpm_2( 155 | model_fn, 156 | x_t, 157 | steps, 158 | sampler_args.get('extra_args', {}), 159 | callback, 160 | sampler_args.get('disable', False), 161 | sampler_args.get('s_churn', 0.0), 162 | sampler_args.get('s_tmin', 0.0), 163 | sampler_args.get('s_tmax',float('inf')), 164 | sampler_args.get('s_noise', 1.0) 165 | ) 166 | elif self == SamplerType.K_DPM2A: 167 | return ksampling.sample_dpm_2_ancestral( 168 | model_fn, 169 | x_t, 170 | steps, 171 | sampler_args.get('extra_args', {}), 172 | callback, 173 | sampler_args.get('disable', False), 174 | sampler_args.get('eta', 0.1), 175 | sampler_args.get('s_noise', 1.0), 176 | sampler_args.get('noise_sampler', None) 177 | ) 178 | elif self == SamplerType.K_LMS: 179 | return ksampling.sample_lms( 180 | model_fn, 181 | x_t, 182 | steps, 183 | sampler_args.get('extra_args', {}), 184 | callback, 185 | sampler_args.get('disable', False), 186 | sampler_args.get('order', 4) 187 | ) 188 | elif self == SamplerType.K_DPMF:# sample_dpm_fast 189 | return ksampling.sample_dpm_fast( 190 | model_fn, 191 | x_t, 192 | sampler_args.get('sigma_min', 0.001), 193 | sampler_args.get('sigma_max', 1.0), 194 | sampler_args.get('n', 3), 195 | sampler_args.get('extra_args', {}), 196 | callback, 197 | sampler_args.get('disable', False), 198 | sampler_args.get('eta', 0.1), 199 | sampler_args.get('s_noise', 1.0), 200 | sampler_args.get('noise_sampler', None) 201 | ) 202 | elif self == SamplerType.K_DPMA: 203 | return ksampling.sample_dpm_adaptive( 204 | model_fn, 205 | x_t, 206 | sampler_args.get('sigma_min', 0.001), 207 | sampler_args.get('sigma_max', 1.0), 208 | sampler_args.get('extra_args', {}), 209 | callback, 210 | sampler_args.get('disable', False), 211 | sampler_args.get('order', 3), 212 | sampler_args.get('rtol', 0.05), 213 | sampler_args.get('atol', 0.0078), 214 | sampler_args.get('h_init', 0.05), 215 | sampler_args.get('pcoeff', 0.0), 216 | sampler_args.get('icoeff', 1.0), 217 | sampler_args.get('dcoeff', 0.0), 218 | sampler_args.get('accept_safety', 0.81), 219 | sampler_args.get('eta', 0.1), 220 | sampler_args.get('s_noise', 1.0), 221 | sampler_args.get('noise_sampler', None), 222 | sampler_args.get('return_info', False) 223 | ) 224 | elif self == SamplerType.K_DPMPP2SA: 225 | return ksampling.sample_dpmpp_2s_ancestral( 226 | model_fn, 227 | x_t, 228 | steps, 229 | sampler_args.get('extra_args', {}), 230 | callback, 231 | sampler_args.get('disable', False), 232 | sampler_args.get('eta', 0.1), 233 | sampler_args.get('s_noise', 1.0), 234 | sampler_args.get('noise_sampler', None) 235 | ) 236 | elif self == SamplerType.K_DPMPP2M: 237 | return ksampling.sample_dpmpp_2m( 238 | model_fn, 239 | x_t, 240 | steps, 241 | sampler_args.get('extra_args', {}), 242 | callback, 243 | sampler_args.get('disable', False) 244 | ) 245 | elif self == SamplerType.K_DPMPPSDE: 246 | return ksampling.sample_dpmpp_sde( 247 | model_fn, 248 | x_t, 249 | steps, 250 | sampler_args.get('extra_args', {}), 251 | callback, 252 | sampler_args.get('disable', False), 253 | sampler_args.get('eta', 0.1), 254 | sampler_args.get('s_noise', 1.0), 255 | sampler_args.get('noise_sampler', None), 256 | sampler_args.get('r', 1/2) 257 | ) 258 | else: 259 | raise Exception(f"No sample implementation for sampler_type '{self}'") 260 | 261 | -------------------------------------------------------------------------------- /diffusion_library/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import enum 3 | import torch 4 | 5 | from diffusion import utils as vscheduling 6 | from k_diffusion import sampling as kscheduling 7 | 8 | 9 | class SchedulerType(str, enum.Enum): 10 | V_DDPM = 'V_DDPM' 11 | V_SPLICED_DDPM_COSINE = 'V_SPLICED_DDPM_COSINE' 12 | V_LOG = 'V_LOG' 13 | V_CRASH = 'V_CRASH' 14 | 15 | K_KARRAS = 'K_KARRAS' 16 | K_EXPONENTIAL = 'K_EXPONENTIAL' 17 | K_POLYEXPONENTIAL = 'K_POLYEXPONENTIAL' 18 | K_VP = 'K_VP' 19 | 20 | @classmethod 21 | def is_v_scheduler(cls, value): 22 | return value[0] == 'V' 23 | 24 | def get_step_list(self, n: int, device: str, **schedule_args): 25 | #if SchedulerType.is_v_scheduler(self): 26 | # n -= 1 27 | 28 | if self == SchedulerType.V_DDPM: 29 | return torch.nn.functional.pad(vscheduling.get_ddpm_schedule(torch.linspace(1, 0, n)), [0,1], value=0.0).to(device) 30 | elif self == SchedulerType.V_SPLICED_DDPM_COSINE: 31 | return vscheduling.get_spliced_ddpm_cosine_schedule(torch.linspace(1, 0, n + 1)).to(device) 32 | elif self == SchedulerType.V_LOG: 33 | return torch.nn.functional.pad( 34 | vscheduling.get_log_schedule( 35 | torch.linspace(1, 0, n), 36 | schedule_args.get('min_log_snr', -10.0), 37 | schedule_args.get('max_log_snr', 10.0) 38 | ), 39 | [0,1], 40 | value=0.0 41 | ).to(device) 42 | elif self == SchedulerType.V_CRASH: 43 | sigma = torch.sin(torch.linspace(1, 0, n + 1) * math.pi / 2) ** 2 44 | alpha = (1 - sigma ** 2) ** 0.5 45 | return vscheduling.alpha_sigma_to_t(alpha, sigma).to(device) 46 | elif self == SchedulerType.K_KARRAS: 47 | return kscheduling.get_sigmas_karras( 48 | n, 49 | schedule_args.get('sigma_min', 0.001), 50 | schedule_args.get('sigma_max', 1.0), 51 | schedule_args.get('rho', 7.0), 52 | device = device 53 | ) 54 | elif self == SchedulerType.K_EXPONENTIAL: 55 | return kscheduling.get_sigmas_exponential( 56 | n, 57 | schedule_args.get('sigma_min', 0.001), 58 | schedule_args.get('sigma_max', 1.0), 59 | device = device 60 | ) 61 | elif self == SchedulerType.K_POLYEXPONENTIAL: 62 | return kscheduling.get_sigmas_polyexponential( 63 | n, 64 | schedule_args.get('sigma_min', 0.001), 65 | schedule_args.get('sigma_max', 1.0), 66 | schedule_args.get('rho', 1.0), 67 | device = device 68 | ) 69 | elif self == SchedulerType.K_VP: 70 | return kscheduling.get_sigmas_vp( 71 | n, 72 | schedule_args.get('beta_d', 1.205), 73 | schedule_args.get('beta_min', 0.09), 74 | schedule_args.get('eps_s', 0.001), 75 | device = device 76 | ) 77 | else: 78 | raise Exception(f"No get_step_list implementation for scheduler_type '{self}'") 79 | -------------------------------------------------------------------------------- /environment-mac.yml: -------------------------------------------------------------------------------- 1 | name: dd 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - git 7 | - pip=22.2.2 8 | - python=3.10.5 9 | - setuptools 10 | - pytorch 11 | - torchvision 12 | - torchaudio 13 | - pip: 14 | - -e . 15 | - v-diffusion-pytorch 16 | - k-diffusion 17 | - black 18 | - diffusers 19 | variables: 20 | PYTORCH_ENABLE_MPS_FALLBACK: 1 21 | TRANSFORMERS_OFFLINE: 1 22 | HF_DATASETS_OFFLINE: 1 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dd 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - git 8 | - python=3.10 9 | - pip=22.2.2 10 | - pytorch=1.12.1 11 | - pytorch-cuda=11.7 12 | - torchvision 13 | - torchaudio 14 | - setuptools 15 | - pip: 16 | - -e . 17 | - v-diffusion-pytorch 18 | - k-diffusion 19 | - PySoundFile 20 | - black 21 | - diffusers 22 | variables: 23 | TRANSFORMERS_OFFLINE: 1 24 | HF_DATASETS_OFFLINE: 1 -------------------------------------------------------------------------------- /scripts/trim_model.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import time 4 | 5 | # original code by @Twobob 6 | 7 | def main(): 8 | if len(sys.argv) != 2: 9 | print("Usage: python trim_model.py ") 10 | sys.exit(1) 11 | 12 | model_path = sys.argv[1] 13 | 14 | if not os.path.isfile(model_path): 15 | print("No file was found at the given path.") 16 | sys.exit(1) 17 | 18 | print(f"Trimming model at '{model_path}'...\n") 19 | 20 | start_time = time.process_time() 21 | 22 | untrimmed_size = os.path.getsize(model_path) 23 | untrimmed = torch.load(model_path, map_location="cpu") 24 | 25 | trimmed = trim_model(untrimmed) 26 | 27 | output_path = model_path.replace(".ckpt", "_trim.ckpt") 28 | torch.save(trimmed, output_path) 29 | 30 | end_time = time.process_time() 31 | elapsed = end_time - start_time 32 | 33 | trimmed_size = os.path.getsize(output_path) 34 | 35 | bytes = untrimmed_size - trimmed_size 36 | megabytes = bytes / 1024.0 / 1024.0 37 | 38 | print(f"Untrimmed: {untrimmed_size} B, {untrimmed_size / 1024.0 / 1024.0} MB") 39 | print(f"Trimmed: {trimmed_size} B, {trimmed_size / 1024.0 / 1024.0} MB") 40 | 41 | print( 42 | f"\nDone! Trimmed {untrimmed_size - trimmed_size} B, or {megabytes} MB, in {elapsed} seconds." 43 | ) 44 | 45 | 46 | def trim_model(untrimmed): 47 | trimmed = dict() 48 | 49 | for k in untrimmed.keys(): 50 | if k != "optimizer_states": 51 | trimmed[k] = untrimmed[k] 52 | 53 | if "global_step" in untrimmed: 54 | print(f"Global step: {untrimmed['global_step']}.") 55 | 56 | temp = trimmed["state_dict"].copy() 57 | 58 | trimmed_model = dict() 59 | 60 | for k in temp: 61 | trimmed_model[k] = temp[k].half() 62 | 63 | trimmed["state_dict"] = trimmed_model 64 | 65 | return trimmed 66 | 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="sample-diffusion", 5 | version="0.0.1", 6 | description="", 7 | packages=find_packages(), 8 | install_requires=[ 9 | "torch", 10 | "tqdm", 11 | ] 12 | ) 13 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sudosilico/sample-diffusion/0729aaa3e1f77401008bcb89713beaeacde76781/util/__init__.py -------------------------------------------------------------------------------- /util/platform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_torch_device_type(): 5 | if is_mps_available(): 6 | return "mps" 7 | 8 | if torch.cuda.is_available(): 9 | return "cuda" 10 | 11 | return "cpu" 12 | 13 | 14 | def is_mps_available(): 15 | try: 16 | return torch.backends.mps.is_available() 17 | except: 18 | return False 19 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | from k_diffusion.utils import append_dims 5 | 6 | 7 | def tensor_slerp_2D(a: torch.Tensor, b: torch.Tensor, t: float): 8 | slerped = torch.empty_like(a) 9 | 10 | for channel in range(a.size(0)): 11 | slerped[channel] = tensor_slerp(a[channel], b[channel], t) 12 | 13 | return slerped 14 | 15 | 16 | def tensor_slerp(a: torch.Tensor, b: torch.Tensor, t: float): 17 | omega = torch.arccos(torch.dot(a / torch.linalg.norm(a), b / torch.linalg.norm(b))) 18 | so = torch.sin(omega) 19 | return torch.sin((1.0 - t) * omega) / so * a + torch.sin(t * omega) / so * b 20 | 21 | 22 | def load_audio(device, audio_path: str, sample_rate): 23 | 24 | if not os.path.exists(audio_path): 25 | raise RuntimeError(f"Audio file not found: {audio_path}") 26 | 27 | audio, file_sample_rate = torchaudio.load(audio_path) 28 | 29 | if file_sample_rate != sample_rate: 30 | resample = torchaudio.transforms.Resample(file_sample_rate, sample_rate) 31 | audio = resample(audio) 32 | 33 | return audio.to(device) 34 | 35 | 36 | def save_audio(audio_out, output_path: str, sample_rate, id_str:str = None): 37 | 38 | if not os.path.exists(output_path): 39 | os.makedirs(output_path) 40 | 41 | for ix, sample in enumerate(audio_out): 42 | output_file = os.path.join(output_path, f"sample_{id_str}_{ix + 1}.wav" if(id_str!=None) else f"sample_{ix + 1}.wav") 43 | open(output_file, "a").close() 44 | 45 | output = sample.cpu() 46 | 47 | torchaudio.save(output_file, output, sample_rate) 48 | 49 | 50 | def crop_audio(source: torch.Tensor, chunk_size: int, crop_offset: int = 0) -> torch.Tensor: 51 | n_channels, n_samples = source.shape 52 | 53 | offset = 0 54 | if (crop_offset > 0): 55 | offset = min(crop_offset, n_samples - chunk_size) 56 | elif (crop_offset == -1): 57 | offset = torch.randint(0, max(0, n_samples - chunk_size) + 1, []).item() 58 | 59 | chunk = source.new_zeros([n_channels, chunk_size]) 60 | chunk [:, :min(n_samples, chunk_size)] = source[:, offset:offset + chunk_size] 61 | 62 | return chunk 63 | 64 | 65 | class PosteriorSampling(torch.nn.Module): 66 | def __init__(self, model, x_T, measurement, mask, scale): 67 | super().__init__() 68 | self.model = model 69 | self.x_prev = x_T 70 | self.measurement = measurement 71 | self.mask = mask 72 | self.scale = scale 73 | 74 | @torch.enable_grad() 75 | def forward(self, input, sigma, **kwargs): 76 | x_t = input.detach().requires_grad_() 77 | out = self.model(x_t, sigma, **kwargs) 78 | difference = (self.measurement - out) * self.mask 79 | norm = torch.linalg.norm(difference) 80 | norm_grad = torch.autograd.grad(outputs=norm, inputs=x_t)[0].detach() 81 | 82 | return out.detach() - self.scale * norm_grad 83 | 84 | # x_t = input.detach().requires_grad_() 85 | # x_0_hat = self.model(input, sigma, **kwargs).detach().requires_grad_() 86 | 87 | # difference = (self.measurement - x_0_hat) * self.mask 88 | # norm = torch.linalg.norm(difference) 89 | # norm_grad = torch.autograd.grad(outputs=norm, inputs=self.x_prev)[0].detach() 90 | 91 | # self.x_prev = x_t.detach().requires_grad_() 92 | 93 | # return x_t.detach() - norm_grad * self.scale 94 | 95 | # class PosteriorSampling(torch.nn.Module): 96 | # def __init__(self, model, measurement, mask, strength): 97 | # super().__init__() 98 | # self.model = model 99 | # self.mask = mask 100 | # self.measurement = measurement 101 | # self.strength = strength 102 | 103 | # @torch.enable_grad() 104 | # def forward(self, input, sigma, **kwargs): 105 | # input = input.detach().requires_grad_() 106 | # out = self.model(input, sigma, **kwargs) 107 | # difference = (self.measurement - out) * self.mask 108 | # norm = torch.linalg.norm(difference) 109 | # norm_grad = torch.autograd.grad(outputs=norm, inputs=input)[0].detach() 110 | # N = self.measurement.shape[-1]**0.5 111 | # step_size = -self.strength * N / (torch.linalg.norm(norm_grad) + 1e-4) * append_dims(sigma**2, input.ndim) 112 | # print('Norm:', norm.detach()) 113 | # print('Step size:', step_size) 114 | # return out.detach() + step_size * norm_grad --------------------------------------------------------------------------------