├── mask.png ├── inputs.png ├── future-bach.png ├── masked-inputs.png ├── util.py ├── cog.yaml ├── README.md ├── data.py ├── output.py ├── predict.py └── train.py /mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/mask.png -------------------------------------------------------------------------------- /inputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/inputs.png -------------------------------------------------------------------------------- /future-bach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/future-bach.png -------------------------------------------------------------------------------- /masked-inputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreasjansson/cantable-diffuguesion/HEAD/masked-inputs.png -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | 4 | args = {} 5 | 6 | 7 | # TODO: this doesn't work 8 | @contextmanager 9 | def define_args(): 10 | global args 11 | 12 | before = {k: v for k, v in globals().items()} 13 | yield 14 | after = globals() 15 | args = {k: after[k] for k in set(after) - set(before)} 16 | print(args) 17 | 18 | 19 | def get_args(): 20 | return args 21 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | cuda: "11.7" 3 | python_version: "3.7" 4 | gpu: true 5 | python_packages: 6 | - torch==1.13.1 7 | - diffusers==0.11.1 8 | - decorator==4.3.0 9 | - flask==1.0.2 10 | - flask-cors==3.0.7 11 | - itsdangerous==1.1.0 12 | - jinja2==2.10 13 | - markupsafe==1.1.0 14 | - music21==6.7.1 15 | - tqdm==4.29.0 16 | - werkzeug==0.14.1 17 | - numpy==1.19.5 18 | - midi2audio==0.1.1 19 | - jupyterlab==3.5.2 20 | - matplotlib==3.5.3 21 | - pretty_midi==0.2.9 22 | - midiSynth==0.3 23 | - pyfluidsynth==1.3.1 24 | - wandb==0.13.7 25 | - mypy==0.991 26 | - soundfile==0.11.0 27 | system_packages: 28 | - musescore 29 | - fluidsynth --fix-missing 30 | - ffmpeg 31 | - lilypond 32 | 33 | predict: predict.py:Predictor 34 | image: r8.im/andreasjansson/cantable-diffuguesion 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cantable Diffuguesion 2 | 3 | [![Replicate](https://replicate.com/andreasjansson/cantable-diffuguesion/badge)](https://replicate.com/andreasjansson/cantable-diffuguesion) 4 | 5 | _Bach chorale generation and harmonization_ 6 | 7 | ![future bach](future-bach.png) 8 | 9 | ## Usage 10 | 11 | You can use Cantable Diffuguesion to generate Bach chorales unconditionally, or harmonize melodies or parts of melodies. 12 | 13 | For harmonization we use [tinyNotation](https://web.mit.edu/music21/doc/moduleReference/moduleTinyNotation.html), with a few modifications: 14 | * The `?` symbol followed by a duration denotes a section that the model should in-paint, e.g. `?2` will in-paint a half note duration. 15 | * The `?*` symbol will in-paint everything between a defined beginning and an end, e.g. `c2 ?* B4 c2` will start the piece with `c2`, then generate notes for the specified duration, and finally the melody will end with `B4 c2`. 16 | * Optional bars `|` are ignored and can be used to make the melody notation more pleasing. 17 | 18 | ## Training 19 | 20 | Cantable Diffuguesion is a diffusion model trained to generate Bach chorales. Four-part chorales are presented to the network as 4-channel arrays. The pitches of the individual parts are activated in the corresponding channel of the array. Here is a plot of a single input example, where the four channels are plotted on separate images: 21 | 22 | 23 | 24 | As in Stable Diffusion, a U-Net is trained to predict the noise residual. 25 | 26 | After training the generative model we add 8 channels to the inputs, with the middle four channels representing a mask, and the last four channels are masked chorales. We randomly mask the four channels individually, as opposed to [Stable Diffusion Inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) that use a one-channel mask. 27 | 28 | The two plots below show a mask and a masked input array: 29 | 30 | 31 | 32 | ## Dataset 33 | 34 | We use all four-part pieces in the [Music21 Bach Chorales corpus](https://web.mit.edu/music21/doc/moduleReference/moduleCorpusChorales.html). 85% are used for training, the rest for validation and testing. 35 | 36 | ## Inspiration 37 | 38 | * [Riffusion](https://github.com/riffusion/riffusion) 39 | * [DeepBach](https://arxiv.org/abs/1612.01010) 40 | * [Dreambooth Inpainting](https://github.com/huggingface/diffusers/blob/50b6513531da7e258204871a9c675a56875d9e69/examples/research_projects/dreambooth_inpaint/README.md) 41 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | import music21 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import Dataset, DataLoader, Subset 7 | 8 | 9 | MIN_PITCH = 30 10 | MAX_PITCH = 94 11 | MIN_TRANSPOSE = -6 12 | MAX_TRANSPOSE = 6 13 | NUM_PARTS = 4 14 | RESOLUTION = 4 # steps per quarter-note, e.g. 4 == 16th note resolution 15 | WIDTH = 64 16 | HEIGHT = MAX_PITCH - MIN_PITCH 17 | BATCH_SIZE = 32 18 | 19 | 20 | def train_val_test_dataloaders(): 21 | ds = BachDataset() 22 | 23 | rng = random.Random(0) 24 | idx = list(range(len(ds))) 25 | rng.shuffle(idx) 26 | 27 | train_idx = idx[: int(len(ds) * 0.8)] 28 | val_idx = idx[len(train_idx) : int(len(ds) * 0.9)] 29 | test_idx = idx[len(train_idx) + len(val_idx) :] 30 | 31 | train_ds = Subset(ds, indices=train_idx) 32 | val_ds = Subset(ds, indices=test_idx) 33 | test_ds = Subset(ds, indices=val_idx) 34 | 35 | train_ds = TransformDataset(train_ds, transform=RandomCropAndTranspose()) 36 | train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) 37 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0) 38 | test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0) 39 | 40 | return train_dl, val_dl, test_dl 41 | 42 | 43 | class BachDataset(Dataset): 44 | def __init__(self): 45 | self.data = load_data() 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, idx): 51 | return self.data[idx] 52 | 53 | 54 | def load_data(): 55 | cache_path = Path("dataset.cache.pt") 56 | if cache_path.exists(): 57 | return torch.load(cache_path) 58 | 59 | data = [] 60 | corpuses = music21.corpus.search("bach") 61 | for i, piece in enumerate(corpuses): 62 | if i % 10 == 0: 63 | print(f"{i+1}/{len(corpuses)}") 64 | piece = piece.parse() 65 | if len(piece.parts) == NUM_PARTS: 66 | data.append(piece_to_array(piece)) 67 | torch.save(data, cache_path) 68 | return data 69 | 70 | 71 | def piece_to_array(piece): 72 | duration = int(piece.expandRepeats().duration.quarterLength * RESOLUTION) 73 | arr = torch.zeros([NUM_PARTS, duration, MAX_PITCH - MIN_PITCH]) 74 | 75 | for part_i, part in enumerate(piece.parts): 76 | notes = part.expandRepeats().flat.notes 77 | for note in notes: 78 | next_note = note.next("Note") 79 | next_is_same = next_note and note.pitch.midi == next_note.pitch.midi 80 | subtract = 1 if next_is_same else 0 81 | start_column = int(note.offset * RESOLUTION) 82 | end_column = int( 83 | (note.offset + note.duration.quarterLength) * RESOLUTION - subtract 84 | ) 85 | pitch_row = note.pitch.midi - MIN_PITCH 86 | arr[part_i, start_column:end_column, pitch_row] = 1 87 | 88 | return arr 89 | 90 | 91 | class RandomCropAndTranspose(nn.Module): 92 | def forward(self, arr): 93 | duration = arr.shape[1] 94 | t = torch.randint(0, duration - WIDTH, size=(1,)).item() 95 | t = 4 * (t // 16) 96 | cropped = arr[:, t : t + WIDTH] 97 | transpose = torch.randint(MIN_TRANSPOSE, MAX_TRANSPOSE, size=(1,)).item() 98 | transposed = torch.roll(cropped, transpose, dims=2) 99 | return transposed 100 | 101 | 102 | class TransformDataset(Dataset): 103 | def __init__(self, subset, transform=None): 104 | self.subset = subset 105 | self.transform = transform 106 | 107 | def __getitem__(self, index): 108 | x = self.subset[index] 109 | if self.transform: 110 | x = self.transform(x) 111 | return x 112 | 113 | def __len__(self): 114 | return len(self.subset) 115 | -------------------------------------------------------------------------------- /output.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import subprocess 4 | import tempfile 5 | from PIL import Image 6 | import torch 7 | import pretty_midi as pm 8 | import matplotlib.pyplot as plt 9 | from midiSynth.synth import MidiSynth 10 | 11 | from data import RESOLUTION, MIN_PITCH 12 | 13 | 14 | PALETTE = torch.tensor( 15 | [ 16 | [0, 0, 0], # black 17 | [255, 0, 0], # red 18 | [0, 255, 0], # green 19 | [0, 0, 255], # blue 20 | [255, 0, 255], # magenta 21 | [255, 255, 0], # yellow 22 | [255, 255, 255], # white 23 | ] 24 | ) 25 | 26 | 27 | def array_to_plot(arr): 28 | color_matrix = PALETTE[ 29 | ((arr.cpu() > 0.75).permute([1, 2, 0]) * torch.arange(1, 5)) 30 | .max(axis=2)[0] 31 | .T.to(int) 32 | ] 33 | resize = 0.1 34 | height, width, _ = color_matrix.shape 35 | figsize = width * resize, height * resize 36 | 37 | fig = plt.figure(figsize=figsize) 38 | ax = fig.add_axes([0, 0, 1, 1]) 39 | ax.axis("off") 40 | 41 | ax.imshow(color_matrix, origin="lower") 42 | fig.canvas.draw() 43 | print(fig.canvas.get_width_height()) 44 | 45 | img = Image.frombytes( 46 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 47 | ) 48 | plt.close(fig) 49 | return img 50 | 51 | 52 | def array_to_plots(arr): 53 | fig, axs = plt.subplots(2, 2, figsize=[8, 8]) 54 | for i in range(2): 55 | for j in range(2): 56 | ax = axs[i][j] 57 | ax.imshow(arr[i * 2 + j].T, origin="lower") 58 | fig.tight_layout() 59 | fig.canvas.draw() 60 | 61 | img = Image.frombytes( 62 | "RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb() 63 | ) 64 | plt.close(fig) 65 | return img 66 | 67 | 68 | def midi_to_wav(midi_path, wav_path): 69 | MidiSynth().midi2audio(str(midi_path), str(wav_path)) 70 | return str(wav_path) 71 | 72 | 73 | def midi_to_mp3(midi_path, mp3_path): 74 | with tempfile.TemporaryDirectory() as temp_dir: 75 | wav_path = temp_dir + "/audio.wav" 76 | MidiSynth().midi2audio(str(midi_path), str(wav_path)) 77 | subprocess.check_output( 78 | [ 79 | "ffmpeg", 80 | "-y", 81 | "-i", 82 | str(wav_path), 83 | "-af", 84 | "silenceremove=1:0:-50dB,aformat=dblp,areverse,silenceremove=1:0:-50dB,aformat=dblp,areverse", # strip silence 85 | str(mp3_path), 86 | ], 87 | ) 88 | return mp3_path 89 | 90 | 91 | def midi_to_score(midi_path, score_path): 92 | with tempfile.TemporaryDirectory() as temp_dir: 93 | lilypond_path = temp_dir + "/score.ly" 94 | subprocess.check_output( 95 | ["midi2ly", str(midi_path), "--output", str(lilypond_path)], 96 | ) 97 | subprocess.check_output( 98 | [ 99 | "lilypond", 100 | "-fpng", 101 | "-dresolution=300", 102 | '-dpaper-size="a5landscape"', 103 | "-dcrop", 104 | "-o", 105 | str(Path(score_path).with_suffix("")), 106 | str(lilypond_path), 107 | ] 108 | ) 109 | cropped_path = str(Path(score_path).with_suffix("")) + ".cropped." + str(Path(score_path).suffix) 110 | if Path(cropped_path).exists(): 111 | os.rename(cropped_path, str(score_path)) 112 | return score_path 113 | 114 | 115 | def array_to_midi( 116 | arr, midi_path, instrument_name="Lead 6 (voice)", tempo=90, time_sig=4 117 | ): 118 | sec_per_beat = 60 / tempo 119 | track = pm.PrettyMIDI(initial_tempo=tempo) 120 | track.time_signature_changes.append(pm.TimeSignature(time_sig, 4, 0)) 121 | 122 | for mat in arr: 123 | instrument = pm.Instrument(pm.instrument_name_to_program(instrument_name)) 124 | write_notes(instrument, mat, sec_per_beat) 125 | track.instruments.append(instrument) 126 | 127 | track.write(str(midi_path)) 128 | return midi_path 129 | 130 | 131 | def write_notes(instrument, mat, sec_per_beat): 132 | def append_note(pitch, start_beat, end_beat): 133 | note = pm.Note( 134 | pitch=pitch, 135 | velocity=120, 136 | start=start_beat * sec_per_beat, 137 | end=end_beat * sec_per_beat, 138 | ) 139 | instrument.notes.append(note) 140 | 141 | cur_pitch = None 142 | start_beat = None 143 | for t, vec in enumerate(mat): 144 | beat = t / RESOLUTION 145 | 146 | max_i = int(torch.argmax(vec).item()) 147 | if vec[max_i] > 0.75: 148 | pitch = max_i + MIN_PITCH 149 | if pitch != cur_pitch: 150 | if cur_pitch: 151 | append_note(cur_pitch, start_beat, beat) 152 | cur_pitch = pitch 153 | start_beat = beat 154 | if start_beat is None: 155 | start_beat = beat 156 | cur_pitch = pitch 157 | else: 158 | if cur_pitch is not None: 159 | append_note(cur_pitch, start_beat, beat) 160 | start_beat = None 161 | cur_pitch = None 162 | 163 | if cur_pitch is not None: 164 | append_note(cur_pitch, start_beat, mat.shape[1] / RESOLUTION) 165 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | import sys 4 | import torch 5 | import torch.nn.functional as F 6 | import music21 7 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel 8 | from cog import BaseModel, BasePredictor, Input, Path 9 | 10 | from data import RESOLUTION, MIN_PITCH, MAX_PITCH 11 | from output import array_to_midi, midi_to_score, midi_to_mp3 12 | 13 | 14 | class Output(BaseModel): 15 | mp3: Optional[Path] 16 | #score: Optional[Path] 17 | midi: Optional[Path] 18 | 19 | 20 | class Predictor(BasePredictor): 21 | def setup(self): 22 | self.model = model = UNet2DModel.from_pretrained("checkpoints/unet").to("cuda") 23 | 24 | def predict( 25 | self, 26 | duration: int = Input( 27 | description="Duration in quarter notes", 28 | choices=( 29 | 64 // RESOLUTION, 30 | 128 // RESOLUTION, 31 | 256 // RESOLUTION, 32 | 512 // RESOLUTION, 33 | 1024 // RESOLUTION, 34 | ), 35 | default=128 // RESOLUTION, 36 | ), 37 | tempo: float = Input( 38 | description="Tempo in quarter notes per minute", default=90, ge=40, le=200 39 | ), 40 | melody: str = Input( 41 | description="Melody in tinyNotation format. Accepts ? for inpainting a single note, and ?* for inpainting between two melodic parts", 42 | default="", 43 | ), 44 | # not working :( 45 | # return_score: bool = Input( 46 | # description="Return sheet music score", default=True 47 | # ), 48 | return_mp3: bool = Input(description="Return mp3 audio", default=True), 49 | return_midi: bool = Input(description="Return midi", default=True), 50 | seed: int = Input(description="Random seed. Random if seed == -1", default=-1), 51 | ) -> Output: 52 | num_outputs = 1 53 | 54 | #if not return_score and not return_mp3 and not return_midi: 55 | if not return_mp3 and not return_midi: 56 | raise Exception( 57 | "At least one of return_score, return_mp3, return_midi must be true" 58 | ) 59 | 60 | if seed == -1: 61 | seed = random.randint(0, 100000) 62 | 63 | length = duration * RESOLUTION 64 | 65 | if melody: 66 | mel_inputs, mel_mask, length = parse_melody(melody, length, num_outputs) 67 | else: 68 | mel_inputs = torch.zeros([num_outputs, 4, length, 64]).to("cuda") 69 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda") 70 | 71 | generator = torch.Generator(device="cuda") 72 | generator.manual_seed(seed) 73 | array = sample(self.model, generator, mel_inputs, mel_mask, 1000, length)[0] 74 | midi = array_to_midi(array, "/tmp/midi.mid", tempo=tempo) 75 | 76 | output = Output( 77 | midi=Path(midi) if return_midi else None, 78 | #score=Path(midi_to_score(midi, "/tmp/score.png")) if return_score else None, 79 | mp3=Path(midi_to_mp3(midi, "/tmp/audio.mp3")) if return_mp3 else None, 80 | ) 81 | 82 | return output 83 | 84 | 85 | def parse_melody(text, length, num_outputs): 86 | text = text.replace("|", "") 87 | if "?*" in text: 88 | if text.count("?*") > 1: 89 | raise Exception("Can only have on '?*' in the input") 90 | text1, text2 = text.split("?*") 91 | mel_inputs1, mel_mask1 = parse_notes(text1, num_outputs) 92 | mel_inputs2, mel_mask2 = parse_notes(text2, num_outputs) 93 | 94 | notes_length = mel_inputs1.shape[2] + mel_inputs2.shape[2] 95 | if notes_length > length: 96 | length = notes_length 97 | mel_inputs = torch.zeros([num_outputs, 4, length, 64]).to("cuda") 98 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda") 99 | 100 | mel_inputs[:, :, : mel_inputs1.shape[2]] = mel_inputs1 101 | mel_mask[:, :, : mel_mask1.shape[2]] = mel_mask1 102 | mel_inputs[:, :, -mel_inputs2.shape[2] :] = mel_inputs2 103 | mel_mask[:, :, -mel_mask2.shape[2] :] = mel_mask2 104 | else: 105 | mel_inputs1, mel_mask1 = parse_notes(text, num_outputs) 106 | 107 | notes_length = mel_inputs1.shape[2] 108 | if notes_length > length: 109 | length = notes_length 110 | mel_inputs = torch.zeros([num_outputs, 4, length, 64]).to("cuda") 111 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda") 112 | 113 | mel_inputs[:, :, : mel_inputs1.shape[2]] = mel_inputs1 114 | mel_mask[:, :, : mel_mask1.shape[2]] = mel_mask1 115 | 116 | if length % 64 != 0: 117 | new_length = length - (length % 64) + 64 118 | pad = new_length - length 119 | mel_inputs = F.pad(mel_inputs, (0, 0, 0, pad), "constant", 0) 120 | mel_mask = F.pad(mel_mask, (0, 0, 0, pad), "constant", True) 121 | 122 | return mel_inputs, mel_mask, length 123 | 124 | 125 | def parse_notes(text, num_outputs): 126 | text = text.replace("?", "CC") # hack for inpainting masks 127 | 128 | notes = music21.converter.parse("tinyNotation: 4/4 " + text).flat.notes 129 | if len(notes) > 0: 130 | mel_length = int((notes[-1].offset + notes[-1].duration.quarterLength) * RESOLUTION) 131 | else: 132 | mel_length = 0 133 | 134 | mel_inputs = torch.zeros([num_outputs, 4, mel_length, 64]).to("cuda") 135 | mel_mask = torch.zeros_like(mel_inputs, dtype=torch.bool).to("cuda") 136 | 137 | for note in notes: 138 | if note.pitch.midi != 36: # == CC == "?" 139 | pitch = note.pitch.midi + 12 140 | if pitch < MIN_PITCH: 141 | raise Exception(f"Pitch is too low: {note}") 142 | if pitch > MAX_PITCH: 143 | raise Exception(f"Pitch is too high: {note}") 144 | start_index = int(note.offset * RESOLUTION) 145 | end_index = int( 146 | note.offset * RESOLUTION + note.duration.quarterLength * RESOLUTION 147 | ) 148 | mel_inputs[0, 0, start_index:end_index, pitch - MIN_PITCH] = 1 149 | mel_mask[:, 0, start_index:end_index] = True 150 | 151 | # staccato 152 | mel_inputs[0, 0, start_index - 1, pitch - MIN_PITCH] = 0 153 | 154 | return mel_inputs, mel_mask 155 | 156 | 157 | @torch.no_grad() 158 | def sample(model, generator, inputs, mask, num_inference_steps, length): 159 | num_outputs = inputs.shape[0] 160 | length = inputs.shape[2] 161 | noise_scheduler = DDPMScheduler(num_train_timesteps=num_inference_steps) 162 | image = torch.randn( 163 | (num_outputs, 4, length, model.sample_size), 164 | generator=generator, 165 | device="cuda", 166 | ) 167 | noise_scheduler.set_timesteps(num_inference_steps) 168 | 169 | for t in noise_scheduler.timesteps: 170 | model_input = torch.cat([image, mask, inputs], dim=1) 171 | model_output = model(model_input, t).sample 172 | 173 | image = noise_scheduler.step( 174 | model_output, t, image, generator=generator 175 | ).prev_sample 176 | 177 | image = (image / 2 + 0.5).clamp(0, 1) 178 | image[mask] = inputs[mask] 179 | return image[:, :, :length] 180 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import wandb 5 | import torch 6 | from torch import nn 7 | from tqdm.auto import tqdm 8 | from torch.nn.utils import clip_grad_norm_ 9 | import torch.nn.functional as F 10 | from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel 11 | from diffusers.optimization import get_scheduler 12 | from diffusers.training_utils import EMAModel 13 | 14 | from data import train_val_test_dataloaders 15 | from output import array_to_plot, array_to_midi, midi_to_mp3 16 | from util import define_args, get_args 17 | 18 | with define_args(): 19 | from data import BATCH_SIZE, WIDTH, HEIGHT 20 | 21 | EVAL_BATCH_SIZE = 4 22 | MIXED_PRECISION = "no" 23 | LEARNING_RATE = 1e-4 24 | ADAM_BETA1 = 0.95 25 | ADAM_BETA2 = 0.999 26 | ADAM_WEIGHT_DECAY = 1e-6 27 | ADAM_EPSILON = 1e-08 28 | LR_SCHEDULER = "cosine" 29 | LR_WARMUP_STEPS = 500 30 | NUM_EPOCHS = 100000 31 | GRADIENT_ACCUMULATION_STEPS = 10 32 | USE_EMA = True 33 | EMA_INV_GAMMA = 1.0 34 | EMA_POWER = 3 / 4 35 | EMA_MAX_DECAY = 0.9999 36 | SAVE_MEDIA_EPOCHS = 100 37 | SAVE_MODEL_EPOCHS = 200 38 | OUTPUT_DIR = "checkpoints" 39 | RESUME_FROM = "checkpoints-570000" 40 | TRAIN_INPAINTER = True 41 | SAMPLE_COUNT = 4 42 | 43 | 44 | def main(): 45 | os.makedirs(OUTPUT_DIR, exist_ok=True) 46 | 47 | if RESUME_FROM: 48 | model = UNet2DModel.from_pretrained(RESUME_FROM + "/unet").to("cuda") 49 | 50 | if model.conv_in.weight.shape[1] == 4 and TRAIN_INPAINTER: 51 | new_conv = nn.Conv2d(12, 128, kernel_size=3, padding=(1, 1)).to("cuda") 52 | new_conv.weight.data[:, :4, :, :] = model.conv_in.weight 53 | new_conv.bias.data = model.conv_in.bias 54 | model.conv_in = new_conv 55 | else: 56 | model = UNet2DModel( 57 | sample_size=64, 58 | in_channels=12 if TRAIN_INPAINTER else 4, 59 | out_channels=4, 60 | layers_per_block=2, 61 | block_out_channels=(128, 128, 256, 256, 512, 512), 62 | down_block_types=( 63 | "DownBlock2D", 64 | "DownBlock2D", 65 | "DownBlock2D", 66 | "DownBlock2D", 67 | "AttnDownBlock2D", 68 | "DownBlock2D", 69 | ), 70 | up_block_types=( 71 | "UpBlock2D", 72 | "AttnUpBlock2D", 73 | "UpBlock2D", 74 | "UpBlock2D", 75 | "UpBlock2D", 76 | "UpBlock2D", 77 | ), 78 | ).to("cuda") 79 | 80 | noise_scheduler = DDPMScheduler(num_train_timesteps=1000) 81 | optimizer = torch.optim.AdamW( 82 | model.parameters(), 83 | lr=LEARNING_RATE, 84 | betas=(ADAM_BETA1, ADAM_BETA2), 85 | weight_decay=ADAM_WEIGHT_DECAY, 86 | eps=ADAM_EPSILON, 87 | ) 88 | 89 | train_dl, test_dl, val_dl = train_val_test_dataloaders() 90 | 91 | lr_scheduler = get_scheduler( 92 | LR_SCHEDULER, 93 | optimizer=optimizer, 94 | num_warmup_steps=LR_WARMUP_STEPS, 95 | num_training_steps=(len(train_dl) * NUM_EPOCHS) // GRADIENT_ACCUMULATION_STEPS, 96 | ) 97 | 98 | num_update_steps_per_epoch = math.ceil(len(train_dl) / GRADIENT_ACCUMULATION_STEPS) 99 | 100 | ema_model = EMAModel( 101 | model, inv_gamma=EMA_INV_GAMMA, power=EMA_POWER, max_value=EMA_MAX_DECAY 102 | ) 103 | 104 | print("args", get_args()) 105 | wandb_run = wandb.init(project="cantable-diffuguesion", config=get_args()) 106 | 107 | pipeline = DDPMPipeline( 108 | unet=ema_model.averaged_model if USE_EMA else model, 109 | scheduler=noise_scheduler, 110 | ).to("cuda") 111 | 112 | global_step = 0 113 | for epoch in range(NUM_EPOCHS): 114 | model.train() 115 | progress_bar = tqdm( 116 | total=num_update_steps_per_epoch, 117 | ) 118 | progress_bar.set_description(f"Epoch {epoch}") 119 | 120 | # Generate sample images for visual inspection 121 | if epoch % SAVE_MEDIA_EPOCHS == 0: 122 | log_media( 123 | model, 124 | noise_scheduler, 125 | global_step, 126 | val_dl, 127 | ) 128 | 129 | if epoch % SAVE_MODEL_EPOCHS == 0: 130 | # save the model 131 | pipeline.save_pretrained(OUTPUT_DIR) 132 | artifact = wandb.Artifact("checkpoints", type="checkpoints") 133 | artifact.add_dir("checkpoints") # Adds multiple files to artifact 134 | wandb_run.log_artifact(artifact) 135 | 136 | for step, batch in enumerate(train_dl): 137 | clean_images = batch.to("cuda") 138 | 139 | # Sample noise that we'll add to the images 140 | noise = torch.randn(clean_images.shape).to(clean_images.device) 141 | bsz = clean_images.shape[0] 142 | # Sample a random timestep for each image 143 | timesteps = torch.randint( 144 | 0, 145 | noise_scheduler.config.num_train_timesteps, 146 | (bsz,), 147 | device=clean_images.device, 148 | ).long() 149 | 150 | # Add noise to the clean images according to the noise magnitude at each timestep 151 | # (this is the forward diffusion process) 152 | noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) 153 | 154 | if TRAIN_INPAINTER: 155 | mask = random_mask(clean_images.shape[0]) 156 | masked_images = clean_images.clone() 157 | masked_images[~mask] *= 0 158 | model_input = torch.cat([noisy_images, mask, masked_images], dim=1) 159 | else: 160 | model_input = noisy_images 161 | 162 | # Predict the noise residual 163 | noise_pred = model(model_input, timesteps).sample 164 | loss = F.mse_loss(noise_pred, noise) 165 | loss.backward() 166 | clip_grad_norm_(model.parameters(), 1.0) 167 | optimizer.step() 168 | lr_scheduler.step() 169 | if USE_EMA: 170 | ema_model.step(model) 171 | optimizer.zero_grad() 172 | 173 | progress_bar.update(1) 174 | global_step += 1 175 | 176 | logs = {"loss": loss.item(), "lr": lr_scheduler.get_last_lr()[0]} 177 | if USE_EMA: 178 | logs["ema_decay"] = ema_model.decay 179 | 180 | if global_step % 10 == 0: 181 | wandb.log(logs, step=global_step) 182 | 183 | progress_bar.set_postfix(**logs) 184 | 185 | progress_bar.close() 186 | log_media( 187 | ema_model.averaged_model if USE_EMA else model, 188 | noise_scheduler, 189 | global_step, 190 | val_dl, 191 | ) 192 | pipeline.save_pretrained(OUTPUT_DIR) 193 | wandb.save(OUTPUT_DIR) 194 | 195 | 196 | @torch.no_grad() 197 | def sample(model, noise_scheduler, mask=None, masked_images=None): 198 | generator = torch.Generator(device="cuda") 199 | generator.manual_seed(0) 200 | num_inference_steps = 1000 201 | noise_scheduler = DDPMScheduler(num_train_timesteps=num_inference_steps) 202 | image = torch.randn( 203 | (SAMPLE_COUNT, model.in_channels, WIDTH, model.sample_size), 204 | generator=generator, 205 | device="cuda", 206 | ) 207 | noise_scheduler.set_timesteps(num_inference_steps) 208 | 209 | for t in noise_scheduler.timesteps: 210 | # 1. predict noise model_output 211 | if TRAIN_INPAINTER: 212 | model_input = torch.cat([image.to("cuda"), mask.to("cuda"), masked_images.to("cuda")], dim=1).to("cuda") 213 | else: 214 | model_input = image 215 | model_output = model(model_input, t).sample 216 | 217 | # 2. compute previous image: x_t -> x_t-1 218 | image = noise_scheduler.step(model_output, t, image).prev_sample 219 | 220 | return (image / 2 + 0.5).clamp(0, 1) 221 | 222 | 223 | def random_mask(count): 224 | mask = torch.zeros([count, 4, WIDTH, HEIGHT], dtype=bool).to("cuda") 225 | for i in range(count): 226 | for c in range(4): 227 | start_index, end_index = torch.randint(0, WIDTH, (2,)) 228 | if start_index > end_index: 229 | start_index, end_index = end_index, start_index 230 | mask[i, c, start_index:end_index] = True 231 | return mask 232 | 233 | 234 | def log_media(model, noise_scheduler, global_step, val_dl=None): 235 | logs = {} 236 | # run pipeline in inference (sample random noise and denoise) 237 | if TRAIN_INPAINTER: 238 | mask = random_mask(SAMPLE_COUNT) 239 | masked_images = torch.zeros_like(mask, dtype=torch.float32).to("cuda") 240 | for i, batch in enumerate(val_dl): 241 | if i == SAMPLE_COUNT: 242 | break 243 | masked_images[i] = batch[0, :, :WIDTH] 244 | masked_images[~mask] *= 0 245 | logs["masked_inputs"] = [wandb.Image(array_to_plot(i)) for i in masked_images] 246 | 247 | arrays = sample(model, noise_scheduler, mask.to("cuda"), masked_images.to("cuda")) 248 | else: 249 | arrays = sample(model, noise_scheduler) 250 | #arrays = torch.from_numpy(arrays).permute([0, 3, 1, 2]) 251 | 252 | plots = [array_to_plot(a) for a in arrays] 253 | midis = [array_to_midi(a, f"/tmp/midi-output-{i}.mid") for i, a in enumerate(arrays)] 254 | audios = [ 255 | midi_to_mp3(midi, f"/tmp/audio-output-{i}.mp3") for i, midi in enumerate(midis) 256 | ] 257 | 258 | logs["plots"] = [wandb.Image(plot) for plot in plots] 259 | logs["audios"] = [wandb.Audio(audio) for audio in audios] 260 | 261 | wandb.log(logs, step=global_step) 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | --------------------------------------------------------------------------------